// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

//go:build go1.18 && !noasm

package kernels

import (
	

	
	
	
	
	
)

func getAvx2ArithmeticBinaryNumeric[ arrow.NumericType]( ArithmeticOp) binaryOps[, , ] {
	 := arrow.GetType[]()
	return binaryOps[, , ]{
		arrArr: func( *exec.KernelCtx, , ,  []) error {
			arithmeticAvx2(, , arrow.GetBytes(), arrow.GetBytes(), arrow.GetBytes(), len())
			return nil
		},
		arrScalar: func( *exec.KernelCtx,  [],  ,  []) error {
			arithmeticArrScalarAvx2(, , arrow.GetBytes(), unsafe.Pointer(&), arrow.GetBytes(), len())
			return nil
		},
		scalarArr: func( *exec.KernelCtx,  , ,  []) error {
			arithmeticScalarArrAvx2(, , unsafe.Pointer(&), arrow.GetBytes(), arrow.GetBytes(), len())
			return nil
		},
	}
}

func getSSE4ArithmeticBinaryNumeric[ arrow.NumericType]( ArithmeticOp) binaryOps[, , ] {
	 := arrow.GetType[]()
	return binaryOps[, , ]{
		arrArr: func( *exec.KernelCtx, , ,  []) error {
			arithmeticSSE4(, , arrow.GetBytes(), arrow.GetBytes(), arrow.GetBytes(), len())
			return nil
		},
		arrScalar: func( *exec.KernelCtx,  [],  ,  []) error {
			arithmeticArrScalarSSE4(, , arrow.GetBytes(), unsafe.Pointer(&), arrow.GetBytes(), len())
			return nil
		},
		scalarArr: func( *exec.KernelCtx,  , ,  []) error {
			arithmeticScalarArrSSE4(, , unsafe.Pointer(&), arrow.GetBytes(), arrow.GetBytes(), len())
			return nil
		},
	}
}

func getArithmeticOpIntegral[,  arrow.UintType | arrow.IntType]( ArithmeticOp) exec.ArrayKernelExec {
	if cpu.X86.HasAVX2 {
		switch  {
		case OpAdd, OpSub, OpMul:
			return ScalarBinary(getAvx2ArithmeticBinaryNumeric[]())
		case OpAbsoluteValue, OpNegate:
			 := arrow.GetType[]()
			return ScalarUnary(func( *exec.KernelCtx, ,  []) error {
				arithmeticUnaryAvx2(, , arrow.GetBytes(), arrow.GetBytes(), len())
				return nil
			})
		case OpSign:
			,  := arrow.GetType[](), arrow.GetType[]()
			return ScalarUnary(func( *exec.KernelCtx,  [],  []) error {
				arithmeticUnaryDiffTypesAvx2(, , , arrow.GetBytes(), arrow.GetBytes(), len())
				return nil
			})
		}
	} else if cpu.X86.HasSSE42 {
		switch  {
		case OpAdd, OpSub, OpMul:
			return ScalarBinary(getSSE4ArithmeticBinaryNumeric[]())
		case OpAbsoluteValue, OpNegate:
			 := arrow.GetType[]()
			return ScalarUnary(func( *exec.KernelCtx, ,  []) error {
				arithmeticUnarySSE4(, , arrow.GetBytes(), arrow.GetBytes(), len())
				return nil
			})
		case OpSign:
			,  := arrow.GetType[](), arrow.GetType[]()
			return ScalarUnary(func( *exec.KernelCtx,  [],  []) error {
				arithmeticUnaryDiffTypesSSE4(, , , arrow.GetBytes(), arrow.GetBytes(), len())
				return nil
			})
		}
	}

	// no SIMD for POWER or SQRT functions
	// integral checked funcs need to use NotNull versions
	return getGoArithmeticOpIntegral[, ]()
}

func getArithmeticOpFloating[,  constraints.Float]( ArithmeticOp) exec.ArrayKernelExec {
	if cpu.X86.HasAVX2 {
		switch  {
		case OpAdd, OpSub, OpAddChecked, OpSubChecked, OpMul, OpMulChecked:
			if arrow.GetType[]() != arrow.GetType[]() {
				debug.Assert(false, "not implemented")
				return nil
			}
			return ScalarBinary(getAvx2ArithmeticBinaryNumeric[]())
		case OpAbsoluteValue, OpAbsoluteValueChecked, OpNegate, OpNegateChecked, OpSign:
			if arrow.GetType[]() != arrow.GetType[]() {
				debug.Assert(false, "not implemented")
				return nil
			}
			 := arrow.GetType[]()
			return ScalarUnary(func( *exec.KernelCtx, ,  []) error {
				arithmeticUnaryAvx2(, , arrow.GetBytes(), arrow.GetBytes(), len())
				return nil
			})
		}
	} else if cpu.X86.HasSSE42 {
		switch  {
		case OpAdd, OpSub, OpAddChecked, OpSubChecked, OpMul, OpMulChecked:
			if arrow.GetType[]() != arrow.GetType[]() {
				debug.Assert(false, "not implemented")
				return nil
			}
			return ScalarBinary(getSSE4ArithmeticBinaryNumeric[]())
		case OpAbsoluteValue, OpAbsoluteValueChecked, OpNegate, OpNegateChecked, OpSign:
			if arrow.GetType[]() != arrow.GetType[]() {
				debug.Assert(false, "not implemented")
				return nil
			}
			 := arrow.GetType[]()
			return ScalarUnary(func( *exec.KernelCtx, ,  []) error {
				arithmeticUnarySSE4(, , arrow.GetBytes(), arrow.GetBytes(), len())
				return nil
			})
		}
	}

	// no SIMD for POWER or SQRT functions
	return getGoArithmeticOpFloating[, ]()
}