diff --git a/pkg/registry/client.go b/pkg/registry/client.go index 537f2a1a2..7eafe96c7 100644 --- a/pkg/registry/client.go +++ b/pkg/registry/client.go @@ -79,7 +79,6 @@ type ( credentialsStore credentials.Store httpClient *http.Client plainHTTP bool - wg sync.WaitGroup } // ClientOption allows specifying various settings configurable by the user for overriding the defaults @@ -685,8 +684,9 @@ func (c *Client) Push(data []byte, ref string, options ...PushOption) (*PushResu } layers := []ocispec.Descriptor{chartBlob.descriptor} + var wg sync.WaitGroup if !exists { - c.runWorker(ctx, chartBlob.push) + runWorker(ctx, &wg, chartBlob.push) } configData, err := json.Marshal(meta) @@ -694,14 +694,14 @@ func (c *Client) Push(data []byte, ref string, options ...PushOption) (*PushResu return nil, err } configBlob := newBlob(repository, ConfigMediaType, configData) - c.runWorker(ctx, configBlob.pushNew) + runWorker(ctx, &wg, configBlob.pushNew) var provBlob blob if operation.provData != nil { provBlob = newBlob(repository, ProvLayerMediaType, operation.provData) - c.runWorker(ctx, provBlob.pushNew) + runWorker(ctx, &wg, provBlob.pushNew) } - c.wg.Wait() + wg.Wait() if chartBlob.err != nil { return nil, chartBlob.err @@ -949,14 +949,20 @@ func (c *Client) tagManifest(ctx context.Context, memoryStore *memory.Store, manifestData, parsedRef.String()) } -func (c *Client) runWorker(ctx context.Context, worker func(context.Context)) { - c.wg.Add(1) +// runWorker spawns a goroutine to execute the worker function and tracks it +// with the provided WaitGroup. The WaitGroup counter is incremented before +// spawning and decremented when the worker completes. +func runWorker(ctx context.Context, wg *sync.WaitGroup, worker func(context.Context)) { + wg.Add(1) go func() { - defer c.wg.Done() + defer wg.Done() worker(ctx) }() } +// blob represents a content-addressable blob to be pushed to an OCI registry. +// It encapsulates the data, media type, and destination repository, and tracks +// the resulting descriptor and any error from push operations. type blob struct { mediaType string dst *remote.Repository @@ -965,6 +971,7 @@ type blob struct { err error } +// newBlob creates a new blob with the given repository, media type, and data. func newBlob(dst *remote.Repository, mediaType string, data []byte) blob { return blob{ mediaType: mediaType, @@ -973,6 +980,9 @@ func newBlob(dst *remote.Repository, mediaType string, data []byte) blob { } } +// exists checks if the blob already exists in the registry by computing its +// digest and querying the repository. It also populates the blob's descriptor +// with size, media type, and digest information. func (b *blob) exists(ctx context.Context) (bool, error) { hash := sha256.Sum256(b.data) b.descriptor.Size = int64(len(b.data)) @@ -981,6 +991,9 @@ func (b *blob) exists(ctx context.Context) (bool, error) { return b.dst.Exists(ctx, b.descriptor) } +// pushNew checks if the blob exists in the registry first, and only pushes +// if it doesn't exist. This avoids redundant uploads for blobs that are +// already present. Any error is stored in b.err. func (b *blob) pushNew(ctx context.Context) { var exists bool exists, b.err = b.exists(ctx) @@ -993,6 +1006,8 @@ func (b *blob) pushNew(ctx context.Context) { b.descriptor, b.err = oras.PushBytes(ctx, b.dst, b.mediaType, b.data) } +// push unconditionally pushes the blob to the registry without checking +// for existence first. Any error is stored in b.err. func (b *blob) push(ctx context.Context) { b.descriptor, b.err = oras.PushBytes(ctx, b.dst, b.mediaType, b.data) } diff --git a/pkg/registry/client_test.go b/pkg/registry/client_test.go index 98a8b2ea3..223d5b96b 100644 --- a/pkg/registry/client_test.go +++ b/pkg/registry/client_test.go @@ -17,11 +17,15 @@ limitations under the License. package registry import ( + "crypto/sha256" + "fmt" "io" "net/http" "net/http/httptest" + "os" "path/filepath" "strings" + "sync" "testing" ocispec "github.com/opencontainers/image-spec/specs-go/v1" @@ -166,3 +170,98 @@ func TestWarnIfHostHasPath(t *testing.T) { }) } } + +// TestPushConcurrent verifies that concurrent Push operations on the same Client +// do not interfere with each other. This test is designed to catch race conditions +// when run with -race flag. +func TestPushConcurrent(t *testing.T) { + t.Parallel() + + // Create a mock registry server that accepts pushes + var mu sync.Mutex + uploads := make(map[string][]byte) + var uploadCounter int + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case r.Method == http.MethodHead && strings.Contains(r.URL.Path, "/blobs/"): + // Blob existence check - return 404 to force upload + w.WriteHeader(http.StatusNotFound) + + case r.Method == http.MethodPost && strings.Contains(r.URL.Path, "/blobs/uploads/"): + // Start upload - return upload URL with unique ID + mu.Lock() + uploadCounter++ + uploadID := fmt.Sprintf("upload-%d", uploadCounter) + mu.Unlock() + w.Header().Set("Location", fmt.Sprintf("%s%s", r.URL.Path, uploadID)) + w.WriteHeader(http.StatusAccepted) + + case r.Method == http.MethodPut && strings.Contains(r.URL.Path, "/blobs/uploads/"): + // Complete upload - extract digest from query param + body, _ := io.ReadAll(r.Body) + digest := r.URL.Query().Get("digest") + mu.Lock() + uploads[r.URL.Path] = body + mu.Unlock() + w.Header().Set("Docker-Content-Digest", digest) + w.WriteHeader(http.StatusCreated) + + case r.Method == http.MethodPut && strings.Contains(r.URL.Path, "/manifests/"): + // Manifest push - compute actual sha256 digest of the body + body, _ := io.ReadAll(r.Body) + hash := sha256.Sum256(body) + digest := fmt.Sprintf("sha256:%x", hash) + w.Header().Set("Docker-Content-Digest", digest) + w.WriteHeader(http.StatusCreated) + + default: + w.WriteHeader(http.StatusNotFound) + } + })) + defer srv.Close() + + host := strings.TrimPrefix(srv.URL, "http://") + + // Create client + credFile := filepath.Join(t.TempDir(), "config.json") + client, err := NewClient( + ClientOptWriter(io.Discard), + ClientOptCredentialsFile(credFile), + ClientOptPlainHTTP(), + ) + require.NoError(t, err) + + // Load test chart + chartData, err := os.ReadFile("../downloader/testdata/local-subchart-0.1.0.tgz") + require.NoError(t, err, "no error loading test chart") + + meta, err := extractChartMeta(chartData) + require.NoError(t, err, "no error extracting chart meta") + + // Run concurrent pushes + const numGoroutines = 10 + var wg sync.WaitGroup + errs := make(chan error, numGoroutines) + + for i := range numGoroutines { + wg.Add(1) + go func(idx int) { + defer wg.Done() + // Each goroutine pushes to a different tag to avoid conflicts + ref := fmt.Sprintf("%s/testrepo/%s:%s-%d", host, meta.Name, meta.Version, idx) + _, err := client.Push(chartData, ref, PushOptStrictMode(false)) + if err != nil { + errs <- fmt.Errorf("goroutine %d: %w", idx, err) + } + }(i) + } + + wg.Wait() + close(errs) + + // Check for errors + for err := range errs { + t.Error(err) + } +}