From b0b35f1231b0b885b1624c5586938cfa69d30995 Mon Sep 17 00:00:00 2001 From: Matheus Pimenta Date: Thu, 15 Jan 2026 14:53:15 +0000 Subject: [PATCH] feat(kstatus): fine-grained context options for waiting Signed-off-by: Matheus Pimenta --- pkg/kube/client.go | 12 +- pkg/kube/options.go | 41 +++- pkg/kube/statuswait.go | 27 ++- pkg/kube/statuswait_test.go | 469 ++++++++++++++++++++++++++++++++++++ 4 files changed, 533 insertions(+), 16 deletions(-) diff --git a/pkg/kube/client.go b/pkg/kube/client.go index 9af4bbcb3..ad8f851d1 100644 --- a/pkg/kube/client.go +++ b/pkg/kube/client.go @@ -167,10 +167,14 @@ func (c *Client) newStatusWatcher(opts ...WaitOption) (*statusWaiter, error) { waitContext = c.WaitContext } return &statusWaiter{ - restMapper: restMapper, - client: dynamicClient, - ctx: waitContext, - readers: o.statusReaders, + restMapper: restMapper, + client: dynamicClient, + ctx: waitContext, + watchUntilReadyCtx: o.watchUntilReadyCtx, + waitCtx: o.waitCtx, + waitWithJobsCtx: o.waitWithJobsCtx, + waitForDeleteCtx: o.waitForDeleteCtx, + readers: o.statusReaders, }, nil } diff --git a/pkg/kube/options.go b/pkg/kube/options.go index 49c6229ba..3326c284b 100644 --- a/pkg/kube/options.go +++ b/pkg/kube/options.go @@ -26,12 +26,45 @@ import ( type WaitOption func(*waitOptions) // WithWaitContext sets the context for waiting on resources. +// If unset, context.Background() will be used. func WithWaitContext(ctx context.Context) WaitOption { return func(wo *waitOptions) { wo.ctx = ctx } } +// WithWatchUntilReadyMethodContext sets the context specifically for the WatchUntilReady method. +// If unset, the context set by `WithWaitContext` will be used (falling back to `context.Background()`). +func WithWatchUntilReadyMethodContext(ctx context.Context) WaitOption { + return func(wo *waitOptions) { + wo.watchUntilReadyCtx = ctx + } +} + +// WithWaitMethodContext sets the context specifically for the Wait method. +// If unset, the context set by `WithWaitContext` will be used (falling back to `context.Background()`). +func WithWaitMethodContext(ctx context.Context) WaitOption { + return func(wo *waitOptions) { + wo.waitCtx = ctx + } +} + +// WithWaitWithJobsMethodContext sets the context specifically for the WaitWithJobs method. +// If unset, the context set by `WithWaitContext` will be used (falling back to `context.Background()`). +func WithWaitWithJobsMethodContext(ctx context.Context) WaitOption { + return func(wo *waitOptions) { + wo.waitWithJobsCtx = ctx + } +} + +// WithWaitForDeleteMethodContext sets the context specifically for the WaitForDelete method. +// If unset, the context set by `WithWaitContext` will be used (falling back to `context.Background()`). +func WithWaitForDeleteMethodContext(ctx context.Context) WaitOption { + return func(wo *waitOptions) { + wo.waitForDeleteCtx = ctx + } +} + // WithKStatusReaders sets the status readers to be used while waiting on resources. func WithKStatusReaders(readers ...engine.StatusReader) WaitOption { return func(wo *waitOptions) { @@ -40,6 +73,10 @@ func WithKStatusReaders(readers ...engine.StatusReader) WaitOption { } type waitOptions struct { - ctx context.Context - statusReaders []engine.StatusReader + ctx context.Context + watchUntilReadyCtx context.Context + waitCtx context.Context + waitWithJobsCtx context.Context + waitForDeleteCtx context.Context + statusReaders []engine.StatusReader } diff --git a/pkg/kube/statuswait.go b/pkg/kube/statuswait.go index bd6e4f93a..84819492b 100644 --- a/pkg/kube/statuswait.go +++ b/pkg/kube/statuswait.go @@ -42,10 +42,14 @@ import ( ) type statusWaiter struct { - client dynamic.Interface - restMapper meta.RESTMapper - ctx context.Context - readers []engine.StatusReader + client dynamic.Interface + restMapper meta.RESTMapper + ctx context.Context + watchUntilReadyCtx context.Context + waitCtx context.Context + waitWithJobsCtx context.Context + waitForDeleteCtx context.Context + readers []engine.StatusReader } // DefaultStatusWatcherTimeout is the timeout used by the status waiter when a @@ -66,7 +70,7 @@ func (w *statusWaiter) WatchUntilReady(resourceList ResourceList, timeout time.D if timeout == 0 { timeout = DefaultStatusWatcherTimeout } - ctx, cancel := w.contextWithTimeout(timeout) + ctx, cancel := w.contextWithTimeout(w.watchUntilReadyCtx, timeout) defer cancel() slog.Debug("waiting for resources", "count", len(resourceList), "timeout", timeout) sw := watcher.NewDefaultStatusWatcher(w.client, w.restMapper) @@ -88,7 +92,7 @@ func (w *statusWaiter) Wait(resourceList ResourceList, timeout time.Duration) er if timeout == 0 { timeout = DefaultStatusWatcherTimeout } - ctx, cancel := w.contextWithTimeout(timeout) + ctx, cancel := w.contextWithTimeout(w.waitCtx, timeout) defer cancel() slog.Debug("waiting for resources", "count", len(resourceList), "timeout", timeout) sw := watcher.NewDefaultStatusWatcher(w.client, w.restMapper) @@ -100,7 +104,7 @@ func (w *statusWaiter) WaitWithJobs(resourceList ResourceList, timeout time.Dura if timeout == 0 { timeout = DefaultStatusWatcherTimeout } - ctx, cancel := w.contextWithTimeout(timeout) + ctx, cancel := w.contextWithTimeout(w.waitWithJobsCtx, timeout) defer cancel() slog.Debug("waiting for resources", "count", len(resourceList), "timeout", timeout) sw := watcher.NewDefaultStatusWatcher(w.client, w.restMapper) @@ -116,7 +120,7 @@ func (w *statusWaiter) WaitForDelete(resourceList ResourceList, timeout time.Dur if timeout == 0 { timeout = DefaultStatusWatcherTimeout } - ctx, cancel := w.contextWithTimeout(timeout) + ctx, cancel := w.contextWithTimeout(w.waitForDeleteCtx, timeout) defer cancel() slog.Debug("waiting for resources to be deleted", "count", len(resourceList), "timeout", timeout) sw := watcher.NewDefaultStatusWatcher(w.client, w.restMapper) @@ -210,8 +214,11 @@ func (w *statusWaiter) wait(ctx context.Context, resourceList ResourceList, sw w return nil } -func (w *statusWaiter) contextWithTimeout(timeout time.Duration) (context.Context, context.CancelFunc) { - return contextWithTimeout(w.ctx, timeout) +func (w *statusWaiter) contextWithTimeout(methodCtx context.Context, timeout time.Duration) (context.Context, context.CancelFunc) { + if methodCtx == nil { + methodCtx = w.ctx + } + return contextWithTimeout(methodCtx, timeout) } func contextWithTimeout(ctx context.Context, timeout time.Duration) (context.Context, context.CancelFunc) { diff --git a/pkg/kube/statuswait_test.go b/pkg/kube/statuswait_test.go index 0d60a526f..83aaa357a 100644 --- a/pkg/kube/statuswait_test.go +++ b/pkg/kube/statuswait_test.go @@ -1246,6 +1246,475 @@ func TestStatusWaitWithFailedResources(t *testing.T) { } } +func TestWaitOptionFunctions(t *testing.T) { + t.Parallel() + + t.Run("WithWatchUntilReadyMethodContext sets watchUntilReadyCtx", func(t *testing.T) { + t.Parallel() + type contextKey struct{} + ctx := context.WithValue(context.Background(), contextKey{}, "test") + opts := &waitOptions{} + WithWatchUntilReadyMethodContext(ctx)(opts) + assert.Equal(t, ctx, opts.watchUntilReadyCtx) + }) + + t.Run("WithWaitMethodContext sets waitCtx", func(t *testing.T) { + t.Parallel() + type contextKey struct{} + ctx := context.WithValue(context.Background(), contextKey{}, "test") + opts := &waitOptions{} + WithWaitMethodContext(ctx)(opts) + assert.Equal(t, ctx, opts.waitCtx) + }) + + t.Run("WithWaitWithJobsMethodContext sets waitWithJobsCtx", func(t *testing.T) { + t.Parallel() + type contextKey struct{} + ctx := context.WithValue(context.Background(), contextKey{}, "test") + opts := &waitOptions{} + WithWaitWithJobsMethodContext(ctx)(opts) + assert.Equal(t, ctx, opts.waitWithJobsCtx) + }) + + t.Run("WithWaitForDeleteMethodContext sets waitForDeleteCtx", func(t *testing.T) { + t.Parallel() + type contextKey struct{} + ctx := context.WithValue(context.Background(), contextKey{}, "test") + opts := &waitOptions{} + WithWaitForDeleteMethodContext(ctx)(opts) + assert.Equal(t, ctx, opts.waitForDeleteCtx) + }) +} + +func TestMethodSpecificContextCancellation(t *testing.T) { + t.Parallel() + + t.Run("WatchUntilReady uses method-specific context", func(t *testing.T) { + t.Parallel() + c := newTestClient(t) + fakeClient := dynamicfake.NewSimpleDynamicClient(scheme.Scheme) + fakeMapper := testutil.NewFakeRESTMapper( + v1.SchemeGroupVersion.WithKind("Pod"), + ) + + // Create a cancelled method-specific context + methodCtx, methodCancel := context.WithCancel(context.Background()) + methodCancel() // Cancel immediately + + sw := statusWaiter{ + client: fakeClient, + restMapper: fakeMapper, + ctx: context.Background(), // General context is not cancelled + watchUntilReadyCtx: methodCtx, // Method context is cancelled + } + + objs := getRuntimeObjFromManifests(t, []string{podCompleteManifest}) + for _, obj := range objs { + u := obj.(*unstructured.Unstructured) + gvr := getGVR(t, fakeMapper, u) + err := fakeClient.Tracker().Create(gvr, u, u.GetNamespace()) + require.NoError(t, err) + } + resourceList := getResourceListFromRuntimeObjs(t, c, objs) + + err := sw.WatchUntilReady(resourceList, time.Second*3) + // Should fail due to cancelled method context + require.Error(t, err) + assert.Contains(t, err.Error(), "context canceled") + }) + + t.Run("Wait uses method-specific context", func(t *testing.T) { + t.Parallel() + c := newTestClient(t) + fakeClient := dynamicfake.NewSimpleDynamicClient(scheme.Scheme) + fakeMapper := testutil.NewFakeRESTMapper( + v1.SchemeGroupVersion.WithKind("Pod"), + ) + + // Create a cancelled method-specific context + methodCtx, methodCancel := context.WithCancel(context.Background()) + methodCancel() // Cancel immediately + + sw := statusWaiter{ + client: fakeClient, + restMapper: fakeMapper, + ctx: context.Background(), // General context is not cancelled + waitCtx: methodCtx, // Method context is cancelled + } + + objs := getRuntimeObjFromManifests(t, []string{podCurrentManifest}) + for _, obj := range objs { + u := obj.(*unstructured.Unstructured) + gvr := getGVR(t, fakeMapper, u) + err := fakeClient.Tracker().Create(gvr, u, u.GetNamespace()) + require.NoError(t, err) + } + resourceList := getResourceListFromRuntimeObjs(t, c, objs) + + err := sw.Wait(resourceList, time.Second*3) + // Should fail due to cancelled method context + require.Error(t, err) + assert.Contains(t, err.Error(), "context canceled") + }) + + t.Run("WaitWithJobs uses method-specific context", func(t *testing.T) { + t.Parallel() + c := newTestClient(t) + fakeClient := dynamicfake.NewSimpleDynamicClient(scheme.Scheme) + fakeMapper := testutil.NewFakeRESTMapper( + batchv1.SchemeGroupVersion.WithKind("Job"), + ) + + // Create a cancelled method-specific context + methodCtx, methodCancel := context.WithCancel(context.Background()) + methodCancel() // Cancel immediately + + sw := statusWaiter{ + client: fakeClient, + restMapper: fakeMapper, + ctx: context.Background(), // General context is not cancelled + waitWithJobsCtx: methodCtx, // Method context is cancelled + } + + objs := getRuntimeObjFromManifests(t, []string{jobCompleteManifest}) + for _, obj := range objs { + u := obj.(*unstructured.Unstructured) + gvr := getGVR(t, fakeMapper, u) + err := fakeClient.Tracker().Create(gvr, u, u.GetNamespace()) + require.NoError(t, err) + } + resourceList := getResourceListFromRuntimeObjs(t, c, objs) + + err := sw.WaitWithJobs(resourceList, time.Second*3) + // Should fail due to cancelled method context + require.Error(t, err) + assert.Contains(t, err.Error(), "context canceled") + }) + + t.Run("WaitForDelete uses method-specific context", func(t *testing.T) { + t.Parallel() + c := newTestClient(t) + fakeClient := dynamicfake.NewSimpleDynamicClient(scheme.Scheme) + fakeMapper := testutil.NewFakeRESTMapper( + v1.SchemeGroupVersion.WithKind("Pod"), + ) + + // Create a cancelled method-specific context + methodCtx, methodCancel := context.WithCancel(context.Background()) + methodCancel() // Cancel immediately + + sw := statusWaiter{ + client: fakeClient, + restMapper: fakeMapper, + ctx: context.Background(), // General context is not cancelled + waitForDeleteCtx: methodCtx, // Method context is cancelled + } + + objs := getRuntimeObjFromManifests(t, []string{podCurrentManifest}) + for _, obj := range objs { + u := obj.(*unstructured.Unstructured) + gvr := getGVR(t, fakeMapper, u) + err := fakeClient.Tracker().Create(gvr, u, u.GetNamespace()) + require.NoError(t, err) + } + resourceList := getResourceListFromRuntimeObjs(t, c, objs) + + err := sw.WaitForDelete(resourceList, time.Second*3) + // Should fail due to cancelled method context + require.Error(t, err) + assert.Contains(t, err.Error(), "context canceled") + }) +} + +func TestMethodContextFallbackToGeneralContext(t *testing.T) { + t.Parallel() + + t.Run("WatchUntilReady falls back to general context when method context is nil", func(t *testing.T) { + t.Parallel() + c := newTestClient(t) + fakeClient := dynamicfake.NewSimpleDynamicClient(scheme.Scheme) + fakeMapper := testutil.NewFakeRESTMapper( + v1.SchemeGroupVersion.WithKind("Pod"), + ) + + // Create a cancelled general context + generalCtx, generalCancel := context.WithCancel(context.Background()) + generalCancel() // Cancel immediately + + sw := statusWaiter{ + client: fakeClient, + restMapper: fakeMapper, + ctx: generalCtx, // General context is cancelled + watchUntilReadyCtx: nil, // Method context is nil, should fall back + } + + objs := getRuntimeObjFromManifests(t, []string{podCompleteManifest}) + for _, obj := range objs { + u := obj.(*unstructured.Unstructured) + gvr := getGVR(t, fakeMapper, u) + err := fakeClient.Tracker().Create(gvr, u, u.GetNamespace()) + require.NoError(t, err) + } + resourceList := getResourceListFromRuntimeObjs(t, c, objs) + + err := sw.WatchUntilReady(resourceList, time.Second*3) + // Should fail due to cancelled general context + require.Error(t, err) + assert.Contains(t, err.Error(), "context canceled") + }) + + t.Run("Wait falls back to general context when method context is nil", func(t *testing.T) { + t.Parallel() + c := newTestClient(t) + fakeClient := dynamicfake.NewSimpleDynamicClient(scheme.Scheme) + fakeMapper := testutil.NewFakeRESTMapper( + v1.SchemeGroupVersion.WithKind("Pod"), + ) + + // Create a cancelled general context + generalCtx, generalCancel := context.WithCancel(context.Background()) + generalCancel() // Cancel immediately + + sw := statusWaiter{ + client: fakeClient, + restMapper: fakeMapper, + ctx: generalCtx, // General context is cancelled + waitCtx: nil, // Method context is nil, should fall back + } + + objs := getRuntimeObjFromManifests(t, []string{podCurrentManifest}) + for _, obj := range objs { + u := obj.(*unstructured.Unstructured) + gvr := getGVR(t, fakeMapper, u) + err := fakeClient.Tracker().Create(gvr, u, u.GetNamespace()) + require.NoError(t, err) + } + resourceList := getResourceListFromRuntimeObjs(t, c, objs) + + err := sw.Wait(resourceList, time.Second*3) + // Should fail due to cancelled general context + require.Error(t, err) + assert.Contains(t, err.Error(), "context canceled") + }) + + t.Run("WaitWithJobs falls back to general context when method context is nil", func(t *testing.T) { + t.Parallel() + c := newTestClient(t) + fakeClient := dynamicfake.NewSimpleDynamicClient(scheme.Scheme) + fakeMapper := testutil.NewFakeRESTMapper( + batchv1.SchemeGroupVersion.WithKind("Job"), + ) + + // Create a cancelled general context + generalCtx, generalCancel := context.WithCancel(context.Background()) + generalCancel() // Cancel immediately + + sw := statusWaiter{ + client: fakeClient, + restMapper: fakeMapper, + ctx: generalCtx, // General context is cancelled + waitWithJobsCtx: nil, // Method context is nil, should fall back + } + + objs := getRuntimeObjFromManifests(t, []string{jobCompleteManifest}) + for _, obj := range objs { + u := obj.(*unstructured.Unstructured) + gvr := getGVR(t, fakeMapper, u) + err := fakeClient.Tracker().Create(gvr, u, u.GetNamespace()) + require.NoError(t, err) + } + resourceList := getResourceListFromRuntimeObjs(t, c, objs) + + err := sw.WaitWithJobs(resourceList, time.Second*3) + // Should fail due to cancelled general context + require.Error(t, err) + assert.Contains(t, err.Error(), "context canceled") + }) + + t.Run("WaitForDelete falls back to general context when method context is nil", func(t *testing.T) { + t.Parallel() + c := newTestClient(t) + fakeClient := dynamicfake.NewSimpleDynamicClient(scheme.Scheme) + fakeMapper := testutil.NewFakeRESTMapper( + v1.SchemeGroupVersion.WithKind("Pod"), + ) + + // Create a cancelled general context + generalCtx, generalCancel := context.WithCancel(context.Background()) + generalCancel() // Cancel immediately + + sw := statusWaiter{ + client: fakeClient, + restMapper: fakeMapper, + ctx: generalCtx, // General context is cancelled + waitForDeleteCtx: nil, // Method context is nil, should fall back + } + + objs := getRuntimeObjFromManifests(t, []string{podCurrentManifest}) + for _, obj := range objs { + u := obj.(*unstructured.Unstructured) + gvr := getGVR(t, fakeMapper, u) + err := fakeClient.Tracker().Create(gvr, u, u.GetNamespace()) + require.NoError(t, err) + } + resourceList := getResourceListFromRuntimeObjs(t, c, objs) + + err := sw.WaitForDelete(resourceList, time.Second*3) + // Should fail due to cancelled general context + require.Error(t, err) + assert.Contains(t, err.Error(), "context canceled") + }) +} + +func TestMethodContextOverridesGeneralContext(t *testing.T) { + t.Parallel() + + t.Run("method-specific context overrides general context for WatchUntilReady", func(t *testing.T) { + t.Parallel() + c := newTestClient(t) + fakeClient := dynamicfake.NewSimpleDynamicClient(scheme.Scheme) + fakeMapper := testutil.NewFakeRESTMapper( + v1.SchemeGroupVersion.WithKind("Pod"), + ) + + // General context is cancelled, but method context is not + generalCtx, generalCancel := context.WithCancel(context.Background()) + generalCancel() + + sw := statusWaiter{ + client: fakeClient, + restMapper: fakeMapper, + ctx: generalCtx, // Cancelled + watchUntilReadyCtx: context.Background(), // Not cancelled - should be used + } + + objs := getRuntimeObjFromManifests(t, []string{podCompleteManifest}) + for _, obj := range objs { + u := obj.(*unstructured.Unstructured) + gvr := getGVR(t, fakeMapper, u) + err := fakeClient.Tracker().Create(gvr, u, u.GetNamespace()) + require.NoError(t, err) + } + resourceList := getResourceListFromRuntimeObjs(t, c, objs) + + err := sw.WatchUntilReady(resourceList, time.Second*3) + // Should succeed because method context is used and it's not cancelled + assert.NoError(t, err) + }) + + t.Run("method-specific context overrides general context for Wait", func(t *testing.T) { + t.Parallel() + c := newTestClient(t) + fakeClient := dynamicfake.NewSimpleDynamicClient(scheme.Scheme) + fakeMapper := testutil.NewFakeRESTMapper( + v1.SchemeGroupVersion.WithKind("Pod"), + ) + + // General context is cancelled, but method context is not + generalCtx, generalCancel := context.WithCancel(context.Background()) + generalCancel() + + sw := statusWaiter{ + client: fakeClient, + restMapper: fakeMapper, + ctx: generalCtx, // Cancelled + waitCtx: context.Background(), // Not cancelled - should be used + } + + objs := getRuntimeObjFromManifests(t, []string{podCurrentManifest}) + for _, obj := range objs { + u := obj.(*unstructured.Unstructured) + gvr := getGVR(t, fakeMapper, u) + err := fakeClient.Tracker().Create(gvr, u, u.GetNamespace()) + require.NoError(t, err) + } + resourceList := getResourceListFromRuntimeObjs(t, c, objs) + + err := sw.Wait(resourceList, time.Second*3) + // Should succeed because method context is used and it's not cancelled + assert.NoError(t, err) + }) + + t.Run("method-specific context overrides general context for WaitWithJobs", func(t *testing.T) { + t.Parallel() + c := newTestClient(t) + fakeClient := dynamicfake.NewSimpleDynamicClient(scheme.Scheme) + fakeMapper := testutil.NewFakeRESTMapper( + batchv1.SchemeGroupVersion.WithKind("Job"), + ) + + // General context is cancelled, but method context is not + generalCtx, generalCancel := context.WithCancel(context.Background()) + generalCancel() + + sw := statusWaiter{ + client: fakeClient, + restMapper: fakeMapper, + ctx: generalCtx, // Cancelled + waitWithJobsCtx: context.Background(), // Not cancelled - should be used + } + + objs := getRuntimeObjFromManifests(t, []string{jobCompleteManifest}) + for _, obj := range objs { + u := obj.(*unstructured.Unstructured) + gvr := getGVR(t, fakeMapper, u) + err := fakeClient.Tracker().Create(gvr, u, u.GetNamespace()) + require.NoError(t, err) + } + resourceList := getResourceListFromRuntimeObjs(t, c, objs) + + err := sw.WaitWithJobs(resourceList, time.Second*3) + // Should succeed because method context is used and it's not cancelled + assert.NoError(t, err) + }) + + t.Run("method-specific context overrides general context for WaitForDelete", func(t *testing.T) { + t.Parallel() + c := newTestClient(t) + timeout := time.Second + timeUntilPodDelete := time.Millisecond * 500 + fakeClient := dynamicfake.NewSimpleDynamicClient(scheme.Scheme) + fakeMapper := testutil.NewFakeRESTMapper( + v1.SchemeGroupVersion.WithKind("Pod"), + ) + + // General context is cancelled, but method context is not + generalCtx, generalCancel := context.WithCancel(context.Background()) + generalCancel() + + sw := statusWaiter{ + client: fakeClient, + restMapper: fakeMapper, + ctx: generalCtx, // Cancelled + waitForDeleteCtx: context.Background(), // Not cancelled - should be used + } + + objs := getRuntimeObjFromManifests(t, []string{podCurrentManifest}) + for _, obj := range objs { + u := obj.(*unstructured.Unstructured) + gvr := getGVR(t, fakeMapper, u) + err := fakeClient.Tracker().Create(gvr, u, u.GetNamespace()) + require.NoError(t, err) + } + + // Schedule deletion + for _, obj := range objs { + u := obj.(*unstructured.Unstructured) + gvr := getGVR(t, fakeMapper, u) + go func(gvr schema.GroupVersionResource, u *unstructured.Unstructured) { + time.Sleep(timeUntilPodDelete) + err := fakeClient.Tracker().Delete(gvr, u.GetNamespace(), u.GetName()) + assert.NoError(t, err) + }(gvr, u) + } + + resourceList := getResourceListFromRuntimeObjs(t, c, objs) + err := sw.WaitForDelete(resourceList, timeout) + // Should succeed because method context is used and it's not cancelled + assert.NoError(t, err) + }) +} + func TestWatchUntilReadyWithCustomReaders(t *testing.T) { t.Parallel() tests := []struct {