From 1ce877fb594a48e2374802dae6f8f8669d9cd703 Mon Sep 17 00:00:00 2001 From: Michael Li Date: Thu, 5 Jan 2023 14:21:29 +0800 Subject: [PATCH] sqlx: add sqlx db initial logic --- internal/conf/conf.go | 1 - internal/conf/db.go | 27 +++++++++++++++++++++++ internal/conf/db_cgo.go | 5 +++-- internal/conf/db_gorm.go | 15 ++----------- internal/conf/db_nocgo.go | 5 +++-- internal/conf/db_sqlx.go | 26 +++++++++++++++++++++- internal/migration/migration_embed.go | 2 +- internal/servants/web/broker/broker.go | 4 ++++ internal/servants/web/broker/message.go | 7 +++--- internal/servants/web/broker/post.go | 4 ++-- internal/servants/web/broker/user.go | 13 +++++------ internal/servants/web/broker/wallet.go | 8 +++---- internal/servants/web/routers/api/api.go | 3 +++ internal/servants/web/routers/api/home.go | 9 ++++---- internal/servants/web/web.go | 2 +- 15 files changed, 88 insertions(+), 43 deletions(-) create mode 100644 internal/conf/db.go diff --git a/internal/conf/conf.go b/internal/conf/conf.go index 950d2d8a..f1c84510 100644 --- a/internal/conf/conf.go +++ b/internal/conf/conf.go @@ -122,7 +122,6 @@ func Initialize(suite []string, noDefault bool) { } setupLogger() - setupDBEngine() } func GetOssDomain() string { diff --git a/internal/conf/db.go b/internal/conf/db.go new file mode 100644 index 00000000..37304f4d --- /dev/null +++ b/internal/conf/db.go @@ -0,0 +1,27 @@ +// Copyright 2023 ROC. All rights reserved. +// Use of this source code is governed by a MIT style +// license that can be found in the LICENSE file. + +package conf + +import ( + "sync" + + "github.com/go-redis/redis/v8" +) + +var ( + _redisClient *redis.Client + _onceRedis sync.Once +) + +func MustRedis() *redis.Client { + _onceRedis.Do(func() { + _redisClient = redis.NewClient(&redis.Options{ + Addr: redisSetting.Host, + Password: redisSetting.Password, + DB: redisSetting.DB, + }) + }) + return _redisClient +} diff --git a/internal/conf/db_cgo.go b/internal/conf/db_cgo.go index 218512a5..231dc16b 100644 --- a/internal/conf/db_cgo.go +++ b/internal/conf/db_cgo.go @@ -18,8 +18,9 @@ const ( sqlite3InCgoEnabled = true ) -func OpenSqlite3() (*sql.DB, error) { - return sql.Open("sqlite3", Sqlite3Setting.Dsn("sqlite3")) +func OpenSqlite3() (string, *sql.DB, error) { + db, err := sql.Open("sqlite3", Sqlite3Setting.Dsn("sqlite3")) + return "sqlite3", db, err } func gormOpenSqlite3(opts ...gorm.Option) (*gorm.DB, error) { diff --git a/internal/conf/db_gorm.go b/internal/conf/db_gorm.go index 5d28b6f2..f667ff61 100644 --- a/internal/conf/db_gorm.go +++ b/internal/conf/db_gorm.go @@ -9,7 +9,6 @@ import ( "time" "github.com/alimy/cfg" - "github.com/go-redis/redis/v8" "github.com/sirupsen/logrus" "gorm.io/driver/mysql" "gorm.io/driver/postgres" @@ -20,8 +19,6 @@ import ( ) var ( - Redis *redis.Client - _gormdb *gorm.DB _onceGorm sync.Once ) @@ -29,14 +26,14 @@ var ( func MustGormDB() *gorm.DB { _onceGorm.Do(func() { var err error - if _gormdb, err = newDBEngine(); err != nil { + if _gormdb, err = newGormDB(); err != nil { logrus.Fatalf("new gorm db failed: %s", err) } }) return _gormdb } -func newDBEngine() (*gorm.DB, error) { +func newGormDB() (*gorm.DB, error) { newLogger := logger.New( logrus.StandardLogger(), // io writer(日志输出的目标,前缀和日志包含的内容) logger.Config{ @@ -85,11 +82,3 @@ func newDBEngine() (*gorm.DB, error) { return db, err } - -func setupDBEngine() { - Redis = redis.NewClient(&redis.Options{ - Addr: redisSetting.Host, - Password: redisSetting.Password, - DB: redisSetting.DB, - }) -} diff --git a/internal/conf/db_nocgo.go b/internal/conf/db_nocgo.go index 977a6695..bbe79a50 100644 --- a/internal/conf/db_nocgo.go +++ b/internal/conf/db_nocgo.go @@ -19,8 +19,9 @@ const ( sqlite3InCgoEnabled = false ) -func OpenSqlite3() (*sql.DB, error) { - return sql.Open("sqlite", Sqlite3Setting.Dsn("sqlite")) +func OpenSqlite3() (string, *sql.DB, error) { + db, err := sql.Open("sqlite", Sqlite3Setting.Dsn("sqlite")) + return "sqlite", db, err } func gormOpenSqlite3(opts ...gorm.Option) (*gorm.DB, error) { diff --git a/internal/conf/db_sqlx.go b/internal/conf/db_sqlx.go index 941b219c..bfdc7a95 100644 --- a/internal/conf/db_sqlx.go +++ b/internal/conf/db_sqlx.go @@ -5,9 +5,12 @@ package conf import ( + "database/sql" "sync" + "github.com/alimy/cfg" "github.com/jmoiron/sqlx" + "github.com/sirupsen/logrus" ) var ( @@ -17,7 +20,28 @@ var ( func MustSqlxDB() *sqlx.DB { _onceSqlx.Do(func() { - // TODO: init sqlx.DB + var err error + if _sqlxdb, err = newSqlxDB(); err != nil { + logrus.Fatalf("new sqlx db failed: %s", err) + } }) return _sqlxdb } + +func newSqlxDB() (db *sqlx.DB, err error) { + if cfg.If("MySQL") { + db, err = sqlx.Open("mysql", MysqlSetting.Dsn()) + } else if cfg.If("PostgreSQL") || cfg.If("Postgres") { + db, err = sqlx.Open("postgres", PostgresSetting.Dsn()) + } else if cfg.If("Sqlite3") { + var ( + driver string + sqldb *sql.DB + ) + driver, sqldb, err = OpenSqlite3() + db = sqlx.NewDb(sqldb, driver) + } else { + db, err = sqlx.Open("mysql", MysqlSetting.Dsn()) + } + return +} diff --git a/internal/migration/migration_embed.go b/internal/migration/migration_embed.go index 02ede293..c57e6081 100644 --- a/internal/migration/migration_embed.go +++ b/internal/migration/migration_embed.go @@ -44,7 +44,7 @@ func Run() { dbName = (*conf.PostgresSetting)["DBName"] db, err = sql.Open("postgres", conf.PostgresSetting.Dsn()) } else if cfg.If("Sqlite3") { - db, err = conf.OpenSqlite3() + _, db, err = conf.OpenSqlite3() } else { dbName = conf.MysqlSetting.DBName db, err = sql.Open("mysql", conf.MysqlSetting.Dsn()) diff --git a/internal/servants/web/broker/broker.go b/internal/servants/web/broker/broker.go index fcc5846d..8cd39205 100644 --- a/internal/servants/web/broker/broker.go +++ b/internal/servants/web/broker/broker.go @@ -6,6 +6,8 @@ package broker import ( "github.com/alimy/cfg" + "github.com/go-redis/redis/v8" + "github.com/rocboss/paopao-ce/internal/conf" "github.com/rocboss/paopao-ce/internal/core" "github.com/rocboss/paopao-ce/internal/dao" "github.com/sirupsen/logrus" @@ -15,6 +17,7 @@ var ( ds core.DataService ts core.TweetSearchService oss core.ObjectStorageService + redisClient *redis.Client DisablePhoneVerify bool ) @@ -22,6 +25,7 @@ func Initialize() { ds = dao.DataService() ts = dao.TweetSearchService() oss = dao.ObjectStorageService() + redisClient = conf.MustRedis() DisablePhoneVerify = !cfg.If("Sms") } diff --git a/internal/servants/web/broker/message.go b/internal/servants/web/broker/message.go index 6e5dd1a1..9bb59c8b 100644 --- a/internal/servants/web/broker/message.go +++ b/internal/servants/web/broker/message.go @@ -9,7 +9,6 @@ import ( "time" "github.com/gin-gonic/gin" - "github.com/rocboss/paopao-ce/internal/conf" "github.com/rocboss/paopao-ce/internal/core" "github.com/rocboss/paopao-ce/pkg/convert" "github.com/rocboss/paopao-ce/pkg/errcode" @@ -31,7 +30,7 @@ func CreateWhisper(c *gin.Context, msg *core.Message) (*core.Message, error) { whisperKey := fmt.Sprintf("WhisperTimes:%d", msg.SenderUserID) // 今日频次限制 - if res, _ := conf.Redis.Get(c, whisperKey).Result(); convert.StrTo(res).MustInt() >= MAX_WHISPER_NUM_DAILY { + if res, _ := redisClient.Get(c, whisperKey).Result(); convert.StrTo(res).MustInt() >= MAX_WHISPER_NUM_DAILY { return nil, errcode.TooManyWhisperNum } @@ -42,11 +41,11 @@ func CreateWhisper(c *gin.Context, msg *core.Message) (*core.Message, error) { } // 写入当日(自然日)计数缓存 - conf.Redis.Incr(c, whisperKey).Result() + redisClient.Incr(c, whisperKey).Result() currentTime := time.Now() endTime := time.Date(currentTime.Year(), currentTime.Month(), currentTime.Day(), 23, 59, 59, 0, currentTime.Location()) - conf.Redis.Expire(c, whisperKey, endTime.Sub(currentTime)) + redisClient.Expire(c, whisperKey, endTime.Sub(currentTime)) return msg, err } diff --git a/internal/servants/web/broker/post.go b/internal/servants/web/broker/post.go index 9a73528b..59d87975 100644 --- a/internal/servants/web/broker/post.go +++ b/internal/servants/web/broker/post.go @@ -501,8 +501,8 @@ func DeleteSearchPost(post *core.Post) error { } func PushPostsToSearch(c *gin.Context) { - if ok, _ := conf.Redis.SetNX(c, "JOB_PUSH_TO_SEARCH", 1, time.Hour).Result(); ok { - defer conf.Redis.Del(c, "JOB_PUSH_TO_SEARCH") + if ok, _ := redisClient.SetNX(c, "JOB_PUSH_TO_SEARCH", 1, time.Hour).Result(); ok { + defer redisClient.Del(c, "JOB_PUSH_TO_SEARCH") splitNum := 1000 totalRows, _ := GetPostCount(&core.ConditionsT{ diff --git a/internal/servants/web/broker/user.go b/internal/servants/web/broker/user.go index 66919a99..655ce2ff 100644 --- a/internal/servants/web/broker/user.go +++ b/internal/servants/web/broker/user.go @@ -13,7 +13,6 @@ import ( "github.com/gin-gonic/gin" "github.com/gofrs/uuid" - "github.com/rocboss/paopao-ce/internal/conf" "github.com/rocboss/paopao-ce/internal/core" "github.com/rocboss/paopao-ce/pkg/convert" "github.com/rocboss/paopao-ce/pkg/errcode" @@ -102,7 +101,7 @@ func DoLogin(ctx *gin.Context, param *AuthRequest) (*core.User, error) { } if user.Model != nil && user.ID > 0 { - if errTimes, err := conf.Redis.Get(ctx, fmt.Sprintf("%s:%d", _LoginErrKey, user.ID)).Result(); err == nil { + if errTimes, err := redisClient.Get(ctx, fmt.Sprintf("%s:%d", _LoginErrKey, user.ID)).Result(); err == nil { if convert.StrTo(errTimes).MustInt() >= _MaxLoginErrTimes { return nil, errcode.TooManyLoginError } @@ -116,14 +115,14 @@ func DoLogin(ctx *gin.Context, param *AuthRequest) (*core.User, error) { } // 清空登录计数 - conf.Redis.Del(ctx, fmt.Sprintf("%s:%d", _LoginErrKey, user.ID)) + redisClient.Del(ctx, fmt.Sprintf("%s:%d", _LoginErrKey, user.ID)) return user, nil } // 登录错误计数 - _, err = conf.Redis.Incr(ctx, fmt.Sprintf("%s:%d", _LoginErrKey, user.ID)).Result() + _, err = redisClient.Incr(ctx, fmt.Sprintf("%s:%d", _LoginErrKey, user.ID)).Result() if err == nil { - conf.Redis.Expire(ctx, fmt.Sprintf("%s:%d", _LoginErrKey, user.ID), time.Hour).Result() + redisClient.Expire(ctx, fmt.Sprintf("%s:%d", _LoginErrKey, user.ID), time.Hour).Result() } return nil, errcode.UnauthorizedAuthFailed @@ -435,12 +434,12 @@ func SendPhoneCaptcha(ctx *gin.Context, phone string) error { } // 写入计数缓存 - conf.Redis.Incr(ctx, "PaoPaoSmsCaptcha:"+phone).Result() + redisClient.Incr(ctx, "PaoPaoSmsCaptcha:"+phone).Result() currentTime := time.Now() endTime := time.Date(currentTime.Year(), currentTime.Month(), currentTime.Day(), 23, 59, 59, 0, currentTime.Location()) - conf.Redis.Expire(ctx, "PaoPaoSmsCaptcha:"+phone, endTime.Sub(currentTime)) + redisClient.Expire(ctx, "PaoPaoSmsCaptcha:"+phone, endTime.Sub(currentTime)) return nil } diff --git a/internal/servants/web/broker/wallet.go b/internal/servants/web/broker/wallet.go index 6f89edf8..fc4392a2 100644 --- a/internal/servants/web/broker/wallet.go +++ b/internal/servants/web/broker/wallet.go @@ -5,11 +5,11 @@ package broker import ( - "github.com/rocboss/paopao-ce/internal/core" "time" + "github.com/rocboss/paopao-ce/internal/core" + "github.com/gin-gonic/gin" - "github.com/rocboss/paopao-ce/internal/conf" "github.com/rocboss/paopao-ce/pkg/errcode" ) @@ -26,7 +26,7 @@ func CreateRecharge(userID, amount int64) (*core.WalletRecharge, error) { } func FinishRecharge(ctx *gin.Context, id int64, tradeNo string) error { - if ok, _ := conf.Redis.SetNX(ctx, "PaoPaoRecharge:"+tradeNo, 1, time.Second*5).Result(); ok { + if ok, _ := redisClient.SetNX(ctx, "PaoPaoRecharge:"+tradeNo, 1, time.Second*5).Result(); ok { recharge, err := ds.GetRechargeByID(id) if err != nil { return err @@ -36,7 +36,7 @@ func FinishRecharge(ctx *gin.Context, id int64, tradeNo string) error { // 标记为已付款 err := ds.HandleRechargeSuccess(recharge, tradeNo) - defer conf.Redis.Del(ctx, "PaoPaoRecharge:"+tradeNo) + defer redisClient.Del(ctx, "PaoPaoRecharge:"+tradeNo) if err != nil { return err diff --git a/internal/servants/web/routers/api/api.go b/internal/servants/web/routers/api/api.go index 656ad0bf..bfd28666 100644 --- a/internal/servants/web/routers/api/api.go +++ b/internal/servants/web/routers/api/api.go @@ -6,6 +6,7 @@ package api import ( "github.com/alimy/cfg" + "github.com/go-redis/redis/v8" "github.com/rocboss/paopao-ce/internal/conf" "github.com/rocboss/paopao-ce/internal/core" "github.com/rocboss/paopao-ce/internal/dao" @@ -14,11 +15,13 @@ import ( ) var ( + redisClient *redis.Client alipayClient *alipay.Client objectStorage core.ObjectStorageService ) func Initialize() { + redisClient = conf.MustRedis() objectStorage = dao.ObjectStorageService() if cfg.If("Alipay") { diff --git a/internal/servants/web/routers/api/home.go b/internal/servants/web/routers/api/home.go index 30f115a9..c0a946da 100644 --- a/internal/servants/web/routers/api/home.go +++ b/internal/servants/web/routers/api/home.go @@ -14,7 +14,6 @@ import ( "github.com/afocus/captcha" "github.com/gin-gonic/gin" "github.com/gofrs/uuid" - "github.com/rocboss/paopao-ce/internal/conf" "github.com/rocboss/paopao-ce/internal/core" "github.com/rocboss/paopao-ce/internal/servants/web/assets" "github.com/rocboss/paopao-ce/internal/servants/web/broker" @@ -65,7 +64,7 @@ func GetCaptcha(c *gin.Context) { key := util.EncodeMD5(uuid.Must(uuid.NewV4()).String()) // 五分钟有效期 - conf.Redis.SetEX(c, "PaoPaoCaptcha:"+key, password, time.Minute*5) + redisClient.SetEX(c, "PaoPaoCaptcha:"+key, password, time.Minute*5) response := app.NewResponse(c) response.ToResponse(gin.H{ @@ -85,14 +84,14 @@ func PostCaptcha(c *gin.Context) { } // 验证图片验证码 - if res, err := conf.Redis.Get(c.Request.Context(), "PaoPaoCaptcha:"+param.ImgCaptchaID).Result(); err != nil || res != param.ImgCaptcha { + if res, err := redisClient.Get(c.Request.Context(), "PaoPaoCaptcha:"+param.ImgCaptchaID).Result(); err != nil || res != param.ImgCaptcha { response.ToErrorResponse(errcode.ErrorCaptchaPassword) return } - conf.Redis.Del(c.Request.Context(), "PaoPaoCaptcha:"+param.ImgCaptchaID).Result() + redisClient.Del(c.Request.Context(), "PaoPaoCaptcha:"+param.ImgCaptchaID).Result() // 今日频次限制 - if res, _ := conf.Redis.Get(c.Request.Context(), "PaoPaoSmsCaptcha:"+param.Phone).Result(); convert.StrTo(res).MustInt() >= MAX_PHONE_CAPTCHA { + if res, _ := redisClient.Get(c.Request.Context(), "PaoPaoSmsCaptcha:"+param.Phone).Result(); convert.StrTo(res).MustInt() >= MAX_PHONE_CAPTCHA { response.ToErrorResponse(errcode.TooManyPhoneCaptchaSend) return } diff --git a/internal/servants/web/web.go b/internal/servants/web/web.go index 286109d3..b385e8c9 100644 --- a/internal/servants/web/web.go +++ b/internal/servants/web/web.go @@ -19,7 +19,7 @@ import ( func RouteWeb(e *gin.Engine) { oss := dao.ObjectStorageService() ds := &base.DaoServant{ - Redis: conf.Redis, + Redis: conf.MustRedis(), Ds: dao.DataService(), Ts: dao.TweetSearchService(), }