diff --git a/pkg/registry/client.go b/pkg/registry/client.go index f2bfd13b4..19f499eaa 100644 --- a/pkg/registry/client.go +++ b/pkg/registry/client.go @@ -24,10 +24,10 @@ import ( "errors" "fmt" "io" - "log/slog" "net/http" "net/url" "os" + "regexp" "sort" "strings" @@ -227,15 +227,44 @@ type ( } ) -// warnIfHostHasPath checks if the host contains a repository path and logs a warning if it does. -// Returns true if the host contains a path component (i.e., contains a '/'). -func warnIfHostHasPath(host string) bool { - if strings.Contains(host, "/") { - registryHost := strings.Split(host, "/")[0] - slog.Warn("registry login currently only supports registry hostname, not a repository path", "host", host, "suggested", registryHost) - return true +// schemeRegex is used to check if a schema is present within the host - we have to do so, to determine if we need +// to prepend a "dummy://" schema for url.Parse to not accidentally interpret the host as just a path +var schemeRegex = regexp.MustCompile(`^([a-zA-Z][a-zA-Z0-9+\-.]*:\/\/).*$`) + +// validateHost checks that the host matches some required pre-checks e.g. does not contain a scheme, query or path. +// While ORAS will also validate some of these things, the current errors are a bit opaque. +// By validating these things upfront, we can provide clearer error messages to users when they attempt to login with an invalid host string. +func validateHost(host string) error { + if host == "" { + return errors.New("host cannot be empty") + } + + // we pre-validate the scheme part here to make sure that we can prepend a dummy scheme for url.Parse without accidentally + // accepting invalid hosts that would be parsed with the scheme as just a path. By not just blindly prepending the scheme, + // we can produce a clearer error message for users who still include a scheme in the host string, which is a common thing. + if schemeRegex.MatchString(host) { + return fmt.Errorf("host should not contain a scheme, found %q", host) + } + + // keep the original host for potential error messages + originalHost := host + + // we have to prepend the dummy scheme here to make it an absolute url + parsed, err := url.Parse("dummy://" + host) + if err != nil { + return fmt.Errorf("invalid host %q: %w", originalHost, err) } - return false + + if parsed.RawQuery != "" || parsed.Fragment != "" { + return fmt.Errorf("host should not contain a query or fragment, found %q", originalHost) + } + + // currently, paths are also not supported + if parsed.Path != "" { + return fmt.Errorf("host should not contain a path, found %q", originalHost) + } + + return nil } // Login authenticates the client with a remote OCI registry using the provided host and options. @@ -244,7 +273,10 @@ func (c *Client) Login(host string, options ...LoginOption) error { option(&loginOperation{host, c}) } - warnIfHostHasPath(host) + host = strings.TrimSpace(host) + if err := validateHost(host); err != nil { + return err + } reg, err := remote.NewRegistry(host) if err != nil { diff --git a/pkg/registry/client_test.go b/pkg/registry/client_test.go index 702dfff69..4cd191d6e 100644 --- a/pkg/registry/client_test.go +++ b/pkg/registry/client_test.go @@ -121,47 +121,121 @@ func TestLogin_ResetsForceAttemptOAuth2_OnFailure(t *testing.T) { } } -// TestWarnIfHostHasPath verifies that warnIfHostHasPath correctly detects path components. -func TestWarnIfHostHasPath(t *testing.T) { +func TestValidateHost(t *testing.T) { t.Parallel() tests := []struct { - name string - host string - wantWarn bool + name string + host string + wantErr bool }{ { - name: "domain only", - host: "ghcr.io", - wantWarn: false, + name: "domain only", + host: "ghcr.io", + wantErr: false, }, { - name: "domain with port", - host: "localhost:8000", - wantWarn: false, + name: "domain with port", + host: "localhost:8000", + wantErr: false, }, { - name: "domain with repository path", - host: "ghcr.io/terryhowe", - wantWarn: true, + name: "domain with repository path", + host: "ghcr.io/terryhowe", + wantErr: true, }, { - name: "domain with nested path", - host: "ghcr.io/terryhowe/myrepo", - wantWarn: true, + name: "domain with nested path", + host: "ghcr.io/terryhowe/myrepo", + wantErr: true, }, { - name: "localhost with port and path", - host: "localhost:8000/myrepo", - wantWarn: true, + name: "localhost with port and path", + host: "localhost:8000/myrepo", + wantErr: true, + }, + { + name: "domain with http protocol", + host: "http://ghcr.io", + wantErr: true, + }, + { + name: "domain with https protocol", + host: "https://ghcr.io", + wantErr: true, + }, + { + name: "domain with oci protocol", + host: "oci://ghcr.io", + wantErr: true, + }, + { + name: "domain with uppercase scheme", + host: "HTTPS://ghcr.io", + wantErr: true, + }, + { + name: "scheme with plus sign", + host: "coap+tcp://ghcr.io", + wantErr: true, + }, + { + name: "scheme with dot and hyphen", + host: "my.custom-scheme://ghcr.io", + wantErr: true, + }, + { + name: "IPv6 loopback with port", + host: "[::1]:5000", + wantErr: false, + }, + { + name: "IPv6 full address with port", + host: "[2001:db8::1]:443", + wantErr: false, + }, + { + name: "IPv4 address with port", + host: "192.168.1.1:5000", + wantErr: false, + }, + { + name: "host with underscore", + host: "my_registry.local", + wantErr: false, + }, + { + name: "scheme with path", + host: "https://ghcr.io/myrepo", + wantErr: true, + }, + { + name: "trailing slash", + host: "ghcr.io/", + wantErr: true, + }, + { + name: "empty string", + host: "", + wantErr: true, + }, + { + name: "url with query parameters", + host: "ghcr.io?param=value", + wantErr: true, + }, + { + name: "url with fragment", + host: "ghcr.io#fragment", + wantErr: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got := warnIfHostHasPath(tt.host) - if got != tt.wantWarn { - t.Errorf("warnIfHostHasPath(%q) = %v, want %v", tt.host, got, tt.wantWarn) + err := validateHost(tt.host) + if (err != nil) != tt.wantErr { + t.Errorf("validateHost(%q) error = %v, wantErr %v", tt.host, err, tt.wantErr) } }) }