// Copyright 2018 Klaus Post. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Based on work Copyright (c) 2013, Yann Collet, released under BSD License.

package fse

import (
	
	
)

// Compress the input bytes. Input must be < 2GB.
// Provide a Scratch buffer to avoid memory allocations.
// Note that the output is also kept in the scratch buffer.
// If input is too hard to compress, ErrIncompressible is returned.
// If input is a single byte value repeated ErrUseRLE is returned.
func ( []byte,  *Scratch) ([]byte, error) {
	if len() <= 1 {
		return nil, ErrIncompressible
	}
	if len() > (2<<30)-1 {
		return nil, errors.New("input too big, must be < 2GB")
	}
	,  := .prepare()
	if  != nil {
		return nil, 
	}

	// Create histogram, if none was provided.
	 := .maxCount
	if  == 0 {
		 = .countSimple()
	}
	// Reset for next run.
	.clearCount = true
	.maxCount = 0
	if  == len() {
		// One symbol, use RLE
		return nil, ErrUseRLE
	}
	if  == 1 ||  < (len()>>7) {
		// Each symbol present maximum once or too well distributed.
		return nil, ErrIncompressible
	}
	.optimalTableLog()
	 = .normalizeCount()
	if  != nil {
		return nil, 
	}
	 = .writeCount()
	if  != nil {
		return nil, 
	}

	if false {
		 = .validateNorm()
		if  != nil {
			return nil, 
		}
	}

	 = .buildCTable()
	if  != nil {
		return nil, 
	}
	 = .compress()
	if  != nil {
		return nil, 
	}
	.Out = .bw.out
	// Check if we compressed.
	if len(.Out) >= len() {
		return nil, ErrIncompressible
	}
	return .Out, nil
}

// cState contains the compression state of a stream.
type cState struct {
	bw         *bitWriter
	stateTable []uint16
	state      uint16
}

// init will initialize the compression state to the first symbol of the stream.
func ( *cState) ( *bitWriter,  *cTable,  uint8,  symbolTransform) {
	.bw = 
	.stateTable = .stateTable

	 := (.deltaNbBits + (1 << 15)) >> 16
	 := int32(( << 16) - .deltaNbBits)
	 := ( >> ) + .deltaFindState
	.state = .stateTable[]
}

// encode the output symbol provided and write it to the bitstream.
func ( *cState) ( symbolTransform) {
	 := (uint32(.state) + .deltaNbBits) >> 16
	 := int32(.state>>(&15)) + .deltaFindState
	.bw.addBits16NC(.state, uint8())
	.state = .stateTable[]
}

// encode the output symbol provided and write it to the bitstream.
func ( *cState) ( symbolTransform) {
	 := (uint32(.state) + .deltaNbBits) >> 16
	 := int32(.state>>(&15)) + .deltaFindState
	.bw.addBits16ZeroNC(.state, uint8())
	.state = .stateTable[]
}

// flush will write the tablelog to the output and flush the remaining full bytes.
func ( *cState) ( uint8) {
	.bw.flush32()
	.bw.addBits16NC(.state, )
	.bw.flush()
}

