// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

//go:build go1.18

package kernels

import (
	
	
	

	
	
	
	
	
	
	
	
	
)

type binaryKernel func(left, right, out []byte, offset int)

type cmpFn[,  arrow.FixedWidthType] func([], [], []uint32)
type cmpScalarLeft[,  arrow.FixedWidthType] func(, [], []uint32)
type cmpScalarRight[,  arrow.FixedWidthType] func([], , []uint32)

type cmpOp[ arrow.FixedWidthType] struct {
	arrArr    cmpFn[, ]
	arrScalar cmpScalarRight[, ]
	scalarArr cmpScalarLeft[, ]
}

func comparePrimitiveArrayArray[ arrow.FixedWidthType]( cmpFn[, ]) binaryKernel {
	return func(, ,  []byte,  int) {
		const  = 32
		var (
			      = arrow.GetData[]()
			     = arrow.GetData[]()
			     = len()
			  =  / 
			 []uint32
		)

		 := [:]
		if  :=  % 8;  != 0 {
			 := 8 - 
			([:], [:], [:])
			,  = [:], [:]

			for ,  := range [:] {
				bitutil.SetBitTo(, +,  != 0)
			}
			 = [1:]
		}

		for  := 0;  < ; ++ {
			(, , )
			,  = [:], [:]
			packBits(, )
			 = [/8:]
		}

		 :=  - ( * )
		(, , [:])
		for ,  := range [:] {
			bitutil.SetBitTo(, ,  != 0)
		}
	}
}

func comparePrimitiveArrayScalar[ arrow.FixedWidthType]( cmpScalarRight[, ]) binaryKernel {
	return func(, ,  []byte,  int) {
		const  = 32
		var (
			      = arrow.GetData[]()
			  = *(*)(unsafe.Pointer(&[0]))
			     = len()
			  =  / 
			 []uint32
		)

		 := [:]
		if  :=  % 8;  != 0 {
			 := 8 - 
			([:], , [:])
			 = [:]

			for ,  := range [:] {
				bitutil.SetBitTo(, +,  != 0)
			}
			 = [1:]
		}

		for  := 0;  < ; ++ {
			(, , )
			 = [:]
			packBits(, )
			 = [/8:]
		}

		 :=  - ( * )
		(, , [:])
		for ,  := range [:] {
			bitutil.SetBitTo(, ,  != 0)
		}
	}
}

func comparePrimitiveScalarArray[ arrow.FixedWidthType]( cmpScalarLeft[, ]) binaryKernel {
	return func(, ,  []byte,  int) {
		const  = 32
		var (
			 = *(*)(unsafe.Pointer(&[0]))
			   = arrow.GetData[]()

			     = len()
			  =  / 
			 []uint32
		)

		 := [:]
		if  :=  % 8;  != 0 {
			 := 8 - 
			(, [:], [:])
			 = [:]

			for ,  := range [:] {
				bitutil.SetBitTo(, +,  != 0)
			}
			 = [1:]
		}

		for  := 0;  < ; ++ {
			(, [:], )
			 = [:]
			packBits(, )
			 = [/8:]
		}

		 :=  - ( * )
		(, , [:])
		for ,  := range [:] {
			bitutil.SetBitTo(, ,  != 0)
		}
	}
}

type CompareData struct {
	funcAA, funcSA, funcAS binaryKernel
}

func ( *CompareData) () *CompareData { return  }

type CompareFuncData interface {
	Funcs() *CompareData
}

func getOffsetSpanBytes( *exec.ArraySpan) []byte {
	if len(.Buffers[1].Buf) == 0 {
		return nil
	}

	 := .Buffers[1].Buf
	 := int64(.Type.(arrow.FixedWidthDataType).Bytes())
	 := .Offset * 
	return [ : +(.Len*)]
}

