// Copyright 2019+ Klaus Post. All rights reserved.
// License information can be found in the LICENSE file.
// Based on work by Yann Collet, released under BSD License.

package zstd

import (
	
	
	
)

const (
	// For encoding we only support up to
	maxEncTableLog    = 8
	maxEncTablesize   = 1 << maxTableLog
	maxEncTableMask   = (1 << maxTableLog) - 1
	minEncTablelog    = 5
	maxEncSymbolValue = maxMatchLengthSymbol
)

// Scratch provides temporary storage for compression and decompression.
type fseEncoder struct {
	symbolLen      uint16 // Length of active part of the symbol table.
	actualTableLog uint8  // Selected tablelog.
	ct             cTable // Compression tables.
	maxCount       int    // count of the most probable symbol
	zeroBits       bool   // no bits has prob > 50%.
	clearCount     bool   // clear count
	useRLE         bool   // This encoder is for RLE
	preDefined     bool   // This encoder is predefined.
	reUsed         bool   // Set to know when the encoder has been reused.
	rleVal         uint8  // RLE Symbol
	maxBits        uint8  // Maximum output bits after transform.

	// TODO: Technically zstd should be fine with 64 bytes.
	count [256]uint32
	norm  [256]int16
}

// cTable contains tables used for compression.
type cTable struct {
	tableSymbol []byte
	stateTable  []uint16
	symbolTT    []symbolTransform
}

// symbolTransform contains the state transform for a symbol.
type symbolTransform struct {
	deltaNbBits    uint32
	deltaFindState int16
	outBits        uint8
}

// String prints values as a human readable string.
func ( symbolTransform) () string {
	return fmt.Sprintf("{deltabits: %08x, findstate:%d outbits:%d}", .deltaNbBits, .deltaFindState, .outBits)
}

// Histogram allows to populate the histogram and skip that step in the compression,
// It otherwise allows to inspect the histogram when compression is done.
// To indicate that you have populated the histogram call HistogramFinished
// with the value of the highest populated symbol, as well as the number of entries
// in the most populated entry. These are accepted at face value.
func ( *fseEncoder) () *[256]uint32 {
	return &.count
}

// HistogramFinished can be called to indicate that the histogram has been populated.
// maxSymbol is the index of the highest set symbol of the next data segment.
// maxCount is the number of entries in the most populated entry.
// These are accepted at face value.
func ( *fseEncoder) ( uint8,  int) {
	.maxCount = 
	.symbolLen = uint16() + 1
	.clearCount =  != 0
}

// allocCtable will allocate tables needed for compression.
// If existing tables a re big enough, they are simply re-used.
func ( *fseEncoder) () {
	 := 1 << .actualTableLog
	// get tableSymbol that is big enough.
	if cap(.ct.tableSymbol) <  {
		.ct.tableSymbol = make([]byte, )
	}
	.ct.tableSymbol = .ct.tableSymbol[:]

	 := 
	if cap(.ct.stateTable) <  {
		.ct.stateTable = make([]uint16, )
	}
	.ct.stateTable = .ct.stateTable[:]

	if cap(.ct.symbolTT) < 256 {
		.ct.symbolTT = make([]symbolTransform, 256)
	}
	.ct.symbolTT = .ct.symbolTT[:256]
}

// buildCTable will populate the compression table so it is ready to be used.
func ( *fseEncoder) () error {
	 := uint32(1 << .actualTableLog)
	 :=  - 1
	var  [256]int16

	.allocCtable()
	 := .ct.tableSymbol[:]
	// symbol start positions
	{
		[0] = 0
		for ,  := range .norm[:.symbolLen-1] {
			 := byte() // one less than reference
			if  == -1 {
				// Low proba symbol
				[+1] = [] + 1
				[] = 
				--
			} else {
				[+1] = [] + 
			}
		}
		// Encode last symbol separately to avoid overflowing u
		 := int(.symbolLen - 1)
		 := .norm[.symbolLen-1]
		if  == -1 {
			// Low proba symbol
			[+1] = [] + 1
			[] = byte()
			--
		} else {
			[+1] = [] + 
		}
		if uint32([.symbolLen]) !=  {
			return fmt.Errorf("internal error: expected cumul[s.symbolLen] (%d) == tableSize (%d)", [.symbolLen], )
		}
		[.symbolLen] = int16() + 1
	}
	// Spread symbols
	.zeroBits = false
	{
		 := tableStep()
		 :=  - 1
		var  uint32
		// if any symbol > largeLimit, we may have 0 bits output.
		 := int16(1 << (.actualTableLog - 1))
		for ,  := range .norm[:.symbolLen] {
			 := byte()
			if  >  {
				.zeroBits = true
			}
			for  := int16(0);  < ; ++ {
				[] = 
				 = ( + ) & 
				for  >  {
					 = ( + ) & 
				} /* Low proba area */
			}
		}

		// Check if we have gone through all positions
		if  != 0 {
			return errors.New("position!=0")
		}
	}

	// Build table
	 := .ct.stateTable
	{
		 := int()
		for ,  := range  {
			// TableU16 : sorted by symbol order; gives next state value
			[[]] = uint16( + )
			[]++
		}
	}

	// Build Symbol Transformation Table
	{
		 := int16(0)
		 := .ct.symbolTT[:.symbolLen]
		 := .actualTableLog
		 := (uint32() << 16) - (1 << )
		for ,  := range .norm[:.symbolLen] {
			switch  {
			case 0:
			case -1, 1:
				[].deltaNbBits = 
				[].deltaFindState =  - 1
				++
			default:
				 := uint32() - highBit(uint32(-1))
				 := uint32() << 
				[].deltaNbBits = ( << 16) - 
				[].deltaFindState =  - 
				 += 
			}
		}
		if  != int16() {
			return fmt.Errorf("total mismatch %d (got) != %d (want)", , )
		}
	}
	return nil
}

