package gorm

import (
	
	
	
	
	
	
	

	
)

type PreparedStmtDB struct {
	Stmts stmt_store.Store
	Mux   *sync.RWMutex
	ConnPool
}

// NewPreparedStmtDB creates and initializes a new instance of PreparedStmtDB.
//
// Parameters:
// - connPool: A connection pool that implements the ConnPool interface, used for managing database connections.
// - maxSize: The maximum number of prepared statements that can be stored in the statement store.
// - ttl: The time-to-live duration for each prepared statement in the store. Statements older than this duration will be automatically removed.
//
// Returns:
// - A pointer to a PreparedStmtDB instance, which manages prepared statements using the provided connection pool and configuration.
func ( ConnPool,  int,  time.Duration) *PreparedStmtDB {
	return &PreparedStmtDB{
		ConnPool: ,                     // Assigns the provided connection pool to manage database connections.
		Stmts:    stmt_store.New(, ), // Initializes a new statement store with the specified maximum size and TTL.
		Mux:      &sync.RWMutex{},              // Sets up a read-write mutex for synchronizing access to the statement store.
	}
}

// GetDBConn returns the underlying *sql.DB connection
func ( *PreparedStmtDB) () (*sql.DB, error) {
	if ,  := .ConnPool.(*sql.DB);  {
		return , nil
	}

	if ,  := .ConnPool.(GetDBConnector);  &&  != nil {
		return .GetDBConn()
	}

	return nil, ErrInvalidDB
}

// Close closes all prepared statements in the store
func ( *PreparedStmtDB) () {
	.Mux.Lock()
	defer .Mux.Unlock()

	for ,  := range .Stmts.Keys() {
		.Stmts.Delete()
	}
}

// Reset Deprecated use Close instead
func ( *PreparedStmtDB) () {
	.Close()
}

func ( *PreparedStmtDB) ( context.Context,  ConnPool,  bool,  string) ( *stmt_store.Stmt,  error) {
	.Mux.RLock()
	if .Stmts != nil {
		if ,  := .Stmts.Get();  && (!.Transaction || ) {
			.Mux.RUnlock()
			return , .Error()
		}
	}
	.Mux.RUnlock()

	// retry
	.Mux.Lock()
	if .Stmts != nil {
		if ,  := .Stmts.Get();  && (!.Transaction || ) {
			.Mux.Unlock()
			return , .Error()
		}
	}

	return .Stmts.New(, , , , .Mux)
}

func ( *PreparedStmtDB) ( context.Context,  *sql.TxOptions) (ConnPool, error) {
	if ,  := .ConnPool.(TxBeginner);  {
		,  := .BeginTx(, )
		return &PreparedStmtTX{PreparedStmtDB: , Tx: }, 
	}

	,  := .ConnPool.(ConnPoolBeginner)
	if ! {
		return nil, ErrInvalidTransaction
	}

	,  := .BeginTx(, )
	if  != nil {
		return nil, 
	}
	if ,  := .(Tx);  {
		return &PreparedStmtTX{PreparedStmtDB: , Tx: }, nil
	}
	return nil, ErrInvalidTransaction
}

func ( *PreparedStmtDB) ( context.Context,  string,  ...interface{}) ( sql.Result,  error) {
	,  := .prepare(, .ConnPool, false, )
	if  == nil {
		,  = .ExecContext(, ...)
		if errors.Is(, driver.ErrBadConn) {
			.Stmts.Delete()
		}
	}
	return , 
}

func ( *PreparedStmtDB) ( context.Context,  string,  ...interface{}) ( *sql.Rows,  error) {
	,  := .prepare(, .ConnPool, false, )
	if  == nil {
		,  = .QueryContext(, ...)
		if errors.Is(, driver.ErrBadConn) {
			.Stmts.Delete()
		}
	}
	return , 
}

func ( *PreparedStmtDB) ( context.Context,  string,  ...interface{}) *sql.Row {
	,  := .prepare(, .ConnPool, false, )
	if  == nil {
		return .QueryRowContext(, ...)
	}
	return &sql.Row{}
}

func ( *PreparedStmtDB) () error {
	,  := .GetDBConn()
	if  != nil {
		return 
	}
	return .Ping()
}

type PreparedStmtTX struct {
	Tx
	PreparedStmtDB *PreparedStmtDB
}

func ( *PreparedStmtTX) () (*sql.DB, error) {
	return .PreparedStmtDB.GetDBConn()
}

func ( *PreparedStmtTX) () error {
	if .Tx != nil && !reflect.ValueOf(.Tx).IsNil() {
		return .Tx.Commit()
	}
	return ErrInvalidTransaction
}

func ( *PreparedStmtTX) () error {
	if .Tx != nil && !reflect.ValueOf(.Tx).IsNil() {
		return .Tx.Rollback()
	}
	return ErrInvalidTransaction
}

func ( *PreparedStmtTX) ( context.Context,  string,  ...interface{}) ( sql.Result,  error) {
	,  := .PreparedStmtDB.prepare(, .Tx, true, )
	if  == nil {
		,  = .Tx.StmtContext(, .Stmt).ExecContext(, ...)
		if errors.Is(, driver.ErrBadConn) {
			.PreparedStmtDB.Stmts.Delete()
		}
	}
	return , 
}

func ( *PreparedStmtTX) ( context.Context,  string,  ...interface{}) ( *sql.Rows,  error) {
	,  := .PreparedStmtDB.prepare(, .Tx, true, )
	if  == nil {
		,  = .Tx.StmtContext(, .Stmt).QueryContext(, ...)
		if errors.Is(, driver.ErrBadConn) {
			.PreparedStmtDB.Stmts.Delete()
		}
	}
	return , 
}

func ( *PreparedStmtTX) ( context.Context,  string,  ...interface{}) *sql.Row {
	,  := .PreparedStmtDB.prepare(, .Tx, true, )
	if  == nil {
		return .Tx.StmtContext(, .Stmt).QueryRowContext(, ...)
	}
	return &sql.Row{}
}

func ( *PreparedStmtTX) () error {
	,  := .GetDBConn()
	if  != nil {
		return 
	}
	return .Ping()
}