package machine

import (
	
	
	
	
	
)

type (
	InternalLogFunc   func(level LogLevel, msg string, args ...any)
	InternalCheckFunc func(states S) bool
)

// Subscriptions is an embed responsible for binding subscriptions,
// managing their indexes, processing triggers, and garbage collection.
type Subscriptions struct {
	// Mx locks the subscription manager. TODO optimize?
	Mx sync.Mutex

	mach  Api
	clock Clock
	is    InternalCheckFunc
	not   InternalCheckFunc
	log   InternalLogFunc

	stateCtx IndexStateCtx

	when         IndexWhen
	whenCtx      map[context.Context][]*WhenBinding
	whenTime     IndexWhenTime
	whenTimeCtx  map[context.Context][]*WhenTimeBinding
	whenArgs     IndexWhenArgs
	whenArgsCtx  map[context.Context][]*WhenArgsBinding
	whenQuery    []*whenQueryBinding
	whenQueryCtx map[context.Context][]*whenQueryBinding

	whenQueueEnds []*whenQueueEndsBinding
	whenQueue     []*whenQueueBinding
}

func (
	 Api,  Clock, ,  InternalCheckFunc,  InternalLogFunc,
) *Subscriptions {
	return &Subscriptions{
		mach:  ,
		clock: ,
		is:    ,
		not:   ,
		log:   ,

		when:        IndexWhen{},
		whenTime:    IndexWhenTime{},
		whenArgs:    IndexWhenArgs{},
		stateCtx:    IndexStateCtx{},
		whenCtx:     map[context.Context][]*WhenBinding{},
		whenTimeCtx: map[context.Context][]*WhenTimeBinding{},
		whenArgsCtx: map[context.Context][]*WhenArgsBinding{},
	}
}

// ///// ///// /////

// ///// PROCESSING

// ///// ///// /////

// ProcessStateCtx collects all deactivated state contexts, and returns theirs
// cancel funcs. Uses transition caches.
func ( *Subscriptions) ( S) []context.CancelFunc {
	// locks
	.Mx.Lock()
	defer .Mx.Unlock()

	var  []context.CancelFunc
	for ,  := range  {
		if ,  := .stateCtx[]; ! {
			continue
		}

		 = append(, .stateCtx[].Cancel)
		.log(LogOps, "[ctx:match] %s", )
		delete(.stateCtx, )
	}

	return 
}

// ProcessWhen collects all the matched active state subscriptions, and
// returns theirs channels.
func ( *Subscriptions) (,  S) []chan struct{} {
	// locks
	.Mx.Lock()
	defer .Mx.Unlock()

	// collect ctx expirations
	// TODO optimize by skipping
	 := .processWhenCtx()

	// collect matched bindings
	 := slices.Concat(, )
	for ,  := range  {
		// TODO optimize clone
		for ,  := range slices.Clone(.when[]) {

			if slices.Contains(, ) {

				// state activated, check the index
				if !.Negation {
					// match for When(
					if !.States[] {
						.Matched++
					}
				} else {
					// match for WhenNot(
					if !.States[] {
						.Matched--
					}
				}

				// update index: mark as active
				.States[] = true
			} else {

				// state deactivated
				if !.Negation {
					// match for When(
					if .States[] {
						.Matched--
					}
				} else {
					// match for WhenNot(
					if .States[] {
						.Matched++
					}
				}

				// update index: mark as inactive
				.States[] = false
			}

			// if not all matched, ignore for now
			 := .Ctx != nil && .Ctx.Err() != nil
			if .Matched < .Total && ! {
				continue
			}

			// completed - rm binding and collect ch
			.gcWhenBinding(, true)
			 = append(, .Ch)
		}
	}

	return 
}

func ( *Subscriptions) () []chan struct{} {
	var  []chan struct{}

	// find expired ctxs
	for ,  := range .whenCtx {
		if .Err() == nil {
			continue
		}

		// delete the ctx and all the bindings
		delete(.whenCtx, )
		for ,  := range  {
			.gcWhenBinding(, false)
		}
	}

	return 
}

