// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package compress

import (
	
	

	
	
)

type zstdCodec struct{}

type zstdcloser struct {
	*zstd.Decoder
}

var (
	enc         *zstd.Encoder
	dec         *zstd.Decoder
	initEncoder sync.Once
	initDecoder sync.Once
)

func getencoder() *zstd.Encoder {
	initEncoder.Do(func() {
		enc, _ = zstd.NewWriter(nil, zstd.WithZeroFrames(true))
	})
	return enc
}

func getdecoder() *zstd.Decoder {
	initDecoder.Do(func() {
		dec, _ = zstd.NewReader(nil)
	})
	return dec
}

func (zstdCodec) (,  []byte) []byte {
	,  := getdecoder().DecodeAll(, [:0])
	if  != nil {
		panic()
	}
	return 
}

func ( *zstdcloser) () error {
	.Decoder.Close()
	return nil
}

func (zstdCodec) ( io.Reader) io.ReadCloser {
	,  := zstd.NewReader()
	return &zstdcloser{}
}

func (zstdCodec) ( io.Writer) io.WriteCloser {
	,  := zstd.NewWriter()
	return 
}

func (zstdCodec) ( io.Writer,  int) (io.WriteCloser, error) {
	var  zstd.EncoderLevel
	if  == DefaultCompressionLevel {
		 = zstd.SpeedDefault
	} else {
		 = zstd.EncoderLevelFromZstd()
	}
	return zstd.NewWriter(, zstd.WithEncoderLevel())
}

func ( zstdCodec) (,  []byte) []byte {
	return getencoder().EncodeAll(, [:0])
}

func ( zstdCodec) (,  []byte,  int) []byte {
	 := zstd.EncoderLevelFromZstd()
	if  == DefaultCompressionLevel {
		 = zstd.SpeedDefault
	}
	,  := zstd.NewWriter(nil, zstd.WithZeroFrames(true), zstd.WithEncoderLevel())
	return .EncodeAll(, [:0])
}

// from zstd.h, ZSTD_COMPRESSBOUND
func (zstdCodec) ( int64) int64 {
	debug.Assert( > 0, "len for zstd CompressBound should be > 0")
	 := ((128 << 10) - ) >> 11
	if  >= (128 << 10) {
		 = 0
	}
	return  + ( >> 8) + 
}

func init() {
	RegisterCodec(Codecs.Zstd, zstdCodec{})
}