// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

//go:build go1.18

package kernels

import (
	

	
	
	
	
	
	
	
)

type HashState interface {
	// Reset for another run
	Reset() error
	// Flush out accumulated results from last invocation
	Flush(*exec.ExecResult) error
	// FlushFinal flushes the accumulated results across all invocations
	// of calls. The kernel should not be used again until after
	// Reset() is called.
	FlushFinal(out *exec.ExecResult) error
	// GetDictionary returns the values (keys) accumulated in the dictionary
	// so far.
	GetDictionary() (arrow.ArrayData, error)
	ValueType() arrow.DataType
	// Append prepares the action for the given input (reserving appropriately
	// sized data structures, etc.) and visits the input with the Action
	Append(*exec.KernelCtx, *exec.ArraySpan) error
	Allocator() memory.Allocator
}

type Action interface {
	Reset() error
	Reserve(int) error
	Flush(*exec.ExecResult) error
	FlushFinal(*exec.ExecResult) error
	ObserveFound(int)
	ObserveNotFound(int) error
	ObserveNullFound(int)
	ObserveNullNotFound(int) error
	ShouldEncodeNulls() bool
}

type emptyAction struct {
	mem memory.Allocator
	dt  arrow.DataType
}

func (emptyAction) () error                      { return nil }
func (emptyAction) (int) error                 { return nil }
func (emptyAction) (*exec.ExecResult) error      { return nil }
func (emptyAction) (*exec.ExecResult) error { return nil }
func (emptyAction) (int)                  {}
func (emptyAction) (int) error         { return nil }
func (emptyAction) (int)              {}
func (emptyAction) (int) error     { return nil }
func (emptyAction) () bool           { return true }

type uniqueAction = emptyAction

type regularHashState struct {
	mem       memory.Allocator
	typ       arrow.DataType
	memoTable hashing.MemoTable
	action    Action

	doAppend func(Action, hashing.MemoTable, *exec.ArraySpan) error
}

func ( *regularHashState) () memory.Allocator { return .mem }

func ( *regularHashState) () arrow.DataType { return .typ }

func ( *regularHashState) () error {
	.memoTable.Reset()
	return .action.Reset()
}

func ( *regularHashState) ( *exec.KernelCtx,  *exec.ArraySpan) error {
	if  := .action.Reserve(int(.Len));  != nil {
		return 
	}

	return .doAppend(.action, .memoTable, )
}

func ( *regularHashState) ( *exec.ExecResult) error { return .action.Flush() }
func ( *regularHashState) ( *exec.ExecResult) error {
	return .action.FlushFinal()
}

func ( *regularHashState) () (arrow.ArrayData, error) {
	return array.GetDictArrayData(.mem, .typ, .memoTable, 0)
}

func doAppendBinary[ int32 | int64]( Action,  hashing.MemoTable,  *exec.ArraySpan) error {
	var (
		            = .Buffers[0].Buf
		           = exec.GetSpanOffsets[](, 1)
		              = .Buffers[2].Buf
		 = .ShouldEncodeNulls()
	)

	return bitutils.VisitBitBlocksShort(, .Offset, .Len,
		func( int64) error {
			 := [[]:[+1]]
			, ,  := .GetOrInsert()
			if  != nil {
				return 
			}
			if  {
				.ObserveFound()
				return nil
			}
			return .ObserveNotFound()
		},
		func() error {
			if ! {
				return .ObserveNullNotFound(-1)
			}

			,  := .GetOrInsertNull()
			if  {
				.ObserveNullFound()
			}
			return .ObserveNullNotFound()
		})
}

func doAppendFixedSize( Action,  hashing.MemoTable,  *exec.ArraySpan) error {
	 := int64(.Type.(arrow.FixedWidthDataType).Bytes())
	 := .Buffers[1].Buf[.Offset*:]
	 := .ShouldEncodeNulls()

	return bitutils.VisitBitBlocksShort(.Buffers[0].Buf, .Offset, .Len,
		func( int64) error {
			// fixed size type memo table we use a binary memo table
			// so get the raw bytes
			, ,  := .GetOrInsert([* : (+1)*])
			if  != nil {
				return 
			}
			if  {
				.ObserveFound()
				return nil
			}
			return .ObserveNotFound()
		}, func() error {
			if ! {
				return .ObserveNullNotFound(-1)
			}

			,  := .GetOrInsertNull()
			if  {
				.ObserveNullFound()
			}
			return .ObserveNullNotFound()
		})
}