func ( *Subscriptions) ( *WhenBinding,  bool) {
	// completed - close and delete indexes for all involved states
	var  []string
	for  := range .States {
		// remove GC ctx
		if .Ctx != nil &&  {
			.whenCtx[.Ctx] = slicesWithout(
				.whenCtx[.Ctx], )

			if len(.whenCtx[.Ctx]) == 0 {
				delete(.whenCtx, .Ctx)
			}
		}

		// delete state index
		 = append(, )
		if len(.when[]) == 1 {
			delete(.when, )
			continue
		}

		// delete with a lookup TODO optimize, GC later
		.when[] = slicesWithout(.when[], )
	}

	// log TODO sem logger
	if .Negation {
		.log(LogOps, "[whenNot:match] %s", j())
	} else {
		.log(LogOps, "[when:match] %s", j())
	}
}

// ProcessWhenTime collects all the time-based subscriptions, and
// returns theirs channels.
func ( *Subscriptions) ( Clock) []chan struct{} {
	// locks
	.Mx.Lock()
	defer .Mx.Unlock()

	// collect ctx expirations
	// TODO optimize by skipping
	 := .processWhenTimeCtx()

	// collect all the ticked states
	// TODO optimize?
	 := S{}
	for ,  := range  {
		// if changed, collect to check
		if .clock[] !=  {
			 = append(, )
		}
	}

	// check all the bindings for all the ticked states
	for ,  := range  {
		// TODO optimize clone
		for ,  := range slices.Clone(.whenTime[]) {

			// check if the requested time has passed
			if !.Completed[] &&
				.clock[] >= .Times[.Index[]] {
				.Matched++
				// mark in the index as completed
				.Completed[] = true
			}

			// if not all matched, ignore for now
			 := .Ctx != nil && .Ctx.Err() != nil
			if .Matched < .Total && ! {
				continue
			}

			// completed - rm binding and collect ch
			.gcWhenTimeBinding(, true)
			 = append(, .Ch)
		}
	}

	return 
}

func ( *Subscriptions) () []chan struct{} {
	var  []chan struct{}

	// find expired ctxs
	for ,  := range .whenTimeCtx {
		if .Err() == nil {
			continue
		}

		// delete the ctx and all the bindings
		delete(.whenTimeCtx, )
		for ,  := range  {
			.gcWhenTimeBinding(, false)
		}
	}

	return 
}

func ( *Subscriptions) (
	 *WhenTimeBinding,  bool,
) {
	// completed - close and delete indexes for all involved states
	var  []string
	for  := range .Index {
		// remove GC ctx
		if .Ctx != nil &&  {
			.whenTimeCtx[.Ctx] = slicesWithout(
				.whenTimeCtx[.Ctx], )

			if len(.whenTimeCtx[.Ctx]) == 0 {
				delete(.whenTimeCtx, .Ctx)
			}
		}

		// remove state index
		 = append(, )
		if len(.whenTime[]) == 1 {
			delete(.whenTime, )
			continue
		}

		// delete with a lookup
		.whenTime[] = slicesWithout(.whenTime[], )
	}

	// log TODO sem logger
	.log(LogOps, "[whenTime:match] %s %d", j(), .Times)
}

// ProcessWhenArgs collects all the args-matching subscriptions, and
// returns theirs channels.
func ( *Subscriptions) ( *Event) []chan struct{} {
	// locks
	.Mx.Lock()
	defer .Mx.Unlock()

	// collect ctx expirations
	// TODO optimize by skipping
	 := .processWhenArgsCtx()

	// collect arg matches
	for ,  := range slices.Clone(.whenArgs[.Name]) {
		 := .ctx != nil && .ctx.Err() != nil
		// TODO better comparison
		if !compareArgs(.Args, .args) && ! {
			continue
		}

		// completed - rm binding and collect ch
		.gcWhenArgsBinding(, true)
		 = append(, .ch)
	}

	return 
}

func ( *Subscriptions) () []chan struct{} {
	var  []chan struct{}

	// find expired ctxs
	for ,  := range .whenArgsCtx {
		if .Err() == nil {
			continue
		}

		// delete the ctx and all the bindings
		delete(.whenArgsCtx, )
		for ,  := range  {
			.gcWhenArgsBinding(, false)
		}
	}

	return 
}

