package mysql
import (
"bytes"
"context"
"crypto/rsa"
"crypto/tls"
"errors"
"fmt"
"math/big"
"net"
"net/url"
"sort"
"strconv"
"strings"
"time"
)
var (
errInvalidDSNUnescaped = errors .New ("invalid DSN: did you forget to escape a param value?" )
errInvalidDSNAddr = errors .New ("invalid DSN: network address not terminated (missing closing brace)" )
errInvalidDSNNoSlash = errors .New ("invalid DSN: missing the slash separating the database name" )
errInvalidDSNUnsafeCollation = errors .New ("invalid DSN: interpolateParams can not be used with unsafe collations" )
)
type Config struct {
User string
Passwd string
Net string
Addr string
DBName string
Params map [string ]string
ConnectionAttributes string
Collation string
Loc *time .Location
MaxAllowedPacket int
ServerPubKey string
TLSConfig string
TLS *tls .Config
Timeout time .Duration
ReadTimeout time .Duration
WriteTimeout time .Duration
Logger Logger
AllowAllFiles bool
AllowCleartextPasswords bool
AllowFallbackToPlaintext bool
AllowNativePasswords bool
AllowOldPasswords bool
CheckConnLiveness bool
ClientFoundRows bool
ColumnsWithAlias bool
InterpolateParams bool
MultiStatements bool
ParseTime bool
RejectReadOnly bool
beforeConnect func (context .Context , *Config ) error
pubKey *rsa .PublicKey
timeTruncate time .Duration
}
type Option func (*Config ) error
func NewConfig () *Config {
cfg := &Config {
Loc : time .UTC ,
MaxAllowedPacket : defaultMaxAllowedPacket ,
Logger : defaultLogger ,
AllowNativePasswords : true ,
CheckConnLiveness : true ,
}
return cfg
}
func (c *Config ) Apply (opts ...Option ) error {
for _ , opt := range opts {
err := opt (c )
if err != nil {
return err
}
}
return nil
}
func TimeTruncate (d time .Duration ) Option {
return func (cfg *Config ) error {
cfg .timeTruncate = d
return nil
}
}
func BeforeConnect (fn func (context .Context , *Config ) error ) Option {
return func (cfg *Config ) error {
cfg .beforeConnect = fn
return nil
}
}
func (cfg *Config ) Clone () *Config {
cp := *cfg
if cp .TLS != nil {
cp .TLS = cfg .TLS .Clone ()
}
if len (cp .Params ) > 0 {
cp .Params = make (map [string ]string , len (cfg .Params ))
for k , v := range cfg .Params {
cp .Params [k ] = v
}
}
if cfg .pubKey != nil {
cp .pubKey = &rsa .PublicKey {
N : new (big .Int ).Set (cfg .pubKey .N ),
E : cfg .pubKey .E ,
}
}
return &cp
}
func (cfg *Config ) normalize () error {
if cfg .InterpolateParams && cfg .Collation != "" && unsafeCollations [cfg .Collation ] {
return errInvalidDSNUnsafeCollation
}
if cfg .Net == "" {
cfg .Net = "tcp"
}
if cfg .Addr == "" {
switch cfg .Net {
case "tcp" :
cfg .Addr = "127.0.0.1:3306"
case "unix" :
cfg .Addr = "/tmp/mysql.sock"
default :
return errors .New ("default addr for network '" + cfg .Net + "' unknown" )
}
} else if cfg .Net == "tcp" {
cfg .Addr = ensureHavePort (cfg .Addr )
}
if cfg .TLS == nil {
switch cfg .TLSConfig {
case "false" , "" :
case "true" :
cfg .TLS = &tls .Config {}
case "skip-verify" :
cfg .TLS = &tls .Config {InsecureSkipVerify : true }
case "preferred" :
cfg .TLS = &tls .Config {InsecureSkipVerify : true }
cfg .AllowFallbackToPlaintext = true
default :
cfg .TLS = getTLSConfigClone (cfg .TLSConfig )
if cfg .TLS == nil {
return errors .New ("invalid value / unknown config name: " + cfg .TLSConfig )
}
}
}
if cfg .TLS != nil && cfg .TLS .ServerName == "" && !cfg .TLS .InsecureSkipVerify {
host , _ , err := net .SplitHostPort (cfg .Addr )
if err == nil {
cfg .TLS .ServerName = host
}
}
if cfg .ServerPubKey != "" {
cfg .pubKey = getServerPubKey (cfg .ServerPubKey )
if cfg .pubKey == nil {
return errors .New ("invalid value / unknown server pub key name: " + cfg .ServerPubKey )
}
}
if cfg .Logger == nil {
cfg .Logger = defaultLogger
}
return nil
}
func writeDSNParam(buf *bytes .Buffer , hasParam *bool , name , value string ) {
buf .Grow (1 + len (name ) + 1 + len (value ))
if !*hasParam {
*hasParam = true
buf .WriteByte ('?' )
} else {
buf .WriteByte ('&' )
}
buf .WriteString (name )
buf .WriteByte ('=' )
buf .WriteString (value )
}
func (cfg *Config ) FormatDSN () string {
var buf bytes .Buffer
if len (cfg .User ) > 0 {
buf .WriteString (cfg .User )
if len (cfg .Passwd ) > 0 {
buf .WriteByte (':' )
buf .WriteString (cfg .Passwd )
}
buf .WriteByte ('@' )
}
if len (cfg .Net ) > 0 {
buf .WriteString (cfg .Net )
if len (cfg .Addr ) > 0 {
buf .WriteByte ('(' )
buf .WriteString (cfg .Addr )
buf .WriteByte (')' )
}
}
buf .WriteByte ('/' )
buf .WriteString (url .PathEscape (cfg .DBName ))
hasParam := false
if cfg .AllowAllFiles {
hasParam = true
buf .WriteString ("?allowAllFiles=true" )
}
if cfg .AllowCleartextPasswords {
writeDSNParam (&buf , &hasParam , "allowCleartextPasswords" , "true" )
}
if cfg .AllowFallbackToPlaintext {
writeDSNParam (&buf , &hasParam , "allowFallbackToPlaintext" , "true" )
}
if !cfg .AllowNativePasswords {
writeDSNParam (&buf , &hasParam , "allowNativePasswords" , "false" )
}
if cfg .AllowOldPasswords {
writeDSNParam (&buf , &hasParam , "allowOldPasswords" , "true" )
}
if !cfg .CheckConnLiveness {
writeDSNParam (&buf , &hasParam , "checkConnLiveness" , "false" )
}
if cfg .ClientFoundRows {
writeDSNParam (&buf , &hasParam , "clientFoundRows" , "true" )
}
if col := cfg .Collation ; col != "" {
writeDSNParam (&buf , &hasParam , "collation" , col )
}
if cfg .ColumnsWithAlias {
writeDSNParam (&buf , &hasParam , "columnsWithAlias" , "true" )
}
if cfg .InterpolateParams {
writeDSNParam (&buf , &hasParam , "interpolateParams" , "true" )
}
if cfg .Loc != time .UTC && cfg .Loc != nil {
writeDSNParam (&buf , &hasParam , "loc" , url .QueryEscape (cfg .Loc .String ()))
}
if cfg .MultiStatements {
writeDSNParam (&buf , &hasParam , "multiStatements" , "true" )
}
if cfg .ParseTime {
writeDSNParam (&buf , &hasParam , "parseTime" , "true" )
}
if cfg .timeTruncate > 0 {
writeDSNParam (&buf , &hasParam , "timeTruncate" , cfg .timeTruncate .String ())
}
if cfg .ReadTimeout > 0 {
writeDSNParam (&buf , &hasParam , "readTimeout" , cfg .ReadTimeout .String ())
}
if cfg .RejectReadOnly {
writeDSNParam (&buf , &hasParam , "rejectReadOnly" , "true" )
}
if len (cfg .ServerPubKey ) > 0 {
writeDSNParam (&buf , &hasParam , "serverPubKey" , url .QueryEscape (cfg .ServerPubKey ))
}
if cfg .Timeout > 0 {
writeDSNParam (&buf , &hasParam , "timeout" , cfg .Timeout .String ())
}
if len (cfg .TLSConfig ) > 0 {
writeDSNParam (&buf , &hasParam , "tls" , url .QueryEscape (cfg .TLSConfig ))
}
if cfg .WriteTimeout > 0 {
writeDSNParam (&buf , &hasParam , "writeTimeout" , cfg .WriteTimeout .String ())
}
if cfg .MaxAllowedPacket != defaultMaxAllowedPacket {
writeDSNParam (&buf , &hasParam , "maxAllowedPacket" , strconv .Itoa (cfg .MaxAllowedPacket ))
}
if cfg .Params != nil {
var params []string
for param := range cfg .Params {
params = append (params , param )
}
sort .Strings (params )
for _ , param := range params {
writeDSNParam (&buf , &hasParam , param , url .QueryEscape (cfg .Params [param ]))
}
}
return buf .String ()
}
func ParseDSN (dsn string ) (cfg *Config , err error ) {
cfg = NewConfig ()
foundSlash := false
for i := len (dsn ) - 1 ; i >= 0 ; i -- {
if dsn [i ] == '/' {
foundSlash = true
var j , k int
if i > 0 {
for j = i ; j >= 0 ; j -- {
if dsn [j ] == '@' {
for k = 0 ; k < j ; k ++ {
if dsn [k ] == ':' {
cfg .Passwd = dsn [k +1 : j ]
break
}
}
cfg .User = dsn [:k ]
break
}
}
for k = j + 1 ; k < i ; k ++ {
if dsn [k ] == '(' {
if dsn [i -1 ] != ')' {
if strings .ContainsRune (dsn [k +1 :i ], ')' ) {
return nil , errInvalidDSNUnescaped
}
return nil , errInvalidDSNAddr
}
cfg .Addr = dsn [k +1 : i -1 ]
break
}
}
cfg .Net = dsn [j +1 : k ]
}
for j = i + 1 ; j < len (dsn ); j ++ {
if dsn [j ] == '?' {
if err = parseDSNParams (cfg , dsn [j +1 :]); err != nil {
return
}
break
}
}
dbname := dsn [i +1 : j ]
if cfg .DBName , err = url .PathUnescape (dbname ); err != nil {
return nil , fmt .Errorf ("invalid dbname %q: %w" , dbname , err )
}
break
}
}
if !foundSlash && len (dsn ) > 0 {
return nil , errInvalidDSNNoSlash
}
if err = cfg .normalize (); err != nil {
return nil , err
}
return
}
func parseDSNParams(cfg *Config , params string ) (err error ) {
for _ , v := range strings .Split (params , "&" ) {
key , value , found := strings .Cut (v , "=" )
if !found {
continue
}
switch key {
case "allowAllFiles" :
var isBool bool
cfg .AllowAllFiles , isBool = readBool (value )
if !isBool {
return errors .New ("invalid bool value: " + value )
}
case "allowCleartextPasswords" :
var isBool bool
cfg .AllowCleartextPasswords , isBool = readBool (value )
if !isBool {
return errors .New ("invalid bool value: " + value )
}
case "allowFallbackToPlaintext" :
var isBool bool
cfg .AllowFallbackToPlaintext , isBool = readBool (value )
if !isBool {
return errors .New ("invalid bool value: " + value )
}
case "allowNativePasswords" :
var isBool bool
cfg .AllowNativePasswords , isBool = readBool (value )
if !isBool {
return errors .New ("invalid bool value: " + value )
}
case "allowOldPasswords" :
var isBool bool
cfg .AllowOldPasswords , isBool = readBool (value )
if !isBool {
return errors .New ("invalid bool value: " + value )
}
case "checkConnLiveness" :
var isBool bool
cfg .CheckConnLiveness , isBool = readBool (value )
if !isBool {
return errors .New ("invalid bool value: " + value )
}
case "clientFoundRows" :
var isBool bool
cfg .ClientFoundRows , isBool = readBool (value )
if !isBool {
return errors .New ("invalid bool value: " + value )
}
case "collation" :
cfg .Collation = value
case "columnsWithAlias" :
var isBool bool
cfg .ColumnsWithAlias , isBool = readBool (value )
if !isBool {
return errors .New ("invalid bool value: " + value )
}
case "compress" :
return errors .New ("compression not implemented yet" )
case "interpolateParams" :
var isBool bool
cfg .InterpolateParams , isBool = readBool (value )
if !isBool {
return errors .New ("invalid bool value: " + value )
}
case "loc" :
if value , err = url .QueryUnescape (value ); err != nil {
return
}
cfg .Loc , err = time .LoadLocation (value )
if err != nil {
return
}
case "multiStatements" :
var isBool bool
cfg .MultiStatements , isBool = readBool (value )
if !isBool {
return errors .New ("invalid bool value: " + value )
}
case "parseTime" :
var isBool bool
cfg .ParseTime , isBool = readBool (value )
if !isBool {
return errors .New ("invalid bool value: " + value )
}
case "timeTruncate" :
cfg .timeTruncate , err = time .ParseDuration (value )
if err != nil {
return fmt .Errorf ("invalid timeTruncate value: %v, error: %w" , value , err )
}
case "readTimeout" :
cfg .ReadTimeout , err = time .ParseDuration (value )
if err != nil {
return
}
case "rejectReadOnly" :
var isBool bool
cfg .RejectReadOnly , isBool = readBool (value )
if !isBool {
return errors .New ("invalid bool value: " + value )
}
case "serverPubKey" :
name , err := url .QueryUnescape (value )
if err != nil {
return fmt .Errorf ("invalid value for server pub key name: %v" , err )
}
cfg .ServerPubKey = name
case "strict" :
panic ("strict mode has been removed. See https://github.com/go-sql-driver/mysql/wiki/strict-mode" )
case "timeout" :
cfg .Timeout , err = time .ParseDuration (value )
if err != nil {
return
}
case "tls" :
boolValue , isBool := readBool (value )
if isBool {
if boolValue {
cfg .TLSConfig = "true"
} else {
cfg .TLSConfig = "false"
}
} else if vl := strings .ToLower (value ); vl == "skip-verify" || vl == "preferred" {
cfg .TLSConfig = vl
} else {
name , err := url .QueryUnescape (value )
if err != nil {
return fmt .Errorf ("invalid value for TLS config name: %v" , err )
}
cfg .TLSConfig = name
}
case "writeTimeout" :
cfg .WriteTimeout , err = time .ParseDuration (value )
if err != nil {
return
}
case "maxAllowedPacket" :
cfg .MaxAllowedPacket , err = strconv .Atoi (value )
if err != nil {
return
}
case "connectionAttributes" :
connectionAttributes , err := url .QueryUnescape (value )
if err != nil {
return fmt .Errorf ("invalid connectionAttributes value: %v" , err )
}
cfg .ConnectionAttributes = connectionAttributes
default :
if cfg .Params == nil {
cfg .Params = make (map [string ]string )
}
if cfg .Params [key ], err = url .QueryUnescape (value ); err != nil {
return
}
}
}
return
}
func ensureHavePort(addr string ) string {
if _ , _ , err := net .SplitHostPort (addr ); err != nil {
return net .JoinHostPort (addr , "3306" )
}
return addr
}
The pages are generated with Golds v0.8.2 . (GOOS=linux GOARCH=amd64)
Golds is a Go 101 project developed by Tapir Liu .
PR and bug reports are welcome and can be submitted to the issue list .
Please follow @zigo_101 (reachable from the left QR code) to get the latest news of Golds .