package mysql
import (
"database/sql"
"fmt"
"strconv"
"strings"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/migrator"
"gorm.io/gorm/schema"
)
const indexSql = `
SELECT
TABLE_NAME,
COLUMN_NAME,
INDEX_NAME,
NON_UNIQUE
FROM
information_schema.STATISTICS
WHERE
TABLE_SCHEMA = ?
AND TABLE_NAME = ?
ORDER BY
INDEX_NAME,
SEQ_IN_INDEX`
var typeAliasMap = map [string ][]string {
"bool" : {"tinyint" },
"tinyint" : {"bool" },
}
type Migrator struct {
migrator .Migrator
Dialector
}
func (m Migrator ) FullDataTypeOf (field *schema .Field ) clause .Expr {
expr := m .Migrator .FullDataTypeOf (field )
if value , ok := field .TagSettings ["COMMENT" ]; ok {
expr .SQL += " COMMENT " + m .Dialector .Explain ("?" , value )
}
return expr
}
func (m Migrator ) MigrateColumnUnique (value interface {}, field *schema .Field , columnType gorm .ColumnType ) error {
unique , ok := columnType .Unique ()
if !ok || field .PrimaryKey {
return nil
}
queryTx , execTx := m .GetQueryAndExecTx ()
return m .RunWithValue (value , func (stmt *gorm .Statement ) error {
constraint := m .DB .NamingStrategy .UniqueName (stmt .Table , field .DBName )
if unique {
indexes , _ := queryTx .Migrator ().GetIndexes (value )
for _ , index := range indexes {
if uni , ok := index .Unique (); !ok || !uni {
continue
}
if columns := index .Columns (); len (columns ) != 1 || columns [0 ] != field .DBName {
continue
}
if name := index .Name (); name == constraint || name == field .UniqueIndex {
continue
}
if err := execTx .Migrator ().DropIndex (value , index .Name ()); err != nil {
return err
}
}
hasConstraint := queryTx .Migrator ().HasConstraint (value , constraint )
switch {
case field .Unique && !hasConstraint :
if field .Unique {
if err := execTx .Migrator ().CreateConstraint (value , constraint ); err != nil {
return err
}
}
case !field .Unique && hasConstraint :
if err := execTx .Migrator ().DropConstraint (value , constraint ); err != nil {
return err
}
if field .UniqueIndex != "" {
if err := execTx .Migrator ().CreateIndex (value , field .UniqueIndex ); err != nil {
return err
}
}
}
if field .UniqueIndex != "" && !queryTx .Migrator ().HasIndex (value , field .UniqueIndex ) {
if err := execTx .Migrator ().CreateIndex (value , field .UniqueIndex ); err != nil {
return err
}
}
} else {
if field .Unique {
if err := execTx .Migrator ().CreateConstraint (value , constraint ); err != nil {
return err
}
}
if field .UniqueIndex != "" {
if err := execTx .Migrator ().CreateIndex (value , field .UniqueIndex ); err != nil {
return err
}
}
}
return nil
})
}
func (m Migrator ) AddColumn (value interface {}, name string ) error {
return m .RunWithValue (value , func (stmt *gorm .Statement ) error {
f := stmt .Schema .LookUpField (name )
if f == nil {
return fmt .Errorf ("failed to look up field with name: %s" , name )
}
if !f .IgnoreMigration {
fieldType := m .FullDataTypeOf (f )
columnName := clause .Column {Name : f .DBName }
values := []interface {}{m .CurrentTable (stmt ), columnName , fieldType }
var alterSql strings .Builder
alterSql .WriteString ("ALTER TABLE ? ADD ? ?" )
if f .PrimaryKey || strings .Contains (strings .ToLower (fieldType .SQL ), "auto_increment" ) {
alterSql .WriteString (", ADD PRIMARY KEY (?)" )
values = append (values , columnName )
}
return m .DB .Exec (alterSql .String (), values ...).Error
}
return nil
})
}
func (m Migrator ) AlterColumn (value interface {}, field string ) error {
return m .RunWithValue (value , func (stmt *gorm .Statement ) error {
if stmt .Schema != nil {
if field := stmt .Schema .LookUpField (field ); field != nil {
fullDataType := m .FullDataTypeOf (field )
if m .Dialector .DontSupportRenameColumnUnique {
fullDataType .SQL = strings .Replace (fullDataType .SQL , " UNIQUE " , " " , 1 )
}
return m .DB .Exec (
"ALTER TABLE ? MODIFY COLUMN ? ?" ,
clause .Table {Name : stmt .Table }, clause .Column {Name : field .DBName }, fullDataType ,
).Error
}
}
return fmt .Errorf ("failed to look up field with name: %s" , field )
})
}
func (m Migrator ) TiDBVersion () (isTiDB bool , major , minor , patch int , err error ) {
tidbVersionArray := strings .Split (m .Dialector .ServerVersion , "-" )
if len (tidbVersionArray ) < 3 || tidbVersionArray [1 ] != "TiDB" {
return
}
rawVersion := strings .TrimPrefix (tidbVersionArray [2 ], "v" )
realVersionArray := strings .Split (rawVersion , "." )
if major , err = strconv .Atoi (realVersionArray [0 ]); err != nil {
err = fmt .Errorf ("failed to parse the version of TiDB, the major version is: %s" , realVersionArray [0 ])
return
}
if minor , err = strconv .Atoi (realVersionArray [1 ]); err != nil {
err = fmt .Errorf ("failed to parse the version of TiDB, the minor version is: %s" , realVersionArray [1 ])
return
}
if patch , err = strconv .Atoi (realVersionArray [2 ]); err != nil {
err = fmt .Errorf ("failed to parse the version of TiDB, the patch version is: %s" , realVersionArray [2 ])
return
}
isTiDB = true
return
}
func (m Migrator ) RenameColumn (value interface {}, oldName , newName string ) error {
return m .RunWithValue (value , func (stmt *gorm .Statement ) error {
if !m .Dialector .DontSupportRenameColumn {
return m .Migrator .RenameColumn (value , oldName , newName )
}
var field *schema .Field
if stmt .Schema != nil {
if f := stmt .Schema .LookUpField (oldName ); f != nil {
oldName = f .DBName
field = f
}
if f := stmt .Schema .LookUpField (newName ); f != nil {
newName = f .DBName
field = f
}
}
if field != nil {
return m .DB .Exec (
"ALTER TABLE ? CHANGE ? ? ?" ,
clause .Table {Name : stmt .Table }, clause .Column {Name : oldName },
clause .Column {Name : newName }, m .FullDataTypeOf (field ),
).Error
}
return fmt .Errorf ("failed to look up field with name: %s" , newName )
})
}
func (m Migrator ) DropConstraint (value interface {}, name string ) error {
if !m .Dialector .Config .DontSupportDropConstraint {
return m .Migrator .DropConstraint (value , name )
}
return m .RunWithValue (value , func (stmt *gorm .Statement ) error {
constraint , table := m .GuessConstraintInterfaceAndTable (stmt , name )
if constraint != nil {
name = constraint .GetName ()
switch constraint .(type ) {
case *schema .Constraint :
return m .DB .Exec ("ALTER TABLE ? DROP FOREIGN KEY ?" , clause .Table {Name : table }, clause .Column {Name : name }).Error
case *schema .CheckConstraint :
return m .DB .Exec ("ALTER TABLE ? DROP CHECK ?" , clause .Table {Name : table }, clause .Column {Name : name }).Error
}
}
if m .HasIndex (value , name ) {
return m .DB .Exec ("ALTER TABLE ? DROP INDEX ?" , clause .Table {Name : table }, clause .Column {Name : name }).Error
}
return nil
})
}
func (m Migrator ) RenameIndex (value interface {}, oldName , newName string ) error {
if !m .Dialector .DontSupportRenameIndex {
return m .RunWithValue (value , func (stmt *gorm .Statement ) error {
return m .DB .Exec (
"ALTER TABLE ? RENAME INDEX ? TO ?" ,
clause .Table {Name : stmt .Table }, clause .Column {Name : oldName }, clause .Column {Name : newName },
).Error
})
}
return m .RunWithValue (value , func (stmt *gorm .Statement ) error {
err := m .DropIndex (value , oldName )
if err != nil {
return err
}
if stmt .Schema != nil {
if idx := stmt .Schema .LookIndex (newName ); idx == nil {
if idx = stmt .Schema .LookIndex (oldName ); idx != nil {
opts := m .BuildIndexOptions (idx .Fields , stmt )
values := []interface {}{clause .Column {Name : newName }, clause .Table {Name : stmt .Table }, opts }
createIndexSQL := "CREATE "
if idx .Class != "" {
createIndexSQL += idx .Class + " "
}
createIndexSQL += "INDEX ? ON ??"
if idx .Type != "" {
createIndexSQL += " USING " + idx .Type
}
return m .DB .Exec (createIndexSQL , values ...).Error
}
}
}
return m .CreateIndex (value , newName )
})
}
func (m Migrator ) DropTable (values ...interface {}) error {
values = m .ReorderModels (values , false )
return m .DB .Connection (func (tx *gorm .DB ) error {
tx .Exec ("SET FOREIGN_KEY_CHECKS = 0;" )
for i := len (values ) - 1 ; i >= 0 ; i -- {
if err := m .RunWithValue (values [i ], func (stmt *gorm .Statement ) error {
return tx .Exec ("DROP TABLE IF EXISTS ? CASCADE" , clause .Table {Name : stmt .Table }).Error
}); err != nil {
return err
}
}
return tx .Exec ("SET FOREIGN_KEY_CHECKS = 1;" ).Error
})
}
func (m Migrator ) ColumnTypes (value interface {}) ([]gorm .ColumnType , error ) {
columnTypes := make ([]gorm .ColumnType , 0 )
err := m .RunWithValue (value , func (stmt *gorm .Statement ) error {
var (
currentDatabase , table = m .CurrentSchema (stmt , stmt .Table )
columnTypeSQL = "SELECT column_name, column_default, is_nullable = 'YES', data_type, character_maximum_length, column_type, column_key, extra, column_comment, numeric_precision, numeric_scale "
rows , err = m .DB .Session (&gorm .Session {}).Table (table ).Limit (1 ).Rows ()
)
if err != nil {
return err
}
rawColumnTypes , err := rows .ColumnTypes ()
if err != nil {
return err
}
if err := rows .Close (); err != nil {
return err
}
if !m .DisableDatetimePrecision {
columnTypeSQL += ", datetime_precision "
}
columnTypeSQL += "FROM information_schema.columns WHERE table_schema = ? AND table_name = ? ORDER BY ORDINAL_POSITION"
columns , rowErr := m .DB .Table (table ).Raw (columnTypeSQL , currentDatabase , table ).Rows ()
if rowErr != nil {
return rowErr
}
defer columns .Close ()
for columns .Next () {
var (
column migrator .ColumnType
datetimePrecision sql .NullInt64
extraValue sql .NullString
columnKey sql .NullString
values = []interface {}{
&column .NameValue , &column .DefaultValueValue , &column .NullableValue , &column .DataTypeValue , &column .LengthValue , &column .ColumnTypeValue , &columnKey , &extraValue , &column .CommentValue , &column .DecimalSizeValue , &column .ScaleValue ,
}
)
if !m .DisableDatetimePrecision {
values = append (values , &datetimePrecision )
}
if scanErr := columns .Scan (values ...); scanErr != nil {
return scanErr
}
column .PrimaryKeyValue = sql .NullBool {Bool : false , Valid : true }
column .UniqueValue = sql .NullBool {Bool : false , Valid : true }
switch columnKey .String {
case "PRI" :
column .PrimaryKeyValue = sql .NullBool {Bool : true , Valid : true }
case "UNI" :
column .UniqueValue = sql .NullBool {Bool : true , Valid : true }
}
if strings .Contains (extraValue .String , "auto_increment" ) {
column .AutoIncrementValue = sql .NullBool {Bool : true , Valid : true }
}
s := column .DefaultValueValue .String
for (len (s ) >= 3 && s [0 ] == '\'' && s [len (s )-1 ] == '\'' && s [len (s )-2 ] != '\\' ) ||
(len (s ) == 2 && s == "''" ) {
s = s [1 : len (s )-1 ]
}
column .DefaultValueValue .String = s
if m .Dialector .DontSupportNullAsDefaultValue {
if column .DefaultValueValue .Valid && column .DefaultValueValue .String == "NULL" {
column .DefaultValueValue .Valid = false
column .DefaultValueValue .String = ""
}
}
if datetimePrecision .Valid {
column .DecimalSizeValue = datetimePrecision
}
for _ , c := range rawColumnTypes {
if c .Name () == column .NameValue .String {
column .SQLColumnType = c
break
}
}
columnTypes = append (columnTypes , column )
}
return nil
})
return columnTypes , err
}
func (m Migrator ) CurrentDatabase () (name string ) {
baseName := m .Migrator .CurrentDatabase ()
m .DB .Raw (
"SELECT SCHEMA_NAME from Information_schema.SCHEMATA where SCHEMA_NAME LIKE ? ORDER BY SCHEMA_NAME=? DESC,SCHEMA_NAME limit 1" ,
baseName +"%" , baseName ).Scan (&name )
return
}
func (m Migrator ) GetTables () (tableList []string , err error ) {
err = m .DB .Raw ("SELECT TABLE_NAME FROM information_schema.tables where TABLE_SCHEMA=?" , m .CurrentDatabase ()).
Scan (&tableList ).Error
return
}
func (m Migrator ) GetIndexes (value interface {}) ([]gorm .Index , error ) {
indexes := make ([]gorm .Index , 0 )
err := m .RunWithValue (value , func (stmt *gorm .Statement ) error {
result := make ([]*Index , 0 )
schema , table := m .CurrentSchema (stmt , stmt .Table )
scanErr := m .DB .Table (table ).Raw (indexSql , schema , table ).Scan (&result ).Error
if scanErr != nil {
return scanErr
}
indexMap , indexNames := groupByIndexName (result )
for _ , name := range indexNames {
idx := indexMap [name ]
if len (idx ) == 0 {
continue
}
tempIdx := &migrator .Index {
TableName : idx [0 ].TableName ,
NameValue : idx [0 ].IndexName ,
PrimaryKeyValue : sql .NullBool {
Bool : idx [0 ].IndexName == "PRIMARY" ,
Valid : true ,
},
UniqueValue : sql .NullBool {
Bool : idx [0 ].NonUnique == 0 ,
Valid : true ,
},
}
for _ , x := range idx {
tempIdx .ColumnList = append (tempIdx .ColumnList , x .ColumnName )
}
indexes = append (indexes , tempIdx )
}
return nil
})
return indexes , err
}
type Index struct {
TableName string `gorm:"column:TABLE_NAME"`
ColumnName string `gorm:"column:COLUMN_NAME"`
IndexName string `gorm:"column:INDEX_NAME"`
NonUnique int32 `gorm:"column:NON_UNIQUE"`
}
func groupByIndexName(indexList []*Index ) (map [string ][]*Index , []string ) {
columnIndexMap := make (map [string ][]*Index , len (indexList ))
indexNames := make ([]string , 0 , len (indexList ))
for _ , idx := range indexList {
if _ , ok := columnIndexMap [idx .IndexName ]; !ok {
indexNames = append (indexNames , idx .IndexName )
}
columnIndexMap [idx .IndexName ] = append (columnIndexMap [idx .IndexName ], idx )
}
return columnIndexMap , indexNames
}
func (m Migrator ) CurrentSchema (stmt *gorm .Statement , table string ) (string , string ) {
if tables := strings .Split (table , `.` ); len (tables ) == 2 {
return tables [0 ], tables [1 ]
}
m .DB = m .DB .Table (table )
return m .CurrentDatabase (), table
}
func (m Migrator ) GetTypeAliases (databaseTypeName string ) []string {
return typeAliasMap [databaseTypeName ]
}
func (m Migrator ) TableType (value interface {}) (tableType gorm .TableType , err error ) {
var table migrator .TableType
err = m .RunWithValue (value , func (stmt *gorm .Statement ) error {
var (
values = []interface {}{
&table .SchemaValue , &table .NameValue , &table .TypeValue , &table .CommentValue ,
}
currentDatabase , tableName = m .CurrentSchema (stmt , stmt .Table )
tableTypeSQL = "SELECT table_schema, table_name, table_type, table_comment FROM information_schema.tables WHERE table_schema = ? AND table_name = ?"
)
row := m .DB .Table (tableName ).Raw (tableTypeSQL , currentDatabase , tableName ).Row ()
if scanErr := row .Scan (values ...); scanErr != nil {
return scanErr
}
return nil
})
return table , err
}
The pages are generated with Golds v0.8.2 . (GOOS=linux GOARCH=amd64)
Golds is a Go 101 project developed by Tapir Liu .
PR and bug reports are welcome and can be submitted to the issue list .
Please follow @zigo_101 (reachable from the left QR code) to get the latest news of Golds .