func ( *Subscriptions) (
	 *WhenArgsBinding,  bool,
) {
	// remove GC ctx
	if .ctx != nil &&  {
		.whenArgsCtx[.ctx] = slicesWithout(
			.whenArgsCtx[.ctx], )

		if len(.whenArgsCtx[.ctx]) == 0 {
			delete(.whenArgsCtx, .ctx)
		}
	}

	// GC
	if len(.whenArgs[.handler]) == 1 {
		delete(.whenArgs, .handler)
	} else {
		.whenArgs[.handler] = slicesWithout(
			.whenArgs[.handler], )
	}

	// log TODO sem logger
	 := jw(slices.Collect(maps.Keys(.args)), ",")
	// FooState -> Foo
	,  := strings.CutSuffix(.handler, SuffixState)
	.log(LogOps, "[whenArgs:match] %s (%s)", , )
}

func ( *Subscriptions) () []chan struct{} {
	// locks
	.Mx.Lock()
	defer .Mx.Unlock()

	// collect ctx expirations
	// TODO optimize by skipping
	 := .processWhenQueryCtx()

	// collect arg matches TODO optimize clone
	for ,  := range slices.Clone(.whenQuery) {
		 := .ctx != nil && .ctx.Err() != nil
		if !.fn(.clock) && ! {
			continue
		}

		// completed - rm binding and collect ch
		.gcWhenQueryBinding(, true)
		 = append(, .ch)
	}

	return 
}

func ( *Subscriptions) () []chan struct{} {
	var  []chan struct{}

	// find expired ctxs
	for ,  := range .whenQueryCtx {
		if .Err() == nil {
			continue
		}

		// delete the ctx and all the bindings
		delete(.whenArgsCtx, )
		for ,  := range  {
			.gcWhenQueryBinding(, false)
		}
	}

	return 
}

func ( *Subscriptions) (
	 *whenQueryBinding,  bool,
) {
	// remove GC ctx
	if .ctx != nil &&  {
		.whenQueryCtx[.ctx] = slicesWithout(
			.whenQueryCtx[.ctx], )

		if len(.whenQueryCtx[.ctx]) == 0 {
			delete(.whenQueryCtx, .ctx)
		}
	}

	// GC
	 := 0
	if len(.whenQuery) == 1 {
		.whenQuery = nil
	} else {
		 = slices.Index(.whenQuery, )
		.whenQuery = slices.Delete(.whenQuery, , +1)
	}

	// log TODO sem logger
	.log(LogOps, "[whenQuery:match] %d", )
}

// ProcessWhenQueueEnds collects all queue-end subscriptions, and
// returns theirs channels.
func ( *Subscriptions) () []chan struct{} {
	// locks
	.Mx.Lock()
	defer .Mx.Unlock()

	// collect chans
	 := make([]chan struct{}, len(.whenQueueEnds))
	for ,  := range .whenQueueEnds {
		[] = .ch
	}

	// clean up
	.whenQueueEnds = nil

	return 
}

func ( *Subscriptions) ( uint64) []chan struct{} {
	// locks
	.Mx.Lock()
	defer .Mx.Unlock()

	// collect
	var  []chan struct{}
	var  []int
	for ,  := range .whenQueue {
		if uint64(.tick) >  {
			continue
		}

		// TODO sem logger
		.log(LogOps, "[whenQueue:match] %d", .tick)
		 = append(, .ch)
		 = append(, )
	}

	// GC
	slices.Reverse()
	for ,  := range  {
		.whenQueue = slices.Delete(.whenQueue, , +1)
	}

	// close and execute waits
	return 
}

// ///// ///// /////

// ///// BINDING

// ///// ///// /////

