package rpc
import (
"context"
"fmt"
"net"
"os"
"reflect"
"sync"
"sync/atomic"
"time"
"github.com/cenkalti/rpc2"
"github.com/pancsta/asyncmachine-go/internal/utils"
amhelp "github.com/pancsta/asyncmachine-go/pkg/helpers"
am "github.com/pancsta/asyncmachine-go/pkg/machine"
"github.com/pancsta/asyncmachine-go/pkg/rpc/states"
ampipe "github.com/pancsta/asyncmachine-go/pkg/states/pipes"
)
var (
ssS = states .ServerStates
ssW = states .WorkerStates
)
type Server struct {
*ExceptionHandler
Mach *am .Machine
Source am .Api
Addr string
DeliveryTimeout time .Duration
PushInterval time .Duration
PushAllTicks bool
Listener atomic .Pointer [net .Listener ]
Conn net .Conn
NoNewListener bool
LogEnabled bool
CallCount uint64
AllowId string
rpcServer *rpc2 .Server
rpcClient atomic .Pointer [rpc2 .Client ]
clockMx sync .Mutex
ticker *time .Ticker
mutMx sync .Mutex
skipClockPush atomic .Bool
tracer *WorkerTracer
clientId atomic .Pointer [string ]
deliveryHandlers any
lastClockHTime time .Time
lastClock am .Time
lastClockSum atomic .Uint64
lastClockMsg *ClockMsg
lastQueueTick uint64
}
var (
_ serverRpcMethods = &Server {}
_ clientServerMethods = &Server {}
)
func NewServer (
ctx context .Context , addr string , name string , sourceMach am .Api ,
opts *ServerOpts ,
) (*Server , error ) {
if name == "" {
name = "rpc"
}
if opts == nil {
opts = &ServerOpts {}
}
if !sourceMach .StatesVerified () {
return nil , fmt .Errorf ("worker states not verified, call VerifyStates()" )
}
hasHandlers := sourceMach .HasHandlers ()
if hasHandlers && !sourceMach .Has (ssW .Names ()) {
err := fmt .Errorf (
"%w: RPC worker with handlers has to implement " +
"pkg/rpc/states/WorkerStatesDef" ,
am .ErrSchema )
return nil , err
}
s := &Server {
ExceptionHandler : &ExceptionHandler {},
Addr : addr ,
PushInterval : 250 * time .Millisecond ,
DeliveryTimeout : 5 * time .Second ,
LogEnabled : os .Getenv (EnvAmRpcLogServer ) != "" ,
Source : sourceMach ,
lastQueueTick : 1 ,
}
var sum uint64
s .lastClockSum .Store (sum )
mach , err := am .NewCommon (ctx , "rs-" +name , states .ServerSchema , ssS .Names (),
s , opts .Parent , &am .Opts {Tags : []string {"rpc-server" }})
if err != nil {
return nil , err
}
mach .SemLogger ().SetArgsMapper (LogArgs )
mach .OnDispose (func (id string , ctx context .Context ) {
if l := s .Listener .Load (); l != nil {
_ = (*l ).Close ()
s .Listener .Store (nil )
}
s .rpcServer = nil
_ = s .Source .DetachTracer (s .tracer )
_ = s .Source .DetachHandlers (s .deliveryHandlers )
})
s .Mach = mach
if os .Getenv (EnvAmRpcDbg ) != "" {
amhelp .MachDebugEnv (mach )
}
s .tracer = &WorkerTracer {s : s }
_ = sourceMach .BindTracer (s .tracer )
if hasHandlers {
payloadState := ssW .SendPayload
if opts .PayloadState != "" {
payloadState = opts .PayloadState
}
var h any
if payloadState == ssW .SendPayload {
h = &SendPayloadHandlers {
SendPayloadState : getSendPayloadState (s , ssW .SendPayload ),
}
} else {
h = createSendPayloadHandlers (s , payloadState )
}
err = sourceMach .BindHandlers (h )
if err != nil {
return nil , err
}
mach .OnDispose (func (id string , ctx context .Context ) {
_ = sourceMach .DetachHandlers (h )
})
}
return s , nil
}
func (s *Server ) StartEnd (e *am .Event ) {
if ParseArgs (e .Args ).Dispose {
s .Mach .Dispose ()
}
}
func (s *Server ) RpcStartingEnter (e *am .Event ) bool {
if s .Listener .Load () == nil && s .NoNewListener {
return false
}
if s .Addr == "" {
return false
}
return true
}
func (s *Server ) RpcStartingState (e *am .Event ) {
ctxRpcStarting := s .Mach .NewStateCtx (ssS .RpcStarting )
ctxStart := s .Mach .NewStateCtx (ssS .Start )
s .log ("Starting RPC on %s" , s .Addr )
s .bindRpcHandlers ()
srv := s .rpcServer
go func () {
if ctxStart .Err () != nil {
return
}
if s .Conn != nil {
s .Addr = s .Conn .LocalAddr ().String ()
} else if l := s .Listener .Load (); l != nil {
s .Addr = (*l ).Addr ().String ()
} else {
cfg := net .ListenConfig {}
lis , err := cfg .Listen (ctxStart , "tcp4" , s .Addr )
if err != nil {
AddErrNetwork (e , s .Mach , err )
s .Mach .Remove1 (ssS .RpcStarting , nil )
return
}
s .Listener .Store (&lis )
s .Addr = lis .Addr ().String ()
}
s .log ("RPC started on %s" , s .Addr )
go func () {
if ctxRpcStarting .Err () != nil {
return
}
s .Mach .EvAdd1 (e , ssS .RpcReady , Pass (&A {Addr : s .Addr }))
lisP := s .Listener .Load ()
if s .Conn != nil {
srv .ServeConn (s .Conn )
} else {
srv .Accept (*lisP )
}
if ctxStart .Err () != nil {
return
}
if lisP != nil {
(*lisP ).Close ()
s .Listener .Store (nil )
}
if ctxStart .Err () != nil {
return
}
if s .Mach .Is1 (ssS .Start ) {
s .Mach .EvRemove1 (e , ssS .RpcReady , nil )
s .Mach .EvAdd1 (e , ssS .RpcStarting , nil )
}
}()
srv .OnDisconnect (func (client *rpc2 .Client ) {
s .Mach .EvRemove1 (e , ssS .ClientConnected , Pass (&A {Client : client }))
})
srv .OnConnect (func (client *rpc2 .Client ) {
s .Mach .EvAdd1 (e , ssS .ClientConnected , Pass (&A {Client : client }))
})
}()
}
func (s *Server ) RpcReadyEnter (e *am .Event ) bool {
return s .Mach .Is1 (ssS .RpcStarting )
}
func (s *Server ) RpcReadyState (e *am .Event ) {
if s .PushInterval == 0 {
return
}
ctx := s .Mach .NewStateCtx (ssS .RpcReady )
if s .ticker == nil {
s .ticker = time .NewTicker (s .PushInterval )
}
t := s .ticker
go func () {
if ctx .Err () != nil {
return
}
for {
select {
case <- ctx .Done ():
s .ticker .Stop ()
return
case <- t .C :
s .pushClockUpdate (false )
}
}
}()
}
func (s *Server ) HandshakeDoneEnd (e *am .Event ) {
if c := s .rpcClient .Load (); c != nil {
_ = c .Close ()
}
}
func (s *Server ) Start () am .Result {
return s .Mach .Add1 (ssS .Start , nil )
}
func (s *Server ) Stop (dispose bool ) am .Result {
if s .Mach == nil {
return am .Canceled
}
if dispose {
s .log ("disposing" )
}
res := s .Mach .Remove1 (ssS .Start , Pass (&A {
Dispose : dispose ,
}))
return res
}
func (s *Server ) SendPayload (
ctx context .Context , event *am .Event , payload *ArgsPayload ,
) error {
if s .Mach .Not1 (ssS .ClientConnected ) || s .Mach .Not1 (ssS .HandshakeDone ) {
return ErrNoConn
}
id := s .ClientId ()
if payload .Destination != "" && id != payload .Destination {
return fmt .Errorf ("%w: %s != %s" , ErrDestination , payload .Destination , id )
}
defer s .Mach .PanicToErr (nil )
payload .Token = utils .RandId (0 )
if event != nil {
payload .Source = event .MachineId
payload .SourceTx = event .TransitionId
}
s .log ("sending payload %s from %s to %s" , payload .Name , payload .Source ,
payload .Destination )
return s .rpcClient .Load ().CallWithContext (ctx ,
ClientSendPayload .Value , payload , &Empty {})
}
func (s *Server ) ClientId () string {
id := s .clientId .Load ()
if id == nil {
return ""
}
return *id
}
func (s *Server ) GetKind () Kind {
return KindServer
}
func (s *Server ) log (msg string , args ...any ) {
if !s .LogEnabled {
return
}
s .Mach .Log (msg , args ...)
}
func (s *Server ) bindRpcHandlers () {
s .rpcServer = rpc2 .NewServer ()
s .rpcServer .Handle (ServerHello .Value , s .RemoteHello )
s .rpcServer .Handle (ServerHandshake .Value , s .RemoteHandshake )
s .rpcServer .Handle (ServerAdd .Value , s .RemoteAdd )
s .rpcServer .Handle (ServerAddNS .Value , s .RemoteAddNS )
s .rpcServer .Handle (ServerRemove .Value , s .RemoteRemove )
s .rpcServer .Handle (ServerSet .Value , s .RemoteSet )
s .rpcServer .Handle (ServerSync .Value , s .RemoteSync )
s .rpcServer .Handle (ServerBye .Value , s .RemoteBye )
}
func (s *Server ) pushClockUpdate (force bool ) {
c := s .rpcClient .Load ()
if c == nil {
return
}
if s .skipClockPush .Load () && !force {
return
}
if s .Mach .Not1 (ssS .ClientConnected ) ||
s .Mach .Not1 (ssS .HandshakeDone ) {
return
}
if s .PushInterval == 0 && !force {
return
}
clock := s .genClockUpdate (false )
if clock == nil {
return
}
defer s .Mach .PanicToErr (nil )
s .log ("pushClockUpdate %d" , s .lastClockSum .Load ())
s .CallCount ++
err := c .Notify (ClientSetClock .Value , clock )
if err != nil {
s .Mach .Remove1 (ssS .ClientConnected , nil )
AddErr (nil , s .Mach , "pushClockUpdate" , err )
}
}
func (s *Server ) genClockUpdate (skipTimeCheck bool ) *ClockMsg {
s .clockMx .Lock ()
defer s .clockMx .Unlock ()
if !skipTimeCheck && (time .Since (s .lastClockHTime ) < s .PushInterval ) {
return nil
}
hTime := time .Now ()
qTick := s .Source .QueueTick ()
mTime := s .Source .Time (nil )
var tSum uint64
for _ , v := range mTime {
tSum += v
}
if tSum == s .lastClockSum .Load () && qTick == s .lastQueueTick {
return nil
}
s .lastClockMsg = NewClockMsg (tSum , s .lastClock , mTime , s .lastQueueTick , qTick )
s .lastClock = mTime
s .lastClockHTime = hTime
s .lastQueueTick = qTick
s .lastClockSum .Store (tSum )
s .log ("genClockUpdate: t%d q%d ch%d (%s)" , tSum , qTick ,
s .lastClockMsg .Checksum , s .Source .ActiveStates (nil ))
return s .lastClockMsg
}
func (s *Server ) RemoteHello (
client *rpc2 .Client , req *ArgsHello , resp *RespHandshake ,
) error {
if s .Mach .Not1 (ssS .Start ) {
return am .ErrCanceled
}
s .clockMx .Lock ()
defer s .clockMx .Unlock ()
export := s .Source .Export ()
*resp = RespHandshake {
Serialized : export ,
}
if req .ReqSchema {
schema := s .Source .Schema ()
resp .Schema = schema
}
sum := export .Time .Sum (nil )
s .log ("RemoteHello: t%v q%d" , sum , export .QueueTick )
s .Mach .Add1 (ssS .Handshaking , nil )
s .lastClock = export .Time
s .lastQueueTick = export .QueueTick
s .lastClockSum .Store (sum )
s .lastClockHTime = time .Now ()
return nil
}
func (s *Server ) RemoteHandshake (
client *rpc2 .Client , id *string , _ *Empty ,
) error {
if s .Mach .Not1 (ssS .Start ) {
return am .ErrCanceled
}
if id == nil || *id == "" {
s .Mach .Remove1 (ssS .Handshaking , nil )
AddErrRpcStr (nil , s .Mach , "handshake failed: ID missing" )
return ErrInvalidParams
}
if s .AllowId != "" && *id != s .AllowId {
s .Mach .Remove1 (ssS .Handshaking , nil )
return fmt .Errorf ("%w: %s != %s" , ErrNoAccess , *id , s .AllowId )
}
sum := s .Source .Time (nil ).Sum (nil )
qTick := s .Source .QueueTick ()
s .log ("RemoteHandshake: t%v q%d" , sum , qTick )
s .rpcClient .Store (client )
s .clientId .Store (id )
s .Mach .Add1 (ssS .HandshakeDone , Pass (&A {Id : *id }))
if s .lastClockSum .Load () != sum || s .lastQueueTick != qTick &&
s .PushInterval == 0 {
s .pushClockUpdate (true )
}
return nil
}
func (s *Server ) RemoteAdd (
_ *rpc2 .Client , args *ArgsMut , resp *RespResult ,
) error {
if s .Mach .Not1 (ssS .Start ) {
return am .ErrCanceled
}
s .mutMx .Lock ()
defer s .mutMx .Unlock ()
if args .States == nil {
return ErrInvalidParams
}
var val am .Result
s .skipClockPush .Store (true )
if args .Event != nil {
val = s .Source .EvAdd (args .Event , amhelp .IndexesToStates (
s .Source .StateNames (), args .States ), args .Args )
} else {
val = s .Source .Add (amhelp .IndexesToStates (s .Source .StateNames (),
args .States ), args .Args )
}
*resp = RespResult {
Result : val ,
Clock : s .genClockUpdate (true ),
}
s .skipClockPush .Store (false )
return nil
}
func (s *Server ) RemoteAddNS (
_ *rpc2 .Client , args *ArgsMut , _ *Empty ,
) error {
if s .Mach .Not1 (ssS .Start ) {
return am .ErrCanceled
}
s .mutMx .Lock ()
defer s .mutMx .Unlock ()
if args .States == nil {
return ErrInvalidParams
}
s .skipClockPush .Store (true )
_ = s .Source .Add (amhelp .IndexesToStates (s .Source .StateNames (), args .States ),
args .Args )
s .skipClockPush .Store (false )
return nil
}
func (s *Server ) RemoteRemove (
_ *rpc2 .Client , args *ArgsMut , resp *RespResult ,
) error {
if s .Mach .Not1 (ssS .Start ) {
return am .ErrCanceled
}
s .mutMx .Lock ()
defer s .mutMx .Unlock ()
if args .States == nil {
return ErrInvalidParams
}
s .skipClockPush .Store (true )
val := s .Source .Remove (amhelp .IndexesToStates (s .Source .StateNames (),
args .States ), args .Args )
s .skipClockPush .Store (false )
*resp = RespResult {
Result : val ,
Clock : s .genClockUpdate (true ),
}
return nil
}
func (s *Server ) RemoteSet (
_ *rpc2 .Client , args *ArgsMut , resp *RespResult ,
) error {
if s .Mach .Not1 (ssS .Start ) {
return am .ErrCanceled
}
s .mutMx .Lock ()
defer s .mutMx .Unlock ()
if args .States == nil {
return ErrInvalidParams
}
s .skipClockPush .Store (true )
val := s .Source .Set (amhelp .IndexesToStates (s .Source .StateNames (),
args .States ), args .Args )
s .skipClockPush .Store (false )
*resp = RespResult {
Result : val ,
Clock : s .genClockUpdate (true ),
}
return nil
}
func (s *Server ) RemoteSync (
_ *rpc2 .Client , sum uint64 , resp *RespSync ,
) error {
if s .Mach .Not1 (ssS .Start ) {
return am .ErrCanceled
}
s .log ("RemoteSync" )
if s .Source .Time (nil ).Sum (nil ) > sum {
*resp = RespSync {
Time : s .Source .Time (nil ),
QueueTick : s .Source .QueueTick (),
}
} else {
*resp = RespSync {}
}
s .log ("RemoteSync: %v" , resp .Time )
return nil
}
func (s *Server ) RemoteBye (
_ *rpc2 .Client , _ *Empty , _ *Empty ,
) error {
if s .Mach .Not1 (ssS .Start ) {
return am .ErrCanceled
}
s .log ("RemoteBye" )
s .Mach .Remove1 (ssS .ClientConnected , Pass (&A {
Addr : s .Addr ,
}))
go func () {
select {
case <- time .After (100 * time .Millisecond ):
s .log ("rpc.Close timeout" )
case <- amhelp .ExecAndClose (func () {
if c := s .rpcClient .Load (); c != nil {
_ = c .Close ()
}
}):
s .log ("rpc.Close" )
}
time .Sleep (100 * time .Millisecond )
s .Mach .Remove1 (ssS .HandshakeDone , nil )
}()
s .rpcClient .Store (nil )
s .clientId .Store (nil )
return nil
}
func (s *Server ) RemoteSetPushAllTicks (
_ *rpc2 .Client , val bool , _ *Empty ,
) error {
if s .Mach .Not1 (ssS .Start ) {
return am .ErrCanceled
}
s .log ("RemoteSetPushAllTicks" )
s .PushAllTicks = val
return nil
}
func BindServer (source , target *am .Machine , rpcReady , clientConn string ) error {
if rpcReady == "" || clientConn == "" {
return fmt .Errorf ("rpcReady and clientConn must be set" )
}
h := &struct {
RpcReadyState am .HandlerFinal
RpcReadyEnd am .HandlerFinal
HandshakeDoneState am .HandlerFinal
HandshakeDoneEnd am .HandlerFinal
}{
RpcReadyState : ampipe .Add (source , target , ssS .RpcReady , rpcReady ),
RpcReadyEnd : ampipe .Remove (source , target , ssS .RpcReady , rpcReady ),
HandshakeDoneState : ampipe .Add (source , target , ssS .ClientConnected ,
clientConn ),
HandshakeDoneEnd : ampipe .Remove (source , target , ssS .ClientConnected ,
clientConn ),
}
return source .BindHandlers (h )
}
func BindServerMulti (
source , target *am .Machine , rpcReady , clientConn , clientDisconn string ,
) error {
if rpcReady == "" || clientConn == "" || clientDisconn == "" {
return fmt .Errorf ("rpcReady, clientConn, and clientDisconn must be set" )
}
h := &struct {
RpcReadyState am .HandlerFinal
RpcReadyEnd am .HandlerFinal
HandshakeDoneState am .HandlerFinal
HandshakeDoneEnd am .HandlerFinal
}{
RpcReadyState : ampipe .Add (source , target , ssS .RpcReady , rpcReady ),
RpcReadyEnd : ampipe .Remove (source , target , ssS .RpcReady , rpcReady ),
HandshakeDoneState : ampipe .Add (source , target ,
ssS .ClientConnected , clientConn ),
HandshakeDoneEnd : ampipe .Add (source , target ,
ssS .ClientConnected , clientDisconn ),
}
return source .BindHandlers (h )
}
func BindServerRpcReady (source , target *am .Machine , rpcReady string ) error {
h := &struct {
RpcReadyState am .HandlerFinal
}{
RpcReadyState : ampipe .Add (source , target , ssS .RpcReady , rpcReady ),
}
return source .BindHandlers (h )
}
type ServerOpts struct {
PayloadState string
Parent am .Api
}
type SendPayloadHandlers struct {
SendPayloadState am .HandlerFinal
}
func getSendPayloadState(s *Server , stateName string ) am .HandlerFinal {
return func (e *am .Event ) {
e .Machine ().EvRemove1 (e , stateName , nil )
ctx := s .Mach .NewStateCtx (ssS .Start )
args := ParseArgs (e .Args )
argsOut := &A {Name : args .Name }
if args .Payload == nil || args .Name == "" {
err := fmt .Errorf ("invalid payload args [name, payload]" )
e .Machine ().EvAddErrState (e , ssW .ErrSendPayload , err , Pass (argsOut ))
return
}
go func () {
ctx , cancel := context .WithTimeout (ctx , s .DeliveryTimeout )
defer cancel ()
err := s .SendPayload (ctx , e , args .Payload )
if err != nil {
e .Machine ().EvAddErrState (e , ssW .ErrSendPayload , err , Pass (argsOut ))
}
}()
}
}
func createSendPayloadHandlers(s *Server , stateName string ) any {
fn := getSendPayloadState (s , stateName )
structType := reflect .StructOf ([]reflect .StructField {
{
Name : stateName + "State" ,
Type : reflect .TypeOf (fn ),
},
})
val := reflect .New (structType ).Elem ()
val .Field (0 ).Set (reflect .ValueOf (fn ))
ret := val .Addr ().Interface ()
return ret
}
The pages are generated with Golds v0.8.2 . (GOOS=linux GOARCH=amd64)
Golds is a Go 101 project developed by Tapir Liu .
PR and bug reports are welcome and can be submitted to the issue list .
Please follow @zigo_101 (reachable from the left QR code) to get the latest news of Golds .