func compareKernel[ arrow.FixedWidthType]( *exec.KernelCtx,  *exec.ExecSpan,  *exec.ExecResult) error {
	 := .Kernel.(*exec.ScalarKernel)
	 := .Data.(CompareFuncData).Funcs()

	 := int(.Offset % 8)
	 := .Buffers[1].Buf[.Offset/8:]

	if .Values[0].IsArray() && .Values[1].IsArray() {
		.funcAA(getOffsetSpanBytes(&.Values[0].Array),
			getOffsetSpanBytes(&.Values[1].Array), , )
	} else if .Values[1].IsScalar() {
		.funcAS(getOffsetSpanBytes(&.Values[0].Array),
			.Values[1].Scalar.(scalar.PrimitiveScalar).Data(), , )
	} else {
		.funcSA(.Values[0].Scalar.(scalar.PrimitiveScalar).Data(),
			getOffsetSpanBytes(&.Values[1].Array), , )
	}

	return nil
}

func genGoCompareKernel[ arrow.FixedWidthType]( *cmpOp[]) *CompareData {
	return &CompareData{
		funcAA: comparePrimitiveArrayArray(.arrArr),
		funcAS: comparePrimitiveArrayScalar(.arrScalar),
		funcSA: comparePrimitiveScalarArray(.scalarArr),
	}
}

type decCmp[ decimal128.Num | decimal256.Num] struct {
	Gt func(, ) bool
	Ge func(, ) bool
}

var dec128Cmp = decCmp[decimal128.Num]{
	Gt: func(,  decimal128.Num) bool { return .Greater() },
	Ge: func(,  decimal128.Num) bool { return .GreaterEqual() },
}

var dec256Cmp = decCmp[decimal256.Num]{
	Gt: func(,  decimal256.Num) bool { return .Greater() },
	Ge: func(,  decimal256.Num) bool { return .GreaterEqual() },
}

func getCmpDec[ decimal128.Num | decimal256.Num]( CompareOperator,  decCmp[]) *cmpOp[] {
	switch  {
	case CmpEQ:
		return &cmpOp[]{
			arrArr: func(,  [],  []uint32) {
				for  := range  {
					if [] == [] {
						[] = 1
					} else {
						[] = 0
					}
				}
			},
			arrScalar: func( [],  ,  []uint32) {
				for  := range  {
					if [] ==  {
						[] = 1
					} else {
						[] = 0
					}
				}
			},
			scalarArr: func( ,  [],  []uint32) {
				for  := range  {
					if  == [] {
						[] = 1
					} else {
						[] = 0
					}
				}
			},
		}
	case CmpNE:
		return &cmpOp[]{
			arrArr: func(,  [],  []uint32) {
				for  := range  {
					if [] != [] {
						[] = 1
					} else {
						[] = 0
					}
				}
			},
			arrScalar: func( [],  ,  []uint32) {
				for  := range  {
					if [] !=  {
						[] = 1
					} else {
						[] = 0
					}
				}
			},
			scalarArr: func( ,  [],  []uint32) {
				for  := range  {
					if  != [] {
						[] = 1
					} else {
						[] = 0
					}
				}
			},
		}
	case CmpGT:
		return &cmpOp[]{
			arrArr: func(,  [],  []uint32) {
				for  := range  {
					if .Gt([], []) {
						[] = 1
					} else {
						[] = 0
					}
				}
			},
			arrScalar: func( [],  ,  []uint32) {
				for  := range  {
					if .Gt([], ) {
						[] = 1
					} else {
						[] = 0
					}
				}
			},
			scalarArr: func( ,  [],  []uint32) {
				for  := range  {
					if .Gt(, []) {
						[] = 1
					} else {
						[] = 0
					}
				}
			},
		}
	case CmpGE:
		return &cmpOp[]{
			arrArr: func(,  [],  []uint32) {
				for  := range  {
					if .Ge([], []) {
						[] = 1
					} else {
						[] = 0
					}
				}
			},
			arrScalar: func( [],  ,  []uint32) {
				for  := range  {
					if .Ge([], ) {
						[] = 1
					} else {
						[] = 0
					}
				}
			},
			scalarArr: func( ,  [],  []uint32) {
				for  := range  {
					if .Ge(, []) {
						[] = 1
					} else {
						[] = 0
					}
				}
			},
		}
	}
	debug.Assert(false, "")
	return nil
}

