diff --git a/internal/conf/db.go b/internal/conf/db.go index 37304f4d..53ce0f16 100644 --- a/internal/conf/db.go +++ b/internal/conf/db.go @@ -5,16 +5,30 @@ package conf import ( + "database/sql" "sync" + "github.com/alimy/cfg" "github.com/go-redis/redis/v8" + "github.com/sirupsen/logrus" ) var ( - _redisClient *redis.Client - _onceRedis sync.Once + _sqldb *sql.DB + _redisClient *redis.Client + _onceSql, _onceRedis sync.Once ) +func MustSqlDB() *sql.DB { + _onceSql.Do(func() { + var err error + if _, _sqldb, err = newSqlDB(); err != nil { + logrus.Fatalf("new sql db failed: %s", err) + } + }) + return _sqldb +} + func MustRedis() *redis.Client { _onceRedis.Do(func() { _redisClient = redis.NewClient(&redis.Options{ @@ -25,3 +39,19 @@ func MustRedis() *redis.Client { }) return _redisClient } + +func newSqlDB() (driver string, db *sql.DB, err error) { + if cfg.If("MySQL") { + driver = "mysql" + db, err = sql.Open(driver, MysqlSetting.Dsn()) + } else if cfg.If("PostgreSQL") || cfg.If("Postgres") { + driver = "pgx" + db, err = sql.Open(driver, PostgresSetting.Dsn()) + } else if cfg.If("Sqlite3") { + driver, db, err = OpenSqlite3() + } else { + driver = "mysql" + db, err = sql.Open(driver, MysqlSetting.Dsn()) + } + return +} diff --git a/internal/conf/db_sqlx.go b/internal/conf/db_sqlx.go index 9b4b3c2c..c2699c84 100644 --- a/internal/conf/db_sqlx.go +++ b/internal/conf/db_sqlx.go @@ -5,10 +5,8 @@ package conf import ( - "database/sql" "sync" - "github.com/alimy/cfg" "github.com/jmoiron/sqlx" "github.com/sirupsen/logrus" ) @@ -20,28 +18,11 @@ var ( func MustSqlxDB() *sqlx.DB { _onceSqlx.Do(func() { - var err error - if _sqlxdb, err = newSqlxDB(); err != nil { + driver, db, err := newSqlDB() + if err != nil { logrus.Fatalf("new sqlx db failed: %s", err) } + _sqlxdb = sqlx.NewDb(db, driver) }) return _sqlxdb } - -func newSqlxDB() (db *sqlx.DB, err error) { - if cfg.If("MySQL") { - db, err = sqlx.Open("mysql", MysqlSetting.Dsn()) - } else if cfg.If("PostgreSQL") || cfg.If("Postgres") { - db, err = sqlx.Open("pgx", PostgresSetting.Dsn()) - } else if cfg.If("Sqlite3") { - var ( - driver string - sqldb *sql.DB - ) - driver, sqldb, err = OpenSqlite3() - db = sqlx.NewDb(sqldb, driver) - } else { - db, err = sqlx.Open("mysql", MysqlSetting.Dsn()) - } - return -}