diff --git a/cmd/helm/pull_test.go b/cmd/helm/pull_test.go index d8b1f57a7..aa75cad49 100644 --- a/cmd/helm/pull_test.go +++ b/cmd/helm/pull_test.go @@ -183,15 +183,20 @@ func TestPullCmd(t *testing.T) { wantError: true, }, { - name: "Fail fetching OCI chart without version specified", - args: fmt.Sprintf("oci://%s/u/ocitestuser/oci-dependent-chart:0.1.0", ociSrv.RegistryURL), - wantErrorMsg: "Error: --version flag is explicitly required for OCI registries", - wantError: true, + name: "Fetching OCI chart without version option specified", + args: fmt.Sprintf("oci://%s/u/ocitestuser/oci-dependent-chart:0.1.0", ociSrv.RegistryURL), + expectFile: "./oci-dependent-chart-0.1.0.tgz", }, { - name: "Fail fetching OCI chart without version specified", - args: fmt.Sprintf("oci://%s/u/ocitestuser/oci-dependent-chart:0.1.0 --version 0.1.0", ociSrv.RegistryURL), - wantError: true, + name: "Fetching OCI chart with version specified", + args: fmt.Sprintf("oci://%s/u/ocitestuser/oci-dependent-chart:0.1.0 --version 0.1.0", ociSrv.RegistryURL), + expectFile: "./oci-dependent-chart-0.1.0.tgz", + }, + { + name: "Fail fetching OCI chart with version mismatch", + args: fmt.Sprintf("oci://%s/u/ocitestuser/oci-dependent-chart:0.2.0 --version 0.1.0", ociSrv.RegistryURL), + wantErrorMsg: "Error: chart reference and version mismatch: 0.2.0 is not 0.1.0", + wantError: true, }, } diff --git a/pkg/downloader/chart_downloader.go b/pkg/downloader/chart_downloader.go index 910942cde..f5d1deac9 100644 --- a/pkg/downloader/chart_downloader.go +++ b/pkg/downloader/chart_downloader.go @@ -23,7 +23,6 @@ import ( "path/filepath" "strings" - "github.com/Masterminds/semver/v3" "github.com/pkg/errors" "helm.sh/helm/v4/internal/fileutil" @@ -143,39 +142,6 @@ func (c *ChartDownloader) DownloadTo(ref, version, dest string) (string, *proven return destfile, ver, nil } -func (c *ChartDownloader) getOciURI(ref, version string, u *url.URL) (*url.URL, error) { - var tag string - var err error - - // Evaluate whether an explicit version has been provided. Otherwise, determine version to use - _, errSemVer := semver.NewVersion(version) - if errSemVer == nil { - tag = version - } else { - // Retrieve list of repository tags - tags, err := c.RegistryClient.Tags(strings.TrimPrefix(ref, fmt.Sprintf("%s://", registry.OCIScheme))) - if err != nil { - return nil, err - } - if len(tags) == 0 { - return nil, errors.Errorf("Unable to locate any tags in provided repository: %s", ref) - } - - // Determine if version provided - // If empty, try to get the highest available tag - // If exact version, try to find it - // If semver constraint string, try to find a match - tag, err = registry.GetTagMatchingVersionOrConstraint(tags, version) - if err != nil { - return nil, err - } - } - - u.Path = fmt.Sprintf("%s:%s", u.Path, tag) - - return u, err -} - // ResolveChartVersion resolves a chart reference to a URL. // // It returns the URL and sets the ChartDownloader's Options that can fetch @@ -198,7 +164,7 @@ func (c *ChartDownloader) ResolveChartVersion(ref, version string) (*url.URL, er } if registry.IsOCI(u.String()) { - return c.getOciURI(ref, version, u) + return c.RegistryClient.ValidateReference(ref, version, u) } rf, err := loadRepoConfig(c.RepositoryConfig) diff --git a/pkg/downloader/chart_downloader_test.go b/pkg/downloader/chart_downloader_test.go index a8c359411..1e989ce6f 100644 --- a/pkg/downloader/chart_downloader_test.go +++ b/pkg/downloader/chart_downloader_test.go @@ -53,6 +53,12 @@ func TestResolveChartRef(t *testing.T) { {name: "full URL, file", ref: "file:///foo-1.2.3.tgz", fail: true}, {name: "invalid", ref: "invalid-1.2.3", fail: true}, {name: "not found", ref: "nosuchthing/invalid-1.2.3", fail: true}, + {name: "ref with tag", ref: "oci://example.com/helm-charts/nginx:15.4.2", expect: "oci://example.com/helm-charts/nginx:15.4.2"}, + {name: "no repository", ref: "oci://", fail: true}, + {name: "oci ref", ref: "oci://example.com/helm-charts/nginx", version: "15.4.2", expect: "oci://example.com/helm-charts/nginx:15.4.2"}, + {name: "oci ref with sha256", ref: "oci://example.com/install/by/sha@sha256:d234555386402a5867ef0169fefe5486858b6d8d209eaf32fd26d29b16807fd6", version: "0.1.1", expect: "oci://example.com/install/by/sha@sha256:d234555386402a5867ef0169fefe5486858b6d8d209eaf32fd26d29b16807fd6"}, + {name: "oci ref with sha256 and version", ref: "oci://example.com/install/by/sha:0.1.1@sha256:d234555386402a5867ef0169fefe5486858b6d8d209eaf32fd26d29b16807fd6", version: "0.1.1", expect: "oci://example.com/install/by/sha:0.1.1@sha256:d234555386402a5867ef0169fefe5486858b6d8d209eaf32fd26d29b16807fd6"}, + {name: "oci ref with sha256 and version mismatch", ref: "oci://example.com/install/by/sha:0.1.1@sha256:d234555386402a5867ef0169fefe5486858b6d8d209eaf32fd26d29b16807fd6", version: "0.1.2", fail: true}, } c := ChartDownloader{ diff --git a/pkg/getter/ocigetter.go b/pkg/getter/ocigetter.go index 5787dc909..1e3d9a0b4 100644 --- a/pkg/getter/ocigetter.go +++ b/pkg/getter/ocigetter.go @@ -20,6 +20,7 @@ import ( "fmt" "net" "net/http" + "path" "strings" "sync" "time" @@ -58,6 +59,9 @@ func (g *OCIGetter) get(href string) (*bytes.Buffer, error) { ref := strings.TrimPrefix(href, fmt.Sprintf("%s://", registry.OCIScheme)) + if version := g.opts.version; version != "" && !strings.Contains(path.Base(ref), ":") { + ref = fmt.Sprintf("%s:%s", ref, version) + } var pullOpts []registry.PullOption requestingProv := strings.HasSuffix(ref, ".prov") if requestingProv { diff --git a/pkg/registry/client.go b/pkg/registry/client.go index f51529965..5cb8d1bb4 100644 --- a/pkg/registry/client.go +++ b/pkg/registry/client.go @@ -23,6 +23,7 @@ import ( "fmt" "io" "net/http" + "net/url" "sort" "strings" @@ -377,7 +378,7 @@ type ( // Pull downloads a chart from a registry func (c *Client) Pull(ref string, options ...PullOption) (*PullResult, error) { - parsedRef, err := parseReference(ref) + parsedRef, err := newReference(ref) if err != nil { return nil, err } @@ -409,7 +410,7 @@ func (c *Client) Pull(ref string, options ...PullOption) (*PullResult, error) { } var descriptors, layers []ocispec.Descriptor - remotesResolver, err := c.resolver(parsedRef) + remotesResolver, err := c.resolver(parsedRef.orasReference) if err != nil { return nil, err } @@ -593,7 +594,7 @@ type ( // Push uploads a chart to a registry. func (c *Client) Push(data []byte, ref string, options ...PushOption) (*PushResult, error) { - parsedRef, err := parseReference(ref) + parsedRef, err := newReference(ref) if err != nil { return nil, err } @@ -652,12 +653,12 @@ func (c *Client) Push(data []byte, ref string, options ...PushOption) (*PushResu return nil, err } - remotesResolver, err := c.resolver(parsedRef) + remotesResolver, err := c.resolver(parsedRef.orasReference) if err != nil { return nil, err } registryStore := content.Registry{Resolver: remotesResolver} - _, err = oras.Copy(ctx(c.out, c.debug), memoryStore, parsedRef.String(), registryStore, "", + _, err = oras.Copy(ctx(c.out, c.debug), memoryStore, parsedRef.orasReference.String(), registryStore, "", oras.WithNameValidation(nil)) if err != nil { return nil, err @@ -688,7 +689,7 @@ func (c *Client) Push(data []byte, ref string, options ...PushOption) (*PushResu } fmt.Fprintf(c.out, "Pushed: %s\n", result.Ref) fmt.Fprintf(c.out, "Digest: %s\n", result.Manifest.Digest) - if strings.Contains(parsedRef.Reference, "_") { + if strings.Contains(parsedRef.orasReference.Reference, "_") { fmt.Fprintf(c.out, "%s contains an underscore.\n", result.Ref) fmt.Fprint(c.out, registryUnderscoreMessage+"\n") } @@ -759,3 +760,89 @@ func (c *Client) Tags(ref string) ([]string, error) { return tags, nil } + +// Resolve a reference to a descriptor. +func (c *Client) Resolve(ref string) (*ocispec.Descriptor, error) { + ctx := context.Background() + parsedRef, err := newReference(ref) + if err != nil { + return nil, err + } + if parsedRef.Registry == "" { + return nil, nil + } + + remotesResolver, err := c.resolver(parsedRef.orasReference) + if err != nil { + return nil, err + } + + _, desc, err := remotesResolver.Resolve(ctx, ref) + return &desc, err +} + +// ValidateReference for path and version +func (c *Client) ValidateReference(ref, version string, u *url.URL) (*url.URL, error) { + var tag string + + registryReference, err := newReference(u.Path) + if err != nil { + return nil, err + } + + if version == "" { + // Use OCI URI tag as default + version = registryReference.Tag + } else { + if registryReference.Tag != "" && registryReference.Tag != version { + return nil, errors.Errorf("chart reference and version mismatch: %s is not %s", version, registryReference.Tag) + } + } + + if registryReference.Digest != "" { + if registryReference.Tag == "" { + // Install by digest only + return u, nil + } + + // Validate the tag if it was specified + path := registryReference.Registry + "/" + registryReference.Repository + ":" + registryReference.Tag + desc, err := c.Resolve(path) + if err != nil { + // The resource does not have to be tagged when digest is specified + return u, nil + } + if desc != nil && desc.Digest.String() != registryReference.Digest { + return nil, errors.Errorf("chart reference digest mismatch: %s is not %s", desc.Digest.String(), registryReference.Digest) + } + return u, nil + } + + // Evaluate whether an explicit version has been provided. Otherwise, determine version to use + _, errSemVer := semver.NewVersion(version) + if errSemVer == nil { + tag = version + } else { + // Retrieve list of repository tags + tags, err := c.Tags(strings.TrimPrefix(ref, fmt.Sprintf("%s://", OCIScheme))) + if err != nil { + return nil, err + } + if len(tags) == 0 { + return nil, errors.Errorf("Unable to locate any tags in provided repository: %s", ref) + } + + // Determine if version provided + // If empty, try to get the highest available tag + // If exact version, try to find it + // If semver constraint string, try to find a match + tag, err = GetTagMatchingVersionOrConstraint(tags, version) + if err != nil { + return nil, err + } + } + + u.Path = fmt.Sprintf("%s/%s:%s", registryReference.Registry, registryReference.Repository, tag) + + return u, err +} diff --git a/pkg/registry/reference.go b/pkg/registry/reference.go new file mode 100644 index 000000000..9b99d73bf --- /dev/null +++ b/pkg/registry/reference.go @@ -0,0 +1,78 @@ +/* +Copyright The Helm Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package registry + +import ( + "strings" + + orasregistry "oras.land/oras-go/pkg/registry" +) + +type reference struct { + orasReference orasregistry.Reference + Registry string + Repository string + Tag string + Digest string +} + +// newReference will parse and validate the reference, and clean tags when +// applicable tags are only cleaned when plus (+) signs are present, and are +// converted to underscores (_) before pushing +// See https://github.com/helm/helm/issues/10166 +func newReference(raw string) (result reference, err error) { + // Remove oci:// prefix if it is there + raw = strings.TrimPrefix(raw, OCIScheme+"://") + + // The sole possible reference modification is replacing plus (+) signs + // present in tags with underscores (_). To do this properly, we first + // need to identify a tag, and then pass it on to the reference parser + // NOTE: Passing immediately to the reference parser will fail since (+) + // signs are an invalid tag character, and simply replacing all plus (+) + // occurrences could invalidate other portions of the URI + lastIndex := strings.LastIndex(raw, "@") + if lastIndex >= 0 { + result.Digest = raw[(lastIndex + 1):] + raw = raw[:lastIndex] + } + parts := strings.Split(raw, ":") + if len(parts) > 1 && !strings.Contains(parts[len(parts)-1], "/") { + tag := parts[len(parts)-1] + + if tag != "" { + // Replace any plus (+) signs with known underscore (_) conversion + newTag := strings.ReplaceAll(tag, "+", "_") + raw = strings.ReplaceAll(raw, tag, newTag) + } + } + + result.orasReference, err = orasregistry.ParseReference(raw) + if err != nil { + return result, err + } + result.Registry = result.orasReference.Registry + result.Repository = result.orasReference.Repository + result.Tag = result.orasReference.Reference + return result, nil +} + +func (r *reference) String() string { + if r.Tag == "" { + return r.orasReference.String() + "@" + r.Digest + } + return r.orasReference.String() +} diff --git a/pkg/registry/reference_test.go b/pkg/registry/reference_test.go new file mode 100644 index 000000000..31317d18f --- /dev/null +++ b/pkg/registry/reference_test.go @@ -0,0 +1,99 @@ +/* +Copyright The Helm Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package registry + +import "testing" + +func verify(t *testing.T, actual reference, registry, repository, tag, digest string) { + if registry != actual.orasReference.Registry { + t.Errorf("Oras reference registry expected %v actual %v", registry, actual.Registry) + } + if repository != actual.orasReference.Repository { + t.Errorf("Oras reference repository expected %v actual %v", repository, actual.Repository) + } + if tag != actual.orasReference.Reference { + t.Errorf("Oras reference reference expected %v actual %v", tag, actual.Tag) + } + if registry != actual.Registry { + t.Errorf("Registry expected %v actual %v", registry, actual.Registry) + } + if repository != actual.Repository { + t.Errorf("Repository expected %v actual %v", repository, actual.Repository) + } + if tag != actual.Tag { + t.Errorf("Tag expected %v actual %v", tag, actual.Tag) + } + if digest != actual.Digest { + t.Errorf("Digest expected %v actual %v", digest, actual.Digest) + } + expectedString := registry + if repository != "" { + expectedString = expectedString + "/" + repository + } + if tag != "" { + expectedString = expectedString + ":" + tag + } else { + expectedString = expectedString + "@" + digest + } + if actual.String() != expectedString { + t.Errorf("String expected %s actual %s", expectedString, actual.String()) + } +} + +func TestNewReference(t *testing.T) { + actual, err := newReference("registry.example.com/repository:1.0@sha256:c6841b3a895f1444a6738b5d04564a57e860ce42f8519c3be807fb6d9bee7888") + if err != nil { + t.Errorf("Unexpected error %v", err) + } + verify(t, actual, "registry.example.com", "repository", "1.0", "sha256:c6841b3a895f1444a6738b5d04564a57e860ce42f8519c3be807fb6d9bee7888") + + actual, err = newReference("oci://registry.example.com/repository:1.0@sha256:c6841b3a895f1444a6738b5d04564a57e860ce42f8519c3be807fb6d9bee7888") + if err != nil { + t.Errorf("Unexpected error %v", err) + } + verify(t, actual, "registry.example.com", "repository", "1.0", "sha256:c6841b3a895f1444a6738b5d04564a57e860ce42f8519c3be807fb6d9bee7888") + + actual, err = newReference("a/b:1@c") + if err != nil { + t.Errorf("Unexpected error %v", err) + } + verify(t, actual, "a", "b", "1", "c") + + actual, err = newReference("a/b:@") + if err != nil { + t.Errorf("Unexpected error %v", err) + } + verify(t, actual, "a", "b", "", "") + + actual, err = newReference("registry.example.com/repository:1.0+001") + if err != nil { + t.Errorf("Unexpected error %v", err) + } + verify(t, actual, "registry.example.com", "repository", "1.0_001", "") + + actual, err = newReference("thing:1.0") + if err == nil { + t.Errorf("Expect error error %v", err) + } + verify(t, actual, "", "", "", "") + + actual, err = newReference("registry.example.com/the/repository@sha256:c6841b3a895f1444a6738b5d04564a57e860ce42f8519c3be807fb6d9bee7888") + if err != nil { + t.Errorf("Unexpected error %v", err) + } + verify(t, actual, "registry.example.com", "the/repository", "", "sha256:c6841b3a895f1444a6738b5d04564a57e860ce42f8519c3be807fb6d9bee7888") +} diff --git a/pkg/registry/util.go b/pkg/registry/util.go index 5180b3313..78b7d4385 100644 --- a/pkg/registry/util.go +++ b/pkg/registry/util.go @@ -32,7 +32,6 @@ import ( "github.com/pkg/errors" "github.com/sirupsen/logrus" orascontext "oras.land/oras-go/pkg/context" - "oras.land/oras-go/pkg/registry" "helm.sh/helm/v4/internal/tlsutil" "helm.sh/helm/v4/pkg/chart" @@ -115,31 +114,6 @@ func ctx(out io.Writer, debug bool) context.Context { return ctx } -// parseReference will parse and validate the reference, and clean tags when -// applicable tags are only cleaned when plus (+) signs are present, and are -// converted to underscores (_) before pushing -// See https://github.com/helm/helm/issues/10166 -func parseReference(raw string) (registry.Reference, error) { - // The sole possible reference modification is replacing plus (+) signs - // present in tags with underscores (_). To do this properly, we first - // need to identify a tag, and then pass it on to the reference parser - // NOTE: Passing immediately to the reference parser will fail since (+) - // signs are an invalid tag character, and simply replacing all plus (+) - // occurrences could invalidate other portions of the URI - parts := strings.Split(raw, ":") - if len(parts) > 1 && !strings.Contains(parts[len(parts)-1], "/") { - tag := parts[len(parts)-1] - - if tag != "" { - // Replace any plus (+) signs with known underscore (_) conversion - newTag := strings.ReplaceAll(tag, "+", "_") - raw = strings.ReplaceAll(raw, tag, newTag) - } - } - - return registry.ParseReference(raw) -} - // NewRegistryClientWithTLS is a helper function to create a new registry client with TLS enabled. func NewRegistryClientWithTLS(out io.Writer, certFile, keyFile, caFile string, insecureSkipTLSverify bool, registryConfig string, debug bool) (*Client, error) { tlsConf, err := tlsutil.NewClientTLS(certFile, keyFile, caFile, insecureSkipTLSverify)