package lifecycle
import (
"context"
"errors"
"fmt"
"io"
"reflect"
"strings"
"sync"
"time"
"go.uber.org/fx/fxevent"
"go.uber.org/fx/internal/fxclock"
"go.uber.org/fx/internal/fxreflect"
"go.uber.org/multierr"
)
var (
_reflFunc = reflect .TypeOf (Func (nil ))
_reflErrorFunc = reflect .TypeOf (ErrorFunc (nil ))
_reflContextFunc = reflect .TypeOf (ContextFunc (nil ))
_reflContextErrorFunc = reflect .TypeOf (ContextErrorFunc (nil ))
)
type (
Func = func ()
ErrorFunc = func () error
ContextFunc = func (context .Context )
ContextErrorFunc = func (context .Context ) error
)
type Callable interface {
~Func | ~ErrorFunc | ~ContextFunc | ~ContextErrorFunc
}
func Wrap [T Callable ](x T ) (ContextErrorFunc , string ) {
if x == nil {
return nil , ""
}
switch fn := any (x ).(type ) {
case Func :
return func (context .Context ) error {
fn ()
return nil
}, fxreflect .FuncName (x )
case ErrorFunc :
return func (context .Context ) error {
return fn ()
}, fxreflect .FuncName (x )
case ContextFunc :
return func (ctx context .Context ) error {
fn (ctx )
return nil
}, fxreflect .FuncName (x )
case ContextErrorFunc :
return fn , fxreflect .FuncName (x )
}
reflVal := reflect .ValueOf (x )
switch {
case reflVal .CanConvert (_reflFunc ):
return Wrap (reflVal .Convert (_reflFunc ).Interface ().(Func ))
case reflVal .CanConvert (_reflErrorFunc ):
return Wrap (reflVal .Convert (_reflErrorFunc ).Interface ().(ErrorFunc ))
case reflVal .CanConvert (_reflContextFunc ):
return Wrap (reflVal .Convert (_reflContextFunc ).Interface ().(ContextFunc ))
default :
return Wrap (reflVal .Convert (_reflContextErrorFunc ).Interface ().(ContextErrorFunc ))
}
}
type Hook struct {
OnStart func (context .Context ) error
OnStop func (context .Context ) error
OnStartName string
OnStopName string
callerFrame fxreflect .Frame
}
type appState int
const (
stopped appState = iota
starting
incompleteStart
started
stopping
)
func (as appState ) String () string {
switch as {
case stopped :
return "stopped"
case starting :
return "starting"
case incompleteStart :
return "incompleteStart"
case started :
return "started"
case stopping :
return "stopping"
default :
return "invalidState"
}
}
type Lifecycle struct {
clock fxclock .Clock
logger fxevent .Logger
state appState
hooks []Hook
numStarted int
startRecords HookRecords
stopRecords HookRecords
runningHook Hook
mu sync .Mutex
}
func New (logger fxevent .Logger , clock fxclock .Clock ) *Lifecycle {
return &Lifecycle {logger : logger , clock : clock }
}
func (l *Lifecycle ) Append (hook Hook ) {
if f := fxreflect .CallerStack (2 , 0 ); len (f ) > 0 {
hook .callerFrame = f [0 ]
}
l .hooks = append (l .hooks , hook )
}
func (l *Lifecycle ) Start (ctx context .Context ) error {
if ctx == nil {
return errors .New ("called OnStart with nil context" )
}
l .mu .Lock ()
if l .state != stopped {
defer l .mu .Unlock ()
return fmt .Errorf ("attempted to start lifecycle when in state: %v" , l .state )
}
l .numStarted = 0
l .state = starting
l .startRecords = make (HookRecords , 0 , len (l .hooks ))
l .mu .Unlock ()
returnState := incompleteStart
defer func () {
l .mu .Lock ()
l .state = returnState
l .mu .Unlock ()
}()
for _ , hook := range l .hooks {
if err := ctx .Err (); err != nil {
return err
}
if hook .OnStart != nil {
l .mu .Lock ()
l .runningHook = hook
l .mu .Unlock ()
runtime , err := l .runStartHook (ctx , hook )
if err != nil {
return err
}
l .mu .Lock ()
l .startRecords = append (l .startRecords , HookRecord {
CallerFrame : hook .callerFrame ,
Func : hook .OnStart ,
Runtime : runtime ,
})
l .mu .Unlock ()
}
l .numStarted ++
}
returnState = started
return nil
}
func (l *Lifecycle ) runStartHook (ctx context .Context , hook Hook ) (runtime time .Duration , err error ) {
funcName := hook .OnStartName
if len (funcName ) == 0 {
funcName = fxreflect .FuncName (hook .OnStart )
}
l .logger .LogEvent (&fxevent .OnStartExecuting {
CallerName : hook .callerFrame .Function ,
FunctionName : funcName ,
})
defer func () {
l .logger .LogEvent (&fxevent .OnStartExecuted {
CallerName : hook .callerFrame .Function ,
FunctionName : funcName ,
Runtime : runtime ,
Err : err ,
})
}()
begin := l .clock .Now ()
err = hook .OnStart (ctx )
return l .clock .Since (begin ), err
}
func (l *Lifecycle ) Stop (ctx context .Context ) error {
if ctx == nil {
return errors .New ("called OnStop with nil context" )
}
l .mu .Lock ()
if l .state != started && l .state != incompleteStart && l .state != starting {
defer l .mu .Unlock ()
return nil
}
l .state = stopping
l .mu .Unlock ()
defer func () {
l .mu .Lock ()
l .state = stopped
l .mu .Unlock ()
}()
l .mu .Lock ()
l .stopRecords = make (HookRecords , 0 , l .numStarted )
allHooks := l .hooks [:]
numStarted := l .numStarted
l .mu .Unlock ()
var errs []error
for ; numStarted > 0 ; numStarted -- {
if err := ctx .Err (); err != nil {
return err
}
hook := allHooks [numStarted -1 ]
if hook .OnStop == nil {
continue
}
l .mu .Lock ()
l .runningHook = hook
l .mu .Unlock ()
runtime , err := l .runStopHook (ctx , hook )
if err != nil {
errs = append (errs , err )
}
l .mu .Lock ()
l .stopRecords = append (l .stopRecords , HookRecord {
CallerFrame : hook .callerFrame ,
Func : hook .OnStop ,
Runtime : runtime ,
})
l .mu .Unlock ()
}
return multierr .Combine (errs ...)
}
func (l *Lifecycle ) runStopHook (ctx context .Context , hook Hook ) (runtime time .Duration , err error ) {
funcName := hook .OnStopName
if len (funcName ) == 0 {
funcName = fxreflect .FuncName (hook .OnStop )
}
l .logger .LogEvent (&fxevent .OnStopExecuting {
CallerName : hook .callerFrame .Function ,
FunctionName : funcName ,
})
defer func () {
l .logger .LogEvent (&fxevent .OnStopExecuted {
CallerName : hook .callerFrame .Function ,
FunctionName : funcName ,
Runtime : runtime ,
Err : err ,
})
}()
begin := l .clock .Now ()
err = hook .OnStop (ctx )
return l .clock .Since (begin ), err
}
func (l *Lifecycle ) RunningHookCaller () string {
l .mu .Lock ()
defer l .mu .Unlock ()
return l .runningHook .callerFrame .Function
}
type HookRecord struct {
CallerFrame fxreflect .Frame
Func func (context .Context ) error
Runtime time .Duration
}
type HookRecords []HookRecord
func (rs HookRecords ) Len () int {
return len (rs )
}
func (rs HookRecords ) Less (i , j int ) bool {
return rs [i ].Runtime > rs [j ].Runtime
}
func (rs HookRecords ) Swap (i , j int ) {
rs [i ], rs [j ] = rs [j ], rs [i ]
}
func (rs HookRecords ) String () string {
var b strings .Builder
for _ , r := range rs {
fmt .Fprintf (&b , "%s took %v from %s" ,
fxreflect .FuncName (r .Func ), r .Runtime , r .CallerFrame )
}
return b .String ()
}
func (rs HookRecords ) Format (w fmt .State , c rune ) {
if !w .Flag ('+' ) {
io .WriteString (w , rs .String ())
return
}
for _ , r := range rs {
fmt .Fprintf (w , "\n%s took %v from:\n\t%+v" ,
fxreflect .FuncName (r .Func ),
r .Runtime ,
r .CallerFrame )
}
fmt .Fprintf (w , "\n" )
}
The pages are generated with Golds v0.8.2 . (GOOS=linux GOARCH=amd64)
Golds is a Go 101 project developed by Tapir Liu .
PR and bug reports are welcome and can be submitted to the issue list .
Please follow @zigo_101 (reachable from the left QR code) to get the latest news of Golds .