diff --git a/pkg/getter/getter.go b/pkg/getter/getter.go index 8ee08cb7f..cdf00c190 100644 --- a/pkg/getter/getter.go +++ b/pkg/getter/getter.go @@ -18,6 +18,7 @@ package getter import ( "bytes" + "context" "time" "github.com/pkg/errors" @@ -38,6 +39,7 @@ type options struct { password string userAgent string timeout time.Duration + context context.Context } // Option allows specifying various settings configurable by the user for overriding the defaults @@ -90,6 +92,13 @@ func WithTimeout(timeout time.Duration) Option { } } +// WithContext sets the context for requests +func WithContext(ctx context.Context) Option { + return func(opts *options) { + opts.context = ctx + } +} + // Getter is an interface to support GET to the specified URL. type Getter interface { // Get file content by url string diff --git a/pkg/getter/httpgetter.go b/pkg/getter/httpgetter.go index c100b2cc0..de14914b1 100644 --- a/pkg/getter/httpgetter.go +++ b/pkg/getter/httpgetter.go @@ -51,6 +51,9 @@ func (g *HTTPGetter) get(href string) (*bytes.Buffer, error) { return buf, err } + if g.opts.context != nil { + req = req.WithContext(g.opts.context) + } req.Header.Set("User-Agent", version.GetUserAgent()) if g.opts.userAgent != "" { req.Header.Set("User-Agent", g.opts.userAgent) diff --git a/pkg/getter/httpgetter_test.go b/pkg/getter/httpgetter_test.go index 90578f7b7..509793f88 100644 --- a/pkg/getter/httpgetter_test.go +++ b/pkg/getter/httpgetter_test.go @@ -16,7 +16,9 @@ limitations under the License. package getter import ( + "context" "fmt" + "github.com/pkg/errors" "io" "net/http" "net/http/httptest" @@ -28,8 +30,6 @@ import ( "testing" "time" - "github.com/pkg/errors" - "helm.sh/helm/v3/internal/tlsutil" "helm.sh/helm/v3/internal/version" "helm.sh/helm/v3/pkg/cli" @@ -50,6 +50,7 @@ func TestHTTPGetter(t *testing.T) { ca, pub, priv := join(cd, "rootca.crt"), join(cd, "crt.pem"), join(cd, "key.pem") insecure := false timeout := time.Second * 5 + ctx := context.TODO() // Test with options g, err = NewHTTPGetter( @@ -58,6 +59,7 @@ func TestHTTPGetter(t *testing.T) { WithTLSClientConfig(pub, priv, ca), WithInsecureSkipVerifyTLS(insecure), WithTimeout(timeout), + WithContext(ctx), ) if err != nil { t.Fatal(err) @@ -100,6 +102,10 @@ func TestHTTPGetter(t *testing.T) { t.Errorf("Expected NewHTTPGetter to contain %s as Timeout flag, got %s", timeout, hg.opts.timeout) } + if hg.opts.context != ctx { + t.Errorf("Expected NewHTTPGetter to contain %s as Context flag, got %s", ctx, hg.opts.context) + } + // Test if setting insecureSkipVerifyTLS is being passed to the ops insecure = true diff --git a/pkg/getter/plugingetter.go b/pkg/getter/plugingetter.go index 0d13ade57..4ac5e1996 100644 --- a/pkg/getter/plugingetter.go +++ b/pkg/getter/plugingetter.go @@ -70,6 +70,9 @@ func (p *pluginGetter) Get(href string, options ...Option) (*bytes.Buffer, error commands := strings.Split(p.command, " ") argv := append(commands[1:], p.opts.certFile, p.opts.keyFile, p.opts.caFile, href) prog := exec.Command(filepath.Join(p.base, commands[0]), argv...) + if p.opts.context != nil { + prog = exec.CommandContext(p.opts.context, filepath.Join(p.base, commands[0]), argv...) + } plugin.SetupPluginEnv(p.settings, p.name, p.base) prog.Env = os.Environ() buf := bytes.NewBuffer(nil)