package gorm
import (
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
"log"
"os"
"slices"
"sync"
"sync/atomic"
"time"
_ "github.com/ncruces/go-sqlite3/embed"
"github.com/ncruces/go-sqlite3/gormlite"
"github.com/ncruces/go-sqlite3/vfs"
"golang.org/x/sync/errgroup"
"gorm.io/datatypes"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/logger"
"github.com/pancsta/asyncmachine-go/internal/utils"
amhist "github.com/pancsta/asyncmachine-go/pkg/history"
am "github.com/pancsta/asyncmachine-go/pkg/machine"
)
type MatcherFn func (now *am .TimeIndex , query *gorm .DB ) *gorm .DB
type Config struct {
amhist .BaseConfig
QueueBatch int32
SavePool int
}
type Machine struct {
ID uint32 `gorm:"primaryKey"`
Times []Time
States []State
MachId string `gorm:"column:mach_id;index:mach_id"`
StateNames datatypes .JSON
Schema datatypes .JSON
FirstTracking time .Time
LastTracking time .Time
LastSync time .Time
MachTick uint32
MTime datatypes .JSON
MTimeSum uint64
NextId uint64
cacheMTime am .Time
}
type Time struct {
ID uint64 `gorm:"primaryKey;autoIncrement:false"`
MachineID uint32 `gorm:"primaryKey"`
Ticks []Tick `gorm:"foreignKey:TimeID,MachineID;references:ID,MachineID"`
MutType am .MutationType
MTimeSum uint64
MTimeTrackedSum uint64
MTimeDiffSum uint64
MTimeTrackedDiffSum uint64
MTimeRecordDiffSum uint64
HTime time .Time
MTimeTracked datatypes .JSON
MTimeTrackedDiff datatypes .JSON
MachTick uint32
cacheMTimeTracked am .Time
TxId string `gorm:"index:tx_id"`
TxSourceTx *string
TxSourceMach *string
TxIsAuto bool
TxIsAccepted bool
TxIsCheck bool
TxIsBroken bool
TxQueueLen uint16
TxQueuedAt *uint64
TxExecutedAt *uint64
TxCalled datatypes .JSON
TxArguments *datatypes .JSON
}
type State struct {
ID uint `gorm:"primaryKey"`
MachineID string `gorm:"index:machine_state"`
Index int `gorm:"index:machine_state"`
Name string
}
type Tick struct {
TimeID uint64 `gorm:"primaryKey;autoIncrement:false;index:activated"`
MachineID uint32 `gorm:"primaryKey;autoIncrement:false;index:activated"`
StateID uint `gorm:"primaryKey;autoIncrement:false;index:activated"`
State State
Tick uint64
Activated bool `gorm:"index:activated"`
Deactivated bool
Active bool
}
type tracer struct {
*am .TracerNoOp
mem *Memory
}
func (t *tracer ) MachineInit (mach am .Api ) context .Context {
m := t .mem
now := time .Now ().UTC ()
var err error
m .mx .Lock ()
defer m .mx .Unlock ()
rec , _ := GetMachine (m .Db , mach .Id (), true )
if rec == nil {
rec = &Machine {
MachId : mach .Id (),
FirstTracking : now ,
NextId : 1 ,
}
}
mTime := mach .Time (nil )
mTimeBt , _ := json .Marshal (mTime )
rec .MachId = m .Mach .Id ()
rec .LastTracking = now
rec .FirstTracking = now
rec .LastSync = now
rec .MTime = mTimeBt
rec .MTimeSum = mTime .Sum (nil )
rec .MachTick = mach .MachineTick ()
rec .cacheMTime = mTime
m .nextId .Store (rec .NextId )
if m .Cfg .StoreSchema {
rec .Schema , err = json .Marshal (mach .Schema ())
if err != nil {
m .onErr (err )
return nil
}
rec .StateNames , err = json .Marshal (mach .StateNames ())
if err != nil {
m .onErr (err )
return nil
}
}
existing := make (map [string ]bool , len (rec .States ))
for _ , state := range rec .States {
m .cacheDbIdxs [state .Name ] = state .ID
existing [state .Name ] = true
}
added := false
for _ , state := range m .Cfg .TrackedStates {
if existing [state ] {
continue
}
added = true
rec .States = append (rec .States , State {
MachineID : m .Mach .Id (),
Index : m .Mach .Index1 (state ),
Name : state ,
})
}
err = m .Db .Clauses (clause .OnConflict {
Columns : []clause .Column {{Name : "id" }},
DoUpdates : clause .AssignmentColumns ([]string {
"last_sync" , "last_tracking" , "mach_tick" , "m_time" , "m_time_sum" ,
"schema" }),
}).Create (rec ).Error
if err != nil {
m .onErr (err )
return nil
}
m .machRec = rec
if added {
rec , _ = GetMachine (m .Db , mach .Id (), true )
for _ , state := range rec .States {
m .cacheDbIdxs [state .Name ] = state .ID
}
}
return nil
}
func (t *tracer ) SchemaChange (machine am .Api , old am .Schema ) {
m := t .mem
var err error
m .mx .Lock ()
defer m .mx .Unlock ()
if !m .Cfg .StoreSchema {
return
}
rec := m .machRec
rec .Schema , err = json .Marshal (m .Mach .Schema ())
if err != nil {
m .onErr (err )
return
}
rec .StateNames , err = json .Marshal (m .Mach .StateNames ())
if err != nil {
m .onErr (err )
return
}
if err := m .Db .Save (m .machRec ).Error ; err != nil {
m .onErr (fmt .Errorf ("failed to save: %w" , err ))
}
}
func (t *tracer ) TransitionEnd (tx *am .Transition ) {
m := t .mem
if m .Ctx .Err () != nil {
_ = m .Dispose ()
return
}
if (!tx .IsAccepted .Load () && !m .Cfg .TrackRejected ) || tx .Mutation .IsCheck {
return
}
m .mx .Lock ()
defer m .mx .Unlock ()
mach := m .Mach
called := tx .CalledStates ()
changed := tx .TimeAfter .DiffSince (tx .TimeBefore ).
ToIndex (mach .StateNames ()).NonZeroStates ()
mut := tx .Mutation
cfg := m .Cfg
match := (cfg .ChangedExclude || len (cfg .Changed ) == 0 ) &&
(cfg .CalledExclude || len (cfg .Called ) == 0 )
mTime := tx .TimeAfter
mTimeTracked := mTime .Filter (m .cacheTrackedIdxs )
mTimeTrackedBefore := tx .TimeBefore .Filter (m .cacheTrackedIdxs )
sum := mTime .Sum (nil )
sumTracked := mTimeTracked .Sum (nil )
for _ , name := range cfg .Called {
listed := slices .Contains (called , name )
if listed && cfg .CalledExclude {
match = false
break
} else if !listed && !cfg .CalledExclude {
match = true
break
}
}
for _ , name := range cfg .Changed {
listed := slices .Contains (changed , name )
if listed && cfg .ChangedExclude {
match = false
break
} else if !listed && !cfg .ChangedExclude {
match = true
break
}
}
if !match {
return
}
calledBt , _ := json .Marshal (tx .Mutation .Called )
args := tx .Mutation .MapArgs (mach .SemLogger ().ArgsMapper ())
argsBt , _ := json .Marshal (args )
mTimeBt , _ := json .Marshal (mTime )
mTimeTrackedBt , _ := json .Marshal (mTimeTracked )
mTimeTrackedDiffBt , _ := json .Marshal (
mTimeTracked .DiffSince (mTimeTrackedBefore ))
var recordDiff uint64
if t .mem .lastRec != nil {
recordDiff = sum - t .mem .lastRec .MTimeSum
}
machTick := mach .MachineTick ()
now := time .Now ().UTC ()
id := m .nextId .Load ()
timeRec := Time {
ID : id ,
MachineID : m .machRec .ID ,
MutType : mut .Type ,
MTimeSum : sum ,
MTimeTrackedSum : sumTracked ,
MTimeDiffSum : sum - tx .TimeBefore .Sum (nil ),
MTimeTrackedDiffSum : sumTracked - tx .TimeBefore .Sum (m .cacheTrackedIdxs ),
MTimeRecordDiffSum : recordDiff ,
HTime : now ,
MTimeTracked : mTimeTrackedBt ,
MTimeTrackedDiff : mTimeTrackedDiffBt ,
MachTick : machTick ,
cacheMTimeTracked : mTimeTracked ,
}
if cfg .StoreTransitions {
timeRec .TxId = tx .Id
timeRec .TxCalled = calledBt
timeRec .TxIsAuto = mut .IsAuto
timeRec .TxIsAccepted = tx .IsAccepted .Load ()
timeRec .TxIsCheck = mut .IsCheck
timeRec .TxIsBroken = tx .IsBroken .Load ()
timeRec .TxQueueLen = tx .QueueLen
if len (args ) > 0 {
j := datatypes .JSON (argsBt )
timeRec .TxArguments = &j
}
if mut .Source != nil {
timeRec .TxSourceMach = &mut .Source .MachId
timeRec .TxSourceTx = &mut .Source .TxId
}
if mut .QueueTick > 0 {
t := mut .QueueTick
timeRec .TxExecutedAt = &t
}
}
ticks := make ([]Tick , len (mTimeTracked ))
i := 0
for hIdx , state := range m .Cfg .TrackedStates {
isActive := am .IsActiveTick (mTimeTracked [hIdx ])
tickRec := Tick {
StateID : m .cacheDbIdxs [state ],
Tick : mTimeTracked [hIdx ],
TimeID : id ,
MachineID : m .machRec .ID ,
Active : isActive ,
}
if isActive &&
(m .lastRec == nil ||
m .lastRec .cacheMTimeTracked [hIdx ] != mTimeTracked [hIdx ]) {
tickRec .Activated = true
}
if !isActive &&
m .lastRec != nil &&
(m .lastRec .cacheMTimeTracked [hIdx ] != mTimeTracked [hIdx ]) {
tickRec .Deactivated = true
}
ticks [i ] = tickRec
i ++
}
m .queue .ticks = append (m .queue .ticks , ticks ...)
m .machRec .MTime = mTimeBt
m .machRec .MTimeSum = sum
m .machRec .LastSync = now
m .machRec .MachTick = machTick
m .machRec .NextId = id + 1
m .machRec .cacheMTime = mTime
m .nextId .Store (id + 1 )
m .queue .times = append (m .queue .times , timeRec )
m .SavePending .Add (1 )
m .lastRec = &timeRec
if m .SavePending .Load () >= m .Cfg .QueueBatch {
m .syncMx .RLock ()
m .writeDb (true )
m .checkGc ()
}
}
type queue struct {
times []Time
ticks []Tick
}
type Memory struct {
*amhist .BaseMemory
Db *gorm .DB
Cfg *Config
SavePending atomic .Int32
SaveInProgress atomic .Bool
Saved atomic .Uint64
SavedGc atomic .Uint64
savePool *errgroup .Group
syncMx sync .RWMutex
gcMx sync .RWMutex
disposed atomic .Bool
machRec *Machine
queue *queue
onErr func (err error )
mx sync .Mutex
cacheTrackedIdxs []int
cacheDbIdxs map [string ]uint
tr *tracer
lastRec *Time
nextId atomic .Uint64
}
func NewMemory (
ctx context .Context , db *gorm .DB , mach am .Api , cfg Config ,
onErr func (err error ),
) (*Memory , error ) {
err := db .AutoMigrate (
&Machine {}, &Time {}, &State {}, &Tick {},
)
if err != nil {
return nil , err
}
c := cfg
if c .MaxRecords <= 0 {
c .MaxRecords = 1000
}
if c .QueueBatch <= 0 {
c .QueueBatch = 10
}
if c .SavePool <= 0 {
c .SavePool = 10
}
if !c .CalledExclude {
c .TrackedStates = slices .Concat (c .TrackedStates , c .Called )
}
if !c .ChangedExclude {
c .TrackedStates = slices .Concat (c .TrackedStates , c .Changed )
}
if c .TrackedStates == nil {
c .TrackedStates = mach .StateNames ()
}
c .TrackedStates = mach .ParseStates (c .TrackedStates )
if len (c .TrackedStates ) == 0 {
return nil , fmt .Errorf ("%w: no states to track" , am .ErrStateMissing )
}
mem := &Memory {
Cfg : &c ,
Db : db ,
savePool : &errgroup .Group {},
onErr : onErr ,
queue : &queue {},
cacheTrackedIdxs : mach .Index (c .TrackedStates ),
cacheDbIdxs : make (map [string ]uint ),
}
mem .BaseMemory = amhist .NewBaseMemory (ctx , mach , cfg .BaseConfig , mem )
mem .savePool .SetLimit (c .SavePool )
tr := &tracer {
mem : mem ,
}
mem .tr = tr
tr .MachineInit (mach )
mach .OnDispose (func (id string , ctx context .Context ) {
err := mem .Dispose ()
if err != nil {
mem .onErr (err )
}
})
return mem , mach .BindTracer (tr )
}
func (m *Memory ) FindLatest (
ctx context .Context , retTx bool , limit int , query amhist .Query ,
) ([]*amhist .MemoryRecord , error ) {
if err := m .ValidateQuery (query ); err != nil {
return nil , err
}
s := query .Start
e := query .End
return m .Match (ctx , limit , func (now *am .TimeIndex , db *gorm .DB ) *gorm .DB {
joins := []string {}
for _ , state := range query .Active {
db = db .Where (state +".active = ?" , true )
joins = append (joins , state )
}
for _ , state := range query .Activated {
db = db .Where (state +".activated = ?" , true )
joins = append (joins , state )
}
for _ , state := range query .Inactive {
db = db .Where (state +".active = ?" , false )
joins = append (joins , state )
}
for _ , state := range query .Deactivated {
db = db .Where (state +".deactivated = ?" , true )
joins = append (joins , state )
}
for i , state := range s .MTimeStates {
db = db .Where (state +".m_time >= ?" , s .MTime [i ])
joins = append (joins , state )
}
for i , state := range e .MTimeStates {
db = db .Where (state +".m_time <= ?" , e .MTime [i ])
joins = append (joins , state )
}
for _ , state := range utils .SlicesUniq (joins ) {
db = m .JoinState (db , state )
}
if retTx && m .Cfg .StoreTransitions {
db = m .JoinTransition (db )
}
if !s .HTime .IsZero () && !e .HTime .IsZero () {
db = db .Where (
"times.h_time >= ? AND times.h_time <= ?" ,
s .HTime , e .HTime ,
)
}
if s .MTimeSum != 0 && e .MTimeSum != 0 {
db = db .Where (
"times.m_time_sum >= ? AND times.m_time_sum <= ?" ,
s .MTimeSum , e .MTimeSum ,
)
}
if s .MTimeTrackedSum != 0 && e .MTimeTrackedSum != 0 {
db = db .Where (
"times.m_time_tracked_sum >= ? AND times.m_time_tracked_sum <= ?" ,
s .MTimeTrackedSum , e .MTimeTrackedSum ,
)
}
if s .MTimeDiff != 0 && e .MTimeDiff != 0 {
db = db .Where (
"times.m_time_diff >= ? AND times.m_time_diff <= ?" ,
s .MTimeDiff , e .MTimeDiff ,
)
}
if s .MTimeTrackedDiff != 0 && e .MTimeTrackedDiff != 0 {
db = db .Where (
"times.m_time_tracked_diff >= ? AND times.m_time_tracked_diff <= ?" ,
s .MTimeTrackedDiff , e .MTimeTrackedDiff ,
)
}
if s .MTimeRecordDiff != 0 && e .MTimeRecordDiff != 0 {
db = db .Where (
"times.m_time_record_diff >= ? AND times.m_time_record_diff <= ?" ,
s .MTimeRecordDiff , e .MTimeRecordDiff ,
)
}
if s .MachTick != 0 && e .MachTick != 0 {
db = db .Where (
"times.mach_tick >= ? AND times.mach_tick <= ?" ,
s .MachTick , e .MachTick ,
)
}
if len (joins ) > 0 {
db = db .Order (joins [len (joins )-1 ] + ".time_id DESC" )
} else {
db = db .Order ("times.id DESC" )
}
return db
})
}
func (m *Memory ) MachineRecord () *amhist .MachineRecord {
m .mx .Lock ()
defer m .mx .Unlock ()
r := m .machRec
ret := &amhist .MachineRecord {
MachId : r .MachId ,
FirstTracking : r .FirstTracking ,
LastTracking : r .LastTracking ,
LastSync : r .LastSync ,
MachTick : r .MachTick ,
MTime : r .cacheMTime ,
MTimeSum : r .MTimeSum ,
NextId : r .NextId ,
}
if m .Config ().StoreSchema {
err := json .Unmarshal (r .Schema , &ret .Schema )
if err != nil {
m .onErr (err )
return nil
}
err = json .Unmarshal (r .StateNames , &ret .StateNames )
if err != nil {
m .onErr (err )
return nil
}
}
return ret
}
func (m *Memory ) Dispose () error {
if !m .disposed .CompareAndSwap (false , true ) {
return nil
}
m .mx .Lock ()
defer m .mx .Unlock ()
trErr := m .Mach .DetachTracer (m .tr )
if stdDb , err := m .Db .DB (); err != nil {
return errors .Join (err , trErr )
} else {
return errors .Join (stdDb .Close (), err )
}
}
func (m *Memory ) JoinState (query *gorm .DB , name string ) *gorm .DB {
sId := m .cacheDbIdxs [name ]
return query .Joins (utils .Sp (`
JOIN ticks ` +name +`
ON ` +name +`.time_id = times.id
AND ` +name +`.machine_id = times.machine_id
` )).Where (name +".state_id = ?" , sId )
}
func (m *Memory ) JoinTransition (query *gorm .DB ) *gorm .DB {
return query .Preload ("transitions" )
}
func (m *Memory ) Match (
ctx context .Context , limit int , matcherFn MatcherFn ,
) ([]*amhist .MemoryRecord , error ) {
var rows = []Time {}
q := m .Db .Model (&Time {})
mTime := m .Mach .Time (nil ).ToIndex (m .Mach .StateNames ())
q = matcherFn (mTime , q )
if limit > 0 {
q = q .Limit (limit )
}
err := q .WithContext (ctx ).
Find (&rows ).Error
if ctx .Err () != nil || m .Ctx .Err () != nil {
return nil , errors .Join (ctx .Err (), m .Ctx .Err ())
} else if err != nil {
return nil , err
} else if len (rows ) == 0 {
return nil , nil
}
var ret = make ([]*amhist .MemoryRecord , len (rows ))
for i , r := range rows {
tTracked := am .Time {}
if err := json .Unmarshal (r .MTimeTracked , &tTracked ); err != nil {
return nil , err
}
tTrackedDiff := am .Time {}
if err := json .Unmarshal (r .MTimeTracked , &tTrackedDiff ); err != nil {
return nil , err
}
ret [i ] = &amhist .MemoryRecord {
Time : &amhist .TimeRecord {
MutType : r .MutType ,
MTimeSum : r .MTimeSum ,
MTimeTrackedSum : r .MTimeTrackedSum ,
MTimeDiffSum : r .MTimeDiffSum ,
MTimeTrackedDiffSum : r .MTimeTrackedDiffSum ,
MTimeRecordDiffSum : r .MTimeRecordDiffSum ,
MTimeTracked : tTracked ,
MTimeTrackedDiff : tTrackedDiff ,
HTime : r .HTime ,
MachTick : r .MachTick ,
},
}
if m .Cfg .StoreTransitions {
ret [i ].Transition = &amhist .TransitionRecord {
TransitionId : r .TxId ,
IsAuto : r .TxIsAuto ,
IsAccepted : r .TxIsAccepted ,
IsCheck : r .TxIsCheck ,
IsBroken : r .TxIsBroken ,
QueueLen : r .TxQueueLen ,
Called : nil ,
Arguments : nil ,
}
retTx := ret [i ].Transition
if r .TxSourceTx != nil && r .TxSourceMach != nil {
retTx .SourceTx = *r .TxSourceTx
retTx .SourceMach = *r .TxSourceMach
}
if r .TxQueuedAt != nil && r .TxExecutedAt != nil {
retTx .QueuedAt = *r .TxQueuedAt
retTx .ExecutedAt = *r .TxExecutedAt
}
if r .TxCalled != nil {
if err := json .Unmarshal (r .TxCalled , &retTx .Called ); err != nil {
return nil , err
}
}
if r .TxArguments != nil {
if err := json .Unmarshal (*r .TxArguments , &retTx .Arguments ); err != nil {
return nil , err
}
}
}
}
return ret , nil
}
func (m *Memory ) checkGc () {
sinceLastGc := m .SavedGc .Load ()
now := m .Saved .Load ()
if float32 (now -sinceLastGc ) <= float32 (m .Cfg .MaxRecords )*1.5 ||
!m .gcMx .TryLock () {
return
}
timeDb := gorm .G [Time ](m .Db )
_ , err := timeDb .
Where ("machine_id = ?" , m .machRec .ID ).
Where ("id NOT IN (?)" , timeDb .
Select ("id" ).
Where ("machine_id = ?" , m .machRec .ID ).
Order ("id DESC" ).
Limit (m .Cfg .MaxRecords )).
Delete (m .Ctx )
if err != nil {
m .onErr (fmt .Errorf ("failed to GC: %w" , err ))
}
m .SavedGc .Store (m .Saved .Load ())
}
func (m *Memory ) Config () amhist .BaseConfig {
return m .Cfg .BaseConfig
}
func (m *Memory ) Machine () am .Api {
return m .Mach
}
func (m *Memory ) Sync () error {
m .log ("sync..." )
m .mx .Lock ()
defer m .mx .Unlock ()
m .syncMx .Lock ()
defer m .syncMx .Unlock ()
m .writeDb (false )
m .log ("sync OK" )
return nil
}
func (m *Memory ) writeDb (rLocked bool ) {
if m .SavePending .Load () <= 0 {
return
}
q := m .queue
machRec := *m .machRec
times := q .times
q .times = nil
ticks := q .ticks
q .ticks = nil
m .log ("writeDb for %d record" , len (times ))
l := len (times )
m .SavePending .Add (-int32 (l ))
go m .savePool .Go (func () error {
if m .disposed .Load () {
return nil
}
if rLocked {
defer m .syncMx .RUnlock ()
}
if err := m .Db .Save (machRec ).Error ; err != nil {
m .onErr (fmt .Errorf ("failed to save: %w" , err ))
return err
}
dbTimes := gorm .G [Time ](m .Db )
err := dbTimes .CreateInBatches (m .Mach .Ctx (), × , 100 )
if err != nil {
m .onErr (err )
return err
}
dbTicks := gorm .G [Tick ](m .Db )
err = dbTicks .CreateInBatches (m .Mach .Ctx (), &ticks , 100 )
if err != nil {
m .onErr (err )
return err
}
return nil
})
}
func (m *Memory ) log (msg string , args ...any ) {
if !m .Cfg .Log {
return
}
log .Printf (msg , args ...)
}
func NewDb (name string , debug bool ) (*gorm .DB , *sql .DB , error ) {
if name == "" {
name = "amhist"
}
cfg := logger .Config {
SlowThreshold : time .Second ,
LogLevel : logger .Silent ,
Colorful : true ,
IgnoreRecordNotFoundError : true ,
}
if debug {
cfg .LogLevel = logger .Info
cfg .IgnoreRecordNotFoundError = false
}
dbG , err := gorm .Open (gormlite .Open (name +".sqlite" ), &gorm .Config {
Logger : logger .New (log .New (os .Stdout , "\r\n" , log .LstdFlags ), cfg ),
})
if err != nil {
return nil , nil , err
} else if dbSql , err := dbG .DB (); err != nil {
return nil , nil , err
} else {
if !vfs .SupportsSharedMemory {
if err = dbG .Exec (`PRAGMA locking_mode=exclusive` ).Error ; err != nil {
return nil , nil , err
}
}
if err = dbG .Exec (`PRAGMA journal_mode=wal;` ).Error ; err != nil {
return nil , nil , err
}
return dbG , dbSql , err
}
}
func GetMachine (db *gorm .DB , id string , inclStates bool ) (*Machine , error ) {
var m Machine
q := db .Where ("mach_id = ?" , id )
if inclStates {
q = q .Preload ("States" )
}
if err := q .First (&m ).Error ; err != nil {
return nil , err
}
return &m , nil
}
func ListMachines (db *gorm .DB ) ([]*amhist .MachineRecord , error ) {
panic ("not implemented" )
}
The pages are generated with Golds v0.8.2 . (GOOS=linux GOARCH=amd64)
Golds is a Go 101 project developed by Tapir Liu .
PR and bug reports are welcome and can be submitted to the issue list .
Please follow @zigo_101 (reachable from the left QR code) to get the latest news of Golds .