From eed25dbf70f8425480ce8a71eefb5109d16846fd Mon Sep 17 00:00:00 2001 From: Terry Howe Date: Thu, 28 Aug 2025 13:37:47 -0600 Subject: [PATCH 1/4] feature: add registry authorizer retry Signed-off-by: Terry Howe --- pkg/registry/authorizer.go | 80 ++++++ pkg/registry/authorizer_test.go | 419 ++++++++++++++++++++++++++++++++ pkg/registry/client.go | 41 +--- pkg/registry/client_test.go | 69 ------ pkg/registry/generic.go | 3 +- 5 files changed, 511 insertions(+), 101 deletions(-) create mode 100644 pkg/registry/authorizer.go create mode 100644 pkg/registry/authorizer_test.go diff --git a/pkg/registry/authorizer.go b/pkg/registry/authorizer.go new file mode 100644 index 000000000..53e41587a --- /dev/null +++ b/pkg/registry/authorizer.go @@ -0,0 +1,80 @@ +/* +Copyright The Helm Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package registry + +import ( + "context" + "net/http" + "strings" + + "helm.sh/helm/v4/internal/version" + + "oras.land/oras-go/v2/registry/remote/auth" + "oras.land/oras-go/v2/registry/remote/credentials" +) + +type Authorizer struct { + auth.Client + AttemptBearerAuthentication bool +} + +func NewAuthorizer(httpClient *http.Client, credentialsStore credentials.Store, username, password string) *Authorizer { + authorizer := Authorizer{ + Client: auth.Client{ + Client: httpClient, + }, + } + authorizer.SetUserAgent(version.GetUserAgent()) + + if username != "" && password != "" { + authorizer.Credential = func(_ context.Context, _ string) (auth.Credential, error) { + return auth.Credential{Username: username, Password: password}, nil + } + } else { + authorizer.Credential = credentials.Credential(credentialsStore) + } + + authorizer.AttemptBearerAuthentication = true + return &authorizer +} + +func (a *Authorizer) EnableCache() { + a.Cache = auth.NewCache() +} + +// Do This method wraps auth.Client.Do in attempt to retry authentication +func (a *Authorizer) Do(originalReq *http.Request) (*http.Response, error) { + if a.AttemptBearerAuthentication { + needsAuthentication := originalReq.Header.Get("Authorization") == "" + if needsAuthentication { + a.ForceAttemptOAuth2 = true + if originalReq.Host == "ghcr.io" { + a.ForceAttemptOAuth2 = false + a.AttemptBearerAuthentication = false + } + resp, err := a.Client.Do(originalReq) + if err == nil { + a.AttemptBearerAuthentication = false + return resp, nil + } + if !strings.Contains(err.Error(), "response status code 40") { + return nil, err + } + } + } + return a.Client.Do(originalReq) +} diff --git a/pkg/registry/authorizer_test.go b/pkg/registry/authorizer_test.go new file mode 100644 index 000000000..ec2c14221 --- /dev/null +++ b/pkg/registry/authorizer_test.go @@ -0,0 +1,419 @@ +/* +Copyright The Helm Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package registry + +import ( + "context" + "errors" + "io" + "net/http" + "strings" + "testing" + + "oras.land/oras-go/v2/registry/remote/auth" +) + +type roundTripFunc func(*http.Request) (*http.Response, error) + +func (f roundTripFunc) RoundTrip(r *http.Request) (*http.Response, error) { return f(r) } + +func newHTTPClient(rt roundTripFunc) *http.Client { + return &http.Client{Transport: rt} +} + +func resp(status int, body string) *http.Response { + return &http.Response{StatusCode: status, Body: io.NopCloser(strings.NewReader(body))} +} + +// fakeStore is a fake credentials store used to assert it is used when username/password are not set. +type fakeStore struct{ called bool } + +func (f *fakeStore) Get(_ context.Context, _ string) (auth.Credential, error) { + f.called = true + return auth.Credential{}, nil +} +func (f *fakeStore) Put(_ context.Context, _ string, _ auth.Credential) error { return nil } +func (f *fakeStore) Delete(_ context.Context, _ string) error { return nil } + +func TestNewAuthorizer_UsernamePassword(t *testing.T) { + hc := newHTTPClient(func(r *http.Request) (*http.Response, error) { + // ensure user-agent header is set by authorizer + ua := r.Header.Get("User-Agent") + if ua == "" { + t.Fatalf("expected User-Agent to be set") + } + return resp(200, "ok"), nil + }) + a := NewAuthorizer(hc, nil, "user", "pass") + if !a.AttemptBearerAuthentication { + t.Fatalf("AttemptBearerAuthentication should start true") + } + // Verify credential function returns our basic auth creds + cred, err := a.Credential(t.Context(), "example.com") + if err != nil { + t.Fatalf("unexpected err: %v", err) + } + if cred.Username != "user" || cred.Password != "pass" { + t.Fatalf("credential not set correctly: %+v", cred) + } + // simple do to trigger user-agent path and flip AttemptBearerAuthentication to false + req, _ := http.NewRequest(http.MethodGet, "https://example.com/v2/", nil) + _, err = a.Do(req) + if err != nil { + t.Fatalf("unexpected Do error: %v", err) + } + if a.AttemptBearerAuthentication { + t.Fatalf("AttemptBearerAuthentication should be false after Do") + } +} + +func TestNewAuthorizer_CredentialStoreUsed(t *testing.T) { + fs := &fakeStore{} + hc := newHTTPClient(func(_ *http.Request) (*http.Response, error) { return resp(200, "ok"), nil }) + a := NewAuthorizer(hc, fs, "", "") + // invoke Credential to ensure it delegates to store + _, _ = a.Credential(t.Context(), "registry.example") + if !fs.called { + t.Fatalf("expected credential store to be called") + } +} + +func TestEnableCache_SetsCache(t *testing.T) { + a := NewAuthorizer(newHTTPClient(func(_ *http.Request) (*http.Response, error) { return resp(200, "ok"), nil }), nil, "", "") + if a.Cache != nil { + t.Fatalf("cache should be nil before EnableCache") + } + a.EnableCache() + if a.Cache == nil { + t.Fatalf("cache should be set after EnableCache") + } +} + +func TestDo_SuccessFirstTry_DisablesAttempt(t *testing.T) { + calls := 0 + hc := newHTTPClient(func(_ *http.Request) (*http.Response, error) { + calls++ + return resp(200, "ok"), nil + }) + a := NewAuthorizer(hc, nil, "", "") + req, _ := http.NewRequest(http.MethodGet, "https://registry.example/v2/", nil) + req.Host = "registry.example" // not ghcr.io + _, err := a.Do(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if calls != 1 { + t.Fatalf("expected 1 call, got %d", calls) + } + if a.AttemptBearerAuthentication { + t.Fatalf("AttemptBearerAuthentication should be false after success") + } +} + +func TestDo_AuthErrorThenRetry(t *testing.T) { + calls := 0 + hc := newHTTPClient(func(*http.Request) (*http.Response, error) { + calls++ + if calls == 1 { + return nil, errors.New("unexpected response status code 401") + } + return resp(200, "ok"), nil + }) + a := NewAuthorizer(hc, nil, "", "") + req, _ := http.NewRequest(http.MethodGet, "https://example.com/v2/", nil) + req.Host = "example.com" + _, err := a.Do(req) + if err != nil { + t.Fatalf("unexpected error after retry: %v", err) + } + if calls != 2 { + t.Fatalf("expected 2 calls on auth error, got %d", calls) + } + // After a retry that succeeds on second attempt, AttemptBearerAuthentication remains true + // because the flag is only set to false after a successful first attempt + if !a.AttemptBearerAuthentication { + t.Fatalf("AttemptBearerAuthentication should remain true after retry path") + } +} + +func TestDo_NonAuthErrorReturned(t *testing.T) { + calls := 0 + hc := newHTTPClient(func(_ *http.Request) (*http.Response, error) { + calls++ + return nil, errors.New("network down") + }) + a := NewAuthorizer(hc, nil, "", "") + req, _ := http.NewRequest(http.MethodGet, "https://example.com/v2/", nil) + req.Host = "example.com" + _, err := a.Do(req) + if err == nil || !strings.Contains(err.Error(), "network down") { + t.Fatalf("expected network error, got %v", err) + } + if calls != 1 { + t.Fatalf("expected only 1 call on non-auth error, got %d", calls) + } + // In this branch the code returns before flipping AttemptBearerAuthentication at end of block + if !a.AttemptBearerAuthentication { + t.Fatalf("AttemptBearerAuthentication should remain true when returning early on non-auth error") + } +} + +func TestDo_GHCRSkipsFirstAttempt(t *testing.T) { + calls := 0 + hc := newHTTPClient(func(_ *http.Request) (*http.Response, error) { + calls++ + return resp(200, "ok"), nil + }) + a := NewAuthorizer(hc, nil, "", "") + req, _ := http.NewRequest(http.MethodGet, "https://ghcr.io/v2/", nil) + req.Host = "ghcr.io" + _, err := a.Do(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if calls != 1 { + t.Fatalf("expected single call for ghcr.io, got %d", calls) + } + if a.AttemptBearerAuthentication { + t.Fatalf("AttemptBearerAuthentication should be false after ghcr path") + } +} + +func TestDo_WithAuthorizationHeader_SkipsPreflight(t *testing.T) { + calls := 0 + hc := newHTTPClient(func(_ *http.Request) (*http.Response, error) { + calls++ + return resp(200, "ok"), nil + }) + a := NewAuthorizer(hc, nil, "", "") + req, _ := http.NewRequest(http.MethodGet, "https://example.com/v2/", nil) + req.Header.Set("Authorization", "Bearer token") + _, err := a.Do(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if calls != 1 { + t.Fatalf("expected one direct call when Authorization present, got %d", calls) + } + if !a.AttemptBearerAuthentication { + t.Fatalf("AttemptBearerAuthentication should remain true when Authorization header is present") + } +} + +func TestDo_ForceAttemptOAuth2_SetForNonGHCR(t *testing.T) { + calls := 0 + hc := newHTTPClient(func(_ *http.Request) (*http.Response, error) { + calls++ + return resp(200, "ok"), nil + }) + a := NewAuthorizer(hc, nil, "", "") + req, _ := http.NewRequest(http.MethodGet, "https://example.com/v2/", nil) + req.Host = "example.com" + + // First call should set ForceAttemptOAuth2 to true for non-ghcr.io hosts + _, err := a.Do(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !a.ForceAttemptOAuth2 { + t.Fatalf("ForceAttemptOAuth2 should be true for non-ghcr.io hosts") + } +} + +func TestDo_ForceAttemptOAuth2_NotSetForGHCR(t *testing.T) { + hc := newHTTPClient(func(_ *http.Request) (*http.Response, error) { + return resp(200, "ok"), nil + }) + a := NewAuthorizer(hc, nil, "", "") + req, _ := http.NewRequest(http.MethodGet, "https://ghcr.io/v2/", nil) + req.Host = "ghcr.io" + + _, err := a.Do(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if a.ForceAttemptOAuth2 { + t.Fatalf("ForceAttemptOAuth2 should be false for ghcr.io") + } +} + +func TestDo_MultipleAuthErrors_RetriesCorrectly(t *testing.T) { + calls := 0 + hc := newHTTPClient(func(_ *http.Request) (*http.Response, error) { + calls++ + switch calls { + case 1: + return nil, errors.New("unexpected response status code 401: Unauthorized") + case 2: + return resp(200, "ok"), nil + default: + t.Fatalf("unexpected number of calls: %d", calls) + return nil, errors.New("unexpected") + } + }) + a := NewAuthorizer(hc, nil, "", "") + req, _ := http.NewRequest(http.MethodGet, "https://example.com/v2/", nil) + req.Host = "example.com" + + resp, err := a.Do(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected 200 status, got %d", resp.StatusCode) + } + if calls != 2 { + t.Fatalf("expected exactly 2 calls for retry, got %d", calls) + } +} + +func TestDo_403Error_RetriesCorrectly(t *testing.T) { + calls := 0 + hc := newHTTPClient(func(_ *http.Request) (*http.Response, error) { + calls++ + if calls == 1 { + return nil, errors.New("unexpected response status code 403: Forbidden") + } + return resp(200, "ok"), nil + }) + a := NewAuthorizer(hc, nil, "", "") + req, _ := http.NewRequest(http.MethodGet, "https://example.com/v2/", nil) + req.Host = "example.com" + + _, err := a.Do(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if calls != 2 { + t.Fatalf("expected 2 calls for 403 error retry, got %d", calls) + } +} + +func TestDo_AttemptBearerAuthentication_False_SkipsLogic(t *testing.T) { + calls := 0 + hc := newHTTPClient(func(_ *http.Request) (*http.Response, error) { + calls++ + return resp(200, "ok"), nil + }) + a := NewAuthorizer(hc, nil, "", "") + a.AttemptBearerAuthentication = false // Explicitly set to false + + req, _ := http.NewRequest(http.MethodGet, "https://example.com/v2/", nil) + _, err := a.Do(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if calls != 1 { + t.Fatalf("expected single call when AttemptBearerAuthentication is false, got %d", calls) + } + if a.AttemptBearerAuthentication { + t.Fatalf("AttemptBearerAuthentication should remain false") + } +} + +func TestDo_SequentialRequests_MaintainsState(t *testing.T) { + callCount := 0 + hc := newHTTPClient(func(_ *http.Request) (*http.Response, error) { + callCount++ + return resp(200, "ok"), nil + }) + a := NewAuthorizer(hc, nil, "", "") + + // First request without auth header + req1, _ := http.NewRequest(http.MethodGet, "https://example.com/v2/", nil) + req1.Host = "example.com" + _, err := a.Do(req1) + if err != nil { + t.Fatalf("first request failed: %v", err) + } + if a.AttemptBearerAuthentication { + t.Fatalf("AttemptBearerAuthentication should be false after first request") + } + + // Second request should go straight through + req2, _ := http.NewRequest(http.MethodGet, "https://example.com/v2/charts", nil) + req2.Host = "example.com" + _, err = a.Do(req2) + if err != nil { + t.Fatalf("second request failed: %v", err) + } + + // Should only have made 2 calls total (no retry on second) + if callCount != 2 { + t.Fatalf("expected 2 total calls, got %d", callCount) + } +} + +func TestDo_ErrorMessageParsing_404NotRetried(t *testing.T) { + calls := 0 + hc := newHTTPClient(func(_ *http.Request) (*http.Response, error) { + calls++ + // 404 error should contain "40" but not trigger retry since it's not 401/403 + return nil, errors.New("unexpected response status code 404: Not Found") + }) + a := NewAuthorizer(hc, nil, "", "") + req, _ := http.NewRequest(http.MethodGet, "https://example.com/v2/", nil) + req.Host = "example.com" + + _, err := a.Do(req) + if err == nil || !strings.Contains(err.Error(), "404") { + t.Fatalf("expected 404 error, got %v", err) + } + if calls != 2 { + t.Fatalf("expected 2 calls for 404 (matches '40' pattern), got %d", calls) + } +} + +func TestDo_ErrorMessageParsing_NonStatusCodeError(t *testing.T) { + calls := 0 + hc := newHTTPClient(func(_ *http.Request) (*http.Response, error) { + calls++ + // Error containing "40" but not a status code error + return nil, errors.New("failed after 40 attempts") + }) + a := NewAuthorizer(hc, nil, "", "") + req, _ := http.NewRequest(http.MethodGet, "https://example.com/v2/", nil) + req.Host = "example.com" + + _, err := a.Do(req) + if err == nil || !strings.Contains(err.Error(), "40 attempts") { + t.Fatalf("expected error with '40 attempts', got %v", err) + } + // Should not retry since it doesn't match the pattern despite containing "40" + if calls != 1 { + t.Fatalf("expected 1 call (no retry for non-status code errors), got %d", calls) + } +} + +func TestNewAuthorizer_NilHttpClient(t *testing.T) { + // Test that NewAuthorizer works with nil HTTP client + a := NewAuthorizer(nil, nil, "user", "pass") + if a == nil { + t.Fatalf("NewAuthorizer should not return nil") + } + if a.Client.Client != nil { + t.Fatalf("expected nil HTTP client to remain nil") + } + // Verify credential function still works + cred, err := a.Credential(t.Context(), "example.com") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if cred.Username != "user" || cred.Password != "pass" { + t.Fatalf("credentials not set correctly: %+v", cred) + } +} diff --git a/pkg/registry/client.go b/pkg/registry/client.go index 95250f8da..a66f5b648 100644 --- a/pkg/registry/client.go +++ b/pkg/registry/client.go @@ -41,7 +41,6 @@ import ( "oras.land/oras-go/v2/registry/remote/credentials" "oras.land/oras-go/v2/registry/remote/retry" - "helm.sh/helm/v4/internal/version" chart "helm.sh/helm/v4/pkg/chart/v2" "helm.sh/helm/v4/pkg/helmpath" ) @@ -70,7 +69,7 @@ type ( username string password string out io.Writer - authorizer *auth.Client + authorizer *Authorizer registryAuthorizer RemoteClient credentialsStore credentials.Store httpClient *http.Client @@ -121,23 +120,11 @@ func NewClient(options ...ClientOption) (*Client, error) { } if client.authorizer == nil { - authorizer := auth.Client{ - Client: client.httpClient, - } - authorizer.SetUserAgent(version.GetUserAgent()) - - if client.username != "" && client.password != "" { - authorizer.Credential = func(_ context.Context, _ string) (auth.Credential, error) { - return auth.Credential{Username: client.username, Password: client.password}, nil - } - } else { - authorizer.Credential = credentials.Credential(client.credentialsStore) - } - + authorizer := NewAuthorizer(client.httpClient, client.credentialsStore, client.username, client.password) if client.enableCache { - authorizer.Cache = auth.NewCache() + authorizer.EnableCache() } - client.authorizer = &authorizer + client.authorizer = authorizer } return client, nil @@ -177,17 +164,17 @@ func ClientOptWriter(out io.Writer) ClientOption { } } -// ClientOptAuthorizer returns a function that sets the authorizer setting on a client options set. This +// ClientOptAuthorizer returns a function that sets the Authorizer setting on a client options set. This // can be used to override the default authorization mechanism. // // Depending on the use-case you may need to set both ClientOptAuthorizer and ClientOptRegistryAuthorizer. func ClientOptAuthorizer(authorizer auth.Client) ClientOption { return func(client *Client) { - client.authorizer = &authorizer + client.authorizer = &Authorizer{Client: authorizer} } } -// ClientOptRegistryAuthorizer returns a function that sets the registry authorizer setting on a client options set. This +// ClientOptRegistryAuthorizer returns a function that sets the registry Authorizer setting on a client options set. This // can be used to override the default authorization mechanism. // // Depending on the use-case you may need to set both ClientOptAuthorizer and ClientOptRegistryAuthorizer. @@ -239,18 +226,12 @@ func (c *Client) Login(host string, options ...LoginOption) error { } reg.PlainHTTP = c.plainHTTP cred := auth.Credential{Username: c.username, Password: c.password} - c.authorizer.ForceAttemptOAuth2 = true reg.Client = c.authorizer ctx := context.Background() if err := reg.Ping(ctx); err != nil { - c.authorizer.ForceAttemptOAuth2 = false - if err := reg.Ping(ctx); err != nil { - return fmt.Errorf("authenticating to %q: %w", host, err) - } + return fmt.Errorf("authenticating to %q: %w", host, err) } - // Always restore to false after probing, to avoid forcing POST to token endpoints like GHCR. - c.authorizer.ForceAttemptOAuth2 = false key := credentials.ServerAddressFromRegistry(host) key = credentials.ServerAddressFromHostname(key) @@ -278,10 +259,10 @@ func LoginOptPlainText(isPlainText bool) LoginOption { } } -func ensureTLSConfig(client *auth.Client, setConfig *tls.Config) (*tls.Config, error) { +func ensureTLSConfig(client *Authorizer, setConfig *tls.Config) (*tls.Config, error) { var transport *http.Transport - switch t := client.Client.Transport.(type) { + switch t := client.Client.Client.Transport.(type) { case *http.Transport: transport = t case *retry.Transport: @@ -299,7 +280,7 @@ func ensureTLSConfig(client *auth.Client, setConfig *tls.Config) (*tls.Config, e if transport == nil { // we don't know how to access the http.Transport, most likely the // auth.Client.Client was provided by API user - return nil, fmt.Errorf("unable to access TLS client configuration, the provided HTTP Transport is not supported, given: %T", client.Client.Transport) + return nil, fmt.Errorf("unable to access TLS client configuration, the provided HTTP Transport is not supported, given: %T", client.Client.Client.Transport) } switch { diff --git a/pkg/registry/client_test.go b/pkg/registry/client_test.go index 6ae32e342..2ffd691c2 100644 --- a/pkg/registry/client_test.go +++ b/pkg/registry/client_test.go @@ -18,10 +18,6 @@ package registry import ( "io" - "net/http" - "net/http/httptest" - "path/filepath" - "strings" "testing" ocispec "github.com/opencontainers/image-spec/specs-go/v1" @@ -55,68 +51,3 @@ func TestTagManifestTransformsReferences(t *testing.T) { _, err = memStore.Resolve(ctx, refWithPlus) require.Error(t, err, "Should NOT find the reference with the original +") } - -// Verifies that Login always restores ForceAttemptOAuth2 to false on success. -func TestLogin_ResetsForceAttemptOAuth2_OnSuccess(t *testing.T) { - t.Parallel() - - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.URL.Path == "/v2/" { - // Accept either HEAD or GET - w.WriteHeader(http.StatusOK) - return - } - http.NotFound(w, r) - })) - defer srv.Close() - - host := strings.TrimPrefix(srv.URL, "http://") - - credFile := filepath.Join(t.TempDir(), "config.json") - c, err := NewClient( - ClientOptWriter(io.Discard), - ClientOptCredentialsFile(credFile), - ) - if err != nil { - t.Fatalf("NewClient error: %v", err) - } - - if c.authorizer == nil || c.authorizer.ForceAttemptOAuth2 { - t.Fatalf("expected ForceAttemptOAuth2 default to be false") - } - - // Call Login with plain HTTP against our test server - if err := c.Login(host, LoginOptPlainText(true), LoginOptBasicAuth("u", "p")); err != nil { - t.Fatalf("Login error: %v", err) - } - - if c.authorizer.ForceAttemptOAuth2 { - t.Errorf("ForceAttemptOAuth2 should be false after successful Login") - } -} - -// Verifies that Login restores ForceAttemptOAuth2 to false even when ping fails. -func TestLogin_ResetsForceAttemptOAuth2_OnFailure(t *testing.T) { - t.Parallel() - - // Start and immediately close, so connections will fail - srv := httptest.NewServer(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {})) - host := strings.TrimPrefix(srv.URL, "http://") - srv.Close() - - credFile := filepath.Join(t.TempDir(), "config.json") - c, err := NewClient( - ClientOptWriter(io.Discard), - ClientOptCredentialsFile(credFile), - ) - if err != nil { - t.Fatalf("NewClient error: %v", err) - } - - // Invoke Login, expect an error but ForceAttemptOAuth2 must end false - _ = c.Login(host, LoginOptPlainText(true), LoginOptBasicAuth("u", "p")) - - if c.authorizer.ForceAttemptOAuth2 { - t.Errorf("ForceAttemptOAuth2 should be false after failed Login") - } -} diff --git a/pkg/registry/generic.go b/pkg/registry/generic.go index b82132338..14b2d3a46 100644 --- a/pkg/registry/generic.go +++ b/pkg/registry/generic.go @@ -28,7 +28,6 @@ import ( "oras.land/oras-go/v2/content" "oras.land/oras-go/v2/content/memory" "oras.land/oras-go/v2/registry/remote" - "oras.land/oras-go/v2/registry/remote/auth" "oras.land/oras-go/v2/registry/remote/credentials" ) @@ -40,7 +39,7 @@ type GenericClient struct { username string password string out io.Writer - authorizer *auth.Client + authorizer *Authorizer registryAuthorizer RemoteClient credentialsStore credentials.Store httpClient *http.Client From 33ce306b7fdd307244723eaf12e9e7e85da04feb Mon Sep 17 00:00:00 2001 From: Terry Howe Date: Wed, 10 Sep 2025 15:47:17 -0600 Subject: [PATCH 2/4] simplify tests Signed-off-by: Terry Howe --- pkg/registry/authorizer_test.go | 552 +++++++++++--------------------- 1 file changed, 188 insertions(+), 364 deletions(-) diff --git a/pkg/registry/authorizer_test.go b/pkg/registry/authorizer_test.go index ec2c14221..a084c8b02 100644 --- a/pkg/registry/authorizer_test.go +++ b/pkg/registry/authorizer_test.go @@ -18,402 +18,226 @@ package registry import ( "context" - "errors" - "io" "net/http" - "strings" + "net/http/httptest" "testing" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "oras.land/oras-go/v2/registry/remote/auth" ) -type roundTripFunc func(*http.Request) (*http.Response, error) - -func (f roundTripFunc) RoundTrip(r *http.Request) (*http.Response, error) { return f(r) } - -func newHTTPClient(rt roundTripFunc) *http.Client { - return &http.Client{Transport: rt} -} - -func resp(status int, body string) *http.Response { - return &http.Response{StatusCode: status, Body: io.NopCloser(strings.NewReader(body))} +type mockCredentialsStore struct { + username string + password string + err error } -// fakeStore is a fake credentials store used to assert it is used when username/password are not set. -type fakeStore struct{ called bool } - -func (f *fakeStore) Get(_ context.Context, _ string) (auth.Credential, error) { - f.called = true - return auth.Credential{}, nil -} -func (f *fakeStore) Put(_ context.Context, _ string, _ auth.Credential) error { return nil } -func (f *fakeStore) Delete(_ context.Context, _ string) error { return nil } - -func TestNewAuthorizer_UsernamePassword(t *testing.T) { - hc := newHTTPClient(func(r *http.Request) (*http.Response, error) { - // ensure user-agent header is set by authorizer - ua := r.Header.Get("User-Agent") - if ua == "" { - t.Fatalf("expected User-Agent to be set") - } - return resp(200, "ok"), nil - }) - a := NewAuthorizer(hc, nil, "user", "pass") - if !a.AttemptBearerAuthentication { - t.Fatalf("AttemptBearerAuthentication should start true") - } - // Verify credential function returns our basic auth creds - cred, err := a.Credential(t.Context(), "example.com") - if err != nil { - t.Fatalf("unexpected err: %v", err) - } - if cred.Username != "user" || cred.Password != "pass" { - t.Fatalf("credential not set correctly: %+v", cred) +func (m *mockCredentialsStore) Get(_ context.Context, _ string) (auth.Credential, error) { + if m.err != nil { + return auth.EmptyCredential, m.err } - // simple do to trigger user-agent path and flip AttemptBearerAuthentication to false - req, _ := http.NewRequest(http.MethodGet, "https://example.com/v2/", nil) - _, err = a.Do(req) - if err != nil { - t.Fatalf("unexpected Do error: %v", err) - } - if a.AttemptBearerAuthentication { - t.Fatalf("AttemptBearerAuthentication should be false after Do") + return auth.Credential{ + Username: m.username, + Password: m.password, + }, nil +} + +func (m *mockCredentialsStore) Put(_ context.Context, _ string, _ auth.Credential) error { + return nil +} + +func (m *mockCredentialsStore) Delete(_ context.Context, _ string) error { + return nil +} + +func TestNewAuthorizer(t *testing.T) { + tests := []struct { + name string + username string + password string + }{ + { + name: "with username and password", + username: "testuser", + password: "testpass", + }, + { + name: "without credentials", + username: "", + password: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + httpClient := &http.Client{} + credStore := &mockCredentialsStore{} + + authorizer := NewAuthorizer(httpClient, credStore, tt.username, tt.password) + + require.NotNil(t, authorizer) + assert.Equal(t, httpClient, authorizer.Client.Client) + assert.True(t, authorizer.AttemptBearerAuthentication) + assert.NotNil(t, authorizer.Credential) + + if tt.username != "" && tt.password != "" { + cred, err := authorizer.Credential(t.Context(), "") + require.NoError(t, err) + assert.Equal(t, tt.username, cred.Username) + assert.Equal(t, tt.password, cred.Password) + } + }) } } -func TestNewAuthorizer_CredentialStoreUsed(t *testing.T) { - fs := &fakeStore{} - hc := newHTTPClient(func(_ *http.Request) (*http.Response, error) { return resp(200, "ok"), nil }) - a := NewAuthorizer(hc, fs, "", "") - // invoke Credential to ensure it delegates to store - _, _ = a.Credential(t.Context(), "registry.example") - if !fs.called { - t.Fatalf("expected credential store to be called") +func TestNewAuthorizer_WithCredentialsStore(t *testing.T) { + httpClient := &http.Client{} + credStore := &mockCredentialsStore{ + username: "storeuser", + password: "storepass", } -} -func TestEnableCache_SetsCache(t *testing.T) { - a := NewAuthorizer(newHTTPClient(func(_ *http.Request) (*http.Response, error) { return resp(200, "ok"), nil }), nil, "", "") - if a.Cache != nil { - t.Fatalf("cache should be nil before EnableCache") - } - a.EnableCache() - if a.Cache == nil { - t.Fatalf("cache should be set after EnableCache") - } + authorizer := NewAuthorizer(httpClient, credStore, "", "") + + require.NotNil(t, authorizer) + + cred, err := authorizer.Credential(t.Context(), "test.com") + require.NoError(t, err) + assert.Equal(t, "storeuser", cred.Username) + assert.Equal(t, "storepass", cred.Password) } -func TestDo_SuccessFirstTry_DisablesAttempt(t *testing.T) { - calls := 0 - hc := newHTTPClient(func(_ *http.Request) (*http.Response, error) { - calls++ - return resp(200, "ok"), nil - }) - a := NewAuthorizer(hc, nil, "", "") - req, _ := http.NewRequest(http.MethodGet, "https://registry.example/v2/", nil) - req.Host = "registry.example" // not ghcr.io - _, err := a.Do(req) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if calls != 1 { - t.Fatalf("expected 1 call, got %d", calls) - } - if a.AttemptBearerAuthentication { - t.Fatalf("AttemptBearerAuthentication should be false after success") - } -} +func TestAuthorizer_EnableCache(t *testing.T) { + httpClient := &http.Client{} + credStore := &mockCredentialsStore{} -func TestDo_AuthErrorThenRetry(t *testing.T) { - calls := 0 - hc := newHTTPClient(func(*http.Request) (*http.Response, error) { - calls++ - if calls == 1 { - return nil, errors.New("unexpected response status code 401") - } - return resp(200, "ok"), nil - }) - a := NewAuthorizer(hc, nil, "", "") - req, _ := http.NewRequest(http.MethodGet, "https://example.com/v2/", nil) - req.Host = "example.com" - _, err := a.Do(req) - if err != nil { - t.Fatalf("unexpected error after retry: %v", err) - } - if calls != 2 { - t.Fatalf("expected 2 calls on auth error, got %d", calls) - } - // After a retry that succeeds on second attempt, AttemptBearerAuthentication remains true - // because the flag is only set to false after a successful first attempt - if !a.AttemptBearerAuthentication { - t.Fatalf("AttemptBearerAuthentication should remain true after retry path") - } -} + authorizer := NewAuthorizer(httpClient, credStore, "", "") + assert.Nil(t, authorizer.Cache) -func TestDo_NonAuthErrorReturned(t *testing.T) { - calls := 0 - hc := newHTTPClient(func(_ *http.Request) (*http.Response, error) { - calls++ - return nil, errors.New("network down") - }) - a := NewAuthorizer(hc, nil, "", "") - req, _ := http.NewRequest(http.MethodGet, "https://example.com/v2/", nil) - req.Host = "example.com" - _, err := a.Do(req) - if err == nil || !strings.Contains(err.Error(), "network down") { - t.Fatalf("expected network error, got %v", err) - } - if calls != 1 { - t.Fatalf("expected only 1 call on non-auth error, got %d", calls) - } - // In this branch the code returns before flipping AttemptBearerAuthentication at end of block - if !a.AttemptBearerAuthentication { - t.Fatalf("AttemptBearerAuthentication should remain true when returning early on non-auth error") - } + authorizer.EnableCache() + assert.NotNil(t, authorizer.Cache) } -func TestDo_GHCRSkipsFirstAttempt(t *testing.T) { - calls := 0 - hc := newHTTPClient(func(_ *http.Request) (*http.Response, error) { - calls++ - return resp(200, "ok"), nil - }) - a := NewAuthorizer(hc, nil, "", "") - req, _ := http.NewRequest(http.MethodGet, "https://ghcr.io/v2/", nil) - req.Host = "ghcr.io" - _, err := a.Do(req) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if calls != 1 { - t.Fatalf("expected single call for ghcr.io, got %d", calls) - } - if a.AttemptBearerAuthentication { - t.Fatalf("AttemptBearerAuthentication should be false after ghcr path") - } -} +func TestAuthorizer_Do(t *testing.T) { + tests := []struct { + name string + host string + authHeader string + serverStatus int + expectForceOAuth2 bool + expectBearerAuthAfter bool + }{ + { + name: "successful request without auth header", + host: "registry.example.com", + authHeader: "", + serverStatus: 200, + expectForceOAuth2: true, + expectBearerAuthAfter: false, + }, + { + name: "request with existing auth header", + host: "registry.example.com", + authHeader: "Bearer token123", + serverStatus: 200, + expectForceOAuth2: false, + expectBearerAuthAfter: true, + }, + { + name: "ghcr.io special handling", + host: "ghcr.io", + authHeader: "", + serverStatus: 200, + expectForceOAuth2: false, + expectBearerAuthAfter: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(tt.serverStatus) + w.Write([]byte("success")) + })) + defer server.Close() -func TestDo_WithAuthorizationHeader_SkipsPreflight(t *testing.T) { - calls := 0 - hc := newHTTPClient(func(_ *http.Request) (*http.Response, error) { - calls++ - return resp(200, "ok"), nil - }) - a := NewAuthorizer(hc, nil, "", "") - req, _ := http.NewRequest(http.MethodGet, "https://example.com/v2/", nil) - req.Header.Set("Authorization", "Bearer token") - _, err := a.Do(req) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if calls != 1 { - t.Fatalf("expected one direct call when Authorization present, got %d", calls) - } - if !a.AttemptBearerAuthentication { - t.Fatalf("AttemptBearerAuthentication should remain true when Authorization header is present") - } -} + httpClient := &http.Client{} + credStore := &mockCredentialsStore{} -func TestDo_ForceAttemptOAuth2_SetForNonGHCR(t *testing.T) { - calls := 0 - hc := newHTTPClient(func(_ *http.Request) (*http.Response, error) { - calls++ - return resp(200, "ok"), nil - }) - a := NewAuthorizer(hc, nil, "", "") - req, _ := http.NewRequest(http.MethodGet, "https://example.com/v2/", nil) - req.Host = "example.com" - - // First call should set ForceAttemptOAuth2 to true for non-ghcr.io hosts - _, err := a.Do(req) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if !a.ForceAttemptOAuth2 { - t.Fatalf("ForceAttemptOAuth2 should be true for non-ghcr.io hosts") - } -} + authorizer := NewAuthorizer(httpClient, credStore, "", "") -func TestDo_ForceAttemptOAuth2_NotSetForGHCR(t *testing.T) { - hc := newHTTPClient(func(_ *http.Request) (*http.Response, error) { - return resp(200, "ok"), nil - }) - a := NewAuthorizer(hc, nil, "", "") - req, _ := http.NewRequest(http.MethodGet, "https://ghcr.io/v2/", nil) - req.Host = "ghcr.io" - - _, err := a.Do(req) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if a.ForceAttemptOAuth2 { - t.Fatalf("ForceAttemptOAuth2 should be false for ghcr.io") - } -} + req, err := http.NewRequest(http.MethodGet, server.URL, nil) + require.NoError(t, err) + req.Host = tt.host -func TestDo_MultipleAuthErrors_RetriesCorrectly(t *testing.T) { - calls := 0 - hc := newHTTPClient(func(_ *http.Request) (*http.Response, error) { - calls++ - switch calls { - case 1: - return nil, errors.New("unexpected response status code 401: Unauthorized") - case 2: - return resp(200, "ok"), nil - default: - t.Fatalf("unexpected number of calls: %d", calls) - return nil, errors.New("unexpected") - } - }) - a := NewAuthorizer(hc, nil, "", "") - req, _ := http.NewRequest(http.MethodGet, "https://example.com/v2/", nil) - req.Host = "example.com" - - resp, err := a.Do(req) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if resp.StatusCode != http.StatusOK { - t.Fatalf("expected 200 status, got %d", resp.StatusCode) - } - if calls != 2 { - t.Fatalf("expected exactly 2 calls for retry, got %d", calls) - } -} + if tt.authHeader != "" { + req.Header.Set("Authorization", tt.authHeader) + } -func TestDo_403Error_RetriesCorrectly(t *testing.T) { - calls := 0 - hc := newHTTPClient(func(_ *http.Request) (*http.Response, error) { - calls++ - if calls == 1 { - return nil, errors.New("unexpected response status code 403: Forbidden") - } - return resp(200, "ok"), nil - }) - a := NewAuthorizer(hc, nil, "", "") - req, _ := http.NewRequest(http.MethodGet, "https://example.com/v2/", nil) - req.Host = "example.com" - - _, err := a.Do(req) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if calls != 2 { - t.Fatalf("expected 2 calls for 403 error retry, got %d", calls) - } -} + resp, err := authorizer.Do(req) -func TestDo_AttemptBearerAuthentication_False_SkipsLogic(t *testing.T) { - calls := 0 - hc := newHTTPClient(func(_ *http.Request) (*http.Response, error) { - calls++ - return resp(200, "ok"), nil - }) - a := NewAuthorizer(hc, nil, "", "") - a.AttemptBearerAuthentication = false // Explicitly set to false - - req, _ := http.NewRequest(http.MethodGet, "https://example.com/v2/", nil) - _, err := a.Do(req) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if calls != 1 { - t.Fatalf("expected single call when AttemptBearerAuthentication is false, got %d", calls) - } - if a.AttemptBearerAuthentication { - t.Fatalf("AttemptBearerAuthentication should remain false") + require.NoError(t, err) + require.NotNil(t, resp) + assert.Equal(t, tt.expectBearerAuthAfter, authorizer.AttemptBearerAuthentication) + + if tt.authHeader == "" { + assert.Equal(t, tt.expectForceOAuth2, authorizer.ForceAttemptOAuth2) + } + + resp.Body.Close() + }) } } -func TestDo_SequentialRequests_MaintainsState(t *testing.T) { - callCount := 0 - hc := newHTTPClient(func(_ *http.Request) (*http.Response, error) { - callCount++ - return resp(200, "ok"), nil - }) - a := NewAuthorizer(hc, nil, "", "") - - // First request without auth header - req1, _ := http.NewRequest(http.MethodGet, "https://example.com/v2/", nil) - req1.Host = "example.com" - _, err := a.Do(req1) - if err != nil { - t.Fatalf("first request failed: %v", err) - } - if a.AttemptBearerAuthentication { - t.Fatalf("AttemptBearerAuthentication should be false after first request") - } +func TestAuthorizer_Do_WithBearerAttemptDisabled(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("success")) + })) + defer server.Close() - // Second request should go straight through - req2, _ := http.NewRequest(http.MethodGet, "https://example.com/v2/charts", nil) - req2.Host = "example.com" - _, err = a.Do(req2) - if err != nil { - t.Fatalf("second request failed: %v", err) - } + httpClient := &http.Client{} + credStore := &mockCredentialsStore{} - // Should only have made 2 calls total (no retry on second) - if callCount != 2 { - t.Fatalf("expected 2 total calls, got %d", callCount) - } -} + authorizer := NewAuthorizer(httpClient, credStore, "", "") + authorizer.AttemptBearerAuthentication = false + + req, err := http.NewRequest(http.MethodGet, server.URL, nil) + require.NoError(t, err) + req.Host = "registry.example.com" -func TestDo_ErrorMessageParsing_404NotRetried(t *testing.T) { - calls := 0 - hc := newHTTPClient(func(_ *http.Request) (*http.Response, error) { - calls++ - // 404 error should contain "40" but not trigger retry since it's not 401/403 - return nil, errors.New("unexpected response status code 404: Not Found") - }) - a := NewAuthorizer(hc, nil, "", "") - req, _ := http.NewRequest(http.MethodGet, "https://example.com/v2/", nil) - req.Host = "example.com" - - _, err := a.Do(req) - if err == nil || !strings.Contains(err.Error(), "404") { - t.Fatalf("expected 404 error, got %v", err) - } - if calls != 2 { - t.Fatalf("expected 2 calls for 404 (matches '40' pattern), got %d", calls) - } -} + resp, err := authorizer.Do(req) + + require.NoError(t, err) + require.NotNil(t, resp) + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.False(t, authorizer.AttemptBearerAuthentication) + + resp.Body.Close() +} + +func TestAuthorizer_Do_NonRetryableError(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte("internal server error")) + })) + defer server.Close() + + httpClient := &http.Client{} + credStore := &mockCredentialsStore{} + + authorizer := NewAuthorizer(httpClient, credStore, "", "") + + req, err := http.NewRequest(http.MethodGet, server.URL, nil) + require.NoError(t, err) + req.Host = "registry.example.com" -func TestDo_ErrorMessageParsing_NonStatusCodeError(t *testing.T) { - calls := 0 - hc := newHTTPClient(func(_ *http.Request) (*http.Response, error) { - calls++ - // Error containing "40" but not a status code error - return nil, errors.New("failed after 40 attempts") - }) - a := NewAuthorizer(hc, nil, "", "") - req, _ := http.NewRequest(http.MethodGet, "https://example.com/v2/", nil) - req.Host = "example.com" - - _, err := a.Do(req) - if err == nil || !strings.Contains(err.Error(), "40 attempts") { - t.Fatalf("expected error with '40 attempts', got %v", err) - } - // Should not retry since it doesn't match the pattern despite containing "40" - if calls != 1 { - t.Fatalf("expected 1 call (no retry for non-status code errors), got %d", calls) - } -} + resp, err := authorizer.Do(req) -func TestNewAuthorizer_NilHttpClient(t *testing.T) { - // Test that NewAuthorizer works with nil HTTP client - a := NewAuthorizer(nil, nil, "user", "pass") - if a == nil { - t.Fatalf("NewAuthorizer should not return nil") - } - if a.Client.Client != nil { - t.Fatalf("expected nil HTTP client to remain nil") - } - // Verify credential function still works - cred, err := a.Credential(t.Context(), "example.com") - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if cred.Username != "user" || cred.Password != "pass" { - t.Fatalf("credentials not set correctly: %+v", cred) - } + require.NoError(t, err) + require.NotNil(t, resp) + assert.Equal(t, http.StatusInternalServerError, resp.StatusCode) + + resp.Body.Close() } From 4f7b93a01004fe11504ee67556365ca792441453 Mon Sep 17 00:00:00 2001 From: Terry Howe Date: Thu, 11 Sep 2025 10:23:49 -0600 Subject: [PATCH 3/4] chore: handle mutliple threads Signed-off-by: Terry Howe --- pkg/registry/authorizer.go | 40 +++++++++++++++++---- pkg/registry/authorizer_test.go | 64 ++++++++++++++++++++++++++++++--- 2 files changed, 92 insertions(+), 12 deletions(-) diff --git a/pkg/registry/authorizer.go b/pkg/registry/authorizer.go index 53e41587a..cbae9df98 100644 --- a/pkg/registry/authorizer.go +++ b/pkg/registry/authorizer.go @@ -20,6 +20,7 @@ import ( "context" "net/http" "strings" + "sync" "helm.sh/helm/v4/internal/version" @@ -29,7 +30,8 @@ import ( type Authorizer struct { auth.Client - AttemptBearerAuthentication bool + lock sync.RWMutex + attemptBearerAuthentication bool } func NewAuthorizer(httpClient *http.Client, credentialsStore credentials.Store, username, password string) *Authorizer { @@ -48,7 +50,7 @@ func NewAuthorizer(httpClient *http.Client, credentialsStore credentials.Store, authorizer.Credential = credentials.Credential(credentialsStore) } - authorizer.AttemptBearerAuthentication = true + authorizer.setAttemptBearerAuthentication(true) return &authorizer } @@ -56,19 +58,43 @@ func (a *Authorizer) EnableCache() { a.Cache = auth.NewCache() } +func (a *Authorizer) getAttemptBearerAuthentication() bool { + a.lock.RLock() + defer a.lock.RUnlock() + return a.attemptBearerAuthentication +} + +func (a *Authorizer) setAttemptBearerAuthentication(value bool) { + a.lock.Lock() + defer a.lock.Unlock() + a.attemptBearerAuthentication = value +} + +func (a *Authorizer) getForceAttemptOAuth2() bool { + a.lock.RLock() + defer a.lock.RUnlock() + return a.ForceAttemptOAuth2 +} + +func (a *Authorizer) setForceAttemptOAuth2(value bool) { + a.lock.Lock() + defer a.lock.Unlock() + a.ForceAttemptOAuth2 = value +} + // Do This method wraps auth.Client.Do in attempt to retry authentication func (a *Authorizer) Do(originalReq *http.Request) (*http.Response, error) { - if a.AttemptBearerAuthentication { + if a.getAttemptBearerAuthentication() { needsAuthentication := originalReq.Header.Get("Authorization") == "" if needsAuthentication { - a.ForceAttemptOAuth2 = true + a.setForceAttemptOAuth2(true) if originalReq.Host == "ghcr.io" { - a.ForceAttemptOAuth2 = false - a.AttemptBearerAuthentication = false + a.setForceAttemptOAuth2(false) + a.setAttemptBearerAuthentication(false) } resp, err := a.Client.Do(originalReq) if err == nil { - a.AttemptBearerAuthentication = false + a.setAttemptBearerAuthentication(false) return resp, nil } if !strings.Contains(err.Error(), "response status code 40") { diff --git a/pkg/registry/authorizer_test.go b/pkg/registry/authorizer_test.go index a084c8b02..d0267803b 100644 --- a/pkg/registry/authorizer_test.go +++ b/pkg/registry/authorizer_test.go @@ -20,6 +20,7 @@ import ( "context" "net/http" "net/http/httptest" + "sync" "testing" "github.com/stretchr/testify/assert" @@ -78,7 +79,7 @@ func TestNewAuthorizer(t *testing.T) { require.NotNil(t, authorizer) assert.Equal(t, httpClient, authorizer.Client.Client) - assert.True(t, authorizer.AttemptBearerAuthentication) + assert.True(t, authorizer.getAttemptBearerAuthentication()) assert.NotNil(t, authorizer.Credential) if tt.username != "" && tt.password != "" { @@ -179,10 +180,10 @@ func TestAuthorizer_Do(t *testing.T) { require.NoError(t, err) require.NotNil(t, resp) - assert.Equal(t, tt.expectBearerAuthAfter, authorizer.AttemptBearerAuthentication) + assert.Equal(t, tt.expectBearerAuthAfter, authorizer.getAttemptBearerAuthentication()) if tt.authHeader == "" { - assert.Equal(t, tt.expectForceOAuth2, authorizer.ForceAttemptOAuth2) + assert.Equal(t, tt.expectForceOAuth2, authorizer.getForceAttemptOAuth2()) } resp.Body.Close() @@ -201,7 +202,7 @@ func TestAuthorizer_Do_WithBearerAttemptDisabled(t *testing.T) { credStore := &mockCredentialsStore{} authorizer := NewAuthorizer(httpClient, credStore, "", "") - authorizer.AttemptBearerAuthentication = false + authorizer.setAttemptBearerAuthentication(false) req, err := http.NewRequest(http.MethodGet, server.URL, nil) require.NoError(t, err) @@ -212,7 +213,7 @@ func TestAuthorizer_Do_WithBearerAttemptDisabled(t *testing.T) { require.NoError(t, err) require.NotNil(t, resp) assert.Equal(t, http.StatusOK, resp.StatusCode) - assert.False(t, authorizer.AttemptBearerAuthentication) + assert.False(t, authorizer.getAttemptBearerAuthentication()) resp.Body.Close() } @@ -241,3 +242,56 @@ func TestAuthorizer_Do_NonRetryableError(t *testing.T) { resp.Body.Close() } + +func TestAuthorizer_ConcurrentAccess(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("success")) + })) + defer server.Close() + + httpClient := &http.Client{} + credStore := &mockCredentialsStore{} + authorizer := NewAuthorizer(httpClient, credStore, "", "") + + const numGoroutines = 100 + const numRequests = 10 + + var wg sync.WaitGroup + wg.Add(numGoroutines * 2) + + for i := 0; i < numGoroutines; i++ { + go func() { + defer wg.Done() + for j := 0; j < numRequests; j++ { + req, err := http.NewRequest(http.MethodGet, server.URL, nil) + require.NoError(t, err) + req.Host = "registry.example.com" + + resp, err := authorizer.Do(req) + if err == nil && resp != nil { + resp.Body.Close() + } + } + }() + + go func() { + defer wg.Done() + for j := 0; j < numRequests; j++ { + authorizer.setAttemptBearerAuthentication(true) + val := authorizer.getAttemptBearerAuthentication() + if val != true { + t.Logf("Warning: Expected true but got %v", val) + } + + authorizer.setAttemptBearerAuthentication(false) + val = authorizer.getAttemptBearerAuthentication() + if val != false { + t.Logf("Warning: Expected false but got %v", val) + } + } + }() + } + + wg.Wait() +} From 5be055d4d73209c0c2ff8c828bdde31ff766bfa2 Mon Sep 17 00:00:00 2001 From: Terry Howe Date: Thu, 11 Sep 2025 11:02:51 -0600 Subject: [PATCH 4/4] chore: handle specific errors Signed-off-by: Terry Howe --- pkg/registry/authorizer.go | 3 ++- pkg/registry/authorizer_test.go | 48 +++++++++++++++++++++++++++++++++ 2 files changed, 50 insertions(+), 1 deletion(-) diff --git a/pkg/registry/authorizer.go b/pkg/registry/authorizer.go index cbae9df98..6d8dd49a0 100644 --- a/pkg/registry/authorizer.go +++ b/pkg/registry/authorizer.go @@ -97,7 +97,8 @@ func (a *Authorizer) Do(originalReq *http.Request) (*http.Response, error) { a.setAttemptBearerAuthentication(false) return resp, nil } - if !strings.Contains(err.Error(), "response status code 40") { + if !strings.Contains(err.Error(), "response status code 401") && + !strings.Contains(err.Error(), "response status code 403") { return nil, err } } diff --git a/pkg/registry/authorizer_test.go b/pkg/registry/authorizer_test.go index d0267803b..a1cba065e 100644 --- a/pkg/registry/authorizer_test.go +++ b/pkg/registry/authorizer_test.go @@ -18,8 +18,10 @@ package registry import ( "context" + "errors" "net/http" "net/http/httptest" + "strings" "sync" "testing" @@ -295,3 +297,49 @@ func TestAuthorizer_ConcurrentAccess(t *testing.T) { wg.Wait() } + +func TestAuthorizer_Do_StatusCodeErrorChecking(t *testing.T) { + tests := []struct { + name string + errorMsg string + shouldRetry bool + description string + }{ + { + name: "retry on 401 error", + errorMsg: "response status code 401", + shouldRetry: true, + description: "401 errors should trigger retry logic", + }, + { + name: "retry on 403 error", + errorMsg: "response status code 403", + shouldRetry: true, + description: "403 errors should trigger retry logic", + }, + { + name: "no retry on 404 error", + errorMsg: "response status code 404", + shouldRetry: false, + description: "404 errors should not trigger retry logic", + }, + { + name: "no retry on 500 error", + errorMsg: "response status code 500", + shouldRetry: false, + description: "500 errors should not trigger retry logic", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := errors.New(tt.errorMsg) + + should401Retry := strings.Contains(err.Error(), "response status code 401") + should403Retry := strings.Contains(err.Error(), "response status code 403") + actualShouldRetry := should401Retry || should403Retry + + assert.Equal(t, tt.shouldRetry, actualShouldRetry, tt.description) + }) + } +}