func genDecimalCompareKernel[ decimal128.Num | decimal256.Num]( CompareOperator) ( exec.ArrayKernelExec,  exec.KernelState) {
	 = compareKernel[]

	var  
	switch any().(type) {
	case decimal128.Num:
		 := getCmpDec(, dec128Cmp)
		 = &CompareData{
			funcAA: comparePrimitiveArrayArray(.arrArr),
			funcAS: comparePrimitiveArrayScalar(.arrScalar),
			funcSA: comparePrimitiveScalarArray(.scalarArr),
		}
	case decimal256.Num:
		 := getCmpDec(, dec256Cmp)
		 = &CompareData{
			funcAA: comparePrimitiveArrayArray(.arrArr),
			funcAS: comparePrimitiveArrayScalar(.arrScalar),
			funcSA: comparePrimitiveScalarArray(.scalarArr),
		}
	}

	return
}

func getCmpOp[ arrow.NumericType]( CompareOperator) *cmpOp[] {
	switch  {
	case CmpEQ:
		return &cmpOp[]{
			arrArr: func(,  [],  []uint32) {
				for  := range  {
					if [] == [] {
						[] = 1
					} else {
						[] = 0
					}
				}
			},
			arrScalar: func( [],  ,  []uint32) {
				for  := range  {
					if [] ==  {
						[] = 1
					} else {
						[] = 0
					}
				}
			},
			scalarArr: func( ,  [],  []uint32) {
				for  := range  {
					if  == [] {
						[] = 1
					} else {
						[] = 0
					}
				}
			},
		}
	case CmpNE:
		return &cmpOp[]{
			arrArr: func(,  [],  []uint32) {
				for  := range  {
					if [] != [] {
						[] = 1
					} else {
						[] = 0
					}
				}
			},
			arrScalar: func( [],  ,  []uint32) {
				for  := range  {
					if [] !=  {
						[] = 1
					} else {
						[] = 0
					}
				}
			},
			scalarArr: func( ,  [],  []uint32) {
				for  := range  {
					if  != [] {
						[] = 1
					} else {
						[] = 0
					}
				}
			},
		}
	case CmpGT:
		return &cmpOp[]{
			arrArr: func(,  [],  []uint32) {
				for  := range  {
					if [] > [] {
						[] = 1
					} else {
						[] = 0
					}
				}
			},
			arrScalar: func( [],  ,  []uint32) {
				for  := range  {
					if [] >  {
						[] = 1
					} else {
						[] = 0
					}
				}
			},
			scalarArr: func( ,  [],  []uint32) {
				for  := range  {
					if  > [] {
						[] = 1
					} else {
						[] = 0
					}
				}
			},
		}
	case CmpGE:
		return &cmpOp[]{
			arrArr: func(,  [],  []uint32) {
				for  := range  {
					if [] >= [] {
						[] = 1
					} else {
						[] = 0
					}
				}
			},
			arrScalar: func( [],  ,  []uint32) {
				for  := range  {
					if [] >=  {
						[] = 1
					} else {
						[] = 0
					}
				}
			},
			scalarArr: func( ,  [],  []uint32) {
				for  := range  {
					if  >= [] {
						[] = 1
					} else {
						[] = 0
					}
				}
			},
		}
	}
	return nil
}

func getBinaryCmp( CompareOperator) binaryBinOp[bool] {
	switch  {
	case CmpEQ:
		return func( *exec.KernelCtx, ,  []byte) bool {
			return bytes.Equal(, )
		}
	case CmpNE:
		return func( *exec.KernelCtx, ,  []byte) bool {
			return !bytes.Equal(, )
		}
	case CmpGT:
		return func( *exec.KernelCtx, ,  []byte) bool {
			return bytes.Compare(, ) == 1
		}
	case CmpGE:
		return func( *exec.KernelCtx, ,  []byte) bool {
			return bytes.Compare(, ) != -1
		}
	}
	return nil
}