func doAppendNumeric[ arrow.IntType | arrow.UintType | arrow.FloatType]( Action,  hashing.MemoTable,  *exec.ArraySpan) error {
	 := exec.GetSpanValues[](, 1)
	 := .ShouldEncodeNulls()
	return bitutils.VisitBitBlocksShort(.Buffers[0].Buf, .Offset, .Len,
		func( int64) error {
			, ,  := .GetOrInsert([])
			if  != nil {
				return 
			}
			if  {
				.ObserveFound()
				return nil
			}
			return .ObserveNotFound()
		}, func() error {
			if ! {
				return .ObserveNullNotFound(-1)
			}

			,  := .GetOrInsertNull()
			if  {
				.ObserveNullFound()
			}
			return .ObserveNullNotFound()
		})
}

type nullHashState struct {
	mem      memory.Allocator
	typ      arrow.DataType
	seenNull bool
	action   Action
}

func ( *nullHashState) () memory.Allocator { return .mem }

func ( *nullHashState) () arrow.DataType { return .typ }

func ( *nullHashState) () error {
	return .action.Reset()
}

func ( *nullHashState) ( *exec.KernelCtx,  *exec.ArraySpan) ( error) {
	if  := .action.Reserve(int(.Len));  != nil {
		return 
	}

	for  := 0;  < int(.Len); ++ {
		if  == 0 {
			.seenNull = true
			 = .action.ObserveNullNotFound(0)
		} else {
			.action.ObserveNullFound(0)
		}
	}
	return
}

func ( *nullHashState) ( *exec.ExecResult) error { return .action.Flush() }
func ( *nullHashState) ( *exec.ExecResult) error {
	return .action.FlushFinal()
}

func ( *nullHashState) () (arrow.ArrayData, error) {
	var  arrow.Array
	if .seenNull {
		 = array.NewNull(1)
	} else {
		 = array.NewNull(0)
	}
	 := .Data()
	.Retain()
	.Release()
	return , nil
}

type dictionaryHashState struct {
	indicesKernel HashState
	dictionary    arrow.Array
	dictValueType arrow.DataType
}

func ( *dictionaryHashState) () memory.Allocator { return .indicesKernel.Allocator() }
func ( *dictionaryHashState) () error                { return .indicesKernel.Reset() }
func ( *dictionaryHashState) ( *exec.ExecResult) error {
	return .indicesKernel.Flush()
}
func ( *dictionaryHashState) ( *exec.ExecResult) error {
	return .indicesKernel.FlushFinal()
}
func ( *dictionaryHashState) () (arrow.ArrayData, error) {
	return .indicesKernel.GetDictionary()
}
func ( *dictionaryHashState) () arrow.DataType           { return .indicesKernel.ValueType() }
func ( *dictionaryHashState) () arrow.DataType { return .dictValueType }
func ( *dictionaryHashState) () arrow.Array             { return .dictionary }
func ( *dictionaryHashState) ( *exec.KernelCtx,  *exec.ArraySpan) error {
	 := .Dictionary().MakeArray()
	if .dictionary == nil || array.Equal(.dictionary, ) {
		.dictionary = 
		return .indicesKernel.Append(, )
	}

	defer .Release()

	// NOTE: this approach computes a new dictionary unification per chunk
	// this is in effect O(n*k) where n is the total chunked array length
	// and k is the number of chunks (therefore O(n**2) if chunks have a fixed size).
	//
	// A better approach may be to run the kernel over each individual chunk,
	// and then hash-aggregate all results (for example sum-group-by for
	// the "value_counts" kernel)
	,  := array.NewDictionaryUnifier(.indicesKernel.Allocator(), .dictValueType)
	if  != nil {
		return 
	}
	defer .Release()

	if  := .Unify(.dictionary);  != nil {
		return 
	}
	,  := .UnifyAndTranspose()
	if  != nil {
		return 
	}
	defer .Release()
	, ,  := .GetResult()
	if  != nil {
		return 
	}
	defer func() {
		.dictionary.Release()
		.dictionary = 
	}()

	 := .MakeData()
	defer .Release()
	,  := array.TransposeDictIndices(.Allocator(), , .Type, .Type, .Data(), arrow.Int32Traits.CastFromBytes(.Bytes()))
	if  != nil {
		return 
	}
	defer .Release()

	var  exec.ArraySpan
	.SetMembers()
	return .indicesKernel.Append(, &)
}

func nullHashInit( initAction) exec.KernelInitFn {
	return func( *exec.KernelCtx,  exec.KernelInitArgs) (exec.KernelState, error) {
		 := exec.GetAllocator(.Ctx)
		 := &nullHashState{
			mem:    ,
			typ:    .Inputs[0],
			action: (.Inputs[0], .Options, ),
		}
		.Reset()
		return , nil
	}
}

