Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 60 additions & 4 deletions internal/client/refreshing_transport.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package client

import (
"bytes"
"context"
"io"
"net/http"
"strings"
"sync"
Expand Down Expand Up @@ -89,11 +91,65 @@ func (t *refreshingAuthTransport) getToken(ctx context.Context) string {
return t.cachedToken
}

// forceRefresh fetches a new token regardless of the cache TTL. It is storm-safe:
// if another goroutine already replaced staleToken (e.g. several concurrent
// requests hit a 401 at once), the already-refreshed token is reused instead of
// fetching again.
func (t *refreshingAuthTransport) forceRefresh(ctx context.Context, staleToken string) string {
t.mu.Lock()
defer t.mu.Unlock()
if t.cachedToken != staleToken {
return t.cachedToken
}
newToken, err := t.tokenProvider()
if err != nil {
logger.L(ctx).Warn("forced token refresh failed after 401", logger.Err(err))
return ""
}
t.cachedToken = newToken
t.cachedAt = time.Now()
return newToken
}

func (t *refreshingAuthTransport) RoundTrip(req *http.Request) (*http.Response, error) {
// Buffer the body so the request can be replayed if the first attempt 401s.
var bodyBytes []byte
if req.Body != nil {
b, err := io.ReadAll(req.Body)
if err != nil {
return nil, err
}
req.Body.Close()
bodyBytes = b
}

attempt := func(token string) (*http.Response, error) {
r := req.Clone(req.Context())
if bodyBytes != nil {
r.Body = io.NopCloser(bytes.NewReader(bodyBytes))
r.ContentLength = int64(len(bodyBytes))
}
r.Header.Set("Authorization", "Bearer "+token)
r.Header.Set("User-Agent", "Nullify-CLI/mcp")
return t.transport.RoundTrip(r)
}

token := t.getToken(req.Context())
resp, err := attempt(token)
if err != nil {
return nil, err
}

// The cached token can be invalid before its TTL elapses (revocation, server
// session kill, clock skew). On a 401, force a refresh and retry once before
// surfacing the failure.
if resp.StatusCode == http.StatusUnauthorized {
if newToken := t.forceRefresh(req.Context(), token); newToken != "" && newToken != token {
_, _ = io.Copy(io.Discard, io.LimitReader(resp.Body, 4<<10))
resp.Body.Close()
return attempt(newToken)
}
}

r := req.Clone(req.Context())
r.Header.Set("Authorization", "Bearer "+token)
r.Header.Set("User-Agent", "Nullify-CLI/mcp")
return t.transport.RoundTrip(r)
return resp, nil
}
179 changes: 179 additions & 0 deletions internal/client/refreshing_transport_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
package client

import (
"io"
"net/http"
"strings"
"sync"
"testing"
"time"
)

func newRefreshingTransport(initial string, tp TokenProvider, inner http.RoundTripper) *refreshingAuthTransport {
return &refreshingAuthTransport{
nullifyHost: "acme.nullify.ai",
tokenProvider: tp,
transport: inner,
cachedToken: initial,
cachedAt: time.Now(),
cacheTTL: time.Hour, // keep getToken from refreshing on TTL during 401 tests
}
}

// bearerRouter returns the status mapped to the request's bearer token.
func bearerRouter(byToken map[string]int) (http.RoundTripper, func() int) {
var mu sync.Mutex
calls := 0
rt := rtFunc(func(r *http.Request) (*http.Response, error) {
mu.Lock()
calls++
mu.Unlock()
tok := strings.TrimPrefix(r.Header.Get("Authorization"), "Bearer ")
status, ok := byToken[tok]
if !ok {
status = http.StatusUnauthorized
}
return newResp(status, nil), nil
})
return rt, func() int { mu.Lock(); defer mu.Unlock(); return calls }
}

func TestRefreshOn401RetriesWithNewToken(t *testing.T) {
refreshCalls := 0
tp := func() (string, error) { refreshCalls++; return "fresh", nil }
inner, _ := bearerRouter(map[string]int{"stale": 401, "fresh": 200})
tr := newRefreshingTransport("stale", tp, inner)

req, _ := http.NewRequest(http.MethodGet, "http://x", nil)
resp, err := tr.RoundTrip(req)
if err != nil {
t.Fatalf("unexpected err: %v", err)
}
if resp.StatusCode != 200 {
t.Errorf("status = %d, want 200 after refresh", resp.StatusCode)
}
if refreshCalls != 1 {
t.Errorf("refreshCalls = %d, want 1", refreshCalls)
}
}

func TestNo401MeansNoRefresh(t *testing.T) {
refreshCalls := 0
tp := func() (string, error) { refreshCalls++; return "fresh", nil }
inner, calls := bearerRouter(map[string]int{"good": 200})
tr := newRefreshingTransport("good", tp, inner)

req, _ := http.NewRequest(http.MethodGet, "http://x", nil)
resp, err := tr.RoundTrip(req)
if err != nil {
t.Fatalf("unexpected err: %v", err)
}
if resp.StatusCode != 200 {
t.Errorf("status = %d, want 200", resp.StatusCode)
}
if refreshCalls != 0 {
t.Errorf("refreshCalls = %d, want 0 (no 401)", refreshCalls)
}
if calls() != 1 {
t.Errorf("inner calls = %d, want 1", calls())
}
}

func TestRefreshFailureSurfaces401(t *testing.T) {
tp := func() (string, error) { return "", io.ErrUnexpectedEOF }
inner, _ := bearerRouter(map[string]int{"stale": 401})
tr := newRefreshingTransport("stale", tp, inner)

req, _ := http.NewRequest(http.MethodGet, "http://x", nil)
resp, err := tr.RoundTrip(req)
if err != nil {
t.Fatalf("unexpected err: %v", err)
}
if resp.StatusCode != 401 {
t.Errorf("status = %d, want 401 surfaced when refresh fails", resp.StatusCode)
}
}

func TestRefreshReplaysBodyOn401(t *testing.T) {
tp := func() (string, error) { return "fresh", nil }
var bodies []string
var mu sync.Mutex
inner := rtFunc(func(r *http.Request) (*http.Response, error) {
b, _ := io.ReadAll(r.Body)
mu.Lock()
bodies = append(bodies, string(b))
mu.Unlock()
if strings.HasSuffix(r.Header.Get("Authorization"), "stale") {
return newResp(401, nil), nil
}
return newResp(200, nil), nil
})
tr := newRefreshingTransport("stale", tp, inner)

req, _ := http.NewRequest(http.MethodPost, "http://x", strings.NewReader("payload"))
resp, err := tr.RoundTrip(req)
if err != nil {
t.Fatalf("unexpected err: %v", err)
}
if resp.StatusCode != 200 {
t.Errorf("status = %d, want 200", resp.StatusCode)
}
if len(bodies) != 2 {
t.Fatalf("attempts = %d, want 2", len(bodies))
}
if bodies[0] != "payload" || bodies[1] != "payload" {
t.Errorf("body not replayed across 401 retry: %q", bodies)
}
}

func TestForceRefreshIsStormSafe(t *testing.T) {
var refreshCalls int
var mu sync.Mutex
tp := func() (string, error) {
mu.Lock()
refreshCalls++
mu.Unlock()
return "fresh", nil
}
inner, _ := bearerRouter(map[string]int{"stale": 401, "fresh": 200})
tr := newRefreshingTransport("stale", tp, inner)

const n = 20
var wg sync.WaitGroup
wg.Add(n)
for range n {
go func() {
defer wg.Done()
req, _ := http.NewRequest(http.MethodGet, "http://x", nil)
resp, err := tr.RoundTrip(req)
if err == nil {
resp.Body.Close()
}
}()
}
wg.Wait()

mu.Lock()
defer mu.Unlock()
if refreshCalls != 1 {
t.Errorf("refreshCalls = %d, want 1 (storm should dedup)", refreshCalls)
}
}

func TestGetTokenRefreshesOnTTLExpiry(t *testing.T) {
refreshCalls := 0
tp := func() (string, error) { refreshCalls++; return "new", nil }
tr := newRefreshingTransport("old", tp, rtFunc(func(*http.Request) (*http.Response, error) {
return newResp(200, nil), nil
}))
tr.cacheTTL = time.Millisecond
tr.cachedAt = time.Now().Add(-time.Hour) // expired

got := tr.getToken(nil)

Check failure on line 172 in internal/client/refreshing_transport_test.go

View workflow job for this annotation

GitHub Actions / build-test

SA1012: do not pass a nil Context, even if a function permits it; pass context.TODO if you are unsure about which Context to use (staticcheck)
if got != "new" {
t.Errorf("getToken = %q, want new (TTL expired)", got)
}
if refreshCalls != 1 {
t.Errorf("refreshCalls = %d, want 1", refreshCalls)
}
}
71 changes: 64 additions & 7 deletions internal/client/retry.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,16 @@ import (
"bytes"
"io"
"math"
"math/rand"
"math/rand/v2"
"net/http"
"strconv"
"strings"
"time"
)

// retryTransport wraps an http.RoundTripper and retries on 429 and 5xx errors
// with exponential backoff.
// with exponential backoff. 5xx is only retried for idempotent methods so a
// POST/PATCH that may have committed server-side is never replayed.
type retryTransport struct {
transport http.RoundTripper
maxRetries int
Expand Down Expand Up @@ -54,14 +57,17 @@ func (t *retryTransport) RoundTrip(req *http.Request) (*http.Response, error) {
return nil, err
}

if !t.shouldRetry(resp.StatusCode) || attempt == t.maxRetries {
if !shouldRetry(req.Method, resp.StatusCode) || attempt == t.maxRetries {
return resp, nil
}

// Drain and close the response body before retrying
delay := t.retryDelay(resp, attempt)

// Drain and close the response body before retrying so the underlying
// connection can be reused.
_, _ = io.Copy(io.Discard, io.LimitReader(resp.Body, 4<<10))
resp.Body.Close()

delay := t.backoffDelay(attempt)
select {
case <-time.After(delay):
case <-req.Context().Done():
Expand All @@ -72,8 +78,59 @@ func (t *retryTransport) RoundTrip(req *http.Request) (*http.Response, error) {
return resp, err
}

func (t *retryTransport) shouldRetry(statusCode int) bool {
return statusCode == http.StatusTooManyRequests || statusCode >= 500
// shouldRetry decides whether a response is worth retrying. 429 is always safe to
// retry (the server rejected the request without processing it). 5xx is retried
// only for idempotent methods; replaying a POST/PATCH risks duplicating work that
// may have already committed before the server errored.
func shouldRetry(method string, statusCode int) bool {
if statusCode == http.StatusTooManyRequests {
return true
}
if statusCode >= 500 {
return isIdempotent(method)
}
return false
}

func isIdempotent(method string) bool {
switch strings.ToUpper(method) {
case http.MethodGet, http.MethodHead, http.MethodPut, http.MethodDelete, http.MethodOptions, http.MethodTrace:
return true
}
return false
}

// retryDelay honors a Retry-After header when the server provides one (capped at
// maxDelay), otherwise falls back to jittered exponential backoff.
func (t *retryTransport) retryDelay(resp *http.Response, attempt int) time.Duration {
if d := parseRetryAfter(resp.Header.Get("Retry-After"), time.Now()); d > 0 {
if d > t.maxDelay {
return t.maxDelay
}
return d
}
return t.backoffDelay(attempt)
}

// parseRetryAfter parses a Retry-After header value, which may be either an
// integer number of seconds or an HTTP-date. Returns 0 if absent/invalid.
func parseRetryAfter(v string, now time.Time) time.Duration {
v = strings.TrimSpace(v)
if v == "" {
return 0
}
if secs, err := strconv.Atoi(v); err == nil {
if secs <= 0 {
return 0
}
return time.Duration(secs) * time.Second
}
if when, err := http.ParseTime(v); err == nil {
if d := when.Sub(now); d > 0 {
return d
}
}
return 0
}

func (t *retryTransport) backoffDelay(attempt int) time.Duration {
Expand Down
Loading
Loading