From 6c594a8b24d349a866da05ada453c1537b10b5b2 Mon Sep 17 00:00:00 2001 From: Michael Li Date: Thu, 24 Aug 2023 09:05:40 +0800 Subject: [PATCH] sqlx: optimize get query for stmt --- go.mod | 2 +- go.sum | 4 ++-- internal/dao/sakila/contacts.go | 3 +-- internal/dao/sakila/messages.go | 2 +- internal/dao/sakila/security.go | 3 +-- internal/dao/sakila/sqlx.go | 18 ++++++++++++++++++ internal/dao/sakila/tweets.go | 6 ++---- internal/dao/sakila/user.go | 9 +++------ internal/dao/sakila/wallet.go | 2 +- 9 files changed, 30 insertions(+), 19 deletions(-) diff --git a/go.mod b/go.mod index fd829383..20de538b 100644 --- a/go.mod +++ b/go.mod @@ -10,7 +10,7 @@ require ( github.com/alimy/yesql v1.9.0 github.com/aliyun/aliyun-oss-go-sdk v2.2.8+incompatible github.com/allegro/bigcache/v3 v3.1.0 - github.com/bitbus/sqlx v1.6.0 + github.com/bitbus/sqlx v1.7.0 github.com/bufbuild/connect-go v1.10.0 github.com/bytedance/sonic v1.10.0 github.com/cockroachdb/errors v1.10.0 diff --git a/go.sum b/go.sum index 662f72e0..b53b09e9 100644 --- a/go.sum +++ b/go.sum @@ -178,8 +178,8 @@ github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24 github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/bgentry/speakeasy v0.1.0/go.mod h1:+zsyZBPWlz7T6j88CTgSN5bM796AkVf0kBD4zp0CCIs= -github.com/bitbus/sqlx v1.6.0 h1:ewrBydRkyHZqfOqvHVYpiBlqQtLn93B/loa2EwjpZ74= -github.com/bitbus/sqlx v1.6.0/go.mod h1:MemKLfQ600g6PxUVsIDe48PlY3wOquyW2ApeiXoynFo= +github.com/bitbus/sqlx v1.7.0 h1:n/hAlfY9bI29J9uObqAtjfITgNU2+XtY1ECnJUdmCZc= +github.com/bitbus/sqlx v1.7.0/go.mod h1:MemKLfQ600g6PxUVsIDe48PlY3wOquyW2ApeiXoynFo= github.com/bitly/go-hostpool v0.0.0-20171023180738-a3a6125de932/go.mod h1:NOuUCSz6Q9T7+igc/hlvDOUdtWKryOrtFyIVABv/p7k= github.com/bitly/go-simplejson v0.5.0/go.mod h1:cXHtHw4XUPsvGaxgjIAn8PhEWG9NfngEKAMDJEczWVA= github.com/bits-and-blooms/bitset v1.2.0/go.mod h1:gIdJ4wp64HaoK2YrL1Q5/N7Y16edYb8uY+O0FJTyyDA= diff --git a/internal/dao/sakila/contacts.go b/internal/dao/sakila/contacts.go index cf8f21ee..51a35c55 100644 --- a/internal/dao/sakila/contacts.go +++ b/internal/dao/sakila/contacts.go @@ -160,8 +160,7 @@ func (s *contactManageSrv) IsFriend(userId int64, friendId int64) (res bool) { } func (s *contactManageSrv) fetchOrNewContact(tx *sqlx.Tx, userId int64, friendId int64, status int8) (res *cs.Contact, err error) { - res = &cs.Contact{} - if err = tx.Stmtx(s.q.GetContact).Get(res, userId, friendId); err == nil { + if err = stmtGet(tx.Stmtx(s.q.GetContact), res, userId, friendId); err == nil { return } result, xerr := tx.Stmtx(s.q.CreateContact).Exec(userId, friendId, status, time.Now().Unix()) diff --git a/internal/dao/sakila/messages.go b/internal/dao/sakila/messages.go index 34f42a59..8de72d59 100644 --- a/internal/dao/sakila/messages.go +++ b/internal/dao/sakila/messages.go @@ -45,7 +45,7 @@ func (s *messageSrv) GetUnreadCount(userID int64) (res int64, err error) { } func (s *messageSrv) GetMessageByID(id int64) (res *ms.Message, err error) { - err = s.q.GetMessageById.Get(res, id) + err = stmtGet(s.q.GetMessageById, res, id) return } diff --git a/internal/dao/sakila/security.go b/internal/dao/sakila/security.go index 0cb08d08..476f5475 100644 --- a/internal/dao/sakila/security.go +++ b/internal/dao/sakila/security.go @@ -29,8 +29,7 @@ type securitySrv struct { // GetLatestPhoneCaptcha 获取最新短信验证码 func (s *securitySrv) GetLatestPhoneCaptcha(phone string) (res *ms.Captcha, err error) { - res = &ms.Captcha{} - err = s.q.GetLatestPhoneCaptcha.Get(res, phone) + err = stmtGet(s.q.GetLatestPhoneCaptcha, res, phone) return } diff --git a/internal/dao/sakila/sqlx.go b/internal/dao/sakila/sqlx.go index b1fcde49..92cc4711 100644 --- a/internal/dao/sakila/sqlx.go +++ b/internal/dao/sakila/sqlx.go @@ -35,6 +35,24 @@ func newSqlxSrv(db *sqlx.DB) *sqlxSrv { } } +//lint:ignore U1000 stmtGet +func stmtGet[T any](stmt *sqlx.Stmt, dest *T, args ...any) error { + *dest = *new(T) + return stmt.Get(dest, args...) +} + +//lint:ignore U1000 stmtGetContext +func stmtGetContext[T any](ctx context.Context, stmt *sqlx.Stmt, dest *T, args ...any) error { + *dest = *new(T) + return stmt.GetContext(ctx, dest, args...) +} + +//lint:ignore U1000 inGet +func inGet[T any](q sqlx.Queryable, dest *T, query string, args ...any) error { + *dest = *new(T) + return q.InGet(dest, query, args...) +} + //lint:ignore U1000 r func r(query string) string { return _db.Rebind(t(query)) diff --git a/internal/dao/sakila/tweets.go b/internal/dao/sakila/tweets.go index c0e16ff6..7e387805 100644 --- a/internal/dao/sakila/tweets.go +++ b/internal/dao/sakila/tweets.go @@ -440,8 +440,7 @@ func (s *tweetSrv) GetUserPostStarCount(userID int64) (res int64, err error) { } func (s *tweetSrv) GetUserPostCollection(postID, userID int64) (res *ms.PostCollection, err error) { - res = &ms.PostCollection{} - err = s.q.GetUserPostCollection.Get(res, postID, userID, userID) + err = stmtGet(s.q.GetUserPostCollection, res, postID, userID, userID) return } @@ -467,8 +466,7 @@ func (s *tweetSrv) GetPostContentsByIDs(ids []int64) (res []*ms.PostContent, err } func (s *tweetSrv) GetPostContentByID(id int64) (res *ms.PostContent, err error) { - res = &ms.PostContent{} - err = s.q.GetPostContentById.Get(res, id) + err = stmtGet(s.q.GetPostContentById, res, id) return } diff --git a/internal/dao/sakila/user.go b/internal/dao/sakila/user.go index 275624e0..e2dd7370 100644 --- a/internal/dao/sakila/user.go +++ b/internal/dao/sakila/user.go @@ -24,20 +24,17 @@ type userManageSrv struct { } func (s *userManageSrv) GetUserByID(id int64) (res *ms.User, err error) { - res = &ms.User{} - err = s.q.GetUserById.Get(res, id) + err = stmtGet(s.q.GetUserById, res, id) return } func (s *userManageSrv) GetUserByUsername(username string) (res *ms.User, err error) { - res = &ms.User{} - err = s.q.GetUserByUsername.Get(res, username) + err = stmtGet(s.q.GetUserByUsername, res, username) return } func (s *userManageSrv) GetUserByPhone(phone string) (res *ms.User, err error) { - res = &ms.User{} - err = s.q.GetUserByPhone.Get(res, phone) + err = stmtGet(s.q.GetUserByPhone, res, phone) return } diff --git a/internal/dao/sakila/wallet.go b/internal/dao/sakila/wallet.go index ab7f09a2..9801732d 100644 --- a/internal/dao/sakila/wallet.go +++ b/internal/dao/sakila/wallet.go @@ -25,7 +25,7 @@ type walletSrv struct { } func (s *walletSrv) GetRechargeByID(id int64) (res *ms.WalletRecharge, err error) { - err = s.q.GetRechargeById.Get(res, id) + err = stmtGet(s.q.GetRechargeById, res, id) return }