fix: create database name (#1285)

pull/1295/head
kvii 1 year ago committed by GitHub
parent 38ab3e0ed7
commit 7722714251
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -18,15 +18,12 @@ import (
"fmt" "fmt"
"time" "time"
mysqldriver "github.com/go-sql-driver/mysql"
"gorm.io/driver/mysql"
"github.com/OpenIMSDK/tools/errs" "github.com/OpenIMSDK/tools/errs"
"github.com/OpenIMSDK/tools/log" "github.com/OpenIMSDK/tools/log"
"github.com/OpenIMSDK/tools/mw/specialerror" "github.com/OpenIMSDK/tools/mw/specialerror"
mysqldriver "github.com/go-sql-driver/mysql"
"github.com/openimsdk/open-im-server/v3/pkg/common/config" "github.com/openimsdk/open-im-server/v3/pkg/common/config"
"gorm.io/driver/mysql"
"gorm.io/gorm" "gorm.io/gorm"
"gorm.io/gorm/logger" "gorm.io/gorm/logger"
) )
@ -35,56 +32,80 @@ const (
maxRetry = 100 // number of retries maxRetry = 100 // number of retries
) )
// newMysqlGormDB Initialize the database connection. type option struct {
func newMysqlGormDB() (*gorm.DB, error) { Username string
dsn := fmt.Sprintf("%s:%s@tcp(%s)/%s?charset=utf8mb4&parseTime=true&loc=Local", Password string
config.Config.Mysql.Username, config.Config.Mysql.Password, config.Config.Mysql.Address[0], "mysql") Address []string
Database string
LogLevel int
SlowThreshold int
MaxLifeTime int
MaxOpenConn int
MaxIdleConn int
Connect func(dsn string, maxRetry int) (*gorm.DB, error)
}
db, err := connectToDatabase(dsn, maxRetry) // newMysqlGormDB Initialize the database connection.
if err != nil { func newMysqlGormDB(o *option) (*gorm.DB, error) {
panic(err.Error() + " Open failed " + dsn) err := maybeCreateTable(o)
}
sqlDB, err := db.DB()
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer sqlDB.Close() dsn := fmt.Sprintf("%s:%s@tcp(%s)/%s?charset=utf8mb4&parseTime=true&loc=Local",
sql := fmt.Sprintf( o.Username, o.Password, o.Address[0], o.Database)
"CREATE DATABASE IF NOT EXISTS %s default charset utf8mb4 COLLATE utf8mb4_unicode_ci;",
config.Config.Mysql.Database,
)
err = db.Exec(sql).Error
if err != nil {
return nil, fmt.Errorf("init db %w", err)
}
dsn = fmt.Sprintf(
"%s:%s@tcp(%s)/%s?charset=utf8mb4&parseTime=true&loc=Local",
config.Config.Mysql.Username,
config.Config.Mysql.Password,
config.Config.Mysql.Address[0],
config.Config.Mysql.Database,
)
sqlLogger := log.NewSqlLogger( sqlLogger := log.NewSqlLogger(
logger.LogLevel(config.Config.Mysql.LogLevel), logger.LogLevel(o.LogLevel),
true, true,
time.Duration(config.Config.Mysql.SlowThreshold)*time.Millisecond, time.Duration(o.SlowThreshold)*time.Millisecond,
) )
db, err = gorm.Open(mysql.Open(dsn), &gorm.Config{ db, err := gorm.Open(mysql.Open(dsn), &gorm.Config{
Logger: sqlLogger, Logger: sqlLogger,
}) })
if err != nil { if err != nil {
return nil, err return nil, err
} }
sqlDB, err = db.DB() sqlDB, err := db.DB()
if err != nil { if err != nil {
return nil, err return nil, err
} }
sqlDB.SetConnMaxLifetime(time.Second * time.Duration(config.Config.Mysql.MaxLifeTime)) sqlDB.SetConnMaxLifetime(time.Second * time.Duration(o.MaxLifeTime))
sqlDB.SetMaxOpenConns(config.Config.Mysql.MaxOpenConn) sqlDB.SetMaxOpenConns(o.MaxOpenConn)
sqlDB.SetMaxIdleConns(config.Config.Mysql.MaxIdleConn) sqlDB.SetMaxIdleConns(o.MaxIdleConn)
return db, nil return db, nil
} }
// maybeCreateTable creates a database if it does not exists.
func maybeCreateTable(o *option) error {
dsn := fmt.Sprintf("%s:%s@tcp(%s)/%s?charset=utf8mb4&parseTime=true&loc=Local",
o.Username, o.Password, o.Address[0], "mysql")
var db *gorm.DB
var err error
if f := o.Connect; f != nil {
db, err = f(dsn, maxRetry)
} else {
db, err = connectToDatabase(dsn, maxRetry)
}
if err != nil {
panic(err.Error() + " Open failed " + dsn)
}
sqlDB, err := db.DB()
if err != nil {
return err
}
defer sqlDB.Close()
sql := fmt.Sprintf(
"CREATE DATABASE IF NOT EXISTS `%s` default charset utf8mb4 COLLATE utf8mb4_unicode_ci",
o.Database,
)
err = db.Exec(sql).Error
if err != nil {
return fmt.Errorf("init db %w", err)
}
return nil
}
// connectToDatabase Connection retry for mysql. // connectToDatabase Connection retry for mysql.
func connectToDatabase(dsn string, maxRetry int) (*gorm.DB, error) { func connectToDatabase(dsn string, maxRetry int) (*gorm.DB, error) {
var db *gorm.DB var db *gorm.DB
@ -106,7 +127,18 @@ func connectToDatabase(dsn string, maxRetry int) (*gorm.DB, error) {
func NewGormDB() (*gorm.DB, error) { func NewGormDB() (*gorm.DB, error) {
specialerror.AddReplace(gorm.ErrRecordNotFound, errs.ErrRecordNotFound) specialerror.AddReplace(gorm.ErrRecordNotFound, errs.ErrRecordNotFound)
specialerror.AddErrHandler(replaceDuplicateKey) specialerror.AddErrHandler(replaceDuplicateKey)
return newMysqlGormDB()
return newMysqlGormDB(&option{
Username: config.Config.Mysql.Username,
Password: config.Config.Mysql.Password,
Address: config.Config.Mysql.Address,
Database: config.Config.Mysql.Database,
LogLevel: config.Config.Mysql.LogLevel,
SlowThreshold: config.Config.Mysql.SlowThreshold,
MaxLifeTime: config.Config.Mysql.MaxLifeTime,
MaxOpenConn: config.Config.Mysql.MaxOpenConn,
MaxIdleConn: config.Config.Mysql.MaxIdleConn,
})
} }
func replaceDuplicateKey(err error) errs.CodeError { func replaceDuplicateKey(err error) errs.CodeError {

@ -0,0 +1,121 @@
package relation
import (
"context"
"database/sql"
"database/sql/driver"
"errors"
"fmt"
"reflect"
"testing"
"gorm.io/driver/mysql"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)
func TestMaybeCreateTable(t *testing.T) {
t.Run("normal", func(t *testing.T) {
err := maybeCreateTable(&option{
Username: "root",
Password: "openIM123",
Address: []string{"172.28.0.1:13306"},
Database: "openIM_v3",
LogLevel: 4,
SlowThreshold: 500,
MaxOpenConn: 1000,
MaxIdleConn: 100,
MaxLifeTime: 60,
Connect: connect(expectExec{
query: "CREATE DATABASE IF NOT EXISTS `openIM_v3` default charset utf8mb4 COLLATE utf8mb4_unicode_ci",
args: nil,
}),
})
if err != nil {
t.Fatal(err)
}
})
t.Run("im-db", func(t *testing.T) {
err := maybeCreateTable(&option{
Username: "root",
Password: "openIM123",
Address: []string{"172.28.0.1:13306"},
Database: "im-db",
LogLevel: 4,
SlowThreshold: 500,
MaxOpenConn: 1000,
MaxIdleConn: 100,
MaxLifeTime: 60,
Connect: connect(expectExec{
query: "CREATE DATABASE IF NOT EXISTS `im-db` default charset utf8mb4 COLLATE utf8mb4_unicode_ci",
args: nil,
}),
})
if err != nil {
t.Fatal(err)
}
})
t.Run("err", func(t *testing.T) {
e := errors.New("e")
err := maybeCreateTable(&option{
Username: "root",
Password: "openIM123",
Address: []string{"172.28.0.1:13306"},
Database: "openIM_v3",
LogLevel: 4,
SlowThreshold: 500,
MaxOpenConn: 1000,
MaxIdleConn: 100,
MaxLifeTime: 60,
Connect: connect(expectExec{
err: e,
}),
})
if !errors.Is(err, e) {
t.Fatalf("err not is e: %v", err)
}
})
}
func connect(e expectExec) func(string, int) (*gorm.DB, error) {
return func(string, int) (*gorm.DB, error) {
return gorm.Open(mysql.New(mysql.Config{
SkipInitializeWithVersion: true,
Conn: sql.OpenDB(e),
}), &gorm.Config{
Logger: logger.Discard,
})
}
}
type expectExec struct {
err error
query string
args []driver.NamedValue
}
func (c expectExec) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
if c.err != nil {
return nil, c.err
}
if query != c.query {
return nil, fmt.Errorf("query mismatch. expect: %s, got: %s", c.query, query)
}
if reflect.DeepEqual(args, c.args) {
return nil, fmt.Errorf("args mismatch. expect: %v, got: %v", c.args, args)
}
return noEffectResult{}, nil
}
func (e expectExec) Connect(context.Context) (driver.Conn, error) { return e, nil }
func (expectExec) Driver() driver.Driver { panic("not implemented") }
func (expectExec) Prepare(query string) (driver.Stmt, error) { panic("not implemented") }
func (expectExec) Close() (e error) { return }
func (expectExec) Begin() (driver.Tx, error) { panic("not implemented") }
type noEffectResult struct{}
func (noEffectResult) LastInsertId() (i int64, e error) { return }
func (noEffectResult) RowsAffected() (i int64, e error) { return }
Loading…
Cancel
Save