From aa6104e442ef8b574ae88efd8b5c41004437ac88 Mon Sep 17 00:00:00 2001 From: Matthew Fisher Date: Fri, 8 Nov 2019 07:32:50 -0800 Subject: [PATCH] fix(tlsutil): accept only a CA certificate for validation Signed-off-by: Matthew Fisher --- internal/tlsutil/tls.go | 16 ++++++---- internal/tlsutil/tlsutil_test.go | 51 ++++++++++++++++++++++++++++++++ pkg/getter/httpgetter.go | 2 +- pkg/getter/httpgetter_test.go | 10 +++++++ 4 files changed, 72 insertions(+), 7 deletions(-) diff --git a/internal/tlsutil/tls.go b/internal/tlsutil/tls.go index dc123e1e5..ed7795dbe 100644 --- a/internal/tlsutil/tls.go +++ b/internal/tlsutil/tls.go @@ -26,13 +26,16 @@ import ( // NewClientTLS returns tls.Config appropriate for client auth. func NewClientTLS(certFile, keyFile, caFile string) (*tls.Config, error) { - cert, err := CertFromFilePair(certFile, keyFile) - if err != nil { - return nil, err - } - config := tls.Config{ - Certificates: []tls.Certificate{*cert}, + config := tls.Config{} + + if certFile != "" && keyFile != "" { + cert, err := CertFromFilePair(certFile, keyFile) + if err != nil { + return nil, err + } + config.Certificates = []tls.Certificate{*cert} } + if caFile != "" { cp, err := CertPoolFromFile(caFile) if err != nil { @@ -40,6 +43,7 @@ func NewClientTLS(certFile, keyFile, caFile string) (*tls.Config, error) { } config.RootCAs = cp } + return &config, nil } diff --git a/internal/tlsutil/tlsutil_test.go b/internal/tlsutil/tlsutil_test.go index c2a477272..a737226fd 100644 --- a/internal/tlsutil/tlsutil_test.go +++ b/internal/tlsutil/tlsutil_test.go @@ -81,3 +81,54 @@ func testfile(t *testing.T, file string) (path string) { } return path } + +func TestNewClientTLS(t *testing.T) { + certFile := testfile(t, testCertFile) + keyFile := testfile(t, testKeyFile) + caCertFile := testfile(t, testCaCertFile) + + cfg, err := NewClientTLS(certFile, keyFile, caCertFile) + if err != nil { + t.Error(err) + } + + if got := len(cfg.Certificates); got != 1 { + t.Fatalf("expecting 1 client certificates, got %d", got) + } + if cfg.InsecureSkipVerify { + t.Fatalf("insecure skip verify mistmatch, expecting false") + } + if cfg.RootCAs == nil { + t.Fatalf("mismatch tls RootCAs, expecting non-nil") + } + + cfg, err = NewClientTLS("", "", caCertFile) + if err != nil { + t.Error(err) + } + + if got := len(cfg.Certificates); got != 0 { + t.Fatalf("expecting 0 client certificates, got %d", got) + } + if cfg.InsecureSkipVerify { + t.Fatalf("insecure skip verify mistmatch, expecting false") + } + if cfg.RootCAs == nil { + t.Fatalf("mismatch tls RootCAs, expecting non-nil") + } + + cfg, err = NewClientTLS(certFile, keyFile, "") + if err != nil { + t.Error(err) + } + + if got := len(cfg.Certificates); got != 1 { + t.Fatalf("expecting 1 client certificates, got %d", got) + } + if cfg.InsecureSkipVerify { + t.Fatalf("insecure skip verify mistmatch, expecting false") + } + if cfg.RootCAs != nil { + t.Fatalf("mismatch tls RootCAs, expecting nil") + } +} diff --git a/pkg/getter/httpgetter.go b/pkg/getter/httpgetter.go index a36ab0321..89abfb1cf 100644 --- a/pkg/getter/httpgetter.go +++ b/pkg/getter/httpgetter.go @@ -89,7 +89,7 @@ func NewHTTPGetter(options ...Option) (Getter, error) { } func (g *HTTPGetter) httpClient() (*http.Client, error) { - if g.opts.certFile != "" && g.opts.keyFile != "" { + if (g.opts.certFile != "" && g.opts.keyFile != "") || g.opts.caFile != "" { tlsConf, err := tlsutil.NewClientTLS(g.opts.certFile, g.opts.keyFile, g.opts.caFile) if err != nil { return nil, errors.Wrap(err, "can't create TLS config for client") diff --git a/pkg/getter/httpgetter_test.go b/pkg/getter/httpgetter_test.go index 0c2c55ea3..b20085574 100644 --- a/pkg/getter/httpgetter_test.go +++ b/pkg/getter/httpgetter_test.go @@ -180,4 +180,14 @@ func TestDownloadTLS(t *testing.T) { if _, err := g.Get(u.String(), WithURL(u.String()), WithTLSClientConfig(pub, priv, ca)); err != nil { t.Error(err) } + + // test with only the CA file (see also #6635) + g, err = NewHTTPGetter() + if err != nil { + t.Fatal(err) + } + + if _, err := g.Get(u.String(), WithURL(u.String()), WithTLSClientConfig("", "", ca)); err != nil { + t.Error(err) + } }