package avro
import (
"errors"
"fmt"
"io"
"reflect"
"unsafe"
"github.com/modern-go/reflect2"
)
func createDecoderOfRecord(d *decoderContext , schema Schema , typ reflect2 .Type ) ValDecoder {
switch typ .Kind () {
case reflect .Struct :
return decoderOfStruct (d , schema , typ )
case reflect .Map :
if typ .(reflect2 .MapType ).Key ().Kind () != reflect .String ||
typ .(reflect2 .MapType ).Elem ().Kind () != reflect .Interface {
break
}
return decoderOfRecord (d , schema , typ )
case reflect .Ptr :
return decoderOfPtr (d , schema , typ )
case reflect .Interface :
if ifaceType , ok := typ .(*reflect2 .UnsafeIFaceType ); ok {
return &recordIfaceDecoder {schema : schema , valType : ifaceType }
}
}
return &errorDecoder {err : fmt .Errorf ("avro: %s is unsupported for avro %s" , typ .String (), schema .Type ())}
}
func createEncoderOfRecord(e *encoderContext , schema *RecordSchema , typ reflect2 .Type ) ValEncoder {
switch typ .Kind () {
case reflect .Struct :
return encoderOfStruct (e , schema , typ )
case reflect .Map :
if typ .(reflect2 .MapType ).Key ().Kind () != reflect .String ||
typ .(reflect2 .MapType ).Elem ().Kind () != reflect .Interface {
break
}
return encoderOfRecord (e , schema , typ )
case reflect .Ptr :
return encoderOfPtr (e , schema , typ )
}
return &errorEncoder {err : fmt .Errorf ("avro: %s is unsupported for avro %s" , typ .String (), schema .Type ())}
}
func decoderOfStruct(d *decoderContext , schema Schema , typ reflect2 .Type ) ValDecoder {
rec := schema .(*RecordSchema )
structDesc := describeStruct (d .cfg .getTagKey (), typ )
fields := make ([]*structFieldDecoder , 0 , len (rec .Fields ()))
for _ , field := range rec .Fields () {
if field .action == FieldIgnore {
fields = append (fields , &structFieldDecoder {
decoder : createSkipDecoder (field .Type ()),
})
continue
}
sf := structDesc .Fields .Get (field .Name ())
if sf == nil {
for _ , alias := range field .Aliases () {
sf = structDesc .Fields .Get (alias )
if sf != nil {
break
}
}
}
if sf == nil {
if field .action == FieldSetDefault {
continue
}
fields = append (fields , &structFieldDecoder {
decoder : createSkipDecoder (field .Type ()),
})
continue
}
if field .action == FieldSetDefault {
if field .hasDef {
fields = append (fields , &structFieldDecoder {
field : sf .Field ,
decoder : createDefaultDecoder (d , field , sf .Field [len (sf .Field )-1 ].Type ()),
})
continue
}
}
dec := decoderOfType (d , field .Type (), sf .Field [len (sf .Field )-1 ].Type ())
fields = append (fields , &structFieldDecoder {
field : sf .Field ,
decoder : dec ,
})
}
return &structDecoder {typ : typ , fields : fields }
}
type structFieldDecoder struct {
field []*reflect2 .UnsafeStructField
decoder ValDecoder
}
type structDecoder struct {
typ reflect2 .Type
fields []*structFieldDecoder
}
func (d *structDecoder ) Decode (ptr unsafe .Pointer , r *Reader ) {
for _ , field := range d .fields {
if field .field == nil {
field .decoder .Decode (nil , r )
continue
}
fieldPtr := ptr
for i , f := range field .field {
fieldPtr = f .UnsafeGet (fieldPtr )
if i == len (field .field )-1 {
break
}
if f .Type ().Kind () == reflect .Ptr {
if *((*unsafe .Pointer )(fieldPtr )) == nil {
newPtr := f .Type ().(*reflect2 .UnsafePtrType ).Elem ().UnsafeNew ()
*((*unsafe .Pointer )(fieldPtr )) = newPtr
}
fieldPtr = *((*unsafe .Pointer )(fieldPtr ))
}
}
field .decoder .Decode (fieldPtr , r )
if r .Error != nil && !errors .Is (r .Error , io .EOF ) {
for _ , f := range field .field {
r .Error = fmt .Errorf ("%s: %w" , f .Name (), r .Error )
return
}
}
}
}
func encoderOfStruct(e *encoderContext , rec *RecordSchema , typ reflect2 .Type ) ValEncoder {
structDesc := describeStruct (e .cfg .getTagKey (), typ )
fields := make ([]*structFieldEncoder , 0 , len (rec .Fields ()))
for _ , field := range rec .Fields () {
sf := structDesc .Fields .Get (field .Name ())
if sf != nil {
fields = append (fields , &structFieldEncoder {
field : sf .Field ,
encoder : encoderOfType (e , field .Type (), sf .Field [len (sf .Field )-1 ].Type ()),
})
continue
}
if !field .HasDefault () {
err := fmt .Errorf ("avro: record %s is missing required field %q" , rec .FullName (), field .Name ())
return &errorEncoder {err : err }
}
def := field .Default ()
if field .Default () == nil {
if field .Type ().Type () == Null {
continue
}
if field .Type ().Type () == Union && field .Type ().(*UnionSchema ).Nullable () {
defaultType := reflect2 .TypeOf (&def )
fields = append (fields , &structFieldEncoder {
defaultPtr : reflect2 .PtrOf (&def ),
encoder : encoderOfNullableUnion (e , field .Type (), defaultType ),
})
continue
}
}
defaultType := reflect2 .TypeOf (def )
defaultEncoder := encoderOfType (e , field .Type (), defaultType )
if defaultType .LikePtr () {
defaultEncoder = &onePtrEncoder {defaultEncoder }
}
fields = append (fields , &structFieldEncoder {
defaultPtr : reflect2 .PtrOf (def ),
encoder : defaultEncoder ,
})
}
return &structEncoder {typ : typ , fields : fields }
}
type structFieldEncoder struct {
field []*reflect2 .UnsafeStructField
defaultPtr unsafe .Pointer
encoder ValEncoder
}
type structEncoder struct {
typ reflect2 .Type
fields []*structFieldEncoder
}
func (e *structEncoder ) Encode (ptr unsafe .Pointer , w *Writer ) {
for _ , field := range e .fields {
if field .field == nil {
field .encoder .Encode (field .defaultPtr , w )
continue
}
fieldPtr := ptr
for i , f := range field .field {
fieldPtr = f .UnsafeGet (fieldPtr )
if i == len (field .field )-1 {
break
}
if f .Type ().Kind () == reflect .Ptr {
if *((*unsafe .Pointer )(fieldPtr )) == nil {
w .Error = fmt .Errorf ("embedded field %q is nil" , f .Name ())
return
}
fieldPtr = *((*unsafe .Pointer )(fieldPtr ))
}
}
field .encoder .Encode (fieldPtr , w )
if w .Error != nil && !errors .Is (w .Error , io .EOF ) {
for _ , f := range field .field {
w .Error = fmt .Errorf ("%s: %w" , f .Name (), w .Error )
return
}
}
}
}
func decoderOfRecord(d *decoderContext , schema Schema , typ reflect2 .Type ) ValDecoder {
rec := schema .(*RecordSchema )
mapType := typ .(*reflect2 .UnsafeMapType )
fields := make ([]recordMapDecoderField , len (rec .Fields ()))
for i , field := range rec .Fields () {
switch field .action {
case FieldIgnore :
fields [i ] = recordMapDecoderField {
name : field .Name (),
decoder : createSkipDecoder (field .Type ()),
skip : true ,
}
continue
case FieldSetDefault :
if field .hasDef {
fields [i ] = recordMapDecoderField {
name : field .Name (),
decoder : createDefaultDecoder (d , field , mapType .Elem ()),
}
continue
}
}
fields [i ] = recordMapDecoderField {
name : field .Name (),
decoder : decoderOfType (d , field .Type (), mapType .Elem ()),
}
}
return &recordMapDecoder {
mapType : mapType ,
elemType : mapType .Elem (),
fields : fields ,
}
}
type recordMapDecoderField struct {
name string
decoder ValDecoder
skip bool
}
type recordMapDecoder struct {
mapType *reflect2 .UnsafeMapType
elemType reflect2 .Type
fields []recordMapDecoderField
}
func (d *recordMapDecoder ) Decode (ptr unsafe .Pointer , r *Reader ) {
if d .mapType .UnsafeIsNil (ptr ) {
d .mapType .UnsafeSet (ptr , d .mapType .UnsafeMakeMap (len (d .fields )))
}
for _ , field := range d .fields {
elemPtr := d .elemType .UnsafeNew ()
field .decoder .Decode (elemPtr , r )
if field .skip {
continue
}
d .mapType .UnsafeSetIndex (ptr , reflect2 .PtrOf (field ), elemPtr )
}
if r .Error != nil && !errors .Is (r .Error , io .EOF ) {
r .Error = fmt .Errorf ("%v: %w" , d .mapType , r .Error )
}
}
func encoderOfRecord(e *encoderContext , rec *RecordSchema , typ reflect2 .Type ) ValEncoder {
mapType := typ .(*reflect2 .UnsafeMapType )
fields := make ([]mapEncoderField , len (rec .Fields ()))
for i , field := range rec .Fields () {
fields [i ] = mapEncoderField {
name : field .Name (),
hasDef : field .HasDefault (),
def : field .Default (),
encoder : encoderOfType (e , field .Type (), mapType .Elem ()),
}
if field .HasDefault () {
switch {
case field .Type ().Type () == Union :
union := field .Type ().(*UnionSchema )
fields [i ].def = map [string ]any {
string (union .Types ()[0 ].Type ()): field .Default (),
}
case field .Default () == nil :
continue
}
defaultType := reflect2 .TypeOf (fields [i ].def )
fields [i ].defEncoder = encoderOfType (e , field .Type (), defaultType )
if defaultType .LikePtr () {
fields [i ].defEncoder = &onePtrEncoder {fields [i ].defEncoder }
}
}
}
return &recordMapEncoder {
mapType : mapType ,
fields : fields ,
}
}
type mapEncoderField struct {
name string
hasDef bool
def any
defEncoder ValEncoder
encoder ValEncoder
}
type recordMapEncoder struct {
mapType *reflect2 .UnsafeMapType
fields []mapEncoderField
}
func (e *recordMapEncoder ) Encode (ptr unsafe .Pointer , w *Writer ) {
for _ , field := range e .fields {
valPtr := e .mapType .UnsafeGetIndex (ptr , reflect2 .PtrOf (field ))
if valPtr == nil {
if !field .hasDef {
w .Error = fmt .Errorf ("avro: missing required field %s" , field .name )
return
}
if field .def == nil {
continue
}
defPtr := reflect2 .PtrOf (field .def )
field .defEncoder .Encode (defPtr , w )
continue
}
field .encoder .Encode (valPtr , w )
if w .Error != nil && !errors .Is (w .Error , io .EOF ) {
w .Error = fmt .Errorf ("%s: %w" , field .name , w .Error )
return
}
}
}
type recordIfaceDecoder struct {
schema Schema
valType *reflect2 .UnsafeIFaceType
}
func (d *recordIfaceDecoder ) Decode (ptr unsafe .Pointer , r *Reader ) {
obj := d .valType .UnsafeIndirect (ptr )
if reflect2 .IsNil (obj ) {
r .ReportError ("decode non empty interface" , "can not unmarshal into nil" )
return
}
r .ReadVal (d .schema , obj )
}
type structDescriptor struct {
Type reflect2 .Type
Fields structFields
}
type structFields []*structField
func (sf structFields ) Get (name string ) *structField {
for _ , f := range sf {
if f .Name == name {
return f
}
}
return nil
}
type structField struct {
Name string
Field []*reflect2 .UnsafeStructField
anon *reflect2 .UnsafeStructType
}
func describeStruct(tagKey string , typ reflect2 .Type ) *structDescriptor {
structType := typ .(*reflect2 .UnsafeStructType )
fields := structFields {}
var curr []structField
next := []structField {{anon : structType }}
visited := map [uintptr ]bool {}
for len (next ) > 0 {
curr , next = next , curr [:0 ]
for _ , f := range curr {
rtype := f .anon .RType ()
if visited [f .anon .RType ()] {
continue
}
visited [rtype ] = true
for i := range f .anon .NumField () {
field := f .anon .Field (i ).(*reflect2 .UnsafeStructField )
isUnexported := field .PkgPath () != ""
chain := make ([]*reflect2 .UnsafeStructField , len (f .Field )+1 )
copy (chain , f .Field )
chain [len (f .Field )] = field
if field .Anonymous () {
t := field .Type ()
if t .Kind () == reflect .Ptr {
t = t .(*reflect2 .UnsafePtrType ).Elem ()
}
if t .Kind () != reflect .Struct {
continue
}
next = append (next , structField {Field : chain , anon : t .(*reflect2 .UnsafeStructType )})
continue
}
if isUnexported {
continue
}
fieldName := field .Name ()
if tag , ok := field .Tag ().Lookup (tagKey ); ok {
fieldName = tag
}
fields = append (fields , &structField {
Name : fieldName ,
Field : chain ,
})
}
}
}
return &structDescriptor {
Type : structType ,
Fields : fields ,
}
}
The pages are generated with Golds v0.8.2 . (GOOS=linux GOARCH=amd64)
Golds is a Go 101 project developed by Tapir Liu .
PR and bug reports are welcome and can be submitted to the issue list .
Please follow @zigo_101 (reachable from the left QR code) to get the latest news of Golds .