// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

//go:build go1.18

package compute

import (
	
	
	

	
	
	
	
	
	
	
	
)

type bufferWriteSeeker struct {
	buf *memory.Buffer
	pos int
	mem memory.Allocator
}

func ( *bufferWriteSeeker) ( int) {
	if .buf == nil {
		.buf = memory.NewResizableBuffer(.mem)
	}
	 := utils.Max(.buf.Cap(), 256)
	for  < .pos+ {
		 = bitutil.NextPowerOf2(.pos + )
	}
	.buf.Reserve()
}

func ( *bufferWriteSeeker) ( []byte) ( int,  error) {
	if len() == 0 {
		return 0, nil
	}

	if .buf == nil {
		.Reserve(len())
	} else if .pos+len() >= .buf.Cap() {
		.Reserve(len())
	}

	return .UnsafeWrite()
}

func ( *bufferWriteSeeker) ( []byte) ( int,  error) {
	 = copy(.buf.Buf()[.pos:], )
	.pos += len()
	if .pos > .buf.Len() {
		.buf.ResizeNoShrink(.pos)
	}
	return
}

func ( *bufferWriteSeeker) ( int64,  int) (int64, error) {
	,  := 0, int()
	switch  {
	case io.SeekStart:
		 = 
	case io.SeekCurrent:
		 = .pos + 
	case io.SeekEnd:
		 = .buf.Len() + 
	}
	if  < 0 {
		return 0, xerrors.New("negative result pos")
	}
	.pos = 
	return int64(), nil
}

// ensureDictionaryDecoded is used by DispatchBest to determine
// the proper types for promotion. Casting is then performed by
// the executor before continuing execution: see the implementation
// of execInternal in exec.go after calling DispatchBest.
//
// That casting is where actual decoding would be performed for
// the dictionary
func ensureDictionaryDecoded( ...arrow.DataType) {
	for ,  := range  {
		if .ID() == arrow.DICTIONARY {
			[] = .(*arrow.DictionaryType).ValueType
		}
	}
}

func ensureNoExtensionType( ...arrow.DataType) {
	for ,  := range  {
		if .ID() == arrow.EXTENSION {
			[] = .(arrow.ExtensionType).StorageType()
		}
	}
}

func replaceNullWithOtherType( ...arrow.DataType) {
	debug.Assert(len() == 2, "should be length 2")

	if [0].ID() == arrow.NULL {
		[0] = [1]
		return
	}

	if [1].ID() == arrow.NULL {
		[1] = [0]
		return
	}
}

func commonTemporalResolution( ...arrow.DataType) (arrow.TimeUnit, bool) {
	 := false
	 := arrow.Second
	for ,  := range  {
		switch dt := .(type) {
		case *arrow.Date32Type:
			 = true
			continue
		case *arrow.Date64Type:
			 = max(, arrow.Millisecond)
			 = true
		case arrow.TemporalWithUnit:
			 = max(, .TimeUnit())
			 = true
		default:
			continue
		}
	}
	return , 
}

func replaceTemporalTypes( arrow.TimeUnit,  ...arrow.DataType) {
	for ,  := range  {
		switch dt := .(type) {
		case *arrow.TimestampType:
			.Unit = 
			[] = 
		case *arrow.Time32Type, *arrow.Time64Type:
			if  > arrow.Millisecond {
				[] = &arrow.Time64Type{Unit: }
			} else {
				[] = &arrow.Time32Type{Unit: }
			}
		case *arrow.DurationType:
			.Unit = 
			[] = 
		case *arrow.Date32Type, *arrow.Date64Type:
			[] = &arrow.TimestampType{Unit: }
		}
	}
}

func replaceTypes( arrow.DataType,  ...arrow.DataType) {
	for  := range  {
		[] = 
	}
}

func commonNumeric( ...arrow.DataType) arrow.DataType {
	for ,  := range  {
		if !arrow.IsFloating(.ID()) && !arrow.IsInteger(.ID()) {
			// a common numeric type is only possible if all are numeric
			return nil
		}
		if .ID() == arrow.FLOAT16 {
			// float16 arithmetic is not currently supported
			return nil
		}
	}

	for ,  := range  {
		if .ID() == arrow.FLOAT64 {
			return arrow.PrimitiveTypes.Float64
		}
	}

	for ,  := range  {
		if .ID() == arrow.FLOAT32 {
			return arrow.PrimitiveTypes.Float32
		}
	}

	,  := 0, 0
	for ,  := range  {
		if arrow.IsUnsignedInteger(.ID()) {
			 = exec.Max(.(arrow.FixedWidthDataType).BitWidth(), )
		} else {
			 = exec.Max(.(arrow.FixedWidthDataType).BitWidth(), )
		}
	}

	if  == 0 {
		switch {
		case  >= 64:
			return arrow.PrimitiveTypes.Uint64
		case  == 32:
			return arrow.PrimitiveTypes.Uint32
		case  == 16:
			return arrow.PrimitiveTypes.Uint16
		default:
			debug.Assert( == 8, "bad maxWidthUnsigned")
			return arrow.PrimitiveTypes.Uint8
		}
	}

	if  <=  {
		 = bitutil.NextPowerOf2( + 1)
	}

	switch {
	case  >= 64:
		return arrow.PrimitiveTypes.Int64
	case  == 32:
		return arrow.PrimitiveTypes.Int32
	case  == 16:
		return arrow.PrimitiveTypes.Int16
	default:
		debug.Assert( == 8, "bad maxWidthSigned")
		return arrow.PrimitiveTypes.Int8
	}
}

