// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

//go:build go1.18

package compute

import (
	
	
	

	
	
	
	
	
)

var (
	castTable map[arrow.Type]*castFunction
	castInit  sync.Once

	castDoc = FunctionDoc{
		Summary:         "cast values to another data type",
		Description:     "Behavior when values wouldn't fit in the target type\ncan be controlled through CastOptions.",
		ArgNames:        []string{"input"},
		OptionsType:     "CastOptions",
		OptionsRequired: true,
	}
	castMetaFunc = NewMetaFunction("cast", Unary(), castDoc,
		func( context.Context,  FunctionOptions,  ...Datum) (Datum, error) {
			 := .(*CastOptions)
			if  == nil || .ToType == nil {
				return nil, fmt.Errorf("%w: cast requires that options be passed with a ToType", arrow.ErrInvalid)
			}

			if arrow.TypeEqual([0].(ArrayLikeDatum).Type(), .ToType) {
				return NewDatum([0]), nil
			}

			,  := getCastFunction(.ToType)
			if  != nil {
				return nil, fmt.Errorf("%w from %s", , [0].(ArrayLikeDatum).Type())
			}

			return .Execute(, , ...)
		})
)

func ( FunctionRegistry) {
	.AddFunction(castMetaFunc, false)
}

type castFunction struct {
	ScalarFunction

	inIDs []arrow.Type
	out   arrow.Type
}

func newCastFunction( string,  arrow.Type) *castFunction {
	return &castFunction{
		ScalarFunction: *NewScalarFunction(, Unary(), EmptyFuncDoc),
		out:            ,
		inIDs:          make([]arrow.Type, 0, 1),
	}
}

func ( *castFunction) ( arrow.Type,  exec.ScalarKernel) error {
	.Init = exec.OptionsInit[kernels.CastState]
	if  := .AddKernel();  != nil {
		return 
	}
	.inIDs = append(.inIDs, )
	return nil
}

func ( *castFunction) ( arrow.Type,  []exec.InputType,  exec.OutputType,
	 exec.ArrayKernelExec,  exec.NullHandling,  exec.MemAlloc) error {

	 := exec.NewScalarKernel(, , , nil)
	.NullHandling = 
	.MemAlloc = 
	return .AddTypeCast(, )
}

func ( *castFunction) ( ...arrow.DataType) (exec.Kernel, error) {
	if  := .checkArity(len());  != nil {
		return nil, 
	}

	 := make([]*exec.ScalarKernel, 0, 1)
	for  := range .kernels {
		if .kernels[].Signature.MatchesInputs() {
			 = append(, &.kernels[])
		}
	}

	if len() == 0 {
		return nil, fmt.Errorf("%w: unsupported cast from %s to %s using function %s",
			arrow.ErrNotImplemented, [0], .out, .name)
	}

	if len() == 1 {
		// one match!
		return [0], nil
	}

	// in this situation we may have both an EXACT type and
	// a SAME_TYPE_ID match. So we will see if there is an exact
	// match among the candidates and if not, we just return the
	// first one
	for ,  := range  {
		 := .Signature.InputTypes[0]
		if .Kind == exec.InputExact {
			// found one!
			return , nil
		}
	}

	// just return some kernel that matches since we didn't find an exact
	return [0], nil
}

func unpackDictionary( *exec.KernelCtx,  *exec.ExecSpan,  *exec.ExecResult) error {
	var (
		  = .Values[0].Array.MakeArray().(*array.Dictionary)
		     = .State.(kernels.CastState)
		 = .DataType().(*arrow.DictionaryType)
		   = .ToType
	)
	defer .Release()

	if !arrow.TypeEqual(, ) && !CanCast(, ) {
		return fmt.Errorf("%w: cast type %s incompatible with dictionary type %s",
			arrow.ErrInvalid, , )
	}

	,  := TakeArray(.Ctx, .Dictionary(), .Indices())
	if  != nil {
		return 
	}
	defer .Release()

	if !arrow.TypeEqual(, ) {
		,  = CastArray(.Ctx, , &)
		if  != nil {
			return 
		}
		defer .Release()
	}

	.TakeOwnership(.Data())
	return nil
}