// compress is the main compression loop that will encode the input from the last byte to the first.
func ( *Scratch) ( []byte) error {
	if len() <= 2 {
		return errors.New("compress: src too small")
	}
	 := .ct.symbolTT[:256]
	.bw.reset(.Out)

	// Our two states each encodes every second byte.
	// Last byte encoded (first byte decoded) will always be encoded by c1.
	var ,  cState

	// Encode so remaining size is divisible by 4.
	 := len()
	if &1 == 1 {
		.init(&.bw, &.ct, .actualTableLog, [[-1]])
		.init(&.bw, &.ct, .actualTableLog, [[-2]])
		.encodeZero([[-3]])
		 -= 3
	} else {
		.init(&.bw, &.ct, .actualTableLog, [[-1]])
		.init(&.bw, &.ct, .actualTableLog, [[-2]])
		 -= 2
	}
	if &2 != 0 {
		.encodeZero([[-1]])
		.encodeZero([[-2]])
		 -= 2
	}
	 = [:]

	// Main compression loop.
	switch {
	case !.zeroBits && .actualTableLog <= 8:
		// We can encode 4 symbols without requiring a flush.
		// We do not need to check if any output is 0 bits.
		for ; len() >= 4;  = [:len()-4] {
			.bw.flush32()
			, , ,  := [len()-4], [len()-3], [len()-2], [len()-1]
			.encode([])
			.encode([])
			.encode([])
			.encode([])
		}
	case !.zeroBits:
		// We do not need to check if any output is 0 bits.
		for ; len() >= 4;  = [:len()-4] {
			.bw.flush32()
			, , ,  := [len()-4], [len()-3], [len()-2], [len()-1]
			.encode([])
			.encode([])
			.bw.flush32()
			.encode([])
			.encode([])
		}
	case .actualTableLog <= 8:
		// We can encode 4 symbols without requiring a flush
		for ; len() >= 4;  = [:len()-4] {
			.bw.flush32()
			, , ,  := [len()-4], [len()-3], [len()-2], [len()-1]
			.encodeZero([])
			.encodeZero([])
			.encodeZero([])
			.encodeZero([])
		}
	default:
		for ; len() >= 4;  = [:len()-4] {
			.bw.flush32()
			, , ,  := [len()-4], [len()-3], [len()-2], [len()-1]
			.encodeZero([])
			.encodeZero([])
			.bw.flush32()
			.encodeZero([])
			.encodeZero([])
		}
	}

	// Flush final state.
	// Used to initialize state when decoding.
	.flush(.actualTableLog)
	.flush(.actualTableLog)

	.bw.close()
	return nil
}

// writeCount will write the normalized histogram count to header.
// This is read back by readNCount.
func ( *Scratch) () error {
	var (
		  = .actualTableLog
		 = 1 << 
		 bool
		   uint16

		 = ((int(.symbolLen)*int() + 4 + 2) >> 3) + 3

		// Write Table Size
		 = uint32( - minTablelog)
		  = uint(4)
		 = int16( + 1) /* +1 for extra accuracy */
		 = int16()
		    = uint( + 1)
	)
	if cap(.Out) <  {
		.Out = make([]byte, 0, .br.remain()+)
	}
	 := uint(0)
	 := .Out[:]

	// stops at 1
	for  > 1 {
		if  {
			 := 
			for .norm[] == 0 {
				++
			}
			for  >= +24 {
				 += 24
				 += uint32(0xFFFF) << 
				[] = byte()
				[+1] = byte( >> 8)
				 += 2
				 >>= 16
			}
			for  >= +3 {
				 += 3
				 += 3 << 
				 += 2
			}
			 += uint32(-) << 
			 += 2
			if  > 16 {
				[] = byte()
				[+1] = byte( >> 8)
				 += 2
				 >>= 16
				 -= 16
			}
		}

		 := .norm[]
		++
		 := (2* - 1) - 
		if  < 0 {
			 += 
		} else {
			 -= 
		}
		++ // +1 for extra accuracy
		if  >=  {
			 +=  // [0..max[ [max..threshold[ (...) [threshold+max 2*threshold[
		}
		 += uint32() << 
		 += 
		if  <  {
			--
		}

		 =  == 1
		if  < 1 {
			return errors.New("internal error: remaining<1")
		}
		for  <  {
			--
			 >>= 1
		}

		if  > 16 {
			[] = byte()
			[+1] = byte( >> 8)
			 += 2
			 >>= 16
			 -= 16
		}
	}

	[] = byte()
	[+1] = byte( >> 8)
	 += ( + 7) / 8

	if  > .symbolLen {
		return errors.New("internal error: charnum > s.symbolLen")
	}
	.Out = [:]
	return nil
}

// symbolTransform contains the state transform for a symbol.
type symbolTransform struct {
	deltaFindState int32
	deltaNbBits    uint32
}

// String prints values as a human readable string.
func ( symbolTransform) () string {
	return fmt.Sprintf("dnbits: %08x, fs:%d", .deltaNbBits, .deltaFindState)
}

// cTable contains tables used for compression.
type cTable struct {
	tableSymbol []byte
	stateTable  []uint16
	symbolTT    []symbolTransform
}

