package frostdb

import (
	
	
	
	
	
	
	
	
	
	
	

	
	

	
	
	

	
	snapshotpb 
	tablepb 
	walpb 
	
	
)

// This file implements writing and reading database snapshots from disk.
// The snapshot format at the time of writing is as follows:
// 4-byte magic "FDBS"
// <Table 1 Granule 1 Part 1>
// <Table 2 Granule 1 Part 1>
// <Table 2 Granule 1 Part 2>
// <Table 2 Granule 2 Part 1>
// ...
// Footer/File Metadata
// 4-byte length in bytes of footer/file metadata (little endian)
// 4-byte version number (little endian)
// 4-byte checksum (little endian)
// 4-byte magic "FDBS"
//
// Readers should start reading a snapshot by first verifying that the magic
// bytes are correct, followed by the version number to ensure that the snapshot
// was encoded using a version the reader supports. A version bump could, for
// example, add compression to the data bytes of the file.
// Refer to minVersionSupported/maxVersionSupported for more details.

const (
	snapshotMagic = "FDBS"
	dirPerms      = os.FileMode(0o755)
	filePerms     = os.FileMode(0o640)
	// When bumping the version number, please add a comment indicating the
	// reason for the bump. Note that the version should only be bumped if the
	// new version introduces backwards-incompatible changes. Note that protobuf
	// changes are backwards-compatible, this version number is only necessary
	// for the non-proto format (e.g. if compression is introduced).
	// Version 1: Initial snapshot version with checksum and version number.
	snapshotVersion = 1
	// minReadVersion is bumped when deprecating older versions. For example,
	// a reader of the new version can choose to still support reading older
	// versions, but will bump this constant to the minimum version it claims
	// to support.
	minReadVersion = snapshotVersion
)

// segmentName returns a 20-byte textual representation of a snapshot file name
// at a given txn used for lexical ordering.
func snapshotFileName( uint64) string {
	return fmt.Sprintf("%020d.fdbs", )
}

func getTxFromSnapshotFileName( string) (uint64, error) {
	,  := strconv.ParseUint([:20], 10, 64)
	if  != nil {
		return 0, 
	}
	return , nil
}

// asyncSnapshot begins a new transaction and takes a snapshot of the
// database in a new goroutine at that txn. It returns whether a snapshot was
// started (i.e. no other snapshot was in progress). When the snapshot
// goroutine successfully completes a snapshot, onSuccess is called.
func ( *DB) ( context.Context,  func()) {
	.snapshot(, true, )
}

// Snapshot performs a database snapshot and writes it to the database snapshots
// directory, as is done by automatic snapshots.
func ( *DB) ( context.Context) error {
	.snapshot(, false, func() {})
	return .reclaimDiskSpace(, nil)
}

func ( *DB) ( context.Context,  bool,  func()) {
	if !.columnStore.enableWAL {
		return
	}
	if !.snapshotInProgress.CompareAndSwap(false, true) {
		// Snapshot already in progress.
		level.Debug(.logger).Log(
			"msg", "cannot start snapshot; snapshot already in progress",
		)
		return
	}

	, ,  := .begin()
	level.Debug(.logger).Log(
		"msg", "starting a new snapshot",
		"tx", ,
	)
	 := func( func(context.Context, io.Writer) error) {
		.Wait( - 1) // Wait for all transactions to complete before taking a snapshot.
		 := time.Now()
		defer .snapshotInProgress.Store(false)
		defer ()
		if .columnStore.enableWAL {
			// Appending a snapshot record to the WAL is necessary,
			// since the WAL expects a 1:1 relationship between txn ids
			// and record indexes. This is done before the actual snapshot so
			// that a failure to snapshot still appends a record to the WAL,
			// avoiding a WAL deadlock.
			if  := .wal.Log(
				,
				&walpb.Record{
					Entry: &walpb.Entry{
						EntryType: &walpb.Entry_Snapshot_{Snapshot: &walpb.Entry_Snapshot{Tx: }},
					},
				},
			);  != nil {
				level.Error(.logger).Log(
					"msg", "failed to append snapshot record to WAL", "err", ,
				)
				return
			}
		}
		if  := .snapshotAtTX(, , );  != nil {
			level.Error(.logger).Log(
				"msg", "failed to snapshot database", "err", ,
			)
			return
		}
		level.Debug(.logger).Log(
			"msg", "snapshot complete",
			"tx", ,
			"duration", time.Since(),
		)
		()
	}

	if  {
		go (.snapshotWriter())
	} else {
		(.offlineSnapshotWriter())
	}
}

