package rpc
import (
"context"
"errors"
"fmt"
"maps"
"net"
"os"
"slices"
"strconv"
"sync"
"sync/atomic"
"time"
"github.com/cenkalti/rpc2"
amhelp "github.com/pancsta/asyncmachine-go/pkg/helpers"
am "github.com/pancsta/asyncmachine-go/pkg/machine"
"github.com/pancsta/asyncmachine-go/pkg/rpc/states"
)
var (
ssC = states .ClientStates
ssCo = states .ConsumerStates
)
type Client struct {
*ExceptionHandler
Mach *am .Machine
Name string
Addr string
RequestSchema bool
Worker *Worker
Consumer *am .Machine
CallCount uint64
LogEnabled bool
DisconnCooldown time .Duration
LastMsgAt time .Time
HelloDelay time .Duration
ReconnectOn bool
ConnTimeout time .Duration
ConnRetries int
ConnRetryTimeout time .Duration
ConnRetryDelay time .Duration
ConnRetryBackoff time .Duration
CallTimeout time .Duration
CallRetries int
CallRetryTimeout time .Duration
CallRetryDelay time .Duration
CallRetryBackoff time .Duration
DisconnTimeout time .Duration
callLock sync .Mutex
rpc *rpc2 .Client
workerStates am .S
workerSchema am .Schema
conn net .Conn
tmpTestErr error
permTestErr error
connRetryRound atomic .Int32
}
var (
_ clientRpcMethods = &Client {}
_ clientServerMethods = &Client {}
)
func NewClient (
ctx context .Context , workerAddr string , name string , stateStruct am .Schema ,
stateNames am .S , opts *ClientOpts ,
) (*Client , error ) {
if workerAddr == "" {
return nil , errors .New ("rpcc: workerAddr required" )
}
if stateStruct == nil {
return nil , errors .New ("rpcc: stateStruct required" )
}
if stateNames == nil {
return nil , errors .New ("rpcc: stateNames required" )
}
if name == "" {
name = "rpc"
}
if opts == nil {
opts = &ClientOpts {}
}
c := &Client {
Name : name ,
ExceptionHandler : &ExceptionHandler {},
LogEnabled : os .Getenv (EnvAmRpcLogClient ) != "" ,
Addr : workerAddr ,
CallTimeout : 3 * time .Second ,
ConnTimeout : 3 * time .Second ,
DisconnTimeout : 3 * time .Second ,
DisconnCooldown : 10 * time .Millisecond ,
ReconnectOn : true ,
ConnRetryTimeout : 1 * time .Minute ,
ConnRetries : 15 ,
ConnRetryDelay : 100 * time .Millisecond ,
ConnRetryBackoff : 3 * time .Second ,
CallRetryTimeout : 1 * time .Minute ,
CallRetries : 15 ,
CallRetryDelay : 100 * time .Millisecond ,
CallRetryBackoff : 3 * time .Second ,
workerStates : slices .Clone (stateNames ),
workerSchema : maps .Clone (stateStruct ),
}
if amhelp .IsDebug () {
c .CallTimeout = 100 * time .Second
}
mach , err := am .NewCommon (ctx , GetClientId (name ), states .ClientSchema ,
ssC .Names (), c , opts .Parent , &am .Opts {Tags : []string {
"rpc-client" ,
"addr:" + workerAddr ,
}})
if err != nil {
return nil , err
}
mach .SemLogger ().SetArgsMapper (LogArgs )
c .Mach = mach
if os .Getenv (EnvAmRpcDbg ) != "" {
amhelp .MachDebugEnv (mach )
}
if opts .Consumer != nil {
err := amhelp .Implements (opts .Consumer .StateNames (), ssCo .Names ())
if err != nil {
err := fmt .Errorf (
"consumer has to implement pkg/rpc/states/ConsumerStatesDef: %w" , err )
return nil , err
}
c .Consumer = opts .Consumer
}
return c , nil
}
func (c *Client ) StartState (e *am .Event ) {
ctx := c .Mach .NewStateCtx (ssC .Start )
worker , err := NewWorker (ctx , "" , c , c .workerSchema , c .workerStates , c .Mach ,
nil )
if err != nil {
c .Mach .AddErr (err , nil )
}
c .Worker = worker
if os .Getenv (EnvAmRpcDbg ) != "" {
amhelp .MachDebugEnv (worker )
}
}
func (c *Client ) StartEnd (e *am .Event ) {
before := e .Transition ().TimeBefore
idx := e .Machine ().Index1
if before .Is ([]int {idx (ssC .Connecting ), idx (ssC .Exception )}) {
return
}
wasConn := before .Is1 (idx (ssC .Connecting )) || before .Is1 (idx (ssC .Connected ))
if wasConn {
c .Mach .EvAdd1 (e , ssC .Disconnecting , nil )
}
}
func (c *Client ) ConnectingState (e *am .Event ) {
ctx := c .Mach .NewStateCtx (ssC .Connecting )
go func () {
if ctx .Err () != nil {
return
}
timeout := c .ConnTimeout
if amhelp .IsDebug () {
timeout = 100 * time .Second
}
d := net .Dialer {
Timeout : timeout ,
}
c .Mach .Log ("dialing %s" , c .Addr )
conn , err := d .DialContext (ctx , "tcp4" , c .Addr )
if ctx .Err () != nil {
return
}
if err != nil {
c .Mach .EvAdd1 (e , ssC .Disconnected , nil )
AddErrNetwork (e , c .Mach , err )
return
}
c .conn = conn
c .bindRpcHandlers (conn )
go c .rpc .Run ()
c .Mach .EvAdd1 (e , ssC .Connected , nil )
}()
}
func (c *Client ) DisconnectingEnter (e *am .Event ) bool {
return c .rpc != nil && c .conn != nil
}
func (c *Client ) DisconnectingState (e *am .Event ) {
ctx := c .Mach .NewStateCtx (ssC .Disconnecting )
go func () {
if ctx .Err () != nil {
return
}
c .notify (ctx , ServerBye .Value , &Empty {})
if !amhelp .Wait (ctx , c .DisconnCooldown ) {
c .ensureGroupConnected (e )
return
}
if c .rpc != nil {
select {
case <- time .After (c .DisconnTimeout ):
c .log ("rpc.Close timeout" )
case <- amhelp .ExecAndClose (func () {
_ = c .rpc .Close ()
}):
c .log ("rpc.Close" )
case <- ctx .Done ():
}
}
if ctx .Err () != nil {
c .ensureGroupConnected (e )
return
}
c .Mach .EvAdd1 (e , ssC .Disconnected , nil )
}()
}
func (c *Client ) ConnectedState (e *am .Event ) {
ctx := c .Mach .NewStateCtx (ssC .Connected )
disconnCh := c .rpc .DisconnectNotify ()
c .connRetryRound .Store (0 )
go func () {
select {
case <- ctx .Done ():
return
case <- disconnCh :
c .log ("rpc.DisconnectNotify" )
c .Mach .EvAdd1 (e , ssC .Disconnected , nil )
}
}()
}
func (c *Client ) DisconnectedEnter (e *am .Event ) bool {
return !c .Mach .WillBe1 (ssC .Disconnecting )
}
func (c *Client ) DisconnectedState (e *am .Event ) {
wasAny := e .Transition ().TimeBefore .Any1
if wasAny (c .Mach .Index1 (ssC .Connected ), c .Mach .Index1 (ssC .Connecting )) &&
c .ReconnectOn {
c .Mach .EvAdd1 (e , ssC .RetryingConn , nil )
return
}
if c .conn != nil {
_ = c .conn .Close ()
}
}
func (c *Client ) HandshakingState (e *am .Event ) {
ctx := c .Mach .NewStateCtx (ssC .Connected )
go func () {
if ctx .Err () != nil {
return
}
resp := &RespHandshake {}
if c .HelloDelay > 0 {
if !amhelp .Wait (ctx , c .HelloDelay ) {
return
}
}
ok := false
delay := c .CallRetryDelay
timeout := c .CallTimeout / 2
for i := 0 ; i < c .ConnRetries ; i ++ {
rcpArgs := ArgsHello {ReqSchema : c .RequestSchema }
if c .call (ctx , ServerHello .Value , rcpArgs , resp , timeout ) {
ok = true
c .log ("hello ok on %d try" , i +1 )
break
}
if !amhelp .Wait (ctx , delay ) {
return
}
if c .CallRetryBackoff > 0 {
delay *= 2
if delay > c .CallRetryBackoff {
delay = c .CallRetryBackoff
}
}
}
if !ok {
c .Mach .EvAdd1 (e , ssC .RetryingConn , nil )
return
}
stateNames := resp .Serialized .StateNames
if len (stateNames ) == 0 {
AddErrRpcStr (e , c .Mach , "states missing" )
return
}
if resp .Serialized .ID == "" {
AddErrRpcStr (e , c .Mach , "ID missing" )
return
}
if c .RequestSchema && resp .Schema == nil {
AddErrRpcStr (e , c .Mach , "schema missing" )
return
}
c .RequestSchema = false
if resp .Schema != nil {
c .updateSchema (resp )
}
c .Worker .tags [1 ] = "src-id:" + resp .Serialized .ID
c .Worker .remoteId = resp .Serialized .ID
if !am .StatesEqual (c .workerStates , stateNames ) {
AddErrRpcStr (e , c .Mach , "States differ on client/server" )
return
}
if !c .call (ctx , ServerHandshake .Value , c .Mach .Id (), &Empty {}, 0 ) {
c .Mach .EvAdd1 (e , ssC .RetryingConn , nil )
return
}
c .Mach .EvAdd1 (e , ssC .HandshakeDone , Pass (&A {
Id : resp .Serialized .ID ,
MachTime : resp .Serialized .Time ,
QueueTick : resp .Serialized .QueueTick ,
}))
}()
}
func (c *Client ) updateSchema (resp *RespHandshake ) {
w := c .Worker
w .schemaMx .Lock ()
defer w .schemaMx .Unlock ()
w .clockMx .Lock ()
defer w .clockMx .Unlock ()
c .workerSchema = resp .Schema
c .workerStates = resp .Serialized .StateNames
w .schema = resp .Schema
w .stateNames = resp .Serialized .StateNames
w .queueTick = resp .Serialized .QueueTick
w .machTime = resp .Serialized .Time
for idx , state := range w .stateNames {
w .machClock [state ] = w .machTime [idx ]
}
}
func (c *Client ) HandshakeDoneEnter (e *am .Event ) bool {
a := ParseArgs (e .Args )
return a .Id != "" && a .MachTime != nil && a .QueueTick > 0
}
func (c *Client ) HandshakeDoneState (e *am .Event ) {
args := ParseArgs (e .Args )
w := c .Worker
w .id = "rw-" + c .Name
c .updateClock (nil , args .MachTime , args .QueueTick )
c .log ("connected to %s" , c .Worker .id )
c .log ("time t%d q%d: %v" ,
c .Worker .Time (nil ).Sum (nil ), c .Worker .QueueTick (), args .MachTime )
}
func (c *Client ) CallRetryFailedState (e *am .Event ) {
c .Mach .EvRemove1 (e , ssC .CallRetryFailed , nil )
}
func (c *Client ) RetryingCallEnter (e *am .Event ) bool {
return c .Mach .Any1 (ssC .Connected , ssC .RetryingConn )
}
func (c *Client ) ExceptionState (e *am .Event ) {
c .ExceptionHandler .ExceptionState (e )
c .Mach .EvRemove1 (e , am .StateException , nil )
}
func (c *Client ) RetryingConnState (e *am .Event ) {
ctx := c .Mach .NewStateCtx (ssC .RetryingConn )
delay := c .ConnRetryDelay
start := time .Now ()
go func () {
for ctx .Err () == nil && c .connRetryRound .Load () < int32 (c .ConnRetries ) {
c .connRetryRound .Add (1 )
if !amhelp .Wait (ctx , delay ) {
return
}
amhelp .Add1Block (ctx , c .Mach , ssC .Connecting , nil )
if ctx .Err () != nil {
return
}
_ = amhelp .WaitForErrAny (ctx , c .ConnTimeout *2 , c .Mach ,
c .Mach .WhenNot1 (ssC .Connecting , ctx ))
if ctx .Err () != nil {
return
}
c .Mach .EvRemove1 (e , ssC .Exception , nil )
if c .ConnRetryBackoff > 0 {
delay *= 2
if delay > c .ConnRetryBackoff {
delay = c .ConnRetryBackoff
}
}
if c .ConnRetryTimeout > 0 && time .Since (start ) > c .ConnRetryTimeout {
break
}
}
if ctx .Err () != nil {
return
}
c .Mach .EvRemove1 (e , ssC .RetryingConn , nil )
c .Mach .EvAdd1 (e , ssC .ConnRetryFailed , nil )
}()
}
func (c *Client ) WorkerPayloadEnter (e *am .Event ) bool {
if c .Consumer == nil {
return false
}
args := ParseArgs (e .Args )
argsOut := &A {Name : args .Name }
if args .Payload == nil {
err := errors .New ("invalid payload" )
c .Mach .AddErrState (ssC .ErrDelivery , err , Pass (argsOut ))
return false
}
return true
}
func (c *Client ) WorkerPayloadState (e *am .Event ) {
args := ParseArgs (e .Args )
argsOut := &A {
Name : args .Name ,
Payload : args .Payload ,
}
c .Consumer .EvAdd1 (e , ssCo .WorkerPayload , Pass (argsOut ))
}
func (c *Client ) HealthcheckState (e *am .Event ) {
c .Mach .EvRemove1 (e , ssC .Healthcheck , nil )
c .ensureGroupConnected (e )
}
func (c *Client ) Start () am .Result {
return c .Mach .Add (am .S {ssC .Start , ssC .Connecting }, nil )
}
func (c *Client ) Stop (waitTillExit context .Context , dispose bool ) am .Result {
res := c .Mach .Remove1 (ssC .Start , nil )
if res != am .Canceled && waitTillExit != nil {
_ = amhelp .WaitForAll (waitTillExit , 2 *time .Second ,
c .Mach .When1 (ssC .Disconnected , nil ))
}
if dispose {
c .log ("disposing" )
c .Mach .Dispose ()
c .Worker .Dispose ()
}
return res
}
func (c *Client ) GetKind () Kind {
return KindClient
}
func (c *Client ) ensureGroupConnected (e *am .Event ) {
groupConn := states .ClientGroups .Connected
if !c .Mach .Any1 (groupConn ...) && !c .Mach .WillBe (groupConn ) {
c .Mach .EvAdd1 (e , ssC .Disconnected , nil )
}
}
func (c *Client ) log (msg string , args ...any ) {
if !c .LogEnabled {
return
}
c .Mach .Log (msg , args ...)
}
func (c *Client ) bindRpcHandlers (conn net .Conn ) {
c .log ("new rpc2 client" )
c .rpc = rpc2 .NewClient (conn )
c .rpc .Handle (ClientSetClock .Value , c .RemoteSetClock )
c .rpc .Handle (ClientPushAllTicks .Value , c .RemotePushAllTicks )
c .rpc .Handle (ClientSendPayload .Value , c .RemoteSendPayload )
c .rpc .Handle (ClientBye .Value , c .RemoteBye )
c .rpc .Handle (ClientSchemaChange .Value , c .RemoteSchemaChange )
c .rpc .SetBlocking (true )
}
func (c *Client ) updateClock (
msg *ClockMsg , fullTime am .Time , fullQTick uint64 ,
) {
if c .Mach .Not1 (ssC .HandshakeDone ) {
return
}
c .Worker .clockMx .Lock ()
var clock am .Time
var qTick uint64
if msg != nil {
clock , qTick = ClockFromMsg (c .Worker .machTime , c .Worker .queueTick , msg )
check := Checksum (clock .Sum (nil ), qTick )
if check != msg .Checksum {
c .Mach .Log ("updateClock mismatch %d != %d" , msg .Checksum , check )
c .log ("msg q%d ch%d %+v" , msg .QueueTick , msg .Checksum , msg .Updates )
c .log ("clock t%d q%d ch%d (%+v)" , clock .Sum (nil ), qTick , check , clock )
c .Worker .clockMx .Unlock ()
return
}
} else {
clock = fullTime
qTick = fullQTick
}
if clock == nil {
c .Worker .clockMx .Unlock ()
return
}
var sum uint64
for _ , v := range clock {
sum += v
}
if msg != nil {
c .log ("updateClock diff OK t%d q%d" , sum , qTick )
} else {
c .log ("updateClock full OK t%d q%d" , sum , fullQTick )
}
c .Worker .InternalUpdateClock (clock , qTick , false )
}
func (c *Client ) callFailsafe (
ctx context .Context , method string , args , resp any ,
) bool {
mName := ServerMethods .Parse (method ).Value
if c .rpc == nil {
AddErrNoConn (nil , c .Mach , errors .New (mName ))
return false
}
c .callLock .Lock ()
defer c .callLock .Unlock ()
if c .call (ctx , method , args , resp , 0 ) {
return true
}
start := time .Now ()
worked := false
delay := c .CallRetryDelay
c .Mach .Add1 (ssC .RetryingCall , Pass (&A {
Method : mName ,
StartedAt : start ,
}))
defer func () {
if worked {
c .Mach .Remove1 (ssC .RetryingCall , nil )
} else {
c .Mach .Add1 (ssC .CallRetryFailed , Pass (&A {Method : mName }))
}
}()
for i := 0 ; i < c .CallRetries ; i ++ {
if !amhelp .Wait (ctx , delay ) {
return false
}
<-c .Mach .When1 (ssC .Ready , ctx )
if ctx .Err () != nil {
return false
}
if c .call (ctx , method , args , resp , 0 ) {
worked = true
return true
}
if c .CallRetryBackoff > 0 {
delay *= 2
if delay > c .CallRetryBackoff {
delay = c .CallRetryBackoff
}
}
if c .CallRetryTimeout > 0 && time .Since (start ) > c .CallRetryTimeout {
break
}
}
return false
}
func (c *Client ) call (
ctx context .Context , method string , args , resp any , timeout time .Duration ,
) bool {
defer c .Mach .PanicToErr (nil )
mName := ServerMethods .Parse (method ).Value
c .CallCount ++
if timeout == 0 {
timeout = c .CallTimeout
}
callCtx , cancel := context .WithTimeout (ctx , timeout )
defer cancel ()
err := c .rpc .CallWithContext (ctx , method , args , resp )
if ctx .Err () != nil {
return false
}
if callCtx .Err () != nil {
c .Mach .AddErrState (ssC .ErrNetworkTimeout , callCtx .Err (), nil )
return false
}
if c .tmpTestErr != nil {
AddErrNetwork (nil , c .Mach , fmt .Errorf ("%w: %s" , c .tmpTestErr , mName ))
c .tmpTestErr = nil
return false
}
if c .permTestErr != nil {
AddErrNetwork (nil , c .Mach , fmt .Errorf ("%w: %s" , c .tmpTestErr , mName ))
return false
}
if err != nil {
AddErr (nil , c .Mach , mName , err )
return false
}
return true
}
func (c *Client ) notifyFailsafe (
ctx context .Context , method string , args any ,
) bool {
mName := ServerMethods .Parse (method ).Value
if c .rpc == nil {
AddErrNoConn (nil , c .Mach , errors .New (mName ))
return false
}
c .callLock .Lock ()
defer c .callLock .Unlock ()
if c .notify (ctx , method , args ) {
return true
}
start := time .Now ()
worked := false
delay := c .CallRetryDelay
c .Mach .Add1 (ssC .RetryingCall , Pass (&A {
Method : mName ,
StartedAt : start ,
}))
defer func () {
if worked {
c .Mach .Remove1 (ssC .RetryingCall , nil )
} else {
c .Mach .Add1 (ssC .CallRetryFailed , Pass (&A {Method : mName }))
}
}()
for i := 0 ; i < c .CallRetries ; i ++ {
time .Sleep (delay )
if c .notify (ctx , method , args ) {
return true
}
if c .CallRetryBackoff > 0 {
delay *= 2
if delay > c .CallRetryBackoff {
delay = c .CallRetryBackoff
}
}
if c .CallRetryTimeout > 0 && time .Since (start ) > c .CallRetryTimeout {
break
}
}
return false
}
func (c *Client ) notify (
ctx context .Context , method string , args any ,
) bool {
defer c .Mach .PanicToErr (nil )
mName := ServerMethods .Parse (method ).Value
err := c .conn .SetDeadline (time .Now ().Add (c .CallTimeout ))
if err != nil {
AddErr (nil , c .Mach , mName , err )
return false
}
c .CallCount ++
err = c .rpc .Notify (method , args )
if ctx .Err () != nil {
return false
}
if err != nil {
AddErr (nil , c .Mach , method , err )
return false
}
err = c .conn .SetDeadline (time .Time {})
if err != nil {
AddErr (nil , c .Mach , mName , err )
return false
}
return true
}
func (c *Client ) RemoteSetClock (
_ *rpc2 .Client , clock *ClockMsg , _ *Empty ,
) error {
if clock == nil {
AddErrParams (nil , c .Mach , nil )
return nil
}
c .updateClock (clock , nil , 0 )
return nil
}
func (c *Client ) RemotePushAllTicks (
_ *rpc2 .Client , clocks []PushAllTicks , _ *Empty ,
) error {
return nil
}
func (c *Client ) RemoteSendingPayload (
_ *rpc2 .Client , payload *ArgsPayload , _ *Empty ,
) error {
c .log ("RemoteSendingPayload %s" , payload .Name )
c .Mach .Add1 (ssC .WorkerDelivering , Pass (&A {
Payload : payload ,
Name : payload .Name ,
}))
return nil
}
func (c *Client ) RemoteSendPayload (
_ *rpc2 .Client , payload *ArgsPayload , _ *Empty ,
) error {
c .log ("RemoteSendPayload %s:%s" , payload .Name , payload .Token )
c .Mach .Add1 (ssC .WorkerPayload , Pass (&A {
Payload : payload ,
Name : payload .Name ,
}))
return nil
}
func (c *Client ) RemoteBye (
_ *rpc2 .Client , _ *Empty , _ *Empty ,
) error {
c .Mach .Remove1 (ssC .Start , nil )
return nil
}
func (c *Client ) RemoteSchemaChange (
_ *rpc2 .Client , msg *RespHandshake , _ *Empty ,
) error {
c .log ("new schema v" + strconv .Itoa (len (msg .Serialized .StateNames )))
c .updateSchema (msg )
return nil
}
type ClientOpts struct {
Consumer *am .Machine
Parent am .Api
}
func GetClientId (name string ) string {
return "rc-" + name
}
The pages are generated with Golds v0.8.2 . (GOOS=linux GOARCH=amd64)
Golds is a Go 101 project developed by Tapir Liu .
PR and bug reports are welcome and can be submitted to the issue list .
Please follow @zigo_101 (reachable from the left QR code) to get the latest news of Golds .