mirror of https://github.com/helm/helm
Merge 5be055d4d7
into 23c5662019
commit
37234c7ac6
@ -0,0 +1,107 @@
|
||||
/*
|
||||
Copyright The Helm Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package registry
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"helm.sh/helm/v4/internal/version"
|
||||
|
||||
"oras.land/oras-go/v2/registry/remote/auth"
|
||||
"oras.land/oras-go/v2/registry/remote/credentials"
|
||||
)
|
||||
|
||||
type Authorizer struct {
|
||||
auth.Client
|
||||
lock sync.RWMutex
|
||||
attemptBearerAuthentication bool
|
||||
}
|
||||
|
||||
func NewAuthorizer(httpClient *http.Client, credentialsStore credentials.Store, username, password string) *Authorizer {
|
||||
authorizer := Authorizer{
|
||||
Client: auth.Client{
|
||||
Client: httpClient,
|
||||
},
|
||||
}
|
||||
authorizer.SetUserAgent(version.GetUserAgent())
|
||||
|
||||
if username != "" && password != "" {
|
||||
authorizer.Credential = func(_ context.Context, _ string) (auth.Credential, error) {
|
||||
return auth.Credential{Username: username, Password: password}, nil
|
||||
}
|
||||
} else {
|
||||
authorizer.Credential = credentials.Credential(credentialsStore)
|
||||
}
|
||||
|
||||
authorizer.setAttemptBearerAuthentication(true)
|
||||
return &authorizer
|
||||
}
|
||||
|
||||
func (a *Authorizer) EnableCache() {
|
||||
a.Cache = auth.NewCache()
|
||||
}
|
||||
|
||||
func (a *Authorizer) getAttemptBearerAuthentication() bool {
|
||||
a.lock.RLock()
|
||||
defer a.lock.RUnlock()
|
||||
return a.attemptBearerAuthentication
|
||||
}
|
||||
|
||||
func (a *Authorizer) setAttemptBearerAuthentication(value bool) {
|
||||
a.lock.Lock()
|
||||
defer a.lock.Unlock()
|
||||
a.attemptBearerAuthentication = value
|
||||
}
|
||||
|
||||
func (a *Authorizer) getForceAttemptOAuth2() bool {
|
||||
a.lock.RLock()
|
||||
defer a.lock.RUnlock()
|
||||
return a.ForceAttemptOAuth2
|
||||
}
|
||||
|
||||
func (a *Authorizer) setForceAttemptOAuth2(value bool) {
|
||||
a.lock.Lock()
|
||||
defer a.lock.Unlock()
|
||||
a.ForceAttemptOAuth2 = value
|
||||
}
|
||||
|
||||
// Do This method wraps auth.Client.Do in attempt to retry authentication
|
||||
func (a *Authorizer) Do(originalReq *http.Request) (*http.Response, error) {
|
||||
if a.getAttemptBearerAuthentication() {
|
||||
needsAuthentication := originalReq.Header.Get("Authorization") == ""
|
||||
if needsAuthentication {
|
||||
a.setForceAttemptOAuth2(true)
|
||||
if originalReq.Host == "ghcr.io" {
|
||||
a.setForceAttemptOAuth2(false)
|
||||
a.setAttemptBearerAuthentication(false)
|
||||
}
|
||||
resp, err := a.Client.Do(originalReq)
|
||||
if err == nil {
|
||||
a.setAttemptBearerAuthentication(false)
|
||||
return resp, nil
|
||||
}
|
||||
if !strings.Contains(err.Error(), "response status code 401") &&
|
||||
!strings.Contains(err.Error(), "response status code 403") {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
return a.Client.Do(originalReq)
|
||||
}
|
@ -0,0 +1,345 @@
|
||||
/*
|
||||
Copyright The Helm Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package registry
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"oras.land/oras-go/v2/registry/remote/auth"
|
||||
)
|
||||
|
||||
type mockCredentialsStore struct {
|
||||
username string
|
||||
password string
|
||||
err error
|
||||
}
|
||||
|
||||
func (m *mockCredentialsStore) Get(_ context.Context, _ string) (auth.Credential, error) {
|
||||
if m.err != nil {
|
||||
return auth.EmptyCredential, m.err
|
||||
}
|
||||
return auth.Credential{
|
||||
Username: m.username,
|
||||
Password: m.password,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *mockCredentialsStore) Put(_ context.Context, _ string, _ auth.Credential) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockCredentialsStore) Delete(_ context.Context, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestNewAuthorizer(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
username string
|
||||
password string
|
||||
}{
|
||||
{
|
||||
name: "with username and password",
|
||||
username: "testuser",
|
||||
password: "testpass",
|
||||
},
|
||||
{
|
||||
name: "without credentials",
|
||||
username: "",
|
||||
password: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
httpClient := &http.Client{}
|
||||
credStore := &mockCredentialsStore{}
|
||||
|
||||
authorizer := NewAuthorizer(httpClient, credStore, tt.username, tt.password)
|
||||
|
||||
require.NotNil(t, authorizer)
|
||||
assert.Equal(t, httpClient, authorizer.Client.Client)
|
||||
assert.True(t, authorizer.getAttemptBearerAuthentication())
|
||||
assert.NotNil(t, authorizer.Credential)
|
||||
|
||||
if tt.username != "" && tt.password != "" {
|
||||
cred, err := authorizer.Credential(t.Context(), "")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.username, cred.Username)
|
||||
assert.Equal(t, tt.password, cred.Password)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewAuthorizer_WithCredentialsStore(t *testing.T) {
|
||||
httpClient := &http.Client{}
|
||||
credStore := &mockCredentialsStore{
|
||||
username: "storeuser",
|
||||
password: "storepass",
|
||||
}
|
||||
|
||||
authorizer := NewAuthorizer(httpClient, credStore, "", "")
|
||||
|
||||
require.NotNil(t, authorizer)
|
||||
|
||||
cred, err := authorizer.Credential(t.Context(), "test.com")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "storeuser", cred.Username)
|
||||
assert.Equal(t, "storepass", cred.Password)
|
||||
}
|
||||
|
||||
func TestAuthorizer_EnableCache(t *testing.T) {
|
||||
httpClient := &http.Client{}
|
||||
credStore := &mockCredentialsStore{}
|
||||
|
||||
authorizer := NewAuthorizer(httpClient, credStore, "", "")
|
||||
assert.Nil(t, authorizer.Cache)
|
||||
|
||||
authorizer.EnableCache()
|
||||
assert.NotNil(t, authorizer.Cache)
|
||||
}
|
||||
|
||||
func TestAuthorizer_Do(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
host string
|
||||
authHeader string
|
||||
serverStatus int
|
||||
expectForceOAuth2 bool
|
||||
expectBearerAuthAfter bool
|
||||
}{
|
||||
{
|
||||
name: "successful request without auth header",
|
||||
host: "registry.example.com",
|
||||
authHeader: "",
|
||||
serverStatus: 200,
|
||||
expectForceOAuth2: true,
|
||||
expectBearerAuthAfter: false,
|
||||
},
|
||||
{
|
||||
name: "request with existing auth header",
|
||||
host: "registry.example.com",
|
||||
authHeader: "Bearer token123",
|
||||
serverStatus: 200,
|
||||
expectForceOAuth2: false,
|
||||
expectBearerAuthAfter: true,
|
||||
},
|
||||
{
|
||||
name: "ghcr.io special handling",
|
||||
host: "ghcr.io",
|
||||
authHeader: "",
|
||||
serverStatus: 200,
|
||||
expectForceOAuth2: false,
|
||||
expectBearerAuthAfter: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(tt.serverStatus)
|
||||
w.Write([]byte("success"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
httpClient := &http.Client{}
|
||||
credStore := &mockCredentialsStore{}
|
||||
|
||||
authorizer := NewAuthorizer(httpClient, credStore, "", "")
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, server.URL, nil)
|
||||
require.NoError(t, err)
|
||||
req.Host = tt.host
|
||||
|
||||
if tt.authHeader != "" {
|
||||
req.Header.Set("Authorization", tt.authHeader)
|
||||
}
|
||||
|
||||
resp, err := authorizer.Do(req)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
assert.Equal(t, tt.expectBearerAuthAfter, authorizer.getAttemptBearerAuthentication())
|
||||
|
||||
if tt.authHeader == "" {
|
||||
assert.Equal(t, tt.expectForceOAuth2, authorizer.getForceAttemptOAuth2())
|
||||
}
|
||||
|
||||
resp.Body.Close()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthorizer_Do_WithBearerAttemptDisabled(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("success"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
httpClient := &http.Client{}
|
||||
credStore := &mockCredentialsStore{}
|
||||
|
||||
authorizer := NewAuthorizer(httpClient, credStore, "", "")
|
||||
authorizer.setAttemptBearerAuthentication(false)
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, server.URL, nil)
|
||||
require.NoError(t, err)
|
||||
req.Host = "registry.example.com"
|
||||
|
||||
resp, err := authorizer.Do(req)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
assert.False(t, authorizer.getAttemptBearerAuthentication())
|
||||
|
||||
resp.Body.Close()
|
||||
}
|
||||
|
||||
func TestAuthorizer_Do_NonRetryableError(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
w.Write([]byte("internal server error"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
httpClient := &http.Client{}
|
||||
credStore := &mockCredentialsStore{}
|
||||
|
||||
authorizer := NewAuthorizer(httpClient, credStore, "", "")
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, server.URL, nil)
|
||||
require.NoError(t, err)
|
||||
req.Host = "registry.example.com"
|
||||
|
||||
resp, err := authorizer.Do(req)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
assert.Equal(t, http.StatusInternalServerError, resp.StatusCode)
|
||||
|
||||
resp.Body.Close()
|
||||
}
|
||||
|
||||
func TestAuthorizer_ConcurrentAccess(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("success"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
httpClient := &http.Client{}
|
||||
credStore := &mockCredentialsStore{}
|
||||
authorizer := NewAuthorizer(httpClient, credStore, "", "")
|
||||
|
||||
const numGoroutines = 100
|
||||
const numRequests = 10
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numGoroutines * 2)
|
||||
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for j := 0; j < numRequests; j++ {
|
||||
req, err := http.NewRequest(http.MethodGet, server.URL, nil)
|
||||
require.NoError(t, err)
|
||||
req.Host = "registry.example.com"
|
||||
|
||||
resp, err := authorizer.Do(req)
|
||||
if err == nil && resp != nil {
|
||||
resp.Body.Close()
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for j := 0; j < numRequests; j++ {
|
||||
authorizer.setAttemptBearerAuthentication(true)
|
||||
val := authorizer.getAttemptBearerAuthentication()
|
||||
if val != true {
|
||||
t.Logf("Warning: Expected true but got %v", val)
|
||||
}
|
||||
|
||||
authorizer.setAttemptBearerAuthentication(false)
|
||||
val = authorizer.getAttemptBearerAuthentication()
|
||||
if val != false {
|
||||
t.Logf("Warning: Expected false but got %v", val)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestAuthorizer_Do_StatusCodeErrorChecking(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
errorMsg string
|
||||
shouldRetry bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "retry on 401 error",
|
||||
errorMsg: "response status code 401",
|
||||
shouldRetry: true,
|
||||
description: "401 errors should trigger retry logic",
|
||||
},
|
||||
{
|
||||
name: "retry on 403 error",
|
||||
errorMsg: "response status code 403",
|
||||
shouldRetry: true,
|
||||
description: "403 errors should trigger retry logic",
|
||||
},
|
||||
{
|
||||
name: "no retry on 404 error",
|
||||
errorMsg: "response status code 404",
|
||||
shouldRetry: false,
|
||||
description: "404 errors should not trigger retry logic",
|
||||
},
|
||||
{
|
||||
name: "no retry on 500 error",
|
||||
errorMsg: "response status code 500",
|
||||
shouldRetry: false,
|
||||
description: "500 errors should not trigger retry logic",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := errors.New(tt.errorMsg)
|
||||
|
||||
should401Retry := strings.Contains(err.Error(), "response status code 401")
|
||||
should403Retry := strings.Contains(err.Error(), "response status code 403")
|
||||
actualShouldRetry := should401Retry || should403Retry
|
||||
|
||||
assert.Equal(t, tt.shouldRetry, actualShouldRetry, tt.description)
|
||||
})
|
||||
}
|
||||
}
|
Loading…
Reference in new issue