func ( *exec.KernelCtx,  *exec.ExecSpan,  *exec.ExecResult) error {
	 := .State.(kernels.CastState)

	 := .Values[0].Array.MakeArray().(array.ExtensionArray)
	defer .Release()

	 := CastOptions()
	,  := CastArray(.Ctx, .Storage(), &)
	if  != nil {
		return 
	}
	defer .Release()

	.TakeOwnership(.Data())
	return nil
}

func [,  int32 | int64]( *exec.KernelCtx,  *exec.ExecSpan,  *exec.ExecResult) error {
	var (
		       = .State.(kernels.CastState)
		  = .Type.(arrow.NestedType).Fields()[0].Type
		      = &.Values[0].Array
		    = exec.GetSpanOffsets[](, 1)
		 = kernels.SizeOf[]() > kernels.SizeOf[]()
	)

	.Buffers[0] = .Buffers[0]
	.Buffers[1] = .Buffers[1]

	if .Offset != 0 && len(.Buffers[0].Buf) > 0 {
		.Buffers[0].WrapBuffer(.AllocateBitmap(.Len))
		bitutil.CopyBitmap(.Buffers[0].Buf, int(.Offset), int(.Len),
			.Buffers[0].Buf, 0)
	}

	// Handle list offsets
	// Several cases possible:
	//	- The source offset is non-zero, in which case we slice the
	//	  underlying values and shift the list offsets (regardless of
	//	  their respective types)
	//	- the source offset is zero but the source and destination types
	//	  have different list offset types, in which case we cast the offsets
	//  - otherwise we simply keep the original offsets
	if  {
		if [.Len] > (kernels.MaxOf[]()) {
			return fmt.Errorf("%w: array of type %s too large to convert to %s",
				arrow.ErrInvalid, .Type, .Type)
		}
	}

	 := .Children[0].MakeArray()
	defer .Release()

	if .Offset != 0 {
		.Buffers[1].WrapBuffer(
			.Allocate(.Type.(arrow.OffsetsDataType).
				OffsetTypeTraits().BytesRequired(int(.Len) + 1)))

		 := exec.GetSpanOffsets[](, 1)
		for  := 0;  < int(.Len)+1; ++ {
			[] = ([] - [0])
		}

		 = array.NewSlice(, int64([0]), int64([.Len]))
		defer .Release()
	} else if kernels.SizeOf[]() != kernels.SizeOf[]() {
		.Buffers[1].WrapBuffer(.Allocate(.Type.(arrow.OffsetsDataType).
			OffsetTypeTraits().BytesRequired(int(.Len) + 1)))

		kernels.DoStaticCast(exec.GetSpanOffsets[](, 1),
			exec.GetSpanOffsets[](, 1))
	}

	// handle values
	.ToType = 

	,  := CastArray(.Ctx, , &)
	if  != nil {
		return 
	}
	defer .Release()

	.Children = make([]exec.ArraySpan, 1)
	.Children[0].SetMembers(.Data())
	for ,  := range .Children[0].Buffers {
		if .Owner != nil && .Owner != .Data().Buffers()[] {
			.Owner.Retain()
			.SelfAlloc = true
		}
	}
	return nil
}

func ( *exec.KernelCtx,  *exec.ExecSpan,  *exec.ExecResult) error {
	var (
		          = .State.(kernels.CastState)
		        = .Values[0].Array.Type.(*arrow.StructType)
		       = .Type.(*arrow.StructType)
		  = .NumFields()
		 = .NumFields()
	)

	 := make([]int, )
	for  := range  {
		[] = -1
	}

	 := 0
	for  := 0;  <  &&  < ; ++ {
		 := .Field()
		 := .Field()
		if .Name == .Name {
			if .Nullable && !.Nullable {
				return fmt.Errorf("%w: cannot cast nullable field to non-nullable field: %s %s",
					arrow.ErrType, , )
			}
			[] = 
			++
		}
	}

	if  <  {
		return fmt.Errorf("%w: struct fields don't match or are in the wrong order: Input: %s Output: %s",
			arrow.ErrType, , )
	}

	 := &.Values[0].Array
	if len(.Buffers[0].Buf) > 0 {
		.Buffers[0].WrapBuffer(.AllocateBitmap(.Len))
		bitutil.CopyBitmap(.Buffers[0].Buf, int(.Offset), int(.Len),
			.Buffers[0].Buf, 0)
	}

	.Children = make([]exec.ArraySpan, )
	for ,  := range  {
		 := .Children[].MakeArray()
		defer .Release()
		 = array.NewSlice(, .Offset, .Len)
		defer .Release()

		.ToType = .Field().Type
		,  := CastArray(.Ctx, , &)
		if  != nil {
			return 
		}
		defer .Release()

		.Children[].TakeOwnership(.Data())
	}
	return nil
}

