From 7faa8c9ebbbbad338fb4dc957e42e6171774cd18 Mon Sep 17 00:00:00 2001 From: Elliot Maincourt Date: Sat, 28 Mar 2020 14:29:25 +0100 Subject: [PATCH] Prevent potential SQL injection Signed-off-by: Elliot Maincourt --- pkg/storage/driver/sql.go | 12 +++++++----- pkg/storage/driver/sql_test.go | 7 ++++--- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/pkg/storage/driver/sql.go b/pkg/storage/driver/sql.go index a89cb5fa9..56642666e 100644 --- a/pkg/storage/driver/sql.go +++ b/pkg/storage/driver/sql.go @@ -180,7 +180,7 @@ type SQLReleaseWrapper struct { ModifiedAt int `db:"modifiedAt"` } -// NewSQL initializes a new memory driver. +// NewSQL initializes a new sql driver. func NewSQL(dialect, connectionString string, logger func(string, ...interface{}), namespace string) (*SQL, error) { if _, ok := supportedSQLDialects[dialect]; !ok { return nil, fmt.Errorf("%s dialect isn't supported, only \"postgres\" is available for now", dialect) @@ -235,20 +235,22 @@ func (s *SQL) Get(key string) (*rspb.Release, error) { // List returns the list of all releases such that filter(release) == true func (s *SQL) List(filter func(*rspb.Release) bool) ([]*rspb.Release, error) { query := fmt.Sprintf( - "SELECT %s FROM %s WHERE %s = '%s'", + "SELECT %s FROM %s WHERE %s = $1", sqlReleaseTableBodyColumn, sqlReleaseTableName, sqlReleaseTableOwnerColumn, - sqlReleaseDefaultOwner, ) + args := []interface{}{sqlReleaseDefaultOwner} + // If a namespace was specified, we only list releases from that namespace if s.namespace != "" { - query = fmt.Sprintf("%s AND %s = '%s'", query, sqlReleaseTableNamespaceColumn, s.namespace) + query = fmt.Sprintf("%s AND %s = $2", query, sqlReleaseTableNamespaceColumn) + args = append(args, s.namespace) } var records = []SQLReleaseWrapper{} - if err := s.db.Select(&records, query); err != nil { + if err := s.db.Select(&records, query, args...); err != nil { s.Log("list: failed to list: %v", err) return nil, err } diff --git a/pkg/storage/driver/sql_test.go b/pkg/storage/driver/sql_test.go index ea3cd50a7..6486e91a8 100644 --- a/pkg/storage/driver/sql_test.go +++ b/pkg/storage/driver/sql_test.go @@ -88,15 +88,16 @@ func TestSQLList(t *testing.T) { for i := 0; i < 3; i++ { query := fmt.Sprintf( - "SELECT %s FROM %s WHERE %s = '%s'", + "SELECT %s FROM %s WHERE %s = $1 AND %s = $2", sqlReleaseTableBodyColumn, sqlReleaseTableName, sqlReleaseTableOwnerColumn, - sqlReleaseDefaultOwner, + sqlReleaseTableNamespaceColumn, ) mock. - ExpectQuery(query). + ExpectQuery(regexp.QuoteMeta(query)). + WithArgs(sqlReleaseDefaultOwner, sqlDriver.namespace). WillReturnRows( mock.NewRows([]string{ sqlReleaseTableBodyColumn,