// Package rpc is a transparent RPC for state machines.
package rpc import ( amhelp am ) func init() { gob.Register(&ARpc{}) gob.Register(am.Relation(0)) } const ( // EnvAmRpcLogServer enables machine logging for RPC server. EnvAmRpcLogServer = "AM_RPC_LOG_SERVER" // EnvAmRpcLogClient enables machine logging for RPC client. EnvAmRpcLogClient = "AM_RPC_LOG_CLIENT" // EnvAmRpcLogMux enables machine logging for RPC multiplexers. EnvAmRpcLogMux = "AM_RPC_LOG_MUX" // EnvAmRpcDbg enables env-based debugging for RPC components. EnvAmRpcDbg = "AM_RPC_DBG" // EnvAmReplAddr is a REPL address to listen on. "1" expands to 127.0.0.1:0. EnvAmReplAddr = "AM_REPL_ADDR" // EnvAmReplDir is a dir path to save the address file as // $AM_REPL_DIR/mach-id.addr. Optional. EnvAmReplDir = "AM_REPL_DIR" PrefixNetMach = "rnm-" ) var ss = states.SharedStates // ///// ///// ///// // ///// TYPES // ///// ///// ///// // RPC methods type ( ServerMethod enum.Member[string] ClientMethod enum.Member[string] ) var ( // methods define on the server ServerAdd = ServerMethod{"Add"} ServerAddNS = ServerMethod{"AddNS"} ServerRemove = ServerMethod{"Remove"} ServerSet = ServerMethod{"Set"} ServerHello = ServerMethod{"Hello"} ServerHandshake = ServerMethod{"Handshake"} ServerLog = ServerMethod{"Log"} ServerSync = ServerMethod{"Sync"} ServerArgs = ServerMethod{"Args"} ServerBye = ServerMethod{"Close"} ServerMethods = enum.New(ServerAdd, ServerAddNS, ServerRemove, ServerSet, ServerHello, ServerHandshake, ServerLog, ServerSync, ServerArgs, ServerBye) // methods define on the client ClientUpdate = ClientMethod{"ClientSetClock"} ClientUpdateMutations = ClientMethod{"ClientSetClockMany"} ClientPushAllTicks = ClientMethod{"ClientPushAllTicks"} ClientSendPayload = ClientMethod{"ClientSendPayload"} ClientBye = ClientMethod{"ClientBye"} ClientSchemaChange = ClientMethod{"SchemaChange"} ClientMethods = enum.New(ClientUpdate, ClientPushAllTicks, ClientSendPayload, ClientBye, ClientSchemaChange) ) // MSGS TODO migrate to msgpack, shorten names // MsgCliHello is the client saying hello to the server. type MsgCliHello struct { // ID of the client saying Hello. Id string // Client wants to synchronize the schema. SyncSchema bool // Hash of the current schema, or "". Schema is always full and not affected // by [MsgCliHello.AllowedStates] or [MsgCliHello.SkippedStates]. SchemaHash string SyncMutations bool AllowedStates am.S SkippedStates am.S ShallowClocks bool // TODO WhenArgs: []{ map[string]string: token } } // MsgSrvHello is the server saying hello to the client. type MsgSrvHello struct { Schema am.Schema Serialized *am.Serialized // total source states count StatesCount uint32 } // MsgCliMutation is the client requesting a mutation from the server. type MsgCliMutation struct { States []int Args am.A Event *am.Event } // MsgSrvMutation is the server replying to a mutation request for the client. type MsgSrvMutation struct { Update *MsgSrvUpdate Mutations *MsgSrvUpdateMuts Result am.Result } // MsgSrvPayload is the server sending a payload to the client. type MsgSrvPayload struct { // Name is used to distinguish different payload types at the destination. Name string // Source is the machine ID that sent the payload. Source string // SourceTx is transition ID. SourceTx string // Destination is an optional machine ID that is supposed to receive the // payload. Useful when using rpc.Mux. Destination string // Data is the payload data. The Consumer has to know the type. Data any // internal // Token is a unique random ID for the payload. Autofilled by the server. Token string } // MsgSrvSync is the server replying to a full sync request from the client. type MsgSrvSync struct { Time am.Time QueueTick uint64 MachTick uint32 } type MsgSrvArgs struct { Args []string } // TODO type MsgCliWhenArgs struct {} // MsgEmpty is an empty message of either the server or client. type MsgEmpty struct{} // MsgSrvUpdate is the server telling the client about a net source's update. type MsgSrvUpdate struct { // Indexes of incremented states. Indexes []uint16 // Clock diffs of incremented states. // TODO optimize: []uint16 and send 2 updates when needed Ticks []uint32 // TODO optimize: for shallow clocks // Active []bool // QueueTick is an incremental diff for the queue tick. QueueTick uint16 // MachTick is an incremental diff for the machine tick. MachTick uint8 // Checksum is the last digit of (TimeSum + QueueTick + MachTick) Checksum uint8 // DBGSum uint64 // DBGLastSum uint64 // DBGQTick uint64 // DBGLastQTick uint64 } // MsgSrvUpdateMuts is like [MsgSrvUpdate] but contains several clock updates // (one for each mutation), as well as extra mutation info. type MsgSrvUpdateMuts struct { // TODO mind partially accepted auto states (fake called states). // Auto bool MutationType []am.MutationType CalledStates [][]uint16 Updates []MsgSrvUpdate } // clientServerMethods is a shared interface for RPC client/server. type clientServerMethods interface { GetKind() Kind } // Kind of the RCP component. type Kind string const ( KindClient Kind = "client" KindServer Kind = "server" ) type ReplOpts struct { // optional dir path to save the address file as addrDir/mach-id.addr AddrDir string // optional channel to send err to, once ready ErrCh chan<- error // optional channel to send the address to, once ready AddrCh chan<- string // optional prefix for typesafe args. Requires Args. ArgsPrefix string // optional typed args instance. Requires ArgsPrefix Args any } // ///// ///// ///// // ///// ARGS // ///// ///// ///// const APrefix = "am_rpc" // A represents typed arguments of the RPC package. It's a typesafe alternative // to [am.A]. type A struct { Id string `log:"id"` Name string `log:"name"` MachTime am.Time QueueTick uint64 MachTick uint32 Payload *MsgSrvPayload Addr string `log:"addr"` Err error Method string `log:"addr"` StartedAt time.Time Dispose bool // non-rpc fields Client *rpc2.Client } // ARpc is a subset of A, that can be passed over RPC. type ARpc struct { Id string `log:"id"` Name string `log:"name"` MachTime am.Time QueueTick uint64 MachTick uint32 Payload *MsgSrvPayload Addr string `log:"addr"` Err error Method string `log:"addr"` StartedAt time.Time Dispose bool } // ParseArgs extracts A from [am.Event.Args][APrefix]. func ( am.A) *A { if , := [APrefix].(*ARpc); != nil { return amhelp.ArgsToArgs(, &A{}) } if , := [APrefix].(*A); != nil { return } return &A{} } // Pass prepares [am.A] from A to pass to further mutations. func ( *A) am.A { return am.A{APrefix: } } // PassRpc prepares [am.A] from A to pass over RPC. func ( *A) am.A { return am.A{APrefix: amhelp.ArgsToArgs(, &ARpc{})} } // LogArgs is an args logger for A. func ( am.A) map[string]string { := ParseArgs() if == nil { return nil } return amhelp.ArgsToLogMap(, 0) } // // DEBUG for perf testing TODO tag // type MsgSrvUpdate am.Time // ///// ///// ///// // ///// RPC APIS // ///// ///// ///// // serverRpcMethods is the main RPC server's exposed methods. type serverRpcMethods interface { // rpc RemoteHello(client *rpc2.Client, args *MsgCliHello, resp *MsgSrvHello) error // mutations RemoteAdd( client *rpc2.Client, args *MsgCliMutation, resp *MsgSrvMutation) error RemoteRemove( client *rpc2.Client, args *MsgCliMutation, resp *MsgSrvMutation) error RemoteSet( client *rpc2.Client, args *MsgCliMutation, reply *MsgSrvMutation) error } // clientRpcMethods is the RPC server exposed by the RPC client for bi-di comm. type clientRpcMethods interface { RemoteUpdate(worker *rpc2.Client, args *MsgSrvUpdate, resp *MsgEmpty) error RemoteUpdateMutations( worker *rpc2.Client, args *MsgSrvUpdateMuts, resp *MsgEmpty) error RemoteSendingPayload( worker *rpc2.Client, file *MsgSrvPayload, resp *MsgEmpty) error RemoteSendPayload( worker *rpc2.Client, file *MsgSrvPayload, resp *MsgEmpty) error } // ///// ///// ///// // ///// ERRORS // ///// ///// ///// // sentinel errors var ( // ErrClient group ErrInvalidParams = errors.New("invalid params") ErrInvalidResp = errors.New("invalid response") ErrRpc = errors.New("rpc") ErrNoAccess = errors.New("no access") ErrNoConn = errors.New("not connected") ErrDestination = errors.New("wrong destination") // ErrNetwork group ErrNetwork = errors.New("network error") ErrNetworkTimeout = errors.New("network timeout") // TODO ErrDelivery ) // wrapping error setters func ( *am.Event, *am.Machine, string) { := fmt.Errorf("%w: %s", ErrRpc, ) .EvAddErrState(, ss.ErrRpc, , nil) } func ( *am.Event, *am.Machine, error) { = fmt.Errorf("%w: %w", ErrInvalidParams, ) .AddErrState(ss.ErrRpc, , nil) } func ( *am.Event, *am.Machine, error) { = fmt.Errorf("%w: %w", ErrInvalidResp, ) .AddErrState(ss.ErrRpc, , nil) } func ( *am.Event, *am.Machine, error) { .AddErrState(ss.ErrNetwork, , nil) } func ( *am.Event, *am.Machine, error) { = fmt.Errorf("%w: %w", ErrNoConn, ) .AddErrState(ss.ErrNetwork, , nil) } // AddErr detects sentinels from error msgs and calls the proper error setter. // TODO also return error for compat func ( *am.Event, *am.Machine, string, error) { if != "" { = fmt.Errorf("%w: %s", , ) } if strings.HasPrefix(.Error(), "gob: ") { AddErrResp(, , ) } else if strings.Contains(.Error(), "rpc2: can't find method") { AddErrRpcStr(, , .Error()) } else if strings.Contains(.Error(), "connection is shut down") || strings.Contains(.Error(), "unexpected EOF") { // TODO bind to sentinels io.ErrUnexpectedEOF, rpc2.ErrShutdown .AddErrState(ss.ErrRpc, , nil) } else if strings.Contains(.Error(), "timeout") { AddErrNetwork(, , errors.Join(, ErrNetworkTimeout)) } else if , := .(*net.OpError); { AddErrNetwork(, , ) } else { .AddErr(, nil) } } // ExceptionHandler is a shared exception handler for RPC server and // client. type ExceptionHandler struct { *am.ExceptionHandler } func ( *ExceptionHandler) ( *am.Event) bool { := ParseArgs(.Args) := .Machine() := .Has(am.S{ssC.Disconnecting, ssC.Disconnected}) if errors.Is(.Err, ErrNetwork) && && .Any1(ssC.Disconnecting, ssC.Disconnected) { // skip network errors on client disconnect .Machine().Log("ignoring ErrNetwork on Disconnecting/Disconnected") return false } return true } // ///// ///// ///// // ///// LOGGER // ///// ///// ///// type semLogger struct { mach *NetworkMachine steps atomic.Bool graph atomic.Bool } // implement [SemLogger] var _ am.SemLogger = &semLogger{} func ( *semLogger) ( am.LogArgsMapperFn) { // TODO } func ( *semLogger) () am.LogArgsMapperFn { // TODO return nil } func ( *semLogger) ( bool) { // TODO } func ( *semLogger) () bool { return false } func ( *semLogger) ( am.LoggerFn) { if == nil { .mach.logger.Store(nil) return } .mach.logger.Store(&) } func ( *semLogger) () am.LoggerFn { if := .mach.logger.Load(); != nil { return * } return nil } func ( *semLogger) ( am.LogLevel) { .mach.logLevel.Store(&) } func ( *semLogger) () am.LogLevel { return *.mach.logLevel.Load() } func ( *semLogger) ( am.LogLevel) { var am.LoggerFn = func( am.LogLevel, string, ...any) { // no-op } .mach.logger.Store(&) .mach.logLevel.Store(&) } func ( *semLogger) ( func( string, ...any), am.LogLevel, ) { var am.LoggerFn = func( am.LogLevel, string, ...any) { (, ...) } .mach.logger.Store(&) .mach.logLevel.Store(&) } func ( *semLogger) ( bool, , string) { := "remove" if { = "add" } .mach.log(am.LogOps, "[pipe-out:%s] %s to %s", , , ) } func ( *semLogger) ( bool, , string) { := "remove" if { = "add" } .mach.log(am.LogOps, "[pipe-in:%s] %s from %s", , , ) } func ( *semLogger) ( string) { .mach.log(am.LogOps, "[pipe:gc] %s", ) } func ( *semLogger) () bool { return .steps.Load() } func ( *semLogger) ( bool) { .steps.Store() } func ( *semLogger) () bool { return .graph.Load() } func ( *semLogger) ( bool) { .graph.Store() } // TODO more data types func ( *semLogger) ( bool) { // TODO } func ( *semLogger) () bool { return true } func ( *semLogger) ( bool) { // TODO } func ( *semLogger) () bool { return true } func ( *semLogger) ( bool) { // TODO params for synthetic log } func ( *semLogger) () bool { return true } func ( *semLogger) ( bool) { // TODO } func ( *semLogger) () bool { return true } func ( *semLogger) ( bool) { // TODO } func ( *semLogger) () bool { return true } // ///// ///// ///// // ///// REMOTE HANDLERS // ///// ///// ///// // handler represents a single event consumer, synchronized by channels. type handler struct { h any name string mx sync.Mutex methods *reflect.Value methodCache map[string]reflect.Value missingCache map[string]struct{} } func newHandler( any, string, *reflect.Value, ) *handler { return &handler{ name: , h: , methods: , methodCache: make(map[string]reflect.Value), missingCache: make(map[string]struct{}), } } // ///// ///// ///// // ///// TRACERS // ///// ///// ///// type tracerData struct { mTrackedTimeSum uint64 // time is either source-bound or client-bound (if no schema) mTime am.Time queueTick uint64 machTick uint32 // tracked-states-only checksum checksum uint8 tracked am.S // tracked idx -> machine idx trackedIdxs []int // mach time on the client according to tracked states // mTimeSumClient uint64 } type tracerMutation struct { mutType am.MutationType calledIdxs []int data tracerData } // sourceTracer is a tracer for source state-machines, used by the RPC server // to produce updates for the RPC client. type sourceTracer struct { *am.TracerNoOp s *Server // tracer needs explicit activation active bool // latest data, possibly already sent dataLatest *tracerData // unsent data generated for each mutations dataQueue []tracerMutation // list of states this tracer is syncing trackedStates am.S // tracked idx -> machine idx trackedStateIdxs []int } // getters func ( *sourceTracer) () *tracerData { // lock .s.lockCollection.Lock() defer .s.lockCollection.Unlock() // copy := .dataLatest if == nil { return nil } := * return & } func ( *sourceTracer) () []tracerMutation { // lock .s.lockCollection.Lock() defer .s.lockCollection.Unlock() // copy and flush := .dataQueue .dataLatest = nil return } // tracing func ( *sourceTracer) ( *am.Transition) { := .s := .Source // lock .lockCollection.Lock() defer .lockCollection.Unlock() if !.active { return } // init cache := .Machine.StateNames() if .trackedStates == nil { .calcTrackedStates() } := .QueueTick() := .MachineTick() := .Time(nil) := .Filter(.trackedStateIdxs).Sum(nil) // filter the time slice if !.syncSchema { = .Filter(.trackedStateIdxs) } if .syncShallowClocks { = am.NewTime(, .ActiveStates(nil)) = .Sum(nil) } // update := &tracerData{ mTime: , mTrackedTimeSum: , // mTimeSumClient: mTimeClient.Sum(nil), queueTick: , machTick: , checksum: Checksum(, , ), tracked: .trackedStates, trackedIdxs: .trackedStateIdxs, } .dataLatest = // DEBUG // if srcMach.Id() == "ns-TestPartial" { // fmt.Printf("[T] [%+v] %d %d\n", d.mTime, qTick, machTick) // fmt.Printf("[T] check %d\n", d.checksum) // } // mutations if .syncMutations { := .Mutation // skip non-tracked called states := slices.DeleteFunc(.Called, func( int) bool { return !slices.Contains(.trackedStateIdxs, ) }) .dataQueue = append(.dataQueue, tracerMutation{ mutType: .Type, calledIdxs: , data: *, }) } := am.StatesShared(.TargetStates(), .trackedStates) // TODO optimize: fork max 1? go func() { .log("tracer push: tt%d q%d (check:%d) %s", , , .checksum, ) // try to push this tx to the client .s.pushClient() }() } func ( *sourceTracer) ( am.Api, am.Schema) { := .s // lock .lockCollection.Lock() defer .lockCollection.Unlock() if !.active { return } := &MsgSrvHello{} .Serialized, .Schema, _ = .Export() .StatesCount = uint32(len(.Serialized.StateNames)) := .Serialized.StateNames // client-bound indexes when no schema synced := .Serialized if !.syncSchema { .StateNames = am.StatesShared(.StateNames, .tracer.trackedStates) .Time = .Time.Filter(.tracer.trackedStateIdxs) // zero non-tracked for consistent checksums } else { for := range { if slices.Contains(.tracer.trackedStateIdxs, ) { continue } .Time[] = 0 } } // init cache if .trackedStates == nil { .calcTrackedStates() } // memorize := .lastPushData .mTime = .Time .queueTick = .QueueTick .mTrackedTimeSum = .Time.Sum(nil) .checksum = Checksum(.mTrackedTimeSum, .queueTick, .machTick) // fork and push go func() { := .s.rpcClient.Load() if == nil { return } .lockExport.Lock() defer .lockExport.Unlock() // send := .CallWithContext(.Ctx(), ClientSchemaChange.Value, , &MsgEmpty{}) .AddErr(, nil) }() } // internal func ( *sourceTracer) ( am.S) { := .s .trackedStates = if .syncAllowedStates != nil { .trackedStates = am.StatesShared(.trackedStates, .syncAllowedStates) } .trackedStates = am.StatesDiff(.trackedStates, .syncSkippedStates) .trackedStateIdxs = make([]int, len(.trackedStates)) for , := range .trackedStates { .trackedStateIdxs[] = slices.Index(, ) } } // ///// ///// ///// // ///// MSGPACK // ///// ///// ///// // TODO type msgpackCoded struct { rwc io.ReadWriteCloser // TODO dec *gob.Decoder enc *gob.Encoder encBuf *bufio.Writer mutex sync.Mutex } type msgpackMsg struct { Seq uint64 Method string Error string } // TODO optimize with msgpack func ( io.ReadWriteCloser) rpc2.Codec { := bufio.NewWriter() return &msgpackCoded{ rwc: , // TODO dec: gob.NewDecoder(), enc: gob.NewEncoder(), encBuf: , } } func ( *msgpackCoded) ( *rpc2.Request, *rpc2.Response, ) error { var msgpackMsg if := .dec.Decode(&); != nil { return } if .Method != "" { .Seq = .Seq .Method = .Method } else { .Seq = .Seq .Error = .Error } return nil } func ( *msgpackCoded) ( interface{}) error { return .dec.Decode() } func ( *msgpackCoded) ( interface{}) error { return .dec.Decode() } func ( *msgpackCoded) ( *rpc2.Request, interface{}, ) ( error) { .mutex.Lock() defer .mutex.Unlock() if = .enc.Encode(); != nil { return } if = .enc.Encode(); != nil { return } return .encBuf.Flush() } func ( *msgpackCoded) ( *rpc2.Response, interface{}, ) ( error) { .mutex.Lock() defer .mutex.Unlock() if = .enc.Encode(); != nil { return } if = .enc.Encode(); != nil { return } return .encBuf.Flush() } func ( *msgpackCoded) () error { return .rwc.Close() } // ///// ///// ///// // ///// MISC // ///// ///// ///// // MachReplEnv sets up a machine for a REPL connection in case AM_REPL_ADDR env // var is set. See MachRepl. func ( am.Api) (error, <-chan error) { := os.Getenv(EnvAmReplAddr) := os.Getenv(EnvAmReplDir) switch { case "": return nil, nil case "1": // expand 1 to default = "" } // MachRepl closes errCh := make(chan error, 1) := &ReplOpts{ AddrDir: , ErrCh: , } if := MachRepl(, , ); != nil { return , } return nil, } // MachRepl sets up a machine for a REPL connection, which allows for // mutations, like any other RPC connection. See [/tools/cmd/arpc] for usage. // This function is considered a debugging helper and can panic. // // addr: address to listen on, default to 127.0.0.1:0 // addrDir: optional dir path to save the address file as addrDir/mach-id.addr // addrCh: optional channel to send the address to, once ready // errCh: optional channel for errors func ( am.Api, string, *ReplOpts) error { if == nil { = &ReplOpts{} } := .AddrDir := .AddrCh := .ErrCh if amhelp.IsTestRunner() { return amhelp.ErrTestAutoDisable } if == "" { = "127.0.0.1:0" } if .HasHandlers() && !.Has(ssW.Names()) { := fmt.Errorf( "%w: REPL source has to implement pkg/rpc/states/NetSourceStatesDef", am.ErrSchema) return } // verify args is a value struct if .Args != nil { := reflect.TypeOf(.Args) if .Kind() != reflect.Struct { return fmt.Errorf("expected a struct, got %s", .Kind()) } } , := NewMux(.Ctx(), "repl-"+.Id(), nil, &MuxOpts{ Parent: , Args: .Args, ArgsPrefix: .ArgsPrefix, }) if != nil { return } .Addr = .Source = .Start() if == nil && == "" { if != nil { close() } return nil } go func() { // dispose ret channels defer func() { if != nil { close() } if != nil { close() } }() // prep the dir := false if != "" { if , := os.Stat(); os.IsNotExist() { := os.MkdirAll(, 0o755) if == nil { = true } else if != nil { <- } } else { = true } } // wait for an addr <-.Mach.When1(ssM.Ready, nil) if != nil { <- .Addr } // save to dir if && != "" { = os.WriteFile( filepath.Join(, .Id()+".addr"), []byte(.Addr), 0o644, ) if != nil { <- } } }() return nil } // // DEBUG for perf testing // func NewClockMsg(before, after am.Time) MsgSrvUpdate { // return MsgSrvUpdate(after) // } // // // DEBUG for perf testing // func clockFromUpdate(before am.Time, msg MsgSrvUpdate) am.Time { // return am.Time(msg) // } // Checksum calculates a short checksum of current machine time and ticks. func ( uint64, uint64, uint32) uint8 { return uint8( + + uint64()) } // TrafficMeter measures the traffic of a listener and forwards it to a // destination. Results are sent to the [counter] channel. Useful for testing // and benchmarking. func ( net.Listener, string, chan<- int64, <-chan struct{}, ) { defer .Close() // fmt.Println("Listening on " + listenOn) // callFailsafe the destination , := net.Dial("tcp4", ) if != nil { fmt.Println("Error connecting to destination:", .Error()) return } defer .Close() // wait for the connection , := .Accept() if != nil { fmt.Println("Error accepting connection:", .Error()) return } defer .Close() // forward data bidirectionally := sync.WaitGroup{} .Add(2) := atomic.Int64{} go func() { , := io.Copy(, ) .Add() .Done() }() go func() { , := io.Copy(, ) .Add() .Done() }() // wait for the test and forwarding to finish <- // fmt.Printf("Closing counter...\n") _ = .Close() _ = .Close() _ = .Close() .Wait() := .Load() // fmt.Printf("Forwarded %d bytes\n", c) <- } func newClosedChan() chan struct{} { := make(chan struct{}) close() return }