var rtbTable = [...]uint32{0, 473195, 504333, 520860, 550000, 700000, 750000, 830000}

func ( *fseEncoder) ( byte) {
	.allocCtable()
	.actualTableLog = 0
	.ct.stateTable = .ct.stateTable[:1]
	.ct.symbolTT[] = symbolTransform{
		deltaFindState: 0,
		deltaNbBits:    0,
	}
	if debugEncoder {
		println("setRLE: val", , "symbolTT", .ct.symbolTT[])
	}
	.rleVal = 
	.useRLE = true
}

// setBits will set output bits for the transform.
// if nil is provided, the number of bits is equal to the index.
func ( *fseEncoder) ( []byte) {
	if .reUsed || .preDefined {
		return
	}
	if .useRLE {
		if  == nil {
			.ct.symbolTT[.rleVal].outBits = .rleVal
			.maxBits = .rleVal
			return
		}
		.maxBits = [.rleVal]
		.ct.symbolTT[.rleVal].outBits = .maxBits
		return
	}
	if  == nil {
		for  := range .ct.symbolTT[:.symbolLen] {
			.ct.symbolTT[].outBits = uint8()
		}
		.maxBits = uint8(.symbolLen - 1)
		return
	}
	.maxBits = 0
	for ,  := range [:.symbolLen] {
		.ct.symbolTT[].outBits = 
		if  > .maxBits {
			// We could assume bits always going up, but we play safe.
			.maxBits = 
		}
	}
}

// normalizeCount will normalize the count of the symbols so
// the total is equal to the table size.
// If successful, compression tables will also be made ready.
func ( *fseEncoder) ( int) error {
	if .reUsed {
		return nil
	}
	.optimalTableLog()
	var (
		          = .actualTableLog
		             = 62 - uint64()
		              = (1 << 62) / uint64()
		             = uint64(1) << ( - 20)
		 = int16(1 << )
		           int
		          int16
		      = (uint32)( >> )
	)
	if .maxCount ==  {
		.useRLE = true
		return nil
	}
	.useRLE = false
	for ,  := range .count[:.symbolLen] {
		// already handled
		// if (count[s] == s.length) return 0;   /* rle special case */

		if  == 0 {
			.norm[] = 0
			continue
		}
		if  <=  {
			.norm[] = -1
			--
		} else {
			 := (int16)((uint64() * ) >> )
			if  < 8 {
				 :=  * uint64(rtbTable[])
				 := uint64()* - (uint64() << )
				if  >  {
					++
				}
			}
			if  >  {
				 = 
				 = 
			}
			.norm[] = 
			 -= 
		}
	}

	if - >= (.norm[] >> 1) {
		// corner case, need another normalization method
		 := .normalizeCount2()
		if  != nil {
			return 
		}
		if debugAsserts {
			 = .validateNorm()
			if  != nil {
				return 
			}
		}
		return .buildCTable()
	}
	.norm[] += 
	if debugAsserts {
		 := .validateNorm()
		if  != nil {
			return 
		}
	}
	return .buildCTable()
}

