package physicalplan
import (
"context"
"fmt"
"math"
"math/rand"
"slices"
"github.com/apache/arrow-go/v18/arrow"
"github.com/apache/arrow-go/v18/arrow/array"
"github.com/apache/arrow-go/v18/arrow/memory"
"github.com/apache/arrow-go/v18/arrow/util"
"github.com/polarsignals/frostdb/pqarrow/builder"
)
type ReservoirSampler struct {
next PhysicalPlan
allocator memory .Allocator
size int64
sizeInBytes int64
sizeLimit int64
reservoir []sample
w float64
n int64
i float64
}
type sample struct {
i int64
ref *referencedRecord
}
type referencedRecord struct {
arrow .Record
size int64
ref int64
}
func (s *referencedRecord ) Release () int64 {
defer s .Record .Release ()
s .ref --
if s .ref == 0 {
return s .size
}
return 0
}
func (s *referencedRecord ) Retain () int64 {
defer s .Record .Retain ()
s .ref ++
if s .ref == 1 {
return s .size
}
return 0
}
func NewReservoirSampler (size , limit int64 , allocator memory .Allocator ) *ReservoirSampler {
return &ReservoirSampler {
size : size ,
sizeLimit : limit ,
w : math .Exp (math .Log (rand .Float64 ()) / float64 (size )),
allocator : allocator ,
}
}
func (s *ReservoirSampler ) SetNext (p PhysicalPlan ) {
s .next = p
}
func (s *ReservoirSampler ) Draw () *Diagram {
var child *Diagram
if s .next != nil {
child = s .next .Draw ()
}
details := fmt .Sprintf ("Reservoir Sampler (%v)" , s .size )
return &Diagram {Details : details , Child : child }
}
func (s *ReservoirSampler ) Close () {
for _ , r := range s .reservoir {
s .sizeInBytes -= r .ref .Release ()
}
s .next .Close ()
}
func (s *ReservoirSampler ) Callback (_ context .Context , r arrow .Record ) error {
var ref *referencedRecord
r , ref = s .fill (r )
if r == nil {
return nil
}
if s .n == s .size {
s .sliceReservoir ()
}
s .sample (r , ref )
if s .sizeInBytes >= s .sizeLimit {
if err := s .materialize (s .allocator ); err != nil {
return err
}
}
return nil
}
func refPtr(r arrow .Record ) *referencedRecord {
return &referencedRecord {Record : r , size : util .TotalRecordSize (r )}
}
func (s *ReservoirSampler ) fill (r arrow .Record ) (arrow .Record , *referencedRecord ) {
if s .n >= s .size {
return r , refPtr (r )
}
if s .n +r .NumRows () <= s .size {
smpl := sample {
i : -1 ,
ref : refPtr (r ),
}
s .reservoir = append (s .reservoir , smpl )
s .sizeInBytes += smpl .ref .Retain ()
s .n += r .NumRows ()
return nil , nil
}
ref := refPtr (r )
smpl := sample {
i : -1 ,
ref : refPtr (r .NewSlice (0 , s .size -s .n )),
}
s .reservoir = append (s .reservoir , smpl )
s .sizeInBytes += smpl .ref .Retain ()
r = r .NewSlice (s .size -s .n , r .NumRows ())
s .n = s .size
return r , ref
}
func (s *ReservoirSampler ) sliceReservoir () {
newReservoir := make ([]sample , 0 , s .size )
for _ , r := range s .reservoir {
ref := refPtr (r .ref .Record )
for j := int64 (0 ); j < r .ref .NumRows (); j ++ {
smpl := sample {
i : j ,
ref : ref ,
}
newReservoir = append (newReservoir , smpl )
s .sizeInBytes += smpl .ref .Retain ()
}
s .sizeInBytes -= r .ref .Release ()
}
s .reservoir = newReservoir
}
func (s *ReservoirSampler ) sample (r arrow .Record , ref *referencedRecord ) {
if s .size == 0 {
return
}
n := s .n + r .NumRows ()
if s .i == 0 {
s .i = float64 (s .n ) - 1
} else if s .i < float64 (n ) {
s .replace (rand .Intn (int (s .size )), sample {i : int64 (s .i ) - s .n , ref : ref })
s .w = s .w * math .Exp (math .Log (rand .Float64 ())/float64 (s .size ))
}
for s .i < float64 (n ) {
s .i += math .Floor (math .Log (rand .Float64 ())/math .Log (1 -s .w )) + 1
if s .i < float64 (n ) {
s .replace (rand .Intn (int (s .size )), sample {i : int64 (s .i ) - s .n , ref : ref })
s .w = s .w * math .Exp (math .Log (rand .Float64 ())/float64 (s .size ))
}
}
s .n = n
}
func (s *ReservoirSampler ) Finish (ctx context .Context ) error {
for _ , r := range s .reservoir {
if r .i == -1 {
if err := s .next .Callback (ctx , r .ref .Record ); err != nil {
return err
}
continue
}
record := r .ref .NewSlice (r .i , r .i +1 )
defer record .Release ()
if err := s .next .Callback (ctx , record ); err != nil {
return err
}
}
return s .next .Finish (ctx )
}
func (s *ReservoirSampler ) replace (i int , newRow sample ) {
s .sizeInBytes -= s .reservoir [i ].ref .Release ()
s .reservoir [i ] = newRow
s .sizeInBytes += newRow .ref .Retain ()
}
func (s *ReservoirSampler ) materialize (allocator memory .Allocator ) error {
schema := s .reservoir [0 ].ref .Schema ()
fields := schema .Fields ()
added := map [string ]struct {}{}
for i := 1 ; i < len (s .reservoir ); i ++ {
for j := 0 ; j < s .reservoir [i ].ref .Schema ().NumFields (); j ++ {
newField := s .reservoir [i ].ref .Schema ().Field (j ).Name
if _ , ok := added [newField ]; !ok && !schema .HasField (s .reservoir [i ].ref .Schema ().Field (j ).Name ) {
fields = append (fields , s .reservoir [i ].ref .Schema ().Field (j ))
added [newField ] = struct {}{}
}
}
}
slices .SortFunc (fields , func (i , j arrow .Field ) int {
switch {
case i .Name < j .Name :
return -1
case i .Name > j .Name :
return 1
default :
return 0
}
})
schema = arrow .NewSchema (fields , nil )
bldr := array .NewRecordBuilder (allocator , schema )
defer bldr .Release ()
for _ , r := range s .reservoir {
for i , f := range bldr .Fields () {
if !r .ref .Schema ().HasField (schema .Field (i ).Name ) {
if err := builder .AppendValue (f , nil , -1 ); err != nil {
return err
}
} else {
if err := builder .AppendValue (f , r .ref .Column (i ), int (r .i )); err != nil {
return err
}
}
}
}
for _ , r := range s .reservoir {
s .sizeInBytes -= r .ref .Release ()
}
smpl := sample {i : -1 , ref : refPtr (bldr .NewRecord ())}
s .sizeInBytes += smpl .ref .Retain ()
smpl .ref .Record .Release ()
s .reservoir = []sample {smpl }
s .sliceReservoir ()
return nil
}
The pages are generated with Golds v0.8.2 . (GOOS=linux GOARCH=amd64)
Golds is a Go 101 project developed by Tapir Liu .
PR and bug reports are welcome and can be submitted to the issue list .
Please follow @zigo_101 (reachable from the left QR code) to get the latest news of Golds .