package pond

import (
	
	
	
	

	
)

var ErrGroupStopped = errors.New("task group stopped")

// TaskGroup represents a group of tasks that can be executed concurrently.
// The group can be waited on to block until all tasks have completed.
// If any of the tasks return an error, the group will return the first error encountered.
type TaskGroup interface {

	// Submits a task to the group.
	Submit(tasks ...func()) TaskGroup

	// Submits a task to the group that can return an error.
	SubmitErr(tasks ...func() error) TaskGroup

	// Waits for all tasks in the group to complete.
	// If any of the tasks return an error, the group will return the first error encountered.
	// If the context is cancelled, the group will return the context error.
	// If the group is stopped, the group will return ErrGroupStopped.
	// If a task is running when the context is cancelled or the group is stopped, the task will be allowed to complete before returning.
	Wait() error

	// Returns a channel that is closed when all tasks in the group have completed, a task returns an error, or the group is stopped.
	Done() <-chan struct{}

	// Stops the group and cancels all remaining tasks. Running tasks are not interrupted.
	Stop()

	// Returns the context associated with this group.
	// This context will be cancelled when either the parent context is cancelled
	// or any task in the group returns an error, whichever comes first.
	Context() context.Context
}

// ResultTaskGroup represents a group of tasks that can be executed concurrently.
// As opposed to TaskGroup, the tasks in a ResultTaskGroup yield a result.
// The group can be waited on to block until all tasks have completed.
// If any of the tasks return an error, the group will return the first error encountered.
type ResultTaskGroup[ any] interface {

	// Submits a task to the group.
	Submit(tasks ...func() ) ResultTaskGroup[]

	// Submits a task to the group that can return an error.
	SubmitErr(tasks ...func() (, error)) ResultTaskGroup[]

	// Waits for all tasks in the group to complete and returns the results of each task in the order they were submitted.
	// If any of the tasks return an error, the group will return the first error encountered.
	// If the context is cancelled, the group will return the context error.
	// If the group is stopped, the group will return ErrGroupStopped.
	// If a task is running when the context is cancelled or the group is stopped, the task will be allowed to complete before returning.
	Wait() ([], error)

	// Returns a channel that is closed when all tasks in the group have completed, a task returns an error, or the group is stopped.
	Done() <-chan struct{}

	// Stops the group and cancels all remaining tasks. Running tasks are not interrupted.
	Stop()
}

type result[ any] struct {
	Output 
	Err    error
}

type abstractTaskGroup[ func() | func() ,  func() error | func() (, error),  any] struct {
	pool           *pool
	nextIndex      atomic.Int64
	taskWaitGroup  sync.WaitGroup
	future         *future.CompositeFuture[*result[]]
	futureResolver future.CompositeFutureResolver[*result[]]
}

func ( *abstractTaskGroup[, , ]) () <-chan struct{} {
	return .future.Done(int(.nextIndex.Load()))
}

func ( *abstractTaskGroup[, , ]) () {
	.future.Cancel(ErrGroupStopped)
}

func ( *abstractTaskGroup[, , ]) () context.Context {
	return .future.Context()
}

func ( *abstractTaskGroup[, , ]) ( ...) *abstractTaskGroup[, , ] {
	for ,  := range  {
		.submit()
	}

	return 
}

func ( *abstractTaskGroup[, , ]) ( ...) *abstractTaskGroup[, , ] {
	for ,  := range  {
		.submit()
	}

	return 
}

func ( *abstractTaskGroup[, , ]) ( any) {
	 := int(.nextIndex.Add(1) - 1)

	.taskWaitGroup.Add(1)

	 := .pool.submit(func() error {
		defer .taskWaitGroup.Done()

		// Check if the context has been cancelled to prevent running tasks that are not needed
		if  := .future.Context().Err();  != nil {
			// Wrap the error with the context canceled error to reflect that the task was canceled.
			 = errors.Join(ErrContextCanceled, )

			.futureResolver(, &result[]{
				Err: ,
			}, )

			return 
		}

		// Invoke the task
		,  := invokeTask[](, .pool.panicRecovery)

		.futureResolver(, &result[]{
			Output: ,
			Err:    ,
		}, )

		return 
	}, .pool.nonBlocking)

	if  != nil {
		.taskWaitGroup.Done()

		.futureResolver(, &result[]{
			Err: ,
		}, )
	}
}

type taskGroup struct {
	abstractTaskGroup[func(), func() error, struct{}]
}

func ( *taskGroup) ( ...func()) TaskGroup {
	.abstractTaskGroup.Submit(...)
	return 
}

func ( *taskGroup) ( ...func() error) TaskGroup {
	.abstractTaskGroup.SubmitErr(...)
	return 
}

func ( *taskGroup) () error {
	,  := .future.Wait(int(.nextIndex.Load()))
	// This wait group could reach zero before the future is resolved if called in between tasks being submitted and the future being resolved.
	// That's why we wait for the future to be resolved before waiting for the wait group.
	.taskWaitGroup.Wait()
	return 
}

type resultTaskGroup[ any] struct {
	abstractTaskGroup[func() , func() (, error), ]
}

func ( *resultTaskGroup[]) ( ...func() ) ResultTaskGroup[] {
	.abstractTaskGroup.Submit(...)
	return 
}

func ( *resultTaskGroup[]) ( ...func() (, error)) ResultTaskGroup[] {
	.abstractTaskGroup.SubmitErr(...)
	return 
}

func ( *resultTaskGroup[]) () ([], error) {
	,  := .future.Wait(int(.nextIndex.Load()))

	// This wait group could reach zero before the future is resolved if called in between tasks being submitted and the future being resolved.
	// That's why we wait for the future to be resolved before waiting for the wait group.
	.taskWaitGroup.Wait()

	 := make([], len())

	for ,  := range  {
		if  != nil {
			[] = .Output
		}
	}

	return , 
}

func newTaskGroup( *pool,  context.Context) TaskGroup {
	,  := future.NewCompositeFuture[*result[struct{}]]()

	return &taskGroup{
		abstractTaskGroup: abstractTaskGroup[func(), func() error, struct{}]{
			pool:           ,
			future:         ,
			futureResolver: ,
		},
	}
}

func newResultTaskGroup[ any]( *pool,  context.Context) ResultTaskGroup[] {
	,  := future.NewCompositeFuture[*result[]]()

	return &resultTaskGroup[]{
		abstractTaskGroup: abstractTaskGroup[func() , func() (, error), ]{
			pool:           ,
			future:         ,
			futureResolver: ,
		},
	}
}