diff --git a/internal/client/refreshing_transport.go b/internal/client/refreshing_transport.go index 9567f24..cca08fa 100644 --- a/internal/client/refreshing_transport.go +++ b/internal/client/refreshing_transport.go @@ -1,7 +1,9 @@ package client import ( + "bytes" "context" + "io" "net/http" "strings" "sync" @@ -89,11 +91,65 @@ func (t *refreshingAuthTransport) getToken(ctx context.Context) string { return t.cachedToken } +// forceRefresh fetches a new token regardless of the cache TTL. It is storm-safe: +// if another goroutine already replaced staleToken (e.g. several concurrent +// requests hit a 401 at once), the already-refreshed token is reused instead of +// fetching again. +func (t *refreshingAuthTransport) forceRefresh(ctx context.Context, staleToken string) string { + t.mu.Lock() + defer t.mu.Unlock() + if t.cachedToken != staleToken { + return t.cachedToken + } + newToken, err := t.tokenProvider() + if err != nil { + logger.L(ctx).Warn("forced token refresh failed after 401", logger.Err(err)) + return "" + } + t.cachedToken = newToken + t.cachedAt = time.Now() + return newToken +} + func (t *refreshingAuthTransport) RoundTrip(req *http.Request) (*http.Response, error) { + // Buffer the body so the request can be replayed if the first attempt 401s. + var bodyBytes []byte + if req.Body != nil { + b, err := io.ReadAll(req.Body) + if err != nil { + return nil, err + } + req.Body.Close() + bodyBytes = b + } + + attempt := func(token string) (*http.Response, error) { + r := req.Clone(req.Context()) + if bodyBytes != nil { + r.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + r.ContentLength = int64(len(bodyBytes)) + } + r.Header.Set("Authorization", "Bearer "+token) + r.Header.Set("User-Agent", "Nullify-CLI/mcp") + return t.transport.RoundTrip(r) + } + token := t.getToken(req.Context()) + resp, err := attempt(token) + if err != nil { + return nil, err + } + + // The cached token can be invalid before its TTL elapses (revocation, server + // session kill, clock skew). On a 401, force a refresh and retry once before + // surfacing the failure. + if resp.StatusCode == http.StatusUnauthorized { + if newToken := t.forceRefresh(req.Context(), token); newToken != "" && newToken != token { + _, _ = io.Copy(io.Discard, io.LimitReader(resp.Body, 4<<10)) + resp.Body.Close() + return attempt(newToken) + } + } - r := req.Clone(req.Context()) - r.Header.Set("Authorization", "Bearer "+token) - r.Header.Set("User-Agent", "Nullify-CLI/mcp") - return t.transport.RoundTrip(r) + return resp, nil } diff --git a/internal/client/refreshing_transport_test.go b/internal/client/refreshing_transport_test.go new file mode 100644 index 0000000..3b6caac --- /dev/null +++ b/internal/client/refreshing_transport_test.go @@ -0,0 +1,179 @@ +package client + +import ( + "io" + "net/http" + "strings" + "sync" + "testing" + "time" +) + +func newRefreshingTransport(initial string, tp TokenProvider, inner http.RoundTripper) *refreshingAuthTransport { + return &refreshingAuthTransport{ + nullifyHost: "acme.nullify.ai", + tokenProvider: tp, + transport: inner, + cachedToken: initial, + cachedAt: time.Now(), + cacheTTL: time.Hour, // keep getToken from refreshing on TTL during 401 tests + } +} + +// bearerRouter returns the status mapped to the request's bearer token. +func bearerRouter(byToken map[string]int) (http.RoundTripper, func() int) { + var mu sync.Mutex + calls := 0 + rt := rtFunc(func(r *http.Request) (*http.Response, error) { + mu.Lock() + calls++ + mu.Unlock() + tok := strings.TrimPrefix(r.Header.Get("Authorization"), "Bearer ") + status, ok := byToken[tok] + if !ok { + status = http.StatusUnauthorized + } + return newResp(status, nil), nil + }) + return rt, func() int { mu.Lock(); defer mu.Unlock(); return calls } +} + +func TestRefreshOn401RetriesWithNewToken(t *testing.T) { + refreshCalls := 0 + tp := func() (string, error) { refreshCalls++; return "fresh", nil } + inner, _ := bearerRouter(map[string]int{"stale": 401, "fresh": 200}) + tr := newRefreshingTransport("stale", tp, inner) + + req, _ := http.NewRequest(http.MethodGet, "http://x", nil) + resp, err := tr.RoundTrip(req) + if err != nil { + t.Fatalf("unexpected err: %v", err) + } + if resp.StatusCode != 200 { + t.Errorf("status = %d, want 200 after refresh", resp.StatusCode) + } + if refreshCalls != 1 { + t.Errorf("refreshCalls = %d, want 1", refreshCalls) + } +} + +func TestNo401MeansNoRefresh(t *testing.T) { + refreshCalls := 0 + tp := func() (string, error) { refreshCalls++; return "fresh", nil } + inner, calls := bearerRouter(map[string]int{"good": 200}) + tr := newRefreshingTransport("good", tp, inner) + + req, _ := http.NewRequest(http.MethodGet, "http://x", nil) + resp, err := tr.RoundTrip(req) + if err != nil { + t.Fatalf("unexpected err: %v", err) + } + if resp.StatusCode != 200 { + t.Errorf("status = %d, want 200", resp.StatusCode) + } + if refreshCalls != 0 { + t.Errorf("refreshCalls = %d, want 0 (no 401)", refreshCalls) + } + if calls() != 1 { + t.Errorf("inner calls = %d, want 1", calls()) + } +} + +func TestRefreshFailureSurfaces401(t *testing.T) { + tp := func() (string, error) { return "", io.ErrUnexpectedEOF } + inner, _ := bearerRouter(map[string]int{"stale": 401}) + tr := newRefreshingTransport("stale", tp, inner) + + req, _ := http.NewRequest(http.MethodGet, "http://x", nil) + resp, err := tr.RoundTrip(req) + if err != nil { + t.Fatalf("unexpected err: %v", err) + } + if resp.StatusCode != 401 { + t.Errorf("status = %d, want 401 surfaced when refresh fails", resp.StatusCode) + } +} + +func TestRefreshReplaysBodyOn401(t *testing.T) { + tp := func() (string, error) { return "fresh", nil } + var bodies []string + var mu sync.Mutex + inner := rtFunc(func(r *http.Request) (*http.Response, error) { + b, _ := io.ReadAll(r.Body) + mu.Lock() + bodies = append(bodies, string(b)) + mu.Unlock() + if strings.HasSuffix(r.Header.Get("Authorization"), "stale") { + return newResp(401, nil), nil + } + return newResp(200, nil), nil + }) + tr := newRefreshingTransport("stale", tp, inner) + + req, _ := http.NewRequest(http.MethodPost, "http://x", strings.NewReader("payload")) + resp, err := tr.RoundTrip(req) + if err != nil { + t.Fatalf("unexpected err: %v", err) + } + if resp.StatusCode != 200 { + t.Errorf("status = %d, want 200", resp.StatusCode) + } + if len(bodies) != 2 { + t.Fatalf("attempts = %d, want 2", len(bodies)) + } + if bodies[0] != "payload" || bodies[1] != "payload" { + t.Errorf("body not replayed across 401 retry: %q", bodies) + } +} + +func TestForceRefreshIsStormSafe(t *testing.T) { + var refreshCalls int + var mu sync.Mutex + tp := func() (string, error) { + mu.Lock() + refreshCalls++ + mu.Unlock() + return "fresh", nil + } + inner, _ := bearerRouter(map[string]int{"stale": 401, "fresh": 200}) + tr := newRefreshingTransport("stale", tp, inner) + + const n = 20 + var wg sync.WaitGroup + wg.Add(n) + for range n { + go func() { + defer wg.Done() + req, _ := http.NewRequest(http.MethodGet, "http://x", nil) + resp, err := tr.RoundTrip(req) + if err == nil { + resp.Body.Close() + } + }() + } + wg.Wait() + + mu.Lock() + defer mu.Unlock() + if refreshCalls != 1 { + t.Errorf("refreshCalls = %d, want 1 (storm should dedup)", refreshCalls) + } +} + +func TestGetTokenRefreshesOnTTLExpiry(t *testing.T) { + refreshCalls := 0 + tp := func() (string, error) { refreshCalls++; return "new", nil } + tr := newRefreshingTransport("old", tp, rtFunc(func(*http.Request) (*http.Response, error) { + return newResp(200, nil), nil + })) + tr.cacheTTL = time.Millisecond + tr.cachedAt = time.Now().Add(-time.Hour) // expired + + got := tr.getToken(nil) + if got != "new" { + t.Errorf("getToken = %q, want new (TTL expired)", got) + } + if refreshCalls != 1 { + t.Errorf("refreshCalls = %d, want 1", refreshCalls) + } +} diff --git a/internal/client/retry.go b/internal/client/retry.go index 1a6301a..0bdcba9 100644 --- a/internal/client/retry.go +++ b/internal/client/retry.go @@ -4,13 +4,16 @@ import ( "bytes" "io" "math" - "math/rand" + "math/rand/v2" "net/http" + "strconv" + "strings" "time" ) // retryTransport wraps an http.RoundTripper and retries on 429 and 5xx errors -// with exponential backoff. +// with exponential backoff. 5xx is only retried for idempotent methods so a +// POST/PATCH that may have committed server-side is never replayed. type retryTransport struct { transport http.RoundTripper maxRetries int @@ -54,14 +57,17 @@ func (t *retryTransport) RoundTrip(req *http.Request) (*http.Response, error) { return nil, err } - if !t.shouldRetry(resp.StatusCode) || attempt == t.maxRetries { + if !shouldRetry(req.Method, resp.StatusCode) || attempt == t.maxRetries { return resp, nil } - // Drain and close the response body before retrying + delay := t.retryDelay(resp, attempt) + + // Drain and close the response body before retrying so the underlying + // connection can be reused. + _, _ = io.Copy(io.Discard, io.LimitReader(resp.Body, 4<<10)) resp.Body.Close() - delay := t.backoffDelay(attempt) select { case <-time.After(delay): case <-req.Context().Done(): @@ -72,8 +78,59 @@ func (t *retryTransport) RoundTrip(req *http.Request) (*http.Response, error) { return resp, err } -func (t *retryTransport) shouldRetry(statusCode int) bool { - return statusCode == http.StatusTooManyRequests || statusCode >= 500 +// shouldRetry decides whether a response is worth retrying. 429 is always safe to +// retry (the server rejected the request without processing it). 5xx is retried +// only for idempotent methods; replaying a POST/PATCH risks duplicating work that +// may have already committed before the server errored. +func shouldRetry(method string, statusCode int) bool { + if statusCode == http.StatusTooManyRequests { + return true + } + if statusCode >= 500 { + return isIdempotent(method) + } + return false +} + +func isIdempotent(method string) bool { + switch strings.ToUpper(method) { + case http.MethodGet, http.MethodHead, http.MethodPut, http.MethodDelete, http.MethodOptions, http.MethodTrace: + return true + } + return false +} + +// retryDelay honors a Retry-After header when the server provides one (capped at +// maxDelay), otherwise falls back to jittered exponential backoff. +func (t *retryTransport) retryDelay(resp *http.Response, attempt int) time.Duration { + if d := parseRetryAfter(resp.Header.Get("Retry-After"), time.Now()); d > 0 { + if d > t.maxDelay { + return t.maxDelay + } + return d + } + return t.backoffDelay(attempt) +} + +// parseRetryAfter parses a Retry-After header value, which may be either an +// integer number of seconds or an HTTP-date. Returns 0 if absent/invalid. +func parseRetryAfter(v string, now time.Time) time.Duration { + v = strings.TrimSpace(v) + if v == "" { + return 0 + } + if secs, err := strconv.Atoi(v); err == nil { + if secs <= 0 { + return 0 + } + return time.Duration(secs) * time.Second + } + if when, err := http.ParseTime(v); err == nil { + if d := when.Sub(now); d > 0 { + return d + } + } + return 0 } func (t *retryTransport) backoffDelay(attempt int) time.Duration { diff --git a/internal/client/retry_test.go b/internal/client/retry_test.go new file mode 100644 index 0000000..8741c85 --- /dev/null +++ b/internal/client/retry_test.go @@ -0,0 +1,175 @@ +package client + +import ( + "context" + "io" + "net/http" + "strings" + "sync" + "testing" + "time" +) + +// rtFunc adapts a function to http.RoundTripper. +type rtFunc func(*http.Request) (*http.Response, error) + +func (f rtFunc) RoundTrip(r *http.Request) (*http.Response, error) { return f(r) } + +func newResp(status int, hdr map[string]string) *http.Response { + h := http.Header{} + for k, v := range hdr { + h.Set(k, v) + } + return &http.Response{StatusCode: status, Header: h, Body: io.NopCloser(strings.NewReader(""))} +} + +// fastRetryTransport returns a retryTransport with negligible delays for tests. +func fastRetryTransport(inner http.RoundTripper) *retryTransport { + return &retryTransport{ + transport: inner, + maxRetries: 3, + initialDelay: time.Millisecond, + maxDelay: 5 * time.Millisecond, + } +} + +func TestShouldRetry(t *testing.T) { + cases := []struct { + method string + status int + want bool + }{ + {http.MethodGet, 500, true}, + {http.MethodGet, 503, true}, + {http.MethodPut, 500, true}, + {http.MethodDelete, 502, true}, + {http.MethodPost, 500, false}, // non-idempotent: never replay on 5xx + {http.MethodPatch, 503, false}, // non-idempotent + {http.MethodPost, 429, true}, // 429 = not processed, safe to retry + {http.MethodPatch, 429, true}, + {http.MethodGet, 200, false}, + {http.MethodGet, 404, false}, + {http.MethodGet, 400, false}, + } + for _, c := range cases { + if got := shouldRetry(c.method, c.status); got != c.want { + t.Errorf("shouldRetry(%s, %d) = %v, want %v", c.method, c.status, got, c.want) + } + } +} + +func TestParseRetryAfter(t *testing.T) { + now := time.Date(2026, 5, 26, 12, 0, 0, 0, time.UTC) + cases := []struct { + in string + want time.Duration + }{ + {"", 0}, + {"5", 5 * time.Second}, + {"0", 0}, + {"-3", 0}, + {" 10 ", 10 * time.Second}, + {"notanumber", 0}, + {now.Add(30 * time.Second).Format(http.TimeFormat), 30 * time.Second}, + {now.Add(-30 * time.Second).Format(http.TimeFormat), 0}, // past date + } + for _, c := range cases { + if got := parseRetryAfter(c.in, now); got != c.want { + t.Errorf("parseRetryAfter(%q) = %v, want %v", c.in, got, c.want) + } + } +} + +func TestRetryDelayHonorsRetryAfterCappedAtMax(t *testing.T) { + tr := fastRetryTransport(nil) // maxDelay 5ms + resp := newResp(429, map[string]string{"Retry-After": "100"}) + if d := tr.retryDelay(resp, 0); d != tr.maxDelay { + t.Errorf("retryDelay capped = %v, want %v", d, tr.maxDelay) + } + // No header -> falls back to backoff (>0). + if d := tr.retryDelay(newResp(429, nil), 0); d <= 0 { + t.Errorf("backoff fallback = %v, want > 0", d) + } +} + +func TestRoundTripRetriesIdempotentOn5xx(t *testing.T) { + calls := 0 + inner := rtFunc(func(r *http.Request) (*http.Response, error) { + calls++ + return newResp(500, nil), nil + }) + req, _ := http.NewRequest(http.MethodGet, "http://x", nil) + resp, err := fastRetryTransport(inner).RoundTrip(req) + if err != nil { + t.Fatalf("unexpected err: %v", err) + } + if resp.StatusCode != 500 { + t.Errorf("status = %d, want 500", resp.StatusCode) + } + if calls != 4 { // 1 initial + 3 retries + t.Errorf("calls = %d, want 4", calls) + } +} + +func TestRoundTripDoesNotRetryPostOn5xx(t *testing.T) { + calls := 0 + inner := rtFunc(func(r *http.Request) (*http.Response, error) { + calls++ + return newResp(503, nil), nil + }) + req, _ := http.NewRequest(http.MethodPost, "http://x", strings.NewReader("payload")) + _, err := fastRetryTransport(inner).RoundTrip(req) + if err != nil { + t.Fatalf("unexpected err: %v", err) + } + if calls != 1 { + t.Errorf("POST retried on 5xx: calls = %d, want 1", calls) + } +} + +func TestRoundTripRetriesPostOn429AndReplaysBody(t *testing.T) { + var bodies []string + var mu sync.Mutex + calls := 0 + inner := rtFunc(func(r *http.Request) (*http.Response, error) { + b, _ := io.ReadAll(r.Body) + mu.Lock() + bodies = append(bodies, string(b)) + calls++ + c := calls + mu.Unlock() + if c < 3 { + return newResp(429, nil), nil + } + return newResp(200, nil), nil + }) + req, _ := http.NewRequest(http.MethodPost, "http://x", strings.NewReader("payload")) + resp, err := fastRetryTransport(inner).RoundTrip(req) + if err != nil { + t.Fatalf("unexpected err: %v", err) + } + if resp.StatusCode != 200 { + t.Errorf("status = %d, want 200", resp.StatusCode) + } + if len(bodies) != 3 { + t.Fatalf("attempts = %d, want 3", len(bodies)) + } + for i, b := range bodies { + if b != "payload" { + t.Errorf("attempt %d body = %q, want payload (body not replayed)", i, b) + } + } +} + +func TestRoundTripStopsOnContextCancel(t *testing.T) { + inner := rtFunc(func(r *http.Request) (*http.Response, error) { + return newResp(500, nil), nil + }) + ctx, cancel := context.WithCancel(context.Background()) + cancel() // cancelled before the backoff wait + req, _ := http.NewRequestWithContext(ctx, http.MethodGet, "http://x", nil) + _, err := fastRetryTransport(inner).RoundTrip(req) + if err == nil { + t.Fatal("expected context error, got nil") + } +}