// allocCtable will allocate tables needed for compression.
// If existing tables a re big enough, they are simply re-used.
func ( *Scratch) () {
	 := 1 << .actualTableLog
	// get tableSymbol that is big enough.
	if cap(.ct.tableSymbol) <  {
		.ct.tableSymbol = make([]byte, )
	}
	.ct.tableSymbol = .ct.tableSymbol[:]

	 := 
	if cap(.ct.stateTable) <  {
		.ct.stateTable = make([]uint16, )
	}
	.ct.stateTable = .ct.stateTable[:]

	if cap(.ct.symbolTT) < 256 {
		.ct.symbolTT = make([]symbolTransform, 256)
	}
	.ct.symbolTT = .ct.symbolTT[:256]
}

// buildCTable will populate the compression table so it is ready to be used.
func ( *Scratch) () error {
	 := uint32(1 << .actualTableLog)
	 :=  - 1
	var  [maxSymbolValue + 2]int16

	.allocCtable()
	 := .ct.tableSymbol[:]
	// symbol start positions
	{
		[0] = 0
		for ,  := range .norm[:.symbolLen-1] {
			 := byte() // one less than reference
			if  == -1 {
				// Low proba symbol
				[+1] = [] + 1
				[] = 
				--
			} else {
				[+1] = [] + 
			}
		}
		// Encode last symbol separately to avoid overflowing u
		 := int(.symbolLen - 1)
		 := .norm[.symbolLen-1]
		if  == -1 {
			// Low proba symbol
			[+1] = [] + 1
			[] = byte()
			--
		} else {
			[+1] = [] + 
		}
		if uint32([.symbolLen]) !=  {
			return fmt.Errorf("internal error: expected cumul[s.symbolLen] (%d) == tableSize (%d)", [.symbolLen], )
		}
		[.symbolLen] = int16() + 1
	}
	// Spread symbols
	.zeroBits = false
	{
		 := tableStep()
		 :=  - 1
		var  uint32
		// if any symbol > largeLimit, we may have 0 bits output.
		 := int16(1 << (.actualTableLog - 1))
		for ,  := range .norm[:.symbolLen] {
			 := byte()
			if  >  {
				.zeroBits = true
			}
			for range  {
				[] = 
				 = ( + ) & 
				for  >  {
					 = ( + ) & 
				} /* Low proba area */
			}
		}

		// Check if we have gone through all positions
		if  != 0 {
			return errors.New("position!=0")
		}
	}

	// Build table
	 := .ct.stateTable
	{
		 := int()
		for ,  := range  {
			// TableU16 : sorted by symbol order; gives next state value
			[[]] = uint16( + )
			[]++
		}
	}

	// Build Symbol Transformation Table
	{
		 := int16(0)
		 := .ct.symbolTT[:.symbolLen]
		 := .actualTableLog
		 := (uint32() << 16) - (1 << )
		for ,  := range .norm[:.symbolLen] {
			switch  {
			case 0:
			case -1, 1:
				[].deltaNbBits = 
				[].deltaFindState = int32( - 1)
				++
			default:
				 := uint32() - highBits(uint32(-1))
				 := uint32() << 
				[].deltaNbBits = ( << 16) - 
				[].deltaFindState = int32( - )
				 += 
			}
		}
		if  != int16() {
			return fmt.Errorf("total mismatch %d (got) != %d (want)", , )
		}
	}
	return nil
}

// countSimple will create a simple histogram in s.count.
// Returns the biggest count.
// Does not update s.clearCount.
func ( *Scratch) ( []byte) ( int) {
	for ,  := range  {
		.count[]++
	}
	,  := uint32(0), .symbolLen
	for ,  := range .count[:] {
		if  == 0 {
			continue
		}
		if  >  {
			 = 
		}
		 = uint16() + 1
	}
	.symbolLen = 
	return int()
}

// minTableLog provides the minimum logSize to safely represent a distribution.
func ( *Scratch) () uint8 {
	 := highBits(uint32(.br.remain()-1)) + 1
	 := highBits(uint32(.symbolLen-1)) + 2
	if  <  {
		return uint8()
	}
	return uint8()
}

// optimalTableLog calculates and sets the optimal tableLog in s.actualTableLog
func ( *Scratch) () {
	 := .TableLog
	 := .minTableLog()
	 := uint8(highBits(uint32(.br.remain()-1))) - 2
	if  <  {
		// Accuracy can be reduced
		 = 
	}
	if  >  {
		 = 
	}
	// Need a minimum to safely represent all symbol values
	if  < minTablelog {
		 = minTablelog
	}
	if  > maxTableLog {
		 = maxTableLog
	}
	.actualTableLog = 
}