func addListCast[,  int32 | int64]( *castFunction,  arrow.Type) error {
	 := exec.NewScalarKernel([]exec.InputType{exec.NewIDInput()},
		kernels.OutputTargetType, CastList[, ], nil)
	.NullHandling = exec.NullComputedNoPrealloc
	.MemAlloc = exec.MemNoPrealloc
	return .AddTypeCast(, )
}

func addStructToStructCast( *castFunction) error {
	 := exec.NewScalarKernel([]exec.InputType{exec.NewIDInput(arrow.STRUCT)},
		kernels.OutputTargetType, CastStruct, nil)
	.NullHandling = exec.NullComputedNoPrealloc
	return .AddTypeCast(arrow.STRUCT, )
}

func addCastFuncs( []*castFunction) {
	for ,  := range  {
		.AddNewTypeCast(arrow.EXTENSION, []exec.InputType{exec.NewIDInput(arrow.EXTENSION)},
			.kernels[0].Signature.OutType, CastFromExtension,
			exec.NullComputedNoPrealloc, exec.MemNoPrealloc)
		castTable[.out] = 
	}
}

func initCastTable() {
	castTable = make(map[arrow.Type]*castFunction)
	addCastFuncs(getBooleanCasts())
	addCastFuncs(getNumericCasts())
	addCastFuncs(getBinaryLikeCasts())
	addCastFuncs(getTemporalCasts())
	addCastFuncs(getNestedCasts())

	 := newCastFunction("cast_extension", arrow.EXTENSION)
	.AddNewTypeCast(arrow.NULL, []exec.InputType{exec.NewExactInput(arrow.Null)},
		kernels.OutputTargetType, kernels.CastFromNull, exec.NullComputedNoPrealloc, exec.MemNoPrealloc)
	castTable[arrow.EXTENSION] = 
}

func getCastFunction( arrow.DataType) (*castFunction, error) {
	castInit.Do(initCastTable)

	,  := castTable[.ID()]
	if  {
		return , nil
	}

	return nil, fmt.Errorf("%w: unsupported cast to %s", arrow.ErrNotImplemented, )
}

func getNestedCasts() []*castFunction {
	 := make([]*castFunction, 0)

	 := func( *castFunction,  []exec.ScalarKernel) {
		for ,  := range  {
			if  := .AddTypeCast(.Signature.InputTypes[0].MatchID(), );  != nil {
				panic()
			}
		}
	}

	 := newCastFunction("cast_list", arrow.LIST)
	(, kernels.GetCommonCastKernels(arrow.LIST, kernels.OutputTargetType))
	if  := addListCast[int32, int32](, arrow.LIST);  != nil {
		panic()
	}
	if  := addListCast[int64, int32](, arrow.LARGE_LIST);  != nil {
		panic()
	}
	 = append(, )

	 := newCastFunction("cast_large_list", arrow.LARGE_LIST)
	(, kernels.GetCommonCastKernels(arrow.LARGE_LIST, kernels.OutputTargetType))
	if  := addListCast[int32, int64](, arrow.LIST);  != nil {
		panic()
	}
	if  := addListCast[int64, int64](, arrow.LARGE_LIST);  != nil {
		panic()
	}
	 = append(, )

	 := newCastFunction("cast_fixed_size_list", arrow.FIXED_SIZE_LIST)
	(, kernels.GetCommonCastKernels(arrow.FIXED_SIZE_LIST, kernels.OutputTargetType))
	 = append(, )

	 := newCastFunction("cast_struct", arrow.STRUCT)
	(, kernels.GetCommonCastKernels(arrow.STRUCT, kernels.OutputTargetType))
	if  := addStructToStructCast();  != nil {
		panic()
	}
	 = append(, )

	return 
}

