// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2016 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.

package mysql

import (
	
	
	
	
	
	
	
	
	
	
	
	
	
)

var (
	errInvalidDSNUnescaped       = errors.New("invalid DSN: did you forget to escape a param value?")
	errInvalidDSNAddr            = errors.New("invalid DSN: network address not terminated (missing closing brace)")
	errInvalidDSNNoSlash         = errors.New("invalid DSN: missing the slash separating the database name")
	errInvalidDSNUnsafeCollation = errors.New("invalid DSN: interpolateParams can not be used with unsafe collations")
)

// Config is a configuration parsed from a DSN string.
// If a new Config is created instead of being parsed from a DSN string,
// the NewConfig function should be used, which sets default values.
type Config struct {
	// non boolean fields

	User                 string            // Username
	Passwd               string            // Password (requires User)
	Net                  string            // Network (e.g. "tcp", "tcp6", "unix". default: "tcp")
	Addr                 string            // Address (default: "127.0.0.1:3306" for "tcp" and "/tmp/mysql.sock" for "unix")
	DBName               string            // Database name
	Params               map[string]string // Connection parameters
	ConnectionAttributes string            // Connection Attributes, comma-delimited string of user-defined "key:value" pairs
	Collation            string            // Connection collation
	Loc                  *time.Location    // Location for time.Time values
	MaxAllowedPacket     int               // Max packet size allowed
	ServerPubKey         string            // Server public key name
	TLSConfig            string            // TLS configuration name
	TLS                  *tls.Config       // TLS configuration, its priority is higher than TLSConfig
	Timeout              time.Duration     // Dial timeout
	ReadTimeout          time.Duration     // I/O read timeout
	WriteTimeout         time.Duration     // I/O write timeout
	Logger               Logger            // Logger

	// boolean fields

	AllowAllFiles            bool // Allow all files to be used with LOAD DATA LOCAL INFILE
	AllowCleartextPasswords  bool // Allows the cleartext client side plugin
	AllowFallbackToPlaintext bool // Allows fallback to unencrypted connection if server does not support TLS
	AllowNativePasswords     bool // Allows the native password authentication method
	AllowOldPasswords        bool // Allows the old insecure password method
	CheckConnLiveness        bool // Check connections for liveness before using them
	ClientFoundRows          bool // Return number of matching rows instead of rows changed
	ColumnsWithAlias         bool // Prepend table alias to column names
	InterpolateParams        bool // Interpolate placeholders into query string
	MultiStatements          bool // Allow multiple statements in one query
	ParseTime                bool // Parse time values to time.Time
	RejectReadOnly           bool // Reject read-only connections

	// unexported fields. new options should be come here

	beforeConnect func(context.Context, *Config) error // Invoked before a connection is established
	pubKey        *rsa.PublicKey                       // Server public key
	timeTruncate  time.Duration                        // Truncate time.Time values to the specified duration
}

// Functional Options Pattern
// https://dave.cheney.net/2014/10/17/functional-options-for-friendly-apis
type Option func(*Config) error

// NewConfig creates a new Config and sets default values.
func () *Config {
	 := &Config{
		Loc:                  time.UTC,
		MaxAllowedPacket:     defaultMaxAllowedPacket,
		Logger:               defaultLogger,
		AllowNativePasswords: true,
		CheckConnLiveness:    true,
	}

	return 
}

// Apply applies the given options to the Config object.
func ( *Config) ( ...Option) error {
	for ,  := range  {
		 := ()
		if  != nil {
			return 
		}
	}
	return nil
}

// TimeTruncate sets the time duration to truncate time.Time values in
// query parameters.
func ( time.Duration) Option {
	return func( *Config) error {
		.timeTruncate = 
		return nil
	}
}

// BeforeConnect sets the function to be invoked before a connection is established.
func ( func(context.Context, *Config) error) Option {
	return func( *Config) error {
		.beforeConnect = 
		return nil
	}
}

func ( *Config) () *Config {
	 := *
	if .TLS != nil {
		.TLS = .TLS.Clone()
	}
	if len(.Params) > 0 {
		.Params = make(map[string]string, len(.Params))
		for ,  := range .Params {
			.Params[] = 
		}
	}
	if .pubKey != nil {
		.pubKey = &rsa.PublicKey{
			N: new(big.Int).Set(.pubKey.N),
			E: .pubKey.E,
		}
	}
	return &
}

