package wal

import (
	
	
	
	
	
	
	
	

	
	
	
	
	
	
	
	

	walpb 
)

type ReplayHandlerFunc func(tx uint64, record *walpb.Record) error

type NopWAL struct{}

func ( *NopWAL) () error {
	return nil
}

func ( *NopWAL) ( uint64,  *walpb.Record) error {
	return nil
}

func ( *NopWAL) ( uint64,  ReplayHandlerFunc) error {
	return nil
}

func ( *NopWAL) ( uint64,  string,  arrow.Record) error {
	return nil
}

func ( *NopWAL) ( uint64) error {
	return nil
}

func ( *NopWAL) ( uint64) error {
	return nil
}

func ( *NopWAL) () (uint64, error) {
	return 0, nil
}

func ( *NopWAL) () (uint64, error) {
	return 0, nil
}

type Metrics struct {
	FailedLogs            prometheus.Counter
	LastTruncationAt      prometheus.Gauge
	WalRepairs            prometheus.Counter
	WalRepairsLostRecords prometheus.Counter
	WalCloseTimeouts      prometheus.Counter
	WalQueueSize          prometheus.Gauge
}

func newMetrics( prometheus.Registerer) *Metrics {
	return &Metrics{
		FailedLogs: promauto.With().NewCounter(prometheus.CounterOpts{
			Name: "failed_logs_total",
			Help: "Number of failed WAL logs",
		}),
		LastTruncationAt: promauto.With().NewGauge(prometheus.GaugeOpts{
			Name: "last_truncation_at",
			Help: "The last transaction the WAL was truncated to",
		}),
		WalRepairs: promauto.With().NewCounter(prometheus.CounterOpts{
			Name: "repairs_total",
			Help: "The number of times the WAL had to be repaired (truncated) due to corrupt records",
		}),
		WalRepairsLostRecords: promauto.With().NewCounter(prometheus.CounterOpts{
			Name: "repairs_lost_records_total",
			Help: "The number of WAL records lost due to WAL repairs (truncations)",
		}),
		WalCloseTimeouts: promauto.With().NewCounter(prometheus.CounterOpts{
			Name: "close_timeouts_total",
			Help: "The number of times the WAL failed to close due to a timeout",
		}),
		WalQueueSize: promauto.With().NewGauge(prometheus.GaugeOpts{
			Name: "queue_size",
			Help: "The number of unprocessed requests in the WAL queue",
		}),
	}
}

const (
	dirPerms           = os.FileMode(0o750)
	progressLogTimeout = 10 * time.Second
)

type FileWAL struct {
	logger log.Logger
	path   string
	log    wal.LogStore

	metrics      *Metrics
	storeMetrics *wal.Metrics

	logRequestCh   chan *logRequest
	logRequestPool *sync.Pool
	arrowBufPool   *sync.Pool
	protected      struct {
		sync.Mutex
		queue logRequestQueue
		// truncateTx is set when the caller wishes to perform a truncation. The
		// WAL will keep on logging records up to and including this txn and
		// then perform a truncation. If another truncate call occurs in the
		// meantime, the highest txn will be used.
		truncateTx uint64
		// nextTx is the next expected txn. The FileWAL will only log a record
		// with this txn.
		nextTx uint64
	}

	// scratch memory reused to reduce allocations.
	scratch struct {
		walBatch []types.LogEntry
		reqBatch []*logRequest
	}

	// segmentSize indicates what the underlying WAL segment size is. This helps
	// the run goroutine size batches more or less appropriately.
	segmentSize int
	// lastBatchWrite is used to determine when to force a close of the WAL.
	lastBatchWrite time.Time

	cancel       func()
	shutdownCh   chan struct{}
	closeTimeout time.Duration

	newLogStoreWrapper func(wal.LogStore) wal.LogStore
	ticker             Ticker
	testingDroppedLogs func([]types.LogEntry)
}

type logRequest struct {
	tx   uint64
	data []byte
}

// min-heap based priority queue to synchronize log requests to be in order of
// transactions.
type logRequestQueue []*logRequest

func ( logRequestQueue) () int           { return len() }
func ( logRequestQueue) (,  int) bool { return [].tx < [].tx }
func ( logRequestQueue) (,  int)      { [], [] = [], [] }

func ( *logRequestQueue) ( any) {
	// Push and Pop use pointer receivers because they modify the slice's length,
	// not just its contents.
	* = append(*, .(*logRequest))
}

func ( *logRequestQueue) () any {
	 := *
	 := len()
	 := [-1]
	// Remove this reference to a logRequest since the GC considers the popped
	// element still accessible otherwise. Since these are sync pooled, we want
	// to defer object lifetime management to the pool without interfering.
	[-1] = nil
	* = [0 : -1]
	return 
}

