From 45367ca9bfa5385e993dc695928db96f37c67598 Mon Sep 17 00:00:00 2001 From: Matthias Fehr Date: Tue, 22 Feb 2022 15:34:29 +0100 Subject: [PATCH] Add transport option and tests Signed-off-by: Matthias Fehr --- pkg/getter/getter.go | 8 ++++ pkg/getter/httpgetter.go | 30 +++++++++----- pkg/getter/httpgetter_test.go | 76 +++++++++++++++++++++++++++++++++++ 3 files changed, 104 insertions(+), 10 deletions(-) diff --git a/pkg/getter/getter.go b/pkg/getter/getter.go index 7f830bac7..4f8e9da52 100644 --- a/pkg/getter/getter.go +++ b/pkg/getter/getter.go @@ -18,6 +18,7 @@ package getter import ( "bytes" + "net/http" "time" "github.com/pkg/errors" @@ -43,6 +44,7 @@ type options struct { version string registryClient *registry.Client timeout time.Duration + transport *http.Transport } // Option allows specifying various settings configurable by the user for overriding the defaults @@ -119,6 +121,12 @@ func WithUntar() Option { } } +func WithTransport(transport *http.Transport) Option { + return func(opts *options) { + opts.transport = transport + } +} + // Getter is an interface to support GET to the specified URL. type Getter interface { // Get file content by url string diff --git a/pkg/getter/httpgetter.go b/pkg/getter/httpgetter.go index 9b8cc4ad9..bd2d663b5 100644 --- a/pkg/getter/httpgetter.go +++ b/pkg/getter/httpgetter.go @@ -21,6 +21,7 @@ import ( "io" "net/http" "net/url" + "sync" "github.com/pkg/errors" @@ -33,6 +34,7 @@ import ( type HTTPGetter struct { opts options transport *http.Transport + once sync.Once } // Get performs a Get from repo.Getter and returns the body. @@ -107,11 +109,19 @@ func NewHTTPGetter(options ...Option) (Getter, error) { } func (g *HTTPGetter) httpClient() (*http.Client, error) { - if g.transport == nil { - g.transport = &http.Transport{ - DisableCompression: true, - Proxy: http.ProxyFromEnvironment, - } + var transport *http.Transport + + if g.opts.transport != nil { + transport = g.opts.transport + } else { + g.once.Do(func() { + g.transport = &http.Transport{ + DisableCompression: true, + Proxy: http.ProxyFromEnvironment, + } + }) + + transport = g.transport } if (g.opts.certFile != "" && g.opts.keyFile != "") || g.opts.caFile != "" { @@ -127,21 +137,21 @@ func (g *HTTPGetter) httpClient() (*http.Client, error) { } tlsConf.ServerName = sni - g.transport.TLSClientConfig = tlsConf + transport.TLSClientConfig = tlsConf } if g.opts.insecureSkipVerifyTLS { - if g.transport.TLSClientConfig == nil { - g.transport.TLSClientConfig = &tls.Config{ + if transport.TLSClientConfig == nil { + transport.TLSClientConfig = &tls.Config{ InsecureSkipVerify: true, } } else { - g.transport.TLSClientConfig.InsecureSkipVerify = true + transport.TLSClientConfig.InsecureSkipVerify = true } } client := &http.Client{ - Transport: g.transport, + Transport: transport, Timeout: g.opts.timeout, } diff --git a/pkg/getter/httpgetter_test.go b/pkg/getter/httpgetter_test.go index d823aec0d..5add189d4 100644 --- a/pkg/getter/httpgetter_test.go +++ b/pkg/getter/httpgetter_test.go @@ -50,6 +50,7 @@ func TestHTTPGetter(t *testing.T) { ca, pub, priv := join(cd, "rootca.crt"), join(cd, "crt.pem"), join(cd, "key.pem") insecure := false timeout := time.Second * 5 + transport := &http.Transport{} // Test with options g, err = NewHTTPGetter( @@ -59,6 +60,7 @@ func TestHTTPGetter(t *testing.T) { WithTLSClientConfig(pub, priv, ca), WithInsecureSkipVerifyTLS(insecure), WithTimeout(timeout), + WithTransport(transport), ) if err != nil { t.Fatal(err) @@ -105,6 +107,10 @@ func TestHTTPGetter(t *testing.T) { t.Errorf("Expected NewHTTPGetter to contain %s as Timeout flag, got %s", timeout, hg.opts.timeout) } + if hg.opts.transport != transport { + t.Errorf("Expected NewHTTPGetter to contain %p as Transport, got %p", transport, hg.opts.transport) + } + // Test if setting insecureSkipVerifyTLS is being passed to the ops insecure = true @@ -443,3 +449,73 @@ func verifyInsecureSkipVerify(t *testing.T, g HTTPGetter, caseName string, expec } return transport } + +func TestDefaultHTTPTransportReuse(t *testing.T) { + g := HTTPGetter{} + + httpClient1, err := g.httpClient() + + if err != nil { + t.Fatal(err) + } + + if httpClient1 == nil { + t.Fatalf("Expected non nil value for http client") + } + + transport1 := (httpClient1.Transport).(*http.Transport) + + httpClient2, err := g.httpClient() + + if err != nil { + t.Fatal(err) + } + + if httpClient2 == nil { + t.Fatalf("Expected non nil value for http client") + } + + transport2 := (httpClient2.Transport).(*http.Transport) + + if transport1 != transport2 { + t.Fatalf("Expected default transport to be reused") + } +} + +func TestHTTPTransportOption(t *testing.T) { + transport := &http.Transport{} + + g := HTTPGetter{} + g.opts.transport = transport + httpClient1, err := g.httpClient() + + if err != nil { + t.Fatal(err) + } + + if httpClient1 == nil { + t.Fatalf("Expected non nil value for http client") + } + + transport1 := (httpClient1.Transport).(*http.Transport) + + if transport1 != transport { + t.Fatalf("Expected transport option to be applied") + } + + httpClient2, err := g.httpClient() + + if err != nil { + t.Fatal(err) + } + + if httpClient2 == nil { + t.Fatalf("Expected non nil value for http client") + } + + transport2 := (httpClient2.Transport).(*http.Transport) + + if transport1 != transport2 { + t.Fatalf("Expected applied transport to be reused") + } +}