func ( *Config) () error {
	if .InterpolateParams && .Collation != "" && unsafeCollations[.Collation] {
		return errInvalidDSNUnsafeCollation
	}

	// Set default network if empty
	if .Net == "" {
		.Net = "tcp"
	}

	// Set default address if empty
	if .Addr == "" {
		switch .Net {
		case "tcp":
			.Addr = "127.0.0.1:3306"
		case "unix":
			.Addr = "/tmp/mysql.sock"
		default:
			return errors.New("default addr for network '" + .Net + "' unknown")
		}
	} else if .Net == "tcp" {
		.Addr = ensureHavePort(.Addr)
	}

	if .TLS == nil {
		switch .TLSConfig {
		case "false", "":
			// don't set anything
		case "true":
			.TLS = &tls.Config{}
		case "skip-verify":
			.TLS = &tls.Config{InsecureSkipVerify: true}
		case "preferred":
			.TLS = &tls.Config{InsecureSkipVerify: true}
			.AllowFallbackToPlaintext = true
		default:
			.TLS = getTLSConfigClone(.TLSConfig)
			if .TLS == nil {
				return errors.New("invalid value / unknown config name: " + .TLSConfig)
			}
		}
	}

	if .TLS != nil && .TLS.ServerName == "" && !.TLS.InsecureSkipVerify {
		, ,  := net.SplitHostPort(.Addr)
		if  == nil {
			.TLS.ServerName = 
		}
	}

	if .ServerPubKey != "" {
		.pubKey = getServerPubKey(.ServerPubKey)
		if .pubKey == nil {
			return errors.New("invalid value / unknown server pub key name: " + .ServerPubKey)
		}
	}

	if .Logger == nil {
		.Logger = defaultLogger
	}

	return nil
}

func writeDSNParam( *bytes.Buffer,  *bool, ,  string) {
	.Grow(1 + len() + 1 + len())
	if !* {
		* = true
		.WriteByte('?')
	} else {
		.WriteByte('&')
	}
	.WriteString()
	.WriteByte('=')
	.WriteString()
}

// FormatDSN formats the given Config into a DSN string which can be passed to
// the driver.
//
// Note: use [NewConnector] and [database/sql.OpenDB] to open a connection from a [*Config].
func ( *Config) () string {
	var  bytes.Buffer

	// [username[:password]@]
	if len(.User) > 0 {
		.WriteString(.User)
		if len(.Passwd) > 0 {
			.WriteByte(':')
			.WriteString(.Passwd)
		}
		.WriteByte('@')
	}

	// [protocol[(address)]]
	if len(.Net) > 0 {
		.WriteString(.Net)
		if len(.Addr) > 0 {
			.WriteByte('(')
			.WriteString(.Addr)
			.WriteByte(')')
		}
	}

	// /dbname
	.WriteByte('/')
	.WriteString(url.PathEscape(.DBName))

	// [?param1=value1&...&paramN=valueN]
	 := false

	if .AllowAllFiles {
		 = true
		.WriteString("?allowAllFiles=true")
	}

	if .AllowCleartextPasswords {
		writeDSNParam(&, &, "allowCleartextPasswords", "true")
	}

	if .AllowFallbackToPlaintext {
		writeDSNParam(&, &, "allowFallbackToPlaintext", "true")
	}

	if !.AllowNativePasswords {
		writeDSNParam(&, &, "allowNativePasswords", "false")
	}

	if .AllowOldPasswords {
		writeDSNParam(&, &, "allowOldPasswords", "true")
	}

	if !.CheckConnLiveness {
		writeDSNParam(&, &, "checkConnLiveness", "false")
	}

	if .ClientFoundRows {
		writeDSNParam(&, &, "clientFoundRows", "true")
	}

	if  := .Collation;  != "" {
		writeDSNParam(&, &, "collation", )
	}

	if .ColumnsWithAlias {
		writeDSNParam(&, &, "columnsWithAlias", "true")
	}

	if .InterpolateParams {
		writeDSNParam(&, &, "interpolateParams", "true")
	}

	if .Loc != time.UTC && .Loc != nil {
		writeDSNParam(&, &, "loc", url.QueryEscape(.Loc.String()))
	}

	if .MultiStatements {
		writeDSNParam(&, &, "multiStatements", "true")
	}

	if .ParseTime {
		writeDSNParam(&, &, "parseTime", "true")
	}

	if .timeTruncate > 0 {
		writeDSNParam(&, &, "timeTruncate", .timeTruncate.String())
	}

	if .ReadTimeout > 0 {
		writeDSNParam(&, &, "readTimeout", .ReadTimeout.String())
	}

	if .RejectReadOnly {
		writeDSNParam(&, &, "rejectReadOnly", "true")
	}

	if len(.ServerPubKey) > 0 {
		writeDSNParam(&, &, "serverPubKey", url.QueryEscape(.ServerPubKey))
	}

	if .Timeout > 0 {
		writeDSNParam(&, &, "timeout", .Timeout.String())
	}

	if len(.TLSConfig) > 0 {
		writeDSNParam(&, &, "tls", url.QueryEscape(.TLSConfig))
	}

	if .WriteTimeout > 0 {
		writeDSNParam(&, &, "writeTimeout", .WriteTimeout.String())
	}

	if .MaxAllowedPacket != defaultMaxAllowedPacket {
		writeDSNParam(&, &, "maxAllowedPacket", strconv.Itoa(.MaxAllowedPacket))
	}

	// other params
	if .Params != nil {
		var  []string
		for  := range .Params {
			 = append(, )
		}
		sort.Strings()
		for ,  := range  {
			writeDSNParam(&, &, , url.QueryEscape(.Params[]))
		}
	}

	return .String()
}