func numericCompareKernel[ arrow.NumericType]( exec.InputType,  CompareOperator) ( exec.ScalarKernel) {
	 := compareKernel[]
	 = exec.NewScalarKernelWithSig(&exec.KernelSignature{
		InputTypes: []exec.InputType{, },
		OutType:    exec.NewOutputType(arrow.FixedWidthTypes.Boolean),
	}, , nil)
	.Data = genCompareKernel[]()
	return
}

func decimalCompareKernel[ decimal128.Num | decimal256.Num]( exec.InputType,  CompareOperator) ( exec.ScalarKernel) {
	,  := genDecimalCompareKernel[]()
	 = exec.NewScalarKernelWithSig(&exec.KernelSignature{
		InputTypes: []exec.InputType{, },
		OutType:    exec.NewOutputType(arrow.FixedWidthTypes.Boolean),
	}, , nil)
	.Data = 
	return
}

func ( exec.InputType,  arrow.Type,  CompareOperator) exec.ScalarKernel {
	switch  {
	case arrow.INT8:
		return numericCompareKernel[int8](, )
	case arrow.INT16:
		return numericCompareKernel[int16](, )
	case arrow.INT32, arrow.DATE32, arrow.TIME32:
		return numericCompareKernel[int32](, )
	case arrow.INT64, arrow.DATE64, arrow.TIMESTAMP, arrow.TIME64, arrow.DURATION:
		return numericCompareKernel[int64](, )
	case arrow.UINT8:
		return numericCompareKernel[uint8](, )
	case arrow.UINT16:
		return numericCompareKernel[uint16](, )
	case arrow.UINT32:
		return numericCompareKernel[uint32](, )
	case arrow.UINT64:
		return numericCompareKernel[uint64](, )
	case arrow.FLOAT32:
		return numericCompareKernel[float32](, )
	case arrow.FLOAT64:
		return numericCompareKernel[float64](, )
	}
	debug.Assert(false, "")
	return exec.ScalarKernel{}
}

func compareTimestampKernel( exec.InputType,  CompareOperator) exec.ScalarKernel {
	 := GetCompareKernel(, arrow.TIMESTAMP, )
	 := .ExecFn
	.ExecFn = func( *exec.KernelCtx,  *exec.ExecSpan,  *exec.ExecResult) error {
		,  := .Values[0].Type().(*arrow.TimestampType), .Values[1].Type().(*arrow.TimestampType)
		if (len(.TimeZone) == 0) != (len(.TimeZone) == 0) {
			return fmt.Errorf("%w: cannot compare timestamp with timezone to timestamp without timezone, got: %s and %s",
				arrow.ErrInvalid, , )
		}
		return (, , )
	}
	return 
}

