package logicalplan
import (
"errors"
"fmt"
"github.com/apache/arrow-go/v18/arrow"
)
type Builder struct {
plan *LogicalPlan
err error
}
func (b Builder ) Scan (
provider TableProvider ,
tableName string ,
) Builder {
return Builder {
err : b .err ,
plan : &LogicalPlan {
TableScan : &TableScan {
TableProvider : provider ,
TableName : tableName ,
},
},
}
}
func (b Builder ) ScanSchema (
provider TableProvider ,
tableName string ,
) Builder {
return Builder {
err : b .err ,
plan : &LogicalPlan {
SchemaScan : &SchemaScan {
TableProvider : provider ,
TableName : tableName ,
},
},
}
}
func (b Builder ) Project (
exprs ...Expr ,
) Builder {
return Builder {
err : b .err ,
plan : &LogicalPlan {
Input : b .plan ,
Projection : &Projection {
Exprs : exprs ,
},
},
}
}
type Visitor interface {
PreVisit (expr Expr ) bool
Visit (expr Expr ) bool
PostVisit (expr Expr ) bool
}
type ExprTypeFinder interface {
DataTypeForExpr (expr Expr ) (arrow .DataType , error )
}
type Expr interface {
DataType (ExprTypeFinder ) (arrow .DataType , error )
Accept (Visitor ) bool
Name () string
Equal (Expr ) bool
fmt .Stringer
ColumnsUsedExprs () []Expr
MatchColumn (columnName string ) bool
MatchPath (path string ) bool
Computed () bool
Clone () Expr
}
func (b Builder ) Filter (expr Expr ) Builder {
if expr == nil {
return b
}
return Builder {
err : b .err ,
plan : &LogicalPlan {
Input : b .plan ,
Filter : &Filter {
Expr : expr ,
},
},
}
}
func (b Builder ) Distinct (
exprs ...Expr ,
) Builder {
return Builder {
err : b .err ,
plan : &LogicalPlan {
Distinct : &Distinct {
Exprs : exprs ,
},
Input : &LogicalPlan {
Projection : &Projection {
Exprs : exprs ,
},
Input : b .plan ,
},
},
}
}
func (b Builder ) Limit (expr Expr ) Builder {
if expr == nil {
return b
}
return Builder {
err : b .err ,
plan : &LogicalPlan {
Input : b .plan ,
Limit : &Limit {
Expr : expr ,
},
},
}
}
func (b Builder ) Aggregate (
aggExpr []*AggregationFunction ,
groupExprs []Expr ,
) Builder {
resolvedAggExpr := make ([]*AggregationFunction , 0 , len (aggExpr ))
projectExprs := make ([]Expr , 0 , len (aggExpr ))
needsPostProcessing := false
var err error
for _ , agg := range aggExpr {
resolvedAggregations , projections , changed , rerr := resolveAggregation (b .plan , agg )
if err != nil {
err = errors .Join (err , rerr )
}
if changed {
needsPostProcessing = true
}
resolvedAggExpr = append (resolvedAggExpr , resolvedAggregations ...)
projectExprs = append (projectExprs , projections ...)
}
if !needsPostProcessing {
return Builder {
err : err ,
plan : &LogicalPlan {
Aggregation : &Aggregation {
GroupExprs : groupExprs ,
AggExprs : aggExpr ,
},
Input : b .plan ,
},
}
}
return Builder {
err : err ,
plan : &LogicalPlan {
Projection : &Projection {
Exprs : append (groupExprs , projectExprs ...),
},
Input : &LogicalPlan {
Aggregation : &Aggregation {
GroupExprs : groupExprs ,
AggExprs : resolvedAggExpr ,
},
Input : b .plan ,
},
},
}
}
func resolveAggregation(plan *LogicalPlan , agg *AggregationFunction ) ([]*AggregationFunction , []Expr , bool , error ) {
switch agg .Func {
case AggFuncAvg :
sum := &AggregationFunction {
Func : AggFuncSum ,
Expr : agg .Expr ,
}
count := &AggregationFunction {
Func : AggFuncCount ,
Expr : agg .Expr ,
}
var (
countExpr Expr = count
aggType arrow .DataType
)
aggType , err := agg .Expr .DataType (plan )
if !arrow .TypeEqual (aggType , arrow .PrimitiveTypes .Int64 ) {
countExpr = Convert (countExpr , aggType )
}
div := (&BinaryExpr {
Left : sum ,
Op : OpDiv ,
Right : countExpr ,
}).Alias (agg .String ())
return []*AggregationFunction {sum , count }, []Expr {div }, true , err
default :
return []*AggregationFunction {agg }, []Expr {agg }, false , nil
}
}
func (b Builder ) Sample (expr , limit Expr ) Builder {
if expr == nil || limit == nil {
return b
}
return Builder {
err : b .err ,
plan : &LogicalPlan {
Input : b .plan ,
Sample : &Sample {
Expr : expr ,
Limit : limit ,
},
},
}
}
func (b Builder ) Build () (*LogicalPlan , error ) {
if b .err != nil {
return nil , b .err
}
if err := Validate (b .plan ); err != nil {
return nil , err
}
return b .plan , nil
}
The pages are generated with Golds v0.8.2 . (GOOS=linux GOARCH=amd64)
Golds is a Go 101 project developed by Tapir Liu .
PR and bug reports are welcome and can be submitted to the issue list .
Please follow @zigo_101 (reachable from the left QR code) to get the latest news of Golds .