pull/31212/merge
Terry Howe 2 days ago committed by GitHub
commit 37234c7ac6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -0,0 +1,107 @@
/*
Copyright The Helm Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package registry
import (
"context"
"net/http"
"strings"
"sync"
"helm.sh/helm/v4/internal/version"
"oras.land/oras-go/v2/registry/remote/auth"
"oras.land/oras-go/v2/registry/remote/credentials"
)
type Authorizer struct {
auth.Client
lock sync.RWMutex
attemptBearerAuthentication bool
}
func NewAuthorizer(httpClient *http.Client, credentialsStore credentials.Store, username, password string) *Authorizer {
authorizer := Authorizer{
Client: auth.Client{
Client: httpClient,
},
}
authorizer.SetUserAgent(version.GetUserAgent())
if username != "" && password != "" {
authorizer.Credential = func(_ context.Context, _ string) (auth.Credential, error) {
return auth.Credential{Username: username, Password: password}, nil
}
} else {
authorizer.Credential = credentials.Credential(credentialsStore)
}
authorizer.setAttemptBearerAuthentication(true)
return &authorizer
}
func (a *Authorizer) EnableCache() {
a.Cache = auth.NewCache()
}
func (a *Authorizer) getAttemptBearerAuthentication() bool {
a.lock.RLock()
defer a.lock.RUnlock()
return a.attemptBearerAuthentication
}
func (a *Authorizer) setAttemptBearerAuthentication(value bool) {
a.lock.Lock()
defer a.lock.Unlock()
a.attemptBearerAuthentication = value
}
func (a *Authorizer) getForceAttemptOAuth2() bool {
a.lock.RLock()
defer a.lock.RUnlock()
return a.ForceAttemptOAuth2
}
func (a *Authorizer) setForceAttemptOAuth2(value bool) {
a.lock.Lock()
defer a.lock.Unlock()
a.ForceAttemptOAuth2 = value
}
// Do This method wraps auth.Client.Do in attempt to retry authentication
func (a *Authorizer) Do(originalReq *http.Request) (*http.Response, error) {
if a.getAttemptBearerAuthentication() {
needsAuthentication := originalReq.Header.Get("Authorization") == ""
if needsAuthentication {
a.setForceAttemptOAuth2(true)
if originalReq.Host == "ghcr.io" {
a.setForceAttemptOAuth2(false)
a.setAttemptBearerAuthentication(false)
}
resp, err := a.Client.Do(originalReq)
if err == nil {
a.setAttemptBearerAuthentication(false)
return resp, nil
}
if !strings.Contains(err.Error(), "response status code 401") &&
!strings.Contains(err.Error(), "response status code 403") {
return nil, err
}
}
}
return a.Client.Do(originalReq)
}

@ -0,0 +1,345 @@
/*
Copyright The Helm Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package registry
import (
"context"
"errors"
"net/http"
"net/http/httptest"
"strings"
"sync"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"oras.land/oras-go/v2/registry/remote/auth"
)
type mockCredentialsStore struct {
username string
password string
err error
}
func (m *mockCredentialsStore) Get(_ context.Context, _ string) (auth.Credential, error) {
if m.err != nil {
return auth.EmptyCredential, m.err
}
return auth.Credential{
Username: m.username,
Password: m.password,
}, nil
}
func (m *mockCredentialsStore) Put(_ context.Context, _ string, _ auth.Credential) error {
return nil
}
func (m *mockCredentialsStore) Delete(_ context.Context, _ string) error {
return nil
}
func TestNewAuthorizer(t *testing.T) {
tests := []struct {
name string
username string
password string
}{
{
name: "with username and password",
username: "testuser",
password: "testpass",
},
{
name: "without credentials",
username: "",
password: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
httpClient := &http.Client{}
credStore := &mockCredentialsStore{}
authorizer := NewAuthorizer(httpClient, credStore, tt.username, tt.password)
require.NotNil(t, authorizer)
assert.Equal(t, httpClient, authorizer.Client.Client)
assert.True(t, authorizer.getAttemptBearerAuthentication())
assert.NotNil(t, authorizer.Credential)
if tt.username != "" && tt.password != "" {
cred, err := authorizer.Credential(t.Context(), "")
require.NoError(t, err)
assert.Equal(t, tt.username, cred.Username)
assert.Equal(t, tt.password, cred.Password)
}
})
}
}
func TestNewAuthorizer_WithCredentialsStore(t *testing.T) {
httpClient := &http.Client{}
credStore := &mockCredentialsStore{
username: "storeuser",
password: "storepass",
}
authorizer := NewAuthorizer(httpClient, credStore, "", "")
require.NotNil(t, authorizer)
cred, err := authorizer.Credential(t.Context(), "test.com")
require.NoError(t, err)
assert.Equal(t, "storeuser", cred.Username)
assert.Equal(t, "storepass", cred.Password)
}
func TestAuthorizer_EnableCache(t *testing.T) {
httpClient := &http.Client{}
credStore := &mockCredentialsStore{}
authorizer := NewAuthorizer(httpClient, credStore, "", "")
assert.Nil(t, authorizer.Cache)
authorizer.EnableCache()
assert.NotNil(t, authorizer.Cache)
}
func TestAuthorizer_Do(t *testing.T) {
tests := []struct {
name string
host string
authHeader string
serverStatus int
expectForceOAuth2 bool
expectBearerAuthAfter bool
}{
{
name: "successful request without auth header",
host: "registry.example.com",
authHeader: "",
serverStatus: 200,
expectForceOAuth2: true,
expectBearerAuthAfter: false,
},
{
name: "request with existing auth header",
host: "registry.example.com",
authHeader: "Bearer token123",
serverStatus: 200,
expectForceOAuth2: false,
expectBearerAuthAfter: true,
},
{
name: "ghcr.io special handling",
host: "ghcr.io",
authHeader: "",
serverStatus: 200,
expectForceOAuth2: false,
expectBearerAuthAfter: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(tt.serverStatus)
w.Write([]byte("success"))
}))
defer server.Close()
httpClient := &http.Client{}
credStore := &mockCredentialsStore{}
authorizer := NewAuthorizer(httpClient, credStore, "", "")
req, err := http.NewRequest(http.MethodGet, server.URL, nil)
require.NoError(t, err)
req.Host = tt.host
if tt.authHeader != "" {
req.Header.Set("Authorization", tt.authHeader)
}
resp, err := authorizer.Do(req)
require.NoError(t, err)
require.NotNil(t, resp)
assert.Equal(t, tt.expectBearerAuthAfter, authorizer.getAttemptBearerAuthentication())
if tt.authHeader == "" {
assert.Equal(t, tt.expectForceOAuth2, authorizer.getForceAttemptOAuth2())
}
resp.Body.Close()
})
}
}
func TestAuthorizer_Do_WithBearerAttemptDisabled(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("success"))
}))
defer server.Close()
httpClient := &http.Client{}
credStore := &mockCredentialsStore{}
authorizer := NewAuthorizer(httpClient, credStore, "", "")
authorizer.setAttemptBearerAuthentication(false)
req, err := http.NewRequest(http.MethodGet, server.URL, nil)
require.NoError(t, err)
req.Host = "registry.example.com"
resp, err := authorizer.Do(req)
require.NoError(t, err)
require.NotNil(t, resp)
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.False(t, authorizer.getAttemptBearerAuthentication())
resp.Body.Close()
}
func TestAuthorizer_Do_NonRetryableError(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte("internal server error"))
}))
defer server.Close()
httpClient := &http.Client{}
credStore := &mockCredentialsStore{}
authorizer := NewAuthorizer(httpClient, credStore, "", "")
req, err := http.NewRequest(http.MethodGet, server.URL, nil)
require.NoError(t, err)
req.Host = "registry.example.com"
resp, err := authorizer.Do(req)
require.NoError(t, err)
require.NotNil(t, resp)
assert.Equal(t, http.StatusInternalServerError, resp.StatusCode)
resp.Body.Close()
}
func TestAuthorizer_ConcurrentAccess(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("success"))
}))
defer server.Close()
httpClient := &http.Client{}
credStore := &mockCredentialsStore{}
authorizer := NewAuthorizer(httpClient, credStore, "", "")
const numGoroutines = 100
const numRequests = 10
var wg sync.WaitGroup
wg.Add(numGoroutines * 2)
for i := 0; i < numGoroutines; i++ {
go func() {
defer wg.Done()
for j := 0; j < numRequests; j++ {
req, err := http.NewRequest(http.MethodGet, server.URL, nil)
require.NoError(t, err)
req.Host = "registry.example.com"
resp, err := authorizer.Do(req)
if err == nil && resp != nil {
resp.Body.Close()
}
}
}()
go func() {
defer wg.Done()
for j := 0; j < numRequests; j++ {
authorizer.setAttemptBearerAuthentication(true)
val := authorizer.getAttemptBearerAuthentication()
if val != true {
t.Logf("Warning: Expected true but got %v", val)
}
authorizer.setAttemptBearerAuthentication(false)
val = authorizer.getAttemptBearerAuthentication()
if val != false {
t.Logf("Warning: Expected false but got %v", val)
}
}
}()
}
wg.Wait()
}
func TestAuthorizer_Do_StatusCodeErrorChecking(t *testing.T) {
tests := []struct {
name string
errorMsg string
shouldRetry bool
description string
}{
{
name: "retry on 401 error",
errorMsg: "response status code 401",
shouldRetry: true,
description: "401 errors should trigger retry logic",
},
{
name: "retry on 403 error",
errorMsg: "response status code 403",
shouldRetry: true,
description: "403 errors should trigger retry logic",
},
{
name: "no retry on 404 error",
errorMsg: "response status code 404",
shouldRetry: false,
description: "404 errors should not trigger retry logic",
},
{
name: "no retry on 500 error",
errorMsg: "response status code 500",
shouldRetry: false,
description: "500 errors should not trigger retry logic",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := errors.New(tt.errorMsg)
should401Retry := strings.Contains(err.Error(), "response status code 401")
should403Retry := strings.Contains(err.Error(), "response status code 403")
actualShouldRetry := should401Retry || should403Retry
assert.Equal(t, tt.shouldRetry, actualShouldRetry, tt.description)
})
}
}

@ -41,7 +41,6 @@ import (
"oras.land/oras-go/v2/registry/remote/credentials"
"oras.land/oras-go/v2/registry/remote/retry"
"helm.sh/helm/v4/internal/version"
chart "helm.sh/helm/v4/pkg/chart/v2"
"helm.sh/helm/v4/pkg/helmpath"
)
@ -70,7 +69,7 @@ type (
username string
password string
out io.Writer
authorizer *auth.Client
authorizer *Authorizer
registryAuthorizer RemoteClient
credentialsStore credentials.Store
httpClient *http.Client
@ -121,23 +120,11 @@ func NewClient(options ...ClientOption) (*Client, error) {
}
if client.authorizer == nil {
authorizer := auth.Client{
Client: client.httpClient,
}
authorizer.SetUserAgent(version.GetUserAgent())
if client.username != "" && client.password != "" {
authorizer.Credential = func(_ context.Context, _ string) (auth.Credential, error) {
return auth.Credential{Username: client.username, Password: client.password}, nil
}
} else {
authorizer.Credential = credentials.Credential(client.credentialsStore)
}
authorizer := NewAuthorizer(client.httpClient, client.credentialsStore, client.username, client.password)
if client.enableCache {
authorizer.Cache = auth.NewCache()
authorizer.EnableCache()
}
client.authorizer = &authorizer
client.authorizer = authorizer
}
return client, nil
@ -177,17 +164,17 @@ func ClientOptWriter(out io.Writer) ClientOption {
}
}
// ClientOptAuthorizer returns a function that sets the authorizer setting on a client options set. This
// ClientOptAuthorizer returns a function that sets the Authorizer setting on a client options set. This
// can be used to override the default authorization mechanism.
//
// Depending on the use-case you may need to set both ClientOptAuthorizer and ClientOptRegistryAuthorizer.
func ClientOptAuthorizer(authorizer auth.Client) ClientOption {
return func(client *Client) {
client.authorizer = &authorizer
client.authorizer = &Authorizer{Client: authorizer}
}
}
// ClientOptRegistryAuthorizer returns a function that sets the registry authorizer setting on a client options set. This
// ClientOptRegistryAuthorizer returns a function that sets the registry Authorizer setting on a client options set. This
// can be used to override the default authorization mechanism.
//
// Depending on the use-case you may need to set both ClientOptAuthorizer and ClientOptRegistryAuthorizer.
@ -239,18 +226,12 @@ func (c *Client) Login(host string, options ...LoginOption) error {
}
reg.PlainHTTP = c.plainHTTP
cred := auth.Credential{Username: c.username, Password: c.password}
c.authorizer.ForceAttemptOAuth2 = true
reg.Client = c.authorizer
ctx := context.Background()
if err := reg.Ping(ctx); err != nil {
c.authorizer.ForceAttemptOAuth2 = false
if err := reg.Ping(ctx); err != nil {
return fmt.Errorf("authenticating to %q: %w", host, err)
}
return fmt.Errorf("authenticating to %q: %w", host, err)
}
// Always restore to false after probing, to avoid forcing POST to token endpoints like GHCR.
c.authorizer.ForceAttemptOAuth2 = false
key := credentials.ServerAddressFromRegistry(host)
key = credentials.ServerAddressFromHostname(key)
@ -278,10 +259,10 @@ func LoginOptPlainText(isPlainText bool) LoginOption {
}
}
func ensureTLSConfig(client *auth.Client, setConfig *tls.Config) (*tls.Config, error) {
func ensureTLSConfig(client *Authorizer, setConfig *tls.Config) (*tls.Config, error) {
var transport *http.Transport
switch t := client.Client.Transport.(type) {
switch t := client.Client.Client.Transport.(type) {
case *http.Transport:
transport = t
case *retry.Transport:
@ -299,7 +280,7 @@ func ensureTLSConfig(client *auth.Client, setConfig *tls.Config) (*tls.Config, e
if transport == nil {
// we don't know how to access the http.Transport, most likely the
// auth.Client.Client was provided by API user
return nil, fmt.Errorf("unable to access TLS client configuration, the provided HTTP Transport is not supported, given: %T", client.Client.Transport)
return nil, fmt.Errorf("unable to access TLS client configuration, the provided HTTP Transport is not supported, given: %T", client.Client.Client.Transport)
}
switch {

@ -18,10 +18,6 @@ package registry
import (
"io"
"net/http"
"net/http/httptest"
"path/filepath"
"strings"
"testing"
ocispec "github.com/opencontainers/image-spec/specs-go/v1"
@ -55,68 +51,3 @@ func TestTagManifestTransformsReferences(t *testing.T) {
_, err = memStore.Resolve(ctx, refWithPlus)
require.Error(t, err, "Should NOT find the reference with the original +")
}
// Verifies that Login always restores ForceAttemptOAuth2 to false on success.
func TestLogin_ResetsForceAttemptOAuth2_OnSuccess(t *testing.T) {
t.Parallel()
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/v2/" {
// Accept either HEAD or GET
w.WriteHeader(http.StatusOK)
return
}
http.NotFound(w, r)
}))
defer srv.Close()
host := strings.TrimPrefix(srv.URL, "http://")
credFile := filepath.Join(t.TempDir(), "config.json")
c, err := NewClient(
ClientOptWriter(io.Discard),
ClientOptCredentialsFile(credFile),
)
if err != nil {
t.Fatalf("NewClient error: %v", err)
}
if c.authorizer == nil || c.authorizer.ForceAttemptOAuth2 {
t.Fatalf("expected ForceAttemptOAuth2 default to be false")
}
// Call Login with plain HTTP against our test server
if err := c.Login(host, LoginOptPlainText(true), LoginOptBasicAuth("u", "p")); err != nil {
t.Fatalf("Login error: %v", err)
}
if c.authorizer.ForceAttemptOAuth2 {
t.Errorf("ForceAttemptOAuth2 should be false after successful Login")
}
}
// Verifies that Login restores ForceAttemptOAuth2 to false even when ping fails.
func TestLogin_ResetsForceAttemptOAuth2_OnFailure(t *testing.T) {
t.Parallel()
// Start and immediately close, so connections will fail
srv := httptest.NewServer(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {}))
host := strings.TrimPrefix(srv.URL, "http://")
srv.Close()
credFile := filepath.Join(t.TempDir(), "config.json")
c, err := NewClient(
ClientOptWriter(io.Discard),
ClientOptCredentialsFile(credFile),
)
if err != nil {
t.Fatalf("NewClient error: %v", err)
}
// Invoke Login, expect an error but ForceAttemptOAuth2 must end false
_ = c.Login(host, LoginOptPlainText(true), LoginOptBasicAuth("u", "p"))
if c.authorizer.ForceAttemptOAuth2 {
t.Errorf("ForceAttemptOAuth2 should be false after failed Login")
}
}

@ -28,7 +28,6 @@ import (
"oras.land/oras-go/v2/content"
"oras.land/oras-go/v2/content/memory"
"oras.land/oras-go/v2/registry/remote"
"oras.land/oras-go/v2/registry/remote/auth"
"oras.land/oras-go/v2/registry/remote/credentials"
)
@ -40,7 +39,7 @@ type GenericClient struct {
username string
password string
out io.Writer
authorizer *auth.Client
authorizer *Authorizer
registryAuthorizer RemoteClient
credentialsStore credentials.Store
httpClient *http.Client

Loading…
Cancel
Save