diff --git a/pkg/kube/client.go b/pkg/kube/client.go index 4442a0d91..68f1e6475 100644 --- a/pkg/kube/client.go +++ b/pkg/kube/client.go @@ -84,6 +84,11 @@ type Client struct { // Namespace allows to bypass the kubeconfig file for the choice of the namespace Namespace string + // WaitContext is an optional context to use for wait operations. + // If not set, a context will be created internally using the + // timeout provided to the wait functions. + WaitContext context.Context + Waiter kubeClient kubernetes.Interface @@ -140,6 +145,7 @@ func (c *Client) newStatusWatcher() (*statusWaiter, error) { return &statusWaiter{ restMapper: restMapper, client: dynamicClient, + ctx: c.WaitContext, }, nil } @@ -150,7 +156,7 @@ func (c *Client) GetWaiter(strategy WaitStrategy) (Waiter, error) { if err != nil { return nil, err } - return &legacyWaiter{kubeClient: kc}, nil + return &legacyWaiter{kubeClient: kc, ctx: c.WaitContext}, nil case StatusWatcherStrategy: return c.newStatusWatcher() case HookOnlyStrategy: diff --git a/pkg/kube/client_test.go b/pkg/kube/client_test.go index 3934171be..d49e179e0 100644 --- a/pkg/kube/client_test.go +++ b/pkg/kube/client_test.go @@ -18,6 +18,7 @@ package kube import ( "bytes" + "context" "errors" "fmt" "io" @@ -1798,3 +1799,387 @@ func TestDetermineFieldValidationDirective(t *testing.T) { assert.Equal(t, FieldValidationDirectiveIgnore, determineFieldValidationDirective(false)) assert.Equal(t, FieldValidationDirectiveStrict, determineFieldValidationDirective(true)) } + +func TestClientWaitContextCancellationLegacy(t *testing.T) { + podList := newPodList("starfish", "otter") + + ctx, cancel := context.WithCancel(t.Context()) + + c := newTestClient(t) + c.WaitContext = ctx + + requestCount := 0 + c.Factory.(*cmdtesting.TestFactory).Client = &fake.RESTClient{ + NegotiatedSerializer: unstructuredSerializer, + Client: fake.CreateHTTPClient(func(req *http.Request) (*http.Response, error) { + requestCount++ + p, m := req.URL.Path, req.Method + t.Logf("got request %s %s", p, m) + + if requestCount == 2 { + cancel() + } + + switch { + case p == "/api/v1/namespaces/default/pods/starfish" && m == http.MethodGet: + pod := &podList.Items[0] + pod.Status.Conditions = []v1.PodCondition{ + { + Type: v1.PodReady, + Status: v1.ConditionFalse, + }, + } + return newResponse(http.StatusOK, pod) + case p == "/api/v1/namespaces/default/pods/otter" && m == http.MethodGet: + pod := &podList.Items[1] + pod.Status.Conditions = []v1.PodCondition{ + { + Type: v1.PodReady, + Status: v1.ConditionFalse, + }, + } + return newResponse(http.StatusOK, pod) + case p == "/namespaces/default/pods" && m == http.MethodPost: + resources, err := c.Build(req.Body, false) + if err != nil { + t.Fatal(err) + } + return newResponse(http.StatusOK, resources[0].Object) + default: + t.Logf("unexpected request: %s %s", req.Method, req.URL.Path) + return newResponse(http.StatusNotFound, notFoundBody()) + } + }), + } + + var err error + c.Waiter, err = c.GetWaiter(LegacyStrategy) + require.NoError(t, err) + + resources, err := c.Build(objBody(&podList), false) + require.NoError(t, err) + + result, err := c.Create( + resources, + ClientCreateOptionServerSideApply(false, false)) + require.NoError(t, err) + assert.Len(t, result.Created, 2, "expected 2 resources created, got %d", len(result.Created)) + + err = c.Wait(resources, time.Second*30) + require.Error(t, err) + assert.Contains(t, err.Error(), "context canceled", "expected context canceled error, got: %v", err) +} + +func TestClientWaitWithJobsContextCancellationLegacy(t *testing.T) { + job := newJob("starfish", 0, intToInt32(1), 0, 0) + + ctx, cancel := context.WithCancel(t.Context()) + + c := newTestClient(t) + c.WaitContext = ctx + + requestCount := 0 + c.Factory.(*cmdtesting.TestFactory).Client = &fake.RESTClient{ + NegotiatedSerializer: unstructuredSerializer, + Client: fake.CreateHTTPClient(func(req *http.Request) (*http.Response, error) { + requestCount++ + p, m := req.URL.Path, req.Method + t.Logf("got request %s %s", p, m) + + if requestCount == 2 { + cancel() + } + + switch { + case p == "/apis/batch/v1/namespaces/default/jobs/starfish" && m == http.MethodGet: + job.Status.Succeeded = 0 + return newResponse(http.StatusOK, job) + case p == "/namespaces/default/jobs" && m == http.MethodPost: + resources, err := c.Build(req.Body, false) + if err != nil { + t.Fatal(err) + } + return newResponse(http.StatusOK, resources[0].Object) + default: + t.Logf("unexpected request: %s %s", req.Method, req.URL.Path) + return newResponse(http.StatusNotFound, notFoundBody()) + } + }), + } + + var err error + c.Waiter, err = c.GetWaiter(LegacyStrategy) + require.NoError(t, err) + + resources, err := c.Build(objBody(job), false) + require.NoError(t, err) + + result, err := c.Create( + resources, + ClientCreateOptionServerSideApply(false, false)) + require.NoError(t, err) + assert.Len(t, result.Created, 1, "expected 1 resource created, got %d", len(result.Created)) + + err = c.WaitWithJobs(resources, time.Second*30) + require.Error(t, err) + assert.Contains(t, err.Error(), "context canceled", "expected context canceled error, got: %v", err) +} + +func TestClientWaitForDeleteContextCancellationLegacy(t *testing.T) { + pod := newPod("starfish") + + ctx, cancel := context.WithCancel(t.Context()) + + c := newTestClient(t) + c.WaitContext = ctx + + deleted := false + requestCount := 0 + c.Factory.(*cmdtesting.TestFactory).Client = &fake.RESTClient{ + NegotiatedSerializer: unstructuredSerializer, + Client: fake.CreateHTTPClient(func(req *http.Request) (*http.Response, error) { + requestCount++ + p, m := req.URL.Path, req.Method + t.Logf("got request %s %s", p, m) + + if requestCount == 3 { + cancel() + } + + switch { + case p == "/namespaces/default/pods/starfish" && m == http.MethodGet: + if deleted { + return newResponse(http.StatusOK, &pod) + } + return newResponse(http.StatusOK, &pod) + case p == "/namespaces/default/pods/starfish" && m == http.MethodDelete: + deleted = true + return newResponse(http.StatusOK, &pod) + case p == "/namespaces/default/pods" && m == http.MethodPost: + resources, err := c.Build(req.Body, false) + if err != nil { + t.Fatal(err) + } + return newResponse(http.StatusOK, resources[0].Object) + default: + t.Logf("unexpected request: %s %s", req.Method, req.URL.Path) + return newResponse(http.StatusNotFound, notFoundBody()) + } + }), + } + + var err error + c.Waiter, err = c.GetWaiter(LegacyStrategy) + require.NoError(t, err) + + resources, err := c.Build(objBody(&pod), false) + require.NoError(t, err) + + result, err := c.Create( + resources, + ClientCreateOptionServerSideApply(false, false)) + require.NoError(t, err) + assert.Len(t, result.Created, 1, "expected 1 resource created, got %d", len(result.Created)) + + if _, err := c.Delete(resources, metav1.DeletePropagationBackground); err != nil { + t.Fatal(err) + } + + err = c.WaitForDelete(resources, time.Second*30) + require.Error(t, err) + assert.Contains(t, err.Error(), "context canceled", "expected context canceled error, got: %v", err) +} + +func TestClientWaitContextNilDoesNotPanic(t *testing.T) { + podList := newPodList("starfish") + + var created *time.Time + + c := newTestClient(t) + c.WaitContext = nil + + c.Factory.(*cmdtesting.TestFactory).Client = &fake.RESTClient{ + NegotiatedSerializer: unstructuredSerializer, + Client: fake.CreateHTTPClient(func(req *http.Request) (*http.Response, error) { + p, m := req.URL.Path, req.Method + t.Logf("got request %s %s", p, m) + switch { + case p == "/api/v1/namespaces/default/pods/starfish" && m == http.MethodGet: + pod := &podList.Items[0] + if created != nil && time.Since(*created) >= time.Second*2 { + pod.Status.Conditions = []v1.PodCondition{ + { + Type: v1.PodReady, + Status: v1.ConditionTrue, + }, + } + } + return newResponse(http.StatusOK, pod) + case p == "/namespaces/default/pods" && m == http.MethodPost: + resources, err := c.Build(req.Body, false) + if err != nil { + t.Fatal(err) + } + now := time.Now() + created = &now + return newResponse(http.StatusOK, resources[0].Object) + default: + t.Fatalf("unexpected request: %s %s", req.Method, req.URL.Path) + return nil, nil + } + }), + } + + var err error + c.Waiter, err = c.GetWaiter(LegacyStrategy) + require.NoError(t, err) + + resources, err := c.Build(objBody(&podList), false) + require.NoError(t, err) + + result, err := c.Create( + resources, + ClientCreateOptionServerSideApply(false, false)) + require.NoError(t, err) + assert.Len(t, result.Created, 1, "expected 1 resource created, got %d", len(result.Created)) + + err = c.Wait(resources, time.Second*30) + require.NoError(t, err) + + assert.GreaterOrEqual(t, time.Since(*created), time.Second*2, "expected to wait at least 2 seconds") +} + +func TestClientWaitContextPreCancelledLegacy(t *testing.T) { + podList := newPodList("starfish") + + ctx, cancel := context.WithCancel(t.Context()) + cancel() + + c := newTestClient(t) + c.WaitContext = ctx + + c.Factory.(*cmdtesting.TestFactory).Client = &fake.RESTClient{ + NegotiatedSerializer: unstructuredSerializer, + Client: fake.CreateHTTPClient(func(req *http.Request) (*http.Response, error) { + p, m := req.URL.Path, req.Method + t.Logf("got request %s %s", p, m) + switch { + case p == "/api/v1/namespaces/default/pods/starfish" && m == http.MethodGet: + pod := &podList.Items[0] + return newResponse(http.StatusOK, pod) + case p == "/namespaces/default/pods" && m == http.MethodPost: + resources, err := c.Build(req.Body, false) + if err != nil { + t.Fatal(err) + } + return newResponse(http.StatusOK, resources[0].Object) + default: + t.Fatalf("unexpected request: %s %s", req.Method, req.URL.Path) + return nil, nil + } + }), + } + + var err error + c.Waiter, err = c.GetWaiter(LegacyStrategy) + require.NoError(t, err) + + resources, err := c.Build(objBody(&podList), false) + require.NoError(t, err) + + result, err := c.Create( + resources, + ClientCreateOptionServerSideApply(false, false)) + require.NoError(t, err) + assert.Len(t, result.Created, 1, "expected 1 resource created, got %d", len(result.Created)) + + err = c.Wait(resources, time.Second*30) + require.Error(t, err) + assert.Contains(t, err.Error(), "context canceled", "expected context canceled error, got: %v", err) +} + +func TestClientWaitContextCancellationStatusWatcher(t *testing.T) { + ctx, cancel := context.WithCancel(t.Context()) + + c := newTestClient(t) + c.WaitContext = ctx + + podManifest := ` +apiVersion: v1 +kind: Pod +metadata: + name: test-pod + namespace: default +` + var err error + c.Waiter, err = c.GetWaiter(StatusWatcherStrategy) + require.NoError(t, err) + + resources, err := c.Build(strings.NewReader(podManifest), false) + require.NoError(t, err) + + cancel() + + err = c.Wait(resources, time.Second*30) + require.Error(t, err) + assert.Contains(t, err.Error(), "context canceled", "expected context canceled error, got: %v", err) +} + +func TestClientWaitWithJobsContextCancellationStatusWatcher(t *testing.T) { + ctx, cancel := context.WithCancel(t.Context()) + + c := newTestClient(t) + c.WaitContext = ctx + + jobManifest := ` +apiVersion: batch/v1 +kind: Job +metadata: + name: test-job + namespace: default +` + var err error + c.Waiter, err = c.GetWaiter(StatusWatcherStrategy) + require.NoError(t, err) + + resources, err := c.Build(strings.NewReader(jobManifest), false) + require.NoError(t, err) + + cancel() + + err = c.WaitWithJobs(resources, time.Second*30) + require.Error(t, err) + assert.Contains(t, err.Error(), "context canceled", "expected context canceled error, got: %v", err) +} + +func TestClientWaitForDeleteContextCancellationStatusWatcher(t *testing.T) { + ctx, cancel := context.WithCancel(t.Context()) + + c := newTestClient(t) + c.WaitContext = ctx + + podManifest := ` +apiVersion: v1 +kind: Pod +metadata: + name: test-pod + namespace: default +status: + conditions: + - type: Ready + status: "True" + phase: Running +` + var err error + c.Waiter, err = c.GetWaiter(StatusWatcherStrategy) + require.NoError(t, err) + + resources, err := c.Build(strings.NewReader(podManifest), false) + require.NoError(t, err) + + cancel() + + err = c.WaitForDelete(resources, time.Second*30) + require.Error(t, err) + assert.Contains(t, err.Error(), "context canceled", "expected context canceled error, got: %v", err) +} diff --git a/pkg/kube/statuswait.go b/pkg/kube/statuswait.go index 2d7cfe971..cd9722eda 100644 --- a/pkg/kube/statuswait.go +++ b/pkg/kube/statuswait.go @@ -36,6 +36,7 @@ import ( "k8s.io/apimachinery/pkg/api/meta" "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" "k8s.io/client-go/dynamic" + watchtools "k8s.io/client-go/tools/watch" helmStatusReaders "helm.sh/helm/v4/internal/statusreaders" ) @@ -43,6 +44,7 @@ import ( type statusWaiter struct { client dynamic.Interface restMapper meta.RESTMapper + ctx context.Context } func alwaysReady(_ *unstructured.Unstructured) (*status.Result, error) { @@ -53,7 +55,7 @@ func alwaysReady(_ *unstructured.Unstructured) (*status.Result, error) { } func (w *statusWaiter) WatchUntilReady(resourceList ResourceList, timeout time.Duration) error { - ctx, cancel := context.WithTimeout(context.Background(), timeout) + ctx, cancel := w.contextWithTimeout(timeout) defer cancel() slog.Debug("waiting for resources", "count", len(resourceList), "timeout", timeout) sw := watcher.NewDefaultStatusWatcher(w.client, w.restMapper) @@ -74,7 +76,7 @@ func (w *statusWaiter) WatchUntilReady(resourceList ResourceList, timeout time.D } func (w *statusWaiter) Wait(resourceList ResourceList, timeout time.Duration) error { - ctx, cancel := context.WithTimeout(context.TODO(), timeout) + ctx, cancel := w.contextWithTimeout(timeout) defer cancel() slog.Debug("waiting for resources", "count", len(resourceList), "timeout", timeout) sw := watcher.NewDefaultStatusWatcher(w.client, w.restMapper) @@ -82,7 +84,7 @@ func (w *statusWaiter) Wait(resourceList ResourceList, timeout time.Duration) er } func (w *statusWaiter) WaitWithJobs(resourceList ResourceList, timeout time.Duration) error { - ctx, cancel := context.WithTimeout(context.TODO(), timeout) + ctx, cancel := w.contextWithTimeout(timeout) defer cancel() slog.Debug("waiting for resources", "count", len(resourceList), "timeout", timeout) sw := watcher.NewDefaultStatusWatcher(w.client, w.restMapper) @@ -93,7 +95,7 @@ func (w *statusWaiter) WaitWithJobs(resourceList ResourceList, timeout time.Dura } func (w *statusWaiter) WaitForDelete(resourceList ResourceList, timeout time.Duration) error { - ctx, cancel := context.WithTimeout(context.TODO(), timeout) + ctx, cancel := w.contextWithTimeout(timeout) defer cancel() slog.Debug("waiting for resources to be deleted", "count", len(resourceList), "timeout", timeout) sw := watcher.NewDefaultStatusWatcher(w.client, w.restMapper) @@ -179,6 +181,17 @@ func (w *statusWaiter) wait(ctx context.Context, resourceList ResourceList, sw w return nil } +func (w *statusWaiter) contextWithTimeout(timeout time.Duration) (context.Context, context.CancelFunc) { + return contextWithTimeout(w.ctx, timeout) +} + +func contextWithTimeout(ctx context.Context, timeout time.Duration) (context.Context, context.CancelFunc) { + if ctx == nil { + ctx = context.Background() + } + return watchtools.ContextWithOptionalTimeout(ctx, timeout) +} + func statusObserver(cancel context.CancelFunc, desired status.Status) collector.ObserverFunc { return func(statusCollector *collector.ResourceStatusCollector, _ event.Event) { var rss []*event.ResourceStatus diff --git a/pkg/kube/wait.go b/pkg/kube/wait.go index 9bfa1ef6d..f776ae471 100644 --- a/pkg/kube/wait.go +++ b/pkg/kube/wait.go @@ -49,6 +49,7 @@ import ( type legacyWaiter struct { c ReadyChecker kubeClient *kubernetes.Clientset + ctx context.Context } func (hw *legacyWaiter) Wait(resources ResourceList, timeout time.Duration) error { @@ -66,7 +67,7 @@ func (hw *legacyWaiter) WaitWithJobs(resources ResourceList, timeout time.Durati func (hw *legacyWaiter) waitForResources(created ResourceList, timeout time.Duration) error { slog.Debug("beginning wait for resources", "count", len(created), "timeout", timeout) - ctx, cancel := context.WithTimeout(context.Background(), timeout) + ctx, cancel := hw.contextWithTimeout(timeout) defer cancel() numberOfErrors := make([]int, len(created)) @@ -121,7 +122,7 @@ func (hw *legacyWaiter) WaitForDelete(deleted ResourceList, timeout time.Duratio slog.Debug("beginning wait for resources to be deleted", "count", len(deleted), "timeout", timeout) startTime := time.Now() - ctx, cancel := context.WithTimeout(context.Background(), timeout) + ctx, cancel := hw.contextWithTimeout(timeout) defer cancel() err := wait.PollUntilContextCancel(ctx, 2*time.Second, true, func(_ context.Context) (bool, error) { @@ -246,7 +247,7 @@ func (hw *legacyWaiter) watchUntilReady(timeout time.Duration, info *resource.In // In the future, we might want to add some special logic for types // like Ingress, Volume, etc. - ctx, cancel := watchtools.ContextWithOptionalTimeout(context.Background(), timeout) + ctx, cancel := hw.contextWithTimeout(timeout) defer cancel() _, err = watchtools.UntilWithSync(ctx, lw, &unstructured.Unstructured{}, nil, func(e watch.Event) (bool, error) { // Make sure the incoming object is versioned as we use unstructured @@ -327,3 +328,7 @@ func (hw *legacyWaiter) waitForPodSuccess(obj runtime.Object, name string) (bool return false, nil } + +func (hw *legacyWaiter) contextWithTimeout(timeout time.Duration) (context.Context, context.CancelFunc) { + return contextWithTimeout(hw.ctx, timeout) +}