// Secondary normalization method.
// To be used when primary method fails.
func ( *fseEncoder) ( int) error {
	const  = -2
	var (
		  uint32
		        = uint32()
		     = .actualTableLog
		 =  >> 
		       = ( * 3) >> ( + 1)
	)
	for ,  := range .count[:.symbolLen] {
		if  == 0 {
			.norm[] = 0
			continue
		}
		if  <=  {
			.norm[] = -1
			++
			 -= 
			continue
		}
		if  <=  {
			.norm[] = 1
			++
			 -= 
			continue
		}
		.norm[] = 
	}
	 := (1 << ) - 

	if ( / ) >  {
		// risk of rounding to zero
		 = ( * 3) / ( * 2)
		for ,  := range .count[:.symbolLen] {
			if (.norm[] == ) && ( <= ) {
				.norm[] = 1
				++
				 -= 
				continue
			}
		}
		 = (1 << ) - 
	}
	if  == uint32(.symbolLen)+1 {
		// all values are pretty poor;
		//   probably incompressible data (should have already been detected);
		//   find max, then give all remaining points to max
		var  int
		var  uint32
		for ,  := range .count[:.symbolLen] {
			if  >  {
				 = 
				 = 
			}
		}
		.norm[] += int16()
		return nil
	}

	if  == 0 {
		// all of the symbols were low enough for the lowOne or lowThreshold
		for  := uint32(0);  > 0;  = ( + 1) % (uint32(.symbolLen)) {
			if .norm[] > 0 {
				--
				.norm[]++
			}
		}
		return nil
	}

	var (
		 = 62 - uint64()
		      = uint64((1 << ( - 1)) - 1)
		    = (((1 << ) * uint64()) + ) / uint64() // scale on remaining
		 = 
	)
	for ,  := range .count[:.symbolLen] {
		if .norm[] ==  {
			var (
				    =  + uint64()*
				 = uint32( >> )
				   = uint32( >> )
				 =  - 
			)
			if  < 1 {
				return errors.New("weight < 1")
			}
			.norm[] = int16()
			 = 
		}
	}
	return nil
}

// optimalTableLog calculates and sets the optimal tableLog in s.actualTableLog
func ( *fseEncoder) ( int) {
	 := uint8(maxEncTableLog)
	 := highBit(uint32()) + 1
	 := highBit(uint32(.symbolLen-1)) + 2
	 := uint8()
	if  <  {
		 = uint8()
	}

	 := uint8(highBit(uint32(-1))) - 2
	if  <  {
		// Accuracy can be reduced
		 = 
	}
	if  >  {
		 = 
	}
	// Need a minimum to safely represent all symbol values
	if  < minEncTablelog {
		 = minEncTablelog
	}
	if  > maxEncTableLog {
		 = maxEncTableLog
	}
	.actualTableLog = 
}

// validateNorm validates the normalized histogram table.
func ( *fseEncoder) () ( error) {
	var  int
	for ,  := range .norm[:.symbolLen] {
		if  >= 0 {
			 += int()
		} else {
			 -= int()
		}
	}
	defer func() {
		if  == nil {
			return
		}
		fmt.Printf("selected TableLog: %d, Symbol length: %d\n", .actualTableLog, .symbolLen)
		for ,  := range .norm[:.symbolLen] {
			fmt.Printf("%3d: %5d -> %4d \n", , .count[], )
		}
	}()
	if  != (1 << .actualTableLog) {
		return fmt.Errorf("warning: Total == %d != %d", , 1<<.actualTableLog)
	}
	for ,  := range .count[.symbolLen:] {
		if  != 0 {
			return fmt.Errorf("warning: Found symbol out of range, %d after cut", )
		}
	}
	return nil
}

