package pond
import (
"context"
"errors"
"fmt"
"math"
"sync"
"sync/atomic"
"github.com/alitto/pond/v2/internal/future"
"github.com/alitto/pond/v2/internal/linkedbuffer"
)
const (
Unbounded = math .MaxInt
DefaultQueueSize = Unbounded
DefaultNonBlocking = false
LinkedBufferInitialSize = 1024
LinkedBufferMaxCapacity = 100 * 1024
)
var (
ErrQueueFull = errors .New ("queue is full" )
ErrQueueEmpty = errors .New ("queue is empty" )
ErrPoolStopped = errors .New ("pool stopped" )
ErrMaxConcurrencyReached = errors .New ("max concurrency reached" )
poolStoppedFuture = func () Task {
future , resolve := future .NewFuture (context .Background ())
resolve (ErrPoolStopped )
return future
}()
)
type BasePool interface {
RunningWorkers () int64
SubmittedTasks () uint64
WaitingTasks () uint64
FailedTasks () uint64
SuccessfulTasks () uint64
CompletedTasks () uint64
DroppedTasks () uint64
CanceledTasks () uint64
MaxConcurrency () int
QueueSize () int
NonBlocking () bool
Context () context .Context
Stop () Task
StopAndWait ()
Stopped () bool
Resize (maxConcurrency int )
}
type Pool interface {
BasePool
Go (task func ()) error
Submit (task func ()) Task
SubmitErr (task func () error ) Task
TrySubmit (task func ()) (Task , bool )
TrySubmitErr (task func () error ) (Task , bool )
NewSubpool (maxConcurrency int , options ...Option ) Pool
NewGroup () TaskGroup
NewGroupContext (ctx context .Context ) TaskGroup
}
type pool struct {
mutex sync .Mutex
parent *pool
ctx context .Context
cancel context .CancelCauseFunc
nonBlocking bool
panicRecovery bool
maxConcurrency int
closed atomic .Bool
workerCount atomic .Int64
workerWaitGroup sync .WaitGroup
submitWaiters chan struct {}
queueSize int
tasks *linkedbuffer .LinkedBuffer [any ]
submittedTaskCount atomic .Uint64
successfulTaskCount atomic .Uint64
failedTaskCount atomic .Uint64
droppedTaskCount atomic .Uint64
canceledTaskCount atomic .Uint64
}
func (p *pool ) Context () context .Context {
return p .ctx
}
func (p *pool ) Stopped () bool {
return p .closed .Load () || p .ctx .Err () != nil
}
func (p *pool ) MaxConcurrency () int {
p .mutex .Lock ()
defer p .mutex .Unlock ()
return p .maxConcurrency
}
func (p *pool ) Resize (maxConcurrency int ) {
if maxConcurrency == 0 {
maxConcurrency = math .MaxInt
}
if maxConcurrency < 0 {
panic (errors .New ("maxConcurrency must be greater than or equal to 0" ))
}
p .mutex .Lock ()
newWorkers := int (math .Min (float64 (maxConcurrency -p .maxConcurrency ), float64 (p .tasks .Len ())))
p .maxConcurrency = maxConcurrency
if newWorkers > 0 {
p .workerCount .Add (int64 (newWorkers ))
p .workerWaitGroup .Add (newWorkers )
}
p .mutex .Unlock ()
for i := 0 ; i < newWorkers ; i ++ {
p .launchWorker (nil )
}
}
func (p *pool ) QueueSize () int {
return p .queueSize
}
func (p *pool ) NonBlocking () bool {
return p .nonBlocking
}
func (p *pool ) RunningWorkers () int64 {
return p .workerCount .Load ()
}
func (p *pool ) SubmittedTasks () uint64 {
return p .submittedTaskCount .Load ()
}
func (p *pool ) WaitingTasks () uint64 {
return p .tasks .Len ()
}
func (p *pool ) FailedTasks () uint64 {
return p .failedTaskCount .Load ()
}
func (p *pool ) SuccessfulTasks () uint64 {
return p .successfulTaskCount .Load ()
}
func (p *pool ) CompletedTasks () uint64 {
return p .successfulTaskCount .Load () + p .failedTaskCount .Load ()
}
func (p *pool ) DroppedTasks () uint64 {
return p .droppedTaskCount .Load ()
}
func (p *pool ) CanceledTasks () uint64 {
return p .canceledTaskCount .Load ()
}
func (p *pool ) worker (task any ) {
var readTaskErr , err error
exitedNormally := false
defer func () {
if !exitedNormally {
p .updateMetrics (fmt .Errorf ("worker exited abnormally: %w" , err ))
task , err := p .readTask ()
if err != nil {
return
}
if task != nil {
p .launchWorker (task )
p .notifySubmitWaiter ()
}
}
}()
for {
if task != nil {
_, err = invokeTask [any ](task , p .panicRecovery )
p .updateMetrics (err )
}
task , readTaskErr = p .readTask ()
if readTaskErr != nil {
exitedNormally = true
return
}
}
}
func (p *pool ) subpoolWorker (task any ) func () (output any , err error ) {
return func () (output any , err error ) {
if task != nil {
output , err = invokeTask [any ](task , p .panicRecovery )
p .updateMetrics (err )
}
if task , err := p .readTask (); err == nil {
for {
submitErr := p .parent .submit (p .subpoolWorker (task ), p .nonBlocking )
if submitErr == nil {
break
}
if errors .Is (submitErr , ErrPoolStopped ) {
err = errors .Join (ErrContextCanceled , submitErr )
p .updateMetrics (err )
p .parent .updateMetrics (err )
}
task , err = p .readTask ()
if err != nil {
break
}
}
}
return
}
}
func (p *pool ) Go (task func ()) error {
return p .submit (task , p .nonBlocking )
}
func (p *pool ) Submit (task func ()) Task {
future , _ := p .wrapAndSubmit (task , p .nonBlocking )
return future
}
func (p *pool ) SubmitErr (task func () error ) Task {
future , _ := p .wrapAndSubmit (task , p .nonBlocking )
return future
}
func (p *pool ) TrySubmit (task func ()) (Task , bool ) {
return p .wrapAndSubmit (task , true )
}
func (p *pool ) TrySubmitErr (task func () error ) (Task , bool ) {
return p .wrapAndSubmit (task , true )
}
func (p *pool ) wrapAndSubmit (task any , nonBlocking bool ) (Task , bool ) {
if p .Stopped () {
return poolStoppedFuture , false
}
future , wrappedTask , resolve := p .wrapTask (task )
if err := p .submit (wrappedTask , nonBlocking ); err != nil {
resolve (err )
return future , false
}
return future , true
}
func (p *pool ) wrapTask (task any ) (Task , func () error , func (error )) {
ctx := p .Context ()
future , resolve := future .NewFuture (ctx )
wrappedTask := wrapTask [struct {}, func (error )](task , resolve , ctx , p .panicRecovery )
return future , wrappedTask , resolve
}
func (p *pool ) submit (task any , nonBlocking bool ) (err error ) {
p .submittedTaskCount .Add (1 )
if nonBlocking {
err = p .trySubmit (task )
} else {
err = p .blockingTrySubmit (task )
}
if err != nil {
p .droppedTaskCount .Add (1 )
}
return
}
func (p *pool ) blockingTrySubmit (task any ) error {
for {
if err := p .trySubmit (task ); err != ErrQueueFull {
return err
}
select {
case <- p .ctx .Done ():
return p .ctx .Err ()
case <- p .submitWaiters :
select {
case <- p .ctx .Done ():
return p .ctx .Err ()
default :
}
}
}
}
func (p *pool ) trySubmit (task any ) error {
p .mutex .Lock ()
if p .Stopped () {
p .mutex .Unlock ()
return ErrPoolStopped
}
queueEnabled := p .queueSize > 0
tasksLen := int (p .tasks .Len ())
if queueEnabled && tasksLen >= p .queueSize {
p .mutex .Unlock ()
return ErrQueueFull
}
if int (p .workerCount .Load ()) >= p .maxConcurrency {
if !queueEnabled {
p .mutex .Unlock ()
return ErrQueueFull
}
p .tasks .Write (task )
p .mutex .Unlock ()
return nil
}
p .workerCount .Add (1 )
p .workerWaitGroup .Add (1 )
if queueEnabled && tasksLen > 0 {
p .tasks .Write (task )
task , _ = p .tasks .Read ()
}
p .mutex .Unlock ()
p .launchWorker (task )
p .notifySubmitWaiter ()
return nil
}
func (p *pool ) launchWorker (task any ) {
if p .parent == nil {
go p .worker (task )
} else {
p .parent .submit (p .subpoolWorker (task ), p .nonBlocking )
}
}
func (p *pool ) readTask () (task any , err error ) {
p .mutex .Lock ()
if p .tasks .Len () == 0 {
p .workerCount .Add (-1 )
p .workerWaitGroup .Done ()
p .mutex .Unlock ()
p .notifySubmitWaiter ()
err = ErrQueueEmpty
return
}
if p .maxConcurrency > 0 && int (p .workerCount .Load ()) > p .maxConcurrency {
p .workerCount .Add (-1 )
p .workerWaitGroup .Done ()
p .mutex .Unlock ()
err = ErrMaxConcurrencyReached
return
}
task , _ = p .tasks .Read ()
p .mutex .Unlock ()
p .notifySubmitWaiter ()
return
}
func (p *pool ) notifySubmitWaiter () {
select {
case p .submitWaiters <- struct {}{}:
default :
return
}
}
func (p *pool ) updateMetrics (err error ) {
if err != nil {
if errors .Is (err , ErrContextCanceled ) {
p .canceledTaskCount .Add (1 )
} else {
p .failedTaskCount .Add (1 )
}
} else {
p .successfulTaskCount .Add (1 )
}
}
func (p *pool ) Stop () Task {
return Submit (func () {
p .mutex .Lock ()
p .closed .Store (true )
p .mutex .Unlock ()
p .workerWaitGroup .Wait ()
p .cancel (ErrPoolStopped )
})
}
func (p *pool ) StopAndWait () {
p .Stop ().Wait ()
}
func (p *pool ) NewSubpool (maxConcurrency int , options ...Option ) Pool {
return newPool (maxConcurrency , p , options ...)
}
func (p *pool ) NewGroup () TaskGroup {
return newTaskGroup (p , p .ctx )
}
func (p *pool ) NewGroupContext (ctx context .Context ) TaskGroup {
return newTaskGroup (p , ctx )
}
func newPool(maxConcurrency int , parent *pool , options ...Option ) *pool {
if parent != nil {
if maxConcurrency > parent .MaxConcurrency () {
panic (fmt .Errorf ("maxConcurrency cannot be greater than the parent pool's maxConcurrency (%d)" , parent .MaxConcurrency ()))
}
if maxConcurrency == 0 {
maxConcurrency = parent .MaxConcurrency ()
}
}
if maxConcurrency == 0 {
maxConcurrency = math .MaxInt
}
if maxConcurrency < 0 {
panic (errors .New ("maxConcurrency must be greater than or equal to 0" ))
}
pool := &pool {
ctx : context .Background (),
nonBlocking : DefaultNonBlocking ,
panicRecovery : true ,
maxConcurrency : maxConcurrency ,
queueSize : DefaultQueueSize ,
submitWaiters : make (chan struct {}, 1 ),
}
if parent != nil {
pool .parent = parent
pool .ctx = parent .Context ()
pool .queueSize = parent .queueSize
pool .nonBlocking = parent .nonBlocking
pool .panicRecovery = parent .panicRecovery
}
for _ , option := range options {
option (pool )
}
pool .ctx , pool .cancel = context .WithCancelCause (pool .ctx )
pool .tasks = linkedbuffer .NewLinkedBuffer [any ](LinkedBufferInitialSize , LinkedBufferMaxCapacity )
return pool
}
func NewPool (maxConcurrency int , options ...Option ) Pool {
return newPool (maxConcurrency , nil , options ...)
}
The pages are generated with Golds v0.8.4 . (GOOS=linux GOARCH=amd64)
Golds is a Go 101 project developed by Tapir Liu .
PR and bug reports are welcome and can be submitted to the issue list .
Please follow @zigo_101 (reachable from the left QR code) to get the latest news of Golds .