package physicalplan
import (
"context"
"errors"
"fmt"
"strings"
"github.com/apache/arrow-go/v18/arrow"
"github.com/apache/arrow-go/v18/arrow/array"
"github.com/apache/arrow-go/v18/arrow/memory"
"go.opentelemetry.io/otel/trace"
"github.com/polarsignals/frostdb/pqarrow/arrowutils"
"github.com/polarsignals/frostdb/pqarrow/builder"
"github.com/polarsignals/frostdb/query/logicalplan"
)
type groupColInfo struct {
field arrow .Field
arr arrow .Array
}
type OrderedAggregate struct {
pool memory .Allocator
tracer trace .Tracer
resultColumnName string
groupByColumnMatchers []logicalplan .Expr
aggregationFunction logicalplan .AggFunc
next PhysicalPlan
columnToAggregate logicalplan .Expr
finalStage bool
groupColOrdering []arrow .Field
notFirstCall bool
curGroup map [string ]any
groupBuilders map [string ]builder .ColumnBuilder
groupResults [][]arrow .Array
arrayToAggCarry builder .ColumnBuilder
aggResultBuilder arrowutils .ArrayConcatenator
aggregationResults []arrow .Array
scratch struct {
groupByMap map [string ]groupColInfo
groupByArrays []arrow .Array
curGroup []any
indexes []int64
}
}
func NewOrderedAggregate (
pool memory .Allocator ,
tracer trace .Tracer ,
aggregation Aggregation ,
groupByColumnMatchers []logicalplan .Expr ,
finalStage bool ,
) *OrderedAggregate {
o := &OrderedAggregate {
pool : pool ,
tracer : tracer ,
resultColumnName : aggregation .resultName ,
columnToAggregate : aggregation .expr ,
groupByColumnMatchers : groupByColumnMatchers ,
aggregationFunction : aggregation .function ,
finalStage : finalStage ,
curGroup : make (map [string ]any , 10 ),
groupBuilders : make (map [string ]builder .ColumnBuilder ),
aggregationResults : make ([]arrow .Array , 0 , 1 ),
}
o .scratch .groupByMap = make (map [string ]groupColInfo , 10 )
o .scratch .groupByArrays = make ([]arrow .Array , 0 , 10 )
o .scratch .curGroup = make ([]any , 0 , 10 )
return o
}
func (a *OrderedAggregate ) Close () {
a .next .Close ()
}
func (a *OrderedAggregate ) SetNext (next PhysicalPlan ) {
a .next = next
}
func (a *OrderedAggregate ) Draw () *Diagram {
var child *Diagram
if a .next != nil {
child = a .next .Draw ()
}
var groupings []string
for _ , grouping := range a .groupByColumnMatchers {
groupings = append (groupings , grouping .Name ())
}
details := fmt .Sprintf (
"OrderedAggregate (%s by %s)" ,
a .columnToAggregate .Name (),
strings .Join (groupings , "," ),
)
return &Diagram {Details : details , Child : child }
}
func (a *OrderedAggregate ) Callback (_ context .Context , r arrow .Record ) error {
for k := range a .scratch .groupByMap {
delete (a .scratch .groupByMap , k )
}
var columnToAggregate arrow .Array
aggregateFieldFound := false
foundNewColumns := false
for i := 0 ; i < r .Schema ().NumFields (); i ++ {
field := r .Schema ().Field (i )
for _ , matcher := range a .groupByColumnMatchers {
if matcher .MatchColumn (field .Name ) {
a .scratch .groupByMap [field .Name ] = groupColInfo {field : field , arr : r .Column (i )}
if _ , ok := a .groupBuilders [field .Name ]; !ok {
a .groupColOrdering = append (a .groupColOrdering , field )
b := builder .NewBuilder (a .pool , field .Type )
a .groupBuilders [field .Name ] = b
foundNewColumns = true
if a .notFirstCall {
for i := 0 ; i < a .groupBuilders [a .groupColOrdering [0 ].Name ].Len (); i ++ {
b .AppendNull ()
}
}
}
}
}
if a .columnToAggregate .MatchColumn (field .Name ) {
columnToAggregate = r .Column (i )
if a .arrayToAggCarry == nil {
a .arrayToAggCarry = builder .NewBuilder (a .pool , columnToAggregate .DataType ())
}
aggregateFieldFound = true
}
}
if !aggregateFieldFound {
return errors .New ("aggregate field not found, aggregations are not possible without it" )
}
if foundNewColumns {
for i := range a .groupResults {
for j := len (a .groupResults [i ]); j < len (a .groupColOrdering ); j ++ {
a .groupResults [i ] = append (
a .groupResults [i ],
arrowutils .MakeNullArray (
a .pool ,
a .groupColOrdering [j ].Type ,
a .groupResults [i ][0 ].Len (),
),
)
}
}
}
a .scratch .groupByArrays = a .scratch .groupByArrays [:0 ]
a .scratch .curGroup = a .scratch .curGroup [:0 ]
for _ , field := range a .groupColOrdering {
info , ok := a .scratch .groupByMap [field .Name ]
var arr arrow .Array
if !ok {
arr = arrowutils .MakeVirtualNullArray (field .Type , int (r .NumRows ()))
} else {
arr = info .arr
}
a .scratch .groupByArrays = append (a .scratch .groupByArrays , arr )
if !a .notFirstCall {
v := arr .GetOneForMarshal (0 )
switch concreteV := v .(type ) {
case []byte :
a .curGroup [field .Name ] = append ([]byte (nil ), concreteV ...)
default :
a .curGroup [field .Name ] = v
}
}
a .scratch .curGroup = append (a .scratch .curGroup , a .curGroup [field .Name ])
}
a .notFirstCall = true
groupRanges , wrappedSetRanges , lastGroup , err := arrowutils .GetGroupsAndOrderedSetRanges (
a .scratch .curGroup ,
a .scratch .groupByArrays ,
)
if err != nil {
return err
}
defer func () {
for i , v := range lastGroup {
a .curGroup [a .groupColOrdering [i ].Name ] = v
}
}()
setRanges := wrappedSetRanges .Unwrap (a .scratch .indexes )
arraysToAggregate := make ([]arrow .Array , 0 , groupRanges .Len ())
var arraysToAggregateSetIdxs []int64
for groupStart , setCursor := int64 (-1 ), 0 ; ; {
groupEnd , groupOk := groupRanges .PopNextNotEqual (groupStart )
if groupStart == -1 {
groupStart = 0
}
if !groupOk {
if err := builder .AppendArray (
a .arrayToAggCarry ,
array .NewSlice (columnToAggregate , groupStart , int64 (columnToAggregate .Len ())),
); err != nil {
return err
}
break
}
var toAgg arrow .Array
if groupEnd == 0 {
toAgg = a .arrayToAggCarry .NewArray ()
} else {
toAgg = array .NewSlice (columnToAggregate , groupStart , groupEnd )
if a .arrayToAggCarry .Len () > 0 {
if err := builder .AppendArray (a .arrayToAggCarry , toAgg ); err != nil {
return err
}
toAgg = a .arrayToAggCarry .NewArray ()
}
}
arraysToAggregate = append (arraysToAggregate , toAgg )
newOrderedSet := false
if len (setRanges ) > 0 && setCursor < len (setRanges ) && setRanges [setCursor ] == groupEnd {
setCursor ++
newOrderedSet = true
arraysToAggregateSetIdxs = append (arraysToAggregateSetIdxs , int64 (len (arraysToAggregate )))
a .groupResults = append (a .groupResults , nil )
}
for i , field := range a .groupColOrdering {
var v any
if groupEnd == 0 {
v = a .curGroup [field .Name ]
} else {
v = a .scratch .groupByArrays [i ].GetOneForMarshal (int (groupStart ))
}
if err := builder .AppendGoValue (
a .groupBuilders [field .Name ],
v ,
); err != nil {
return err
}
if newOrderedSet {
n := len (a .groupResults ) - 1
arr := a .groupBuilders [field .Name ].NewArray ()
a .groupBuilders [field .Name ] = builder .NewBuilder (a .pool , arr .DataType ())
a .groupResults [n ] = append (a .groupResults [n ], arr )
}
}
groupStart = groupEnd
}
if len (arraysToAggregate ) == 0 {
return nil
}
results , err := runAggregation (a .finalStage , a .aggregationFunction , a .pool , arraysToAggregate )
if err != nil {
return err
}
setStart := int64 (0 )
for _ , setEnd := range arraysToAggregateSetIdxs {
set := array .NewSlice (results , setStart , setEnd )
if a .aggResultBuilder .Len () > 0 {
a .aggResultBuilder .Add (set )
var err error
set , err = a .aggResultBuilder .NewArray (a .pool )
if err != nil {
return err
}
}
a .aggregationResults = append (a .aggregationResults , set )
setStart = setEnd
}
a .aggResultBuilder .Add (array .NewSlice (results , setStart , int64 (results .Len ())))
return nil
}
func (a *OrderedAggregate ) Finish (ctx context .Context ) error {
ctx , span := a .tracer .Start (ctx , "OrderedAggregate/Finish" )
defer span .End ()
if !a .notFirstCall {
return a .next .Finish (ctx )
}
if a .arrayToAggCarry .Len () > 0 {
a .groupResults = append (a .groupResults , nil )
n := len (a .groupResults ) - 1
for _ , field := range a .groupColOrdering {
b := a .groupBuilders [field .Name ]
if err := builder .AppendGoValue (
b , a .curGroup [field .Name ],
); err != nil {
return err
}
a .groupResults [n ] = append (a .groupResults [n ], b .NewArray ())
}
results , err := runAggregation (
a .finalStage , a .aggregationFunction , a .pool , []arrow .Array {a .arrayToAggCarry .NewArray ()},
)
if err != nil {
return err
}
var lastResults arrow .Array
if a .aggResultBuilder .Len () > 0 {
a .aggResultBuilder .Add (results )
var err error
lastResults , err = a .aggResultBuilder .NewArray (a .pool )
if err != nil {
return err
}
} else {
lastResults = results
}
a .aggregationResults = append (a .aggregationResults , lastResults )
}
schema := arrow .NewSchema (
append (
a .groupColOrdering ,
arrow .Field {Name : a .getResultColumnName (), Type : a .aggregationResults [0 ].DataType ()},
),
nil ,
)
records := make ([]arrow .Record , 0 , len (a .groupResults ))
for i := range a .groupResults {
records = append (
records ,
array .NewRecord (
schema ,
append (
a .groupResults [i ],
a .aggregationResults [i ],
),
int64 (a .aggregationResults [i ].Len ()),
),
)
}
if len (records ) == 1 {
if err := a .next .Callback (ctx , records [0 ]); err != nil {
return err
}
} else {
orderByCols := make ([]arrowutils .SortingColumn , len (a .groupColOrdering ))
for i := range orderByCols {
orderByCols [i ] = arrowutils .SortingColumn {Index : i }
}
mergedRecord , err := arrowutils .MergeRecords (a .pool , records , orderByCols , 0 )
if err != nil {
return err
}
firstGroup := make ([]any , len (a .groupColOrdering ))
groupArrs := mergedRecord .Columns ()[:len (a .groupColOrdering )]
for i , arr := range groupArrs {
firstGroup [i ] = arr .GetOneForMarshal (0 )
}
wrappedGroupRanges , _ , _ , err := arrowutils .GetGroupsAndOrderedSetRanges (firstGroup , groupArrs )
if err != nil {
return err
}
groupRanges := wrappedGroupRanges .Unwrap (a .scratch .indexes )
groupRanges = append (groupRanges , mergedRecord .NumRows ())
for i , field := range a .groupColOrdering {
start := int64 (0 )
for _ , end := range groupRanges {
if err := builder .AppendValue (
a .groupBuilders [field .Name ], mergedRecord .Column (i ), int (start ),
); err != nil {
return err
}
start = end
}
}
aggregationVals := mergedRecord .Columns ()[len (a .groupColOrdering )]
start := int64 (0 )
toAggregate := make ([]arrow .Array , 0 , len (groupRanges ))
for _ , end := range groupRanges {
toAggregate = append (toAggregate , array .NewSlice (aggregationVals , start , end ))
start = end
}
result , err := runAggregation (true , a .aggregationFunction , a .pool , toAggregate )
if err != nil {
return err
}
groups := make ([]arrow .Array , 0 , len (a .groupBuilders ))
for _ , field := range a .groupColOrdering {
groups = append (groups , a .groupBuilders [field .Name ].NewArray ())
}
if err := a .next .Callback (
ctx ,
array .NewRecord (
schema ,
append (groups , result ),
int64 (result .Len ()),
),
); err != nil {
return err
}
}
return a .next .Finish (ctx )
}
func (a *OrderedAggregate ) getResultColumnName () string {
fieldName := a .columnToAggregate .Name ()
if a .finalStage {
fieldName = a .resultColumnName
}
return fieldName
}
The pages are generated with Golds v0.8.2 . (GOOS=linux GOARCH=amd64)
Golds is a Go 101 project developed by Tapir Liu .
PR and bug reports are welcome and can be submitted to the issue list .
Please follow @zigo_101 (reachable from the left QR code) to get the latest news of Golds .