// snapshotAtTX takes a snapshot of the state of the database at transaction tx.
func ( *DB) ( context.Context,  uint64,  func(context.Context, io.Writer) error) error {
	var  int64
	 := time.Now()
	if  := func() error {
		 := SnapshotDir(, )
		 := filepath.Join(, snapshotFileName())
		,  := os.Stat()
		if  == nil { // Snapshot file already exists
			if .validateSnapshotTxn(, ) == nil {
				return nil // valid snapshot already exists at tx no need to re-snapshot
			}

			// Snapshot exists but is invalid. Remove it.
			if  := os.RemoveAll(SnapshotDir(, ));  != nil {
				return fmt.Errorf("failed to remove invalid snapshot %v: %w", , )
			}
		}
		if  := os.MkdirAll(, dirPerms);  != nil {
			return 
		}

		,  := os.OpenFile(, os.O_CREATE|os.O_RDWR|os.O_TRUNC, filePerms)
		if  != nil {
			return 
		}
		defer .Close()

		if  := func() error {
			if  := (, );  != nil {
				return 
			}
			if  := .Sync();  != nil {
				return 
			}
			,  := .Stat()
			if  != nil {
				return 
			}
			 = .Size()
			return nil
		}();  != nil {
			 = fmt.Errorf("failed to write snapshot for tx %d: %w", , )
			if  := os.RemoveAll();  != nil {
				 = fmt.Errorf("%w: failed to remove snapshot directory: %v", , )
			}
			return 
		}
		return nil
	}();  != nil {
		.metrics.snapshotsTotal.WithLabelValues("false").Inc()
		return 
	}
	.metrics.snapshotsTotal.WithLabelValues("true").Inc()
	if  > 0 {
		.metrics.snapshotFileSizeBytes.Set(float64())
	}
	.metrics.snapshotDurationHistogram.Observe(time.Since().Seconds())
	// TODO(asubiotto): If snapshot file sizes become too large, investigate
	// adding compression.
	return nil
}

// loadLatestSnapshot loads the latest snapshot (i.e. the snapshot with the
// highest txn) from the snapshots dir into the database.
func ( *DB) ( context.Context) (uint64, error) {
	return .loadLatestSnapshotFromDir(, .snapshotsDir())
}

func ( *DB) ( context.Context,  string) (uint64, error) {
	var (
		   error
		 uint64
	)
	// No error should be returned from snapshotsDo.
	_ = .snapshotsDo(, , func( uint64,  os.DirEntry) (bool, error) {
		if  := func() error {
			,  := os.Open(filepath.Join(, .Name(), snapshotFileName()))
			if  != nil {
				return 
			}
			defer .Close()
			,  := .Stat()
			if  != nil {
				return 
			}
			,  := LoadSnapshot(, , , , .Size(), filepath.Join(, .Name()), false)
			if  != nil {
				return 
			}
			// Success.
			 = 
			return nil
		}();  != nil {
			 = fmt.Errorf("unable to read snapshot file %s: %w", .Name(), )
			level.Debug(.logger).Log(
				"msg", "error reading snapshot",
				"error", ,
			)
			 = 
			return true, nil
		}
		return false, nil
	})
	if  != 0 {
		// Successfully loaded a snapshot.
		return , nil
	}

	 := "no valid snapshots found"
	if  != nil {
		return 0, fmt.Errorf("%s: lastErr: %w", , )
	}
	return 0, fmt.Errorf("%s", )
}

func ( context.Context,  *DB,  uint64,  io.ReaderAt,  int64,  string,  bool) (uint64, error) {
	if  := loadSnapshot(, , , , );  != nil {
		return 0, 
	}
	 := 
	var  WAL
	if  {
		 = .wal
	}
	.resetToTxn(, )
	return , nil
}

func ( *DB) ( context.Context,  uint64) error {
	 := .snapshotsDir()

	return .snapshotsDo(, , func( uint64,  os.DirEntry) (bool, error) {
		if  !=  { // We're only trying to validate a single tx
			return true, nil
		}

		return false, func() error {
			,  := os.Open(filepath.Join(, .Name(), snapshotFileName()))
			if  != nil {
				return 
			}
			defer .Close()
			,  := .Stat()
			if  != nil {
				return 
			}
			// readFooter validates the checksum.
			if ,  := readFooter(, .Size());  != nil {
				return 
			}
			return nil
		}()
	})
}