func getBooleanCasts() []*castFunction {
	 := newCastFunction("cast_boolean", arrow.BOOL)
	 := kernels.GetBooleanCastKernels()

	for ,  := range  {
		if  := .AddTypeCast(.Signature.InputTypes[0].Type.ID(), );  != nil {
			panic()
		}
	}

	return []*castFunction{}
}

func getTemporalCasts() []*castFunction {
	 := make([]*castFunction, 0)
	 := func( string,  arrow.Type,  []exec.ScalarKernel) {
		 := newCastFunction(, )
		for ,  := range  {
			if  := .AddTypeCast(.Signature.InputTypes[0].MatchID(), );  != nil {
				panic()
			}
		}
		.AddNewTypeCast(arrow.DICTIONARY, []exec.InputType{exec.NewIDInput(arrow.DICTIONARY)},
			[0].Signature.OutType, unpackDictionary, exec.NullComputedNoPrealloc, exec.MemNoPrealloc)
		 = append(, )
	}

	("cast_timestamp", arrow.TIMESTAMP, kernels.GetTimestampCastKernels())
	("cast_date32", arrow.DATE32, kernels.GetDate32CastKernels())
	("cast_date64", arrow.DATE64, kernels.GetDate64CastKernels())
	("cast_time32", arrow.TIME32, kernels.GetTime32CastKernels())
	("cast_time64", arrow.TIME64, kernels.GetTime64CastKernels())
	("cast_duration", arrow.DURATION, kernels.GetDurationCastKernels())
	("cast_month_day_nano_interval", arrow.INTERVAL_MONTH_DAY_NANO, kernels.GetIntervalCastKernels())
	return 
}

func getNumericCasts() []*castFunction {
	 := make([]*castFunction, 0)

	 := func( string,  arrow.Type,  []exec.ScalarKernel) *castFunction {
		 := newCastFunction(, )
		for ,  := range  {
			if  := .AddTypeCast(.Signature.InputTypes[0].MatchID(), );  != nil {
				panic()
			}
		}

		.AddNewTypeCast(arrow.DICTIONARY, []exec.InputType{exec.NewIDInput(arrow.DICTIONARY)},
			[0].Signature.OutType, unpackDictionary, exec.NullComputedNoPrealloc, exec.MemNoPrealloc)

		return 
	}

	 = append(, ("cast_int8", arrow.INT8, kernels.GetCastToInteger[int8](arrow.PrimitiveTypes.Int8)))
	 = append(, ("cast_int16", arrow.INT16, kernels.GetCastToInteger[int8](arrow.PrimitiveTypes.Int16)))

	 := ("cast_int32", arrow.INT32, kernels.GetCastToInteger[int32](arrow.PrimitiveTypes.Int32))
	.AddTypeCast(arrow.DATE32,
		kernels.GetZeroCastKernel(arrow.DATE32,
			exec.NewExactInput(arrow.FixedWidthTypes.Date32),
			exec.NewOutputType(arrow.PrimitiveTypes.Int32)))
	.AddTypeCast(arrow.TIME32,
		kernels.GetZeroCastKernel(arrow.TIME32,
			exec.NewIDInput(arrow.TIME32), exec.NewOutputType(arrow.PrimitiveTypes.Int32)))
	 = append(, )

	 := ("cast_int64", arrow.INT64, kernels.GetCastToInteger[int64](arrow.PrimitiveTypes.Int64))
	.AddTypeCast(arrow.DATE64,
		kernels.GetZeroCastKernel(arrow.DATE64,
			exec.NewIDInput(arrow.DATE64),
			exec.NewOutputType(arrow.PrimitiveTypes.Int64)))
	.AddTypeCast(arrow.TIME64,
		kernels.GetZeroCastKernel(arrow.TIME64,
			exec.NewIDInput(arrow.TIME64),
			exec.NewOutputType(arrow.PrimitiveTypes.Int64)))
	.AddTypeCast(arrow.DURATION,
		kernels.GetZeroCastKernel(arrow.DURATION,
			exec.NewIDInput(arrow.DURATION),
			exec.NewOutputType(arrow.PrimitiveTypes.Int64)))
	.AddTypeCast(arrow.TIMESTAMP,
		kernels.GetZeroCastKernel(arrow.TIMESTAMP,
			exec.NewIDInput(arrow.TIMESTAMP),
			exec.NewOutputType(arrow.PrimitiveTypes.Int64)))
	 = append(, )

	 = append(, ("cast_uint8", arrow.UINT8, kernels.GetCastToInteger[uint8](arrow.PrimitiveTypes.Uint8)))
	 = append(, ("cast_uint16", arrow.UINT16, kernels.GetCastToInteger[uint16](arrow.PrimitiveTypes.Uint16)))
	 = append(, ("cast_uint32", arrow.UINT32, kernels.GetCastToInteger[uint32](arrow.PrimitiveTypes.Uint32)))
	 = append(, ("cast_uint64", arrow.UINT64, kernels.GetCastToInteger[uint64](arrow.PrimitiveTypes.Uint64)))

	 = append(, ("cast_half_float", arrow.FLOAT16, kernels.GetCommonCastKernels(arrow.FLOAT16, exec.NewOutputType(arrow.FixedWidthTypes.Float16))))
	 = append(, ("cast_float", arrow.FLOAT32, kernels.GetCastToFloating[float32](arrow.PrimitiveTypes.Float32)))
	 = append(, ("cast_double", arrow.FLOAT64, kernels.GetCastToFloating[float64](arrow.PrimitiveTypes.Float64)))

	// cast to decimal128
	 = append(, ("cast_decimal", arrow.DECIMAL128, kernels.GetCastToDecimal128()))
	// cast to decimal256
	 = append(, ("cast_decimal256", arrow.DECIMAL256, kernels.GetCastToDecimal256()))
	return 
}

