diff --git a/pkg/registry/authorizer.go b/pkg/registry/authorizer.go new file mode 100644 index 000000000..6d8dd49a0 --- /dev/null +++ b/pkg/registry/authorizer.go @@ -0,0 +1,107 @@ +/* +Copyright The Helm Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package registry + +import ( + "context" + "net/http" + "strings" + "sync" + + "helm.sh/helm/v4/internal/version" + + "oras.land/oras-go/v2/registry/remote/auth" + "oras.land/oras-go/v2/registry/remote/credentials" +) + +type Authorizer struct { + auth.Client + lock sync.RWMutex + attemptBearerAuthentication bool +} + +func NewAuthorizer(httpClient *http.Client, credentialsStore credentials.Store, username, password string) *Authorizer { + authorizer := Authorizer{ + Client: auth.Client{ + Client: httpClient, + }, + } + authorizer.SetUserAgent(version.GetUserAgent()) + + if username != "" && password != "" { + authorizer.Credential = func(_ context.Context, _ string) (auth.Credential, error) { + return auth.Credential{Username: username, Password: password}, nil + } + } else { + authorizer.Credential = credentials.Credential(credentialsStore) + } + + authorizer.setAttemptBearerAuthentication(true) + return &authorizer +} + +func (a *Authorizer) EnableCache() { + a.Cache = auth.NewCache() +} + +func (a *Authorizer) getAttemptBearerAuthentication() bool { + a.lock.RLock() + defer a.lock.RUnlock() + return a.attemptBearerAuthentication +} + +func (a *Authorizer) setAttemptBearerAuthentication(value bool) { + a.lock.Lock() + defer a.lock.Unlock() + a.attemptBearerAuthentication = value +} + +func (a *Authorizer) getForceAttemptOAuth2() bool { + a.lock.RLock() + defer a.lock.RUnlock() + return a.ForceAttemptOAuth2 +} + +func (a *Authorizer) setForceAttemptOAuth2(value bool) { + a.lock.Lock() + defer a.lock.Unlock() + a.ForceAttemptOAuth2 = value +} + +// Do This method wraps auth.Client.Do in attempt to retry authentication +func (a *Authorizer) Do(originalReq *http.Request) (*http.Response, error) { + if a.getAttemptBearerAuthentication() { + needsAuthentication := originalReq.Header.Get("Authorization") == "" + if needsAuthentication { + a.setForceAttemptOAuth2(true) + if originalReq.Host == "ghcr.io" { + a.setForceAttemptOAuth2(false) + a.setAttemptBearerAuthentication(false) + } + resp, err := a.Client.Do(originalReq) + if err == nil { + a.setAttemptBearerAuthentication(false) + return resp, nil + } + if !strings.Contains(err.Error(), "response status code 401") && + !strings.Contains(err.Error(), "response status code 403") { + return nil, err + } + } + } + return a.Client.Do(originalReq) +} diff --git a/pkg/registry/authorizer_test.go b/pkg/registry/authorizer_test.go new file mode 100644 index 000000000..a1cba065e --- /dev/null +++ b/pkg/registry/authorizer_test.go @@ -0,0 +1,345 @@ +/* +Copyright The Helm Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package registry + +import ( + "context" + "errors" + "net/http" + "net/http/httptest" + "strings" + "sync" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "oras.land/oras-go/v2/registry/remote/auth" +) + +type mockCredentialsStore struct { + username string + password string + err error +} + +func (m *mockCredentialsStore) Get(_ context.Context, _ string) (auth.Credential, error) { + if m.err != nil { + return auth.EmptyCredential, m.err + } + return auth.Credential{ + Username: m.username, + Password: m.password, + }, nil +} + +func (m *mockCredentialsStore) Put(_ context.Context, _ string, _ auth.Credential) error { + return nil +} + +func (m *mockCredentialsStore) Delete(_ context.Context, _ string) error { + return nil +} + +func TestNewAuthorizer(t *testing.T) { + tests := []struct { + name string + username string + password string + }{ + { + name: "with username and password", + username: "testuser", + password: "testpass", + }, + { + name: "without credentials", + username: "", + password: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + httpClient := &http.Client{} + credStore := &mockCredentialsStore{} + + authorizer := NewAuthorizer(httpClient, credStore, tt.username, tt.password) + + require.NotNil(t, authorizer) + assert.Equal(t, httpClient, authorizer.Client.Client) + assert.True(t, authorizer.getAttemptBearerAuthentication()) + assert.NotNil(t, authorizer.Credential) + + if tt.username != "" && tt.password != "" { + cred, err := authorizer.Credential(t.Context(), "") + require.NoError(t, err) + assert.Equal(t, tt.username, cred.Username) + assert.Equal(t, tt.password, cred.Password) + } + }) + } +} + +func TestNewAuthorizer_WithCredentialsStore(t *testing.T) { + httpClient := &http.Client{} + credStore := &mockCredentialsStore{ + username: "storeuser", + password: "storepass", + } + + authorizer := NewAuthorizer(httpClient, credStore, "", "") + + require.NotNil(t, authorizer) + + cred, err := authorizer.Credential(t.Context(), "test.com") + require.NoError(t, err) + assert.Equal(t, "storeuser", cred.Username) + assert.Equal(t, "storepass", cred.Password) +} + +func TestAuthorizer_EnableCache(t *testing.T) { + httpClient := &http.Client{} + credStore := &mockCredentialsStore{} + + authorizer := NewAuthorizer(httpClient, credStore, "", "") + assert.Nil(t, authorizer.Cache) + + authorizer.EnableCache() + assert.NotNil(t, authorizer.Cache) +} + +func TestAuthorizer_Do(t *testing.T) { + tests := []struct { + name string + host string + authHeader string + serverStatus int + expectForceOAuth2 bool + expectBearerAuthAfter bool + }{ + { + name: "successful request without auth header", + host: "registry.example.com", + authHeader: "", + serverStatus: 200, + expectForceOAuth2: true, + expectBearerAuthAfter: false, + }, + { + name: "request with existing auth header", + host: "registry.example.com", + authHeader: "Bearer token123", + serverStatus: 200, + expectForceOAuth2: false, + expectBearerAuthAfter: true, + }, + { + name: "ghcr.io special handling", + host: "ghcr.io", + authHeader: "", + serverStatus: 200, + expectForceOAuth2: false, + expectBearerAuthAfter: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(tt.serverStatus) + w.Write([]byte("success")) + })) + defer server.Close() + + httpClient := &http.Client{} + credStore := &mockCredentialsStore{} + + authorizer := NewAuthorizer(httpClient, credStore, "", "") + + req, err := http.NewRequest(http.MethodGet, server.URL, nil) + require.NoError(t, err) + req.Host = tt.host + + if tt.authHeader != "" { + req.Header.Set("Authorization", tt.authHeader) + } + + resp, err := authorizer.Do(req) + + require.NoError(t, err) + require.NotNil(t, resp) + assert.Equal(t, tt.expectBearerAuthAfter, authorizer.getAttemptBearerAuthentication()) + + if tt.authHeader == "" { + assert.Equal(t, tt.expectForceOAuth2, authorizer.getForceAttemptOAuth2()) + } + + resp.Body.Close() + }) + } +} + +func TestAuthorizer_Do_WithBearerAttemptDisabled(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("success")) + })) + defer server.Close() + + httpClient := &http.Client{} + credStore := &mockCredentialsStore{} + + authorizer := NewAuthorizer(httpClient, credStore, "", "") + authorizer.setAttemptBearerAuthentication(false) + + req, err := http.NewRequest(http.MethodGet, server.URL, nil) + require.NoError(t, err) + req.Host = "registry.example.com" + + resp, err := authorizer.Do(req) + + require.NoError(t, err) + require.NotNil(t, resp) + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.False(t, authorizer.getAttemptBearerAuthentication()) + + resp.Body.Close() +} + +func TestAuthorizer_Do_NonRetryableError(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte("internal server error")) + })) + defer server.Close() + + httpClient := &http.Client{} + credStore := &mockCredentialsStore{} + + authorizer := NewAuthorizer(httpClient, credStore, "", "") + + req, err := http.NewRequest(http.MethodGet, server.URL, nil) + require.NoError(t, err) + req.Host = "registry.example.com" + + resp, err := authorizer.Do(req) + + require.NoError(t, err) + require.NotNil(t, resp) + assert.Equal(t, http.StatusInternalServerError, resp.StatusCode) + + resp.Body.Close() +} + +func TestAuthorizer_ConcurrentAccess(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("success")) + })) + defer server.Close() + + httpClient := &http.Client{} + credStore := &mockCredentialsStore{} + authorizer := NewAuthorizer(httpClient, credStore, "", "") + + const numGoroutines = 100 + const numRequests = 10 + + var wg sync.WaitGroup + wg.Add(numGoroutines * 2) + + for i := 0; i < numGoroutines; i++ { + go func() { + defer wg.Done() + for j := 0; j < numRequests; j++ { + req, err := http.NewRequest(http.MethodGet, server.URL, nil) + require.NoError(t, err) + req.Host = "registry.example.com" + + resp, err := authorizer.Do(req) + if err == nil && resp != nil { + resp.Body.Close() + } + } + }() + + go func() { + defer wg.Done() + for j := 0; j < numRequests; j++ { + authorizer.setAttemptBearerAuthentication(true) + val := authorizer.getAttemptBearerAuthentication() + if val != true { + t.Logf("Warning: Expected true but got %v", val) + } + + authorizer.setAttemptBearerAuthentication(false) + val = authorizer.getAttemptBearerAuthentication() + if val != false { + t.Logf("Warning: Expected false but got %v", val) + } + } + }() + } + + wg.Wait() +} + +func TestAuthorizer_Do_StatusCodeErrorChecking(t *testing.T) { + tests := []struct { + name string + errorMsg string + shouldRetry bool + description string + }{ + { + name: "retry on 401 error", + errorMsg: "response status code 401", + shouldRetry: true, + description: "401 errors should trigger retry logic", + }, + { + name: "retry on 403 error", + errorMsg: "response status code 403", + shouldRetry: true, + description: "403 errors should trigger retry logic", + }, + { + name: "no retry on 404 error", + errorMsg: "response status code 404", + shouldRetry: false, + description: "404 errors should not trigger retry logic", + }, + { + name: "no retry on 500 error", + errorMsg: "response status code 500", + shouldRetry: false, + description: "500 errors should not trigger retry logic", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := errors.New(tt.errorMsg) + + should401Retry := strings.Contains(err.Error(), "response status code 401") + should403Retry := strings.Contains(err.Error(), "response status code 403") + actualShouldRetry := should401Retry || should403Retry + + assert.Equal(t, tt.shouldRetry, actualShouldRetry, tt.description) + }) + } +} diff --git a/pkg/registry/client.go b/pkg/registry/client.go index 95250f8da..a66f5b648 100644 --- a/pkg/registry/client.go +++ b/pkg/registry/client.go @@ -41,7 +41,6 @@ import ( "oras.land/oras-go/v2/registry/remote/credentials" "oras.land/oras-go/v2/registry/remote/retry" - "helm.sh/helm/v4/internal/version" chart "helm.sh/helm/v4/pkg/chart/v2" "helm.sh/helm/v4/pkg/helmpath" ) @@ -70,7 +69,7 @@ type ( username string password string out io.Writer - authorizer *auth.Client + authorizer *Authorizer registryAuthorizer RemoteClient credentialsStore credentials.Store httpClient *http.Client @@ -121,23 +120,11 @@ func NewClient(options ...ClientOption) (*Client, error) { } if client.authorizer == nil { - authorizer := auth.Client{ - Client: client.httpClient, - } - authorizer.SetUserAgent(version.GetUserAgent()) - - if client.username != "" && client.password != "" { - authorizer.Credential = func(_ context.Context, _ string) (auth.Credential, error) { - return auth.Credential{Username: client.username, Password: client.password}, nil - } - } else { - authorizer.Credential = credentials.Credential(client.credentialsStore) - } - + authorizer := NewAuthorizer(client.httpClient, client.credentialsStore, client.username, client.password) if client.enableCache { - authorizer.Cache = auth.NewCache() + authorizer.EnableCache() } - client.authorizer = &authorizer + client.authorizer = authorizer } return client, nil @@ -177,17 +164,17 @@ func ClientOptWriter(out io.Writer) ClientOption { } } -// ClientOptAuthorizer returns a function that sets the authorizer setting on a client options set. This +// ClientOptAuthorizer returns a function that sets the Authorizer setting on a client options set. This // can be used to override the default authorization mechanism. // // Depending on the use-case you may need to set both ClientOptAuthorizer and ClientOptRegistryAuthorizer. func ClientOptAuthorizer(authorizer auth.Client) ClientOption { return func(client *Client) { - client.authorizer = &authorizer + client.authorizer = &Authorizer{Client: authorizer} } } -// ClientOptRegistryAuthorizer returns a function that sets the registry authorizer setting on a client options set. This +// ClientOptRegistryAuthorizer returns a function that sets the registry Authorizer setting on a client options set. This // can be used to override the default authorization mechanism. // // Depending on the use-case you may need to set both ClientOptAuthorizer and ClientOptRegistryAuthorizer. @@ -239,18 +226,12 @@ func (c *Client) Login(host string, options ...LoginOption) error { } reg.PlainHTTP = c.plainHTTP cred := auth.Credential{Username: c.username, Password: c.password} - c.authorizer.ForceAttemptOAuth2 = true reg.Client = c.authorizer ctx := context.Background() if err := reg.Ping(ctx); err != nil { - c.authorizer.ForceAttemptOAuth2 = false - if err := reg.Ping(ctx); err != nil { - return fmt.Errorf("authenticating to %q: %w", host, err) - } + return fmt.Errorf("authenticating to %q: %w", host, err) } - // Always restore to false after probing, to avoid forcing POST to token endpoints like GHCR. - c.authorizer.ForceAttemptOAuth2 = false key := credentials.ServerAddressFromRegistry(host) key = credentials.ServerAddressFromHostname(key) @@ -278,10 +259,10 @@ func LoginOptPlainText(isPlainText bool) LoginOption { } } -func ensureTLSConfig(client *auth.Client, setConfig *tls.Config) (*tls.Config, error) { +func ensureTLSConfig(client *Authorizer, setConfig *tls.Config) (*tls.Config, error) { var transport *http.Transport - switch t := client.Client.Transport.(type) { + switch t := client.Client.Client.Transport.(type) { case *http.Transport: transport = t case *retry.Transport: @@ -299,7 +280,7 @@ func ensureTLSConfig(client *auth.Client, setConfig *tls.Config) (*tls.Config, e if transport == nil { // we don't know how to access the http.Transport, most likely the // auth.Client.Client was provided by API user - return nil, fmt.Errorf("unable to access TLS client configuration, the provided HTTP Transport is not supported, given: %T", client.Client.Transport) + return nil, fmt.Errorf("unable to access TLS client configuration, the provided HTTP Transport is not supported, given: %T", client.Client.Client.Transport) } switch { diff --git a/pkg/registry/client_test.go b/pkg/registry/client_test.go index 6ae32e342..2ffd691c2 100644 --- a/pkg/registry/client_test.go +++ b/pkg/registry/client_test.go @@ -18,10 +18,6 @@ package registry import ( "io" - "net/http" - "net/http/httptest" - "path/filepath" - "strings" "testing" ocispec "github.com/opencontainers/image-spec/specs-go/v1" @@ -55,68 +51,3 @@ func TestTagManifestTransformsReferences(t *testing.T) { _, err = memStore.Resolve(ctx, refWithPlus) require.Error(t, err, "Should NOT find the reference with the original +") } - -// Verifies that Login always restores ForceAttemptOAuth2 to false on success. -func TestLogin_ResetsForceAttemptOAuth2_OnSuccess(t *testing.T) { - t.Parallel() - - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.URL.Path == "/v2/" { - // Accept either HEAD or GET - w.WriteHeader(http.StatusOK) - return - } - http.NotFound(w, r) - })) - defer srv.Close() - - host := strings.TrimPrefix(srv.URL, "http://") - - credFile := filepath.Join(t.TempDir(), "config.json") - c, err := NewClient( - ClientOptWriter(io.Discard), - ClientOptCredentialsFile(credFile), - ) - if err != nil { - t.Fatalf("NewClient error: %v", err) - } - - if c.authorizer == nil || c.authorizer.ForceAttemptOAuth2 { - t.Fatalf("expected ForceAttemptOAuth2 default to be false") - } - - // Call Login with plain HTTP against our test server - if err := c.Login(host, LoginOptPlainText(true), LoginOptBasicAuth("u", "p")); err != nil { - t.Fatalf("Login error: %v", err) - } - - if c.authorizer.ForceAttemptOAuth2 { - t.Errorf("ForceAttemptOAuth2 should be false after successful Login") - } -} - -// Verifies that Login restores ForceAttemptOAuth2 to false even when ping fails. -func TestLogin_ResetsForceAttemptOAuth2_OnFailure(t *testing.T) { - t.Parallel() - - // Start and immediately close, so connections will fail - srv := httptest.NewServer(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {})) - host := strings.TrimPrefix(srv.URL, "http://") - srv.Close() - - credFile := filepath.Join(t.TempDir(), "config.json") - c, err := NewClient( - ClientOptWriter(io.Discard), - ClientOptCredentialsFile(credFile), - ) - if err != nil { - t.Fatalf("NewClient error: %v", err) - } - - // Invoke Login, expect an error but ForceAttemptOAuth2 must end false - _ = c.Login(host, LoginOptPlainText(true), LoginOptBasicAuth("u", "p")) - - if c.authorizer.ForceAttemptOAuth2 { - t.Errorf("ForceAttemptOAuth2 should be false after failed Login") - } -} diff --git a/pkg/registry/generic.go b/pkg/registry/generic.go index b82132338..14b2d3a46 100644 --- a/pkg/registry/generic.go +++ b/pkg/registry/generic.go @@ -28,7 +28,6 @@ import ( "oras.land/oras-go/v2/content" "oras.land/oras-go/v2/content/memory" "oras.land/oras-go/v2/registry/remote" - "oras.land/oras-go/v2/registry/remote/auth" "oras.land/oras-go/v2/registry/remote/credentials" ) @@ -40,7 +39,7 @@ type GenericClient struct { username string password string out io.Writer - authorizer *auth.Client + authorizer *Authorizer registryAuthorizer RemoteClient credentialsStore credentials.Store httpClient *http.Client