diff --git a/pkg/registry/client.go b/pkg/registry/client.go index 750bb9715..7eafe96c7 100644 --- a/pkg/registry/client.go +++ b/pkg/registry/client.go @@ -18,6 +18,7 @@ package registry // import "helm.sh/helm/v4/pkg/registry" import ( "context" + "crypto/sha256" "crypto/tls" "crypto/x509" "encoding/json" @@ -30,8 +31,10 @@ import ( "os" "sort" "strings" + "sync" "github.com/Masterminds/semver/v3" + "github.com/opencontainers/go-digest" "github.com/opencontainers/image-spec/specs-go" ocispec "github.com/opencontainers/image-spec/specs-go/v1" "oras.land/oras-go/v2" @@ -664,33 +667,53 @@ func (c *Client) Push(data []byte, ref string, options ...PushOption) (*PushResu } } - ctx := context.Background() - - memoryStore := memory.New() - chartDescriptor, err := oras.PushBytes(ctx, memoryStore, ChartLayerMediaType, data) + repository, err := remote.NewRepository(parsedRef.String()) if err != nil { return nil, err } + repository.PlainHTTP = c.plainHTTP + repository.Client = c.authorizer - configData, err := json.Marshal(meta) + ctx := context.Background() + ctx = auth.AppendRepositoryScope(ctx, repository.Reference, auth.ActionPull, auth.ActionPush) + + chartBlob := newBlob(repository, ChartLayerMediaType, data) + exists, err := chartBlob.exists(ctx) if err != nil { return nil, err } - configDescriptor, err := oras.PushBytes(ctx, memoryStore, ConfigMediaType, configData) + layers := []ocispec.Descriptor{chartBlob.descriptor} + var wg sync.WaitGroup + if !exists { + runWorker(ctx, &wg, chartBlob.push) + } + + configData, err := json.Marshal(meta) if err != nil { return nil, err } + configBlob := newBlob(repository, ConfigMediaType, configData) + runWorker(ctx, &wg, configBlob.pushNew) - layers := []ocispec.Descriptor{chartDescriptor} - var provDescriptor ocispec.Descriptor + var provBlob blob if operation.provData != nil { - provDescriptor, err = oras.PushBytes(ctx, memoryStore, ProvLayerMediaType, operation.provData) - if err != nil { - return nil, err - } + provBlob = newBlob(repository, ProvLayerMediaType, operation.provData) + runWorker(ctx, &wg, provBlob.pushNew) + } + wg.Wait() - layers = append(layers, provDescriptor) + if chartBlob.err != nil { + return nil, chartBlob.err + } + if configBlob.err != nil { + return nil, configBlob.err + } + if provBlob.err != nil { + return nil, provBlob.err + } + if operation.provData != nil { + layers = append(layers, provBlob.descriptor) } // sort layers for determinism, similar to how ORAS v1 does it @@ -700,20 +723,20 @@ func (c *Client) Push(data []byte, ref string, options ...PushOption) (*PushResu ociAnnotations := generateOCIAnnotations(meta, operation.creationTime) - manifestDescriptor, err := c.tagManifest(ctx, memoryStore, configDescriptor, - layers, ociAnnotations, parsedRef) - if err != nil { - return nil, err + manifest := ocispec.Manifest{ + Versioned: specs.Versioned{SchemaVersion: 2}, + Config: configBlob.descriptor, + Layers: layers, + Annotations: ociAnnotations, } - repository, err := remote.NewRepository(parsedRef.String()) + manifestData, err := json.Marshal(manifest) if err != nil { return nil, err } - repository.PlainHTTP = c.plainHTTP - repository.Client = c.authorizer - manifestDescriptor, err = oras.ExtendedCopy(ctx, memoryStore, parsedRef.String(), repository, parsedRef.String(), oras.DefaultExtendedCopyOptions) + manifestDescriptor, err := oras.TagBytes(ctx, repository, ocispec.MediaTypeImageManifest, + manifestData, parsedRef.String()) if err != nil { return nil, err } @@ -721,16 +744,16 @@ func (c *Client) Push(data []byte, ref string, options ...PushOption) (*PushResu chartSummary := &descriptorPushSummaryWithMeta{ Meta: meta, } - chartSummary.Digest = chartDescriptor.Digest.String() - chartSummary.Size = chartDescriptor.Size + chartSummary.Digest = chartBlob.descriptor.Digest.String() + chartSummary.Size = chartBlob.descriptor.Size result := &PushResult{ Manifest: &descriptorPushSummary{ Digest: manifestDescriptor.Digest.String(), Size: manifestDescriptor.Size, }, Config: &descriptorPushSummary{ - Digest: configDescriptor.Digest.String(), - Size: configDescriptor.Size, + Digest: configBlob.descriptor.Digest.String(), + Size: configBlob.descriptor.Size, }, Chart: chartSummary, Prov: &descriptorPushSummary{}, // prevent nil references @@ -738,8 +761,8 @@ func (c *Client) Push(data []byte, ref string, options ...PushOption) (*PushResu } if operation.provData != nil { result.Prov = &descriptorPushSummary{ - Digest: provDescriptor.Digest.String(), - Size: provDescriptor.Size, + Digest: provBlob.descriptor.Digest.String(), + Size: provBlob.descriptor.Size, } } _, _ = fmt.Fprintf(c.out, "Pushed: %s\n", result.Ref) @@ -925,3 +948,66 @@ func (c *Client) tagManifest(ctx context.Context, memoryStore *memory.Store, return oras.TagBytes(ctx, memoryStore, ocispec.MediaTypeImageManifest, manifestData, parsedRef.String()) } + +// runWorker spawns a goroutine to execute the worker function and tracks it +// with the provided WaitGroup. The WaitGroup counter is incremented before +// spawning and decremented when the worker completes. +func runWorker(ctx context.Context, wg *sync.WaitGroup, worker func(context.Context)) { + wg.Add(1) + go func() { + defer wg.Done() + worker(ctx) + }() +} + +// blob represents a content-addressable blob to be pushed to an OCI registry. +// It encapsulates the data, media type, and destination repository, and tracks +// the resulting descriptor and any error from push operations. +type blob struct { + mediaType string + dst *remote.Repository + data []byte + descriptor ocispec.Descriptor + err error +} + +// newBlob creates a new blob with the given repository, media type, and data. +func newBlob(dst *remote.Repository, mediaType string, data []byte) blob { + return blob{ + mediaType: mediaType, + dst: dst, + data: data, + } +} + +// exists checks if the blob already exists in the registry by computing its +// digest and querying the repository. It also populates the blob's descriptor +// with size, media type, and digest information. +func (b *blob) exists(ctx context.Context) (bool, error) { + hash := sha256.Sum256(b.data) + b.descriptor.Size = int64(len(b.data)) + b.descriptor.MediaType = b.mediaType + b.descriptor.Digest = digest.NewDigestFromBytes(digest.SHA256, hash[:]) + return b.dst.Exists(ctx, b.descriptor) +} + +// pushNew checks if the blob exists in the registry first, and only pushes +// if it doesn't exist. This avoids redundant uploads for blobs that are +// already present. Any error is stored in b.err. +func (b *blob) pushNew(ctx context.Context) { + var exists bool + exists, b.err = b.exists(ctx) + if b.err != nil { + return + } + if exists { + return + } + b.descriptor, b.err = oras.PushBytes(ctx, b.dst, b.mediaType, b.data) +} + +// push unconditionally pushes the blob to the registry without checking +// for existence first. Any error is stored in b.err. +func (b *blob) push(ctx context.Context) { + b.descriptor, b.err = oras.PushBytes(ctx, b.dst, b.mediaType, b.data) +} diff --git a/pkg/registry/client_test.go b/pkg/registry/client_test.go index 98a8b2ea3..223d5b96b 100644 --- a/pkg/registry/client_test.go +++ b/pkg/registry/client_test.go @@ -17,11 +17,15 @@ limitations under the License. package registry import ( + "crypto/sha256" + "fmt" "io" "net/http" "net/http/httptest" + "os" "path/filepath" "strings" + "sync" "testing" ocispec "github.com/opencontainers/image-spec/specs-go/v1" @@ -166,3 +170,98 @@ func TestWarnIfHostHasPath(t *testing.T) { }) } } + +// TestPushConcurrent verifies that concurrent Push operations on the same Client +// do not interfere with each other. This test is designed to catch race conditions +// when run with -race flag. +func TestPushConcurrent(t *testing.T) { + t.Parallel() + + // Create a mock registry server that accepts pushes + var mu sync.Mutex + uploads := make(map[string][]byte) + var uploadCounter int + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case r.Method == http.MethodHead && strings.Contains(r.URL.Path, "/blobs/"): + // Blob existence check - return 404 to force upload + w.WriteHeader(http.StatusNotFound) + + case r.Method == http.MethodPost && strings.Contains(r.URL.Path, "/blobs/uploads/"): + // Start upload - return upload URL with unique ID + mu.Lock() + uploadCounter++ + uploadID := fmt.Sprintf("upload-%d", uploadCounter) + mu.Unlock() + w.Header().Set("Location", fmt.Sprintf("%s%s", r.URL.Path, uploadID)) + w.WriteHeader(http.StatusAccepted) + + case r.Method == http.MethodPut && strings.Contains(r.URL.Path, "/blobs/uploads/"): + // Complete upload - extract digest from query param + body, _ := io.ReadAll(r.Body) + digest := r.URL.Query().Get("digest") + mu.Lock() + uploads[r.URL.Path] = body + mu.Unlock() + w.Header().Set("Docker-Content-Digest", digest) + w.WriteHeader(http.StatusCreated) + + case r.Method == http.MethodPut && strings.Contains(r.URL.Path, "/manifests/"): + // Manifest push - compute actual sha256 digest of the body + body, _ := io.ReadAll(r.Body) + hash := sha256.Sum256(body) + digest := fmt.Sprintf("sha256:%x", hash) + w.Header().Set("Docker-Content-Digest", digest) + w.WriteHeader(http.StatusCreated) + + default: + w.WriteHeader(http.StatusNotFound) + } + })) + defer srv.Close() + + host := strings.TrimPrefix(srv.URL, "http://") + + // Create client + credFile := filepath.Join(t.TempDir(), "config.json") + client, err := NewClient( + ClientOptWriter(io.Discard), + ClientOptCredentialsFile(credFile), + ClientOptPlainHTTP(), + ) + require.NoError(t, err) + + // Load test chart + chartData, err := os.ReadFile("../downloader/testdata/local-subchart-0.1.0.tgz") + require.NoError(t, err, "no error loading test chart") + + meta, err := extractChartMeta(chartData) + require.NoError(t, err, "no error extracting chart meta") + + // Run concurrent pushes + const numGoroutines = 10 + var wg sync.WaitGroup + errs := make(chan error, numGoroutines) + + for i := range numGoroutines { + wg.Add(1) + go func(idx int) { + defer wg.Done() + // Each goroutine pushes to a different tag to avoid conflicts + ref := fmt.Sprintf("%s/testrepo/%s:%s-%d", host, meta.Name, meta.Version, idx) + _, err := client.Push(chartData, ref, PushOptStrictMode(false)) + if err != nil { + errs <- fmt.Errorf("goroutine %d: %w", idx, err) + } + }(i) + } + + wg.Wait() + close(errs) + + // Check for errors + for err := range errs { + t.Error(err) + } +}