package callbacks
import (
"fmt"
"reflect"
"sort"
"strings"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/schema"
"gorm.io/gorm/utils"
)
func parsePreloadMap(s *schema .Schema , preloads map [string ][]interface {}) map [string ]map [string ][]interface {} {
preloadMap := map [string ]map [string ][]interface {}{}
setPreloadMap := func (name , value string , args []interface {}) {
if _ , ok := preloadMap [name ]; !ok {
preloadMap [name ] = map [string ][]interface {}{}
}
if value != "" {
preloadMap [name ][value ] = args
}
}
for name , args := range preloads {
preloadFields := strings .Split (name , "." )
value := strings .TrimPrefix (strings .TrimPrefix (name , preloadFields [0 ]), "." )
if preloadFields [0 ] == clause .Associations {
for _ , relation := range s .Relationships .Relations {
if relation .Schema == s {
setPreloadMap (relation .Name , value , args )
}
}
for embedded , embeddedRelations := range s .Relationships .EmbeddedRelations {
for _ , value := range embeddedValues (embeddedRelations ) {
setPreloadMap (embedded , value , args )
}
}
} else {
setPreloadMap (preloadFields [0 ], value , args )
}
}
return preloadMap
}
func embeddedValues(embeddedRelations *schema .Relationships ) []string {
if embeddedRelations == nil {
return nil
}
names := make ([]string , 0 , len (embeddedRelations .Relations )+len (embeddedRelations .EmbeddedRelations ))
for _ , relation := range embeddedRelations .Relations {
names = append (names , strings .Join (relation .Field .EmbeddedBindNames [1 :], "." ))
}
for _ , relations := range embeddedRelations .EmbeddedRelations {
names = append (names , embeddedValues (relations )...)
}
return names
}
func preloadEntryPoint(db *gorm .DB , joins []string , relationships *schema .Relationships , preloads map [string ][]interface {}, associationsConds []interface {}) error {
preloadMap := parsePreloadMap (db .Statement .Schema , preloads )
preloadNames := make ([]string , 0 , len (preloadMap ))
for key := range preloadMap {
preloadNames = append (preloadNames , key )
}
sort .Strings (preloadNames )
isJoined := func (name string ) (joined bool , nestedJoins []string ) {
for _ , join := range joins {
if _ , ok := relationships .Relations [join ]; ok && name == join {
joined = true
continue
}
join0 , join1 , cut := strings .Cut (join , "." )
if cut {
if _ , ok := relationships .Relations [join0 ]; ok && name == join0 {
joined = true
nestedJoins = append (nestedJoins , join1 )
}
}
}
return joined , nestedJoins
}
for _ , name := range preloadNames {
if relations := relationships .EmbeddedRelations [name ]; relations != nil {
if err := preloadEntryPoint (db , joins , relations , preloadMap [name ], associationsConds ); err != nil {
return err
}
} else if rel := relationships .Relations [name ]; rel != nil {
if joined , nestedJoins := isJoined (name ); joined {
switch rv := db .Statement .ReflectValue ; rv .Kind () {
case reflect .Slice , reflect .Array :
if rv .Len () > 0 {
reflectValue := rel .FieldSchema .MakeSlice ().Elem ()
for i := 0 ; i < rv .Len (); i ++ {
frv := rel .Field .ReflectValueOf (db .Statement .Context , rv .Index (i ))
if frv .Kind () != reflect .Ptr {
reflectValue = reflect .Append (reflectValue , frv .Addr ())
} else {
if frv .IsNil () {
continue
}
reflectValue = reflect .Append (reflectValue , frv )
}
}
tx := preloadDB (db , reflectValue , reflectValue .Interface ())
if err := preloadEntryPoint (tx , nestedJoins , &tx .Statement .Schema .Relationships , preloadMap [name ], associationsConds ); err != nil {
return err
}
}
case reflect .Struct , reflect .Pointer :
reflectValue := rel .Field .ReflectValueOf (db .Statement .Context , rv )
tx := preloadDB (db , reflectValue , reflectValue .Interface ())
if err := preloadEntryPoint (tx , nestedJoins , &tx .Statement .Schema .Relationships , preloadMap [name ], associationsConds ); err != nil {
return err
}
default :
return gorm .ErrInvalidData
}
} else {
tx := db .Table ("" ).Session (&gorm .Session {Context : db .Statement .Context , SkipHooks : db .Statement .SkipHooks })
tx .Statement .ReflectValue = db .Statement .ReflectValue
tx .Statement .Unscoped = db .Statement .Unscoped
if err := preload (tx , rel , append (preloads [name ], associationsConds ...), preloadMap [name ]); err != nil {
return err
}
}
} else {
return fmt .Errorf ("%s: %w for schema %s" , name , gorm .ErrUnsupportedRelation , db .Statement .Schema .Name )
}
}
return nil
}
func preloadDB(db *gorm .DB , reflectValue reflect .Value , dest interface {}) *gorm .DB {
tx := db .Session (&gorm .Session {Context : db .Statement .Context , NewDB : true , SkipHooks : db .Statement .SkipHooks , Initialized : true })
db .Statement .Settings .Range (func (k , v interface {}) bool {
tx .Statement .Settings .Store (k , v )
return true
})
if err := tx .Statement .Parse (dest ); err != nil {
tx .AddError (err )
return tx
}
tx .Statement .ReflectValue = reflectValue
tx .Statement .Unscoped = db .Statement .Unscoped
return tx
}
func preload(tx *gorm .DB , rel *schema .Relationship , conds []interface {}, preloads map [string ][]interface {}) error {
var (
reflectValue = tx .Statement .ReflectValue
relForeignKeys []string
relForeignFields []*schema .Field
foreignFields []*schema .Field
foreignValues [][]interface {}
identityMap = map [string ][]reflect .Value {}
inlineConds []interface {}
)
if rel .JoinTable != nil {
var (
joinForeignFields = make ([]*schema .Field , 0 , len (rel .References ))
joinRelForeignFields = make ([]*schema .Field , 0 , len (rel .References ))
joinForeignKeys = make ([]string , 0 , len (rel .References ))
)
for _ , ref := range rel .References {
if ref .OwnPrimaryKey {
joinForeignKeys = append (joinForeignKeys , ref .ForeignKey .DBName )
joinForeignFields = append (joinForeignFields , ref .ForeignKey )
foreignFields = append (foreignFields , ref .PrimaryKey )
} else if ref .PrimaryValue != "" {
tx = tx .Where (clause .Eq {Column : ref .ForeignKey .DBName , Value : ref .PrimaryValue })
} else {
joinRelForeignFields = append (joinRelForeignFields , ref .ForeignKey )
relForeignKeys = append (relForeignKeys , ref .PrimaryKey .DBName )
relForeignFields = append (relForeignFields , ref .PrimaryKey )
}
}
joinIdentityMap , joinForeignValues := schema .GetIdentityFieldValuesMap (tx .Statement .Context , reflectValue , foreignFields )
if len (joinForeignValues ) == 0 {
return nil
}
joinResults := rel .JoinTable .MakeSlice ().Elem ()
column , values := schema .ToQueryValues (clause .CurrentTable , joinForeignKeys , joinForeignValues )
if err := tx .Where (clause .IN {Column : column , Values : values }).Find (joinResults .Addr ().Interface ()).Error ; err != nil {
return err
}
fieldValues := make ([]interface {}, len (joinForeignFields ))
joinFieldValues := make ([]interface {}, len (joinRelForeignFields ))
for i := 0 ; i < joinResults .Len (); i ++ {
joinIndexValue := joinResults .Index (i )
for idx , field := range joinForeignFields {
fieldValues [idx ], _ = field .ValueOf (tx .Statement .Context , joinIndexValue )
}
for idx , field := range joinRelForeignFields {
joinFieldValues [idx ], _ = field .ValueOf (tx .Statement .Context , joinIndexValue )
}
if results , ok := joinIdentityMap [utils .ToStringKey (fieldValues ...)]; ok {
joinKey := utils .ToStringKey (joinFieldValues ...)
identityMap [joinKey ] = append (identityMap [joinKey ], results ...)
}
}
_, foreignValues = schema .GetIdentityFieldValuesMap (tx .Statement .Context , joinResults , joinRelForeignFields )
} else {
for _ , ref := range rel .References {
if ref .OwnPrimaryKey {
relForeignKeys = append (relForeignKeys , ref .ForeignKey .DBName )
relForeignFields = append (relForeignFields , ref .ForeignKey )
foreignFields = append (foreignFields , ref .PrimaryKey )
} else if ref .PrimaryValue != "" {
tx = tx .Where (clause .Eq {Column : ref .ForeignKey .DBName , Value : ref .PrimaryValue })
} else {
relForeignKeys = append (relForeignKeys , ref .PrimaryKey .DBName )
relForeignFields = append (relForeignFields , ref .PrimaryKey )
foreignFields = append (foreignFields , ref .ForeignKey )
}
}
identityMap , foreignValues = schema .GetIdentityFieldValuesMap (tx .Statement .Context , reflectValue , foreignFields )
if len (foreignValues ) == 0 {
return nil
}
}
for p , pvs := range preloads {
tx = tx .Preload (p , pvs ...)
}
reflectResults := rel .FieldSchema .MakeSlice ().Elem ()
column , values := schema .ToQueryValues (clause .CurrentTable , relForeignKeys , foreignValues )
if len (values ) != 0 {
tx = tx .Model (reflectResults .Addr ().Interface ()).Where (clause .IN {Column : column , Values : values })
for _ , cond := range conds {
if fc , ok := cond .(func (*gorm .DB ) *gorm .DB ); ok {
tx = fc (tx )
} else {
inlineConds = append (inlineConds , cond )
}
}
if len (inlineConds ) > 0 {
tx = tx .Where (inlineConds [0 ], inlineConds [1 :]...)
}
if err := tx .Find (reflectResults .Addr ().Interface ()).Error ; err != nil {
return err
}
}
fieldValues := make ([]interface {}, len (relForeignFields ))
switch reflectValue .Kind () {
case reflect .Struct :
switch rel .Type {
case schema .HasMany , schema .Many2Many :
tx .AddError (rel .Field .Set (tx .Statement .Context , reflectValue , reflect .MakeSlice (rel .Field .IndirectFieldType , 0 , 10 ).Interface ()))
default :
tx .AddError (rel .Field .Set (tx .Statement .Context , reflectValue , reflect .New (rel .Field .FieldType ).Interface ()))
}
case reflect .Slice , reflect .Array :
for i := 0 ; i < reflectValue .Len (); i ++ {
switch rel .Type {
case schema .HasMany , schema .Many2Many :
tx .AddError (rel .Field .Set (tx .Statement .Context , reflectValue .Index (i ), reflect .MakeSlice (rel .Field .IndirectFieldType , 0 , 10 ).Interface ()))
default :
tx .AddError (rel .Field .Set (tx .Statement .Context , reflectValue .Index (i ), reflect .New (rel .Field .FieldType ).Interface ()))
}
}
}
for i := 0 ; i < reflectResults .Len (); i ++ {
elem := reflectResults .Index (i )
for idx , field := range relForeignFields {
fieldValues [idx ], _ = field .ValueOf (tx .Statement .Context , elem )
}
datas , ok := identityMap [utils .ToStringKey (fieldValues ...)]
if !ok {
return fmt .Errorf ("failed to assign association %#v, make sure foreign fields exists" , elem .Interface ())
}
for _ , data := range datas {
reflectFieldValue := rel .Field .ReflectValueOf (tx .Statement .Context , data )
if reflectFieldValue .Kind () == reflect .Ptr && reflectFieldValue .IsNil () {
reflectFieldValue .Set (reflect .New (rel .Field .FieldType .Elem ()))
}
reflectFieldValue = reflect .Indirect (reflectFieldValue )
switch reflectFieldValue .Kind () {
case reflect .Struct :
tx .AddError (rel .Field .Set (tx .Statement .Context , data , elem .Interface ()))
case reflect .Slice , reflect .Array :
if reflectFieldValue .Type ().Elem ().Kind () == reflect .Ptr {
tx .AddError (rel .Field .Set (tx .Statement .Context , data , reflect .Append (reflectFieldValue , elem ).Interface ()))
} else {
tx .AddError (rel .Field .Set (tx .Statement .Context , data , reflect .Append (reflectFieldValue , elem .Elem ()).Interface ()))
}
}
}
}
return tx .Error
}
The pages are generated with Golds v0.8.2 . (GOOS=linux GOARCH=amd64)
Golds is a Go 101 project developed by Tapir Liu .
PR and bug reports are welcome and can be submitted to the issue list .
Please follow @zigo_101 (reachable from the left QR code) to get the latest news of Golds .