package avro

import (
	
	
	
	
	
	

	
)

// UnionConverter to handle Avro Union's in a type-safe way.
type UnionConverter interface {
	// FromAny payload decode into any of the mentioned types in the Union.
	FromAny(payload any) error
	// ToAny from the Union struct
	ToAny() (any, error)
}

func createDecoderOfUnion( *decoderContext,  *UnionSchema,  reflect2.Type) ValDecoder {
	switch .Kind() {
	case reflect.Map:
		if .(reflect2.MapType).Key().Kind() != reflect.String ||
			.(reflect2.MapType).Elem().Kind() != reflect.Interface {
			break
		}
		return decoderOfMapUnion(, , )
	case reflect.Slice:
		if !.Nullable() {
			break
		}
		return decoderOfNullableUnion(, , )
	case reflect.Ptr:
		if .Implements(reflect2.Type2(reflect.TypeFor[UnionConverter]())) {
			return decoderOfUnionConverterCodec(, , )
		}

		if !.Nullable() {
			break
		}
		return decoderOfNullableUnion(, , )
	case reflect.Interface:
		if ,  := .(*reflect2.UnsafeIFaceType); ! {
			,  := decoderOfResolvedUnion(, , )
			if  != nil {
				return &errorDecoder{err: fmt.Errorf("avro: problem resolving decoder for Avro %s: %w", .Type(), )}
			}
			return 
		}
	case reflect.Struct:
		return (, , reflect2.PtrTo())
	}

	return &errorDecoder{err: fmt.Errorf("avro: %s is unsupported for Avro %s", .String(), .Type())}
}

func createEncoderOfUnion( *encoderContext,  *UnionSchema,  reflect2.Type) ValEncoder {
	switch .Kind() {
	case reflect.Map:
		if .(reflect2.MapType).Key().Kind() != reflect.String ||
			.(reflect2.MapType).Elem().Kind() != reflect.Interface {
			break
		}
		return encoderOfMapUnion(, , )
	case reflect.Slice:
		if !.Nullable() {
			break
		}
		return encoderOfNullableUnion(, , )
	case reflect.Ptr:
		if .Implements(reflect2.Type2(reflect.TypeFor[UnionConverter]())) {
			return encoderOfUnionConverterCodec(, , )
		}

		if !.Nullable() {
			break
		}
		return encoderOfNullableUnion(, , )
	}

	return encoderOfResolverUnion(, , )
}

func decoderOfMapUnion( *decoderContext,  *UnionSchema,  reflect2.Type) ValDecoder {
	 := .(*reflect2.UnsafeMapType)

	 := make([]ValDecoder, len(.Types()))
	for ,  := range .Types() {
		if .Type() == Null {
			continue
		}
		[] = newEfaceDecoder(, )
	}

	return &mapUnionDecoder{
		cfg:      .cfg,
		schema:   ,
		mapType:  ,
		elemType: .Elem(),
		typeDecs: ,
	}
}

type mapUnionDecoder struct {
	cfg      *frozenConfig
	schema   *UnionSchema
	mapType  *reflect2.UnsafeMapType
	elemType reflect2.Type
	typeDecs []ValDecoder
}

func ( *mapUnionDecoder) ( unsafe.Pointer,  *Reader) {
	,  := getUnionSchema(.schema, )
	if  == nil {
		return
	}

	// In a null case, just return
	if .Type() == Null {
		return
	}

	if .mapType.UnsafeIsNil() {
		.mapType.UnsafeSet(, .mapType.UnsafeMakeMap(1))
	}

	 := schemaTypeName()
	 := reflect2.PtrOf()

	 := .elemType.UnsafeNew()
	.typeDecs[].Decode(, )

	.mapType.UnsafeSetIndex(, , )
}

func encoderOfMapUnion( *encoderContext,  *UnionSchema,  reflect2.Type) ValEncoder {
	return &mapUnionEncoder{
		cfg:    .cfg,
		schema: ,
	}
}

