package badger
import (
"bytes"
"context"
"encoding/hex"
"errors"
"fmt"
"math"
"sort"
"strconv"
"sync"
"sync/atomic"
"github.com/dgraph-io/badger/v4/y"
"github.com/dgraph-io/ristretto/v2/z"
)
type oracle struct {
isManaged bool
detectConflicts bool
sync .Mutex
writeChLock sync .Mutex
nextTxnTs uint64
txnMark *y .WaterMark
discardTs uint64
readMark *y .WaterMark
committedTxns []committedTxn
lastCleanupTs uint64
closer *z .Closer
}
type committedTxn struct {
ts uint64
conflictKeys map [uint64 ]struct {}
}
func newOracle(opt Options ) *oracle {
orc := &oracle {
isManaged : opt .managedTxns ,
detectConflicts : opt .DetectConflicts ,
readMark : &y .WaterMark {Name : "badger.PendingReads" },
txnMark : &y .WaterMark {Name : "badger.TxnTimestamp" },
closer : z .NewCloser (2 ),
}
orc .readMark .Init (orc .closer )
orc .txnMark .Init (orc .closer )
return orc
}
func (o *oracle ) Stop () {
o .closer .SignalAndWait ()
}
func (o *oracle ) readTs () uint64 {
if o .isManaged {
panic ("ReadTs should not be retrieved for managed DB" )
}
var readTs uint64
o .Lock ()
readTs = o .nextTxnTs - 1
o .readMark .Begin (readTs )
o .Unlock ()
y .Check (o .txnMark .WaitForMark (context .Background (), readTs ))
return readTs
}
func (o *oracle ) nextTs () uint64 {
o .Lock ()
defer o .Unlock ()
return o .nextTxnTs
}
func (o *oracle ) incrementNextTs () {
o .Lock ()
defer o .Unlock ()
o .nextTxnTs ++
}
func (o *oracle ) setDiscardTs (ts uint64 ) {
o .Lock ()
defer o .Unlock ()
o .discardTs = ts
o .cleanupCommittedTransactions ()
}
func (o *oracle ) discardAtOrBelow () uint64 {
if o .isManaged {
o .Lock ()
defer o .Unlock ()
return o .discardTs
}
return o .readMark .DoneUntil ()
}
func (o *oracle ) hasConflict (txn *Txn ) bool {
if len (txn .reads ) == 0 {
return false
}
for _ , committedTxn := range o .committedTxns {
if committedTxn .ts <= txn .readTs {
continue
}
for _ , ro := range txn .reads {
if _ , has := committedTxn .conflictKeys [ro ]; has {
return true
}
}
}
return false
}
func (o *oracle ) newCommitTs (txn *Txn ) (uint64 , bool ) {
o .Lock ()
defer o .Unlock ()
if o .hasConflict (txn ) {
return 0 , true
}
var ts uint64
if !o .isManaged {
o .doneRead (txn )
o .cleanupCommittedTransactions ()
ts = o .nextTxnTs
o .nextTxnTs ++
o .txnMark .Begin (ts )
} else {
ts = txn .commitTs
}
y .AssertTrue (ts >= o .lastCleanupTs )
if o .detectConflicts {
o .committedTxns = append (o .committedTxns , committedTxn {
ts : ts ,
conflictKeys : txn .conflictKeys ,
})
}
return ts , false
}
func (o *oracle ) doneRead (txn *Txn ) {
if !txn .doneRead {
txn .doneRead = true
o .readMark .Done (txn .readTs )
}
}
func (o *oracle ) cleanupCommittedTransactions () {
if !o .detectConflicts {
return
}
var maxReadTs uint64
if o .isManaged {
maxReadTs = o .discardTs
} else {
maxReadTs = o .readMark .DoneUntil ()
}
y .AssertTrue (maxReadTs >= o .lastCleanupTs )
if maxReadTs == o .lastCleanupTs {
return
}
o .lastCleanupTs = maxReadTs
tmp := o .committedTxns [:0 ]
for _ , txn := range o .committedTxns {
if txn .ts <= maxReadTs {
continue
}
tmp = append (tmp , txn )
}
o .committedTxns = tmp
}
func (o *oracle ) doneCommit (cts uint64 ) {
if o .isManaged {
return
}
o .txnMark .Done (cts )
}
type Txn struct {
readTs uint64
commitTs uint64
size int64
count int64
db *DB
reads []uint64
conflictKeys map [uint64 ]struct {}
readsLock sync .Mutex
pendingWrites map [string ]*Entry
duplicateWrites []*Entry
numIterators atomic .Int32
discarded bool
doneRead bool
update bool
}
type pendingWritesIterator struct {
entries []*Entry
nextIdx int
readTs uint64
reversed bool
}
func (pi *pendingWritesIterator ) Next () {
pi .nextIdx ++
}
func (pi *pendingWritesIterator ) Rewind () {
pi .nextIdx = 0
}
func (pi *pendingWritesIterator ) Seek (key []byte ) {
key = y .ParseKey (key )
pi .nextIdx = sort .Search (len (pi .entries ), func (idx int ) bool {
cmp := bytes .Compare (pi .entries [idx ].Key , key )
if !pi .reversed {
return cmp >= 0
}
return cmp <= 0
})
}
func (pi *pendingWritesIterator ) Key () []byte {
y .AssertTrue (pi .Valid ())
entry := pi .entries [pi .nextIdx ]
return y .KeyWithTs (entry .Key , pi .readTs )
}
func (pi *pendingWritesIterator ) Value () y .ValueStruct {
y .AssertTrue (pi .Valid ())
entry := pi .entries [pi .nextIdx ]
return y .ValueStruct {
Value : entry .Value ,
Meta : entry .meta ,
UserMeta : entry .UserMeta ,
ExpiresAt : entry .ExpiresAt ,
Version : pi .readTs ,
}
}
func (pi *pendingWritesIterator ) Valid () bool {
return pi .nextIdx < len (pi .entries )
}
func (pi *pendingWritesIterator ) Close () error {
return nil
}
func (txn *Txn ) newPendingWritesIterator (reversed bool ) *pendingWritesIterator {
if !txn .update || len (txn .pendingWrites ) == 0 {
return nil
}
entries := make ([]*Entry , 0 , len (txn .pendingWrites ))
for _ , e := range txn .pendingWrites {
entries = append (entries , e )
}
sort .Slice (entries , func (i , j int ) bool {
cmp := bytes .Compare (entries [i ].Key , entries [j ].Key )
if !reversed {
return cmp < 0
}
return cmp > 0
})
return &pendingWritesIterator {
readTs : txn .readTs ,
entries : entries ,
reversed : reversed ,
}
}
func (txn *Txn ) checkSize (e *Entry ) error {
count := txn .count + 1
size := txn .size + e .estimateSizeAndSetThreshold (txn .db .valueThreshold ()) + 10
if count >= txn .db .opt .maxBatchCount || size >= txn .db .opt .maxBatchSize {
return ErrTxnTooBig
}
txn .count , txn .size = count , size
return nil
}
func exceedsSize(prefix string , max int64 , key []byte ) error {
return fmt .Errorf ("%s with size %d exceeded %d limit. %s:\n%s" ,
prefix , len (key ), max , prefix , hex .Dump (key [:1 <<10 ]))
}
func (txn *Txn ) modify (e *Entry ) error {
const maxKeySize = 65000
switch {
case !txn .update :
return ErrReadOnlyTxn
case txn .discarded :
return ErrDiscardedTxn
case len (e .Key ) == 0 :
return ErrEmptyKey
case bytes .HasPrefix (e .Key , badgerPrefix ):
return ErrInvalidKey
case len (e .Key ) > maxKeySize :
return exceedsSize ("Key" , maxKeySize , e .Key )
case int64 (len (e .Value )) > txn .db .opt .ValueLogFileSize :
return exceedsSize ("Value" , txn .db .opt .ValueLogFileSize , e .Value )
case txn .db .opt .InMemory && int64 (len (e .Value )) > txn .db .valueThreshold ():
return exceedsSize ("Value" , txn .db .valueThreshold (), e .Value )
}
if err := txn .db .isBanned (e .Key ); err != nil {
return err
}
if err := txn .checkSize (e ); err != nil {
return err
}
if txn .db .opt .DetectConflicts {
fp := z .MemHash (e .Key )
txn .conflictKeys [fp ] = struct {}{}
}
if oldEntry , ok := txn .pendingWrites [string (e .Key )]; ok && oldEntry .version != e .version {
txn .duplicateWrites = append (txn .duplicateWrites , oldEntry )
}
txn .pendingWrites [string (e .Key )] = e
return nil
}
func (txn *Txn ) Set (key , val []byte ) error {
return txn .SetEntry (NewEntry (key , val ))
}
func (txn *Txn ) SetEntry (e *Entry ) error {
return txn .modify (e )
}
func (txn *Txn ) Delete (key []byte ) error {
e := &Entry {
Key : key ,
meta : bitDelete ,
}
return txn .modify (e )
}
func (txn *Txn ) Get (key []byte ) (item *Item , rerr error ) {
if len (key ) == 0 {
return nil , ErrEmptyKey
} else if txn .discarded {
return nil , ErrDiscardedTxn
}
if err := txn .db .isBanned (key ); err != nil {
return nil , err
}
item = new (Item )
if txn .update {
if e , has := txn .pendingWrites [string (key )]; has && bytes .Equal (key , e .Key ) {
if isDeletedOrExpired (e .meta , e .ExpiresAt ) {
return nil , ErrKeyNotFound
}
item .meta = e .meta
item .val = e .Value
item .userMeta = e .UserMeta
item .key = key
item .status = prefetched
item .version = txn .readTs
item .expiresAt = e .ExpiresAt
return item , nil
}
txn .addReadKey (key )
}
seek := y .KeyWithTs (key , txn .readTs )
vs , err := txn .db .get (seek )
if err != nil {
return nil , y .Wrapf (err , "DB::Get key: %q" , key )
}
if vs .Value == nil && vs .Meta == 0 {
return nil , ErrKeyNotFound
}
if isDeletedOrExpired (vs .Meta , vs .ExpiresAt ) {
return nil , ErrKeyNotFound
}
item .key = key
item .version = vs .Version
item .meta = vs .Meta
item .userMeta = vs .UserMeta
item .vptr = y .SafeCopy (item .vptr , vs .Value )
item .txn = txn
item .expiresAt = vs .ExpiresAt
return item , nil
}
func (txn *Txn ) addReadKey (key []byte ) {
if txn .update {
fp := z .MemHash (key )
txn .readsLock .Lock ()
txn .reads = append (txn .reads , fp )
txn .readsLock .Unlock ()
}
}
func (txn *Txn ) Discard () {
if txn .discarded {
return
}
if txn .numIterators .Load () > 0 {
panic ("Unclosed iterator at time of Txn.Discard." )
}
txn .discarded = true
if !txn .db .orc .isManaged {
txn .db .orc .doneRead (txn )
}
}
func (txn *Txn ) commitAndSend () (func () error , error ) {
orc := txn .db .orc
orc .writeChLock .Lock ()
defer orc .writeChLock .Unlock ()
commitTs , conflict := orc .newCommitTs (txn )
if conflict {
return nil , ErrConflict
}
keepTogether := true
setVersion := func (e *Entry ) {
if e .version == 0 {
e .version = commitTs
} else {
keepTogether = false
}
}
for _ , e := range txn .pendingWrites {
setVersion (e )
}
for _ , e := range txn .duplicateWrites {
setVersion (e )
}
entries := make ([]*Entry , 0 , len (txn .pendingWrites )+len (txn .duplicateWrites )+1 )
processEntry := func (e *Entry ) {
e .Key = y .KeyWithTs (e .Key , e .version )
if keepTogether {
e .meta |= bitTxn
}
entries = append (entries , e )
}
for _ , e := range txn .pendingWrites {
processEntry (e )
}
for _ , e := range txn .duplicateWrites {
processEntry (e )
}
if keepTogether {
y .AssertTrue (commitTs != 0 )
e := &Entry {
Key : y .KeyWithTs (txnKey , commitTs ),
Value : []byte (strconv .FormatUint (commitTs , 10 )),
meta : bitFinTxn ,
}
entries = append (entries , e )
}
req , err := txn .db .sendToWriteCh (entries )
if err != nil {
orc .doneCommit (commitTs )
return nil , err
}
ret := func () error {
err := req .Wait ()
orc .doneCommit (commitTs )
return err
}
return ret , nil
}
func (txn *Txn ) commitPrecheck () error {
if txn .discarded {
return errors .New ("Trying to commit a discarded txn" )
}
keepTogether := true
for _ , e := range txn .pendingWrites {
if e .version != 0 {
keepTogether = false
}
}
if keepTogether && txn .db .opt .managedTxns && txn .commitTs == 0 {
return errors .New ("CommitTs cannot be zero. Please use commitAt instead" )
}
return nil
}
func (txn *Txn ) Commit () error {
if len (txn .pendingWrites ) == 0 {
txn .Discard ()
return nil
}
if err := txn .commitPrecheck (); err != nil {
return err
}
defer txn .Discard ()
txnCb , err := txn .commitAndSend ()
if err != nil {
return err
}
return txnCb ()
}
type txnCb struct {
commit func () error
user func (error )
err error
}
func runTxnCallback(cb *txnCb ) {
switch {
case cb == nil :
panic ("txn callback is nil" )
case cb .user == nil :
panic ("Must have caught a nil callback for txn.CommitWith" )
case cb .err != nil :
cb .user (cb .err )
case cb .commit != nil :
err := cb .commit ()
cb .user (err )
default :
cb .user (nil )
}
}
func (txn *Txn ) CommitWith (cb func (error )) {
if cb == nil {
panic ("Nil callback provided to CommitWith" )
}
if len (txn .pendingWrites ) == 0 {
go runTxnCallback (&txnCb {user : cb , err : nil })
txn .Discard ()
return
}
if err := txn .commitPrecheck (); err != nil {
cb (err )
return
}
defer txn .Discard ()
commitCb , err := txn .commitAndSend ()
if err != nil {
go runTxnCallback (&txnCb {user : cb , err : err })
return
}
go runTxnCallback (&txnCb {user : cb , commit : commitCb })
}
func (txn *Txn ) ReadTs () uint64 {
return txn .readTs
}
func (db *DB ) NewTransaction (update bool ) *Txn {
return db .newTransaction (update , false )
}
func (db *DB ) newTransaction (update , isManaged bool ) *Txn {
if db .opt .ReadOnly && update {
update = false
}
txn := &Txn {
update : update ,
db : db ,
count : 1 ,
size : int64 (len (txnKey ) + 10 ),
}
if update {
if db .opt .DetectConflicts {
txn .conflictKeys = make (map [uint64 ]struct {})
}
txn .pendingWrites = make (map [string ]*Entry )
}
if !isManaged {
txn .readTs = db .orc .readTs ()
}
return txn
}
func (db *DB ) View (fn func (txn *Txn ) error ) error {
if db .IsClosed () {
return ErrDBClosed
}
var txn *Txn
if db .opt .managedTxns {
txn = db .NewTransactionAt (math .MaxUint64 , false )
} else {
txn = db .NewTransaction (false )
}
defer txn .Discard ()
return fn (txn )
}
func (db *DB ) Update (fn func (txn *Txn ) error ) error {
if db .IsClosed () {
return ErrDBClosed
}
if db .opt .managedTxns {
panic ("Update can only be used with managedDB=false." )
}
txn := db .NewTransaction (true )
defer txn .Discard ()
if err := fn (txn ); err != nil {
return err
}
return txn .Commit ()
}
The pages are generated with Golds v0.8.4 . (GOOS=linux GOARCH=amd64)
Golds is a Go 101 project developed by Tapir Liu .
PR and bug reports are welcome and can be submitted to the issue list .
Please follow @zigo_101 (reachable from the left QR code) to get the latest news of Golds .