type Option func(*FileWAL)

func ( func(wal.LogStore) wal.LogStore) Option {
	return func( *FileWAL) {
		.newLogStoreWrapper = 
	}
}

func ( *Metrics) Option {
	return func( *FileWAL) {
		.metrics = 
	}
}

func ( *wal.Metrics) Option {
	return func( *FileWAL) {
		.storeMetrics = 
	}
}

type Ticker interface {
	C() <-chan time.Time
	Stop()
}

type realTicker struct {
	*time.Ticker
}

func ( realTicker) () <-chan time.Time {
	return .Ticker.C
}

// WithTestingLoopTicker allows the caller to force processing of pending WAL
// entries by providing a custom ticker implementation.
func ( Ticker) Option {
	return func( *FileWAL) {
		.ticker = 
	}
}

// WithTestingCallbackWithDroppedLogsOnClose is called when the WAL times out on
// close with all the entries that could not be written.
func ( func([]types.LogEntry)) Option {
	return func( *FileWAL) {
		.testingDroppedLogs = 
	}
}

func (
	 log.Logger,
	 string,
	 ...Option,
) (*FileWAL, error) {
	if  := os.MkdirAll(, dirPerms);  != nil {
		return nil, 
	}

	 := wal.DefaultSegmentSize
	 := &FileWAL{
		logger:       ,
		path:         ,
		logRequestCh: make(chan *logRequest),
		logRequestPool: &sync.Pool{
			New: func() any {
				return &logRequest{
					data: make([]byte, 1024),
				}
			},
		},
		arrowBufPool: &sync.Pool{
			New: func() any {
				return &bytes.Buffer{}
			},
		},
		closeTimeout: 1 * time.Second,
		segmentSize:  ,
		shutdownCh:   make(chan struct{}),
	}

	for ,  := range  {
		()
	}

	,  := wal.Open(, wal.WithLogger(), wal.WithMetrics(.storeMetrics), wal.WithSegmentSize())
	if  != nil {
		return nil, 
	}

	,  := .LastIndex()
	if  != nil {
		return nil, 
	}
	.protected.nextTx =  + 1

	if .newLogStoreWrapper != nil {
		.log = .newLogStoreWrapper()
	} else {
		.log = 
	}
	if .metrics == nil {
		.metrics = newMetrics(prometheus.NewRegistry())
	}

	.scratch.walBatch = make([]types.LogEntry, 0, 64)
	.scratch.reqBatch = make([]*logRequest, 0, 64)

	return , nil
}

func ( *FileWAL) ( context.Context) {
	const  = 50 * time.Millisecond
	if .ticker == nil {
		.ticker = realTicker{Ticker: time.NewTicker()}
	}
	defer .ticker.Stop()
	// lastQueueSize is only used on shutdown to reduce debug logging verbosity.
	 := 0
	.lastBatchWrite = time.Now()

	for {
		select {
		case <-.Done():
			// Need to drain the queue before we can shutdown, so
			// proactively try to process entries.
			.process()

			.protected.Lock()
			 := .protected.queue.Len()
			.protected.Unlock()
			if  > 0 {
				// Force the WAL to close after some a timeout.
				if .closeTimeout > 0 && time.Since(.lastBatchWrite) > .closeTimeout {
					.metrics.WalCloseTimeouts.Inc()
					level.Error(.logger).Log(
						"msg", "WAL timed out attempting to close",
					)
					if .testingDroppedLogs != nil {
						 := make([]types.LogEntry, 0, )
						.protected.Lock()
						for ,  := range .protected.queue {
							 = append(, types.LogEntry{Index: .tx, Data: .data})
						}
						.protected.Unlock()
						.testingDroppedLogs()
					}
					return
				}

				if  ==  {
					// No progress made.
					time.Sleep()
					continue
				}

				 = 
				.protected.Lock()
				 := .protected.queue[0].tx
				.protected.Unlock()
				,  := .log.LastIndex()
				 := []any{
					"msg", "WAL received shutdown request; waiting for log request queue to drain",
					"queueSize", ,
					"minTx", ,
					"lastIndex", ,
				}
				if  != nil {
					 = append(, "lastIndexErr", )
				}
				level.Debug(.logger).Log(...)
				continue
			}
			level.Debug(.logger).Log("msg", "WAL shut down")
			return
		case <-.ticker.C():
			.process()
		}
	}
}