// NewStateCtx returns a new sub-context, bound to the current clock's tick of
// the passed state.
//
// Context cancels when the state has been deactivated, or right away,
// if it isn't currently active.
//
// State contexts are used to check state expirations and should be checked
// often inside goroutines.
func ( *Subscriptions) ( string) context.Context {
	// locks
	.Mx.Lock()
	defer .Mx.Unlock()

	if ,  := .stateCtx[];  {
		return .stateCtx[].Ctx
	}

	// store a fingerprint
	 := CtxValue{
		Id:    .mach.Id(),
		State: ,
		Tick:  .clock[],
	}
	,  := context.WithCancel(context.WithValue(.mach.Ctx(),
		CtxKey, ))

	// cancel early
	if !.is(S{}) {
		// TODO decision msg
		()
		return 
	}

	 := &CtxBinding{
		Ctx:    ,
		Cancel: ,
	}

	// add an index
	.stateCtx[] = 
	.log(LogOps, "[ctx:new] %s", )

	return 
}

func ( *Subscriptions) ( S,  context.Context) <-chan struct{} {
	// TODO re-use channels with the same state set and context

	// if all active, close early
	if .is() {
		return newClosedChan()
	}

	// locks
	.Mx.Lock()
	defer .Mx.Unlock()

	 := make(chan struct{})

	 := StateIsActive{}
	 := 0
	for ,  := range  {
		[] = .is(S{})
		if [] {
			++
		}
	}

	// add the binding to an index of each state
	 := &WhenBinding{
		Ch:       ,
		Negation: false,
		States:   ,
		Total:    len(),
		Matched:  ,
		Ctx:      ,
	}
	.log(LogOps, "[when:new] %s", j())

	// insert the binding
	for ,  := range  {
		.when[] = append(.when[], )

		if  != nil {
			.whenCtx[] = append(.whenCtx[], )
		}
	}

	return 
}

// WhenNot returns a channel that will be closed when all the passed states
// become inactive or the machine gets disposed.
//
// ctx: optional context that will close the channel early.
func ( *Subscriptions) (
	 S,  context.Context,
) <-chan struct{} {
	// if all inactive, close early
	if .not() {
		return newClosedChan()
	}

	// locks
	.Mx.Lock()
	defer .Mx.Unlock()

	 := make(chan struct{})
	 := StateIsActive{}
	 := 0
	for ,  := range  {
		[] = .is(S{})
		if ![] {
			++
		}
	}

	// add the binding to an index of each state
	 := &WhenBinding{
		Ch:       ,
		Negation: true,
		States:   ,
		Total:    len(),
		Matched:  ,
		Ctx:      ,
	}
	.log(LogOps, "[whenNot:new] %s", j())

	// insert the binding
	for ,  := range  {
		if ,  := .when[]; ! {
			.when[] = []*WhenBinding{}
		} else {
			.when[] = append(.when[], )
		}
	}
	if  != nil {
		.whenCtx[] = append(.whenCtx[], )
	}

	return 
}

// WhenArgs returns a channel that will be closed when the passed state
// becomes active with all the passed args. Args are compared using the native
// '=='. It's meant to be used with async Multi states, to filter out
// a specific call.
//
// ctx: optional context that will close the channel when handler loop ends.
func ( *Subscriptions) (
	 string,  A,  context.Context,
) <-chan struct{} {
	// TODO better val comparisons
	//  support regexp for strings
	// TODO support typed args
	// locks
	.Mx.Lock()
	defer .Mx.Unlock()

	 := make(chan struct{})
	 :=  + SuffixState

	// log TODO pass through logArgs?
	 := jw(slices.Collect(maps.Keys()), ",")
	.log(LogOps, "[whenArgs:new] %s (%s)", , )

	// try to reuse an existing channel
	for ,  := range .whenArgs[] {
		if compareArgs(.args, ) {
			return .ch
		}
	}

	 := &WhenArgsBinding{
		ch:      ,
		handler: ,
		args:    ,
		ctx:     ,
	}

	// insert the binding
	.whenArgs[] = append(.whenArgs[], )

	if  != nil {
		.whenArgsCtx[] = append(.whenArgsCtx[], )
	}

	return 
}

