Prevent potential SQL injection

Signed-off-by: Elliot Maincourt <e.maincourt@gmail.com>
pull/7635/head
Elliot Maincourt 6 years ago
parent 2c962d7bfc
commit 7faa8c9ebb

@ -180,7 +180,7 @@ type SQLReleaseWrapper struct {
ModifiedAt int `db:"modifiedAt"` ModifiedAt int `db:"modifiedAt"`
} }
// NewSQL initializes a new memory driver. // NewSQL initializes a new sql driver.
func NewSQL(dialect, connectionString string, logger func(string, ...interface{}), namespace string) (*SQL, error) { func NewSQL(dialect, connectionString string, logger func(string, ...interface{}), namespace string) (*SQL, error) {
if _, ok := supportedSQLDialects[dialect]; !ok { if _, ok := supportedSQLDialects[dialect]; !ok {
return nil, fmt.Errorf("%s dialect isn't supported, only \"postgres\" is available for now", dialect) return nil, fmt.Errorf("%s dialect isn't supported, only \"postgres\" is available for now", dialect)
@ -235,20 +235,22 @@ func (s *SQL) Get(key string) (*rspb.Release, error) {
// List returns the list of all releases such that filter(release) == true // List returns the list of all releases such that filter(release) == true
func (s *SQL) List(filter func(*rspb.Release) bool) ([]*rspb.Release, error) { func (s *SQL) List(filter func(*rspb.Release) bool) ([]*rspb.Release, error) {
query := fmt.Sprintf( query := fmt.Sprintf(
"SELECT %s FROM %s WHERE %s = '%s'", "SELECT %s FROM %s WHERE %s = $1",
sqlReleaseTableBodyColumn, sqlReleaseTableBodyColumn,
sqlReleaseTableName, sqlReleaseTableName,
sqlReleaseTableOwnerColumn, sqlReleaseTableOwnerColumn,
sqlReleaseDefaultOwner,
) )
args := []interface{}{sqlReleaseDefaultOwner}
// If a namespace was specified, we only list releases from that namespace // If a namespace was specified, we only list releases from that namespace
if s.namespace != "" { if s.namespace != "" {
query = fmt.Sprintf("%s AND %s = '%s'", query, sqlReleaseTableNamespaceColumn, s.namespace) query = fmt.Sprintf("%s AND %s = $2", query, sqlReleaseTableNamespaceColumn)
args = append(args, s.namespace)
} }
var records = []SQLReleaseWrapper{} var records = []SQLReleaseWrapper{}
if err := s.db.Select(&records, query); err != nil { if err := s.db.Select(&records, query, args...); err != nil {
s.Log("list: failed to list: %v", err) s.Log("list: failed to list: %v", err)
return nil, err return nil, err
} }

@ -88,15 +88,16 @@ func TestSQLList(t *testing.T) {
for i := 0; i < 3; i++ { for i := 0; i < 3; i++ {
query := fmt.Sprintf( query := fmt.Sprintf(
"SELECT %s FROM %s WHERE %s = '%s'", "SELECT %s FROM %s WHERE %s = $1 AND %s = $2",
sqlReleaseTableBodyColumn, sqlReleaseTableBodyColumn,
sqlReleaseTableName, sqlReleaseTableName,
sqlReleaseTableOwnerColumn, sqlReleaseTableOwnerColumn,
sqlReleaseDefaultOwner, sqlReleaseTableNamespaceColumn,
) )
mock. mock.
ExpectQuery(query). ExpectQuery(regexp.QuoteMeta(query)).
WithArgs(sqlReleaseDefaultOwner, sqlDriver.namespace).
WillReturnRows( WillReturnRows(
mock.NewRows([]string{ mock.NewRows([]string{
sqlReleaseTableBodyColumn, sqlReleaseTableBodyColumn,

Loading…
Cancel
Save