diff --git a/pkg/registry/client.go b/pkg/registry/client.go index f2bfd13b4..cb89335a6 100644 --- a/pkg/registry/client.go +++ b/pkg/registry/client.go @@ -24,10 +24,10 @@ import ( "errors" "fmt" "io" - "log/slog" "net/http" "net/url" "os" + "regexp" "sort" "strings" @@ -227,15 +227,31 @@ type ( } ) -// warnIfHostHasPath checks if the host contains a repository path and logs a warning if it does. -// Returns true if the host contains a path component (i.e., contains a '/'). -func warnIfHostHasPath(host string) bool { - if strings.Contains(host, "/") { - registryHost := strings.Split(host, "/")[0] - slog.Warn("registry login currently only supports registry hostname, not a repository path", "host", host, "suggested", registryHost) - return true +// hostRegex is a pretty naive regex for validating host urls. The goal of it is not to ultimately validate all possible valid hosts, +// but to catch common user errors such as including a scheme or path in the host string. +var hostRegex = regexp.MustCompile(`^(?P[a-zA-Z][a-zA-Z0-9+\-.]*:\/\/)?(?P[a-zA-Z0-9\-._:\[\]]+)(?P\/.*)?$`) + +// validateHost checks that the host matches some required pre-checks e.g. does not contain a scheme or path. +// While ORAS will also validate some of these things, the current errors are a bit opaque. +// By validating these things upfront, we can provide clearer error messages to users when they attempt to login with an invalid host string. +func validateHost(host string) error { + matches := hostRegex.FindStringSubmatch(host) + if len(matches) == 0 { + return fmt.Errorf("invalid host: %q", host) + } + + scheme := matches[hostRegex.SubexpIndex("scheme")] + path := matches[hostRegex.SubexpIndex("path")] + + if scheme != "" { + return fmt.Errorf("host should not contain a scheme (e.g. http://), found %q", scheme) + } + + if path != "" { + return fmt.Errorf("host should not contain a path, found %q", path) } - return false + + return nil } // Login authenticates the client with a remote OCI registry using the provided host and options. @@ -244,7 +260,9 @@ func (c *Client) Login(host string, options ...LoginOption) error { option(&loginOperation{host, c}) } - warnIfHostHasPath(host) + if err := validateHost(host); err != nil { + return err + } reg, err := remote.NewRegistry(host) if err != nil { diff --git a/pkg/registry/client_test.go b/pkg/registry/client_test.go index 702dfff69..0140b038c 100644 --- a/pkg/registry/client_test.go +++ b/pkg/registry/client_test.go @@ -121,47 +121,106 @@ func TestLogin_ResetsForceAttemptOAuth2_OnFailure(t *testing.T) { } } -// TestWarnIfHostHasPath verifies that warnIfHostHasPath correctly detects path components. -func TestWarnIfHostHasPath(t *testing.T) { +func TestValidateHost(t *testing.T) { t.Parallel() tests := []struct { - name string - host string - wantWarn bool + name string + host string + wantErr bool }{ { - name: "domain only", - host: "ghcr.io", - wantWarn: false, + name: "domain only", + host: "ghcr.io", + wantErr: false, }, { - name: "domain with port", - host: "localhost:8000", - wantWarn: false, + name: "domain with port", + host: "localhost:8000", + wantErr: false, }, { - name: "domain with repository path", - host: "ghcr.io/terryhowe", - wantWarn: true, + name: "domain with repository path", + host: "ghcr.io/terryhowe", + wantErr: true, }, { - name: "domain with nested path", - host: "ghcr.io/terryhowe/myrepo", - wantWarn: true, + name: "domain with nested path", + host: "ghcr.io/terryhowe/myrepo", + wantErr: true, }, { - name: "localhost with port and path", - host: "localhost:8000/myrepo", - wantWarn: true, + name: "localhost with port and path", + host: "localhost:8000/myrepo", + wantErr: true, + }, + { + name: "domain with http protocol", + host: "http://ghcr.io", + wantErr: true, + }, + { + name: "domain with https protocol", + host: "https://ghcr.io", + wantErr: true, + }, + { + name: "domain with oci protocol", + host: "oci://ghcr.io", + wantErr: true, + }, + { + name: "domain with uppercase scheme", + host: "HTTPS://ghcr.io", + wantErr: true, + }, + { + name: "scheme with plus sign", + host: "coap+tcp://ghcr.io", + wantErr: true, + }, + { + name: "scheme with dot and hyphen", + host: "my.custom-scheme://ghcr.io", + wantErr: true, + }, + { + name: "IPv6 loopback with port", + host: "[::1]:5000", + wantErr: false, + }, + { + name: "IPv6 full address with port", + host: "[2001:db8::1]:443", + wantErr: false, + }, + { + name: "IPv4 address with port", + host: "192.168.1.1:5000", + wantErr: false, + }, + { + name: "host with underscore", + host: "my_registry.local", + wantErr: false, + }, + { + name: "scheme with path", + host: "https://ghcr.io/myrepo", + wantErr: true, + }, + { + name: "empty string", + host: "", + wantErr: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got := warnIfHostHasPath(tt.host) - if got != tt.wantWarn { - t.Errorf("warnIfHostHasPath(%q) = %v, want %v", tt.host, got, tt.wantWarn) + err := validateHost(tt.host) + if (err != nil) != tt.wantErr { + t.Errorf("validateHost(%q) error = %v, wantErr %v", tt.host, err, tt.wantErr) } }) }