From a80739239fbcbb2e98ae33c2b56b9897872dc015 Mon Sep 17 00:00:00 2001 From: Michael Li Date: Wed, 5 Apr 2023 11:22:12 +0800 Subject: [PATCH] sqlx: optimize table prefix for query logic --- internal/dao/sakila/sakila_suite_test.go | 13 ++++++ internal/dao/sakila/sqlx.go | 37 ++++++++++++++--- internal/dao/sakila/sqlx_test.go | 51 ++++++++++++++++++++++++ 3 files changed, 95 insertions(+), 6 deletions(-) create mode 100644 internal/dao/sakila/sakila_suite_test.go create mode 100644 internal/dao/sakila/sqlx_test.go diff --git a/internal/dao/sakila/sakila_suite_test.go b/internal/dao/sakila/sakila_suite_test.go new file mode 100644 index 00000000..15a529fd --- /dev/null +++ b/internal/dao/sakila/sakila_suite_test.go @@ -0,0 +1,13 @@ +package sakila_test + +import ( + "testing" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +func TestSakila(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "Sakila Suite") +} diff --git a/internal/dao/sakila/sqlx.go b/internal/dao/sakila/sqlx.go index 78598146..4e6b440c 100644 --- a/internal/dao/sakila/sqlx.go +++ b/internal/dao/sakila/sqlx.go @@ -5,10 +5,10 @@ package sakila import ( + "bytes" "context" "database/sql" _ "embed" - "strings" "github.com/alimy/yesql" "github.com/jmoiron/sqlx" @@ -17,7 +17,8 @@ import ( ) var ( - _db *sqlx.DB + _db *sqlx.DB + _tablePrefix string ) type sqlxSrv struct { @@ -110,7 +111,27 @@ func n(query string) *sqlx.NamedStmt { // t repace table prefix for query func t(query string) string { - return strings.Replace(query, "@", conf.DatabaseSetting.TablePrefix, -1) + buf := bytes.NewBuffer(make([]byte, 0, len(query))) + qr := make([]rune, 0, len(query)) + for _, c := range query { + qr = append(qr, c) + } + isPrevAt := false + size := len(qr) + for i := 0; i < size; i++ { + if qr[i] == '@' { + if next := i + 1; next == size || (!isPrevAt && qr[next] != '@') { + buf.WriteString(_tablePrefix) + } else { + buf.WriteRune('@') + } + isPrevAt = true + } else { + buf.WriteRune(qr[i]) + isPrevAt = false + } + } + return buf.String() } // yesqlScan yesql.Scan help function @@ -122,9 +143,7 @@ func yesqlScan[T any](query yesql.SQLQuery, obj T) T { } func mustBuild[T any](db *sqlx.DB, fn func(yesql.PreparexBuilder, ...context.Context) (T, error)) T { - p := yesql.NewPreparexBuilder(db, func(query string) string { - return strings.Replace(query, "@", conf.DatabaseSetting.TablePrefix, -1) - }) + p := yesql.NewPreparexBuilder(db, t) obj, err := fn(p) if err != nil { logrus.Fatalf("build object failure: %s", err) @@ -134,4 +153,10 @@ func mustBuild[T any](db *sqlx.DB, fn func(yesql.PreparexBuilder, ...context.Con func initSqlxDB() { _db = conf.MustSqlxDB() + _tablePrefix = conf.DatabaseSetting.TablePrefix +} + +// FnTest_t just for test t(...) function not use in out package +func FnTest_t(query string) string { + return t(query) } diff --git a/internal/dao/sakila/sqlx_test.go b/internal/dao/sakila/sqlx_test.go new file mode 100644 index 00000000..98947a9f --- /dev/null +++ b/internal/dao/sakila/sqlx_test.go @@ -0,0 +1,51 @@ +package sakila_test + +import ( + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + + "github.com/rocboss/paopao-ce/internal/dao/sakila" +) + +var _ = Describe("Sqlx", Ordered, func() { + type queries []struct { + originQuery string + fixedQuery string + } + var samples queries + + BeforeAll(func() { + samples = queries{ + { + originQuery: `SELECT * FROM @user WHERE username=?@_`, + fixedQuery: `SELECT * FROM user WHERE username=?_`, + }, + { + originQuery: `SELECT * FROM @user WHERE username=?`, + fixedQuery: `SELECT * FROM user WHERE username=?`, + }, + { + originQuery: `SELECT * FROM @@user WHERE 用户名=?`, + fixedQuery: `SELECT * FROM @@user WHERE 用户名=?`, + }, + { + originQuery: `SELECT * FROM @@user, @@@@contact WHERE 用户名=?`, + fixedQuery: `SELECT * FROM @@user, @@@@contact WHERE 用户名=?`, + }, + { + originQuery: `SELECT @@name, @@@@@id FROM @@user, @@@@contact WHERE 用户名=?`, + fixedQuery: `SELECT @@name, @@@@@id FROM @@user, @@@@contact WHERE 用户名=?`, + }, + { + originQuery: `SELECT @name, @id FROM @user, @contact WHERE 用户名=?`, + fixedQuery: `SELECT name, id FROM user, contact WHERE 用户名=?`, + }, + } + }) + + It("test internal t func", func() { + for _, t := range samples { + Expect(sakila.FnTest_t(t.originQuery)).To(Equal(t.fixedQuery)) + } + }) +})