// WhenTime returns a channel that will be closed when all the passed states
// have passed the specified time. The time is a logical clock of the state.
// Machine time can be sourced from [Machine.Time](), or [Machine.Clock]().
//
// ctx: optional context that will close the channel early.
func ( *Subscriptions) (
	 S,  Time,  context.Context,
) <-chan struct{} {
	// locks
	.Mx.Lock()
	defer .Mx.Unlock()

	 := make(chan struct{})
	 := .whenTime

	// if all times passed, close early
	 := true
	for ,  := range  {
		if .clock[] < [] {
			 = false
			break
		}
	}
	if  {
		// TODO decision msg
		close()
		return 
	}

	 := StateIsActive{}
	 := 0
	 := map[string]int{}
	for ,  := range  {
		[] = .clock[] >= []
		if [] {
			++
		}
		[] = 
	}

	// add the binding to an index of each state
	 := &WhenTimeBinding{
		Ch:        ,
		Index:     ,
		Completed: ,
		Total:     len(),
		Matched:   ,
		Times:     ,
		Ctx:       ,
	}
	.log(LogOps, "[whenTime:new] %s %s", jw(, ","), )

	// insert the binding
	for ,  := range  {
		[] = append([], )
	}
	if  != nil {
		.whenTimeCtx[] = append(.whenTimeCtx[], )
	}

	return 
}

func ( *Subscriptions) (
	 func( Clock) bool,  context.Context,
) <-chan struct{} {
	// locks
	.Mx.Lock()
	defer .Mx.Unlock()

	// add the binding to an index of each state
	 := make(chan struct{})
	 := &whenQueryBinding{
		ch:  ,
		ctx: ,
		fn:  ,
	}

	// insert the binding
	.log(LogOps, "[whenQuery:new] %d", len(.whenQuery))
	.whenQuery = append(.whenQuery, )
	if  != nil {
		.whenQueryCtx[] = append(.whenQueryCtx[], )
	}

	return 
}

// WhenQueueEnds closes every time the queue ends, or the optional ctx expires.
// This function assumes the queue is running, and wont close early.
//
// ctx: optional context that will close the channel early.
func ( *Subscriptions) (
	 context.Context,  *sync.RWMutex,
) <-chan struct{} {
	// locks
	.Mx.Lock()
	defer .Mx.Unlock()

	// add the binding to an index of each state
	 := make(chan struct{})
	 := &whenQueueEndsBinding{
		ch:  ,
		ctx: ,
	}

	// insert the binding
	.whenQueueEnds = append(.whenQueueEnds, )

	// TODO remove ctx
	// if ctx != nil {
	// 	// fork in this special case
	// 	go func() {
	// 		<-ctx.Done()
	// 		mx.Lock()
	// 		defer mx.Unlock()
	// 		sm.whenQueueEnds = slicesWithout(sm.whenQueueEnds, binding)
	// 		close(ch)
	// 	}()
	// }

	return 
}

// WhenQueue waits until the passed queueTick gets processed.
func ( *Subscriptions) ( Result) <-chan struct{} {
	// TODO add gc ctx (just in case)
	// locks
	.Mx.Lock()
	defer .Mx.Unlock()

	// add the binding to an index of each state
	 := make(chan struct{})
	 := &whenQueueBinding{
		ch:   ,
		tick: ,
	}
	.log(LogOps, "[whenQueue:new] %d", )

	// insert the binding
	.whenQueue = append(.whenQueue, )

	return 
}

// ///// ///// /////

// ///// MISC

// ///// ///// /////

func ( *Subscriptions) () bool {
	.Mx.Lock()
	defer .Mx.Unlock()

	return len(.whenArgs) > 0
}

func ( *Subscriptions) () {
	// cancel ctx
	for ,  := range .stateCtx {
		.Cancel()
	}

	// close channels
	for  := range .when {
		for ,  := range .when[] {
			closeSafe(.Ch)
		}
	}
	for  := range .whenTime {
		for ,  := range .whenTime[] {
			closeSafe(.Ch)
		}
	}
	for  := range .whenArgs {
		for ,  := range .whenArgs[] {
			closeSafe(.ch)
		}
	}
	for ,  := range .whenQueueEnds {
		closeSafe(.ch)
	}
	for ,  := range .whenQueue {
		closeSafe(.ch)
	}
}