// ParseDSN parses the DSN string to a Config
func ( string) ( *Config,  error) {
	// New config with some default values
	 = NewConfig()

	// [user[:password]@][net[(addr)]]/dbname[?param1=value1&paramN=valueN]
	// Find the last '/' (since the password or the net addr might contain a '/')
	 := false
	for  := len() - 1;  >= 0; -- {
		if [] == '/' {
			 = true
			var ,  int

			// left part is empty if i <= 0
			if  > 0 {
				// [username[:password]@][protocol[(address)]]
				// Find the last '@' in dsn[:i]
				for  = ;  >= 0; -- {
					if [] == '@' {
						// username[:password]
						// Find the first ':' in dsn[:j]
						for  = 0;  < ; ++ {
							if [] == ':' {
								.Passwd = [+1 : ]
								break
							}
						}
						.User = [:]

						break
					}
				}

				// [protocol[(address)]]
				// Find the first '(' in dsn[j+1:i]
				for  =  + 1;  < ; ++ {
					if [] == '(' {
						// dsn[i-1] must be == ')' if an address is specified
						if [-1] != ')' {
							if strings.ContainsRune([+1:], ')') {
								return nil, errInvalidDSNUnescaped
							}
							return nil, errInvalidDSNAddr
						}
						.Addr = [+1 : -1]
						break
					}
				}
				.Net = [+1 : ]
			}

			// dbname[?param1=value1&...&paramN=valueN]
			// Find the first '?' in dsn[i+1:]
			for  =  + 1;  < len(); ++ {
				if [] == '?' {
					if  = parseDSNParams(, [+1:]);  != nil {
						return
					}
					break
				}
			}

			 := [+1 : ]
			if .DBName,  = url.PathUnescape();  != nil {
				return nil, fmt.Errorf("invalid dbname %q: %w", , )
			}

			break
		}
	}

	if ! && len() > 0 {
		return nil, errInvalidDSNNoSlash
	}

	if  = .normalize();  != nil {
		return nil, 
	}
	return
}