func ( *FileWAL) () {
	.scratch.reqBatch = .scratch.reqBatch[:0]

	.protected.Lock()
	 := 0
	for .protected.queue.Len() > 0 &&  < .segmentSize {
		if  := .protected.queue[0].tx;  != .protected.nextTx {
			if  < .protected.nextTx {
				// The next entry must be dropped otherwise progress
				// will never be made. Log a warning given this could
				// lead to missing data.
				level.Warn(.logger).Log(
					"msg", "WAL cannot log a txn id that has already been seen; dropping entry",
					"expected", .protected.nextTx,
					"found", ,
				)
				.logRequestPool.Put(heap.Pop(&.protected.queue))
				.metrics.WalQueueSize.Sub(1)
				// Keep on going since there might be other transactions
				// below this one.
				continue
			}
			if  := time.Since(.lastBatchWrite);  > progressLogTimeout {
				level.Info(.logger).Log(
					"msg", "wal has not made progress",
					"since", ,
					"next_expected_tx", .protected.nextTx,
					"min_tx", ,
				)
			}
			// Next expected tx has not yet been seen.
			break
		}
		 := heap.Pop(&.protected.queue).(*logRequest)
		.metrics.WalQueueSize.Sub(1)
		.scratch.reqBatch = append(.scratch.reqBatch, )
		 += len(.data)
		.protected.nextTx++
	}
	// truncateTx will be non-zero if we either are about to log a
	// record with a txn past the txn to truncate, or we have logged one
	// in the past.
	 := uint64(0)
	if .protected.truncateTx != 0 {
		 = .protected.truncateTx
		.protected.truncateTx = 0
	}
	.protected.Unlock()
	if len(.scratch.reqBatch) == 0 &&  == 0 {
		// No records to log or truncations.
		return
	}

	.scratch.walBatch = .scratch.walBatch[:0]
	for ,  := range .scratch.reqBatch {
		.scratch.walBatch = append(.scratch.walBatch, types.LogEntry{
			Index: .tx,
			// No copy is needed here since the log request is only
			// released once these bytes are persisted.
			Data: .data,
		})
	}

	if len(.scratch.walBatch) > 0 {
		if  := .log.StoreLogs(.scratch.walBatch);  != nil {
			.metrics.FailedLogs.Add(float64(len(.scratch.reqBatch)))
			,  := .log.LastIndex()
			level.Error(.logger).Log(
				"msg", "failed to write WAL batch",
				"err", ,
				"lastIndex", ,
				"lastIndexErr", ,
			)
		}
	}

	if  != 0 {
		.metrics.LastTruncationAt.Set(float64())
		level.Debug(.logger).Log("msg", "truncating WAL", "tx", )
		if  := .log.TruncateFront();  != nil {
			level.Error(.logger).Log("msg", "failed to truncate WAL", "tx", , "err", )
		} else {
			.protected.Lock()
			if  > .protected.nextTx {
				// truncateTx is the new firstIndex of the WAL. If it is
				// greater than the next expected transaction, this was
				// a full WAL truncation/reset so both the first and
				// last index are now 0. The underlying WAL will allow a
				// record with any index to be written, however we only
				// want to allow the next index to be logged.
				.protected.nextTx = 
				// Remove any records that have not yet been written and
				// are now below the nextTx.
				for .protected.queue.Len() > 0 {
					if  := .protected.queue[0].tx;  >= .protected.nextTx {
						break
					}
					.logRequestPool.Put(heap.Pop(&.protected.queue))
					.metrics.WalQueueSize.Sub(1)
				}
			}
			.protected.Unlock()
			level.Debug(.logger).Log("msg", "truncated WAL", "tx", )
		}
	}

	// Remove references to a logRequest since the GC considers the
	// popped element still accessible otherwise. Since these are sync
	// pooled, we want to defer object lifetime management to the pool
	// without interfering.
	for  := range .scratch.walBatch {
		.scratch.walBatch[].Data = nil
	}
	for ,  := range .scratch.reqBatch {
		.scratch.reqBatch[] = nil
		.logRequestPool.Put()
	}

	.lastBatchWrite = time.Now()
}

// Truncate queues a truncation of the WAL at the given tx. Note that the
// truncation will be performed asynchronously. A nil error does not indicate
// a successful truncation.
func ( *FileWAL) ( uint64) error {
	.protected.Lock()
	defer .protected.Unlock()
	if  > .protected.truncateTx {
		.protected.truncateTx = 
	}
	return nil
}

func ( *FileWAL) ( uint64) error {
	.protected.Lock()
	defer .protected.Unlock()
	// Drain any pending records.
	for .protected.queue.Len() > 0 {
		_ = heap.Pop(&.protected.queue)
	}
	// Set the next expected transaction.
	.protected.nextTx = 
	// This truncation will fully reset the underlying WAL. Any index can be
	// logged, but setting the nextTx above will ensure that only a record with
	// a matching txn will be accepted as the first record.
	return .log.TruncateFront(math.MaxUint64)
}

func ( *FileWAL) () error {
	if .cancel == nil { // wal was never started
		return nil
	}
	level.Debug(.logger).Log("msg", "WAL received shutdown request; canceling run loop")
	.cancel()
	<-.shutdownCh
	return .log.Close()
}