func newMemoTable( memory.Allocator,  arrow.Type) (hashing.MemoTable, error) {
	switch  {
	case arrow.INT8, arrow.UINT8:
		return hashing.NewMemoTable[uint8](0), nil
	case arrow.INT16, arrow.UINT16:
		return hashing.NewMemoTable[uint16](0), nil
	case arrow.INT32, arrow.UINT32, arrow.FLOAT32, arrow.DECIMAL32,
		arrow.DATE32, arrow.TIME32, arrow.INTERVAL_MONTHS:
		return hashing.NewMemoTable[uint32](0), nil
	case arrow.INT64, arrow.UINT64, arrow.FLOAT64, arrow.DECIMAL64,
		arrow.DATE64, arrow.TIME64, arrow.TIMESTAMP,
		arrow.DURATION, arrow.INTERVAL_DAY_TIME:
		return hashing.NewMemoTable[uint64](0), nil
	case arrow.BINARY, arrow.STRING, arrow.FIXED_SIZE_BINARY, arrow.DECIMAL128,
		arrow.DECIMAL256, arrow.INTERVAL_MONTH_DAY_NANO:
		return hashing.NewBinaryMemoTable(0, 0,
			array.NewBinaryBuilder(, arrow.BinaryTypes.Binary)), nil
	case arrow.LARGE_BINARY, arrow.LARGE_STRING:
		return hashing.NewBinaryMemoTable(0, 0,
			array.NewBinaryBuilder(, arrow.BinaryTypes.LargeBinary)), nil
	default:
		return nil, fmt.Errorf("%w: unsupported type %s", arrow.ErrNotImplemented, )
	}
}

func regularHashInit( arrow.DataType,  initAction,  func(Action, hashing.MemoTable, *exec.ArraySpan) error) exec.KernelInitFn {
	return func( *exec.KernelCtx,  exec.KernelInitArgs) (exec.KernelState, error) {
		 := exec.GetAllocator(.Ctx)
		,  := newMemoTable(, .ID())
		if  != nil {
			return nil, 
		}

		 := &regularHashState{
			mem:       ,
			typ:       .Inputs[0],
			memoTable: ,
			action:    (.Inputs[0], .Options, ),
			doAppend:  ,
		}
		.Reset()
		return , nil
	}
}

func dictionaryHashInit( initAction) exec.KernelInitFn {
	return func( *exec.KernelCtx,  exec.KernelInitArgs) (exec.KernelState, error) {
		var (
			      = .Inputs[0].(*arrow.DictionaryType)
			 exec.KernelState
			           error
		)

		switch .IndexType.ID() {
		case arrow.INT8, arrow.UINT8:
			,  = getHashInit(arrow.UINT8, )(, )
		case arrow.INT16, arrow.UINT16:
			,  = getHashInit(arrow.UINT16, )(, )
		case arrow.INT32, arrow.UINT32:
			,  = getHashInit(arrow.UINT32, )(, )
		case arrow.INT64, arrow.UINT64:
			,  = getHashInit(arrow.UINT64, )(, )
		default:
			return nil, fmt.Errorf("%w: unsupported dictionary index type", arrow.ErrInvalid)
		}
		if  != nil {
			return nil, 
		}

		return &dictionaryHashState{
			indicesKernel: .(HashState),
			dictValueType: .ValueType,
		}, nil
	}
}

type initAction func(arrow.DataType, any, memory.Allocator) Action

func getHashInit( arrow.Type,  initAction) exec.KernelInitFn {
	switch  {
	case arrow.NULL:
		return nullHashInit()
	case arrow.INT8, arrow.UINT8:
		return regularHashInit(arrow.PrimitiveTypes.Uint8, , doAppendNumeric[uint8])
	case arrow.INT16, arrow.UINT16:
		return regularHashInit(arrow.PrimitiveTypes.Uint16, , doAppendNumeric[uint16])
	case arrow.INT32, arrow.UINT32, arrow.FLOAT32,
		arrow.DATE32, arrow.TIME32, arrow.INTERVAL_MONTHS:
		return regularHashInit(arrow.PrimitiveTypes.Uint32, , doAppendNumeric[uint32])
	case arrow.INT64, arrow.UINT64, arrow.FLOAT64,
		arrow.DATE64, arrow.TIME64, arrow.TIMESTAMP,
		arrow.DURATION, arrow.INTERVAL_DAY_TIME:
		return regularHashInit(arrow.PrimitiveTypes.Uint64, , doAppendNumeric[uint64])
	case arrow.BINARY, arrow.STRING:
		return regularHashInit(arrow.BinaryTypes.Binary, , doAppendBinary[int32])
	case arrow.LARGE_BINARY, arrow.LARGE_STRING:
		return regularHashInit(arrow.BinaryTypes.LargeBinary, , doAppendBinary[int64])
	case arrow.FIXED_SIZE_BINARY, arrow.DECIMAL128, arrow.DECIMAL256:
		return regularHashInit(arrow.BinaryTypes.Binary, , doAppendFixedSize)
	case arrow.INTERVAL_MONTH_DAY_NANO:
		return regularHashInit(arrow.FixedWidthTypes.MonthDayNanoInterval, , doAppendFixedSize)
	default:
		debug.Assert(false, "unsupported hash init type")
		return nil
	}
}