var (
	boolEQ = binaryBoolOps{
		arrArr: func( *exec.KernelCtx, , ,  bitutil.Bitmap) error {
			bitutil.BitmapAnd(.Data, .Data, .Offset, .Offset, .Data, .Offset, .Len)
			return nil
		},
		arrScalar: func( *exec.KernelCtx,  bitutil.Bitmap,  bool,  bitutil.Bitmap) error {
			 := bitutil.NewBitmapReader(.Data, int(.Offset), int(.Len))
			bitutils.GenerateBitsUnrolled(.Data, .Offset, .Len, func() ( bool) {
				 = .Set() == 
				.Next()
				return
			})
			return nil
		},
		scalarArr: func( *exec.KernelCtx,  bool, ,  bitutil.Bitmap) error {
			 := bitutil.NewBitmapReader(.Data, int(.Offset), int(.Len))
			bitutils.GenerateBitsUnrolled(.Data, .Offset, .Len, func() ( bool) {
				 =  == .Set()
				.Next()
				return
			})
			return nil
		},
	}
	boolNE = binaryBoolOps{
		arrArr: func( *exec.KernelCtx, , ,  bitutil.Bitmap) error {
			bitutil.BitmapXor(.Data, .Data, .Offset, .Offset, .Data, .Offset, .Len)
			return nil
		},
		arrScalar: func( *exec.KernelCtx,  bitutil.Bitmap,  bool,  bitutil.Bitmap) error {
			 := bitutil.NewBitmapReader(.Data, int(.Offset), int(.Len))
			bitutils.GenerateBitsUnrolled(.Data, .Offset, .Len, func() ( bool) {
				 = .Set() != 
				.Next()
				return
			})
			return nil
		},
		scalarArr: func( *exec.KernelCtx,  bool, ,  bitutil.Bitmap) error {
			 := bitutil.NewBitmapReader(.Data, int(.Offset), int(.Len))
			bitutils.GenerateBitsUnrolled(.Data, .Offset, .Len, func() ( bool) {
				 =  != .Set()
				.Next()
				return
			})
			return nil
		},
	}
)

func ( CompareOperator) []exec.ScalarKernel {
	 := make([]exec.ScalarKernel, 0)

	 := exec.NewOutputType(arrow.FixedWidthTypes.Boolean)
	switch  {
	case CmpEQ:
		 := exec.NewExactInput(arrow.FixedWidthTypes.Boolean)
		 = append(, exec.NewScalarKernel([]exec.InputType{, }, ,
			ScalarBinaryBools(&boolEQ), nil))
	case CmpNE:
		 := exec.NewExactInput(arrow.FixedWidthTypes.Boolean)
		 = append(, exec.NewScalarKernel([]exec.InputType{, }, ,
			ScalarBinaryBools(&boolNE), nil))
	}

	for ,  := range numericTypes {
		 := exec.NewExactInput()
		 = append(, GetCompareKernel(, .ID(), ))
	}
	 = append(,
		GetCompareKernel(exec.NewExactInput(arrow.FixedWidthTypes.Date32), arrow.DATE32, ),
		GetCompareKernel(exec.NewExactInput(arrow.FixedWidthTypes.Date64), arrow.DATE64, ))

	for ,  := range arrow.TimeUnitValues {
		 := exec.NewMatchedInput(exec.TimestampTypeUnit())
		 = append(, compareTimestampKernel(, ))

		 = exec.NewMatchedInput(exec.DurationTypeUnit())
		 = append(, GetCompareKernel(, arrow.INT64, ))
	}

	for ,  := range []arrow.TimeUnit{arrow.Second, arrow.Millisecond} {
		 := exec.NewMatchedInput(exec.Time32TypeUnit())
		 = append(, GetCompareKernel(, arrow.INT32, ))
	}
	for ,  := range []arrow.TimeUnit{arrow.Microsecond, arrow.Nanosecond} {
		 := exec.NewMatchedInput(exec.Time64TypeUnit())
		 = append(, GetCompareKernel(, arrow.INT64, ))
	}

	for ,  := range baseBinaryTypes {
		var  exec.ArrayKernelExec
		switch .Layout().Buffers[1].ByteWidth {
		case 4:
			 = ScalarBinaryBinaryArgsBoolOut(exec.NewVarBinaryIter[int32], getBinaryCmp())
		default:
			 = ScalarBinaryBinaryArgsBoolOut(exec.NewVarBinaryIter[int64], getBinaryCmp())
		}
		 := exec.NewExactInput()
		 = append(, exec.NewScalarKernel([]exec.InputType{, },
			, , nil))
	}

	,  := exec.NewIDInput(arrow.DECIMAL128), exec.NewIDInput(arrow.DECIMAL256)
	 = append(, decimalCompareKernel[decimal128.Num](, ),
		decimalCompareKernel[decimal256.Num](, ))

	 := exec.NewIDInput(arrow.FIXED_SIZE_BINARY)
	 = append(, exec.NewScalarKernel([]exec.InputType{, }, ,
		ScalarBinaryBinaryArgsBoolOut(exec.NewFSBIter, getBinaryCmp()), nil))

	return 
}