type mapUnionEncoder struct {
	cfg    *frozenConfig
	schema *UnionSchema
}

func ( *mapUnionEncoder) ( unsafe.Pointer,  *Writer) {
	 := *((*map[string]any)())

	if len() > 1 {
		.Error = errors.New("avro: cannot encode union map with multiple entries")
		return
	}

	 := "null"
	 := any(nil)
	for ,  := range  {
		 = 
		 = 
		break
	}

	,  := .schema.Types().Get()
	if  == nil {
		.Error = fmt.Errorf("avro: unknown union type %s", )
		return
	}

	.WriteInt(int32())

	if .Type() == Null &&  == nil {
		return
	}

	// encode a nil slice as an empty array
	if .Type() == Array &&  == nil {
		// element data type doesn't matter since it skips iterating the slice
		 = []struct{}{}
	}

	,  := .cfg.typeConverters.EncodeTypeConvert(, .schema)
	if  != nil && !errors.Is(, errNoTypeConverter) {
		.Error = 
		return
	}

	 := reflect2.TypeOf()
	 := reflect2.PtrOf()

	 := encoderOfType(newEncoderContext(.cfg), , )
	if .LikePtr() {
		 = &onePtrEncoder{}
	}
	.Encode(, )
}

func decoderOfNullableUnion( *decoderContext,  Schema,  reflect2.Type) ValDecoder {
	 := .(*UnionSchema)
	,  := .Indices()

	var (
		 reflect2.Type
		   bool
	)
	switch v := .(type) {
	case *reflect2.UnsafePtrType:
		 = .Elem()
		 = true
	case *reflect2.UnsafeSliceType:
		 = 
	}
	 := decoderOfType(, .Types()[], )

	return &unionNullableDecoder{
		schema:  ,
		typ:     ,
		isPtr:   ,
		decoder: ,
	}
}

type unionNullableDecoder struct {
	schema  *UnionSchema
	typ     reflect2.Type
	isPtr   bool
	decoder ValDecoder
}

func ( *unionNullableDecoder) ( unsafe.Pointer,  *Reader) {
	,  := getUnionSchema(.schema, )
	if  == nil {
		return
	}

	if .Type() == Null {
		*((*unsafe.Pointer)()) = nil
		return
	}

	defer func() {
		if !.isPtr {
			 := .typ.UnsafeIndirect()
			,  := .cfg.typeConverters.DecodeTypeConvert(, .schema)
			if errors.Is(, errNoTypeConverter) {
				return
			}
			if  != nil {
				.Error = 
			}
			if  == nil {
				*(*unsafe.Pointer)() = nil
				return
			}
			.typ.UnsafeSet(, reflect2.PtrOf())
			return
		}
		 := .typ.UnsafeIndirect(*((*unsafe.Pointer)()))
		,  := .cfg.typeConverters.DecodeTypeConvert(, .schema)
		if errors.Is(, errNoTypeConverter) {
			return
		}
		if  != nil {
			.Error = 
		}
		*((*unsafe.Pointer)()) = reflect2.PtrOf()
	}()

	// Handle the non-ptr case separately.
	if !.isPtr {
		if .typ.UnsafeIsNil() {
			// Create a new instance.
			 := .typ.UnsafeNew()
			.decoder.Decode(, )
			.typ.UnsafeSet(, )
			return
		}

		// Reuse the existing instance.
		.decoder.Decode(, )
		return
	}

	if *((*unsafe.Pointer)()) == nil {
		// Create new instance.
		 := .typ.UnsafeNew()
		.decoder.Decode(, )
		*((*unsafe.Pointer)()) = 
		return
	}

	// Reuse existing instance.
	.decoder.Decode(*((*unsafe.Pointer)()), )
}