func hashExec( *exec.KernelCtx,  *exec.ExecSpan,  *exec.ExecResult) error {
	,  := .State.(HashState)
	if ! {
		return fmt.Errorf("%w: bad initialization of hash state", arrow.ErrInvalid)
	}

	if  := .Append(, &.Values[0].Array);  != nil {
		return 
	}

	return .Flush()
}

func uniqueFinalize( *exec.KernelCtx,  []*exec.ArraySpan) ([]*exec.ArraySpan, error) {
	,  := .State.(HashState)
	if ! {
		return nil, fmt.Errorf("%w: HashState in invalid state", arrow.ErrInvalid)
	}

	for ,  := range  {
		// release any pre-allocation we did
		.Release()
	}

	,  := .GetDictionary()
	if  != nil {
		return nil, 
	}
	defer .Release()

	var  exec.ArraySpan
	.TakeOwnership()
	return []*exec.ArraySpan{&}, nil
}

func ensureHashDictionary( *exec.KernelCtx,  *dictionaryHashState) (*exec.ArraySpan, error) {
	 := &exec.ArraySpan{}

	if .dictionary != nil {
		.TakeOwnership(.dictionary.Data())
		.dictionary.Release()
		return , nil
	}

	exec.FillZeroLength(.DictionaryValueType(), )
	return , nil
}

func uniqueFinalizeDictionary( *exec.KernelCtx,  []*exec.ArraySpan) ( []*exec.ArraySpan,  error) {
	if ,  = uniqueFinalize(, );  != nil {
		return
	}

	,  := .State.(*dictionaryHashState)
	if ! {
		return nil, fmt.Errorf("%w: state should be *dictionaryHashState", arrow.ErrInvalid)
	}

	,  := ensureHashDictionary(, )
	if  != nil {
		return nil, 
	}
	[0].SetDictionary()
	return
}

func addHashKernels( exec.VectorKernel,  initAction,  exec.OutputType) []exec.VectorKernel {
	 := make([]exec.VectorKernel, 0)
	for ,  := range primitiveTypes {
		.Init = getHashInit(.ID(), )
		.Signature = &exec.KernelSignature{
			InputTypes: []exec.InputType{exec.NewExactInput()},
			OutType:    ,
		}
		 = append(, )
	}

	 := []arrow.Type{arrow.TIME32, arrow.TIME64, arrow.TIMESTAMP,
		arrow.DURATION, arrow.FIXED_SIZE_BINARY, arrow.DECIMAL128, arrow.DECIMAL256,
		arrow.INTERVAL_DAY_TIME, arrow.INTERVAL_MONTHS, arrow.INTERVAL_MONTH_DAY_NANO}
	for ,  := range  {
		.Init = getHashInit(, )
		.Signature = &exec.KernelSignature{
			InputTypes: []exec.InputType{exec.NewIDInput()},
			OutType:    ,
		}
		 = append(, )
	}

	return 
}

func initUnique( arrow.DataType,  any,  memory.Allocator) Action {
	return uniqueAction{mem: , dt: }
}

func () (, ,  []exec.VectorKernel) {
	var  exec.VectorKernel
	.ExecFn = hashExec

	// unique
	.Finalize = uniqueFinalize
	.OutputChunked = false
	.CanExecuteChunkWise = true
	 = addHashKernels(, initUnique, OutputFirstType)

	// dictionary unique
	.Init = dictionaryHashInit(initUnique)
	.Finalize = uniqueFinalizeDictionary
	.Signature = &exec.KernelSignature{
		InputTypes: []exec.InputType{exec.NewIDInput(arrow.DICTIONARY)},
		OutType:    OutputFirstType,
	}
	 = append(, )

	return
}