func ( *DB) ( context.Context) (uint64, error) {
	 := .snapshotsDir()
	 := uint64(0)
	// No error should be returned from snapshotsDo.
	_ = .snapshotsDo(, , func( uint64,  os.DirEntry) (bool, error) {
		if  := func() error {
			,  := os.Open(filepath.Join(, .Name(), snapshotFileName()))
			if  != nil {
				return 
			}
			defer .Close()
			,  := .Stat()
			if  != nil {
				return 
			}
			// readFooter validates the checksum.
			if ,  := readFooter(, .Size());  != nil {
				return 
			}
			return nil
		}();  != nil {
			level.Debug(.logger).Log(
				"msg", "error reading snapshot",
				"error", ,
			)
			// Continue to the next snapshot.
			return true, nil
		}
		// Valid snapshot found.
		 = 
		return false, nil
	})
	return , nil
}

type offsetWriter struct {
	w               io.Writer
	runningChecksum hash.Hash32
	offset          int
}

func newChecksumWriter() hash.Hash32 {
	return crc32.New(crc32.MakeTable(crc32.Castagnoli))
}

func newOffsetWriter( io.Writer) *offsetWriter {
	return &offsetWriter{
		w:               ,
		runningChecksum: newChecksumWriter(),
	}
}

func ( *offsetWriter) ( []byte) (int, error) {
	if ,  := .runningChecksum.Write();  != nil {
		return , fmt.Errorf("error writing checksum: %w", )
	}
	,  := .w.Write()
	.offset += 
	return , 
}

func ( *offsetWriter) () uint32 {
	return .runningChecksum.Sum32()
}

func ( *DB) ( uint64) func(context.Context, io.Writer) error {
	return func( context.Context,  io.Writer) error {
		return WriteSnapshot(, , , )
	}
}

// offlineSnapshotWriter is used when a database is closing after all the tables have closed.
func ( *DB) ( uint64) func(context.Context, io.Writer) error {
	return func( context.Context,  io.Writer) error {
		return WriteSnapshot(, , , )
	}
}

func ( context.Context,  uint64,  *DB,  io.Writer) error {
	 := newOffsetWriter()
	 = 
	var  []*Table
	.mtx.RLock()
	for ,  := range .tables {
		 = append(, )
	}
	.mtx.RUnlock()

	if ,  := .Write([]byte(snapshotMagic));  != nil {
		return 
	}

	 := &snapshotpb.FooterData{}
	for ,  := range  {
		if  := func() error {
			// Obtain a write block to prevent racing with
			// compaction/persistence.
			, ,  := .ActiveWriteBlock()
			if  != nil {
				return 
			}
			defer ()
			,  := .ulid.MarshalBinary()
			if  != nil {
				return 
			}

			 := &snapshotpb.Table{
				Name:   .name,
				Config: .config.Load(),
				ActiveBlock: &snapshotpb.Table_TableBlock{
					Ulid:   ,
					Size:   .Size(),
					MinTx:  .minTx,
					PrevTx: .prevTx,
				},
			}

			if  := .Index().Snapshot(, func( parts.Part) error {
				 := &snapshotpb.Granule{}
				 := &snapshotpb.Part{
					StartOffset:     int64(.offset),
					Tx:              .TX(),
					CompactionLevel: uint64(.CompactionLevel()),
				}
				if  := .Err();  != nil {
					return 
				}

				if  := .Record();  != nil {
					.Encoding = snapshotpb.Part_ENCODING_ARROW
				} else {
					.Encoding = snapshotpb.Part_ENCODING_PARQUET
				}

				if  := .Write();  != nil {
					return 
				}

				.EndOffset = int64(.offset)
				.PartMetadata = append(.PartMetadata, )
				.GranuleMetadata = append(.GranuleMetadata, ) // TODO: we have one part per granule now
				return nil
			}, snapshotIndexDir(, , .name, .ulid.String()));  != nil {
				return fmt.Errorf("failed to snapshot table %s index: %w", .name, )
			}

			.TableMetadata = append(.TableMetadata, )
			return nil
		}();  != nil {
			return 
		}
	}
	,  := .MarshalVT()
	if  != nil {
		return 
	}
	// Write footer + size.
	 = binary.LittleEndian.AppendUint32(, uint32(len()))
	if ,  := .Write();  != nil {
		return 
	}
	if ,  := .Write(binary.LittleEndian.AppendUint32(nil, snapshotVersion));  != nil {
		return 
	}
	if ,  := .Write(binary.LittleEndian.AppendUint32(nil, .checksum()));  != nil {
		return 
	}
	if ,  := .Write([]byte(snapshotMagic));  != nil {
		return 
	}
	return nil
}