var rtbTable = [...]uint32{0, 473195, 504333, 520860, 550000, 700000, 750000, 830000}

// normalizeCount will normalize the count of the symbols so
// the total is equal to the table size.
func ( *Scratch) () error {
	var (
		          = .actualTableLog
		             = 62 - uint64()
		              = (1 << 62) / uint64(.br.remain())
		             = uint64(1) << ( - 20)
		 = int16(1 << )
		           int
		          int16
		      = (uint32)(.br.remain() >> )
	)

	for ,  := range .count[:.symbolLen] {
		// already handled
		// if (count[s] == s.length) return 0;   /* rle special case */

		if  == 0 {
			.norm[] = 0
			continue
		}
		if  <=  {
			.norm[] = -1
			--
		} else {
			 := (int16)((uint64() * ) >> )
			if  < 8 {
				 :=  * uint64(rtbTable[])
				 := uint64()* - (uint64() << )
				if  >  {
					++
				}
			}
			if  >  {
				 = 
				 = 
			}
			.norm[] = 
			 -= 
		}
	}

	if - >= (.norm[] >> 1) {
		// corner case, need another normalization method
		return .normalizeCount2()
	}
	.norm[] += 
	return nil
}

// Secondary normalization method.
// To be used when primary method fails.
func ( *Scratch) () error {
	const  = -2
	var (
		  uint32
		        = uint32(.br.remain())
		     = .actualTableLog
		 =  >> 
		       = ( * 3) >> ( + 1)
	)
	for ,  := range .count[:.symbolLen] {
		if  == 0 {
			.norm[] = 0
			continue
		}
		if  <=  {
			.norm[] = -1
			++
			 -= 
			continue
		}
		if  <=  {
			.norm[] = 1
			++
			 -= 
			continue
		}
		.norm[] = 
	}
	 := (1 << ) - 

	if ( / ) >  {
		// risk of rounding to zero
		 = ( * 3) / ( * 2)
		for ,  := range .count[:.symbolLen] {
			if (.norm[] == ) && ( <= ) {
				.norm[] = 1
				++
				 -= 
				continue
			}
		}
		 = (1 << ) - 
	}
	if  == uint32(.symbolLen)+1 {
		// all values are pretty poor;
		//   probably incompressible data (should have already been detected);
		//   find max, then give all remaining points to max
		var  int
		var  uint32
		for ,  := range .count[:.symbolLen] {
			if  >  {
				 = 
				 = 
			}
		}
		.norm[] += int16()
		return nil
	}

	if  == 0 {
		// all of the symbols were low enough for the lowOne or lowThreshold
		for  := uint32(0);  > 0;  = ( + 1) % (uint32(.symbolLen)) {
			if .norm[] > 0 {
				--
				.norm[]++
			}
		}
		return nil
	}

	var (
		 = 62 - uint64()
		      = uint64((1 << ( - 1)) - 1)
		    = (((1 << ) * uint64()) + ) / uint64() // scale on remaining
		 = 
	)
	for ,  := range .count[:.symbolLen] {
		if .norm[] ==  {
			var (
				    =  + uint64()*
				 = uint32( >> )
				   = uint32( >> )
				 =  - 
			)
			if  < 1 {
				return errors.New("weight < 1")
			}
			.norm[] = int16()
			 = 
		}
	}
	return nil
}

// validateNorm validates the normalized histogram table.
func ( *Scratch) () ( error) {
	var  int
	for ,  := range .norm[:.symbolLen] {
		if  >= 0 {
			 += int()
		} else {
			 -= int()
		}
	}
	defer func() {
		if  == nil {
			return
		}
		fmt.Printf("selected TableLog: %d, Symbol length: %d\n", .actualTableLog, .symbolLen)
		for ,  := range .norm[:.symbolLen] {
			fmt.Printf("%3d: %5d -> %4d \n", , .count[], )
		}
	}()
	if  != (1 << .actualTableLog) {
		return fmt.Errorf("warning: Total == %d != %d", , 1<<.actualTableLog)
	}
	for ,  := range .count[.symbolLen:] {
		if  != 0 {
			return fmt.Errorf("warning: Found symbol out of range, %d after cut", )
		}
	}
	return nil
}