package rpc
import (
"bufio"
"encoding/gob"
"errors"
"fmt"
"io"
"net"
"os"
"path/filepath"
"reflect"
"slices"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/cenkalti/rpc2"
"github.com/orsinium-labs/enum"
amhelp "github.com/pancsta/asyncmachine-go/pkg/helpers"
am "github.com/pancsta/asyncmachine-go/pkg/machine"
"github.com/pancsta/asyncmachine-go/pkg/rpc/states"
)
func init() {
gob .Register (&ARpc {})
gob .Register (am .Relation (0 ))
}
const (
EnvAmRpcLogServer = "AM_RPC_LOG_SERVER"
EnvAmRpcLogClient = "AM_RPC_LOG_CLIENT"
EnvAmRpcLogMux = "AM_RPC_LOG_MUX"
EnvAmRpcDbg = "AM_RPC_DBG"
EnvAmReplAddr = "AM_REPL_ADDR"
EnvAmReplDir = "AM_REPL_DIR"
PrefixNetMach = "rnm-"
)
var ss = states .SharedStates
type (
ServerMethod enum .Member [string ]
ClientMethod enum .Member [string ]
)
var (
ServerAdd = ServerMethod {"Add" }
ServerAddNS = ServerMethod {"AddNS" }
ServerRemove = ServerMethod {"Remove" }
ServerSet = ServerMethod {"Set" }
ServerHello = ServerMethod {"Hello" }
ServerHandshake = ServerMethod {"Handshake" }
ServerLog = ServerMethod {"Log" }
ServerSync = ServerMethod {"Sync" }
ServerArgs = ServerMethod {"Args" }
ServerBye = ServerMethod {"Close" }
ServerMethods = enum .New (ServerAdd , ServerAddNS , ServerRemove , ServerSet ,
ServerHello , ServerHandshake , ServerLog , ServerSync , ServerArgs , ServerBye )
ClientUpdate = ClientMethod {"ClientSetClock" }
ClientUpdateMutations = ClientMethod {"ClientSetClockMany" }
ClientPushAllTicks = ClientMethod {"ClientPushAllTicks" }
ClientSendPayload = ClientMethod {"ClientSendPayload" }
ClientBye = ClientMethod {"ClientBye" }
ClientSchemaChange = ClientMethod {"SchemaChange" }
ClientMethods = enum .New (ClientUpdate , ClientPushAllTicks ,
ClientSendPayload , ClientBye , ClientSchemaChange )
)
type MsgCliHello struct {
Id string
SyncSchema bool
SchemaHash string
SyncMutations bool
AllowedStates am .S
SkippedStates am .S
ShallowClocks bool
}
type MsgSrvHello struct {
Schema am .Schema
Serialized *am .Serialized
StatesCount uint32
}
type MsgCliMutation struct {
States []int
Args am .A
Event *am .Event
}
type MsgSrvMutation struct {
Update *MsgSrvUpdate
Mutations *MsgSrvUpdateMuts
Result am .Result
}
type MsgSrvPayload struct {
Name string
Source string
SourceTx string
Destination string
Data any
Token string
}
type MsgSrvSync struct {
Time am .Time
QueueTick uint64
MachTick uint32
}
type MsgSrvArgs struct {
Args []string
}
type MsgEmpty struct {}
type MsgSrvUpdate struct {
Indexes []uint16
Ticks []uint32
QueueTick uint16
MachTick uint8
Checksum uint8
}
type MsgSrvUpdateMuts struct {
MutationType []am .MutationType
CalledStates [][]uint16
Updates []MsgSrvUpdate
}
type clientServerMethods interface {
GetKind() Kind
}
type Kind string
const (
KindClient Kind = "client"
KindServer Kind = "server"
)
type ReplOpts struct {
AddrDir string
ErrCh chan <- error
AddrCh chan <- string
ArgsPrefix string
Args any
}
const APrefix = "am_rpc"
type A struct {
Id string `log:"id"`
Name string `log:"name"`
MachTime am .Time
QueueTick uint64
MachTick uint32
Payload *MsgSrvPayload
Addr string `log:"addr"`
Err error
Method string `log:"addr"`
StartedAt time .Time
Dispose bool
Client *rpc2 .Client
}
type ARpc struct {
Id string `log:"id"`
Name string `log:"name"`
MachTime am .Time
QueueTick uint64
MachTick uint32
Payload *MsgSrvPayload
Addr string `log:"addr"`
Err error
Method string `log:"addr"`
StartedAt time .Time
Dispose bool
}
func ParseArgs (args am .A ) *A {
if r , _ := args [APrefix ].(*ARpc ); r != nil {
return amhelp .ArgsToArgs (r , &A {})
}
if a , _ := args [APrefix ].(*A ); a != nil {
return a
}
return &A {}
}
func Pass (args *A ) am .A {
return am .A {APrefix : args }
}
func PassRpc (args *A ) am .A {
return am .A {APrefix : amhelp .ArgsToArgs (args , &ARpc {})}
}
func LogArgs (args am .A ) map [string ]string {
a := ParseArgs (args )
if a == nil {
return nil
}
return amhelp .ArgsToLogMap (a , 0 )
}
type serverRpcMethods interface {
RemoteHello(client *rpc2 .Client , args *MsgCliHello , resp *MsgSrvHello ) error
RemoteAdd(
client *rpc2 .Client , args *MsgCliMutation , resp *MsgSrvMutation ) error
RemoteRemove(
client *rpc2 .Client , args *MsgCliMutation , resp *MsgSrvMutation ) error
RemoteSet(
client *rpc2 .Client , args *MsgCliMutation , reply *MsgSrvMutation ) error
}
type clientRpcMethods interface {
RemoteUpdate(worker *rpc2 .Client , args *MsgSrvUpdate , resp *MsgEmpty ) error
RemoteUpdateMutations(
worker *rpc2 .Client , args *MsgSrvUpdateMuts , resp *MsgEmpty ) error
RemoteSendingPayload(
worker *rpc2 .Client , file *MsgSrvPayload , resp *MsgEmpty ) error
RemoteSendPayload(
worker *rpc2 .Client , file *MsgSrvPayload , resp *MsgEmpty ) error
}
var (
ErrInvalidParams = errors .New ("invalid params" )
ErrInvalidResp = errors .New ("invalid response" )
ErrRpc = errors .New ("rpc" )
ErrNoAccess = errors .New ("no access" )
ErrNoConn = errors .New ("not connected" )
ErrDestination = errors .New ("wrong destination" )
ErrNetwork = errors .New ("network error" )
ErrNetworkTimeout = errors .New ("network timeout" )
)
func AddErrRpcStr (e *am .Event , mach *am .Machine , msg string ) {
err := fmt .Errorf ("%w: %s" , ErrRpc , msg )
mach .EvAddErrState (e , ss .ErrRpc , err , nil )
}
func AddErrParams (e *am .Event , mach *am .Machine , err error ) {
err = fmt .Errorf ("%w: %w" , ErrInvalidParams , err )
mach .AddErrState (ss .ErrRpc , err , nil )
}
func AddErrResp (e *am .Event , mach *am .Machine , err error ) {
err = fmt .Errorf ("%w: %w" , ErrInvalidResp , err )
mach .AddErrState (ss .ErrRpc , err , nil )
}
func AddErrNetwork (e *am .Event , mach *am .Machine , err error ) {
mach .AddErrState (ss .ErrNetwork , err , nil )
}
func AddErrNoConn (e *am .Event , mach *am .Machine , err error ) {
err = fmt .Errorf ("%w: %w" , ErrNoConn , err )
mach .AddErrState (ss .ErrNetwork , err , nil )
}
func AddErr (e *am .Event , mach *am .Machine , msg string , err error ) {
if msg != "" {
err = fmt .Errorf ("%w: %s" , err , msg )
}
if strings .HasPrefix (err .Error(), "gob: " ) {
AddErrResp (e , mach , err )
} else if strings .Contains (err .Error(), "rpc2: can't find method" ) {
AddErrRpcStr (e , mach , err .Error())
} else if strings .Contains (err .Error(), "connection is shut down" ) ||
strings .Contains (err .Error(), "unexpected EOF" ) {
mach .AddErrState (ss .ErrRpc , err , nil )
} else if strings .Contains (err .Error(), "timeout" ) {
AddErrNetwork (e , mach , errors .Join (err , ErrNetworkTimeout ))
} else if _ , ok := err .(*net .OpError ); ok {
AddErrNetwork (e , mach , err )
} else {
mach .AddErr (err , nil )
}
}
type ExceptionHandler struct {
*am .ExceptionHandler
}
func (h *ExceptionHandler ) ExceptionEnter (e *am .Event ) bool {
args := ParseArgs (e .Args )
mach := e .Machine ()
isRpcClient := mach .Has (am .S {ssC .Disconnecting , ssC .Disconnected })
if errors .Is (args .Err , ErrNetwork ) && isRpcClient &&
mach .Any1 (ssC .Disconnecting , ssC .Disconnected ) {
e .Machine ().Log ("ignoring ErrNetwork on Disconnecting/Disconnected" )
return false
}
return true
}
type semLogger struct {
mach *NetworkMachine
steps atomic .Bool
graph atomic .Bool
}
var _ am .SemLogger = &semLogger {}
func (s *semLogger ) SetArgsMapper (mapper am .LogArgsMapperFn ) {
}
func (s *semLogger ) ArgsMapper () am .LogArgsMapperFn {
return nil
}
func (s *semLogger ) EnableId (val bool ) {
}
func (s *semLogger ) IsId () bool {
return false
}
func (s *semLogger ) SetLogger (fn am .LoggerFn ) {
if fn == nil {
s .mach .logger .Store (nil )
return
}
s .mach .logger .Store (&fn )
}
func (s *semLogger ) Logger () am .LoggerFn {
if l := s .mach .logger .Load (); l != nil {
return *l
}
return nil
}
func (s *semLogger ) SetLevel (lvl am .LogLevel ) {
s .mach .logLevel .Store (&lvl )
}
func (s *semLogger ) Level () am .LogLevel {
return *s .mach .logLevel .Load ()
}
func (s *semLogger ) SetEmpty (lvl am .LogLevel ) {
var logger am .LoggerFn = func (_ am .LogLevel , msg string , args ...any ) {
}
s .mach .logger .Store (&logger )
s .mach .logLevel .Store (&lvl )
}
func (s *semLogger ) SetSimple (
logf func (format string , args ...any ), level am .LogLevel ,
) {
var logger am .LoggerFn = func (_ am .LogLevel , msg string , args ...any ) {
logf (msg , args ...)
}
s .mach .logger .Store (&logger )
s .mach .logLevel .Store (&level )
}
func (s *semLogger ) AddPipeOut (addMut bool , sourceState , targetMach string ) {
kind := "remove"
if addMut {
kind = "add"
}
s .mach .log (am .LogOps , "[pipe-out:%s] %s to %s" , kind , sourceState ,
targetMach )
}
func (s *semLogger ) AddPipeIn (addMut bool , targetState , sourceMach string ) {
kind := "remove"
if addMut {
kind = "add"
}
s .mach .log (am .LogOps , "[pipe-in:%s] %s from %s" , kind , targetState ,
sourceMach )
}
func (s *semLogger ) RemovePipes (machId string ) {
s .mach .log (am .LogOps , "[pipe:gc] %s" , machId )
}
func (s *semLogger ) IsSteps () bool {
return s .steps .Load ()
}
func (s *semLogger ) EnableSteps (enable bool ) {
s .steps .Store (enable )
}
func (s *semLogger ) IsGraph () bool {
return s .graph .Load ()
}
func (s *semLogger ) EnableGraph (enable bool ) {
s .graph .Store (enable )
}
func (s *semLogger ) EnableStateCtx (val bool ) {
}
func (s *semLogger ) IsStateCtx () bool {
return true
}
func (s *semLogger ) EnableWhen (val bool ) {
}
func (s *semLogger ) IsWhen () bool {
return true
}
func (s *semLogger ) EnableArgs (val bool ) {
}
func (s *semLogger ) IsArgs () bool {
return true
}
func (s *semLogger ) EnableQueued (val bool ) {
}
func (s *semLogger ) IsQueued () bool {
return true
}
func (s *semLogger ) EnableCan (enable bool ) {
}
func (s *semLogger ) IsCan () bool {
return true
}
type handler struct {
h any
name string
mx sync .Mutex
methods *reflect .Value
methodCache map [string ]reflect .Value
missingCache map [string ]struct {}
}
func newHandler(
handlers any , name string , methods *reflect .Value ,
) *handler {
return &handler {
name : name ,
h : handlers ,
methods : methods ,
methodCache : make (map [string ]reflect .Value ),
missingCache : make (map [string ]struct {}),
}
}
type tracerData struct {
mTrackedTimeSum uint64
mTime am .Time
queueTick uint64
machTick uint32
checksum uint8
tracked am .S
trackedIdxs []int
}
type tracerMutation struct {
mutType am .MutationType
calledIdxs []int
data tracerData
}
type sourceTracer struct {
*am .TracerNoOp
s *Server
active bool
dataLatest *tracerData
dataQueue []tracerMutation
trackedStates am .S
trackedStateIdxs []int
}
func (t *sourceTracer ) DataLatest () *tracerData {
t .s .lockCollection .Lock ()
defer t .s .lockCollection .Unlock ()
data := t .dataLatest
if data == nil {
return nil
}
ret := *data
return &ret
}
func (t *sourceTracer ) DataQueue () []tracerMutation {
t .s .lockCollection .Lock ()
defer t .s .lockCollection .Unlock ()
ret := t .dataQueue
t .dataLatest = nil
return ret
}
func (t *sourceTracer ) TransitionEnd (tx *am .Transition ) {
s := t .s
srcMach := s .Source
s .lockCollection .Lock ()
defer s .lockCollection .Unlock ()
if !t .active {
return
}
allStates := tx .Machine .StateNames ()
if t .trackedStates == nil {
t .calcTrackedStates (allStates )
}
qTick := srcMach .QueueTick ()
machTick := srcMach .MachineTick ()
mTime := srcMach .Time (nil )
trackedTSum := mTime .Filter (t .trackedStateIdxs ).Sum (nil )
if !s .syncSchema {
mTime = mTime .Filter (t .trackedStateIdxs )
}
if s .syncShallowClocks {
mTime = am .NewTime (mTime , mTime .ActiveStates (nil ))
trackedTSum = mTime .Sum (nil )
}
d := &tracerData {
mTime : mTime ,
mTrackedTimeSum : trackedTSum ,
queueTick : qTick ,
machTick : machTick ,
checksum : Checksum (trackedTSum , qTick , machTick ),
tracked : t .trackedStates ,
trackedIdxs : t .trackedStateIdxs ,
}
t .dataLatest = d
if s .syncMutations {
mut := tx .Mutation
called := slices .DeleteFunc (mut .Called , func (idx int ) bool {
return !slices .Contains (t .trackedStateIdxs , idx )
})
t .dataQueue = append (t .dataQueue , tracerMutation {
mutType : mut .Type ,
calledIdxs : called ,
data : *d ,
})
}
calledTracked := am .StatesShared (tx .TargetStates (), t .trackedStates )
go func () {
s .log ("tracer push: tt%d q%d (check:%d) %s" , trackedTSum , qTick ,
d .checksum , calledTracked )
t .s .pushClient ()
}()
}
func (t *sourceTracer ) SchemaChange (mach am .Api , oldSchema am .Schema ) {
s := t .s
s .lockCollection .Lock ()
defer s .lockCollection .Unlock ()
if !t .active {
return
}
msg := &MsgSrvHello {}
msg .Serialized , msg .Schema , _ = mach .Export ()
msg .StatesCount = uint32 (len (msg .Serialized .StateNames ))
allStates := msg .Serialized .StateNames
export := msg .Serialized
if !s .syncSchema {
export .StateNames = am .StatesShared (export .StateNames ,
s .tracer .trackedStates )
export .Time = export .Time .Filter (s .tracer .trackedStateIdxs )
} else {
for i := range allStates {
if slices .Contains (s .tracer .trackedStateIdxs , i ) {
continue
}
export .Time [i ] = 0
}
}
if t .trackedStates == nil {
t .calcTrackedStates (allStates )
}
d := s .lastPushData
d .mTime = export .Time
d .queueTick = export .QueueTick
d .mTrackedTimeSum = export .Time .Sum (nil )
d .checksum = Checksum (d .mTrackedTimeSum , d .queueTick , d .machTick )
go func () {
client := t .s .rpcClient .Load ()
if client == nil {
return
}
s .lockExport .Lock ()
defer s .lockExport .Unlock ()
err := client .CallWithContext (mach .Ctx (), ClientSchemaChange .Value , msg ,
&MsgEmpty {})
mach .AddErr (err , nil )
}()
}
func (t *sourceTracer ) calcTrackedStates (states am .S ) {
s := t .s
t .trackedStates = states
if s .syncAllowedStates != nil {
t .trackedStates = am .StatesShared (t .trackedStates , s .syncAllowedStates )
}
t .trackedStates = am .StatesDiff (t .trackedStates , s .syncSkippedStates )
t .trackedStateIdxs = make ([]int , len (t .trackedStates ))
for i , name := range t .trackedStates {
t .trackedStateIdxs [i ] = slices .Index (states , name )
}
}
type msgpackCoded struct {
rwc io .ReadWriteCloser
dec *gob .Decoder
enc *gob .Encoder
encBuf *bufio .Writer
mutex sync .Mutex
}
type msgpackMsg struct {
Seq uint64
Method string
Error string
}
func NewMsgpackCodec (conn io .ReadWriteCloser ) rpc2 .Codec {
buf := bufio .NewWriter (conn )
return &msgpackCoded {
rwc : conn ,
dec : gob .NewDecoder (conn ),
enc : gob .NewEncoder (buf ),
encBuf : buf ,
}
}
func (c *msgpackCoded ) ReadHeader (
req *rpc2 .Request , resp *rpc2 .Response ,
) error {
var msg msgpackMsg
if err := c .dec .Decode (&msg ); err != nil {
return err
}
if msg .Method != "" {
req .Seq = msg .Seq
req .Method = msg .Method
} else {
resp .Seq = msg .Seq
resp .Error = msg .Error
}
return nil
}
func (c *msgpackCoded ) ReadRequestBody (body interface {}) error {
return c .dec .Decode (body )
}
func (c *msgpackCoded ) ReadResponseBody (body interface {}) error {
return c .dec .Decode (body )
}
func (c *msgpackCoded ) WriteRequest (
r *rpc2 .Request , body interface {},
) (err error ) {
c .mutex .Lock ()
defer c .mutex .Unlock ()
if err = c .enc .Encode (r ); err != nil {
return
}
if err = c .enc .Encode (body ); err != nil {
return
}
return c .encBuf .Flush ()
}
func (c *msgpackCoded ) WriteResponse (
r *rpc2 .Response , body interface {},
) (err error ) {
c .mutex .Lock ()
defer c .mutex .Unlock ()
if err = c .enc .Encode (r ); err != nil {
return
}
if err = c .enc .Encode (body ); err != nil {
return
}
return c .encBuf .Flush ()
}
func (c *msgpackCoded ) Close () error {
return c .rwc .Close ()
}
func MachReplEnv (mach am .Api ) (error , <-chan error ) {
addr := os .Getenv (EnvAmReplAddr )
dir := os .Getenv (EnvAmReplDir )
switch addr {
case "" :
return nil , nil
case "1" :
addr = ""
}
errCh := make (chan error , 1 )
opts := &ReplOpts {
AddrDir : dir ,
ErrCh : errCh ,
}
if err := MachRepl (mach , addr , opts ); err != nil {
return err , errCh
}
return nil , errCh
}
func MachRepl (mach am .Api , addr string , opts *ReplOpts ) error {
if opts == nil {
opts = &ReplOpts {}
}
addrDir := opts .AddrDir
addrCh := opts .AddrCh
errCh := opts .ErrCh
if amhelp .IsTestRunner () {
return amhelp .ErrTestAutoDisable
}
if addr == "" {
addr = "127.0.0.1:0"
}
if mach .HasHandlers () && !mach .Has (ssW .Names ()) {
err := fmt .Errorf (
"%w: REPL source has to implement pkg/rpc/states/NetSourceStatesDef" ,
am .ErrSchema )
return err
}
if opts .Args != nil {
t := reflect .TypeOf (opts .Args )
if t .Kind () != reflect .Struct {
return fmt .Errorf ("expected a struct, got %s" , t .Kind ())
}
}
mux , err := NewMux (mach .Ctx (), "repl-" +mach .Id (), nil , &MuxOpts {
Parent : mach ,
Args : opts .Args ,
ArgsPrefix : opts .ArgsPrefix ,
})
if err != nil {
return err
}
mux .Addr = addr
mux .Source = mach
mux .Start ()
if addrCh == nil && addrDir == "" {
if errCh != nil {
close (errCh )
}
return nil
}
go func () {
defer func () {
if errCh != nil {
close (errCh )
}
if addrCh != nil {
close (addrCh )
}
}()
dirOk := false
if addrDir != "" {
if _ , err := os .Stat (addrDir ); os .IsNotExist (err ) {
err := os .MkdirAll (addrDir , 0o755 )
if err == nil {
dirOk = true
} else if errCh != nil {
errCh <- err
}
} else {
dirOk = true
}
}
<-mux .Mach .When1 (ssM .Ready , nil )
if addrCh != nil {
addrCh <- mux .Addr
}
if dirOk && addrDir != "" {
err = os .WriteFile (
filepath .Join (addrDir , mach .Id ()+".addr" ),
[]byte (mux .Addr ), 0o644 ,
)
if errCh != nil {
errCh <- err
}
}
}()
return nil
}
func Checksum (mTime uint64 , qTick uint64 , machTick uint32 ) uint8 {
return uint8 (mTime + qTick + uint64 (machTick ))
}
func TrafficMeter (
listener net .Listener , fwdTo string , counter chan <- int64 ,
end <-chan struct {},
) {
defer listener .Close ()
destination , err := net .Dial ("tcp4" , fwdTo )
if err != nil {
fmt .Println ("Error connecting to destination:" , err .Error())
return
}
defer destination .Close ()
conn , err := listener .Accept ()
if err != nil {
fmt .Println ("Error accepting connection:" , err .Error())
return
}
defer conn .Close ()
wg := sync .WaitGroup {}
wg .Add (2 )
bytes := atomic .Int64 {}
go func () {
c , _ := io .Copy (destination , conn )
bytes .Add (c )
wg .Done ()
}()
go func () {
c , _ := io .Copy (conn , destination )
bytes .Add (c )
wg .Done ()
}()
<-end
_ = listener .Close ()
_ = destination .Close ()
_ = conn .Close ()
wg .Wait ()
c := bytes .Load ()
counter <- c
}
func newClosedChan() chan struct {} {
ch := make (chan struct {})
close (ch )
return ch
}
The pages are generated with Golds v0.8.2 . (GOOS=linux GOARCH=amd64)
Golds is a Go 101 project developed by Tapir Liu .
PR and bug reports are welcome and can be submitted to the issue list .
Please follow @zigo_101 (reachable from the left QR code) to get the latest news of Golds .