func readFooter( io.ReaderAt,  int64) (*snapshotpb.FooterData, error) {
	 := make([]byte, 16)
	if ,  := .ReadAt([:4], 0);  != nil {
		return nil, 
	}
	if string([:4]) != snapshotMagic {
		return nil, fmt.Errorf("invalid snapshot magic: %q", [:4])
	}
	if ,  := .ReadAt(, -int64(len()));  != nil {
		return nil, 
	}
	if string([12:]) != snapshotMagic {
		return nil, fmt.Errorf("invalid snapshot magic: %q", [4:])
	}

	// The checksum does not include the last 8 bytes of the file, which is the
	// magic and the checksum. Create a section reader of all but the last 8
	// bytes to compute the checksum and validate it against the read checksum.
	 := binary.LittleEndian.Uint32([8:12])
	 := newChecksumWriter()
	if ,  := io.Copy(, io.NewSectionReader(, 0, -8));  != nil {
		return nil, fmt.Errorf("failed to compute checksum: %w", )
	}
	if  != .Sum32() {
		return nil, fmt.Errorf(
			"snapshot file corrupt: invalid checksum: expected %x, got %x", , .Sum32(),
		)
	}

	 := binary.LittleEndian.Uint32([4:8])
	if  > snapshotVersion {
		return nil, fmt.Errorf(
			"cannot read snapshot with version %d: max version supported: %d", , snapshotVersion,
		)
	}
	if  < minReadVersion {
		return nil, fmt.Errorf(
			"cannot read snapshot with version %d: min version supported: %d", , minReadVersion,
		)
	}

	 := binary.LittleEndian.Uint32([:4])
	 := make([]byte, )
	if ,  := .ReadAt(, -(int64(len())+int64()));  != nil {
		return nil, 
	}
	 := &snapshotpb.FooterData{}
	if  := .UnmarshalVT();  != nil {
		return nil, fmt.Errorf("could not unmarshal footer: %v", )
	}
	return , nil
}

// loadSnapshot loads a snapshot from the given io.ReaderAt and returns the
// txnMetadata (if any) the snapshot was created with and an error if any
// occurred.
func loadSnapshot( context.Context,  *DB,  io.ReaderAt,  int64,  string) error {
	,  := readFooter(, )
	if  != nil {
		return 
	}

	for ,  := range .TableMetadata {
		if  := func() error {
			var  proto.Message
			switch v := .Config.Schema.(type) {
			case *tablepb.TableConfig_DeprecatedSchema:
				 = .DeprecatedSchema
			case *tablepb.TableConfig_SchemaV2:
				 = .SchemaV2
			default:
				return fmt.Errorf("unhandled schema type: %T", )
			}

			 := []TableOption{
				WithRowGroupSize(int(.Config.RowGroupSize)),
				WithBlockReaderLimit(int(.Config.BlockReaderLimit)),
			}
			if .Config.DisableWal {
				 = append(, WithoutWAL())
			}
			 := NewTableConfig(
				,
				...,
			)

			var  ulid.ULID
			if  := .UnmarshalBinary(.ActiveBlock.Ulid);  != nil {
				return 
			}

			// Restore the table index from tx snapshot dir
			if  := restoreIndexFilesFromSnapshot(, .Name, , .String());  != nil {
				return 
			}

			,  := .table(.Name, , )
			if  != nil {
				return 
			}

			.mtx.Lock()
			 := .active
			.mtx.Lock()
			// Store the last snapshot size so a snapshot is not triggered right
			// after loading this snapshot.
			.lastSnapshotSize.Store(.ActiveBlock.Size)
			.minTx = .ActiveBlock.MinTx
			.prevTx = .ActiveBlock.PrevTx
			 := .Index()
			.mtx.Unlock()
			.mtx.Unlock()

			for ,  := range .GranuleMetadata {
				 := make([]parts.Part, 0, len(.PartMetadata))
				for ,  := range .PartMetadata {
					if  := .Err();  != nil {
						return 
					}
					 := .StartOffset
					 := .EndOffset
					 := make([]byte, -)
					if ,  := .ReadAt(, );  != nil {
						return 
					}
					 := parts.WithCompactionLevel(int(.CompactionLevel))
					switch .Encoding {
					case snapshotpb.Part_ENCODING_PARQUET:
						,  := dynparquet.ReaderFromBytes()
						if  != nil {
							return 
						}
						 = append(, parts.NewParquetPart(.Tx, , ))
					case snapshotpb.Part_ENCODING_ARROW:
						if  := func() error {
							,  := ipc.NewReader(bytes.NewReader())
							if  != nil {
								return 
							}
							defer .Release()

							,  := .Read()
							if  != nil {
								return 
							}

							.Retain()
							 = append(
								,
								parts.NewArrowPart(.Tx, , uint64(util.TotalRecordSize()), .schema, ),
							)
							return nil
						}();  != nil {
							return 
						}
					default:
						return fmt.Errorf("unknown part encoding: %s", .Encoding)
					}
				}

				for ,  := range  {
					.InsertPart()
				}
			}

			return nil
		}();  != nil {
			.mtx.Lock()
			for ,  := range .TableMetadata[:] {
				delete(.tables, .Name)
			}
			.mtx.Unlock()
			return 
		}
	}

	return nil
}