// parseDSNParams parses the DSN "query string"
// Values must be url.QueryEscape'ed
func parseDSNParams( *Config,  string) ( error) {
	for ,  := range strings.Split(, "&") {
		, ,  := strings.Cut(, "=")
		if ! {
			continue
		}

		// cfg params
		switch  {
		// Disable INFILE allowlist / enable all files
		case "allowAllFiles":
			var  bool
			.AllowAllFiles,  = readBool()
			if ! {
				return errors.New("invalid bool value: " + )
			}

		// Use cleartext authentication mode (MySQL 5.5.10+)
		case "allowCleartextPasswords":
			var  bool
			.AllowCleartextPasswords,  = readBool()
			if ! {
				return errors.New("invalid bool value: " + )
			}

		// Allow fallback to unencrypted connection if server does not support TLS
		case "allowFallbackToPlaintext":
			var  bool
			.AllowFallbackToPlaintext,  = readBool()
			if ! {
				return errors.New("invalid bool value: " + )
			}

		// Use native password authentication
		case "allowNativePasswords":
			var  bool
			.AllowNativePasswords,  = readBool()
			if ! {
				return errors.New("invalid bool value: " + )
			}

		// Use old authentication mode (pre MySQL 4.1)
		case "allowOldPasswords":
			var  bool
			.AllowOldPasswords,  = readBool()
			if ! {
				return errors.New("invalid bool value: " + )
			}

		// Check connections for Liveness before using them
		case "checkConnLiveness":
			var  bool
			.CheckConnLiveness,  = readBool()
			if ! {
				return errors.New("invalid bool value: " + )
			}

		// Switch "rowsAffected" mode
		case "clientFoundRows":
			var  bool
			.ClientFoundRows,  = readBool()
			if ! {
				return errors.New("invalid bool value: " + )
			}

		// Collation
		case "collation":
			.Collation = 

		case "columnsWithAlias":
			var  bool
			.ColumnsWithAlias,  = readBool()
			if ! {
				return errors.New("invalid bool value: " + )
			}

		// Compression
		case "compress":
			return errors.New("compression not implemented yet")

		// Enable client side placeholder substitution
		case "interpolateParams":
			var  bool
			.InterpolateParams,  = readBool()
			if ! {
				return errors.New("invalid bool value: " + )
			}

		// Time Location
		case "loc":
			if ,  = url.QueryUnescape();  != nil {
				return
			}
			.Loc,  = time.LoadLocation()
			if  != nil {
				return
			}

		// multiple statements in one query
		case "multiStatements":
			var  bool
			.MultiStatements,  = readBool()
			if ! {
				return errors.New("invalid bool value: " + )
			}

		// time.Time parsing
		case "parseTime":
			var  bool
			.ParseTime,  = readBool()
			if ! {
				return errors.New("invalid bool value: " + )
			}

		// time.Time truncation
		case "timeTruncate":
			.timeTruncate,  = time.ParseDuration()
			if  != nil {
				return fmt.Errorf("invalid timeTruncate value: %v, error: %w", , )
			}

		// I/O read Timeout
		case "readTimeout":
			.ReadTimeout,  = time.ParseDuration()
			if  != nil {
				return
			}

		// Reject read-only connections
		case "rejectReadOnly":
			var  bool
			.RejectReadOnly,  = readBool()
			if ! {
				return errors.New("invalid bool value: " + )
			}

		// Server public key
		case "serverPubKey":
			,  := url.QueryUnescape()
			if  != nil {
				return fmt.Errorf("invalid value for server pub key name: %v", )
			}
			.ServerPubKey = 

		// Strict mode
		case "strict":
			panic("strict mode has been removed. See https://github.com/go-sql-driver/mysql/wiki/strict-mode")

		// Dial Timeout
		case "timeout":
			.Timeout,  = time.ParseDuration()
			if  != nil {
				return
			}

		// TLS-Encryption
		case "tls":
			,  := readBool()
			if  {
				if  {
					.TLSConfig = "true"
				} else {
					.TLSConfig = "false"
				}
			} else if  := strings.ToLower();  == "skip-verify" ||  == "preferred" {
				.TLSConfig = 
			} else {
				,  := url.QueryUnescape()
				if  != nil {
					return fmt.Errorf("invalid value for TLS config name: %v", )
				}
				.TLSConfig = 
			}

		// I/O write Timeout
		case "writeTimeout":
			.WriteTimeout,  = time.ParseDuration()
			if  != nil {
				return
			}
		case "maxAllowedPacket":
			.MaxAllowedPacket,  = strconv.Atoi()
			if  != nil {
				return
			}

		// Connection attributes
		case "connectionAttributes":
			,  := url.QueryUnescape()
			if  != nil {
				return fmt.Errorf("invalid connectionAttributes value: %v", )
			}
			.ConnectionAttributes = 

		default:
			// lazy init
			if .Params == nil {
				.Params = make(map[string]string)
			}

			if .Params[],  = url.QueryUnescape();  != nil {
				return
			}
		}
	}

	return
}

func ensureHavePort( string) string {
	if , ,  := net.SplitHostPort();  != nil {
		return net.JoinHostPort(, "3306")
	}
	return 
}