func encoderOfUnionConverterCodec( *encoderContext,  Schema,  reflect2.Type) ValEncoder {
	 := .(*UnionSchema)
	var  int32
	var  bool

	for ,  := range .Types() {
		if .Type() == Null {
			 = int32()
			 = true
		}
	}

	return &unionConverterToAnyCodec{
		schema:   ,
		typ:      ,
		nullable: ,
		nullIdx:  ,
	}
}

type unionConverterToAnyCodec struct {
	schema   *UnionSchema
	typ      reflect2.Type
	nullable bool
	nullIdx  int32
}

func ( *unionConverterToAnyCodec) ( unsafe.Pointer,  *Writer) {
	if *((*unsafe.Pointer)()) == nil {
		if !.nullable {
			.Error = errors.New("avro: unionConverterToAnyCodec: encoding nil value for non nillable union")
			return
		}
		.WriteInt(.nullIdx)
		return
	}

	 := .typ.UnsafeIndirect()
	 := .(UnionConverter)
	,  := .ToAny()
	if  != nil {
		.Error = fmt.Errorf("avro: unable to convert union: %w", )
		return
	}

	 := reflect2.TypeOf()
	,  := .(*reflect2.UnsafePtrType)
	if ! {
		.Error = fmt.Errorf("avro: expected ptr but received %q", .String())
		return
	}

	 := .Elem()
	.WriteVal(.schema, .Indirect())
}

func encoderOfNullableUnion( *encoderContext,  Schema,  reflect2.Type) ValEncoder {
	 := .(*UnionSchema)
	,  := .Indices()

	var (
		 reflect2.Type
		   bool
	)
	switch v := .(type) {
	case *reflect2.UnsafePtrType:
		 = .Elem()
		 = true
	case *reflect2.UnsafeSliceType:
		 = 
	}
	 := encoderOfType(, .Types()[], )

	return &unionNullableEncoder{
		schema:  ,
		encoder: ,
		isPtr:   ,
		nullIdx: int32(),
		typeIdx: int32(),
	}
}

type unionNullableEncoder struct {
	schema  *UnionSchema
	encoder ValEncoder
	isPtr   bool
	nullIdx int32
	typeIdx int32
}

func ( *unionNullableEncoder) ( unsafe.Pointer,  *Writer) {
	if *((*unsafe.Pointer)()) == nil {
		.WriteInt(.nullIdx)
		return
	}

	.WriteInt(.typeIdx)
	 := 
	if .isPtr {
		 = *((*unsafe.Pointer)())
	}
	.encoder.Encode(, )
}

func decoderOfResolvedUnion( *decoderContext,  Schema,  reflect2.Type) (ValDecoder, error) {
	 := .(*UnionSchema)

	 := make([]reflect2.Type, len(.Types()))
	 := make([]ValDecoder, len(.Types()))
	for ,  := range .Types() {
		 := unionResolutionName()

		,  := .cfg.resolver.Type()
		if  != nil {
			if .cfg.config.UnionResolutionError {
				return nil, 
			}

			if .cfg.config.PartialUnionTypeResolution {
				[] = nil
				[] = nil
				continue
			}

			 = []ValDecoder{}
			 = []reflect2.Type{}
			break
		}

		 := decoderOfType(, , )
		[] = 
		[] = 
	}

	return &unionResolvedDecoder{
		cfg:      .cfg,
		schema:   ,
		types:    ,
		decoders: ,
	}, nil
}

type unionResolvedDecoder struct {
	cfg      *frozenConfig
	schema   *UnionSchema
	types    []reflect2.Type
	decoders []ValDecoder
}