// cleanupSnapshotDir should be called with a tx at which the caller is certain
// a valid snapshot exists (e.g. the tx returned from
// getLatestValidSnapshotTxn). This method deletes all snapshots taken at any
// other transaction.
func ( *DB) ( context.Context,  uint64) error {
	 := .snapshotsDir()
	return .snapshotsDo(, , func( uint64,  os.DirEntry) (bool, error) {
		if  ==  {
			// Continue.
			return true, nil
		}
		if  := os.RemoveAll(filepath.Join(, .Name()));  != nil {
			return false, 
		}
		return true, nil
	})
}

// snapshotsDo executes the given callback with the directory of each snapshot
// in dir in reverse lexicographical order (most recent snapshot first). If
// false or an error is returned by the callback, the iteration is aborted and
// the error returned.
func ( *DB) ( context.Context,  string,  func( uint64,  os.DirEntry) (bool, error)) error {
	,  := os.ReadDir()
	if  != nil {
		return 
	}
	for  := len() - 1;  >= 0; -- {
		 := []
		if .Err() != nil {
			return .Err()
		}
		if filepath.Ext(.Name()) == ".fdbs" { // Legacy snapshots were stored at the top-level. Ignore these
			continue
		}
		 := .Name()
		if len() < 20 {
			continue
		}
		,  := getTxFromSnapshotFileName()
		if  != nil {
			continue
		}
		if ,  := (, );  != nil {
			return 
		} else if ! {
			return nil
		}
	}
	return nil
}

func ( context.Context,  uint64,  *DB,  io.Reader) error {
	return .snapshotAtTX(, , func( context.Context,  io.Writer) error {
		,  := io.Copy(, )
		return 
	})
}

// Will restore the index files found in the given directory back to the table's index directory.
func restoreIndexFilesFromSnapshot( *DB, , ,  string) error {
	// Remove the current index directory.
	if  := os.RemoveAll(filepath.Join(.indexDir(), ));  != nil {
		return fmt.Errorf("failed to remove index directory: %w", )
	}

	 := filepath.Join(, "index", , )

	// Restore the index files from the snapshot files.
	return filepath.WalkDir(, func( string,  os.DirEntry,  error) error {
		if  != nil {
			if os.IsNotExist() {
				return nil // There is no index directory for this table.
			}
			return fmt.Errorf("failed to walk snapshot index directory: %w", )
		}

		if .IsDir() { // Level dirs expected
			return nil
		}

		if filepath.Ext() != index.IndexFileExtension {
			return nil // unknown file
		}

		// Expected file path is .../<level>/<file>
		 := filepath.Base()
		 := filepath.Base(filepath.Dir())

		if  := os.MkdirAll(filepath.Join(.indexDir(), , , ), dirPerms);  != nil {
			return 
		}

		// Hard link the file back into the index directory.
		if  := os.Link(, filepath.Join(.indexDir(), , , , ));  != nil {
			return fmt.Errorf("hard link file: %w", )
		}

		return nil
	})
}

func ( *DB,  uint64) string {
	return filepath.Join(.snapshotsDir(), fmt.Sprintf("%020d", ))
}

func snapshotIndexDir( *DB,  uint64, ,  string) string {
	return filepath.Join(SnapshotDir(, ), "index", , )
}