func isNullExec( *exec.KernelCtx,  *exec.ExecSpan,  *exec.ExecResult) error {
	.Release()
	 := .Values[0].Array

	 := .GetBuffer(0)
	.Buffers[1].WrapBuffer(.AllocateBitmap(.Len))
	if  != nil {
		bitutil.InvertBitmap(.Bytes(), int(.Offset), int(.Len),
			.Buffers[1].Buf, 0)
	}

	return nil
}

func isNotNullExec( *exec.KernelCtx,  *exec.ExecSpan,  *exec.ExecResult) error {
	.Release()
	 := .Values[0].Array

	 := .GetBuffer(0)
	if  == nil {
		.Buffers[1].WrapBuffer(.AllocateBitmap(.Len))
		memory.Set(.Buffers[1].Buf, 0xFF)
	} else {
		.Buffers[1].SetBuffer()
	}

	return nil
}

func () []exec.ScalarKernel {
	 := exec.InputType{Kind: exec.InputAny}
	 := exec.NewOutputType(arrow.FixedWidthTypes.Boolean)

	 := make([]exec.ScalarKernel, 2)
	[0] = exec.NewScalarKernel([]exec.InputType{}, , isNullExec, nil)
	[0].NullHandling = exec.NullComputedNoPrealloc
	[0].MemAlloc = exec.MemNoPrealloc

	[1] = exec.NewScalarKernel([]exec.InputType{}, , isNotNullExec, nil)
	[1].NullHandling = exec.NullComputedNoPrealloc
	[1].MemAlloc = exec.MemNoPrealloc

	return 
}

func ( bool) func(*exec.KernelCtx, *exec.ExecSpan, *exec.ExecResult) error {
	return func( *exec.KernelCtx,  *exec.ExecSpan,  *exec.ExecResult) error {
		bitutil.SetBitsTo(.Buffers[1].Buf, .Offset, .Len, )
		return nil
	}
}

func isNanKernelExec[ float32 | float64]( *exec.KernelCtx,  *exec.ExecSpan,  *exec.ExecResult) error {
	 := .Kernel.(*exec.ScalarKernel)
	 := .Data.(CompareFuncData).Funcs()

	 := int(.Offset % 8)
	 := .Buffers[1].Buf[.Offset/8:]

	 := getOffsetSpanBytes(&.Values[0].Array)
	.funcAA(, , , )
	return nil
}

func () []exec.ScalarKernel {
	 := exec.NewOutputType(arrow.FixedWidthTypes.Boolean)

	 := exec.NewScalarKernel([]exec.InputType{exec.NewExactInput(arrow.PrimitiveTypes.Float32)},
		, isNanKernelExec[float32], nil)
	.Data = genCompareKernel[float32](CmpNE)
	.NullHandling = exec.NullNoOutput
	 := exec.NewScalarKernel([]exec.InputType{exec.NewExactInput(arrow.PrimitiveTypes.Float64)},
		, isNanKernelExec[float64], nil)
	.Data = genCompareKernel[float64](CmpNE)
	.NullHandling = exec.NullNoOutput

	 := []exec.ScalarKernel{, }

	for ,  := range intTypes {
		 := exec.NewScalarKernel(
			[]exec.InputType{exec.NewExactInput()},
			, ConstBoolExec(false), nil)
		.NullHandling = exec.NullNoOutput
		 = append(, )
	}

	for ,  := range []arrow.Type{arrow.NULL, arrow.DURATION, arrow.DECIMAL32, arrow.DECIMAL64, arrow.DECIMAL128, arrow.DECIMAL256} {
		 := exec.NewScalarKernel(
			[]exec.InputType{exec.NewIDInput()},
			, ConstBoolExec(false), nil)
		.NullHandling = exec.NullNoOutput
		 = append(, )
	}

	return 
}