func ( *FileWAL) ( uint64,  *walpb.Record) error {
	 := .logRequestPool.Get().(*logRequest)
	.tx = 
	 := .SizeVT()
	if cap(.data) <  {
		.data = make([]byte, )
	}
	.data = .data[:]
	,  := .MarshalToSizedBufferVT(.data)
	if  != nil {
		return 
	}

	.protected.Lock()
	heap.Push(&.protected.queue, )
	.metrics.WalQueueSize.Add(1)
	.protected.Unlock()

	return nil
}

func ( *FileWAL) () *bytes.Buffer {
	return .arrowBufPool.Get().(*bytes.Buffer)
}

func ( *FileWAL) ( *bytes.Buffer) {
	.Reset()
	.arrowBufPool.Put()
}

func ( *FileWAL) ( *bytes.Buffer,  arrow.Record) error {
	 := ipc.NewWriter(
		,
		ipc.WithSchema(.Schema()),
	)
	defer .Close()

	return .Write()
}

func ( *FileWAL) ( uint64,  string,  arrow.Record) error {
	.protected.Lock()
	 := .protected.nextTx
	.protected.Unlock()
	if  <  {
		// Transaction should not be logged. This could happen if a truncation
		// has been issued simultaneously as logging a WAL record.
		level.Warn(.logger).Log(
			"msg", "attempted to log txn below next expected txn",
			"tx", ,
			"next_tx", ,
		)
		return nil
	}
	 := .getArrowBuf()
	defer .putArrowBuf()
	if  := .writeRecord(, );  != nil {
		return 
	}

	 := &walpb.Record{
		Entry: &walpb.Entry{
			EntryType: &walpb.Entry_Write_{
				Write: &walpb.Entry_Write{
					Data:      .Bytes(),
					TableName: ,
					Arrow:     true,
				},
			},
		},
	}

	 := .logRequestPool.Get().(*logRequest)
	.tx = 
	 := .SizeVT()
	if cap(.data) <  {
		.data = make([]byte, )
	}
	.data = .data[:]
	,  := .MarshalToSizedBufferVT(.data)
	if  != nil {
		return 
	}

	.protected.Lock()
	heap.Push(&.protected.queue, )
	.metrics.WalQueueSize.Add(1)
	.protected.Unlock()

	return nil
}

func ( *FileWAL) () (uint64, error) {
	return .log.FirstIndex()
}

func ( *FileWAL) () (uint64, error) {
	return .log.LastIndex()
}

func ( *FileWAL) ( uint64,  ReplayHandlerFunc) ( error) {
	if  == nil { // no handler provided
		return nil
	}

	,  := .log.FirstIndex()
	if  != nil {
		return fmt.Errorf("read first index: %w", )
	}
	if  == 0 ||  <  {
		 = 
	}

	,  := .log.LastIndex()
	if  != nil {
		return fmt.Errorf("read last index: %w", )
	}

	// FirstIndex and LastIndex returns zero when there is no WAL files.
	if  == 0 ||  == 0 {
		return nil
	}

	level.Debug(.logger).Log("msg", "replaying WAL", "first_index", , "last_index", )

	defer func() {
		// recover a panic of reading a transaction. Truncate the wal to the
		// last valid transaction.
		if  := recover();  != nil {
			level.Error(.logger).Log(
				"msg", "replaying WAL failed",
				"path", .path,
				"first_index", ,
				"last_index", ,
				"offending_index", ,
				"err", ,
			)
			if  = .log.TruncateBack( - 1);  != nil {
				return
			}
			.metrics.WalRepairs.Inc()
			.metrics.WalRepairsLostRecords.Add(float64(( - ) + 1))
		}
	}()

	var  types.LogEntry
	for ;  <= ; ++ {
		level.Debug(.logger).Log("msg", "replaying WAL record", "tx", )
		if  := .log.GetLog(, &);  != nil {
			// Panic since this is most likely a corruption issue. The recover
			// call above will truncate the WAL to the last valid transaction.
			panic(fmt.Sprintf("read index %d: %v", , ))
		}

		 := &walpb.Record{}
		if  := .UnmarshalVT(.Data);  != nil {
			// Panic since this is most likely a corruption issue. The recover
			// call above will truncate the WAL to the last valid transaction.
			panic(fmt.Sprintf("unmarshal WAL record: %v", ))
		}

		if  := (, );  != nil {
			return fmt.Errorf("call replay handler: %w", )
		}
	}

	return nil
}

func ( *FileWAL) () {
	,  := context.WithCancel(context.Background())
	.cancel = 
	go func() {
		.run()
		close(.shutdownCh)
	}()
}