// writeCount will write the normalized histogram count to header.
// This is read back by readNCount.
func ( *fseEncoder) ( []byte) ([]byte, error) {
	if .useRLE {
		return append(, .rleVal), nil
	}
	if .preDefined || .reUsed {
		// Never write predefined.
		return , nil
	}

	var (
		  = .actualTableLog
		 = 1 << 
		 bool
		   uint16

		// maximum header size plus 2 extra bytes for final output if bitCount == 0.
		 = ((int(.symbolLen) * int()) >> 3) + 3 + 2

		// Write Table Size
		 = uint32( - minEncTablelog)
		  = uint(4)
		 = int16( + 1) /* +1 for extra accuracy */
		 = int16()
		    = uint( + 1)
		      = len()
	)
	if cap() < + {
		 = append(, make([]byte, *3)...)
		 = [:len()-*3]
	}
	 = [:+]

	// stops at 1
	for  > 1 {
		if  {
			 := 
			for .norm[] == 0 {
				++
			}
			for  >= +24 {
				 += 24
				 += uint32(0xFFFF) << 
				[] = byte()
				[+1] = byte( >> 8)
				 += 2
				 >>= 16
			}
			for  >= +3 {
				 += 3
				 += 3 << 
				 += 2
			}
			 += uint32(-) << 
			 += 2
			if  > 16 {
				[] = byte()
				[+1] = byte( >> 8)
				 += 2
				 >>= 16
				 -= 16
			}
		}

		 := .norm[]
		++
		 := (2* - 1) - 
		if  < 0 {
			 += 
		} else {
			 -= 
		}
		++ // +1 for extra accuracy
		if  >=  {
			 +=  // [0..max[ [max..threshold[ (...) [threshold+max 2*threshold[
		}
		 += uint32() << 
		 += 
		if  <  {
			--
		}

		 =  == 1
		if  < 1 {
			return nil, errors.New("internal error: remaining < 1")
		}
		for  <  {
			--
			 >>= 1
		}

		if  > 16 {
			[] = byte()
			[+1] = byte( >> 8)
			 += 2
			 >>= 16
			 -= 16
		}
	}

	if +2 > len() {
		return nil, fmt.Errorf("internal error: %d > %d, maxheader: %d, sl: %d, tl: %d, normcount: %v", +2, len(), , .symbolLen, int(), .norm[:.symbolLen])
	}
	[] = byte()
	[+1] = byte( >> 8)
	 += int(( + 7) / 8)

	if  > .symbolLen {
		return nil, errors.New("internal error: charnum > s.symbolLen")
	}
	return [:], nil
}

// Approximate symbol cost, as fractional value, using fixed-point format (accuracyLog fractional bits)
// note 1 : assume symbolValue is valid (<= maxSymbolValue)
// note 2 : if freq[symbolValue]==0, @return a fake cost of tableLog+1 bits *
func ( *fseEncoder) ( uint8,  uint32) uint32 {
	 := .ct.symbolTT[].deltaNbBits >> 16
	 := ( + 1) << 16
	if debugAsserts {
		if !(.actualTableLog < 16) {
			panic("!s.actualTableLog < 16")
		}
		// ensure enough room for renormalization double shift
		if !(uint8() < 31-.actualTableLog) {
			panic("!uint8(accuracyLog) < 31-s.actualTableLog")
		}
	}
	 := uint32(1) << .actualTableLog
	 :=  - (.ct.symbolTT[].deltaNbBits + )
	// linear interpolation (very approximate)
	 := ( << ) >> .actualTableLog
	 := uint32(1) << 
	if debugAsserts {
		if .ct.symbolTT[].deltaNbBits+ >  {
			panic("s.ct.symbolTT[symbolValue].deltaNbBits+tableSize > threshold")
		}
		if  >  {
			panic("normalizedDeltaFromThreshold > bitMultiplier")
		}
	}
	return (+1)* - 
}

// Returns the cost in bits of encoding the distribution in count using ctable.
// Histogram should only be up to the last non-zero symbol.
// Returns an -1 if ctable cannot represent all the symbols in count.
func ( *fseEncoder) ( []uint32) uint32 {
	if int(.symbolLen) < len() {
		// More symbols than we have.
		return math.MaxUint32
	}
	if .useRLE {
		// We will never reuse RLE encoders.
		return math.MaxUint32
	}
	const  = 8
	 := (uint32(.actualTableLog) + 1) << 
	var  uint32
	for ,  := range  {
		if  == 0 {
			continue
		}
		if .norm[] == 0 {
			return math.MaxUint32
		}
		 := .bitCost(uint8(), )
		if  >  {
			return math.MaxUint32
		}
		 +=  * 
	}
	return  >> 
}

// maxHeaderSize returns the maximum header size in bits.
// This is not exact size, but we want a penalty for new tables anyway.
func ( *fseEncoder) () uint32 {
	if .preDefined {
		return 0
	}
	if .useRLE {
		return 8
	}
	return (((uint32(.symbolLen) * uint32(.actualTableLog)) >> 3) + 3) * 8
}

// cState contains the compression state of a stream.
type cState struct {
	bw         *bitWriter
	stateTable []uint16
	state      uint16
}

// init will initialize the compression state to the first symbol of the stream.
func ( *cState) ( *bitWriter,  *cTable,  symbolTransform) {
	.bw = 
	.stateTable = .stateTable
	if len(.stateTable) == 1 {
		// RLE
		.stateTable[0] = uint16(0)
		.state = 0
		return
	}
	 := (.deltaNbBits + (1 << 15)) >> 16
	 := int32(( << 16) - .deltaNbBits)
	 := ( >> ) + int32(.deltaFindState)
	.state = .stateTable[]
}

// flush will write the tablelog to the output and flush the remaining full bytes.
func ( *cState) ( uint8) {
	.bw.flush32()
	.bw.addBits16NC(.state, )
}