diff --git a/internal/dao/sakila/sqlx.go b/internal/dao/sakila/sqlx.go index 5b76e5aa..d5c2db8e 100644 --- a/internal/dao/sakila/sqlx.go +++ b/internal/dao/sakila/sqlx.go @@ -48,6 +48,38 @@ func (s *sqlxServant) withTx(ctx context.Context, opts *sql.TxOptions, handle fu return tx.Commit() } +func (s *sqlxServant) in(query string, args ...any) (string, []any, error) { + q, params, err := sqlx.In(query, args...) + if err != nil { + return "", nil, err + } + return s.db.Rebind(q), params, nil +} + +func (s *sqlxServant) inExec(execer sqlx.Execer, query string, args ...any) (sql.Result, error) { + q, params, err := sqlx.In(query, args...) + if err != nil { + return nil, err + } + return execer.Exec(s.db.Rebind(q), params...) +} + +func (s *sqlxServant) inSelect(queryer sqlx.Queryer, dest any, query string, args ...any) error { + q, params, err := sqlx.In(query, args...) + if err != nil { + return err + } + return sqlx.Select(queryer, dest, s.db.Rebind(q), params...) +} + +func (s *sqlxServant) inGet(queryer sqlx.Queryer, dest any, query string, args ...any) error { + q, params, err := sqlx.In(query, args...) + if err != nil { + return err + } + return sqlx.Get(queryer, dest, s.db.Rebind(q), params...) +} + func newSqlxServant(db *sqlx.DB) *sqlxServant { return &sqlxServant{ db: db, @@ -61,14 +93,6 @@ func sqlxDB() *sqlx.DB { return _db } -func in(db *sqlx.DB, query string, args ...interface{}) (string, []interface{}, error) { - q, params, err := sqlx.In(query, args...) - if err != nil { - return "", nil, err - } - return db.Rebind(q), params, nil -} - func r(query string) string { db := sqlxDB() return db.Rebind(t(query)) diff --git a/internal/dao/sakila/topics.go b/internal/dao/sakila/topics.go index 05c4613a..36f2cce4 100644 --- a/internal/dao/sakila/topics.go +++ b/internal/dao/sakila/topics.go @@ -26,24 +26,23 @@ type topicServant struct { stmtTagsByIdA string stmtTagsByIdB string stmtDecrTagsById string - stmtTagsByName string + stmtTagsForIncr string stmtIncrTagsById string } func (s *topicServant) UpsertTags(userId int64, tags []string) (res []*core.Tag, xerr error) { - if len(tags) <= 0 { + if len(tags) == 0 { return nil, nil } xerr = s.with(func(tx *sqlx.Tx) error { - query, args, err := in(s.db, s.stmtTagsByName, tags) - var ts []*core.Tag - if err = tx.Select(&ts, query, args...); err != nil { + var upTags []*core.Tag + if err := s.inSelect(tx, &upTags, s.stmtTagsForIncr, tags); err != nil { return err } now := time.Now().Unix() - if len(ts) > 0 { + if len(upTags) > 0 { var ids []int64 - for _, t := range ts { + for _, t := range upTags { ids = append(ids, t.ID) t.QuoteNum++ // prepare remain tags just delete updated tag @@ -57,13 +56,10 @@ func (s *topicServant) UpsertTags(userId int64, tags []string) (res []*core.Tag, } } } - if query, args, err = in(s.db, s.stmtIncrTagsById, now, ids); err != nil { + if _, err := s.inExec(tx, s.stmtIncrTagsById, now, ids); err != nil { return err } - if _, err = tx.Exec(query, args...); err != nil { - return err - } - res = append(res, ts...) + res = append(res, upTags...) } // process remain tags if tags is not empty if len(tags) == 0 { @@ -82,11 +78,7 @@ func (s *topicServant) UpsertTags(userId int64, tags []string) (res []*core.Tag, ids = append(ids, id) } var newTags []*core.Tag - query, args, err = in(s.db, s.stmtTagsByIdB, ids) - if err != nil { - return err - } - if err = tx.Select(&newTags, query, args...); err != nil { + if err := s.inSelect(tx, &newTags, s.stmtTagsByIdB, ids); err != nil { return err } res = append(res, newTags...) @@ -97,16 +89,12 @@ func (s *topicServant) UpsertTags(userId int64, tags []string) (res []*core.Tag, func (s *topicServant) DecrTagsById(ids []int64) error { return s.with(func(tx *sqlx.Tx) error { - query, args, err := in(s.db, s.stmtTagsByIdA, ids) - if err != nil { - return err - } var ids []int64 - if err = tx.Select(&ids, query, args...); err != nil { + err := s.inSelect(tx, &ids, s.stmtTagsByIdA, ids) + if err != nil { return err } - query, args, err = in(s.db, s.stmtDecrTagsById, time.Now().Unix(), ids) - _, err = tx.Exec(query, args...) + _, err = s.inExec(tx, s.stmtDecrTagsById, time.Now().Unix(), ids) return err }) } @@ -116,7 +104,7 @@ func (s *topicServant) GetTags(category core.TagCategory, offset int, limit int) case core.TagCategoryHot: err = s.stmtHotTags.Select(&res, offset, limit) case core.TagCategoryNew: - err = s.stmtHotTags.Select(&res, offset, limit) + err = s.stmtNewestTags.Select(&res, offset, limit) } return } @@ -139,10 +127,10 @@ func newTopicService(db *sqlx.DB) core.TopicService { stmtTagsByKeywordA: c(`SELECT id, user_id, tag, quote_num FROM @tag WHERE is_del = 0 ORDER BY quote_num DESC OFFSET 0 LIMIT 6`), stmtTagsByKeywordB: c(`SELECT id, user_id, tag, quote_num FROM @tag WHERE is_del = 0 AND tag LIKE ? ORDER BY quote_num DESC OFFSET 0 LIMIT 6`), stmtInsertTag: c(`INSERT INTO @tag (user_id, tag, created_on, modified_on, quote_num) VALUES (?, ?, ?, ?, 1)`), - stmtTagsByIdA: r(`SELECT id FROM @tag WHERE id IN (?) AND is_del = 0 AND quote_num >= 0`), + stmtTagsByIdA: r(`SELECT id FROM @tag WHERE id IN (?) AND is_del = 0 AND quote_num > 0`), stmtTagsByIdB: r(`SELECT id, user_id, tag, quote_num FROM @tag WHERE id IN (?)`), stmtDecrTagsById: r(`UPDATE @tag SET quote_num=quote_num-1, modified_on=? WHERE id IN (?)`), - stmtTagsByName: r(`SELECT id, user_id, tag, quote_num FROM @tag WHERE tag IN (?) AND is_del = 0 AND quote_num >= 0`), - stmtIncrTagsById: r(`UPDATE @tag SET quote_num=quote_num+1, modified_on=? WHERE id IN (?)`), + stmtTagsForIncr: r(`SELECT id, user_id, tag, quote_num FROM @tag WHERE tag IN (?)`), + stmtIncrTagsById: r(`UPDATE @tag SET quote_num=quote_num+1, is_del=0, modified_on=? WHERE id IN (?)`), } }