func hasDecimal( ...arrow.DataType) bool {
	for ,  := range  {
		if arrow.IsDecimal(.ID()) {
			return true
		}
	}

	return false
}

type decimalPromotion uint8

const (
	decPromoteNone decimalPromotion = iota
	decPromoteAdd
	decPromoteMultiply
	decPromoteDivide
)

func castBinaryDecimalArgs( decimalPromotion,  ...arrow.DataType) error {
	,  := [0], [1]
	debug.Assert(arrow.IsDecimal(.ID()) || arrow.IsDecimal(.ID()), "at least one of the types should be decimal")

	// decimal + float = float
	if arrow.IsFloating(.ID()) {
		[1] = [0]
		return nil
	} else if arrow.IsFloating(.ID()) {
		[0] = [1]
		return nil
	}

	var , , ,  int32
	var  error
	// decimal + integer = decimal
	if arrow.IsDecimal(.ID()) {
		 := .(arrow.DecimalType)
		,  = .GetPrecision(), .GetScale()
	} else {
		debug.Assert(arrow.IsInteger(.ID()), "floats were already handled, this should be an int")
		if ,  = kernels.MaxDecimalDigitsForInt(.ID());  != nil {
			return 
		}
	}
	if arrow.IsDecimal(.ID()) {
		 := .(arrow.DecimalType)
		,  = .GetPrecision(), .GetScale()
	} else {
		debug.Assert(arrow.IsInteger(.ID()), "float already handled, should be ints")
		if ,  = kernels.MaxDecimalDigitsForInt(.ID());  != nil {
			return 
		}
	}

	if  < 0 ||  < 0 {
		return fmt.Errorf("%w: decimals with negative scales not supported", arrow.ErrNotImplemented)
	}

	// decimal128 + decimal256 = decimal256
	 := arrow.DECIMAL128
	if .ID() == arrow.DECIMAL256 || .ID() == arrow.DECIMAL256 {
		 = arrow.DECIMAL256
	}

	// decimal promotion rules compatible with amazon redshift
	// https://docs.aws.amazon.com/redshift/latest/dg/r_numeric_computations201.html
	var ,  int32

	switch  {
	case decPromoteAdd:
		 = exec.Max(, ) - 
		 = exec.Max(, ) - 
	case decPromoteMultiply:
	case decPromoteDivide:
		 = exec.Max(4, +-+1) +  - 
	default:
		debug.Assert(false, fmt.Sprintf("invalid DecimalPromotion value %d", ))
	}

	[0],  = arrow.NewDecimalType(, +, +)
	if  != nil {
		return 
	}
	[1],  = arrow.NewDecimalType(, +, +)
	return 
}

func commonTemporal( ...arrow.DataType) arrow.DataType {
	var (
		           = arrow.Second
		                 *string
		                  *time.Location
		,  bool
		,  bool
	)

	for ,  := range  {
		switch .ID() {
		case arrow.DATE32:
			// date32's unit is days, but the coarsest we have is seconds
			 = true
		case arrow.DATE64:
			 = max(, arrow.Millisecond)
			 = true
		case arrow.TIMESTAMP:
			 := .(*arrow.TimestampType)
			if .TimeZone != "" {
				,  := .GetZone()
				if  != nil &&  !=  {
					return nil
				}
				 = 
			}
			 = &.TimeZone
			 = max(, .Unit)
		case arrow.TIME32, arrow.TIME64:
			 := .(arrow.TemporalWithUnit)
			 = max(, .TimeUnit())
			 = true
		case arrow.DURATION:
			 := .(*arrow.DurationType)
			 = max(, .Unit)
			 = true
		default:
			return nil
		}
	}

	 :=  != nil ||  || 

	if  && ( || ) {
		// no common type possible
		return nil
	}

	if  {
		switch {
		case  != nil:
			// at least one timestamp seen
			return &arrow.TimestampType{Unit: , TimeZone: *}
		case :
			return arrow.FixedWidthTypes.Date64
		case :
			return arrow.FixedWidthTypes.Date32
		}
	} else if  {
		switch  {
		case arrow.Second, arrow.Millisecond:
			return &arrow.Time32Type{Unit: }
		case arrow.Microsecond, arrow.Nanosecond:
			return &arrow.Time64Type{Unit: }
		}
	} else if  {
		// we can only get here if we ONLY saw durations
		return &arrow.DurationType{Unit: }
	}
	return nil
}

func commonBinary( ...arrow.DataType) arrow.DataType {
	var (
		, ,  = true, true, true
	)

	for ,  := range  {
		switch .ID() {
		case arrow.STRING:
			 = false
		case arrow.BINARY:
			,  = false, false
		case arrow.FIXED_SIZE_BINARY:
			 = false
		case arrow.LARGE_BINARY:
			, ,  = false, false, false
		case arrow.LARGE_STRING:
			,  = false, false
		default:
			return nil
		}
	}

	switch {
	case :
		// at least for the purposes of comparison, no need to cast
		return nil
	case :
		if  {
			return arrow.BinaryTypes.String
		}
		return arrow.BinaryTypes.LargeString
	case :
		return arrow.BinaryTypes.Binary
	}
	return arrow.BinaryTypes.LargeBinary
}