func ( *unionResolvedDecoder) ( unsafe.Pointer,  *Reader) {
	,  := getUnionSchema(.schema, )
	if  == nil {
		return
	}

	 := (*any)()

	if .Type() == Null {
		* = nil
		return
	}

	defer func() {
		,  := .cfg.typeConverters.DecodeTypeConvert(*, .schema)
		if  != nil && !errors.Is(, errNoTypeConverter) {
			.Error = 
		}
		* = 
	}()

	if  >= len(.decoders) || .decoders[] == nil {
		if .cfg.config.UnionResolutionError {
			.ReportError("decode union type", "unknown union type")
			return
		}

		// We cannot resolve this, set it to the map type
		 := schemaTypeName()
		 := map[string]any{}
		,  := genericReceiver()
		if  != nil {
			.ReportError("Union", .Error())
			return
		}
		[] = genericDecode(, decoderOfType(newDecoderContext(.cfg), , ), )

		* = 
		return
	}

	 := .types[]
	var  unsafe.Pointer
	switch .Kind() {
	case reflect.Map:
		 := .(*reflect2.UnsafeMapType)
		 = .UnsafeMakeMap(0)

	case reflect.Slice:
		 := .(*reflect2.UnsafeSliceType)
		 = .UnsafeMakeSlice(0, 0)

	case reflect.Ptr:
		 := .(*reflect2.UnsafePtrType).Elem()
		 = .UnsafeNew()

	default:
		 = .UnsafeNew()
	}

	.decoders[].Decode(, )

	* = .UnsafeIndirect()
}

func decoderOfUnionConverterCodec( *decoderContext,  *UnionSchema,  reflect2.Type) ValDecoder {
	 := createDecoderOfUnion(, , reflect2.Type2(reflect.TypeFor[any]()))
	 := slices.ContainsFunc(.Types(), func( Schema) bool {
		return .Type() == Null
	})

	return &unionConverterFromAnyCodec{
		decoder:  ,
		schema:   ,
		nullable: ,
		typ:      ,
	}
}

type unionConverterFromAnyCodec struct {
	decoder  ValDecoder
	schema   *UnionSchema
	nullable bool
	typ      reflect2.Type
}

func ( *unionConverterFromAnyCodec) ( unsafe.Pointer,  *Reader) {
	 := new(any)
	 := reflect2.PtrOf()
	.decoder.Decode(, )

	if * == nil {
		if .nullable {
			return
		}

		.Error = errors.New("avro: cannot decode nil value in non-nullable union type")
		return
	}

	if .typ.Kind() == reflect.Ptr {
		 := .typ.(*reflect2.UnsafePtrType).Elem()
		 := .UnsafeNew()
		*((*unsafe.Pointer)()) = 
	}
	 := .typ.UnsafeIndirect()

	 := .(UnionConverter)
	if  := .FromAny(*);  != nil {
		.ReportError("Union", .Error())
		return
	}
}

func unionResolutionName( Schema) string {
	 := schemaTypeName()
	switch .Type() {
	case Map:
		 += ":"
		 := .(*MapSchema).Values()
		 := schemaTypeName()

		 += 

	case Array:
		 += ":"
		 := .(*ArraySchema).Items()
		 := schemaTypeName()

		 += 
	}

	return 
}

func encoderOfResolverUnion( *encoderContext,  Schema,  reflect2.Type) ValEncoder {
	 := .(*UnionSchema)

	,  := .cfg.resolver.Name()
	if  != nil {
		return &errorEncoder{err: }
	}

	var  int
	for ,  := range  {
		if  := strings.Index(, ":");  > 0 {
			 = [:]
		}

		,  = .Types().Get()
		if  != nil {
			break
		}
	}
	if  == nil {
		return &errorEncoder{err: fmt.Errorf("avro: unknown union type %s", [0])}
	}

	 := encoderOfType(, , )

	return &unionResolverEncoder{
		pos:     ,
		encoder: ,
	}
}

type unionResolverEncoder struct {
	pos     int
	encoder ValEncoder
}

func ( *unionResolverEncoder) ( unsafe.Pointer,  *Writer) {
	.WriteInt(int32(.pos))

	.encoder.Encode(, )
}

func getUnionSchema( *UnionSchema,  *Reader) (int, Schema) {
	 := .Types()

	 := int(.ReadInt())
	if  < 0 ||  > len()-1 {
		.ReportError("decode union type", "unknown union type")
		return 0, nil
	}

	return , []
}