func getBinaryLikeCasts() []*castFunction {
	 := make([]*castFunction, 0)

	 := func( string,  arrow.Type,  []exec.ScalarKernel) {
		 := newCastFunction(, )
		for ,  := range  {
			if  := .AddTypeCast(.Signature.InputTypes[0].MatchID(), );  != nil {
				panic()
			}
		}

		.AddNewTypeCast(arrow.DICTIONARY, []exec.InputType{exec.NewIDInput(arrow.DICTIONARY)},
			[0].Signature.OutType, unpackDictionary, exec.NullComputedNoPrealloc, exec.MemNoPrealloc)

		 = append(, )
	}

	("cast_binary", arrow.BINARY, kernels.GetToBinaryKernels(arrow.BinaryTypes.Binary))
	("cast_large_binary", arrow.LARGE_BINARY, kernels.GetToBinaryKernels(arrow.BinaryTypes.LargeBinary))
	("cast_string", arrow.STRING, kernels.GetToBinaryKernels(arrow.BinaryTypes.String))
	("cast_large_string", arrow.LARGE_STRING, kernels.GetToBinaryKernels(arrow.BinaryTypes.LargeString))
	("cast_fixed_sized_binary", arrow.FIXED_SIZE_BINARY, kernels.GetFsbCastKernels())
	return 
}

// CastDatum is a convenience function for casting a Datum to another type.
// It is equivalent to calling CallFunction(ctx, "cast", opts, Datum) and
// should work for Scalar, Array or ChunkedArray Datums.
func ( context.Context,  Datum,  *CastOptions) (Datum, error) {
	return CallFunction(, "cast", , )
}

// CastArray is a convenience function for casting an Array to another type.
// It is equivalent to constructing a Datum for the array and using
// CallFunction(ctx, "cast", ...).
func ( context.Context,  arrow.Array,  *CastOptions) (arrow.Array, error) {
	 := NewDatum()
	defer .Release()

	,  := CastDatum(, , )
	if  != nil {
		return nil, 
	}

	defer .Release()
	return .(*ArrayDatum).MakeArray(), nil
}

// CastToType is a convenience function equivalent to calling
// CastArray(ctx, val, compute.SafeCastOptions(toType))
func ( context.Context,  arrow.Array,  arrow.DataType) (arrow.Array, error) {
	return CastArray(, , SafeCastOptions())
}

// CanCast returns true if there is an implementation for casting an array
// or scalar value from the specified DataType to the other data type.
func (,  arrow.DataType) bool {
	,  := getCastFunction()
	if  != nil {
		return false
	}

	for ,  := range .inIDs {
		if .ID() ==  {
			return true
		}
	}
	return false
}