diff --git a/src/databricks/sql/auth/common.py b/src/databricks/sql/auth/common.py index 0e3a01918..cda30a528 100644 --- a/src/databricks/sql/auth/common.py +++ b/src/databricks/sql/auth/common.py @@ -47,6 +47,7 @@ def __init__( retry_stop_after_attempts_duration: Optional[float] = None, retry_delay_default: Optional[float] = None, retry_dangerous_codes: Optional[List[int]] = None, + respect_server_retry_after_header: Optional[bool] = None, proxy_auth_method: Optional[str] = None, pool_connections: Optional[int] = None, pool_maxsize: Optional[int] = None, @@ -80,6 +81,7 @@ def __init__( ) self.retry_delay_default = retry_delay_default or 5.0 self.retry_dangerous_codes = retry_dangerous_codes or [] + self.respect_server_retry_after_header = bool(respect_server_retry_after_header) self.proxy_auth_method = proxy_auth_method self.pool_connections = pool_connections or 10 self.pool_maxsize = pool_maxsize or 20 diff --git a/src/databricks/sql/auth/retry.py b/src/databricks/sql/auth/retry.py index b0c2f497d..9cb29fdce 100755 --- a/src/databricks/sql/auth/retry.py +++ b/src/databricks/sql/auth/retry.py @@ -94,6 +94,7 @@ def __init__( stop_after_attempts_duration: float, delay_default: float, force_dangerous_codes: List[int], + respect_server_retry_after_header: bool = False, urllib3_kwargs: dict = {}, ): # These values do not change from one command to the next @@ -103,6 +104,7 @@ def __init__( self.stop_after_attempts_duration = stop_after_attempts_duration self._delay_default = delay_default self.force_dangerous_codes = force_dangerous_codes + self.respect_server_retry_after_header = respect_server_retry_after_header # the urllib3 kwargs are a mix of configuration (some of which we override) # and counters like `total` or `connect` which may change between successive retries @@ -202,6 +204,7 @@ def new( stop_after_attempts_duration=self.stop_after_attempts_duration, delay_default=self.delay_default, force_dangerous_codes=self.force_dangerous_codes, + respect_server_retry_after_header=self.respect_server_retry_after_header, urllib3_kwargs={}, ) @@ -323,7 +326,9 @@ def get_backoff_time(self) -> float: return proposed_backoff - def should_retry(self, method: str, status_code: int) -> Tuple[bool, str]: + def should_retry( + self, method: str, status_code: int, has_retry_after: bool = False + ) -> Tuple[bool, str]: """This method encapsulates the connector's approach to retries. We always retry a request unless one of these conditions is met: @@ -388,6 +393,15 @@ def should_retry(self, method: str, status_code: int) -> Tuple[bool, str]: if not self._is_method_retryable(method): return False, "Only POST requests are retried" + # When respect_server_retry_after_header is enabled, only retry when the + # server explicitly signals it's safe via a Retry-After header. This prevents + # duplicate side effects for non-idempotent operations. + if self.respect_server_retry_after_header and not has_retry_after: + return ( + False, + "respect_server_retry_after_header mode: no Retry-After header present", + ) + # Request failed, was an ExecuteStatement and the command may have reached the server if ( self.command_type == CommandType.EXECUTE_STATEMENT @@ -430,7 +444,7 @@ def is_retry( Logs a debug message if the request will be retried """ - should_retry, msg = self.should_retry(method, status_code) + should_retry, msg = self.should_retry(method, status_code, has_retry_after) if should_retry: logger.debug(msg) diff --git a/src/databricks/sql/backend/sea/utils/http_client.py b/src/databricks/sql/backend/sea/utils/http_client.py index caefe9929..99fe5edb5 100644 --- a/src/databricks/sql/backend/sea/utils/http_client.py +++ b/src/databricks/sql/backend/sea/utils/http_client.py @@ -92,6 +92,9 @@ def __init__( ) self._retry_delay_default = kwargs.get("_retry_delay_default", 5.0) self.force_dangerous_codes = kwargs.get("_retry_dangerous_codes", []) + self._respect_server_retry_after_header = kwargs.get( + "_respect_server_retry_after_header", False + ) # Connection pooling settings self.max_connections = kwargs.get("max_connections", 10) @@ -116,6 +119,7 @@ def __init__( stop_after_attempts_duration=self._retry_stop_after_attempts_duration, delay_default=self._retry_delay_default, force_dangerous_codes=self.force_dangerous_codes, + respect_server_retry_after_header=self._respect_server_retry_after_header, urllib3_kwargs=urllib3_kwargs, ) else: diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index e23f3389b..776a7b5c8 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -191,6 +191,9 @@ def __init__( " This behaviour is deprecated and will be removed in a future release." ) self.force_dangerous_codes = kwargs.get("_retry_dangerous_codes", []) + self._respect_server_retry_after_header = kwargs.get( + "_respect_server_retry_after_header", False + ) additional_transport_args = {} @@ -217,6 +220,7 @@ def __init__( stop_after_attempts_duration=self._retry_stop_after_attempts_duration, delay_default=self._retry_delay_default, force_dangerous_codes=self.force_dangerous_codes, + respect_server_retry_after_header=self._respect_server_retry_after_header, urllib3_kwargs=urllib3_kwargs, ) diff --git a/src/databricks/sql/common/unified_http_client.py b/src/databricks/sql/common/unified_http_client.py index ef55564c8..67c178e19 100644 --- a/src/databricks/sql/common/unified_http_client.py +++ b/src/databricks/sql/common/unified_http_client.py @@ -135,6 +135,7 @@ def _setup_pool_managers(self): stop_after_attempts_duration=self.config.retry_stop_after_attempts_duration, delay_default=self.config.retry_delay_default, force_dangerous_codes=self.config.retry_dangerous_codes, + respect_server_retry_after_header=self.config.respect_server_retry_after_header, ) # Initialize the required attributes that DatabricksRetryPolicy expects diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index b1fff7202..7d7c96807 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -977,6 +977,9 @@ def build_client_context(server_hostname: str, version: str, **kwargs): ), retry_delay_default=kwargs.get("_retry_delay_default"), retry_dangerous_codes=kwargs.get("_retry_dangerous_codes"), + respect_server_retry_after_header=kwargs.get( + "_respect_server_retry_after_header" + ), proxy_auth_method=kwargs.get("_proxy_auth_method"), pool_connections=kwargs.get("_pool_connections"), pool_maxsize=kwargs.get("_pool_maxsize"), diff --git a/tests/unit/test_retry.py b/tests/unit/test_retry.py index 0d01d8675..7abab509d 100644 --- a/tests/unit/test_retry.py +++ b/tests/unit/test_retry.py @@ -7,9 +7,8 @@ class TestRetry: - @pytest.fixture() - def retry_policy(self) -> DatabricksRetryPolicy: - return DatabricksRetryPolicy( + def _make_retry_policy(self, **overrides) -> DatabricksRetryPolicy: + defaults = dict( delay_min=1, delay_max=30, stop_after_attempts_count=3, @@ -17,6 +16,12 @@ def retry_policy(self) -> DatabricksRetryPolicy: delay_default=2, force_dangerous_codes=[], ) + defaults.update(overrides) + return DatabricksRetryPolicy(**defaults) + + @pytest.fixture() + def retry_policy(self) -> DatabricksRetryPolicy: + return self._make_retry_policy() @pytest.fixture() def error_history(self) -> RequestHistory: @@ -84,6 +89,84 @@ def test_excessive_retry_attempts_error(self, t_mock, retry_policy): # Internally urllib3 calls the increment function generating a new instance for every retry retry_policy = retry_policy.increment() + def test_respect_server_retry_after__retries_with_retry_after(self): + """429 + Retry-After header → should retry""" + policy = self._make_retry_policy(respect_server_retry_after_header=True) + policy._retry_start_time = time.time() + policy.command_type = CommandType.OTHER + should_retry, msg = policy.should_retry("POST", 429, has_retry_after=True) + assert should_retry is True + + def test_respect_server_retry_after__no_retry_without_retry_after(self): + """429 without Retry-After header → no retry""" + policy = self._make_retry_policy(respect_server_retry_after_header=True) + policy._retry_start_time = time.time() + policy.command_type = CommandType.OTHER + should_retry, msg = policy.should_retry("POST", 429, has_retry_after=False) + assert should_retry is False + assert "respect_server_retry_after_header" in msg + + def test_respect_server_retry_after__no_retry_503_without_header(self): + """503 without Retry-After header → no retry""" + policy = self._make_retry_policy(respect_server_retry_after_header=True) + policy._retry_start_time = time.time() + policy.command_type = CommandType.OTHER + should_retry, msg = policy.should_retry("POST", 503, has_retry_after=False) + assert should_retry is False + assert "respect_server_retry_after_header" in msg + + def test_respect_server_retry_after__overrides_dangerous_codes(self): + """force_dangerous_codes=[500] + no Retry-After → no retry in respect_server_retry_after_header mode""" + policy = self._make_retry_policy( + force_dangerous_codes=[500], respect_server_retry_after_header=True + ) + policy._retry_start_time = time.time() + policy.command_type = CommandType.EXECUTE_STATEMENT + should_retry, msg = policy.should_retry("POST", 500, has_retry_after=False) + assert should_retry is False + assert "respect_server_retry_after_header" in msg + + def test_respect_server_retry_after__non_retryable_codes_unaffected(self): + """401/403/501 still don't retry even with Retry-After header""" + policy = self._make_retry_policy(respect_server_retry_after_header=True) + policy._retry_start_time = time.time() + policy.command_type = CommandType.OTHER + for code in [401, 403, 501]: + should_retry, msg = policy.should_retry( + "POST", code, has_retry_after=True + ) + assert should_retry is False, f"Code {code} should never retry" + + def test_default_mode_unchanged(self, retry_policy): + """respect_server_retry_after_header=False preserves existing behavior — 429 retries without header""" + retry_policy._retry_start_time = time.time() + retry_policy.command_type = CommandType.OTHER + should_retry, msg = retry_policy.should_retry( + "POST", 429, has_retry_after=False + ) + assert should_retry is True + + def test_respect_server_retry_after__survives_new(self): + """urllib3 calls .new() between retries to create a fresh policy instance. + Verify that respect_server_retry_after_header is carried over and still enforced.""" + policy = self._make_retry_policy(respect_server_retry_after_header=True) + policy._retry_start_time = time.time() + policy.command_type = CommandType.OTHER + new_policy = policy.new() + assert new_policy.respect_server_retry_after_header is True + # The new instance should still block retries without Retry-After + should_retry, msg = new_policy.should_retry("POST", 429, has_retry_after=False) + assert should_retry is False + assert "respect_server_retry_after_header" in msg + + def test_respect_server_retry_after__execute_statement_with_retry_after(self): + """EXECUTE_STATEMENT + 429 + Retry-After header → retry""" + policy = self._make_retry_policy(respect_server_retry_after_header=True) + policy._retry_start_time = time.time() + policy.command_type = CommandType.EXECUTE_STATEMENT + should_retry, msg = policy.should_retry("POST", 429, has_retry_after=True) + assert should_retry is True + def test_404_does_not_retry_for_any_command_type(self, retry_policy): """Test that 404 never retries for any CommandType""" retry_policy._retry_start_time = time.time() diff --git a/tests/unit/test_unified_http_client.py b/tests/unit/test_unified_http_client.py index 4e9ce1bbf..44d05178d 100644 --- a/tests/unit/test_unified_http_client.py +++ b/tests/unit/test_unified_http_client.py @@ -37,6 +37,7 @@ def client_context(self): context.retry_stop_after_attempts_duration = 300.0 context.retry_delay_default = 5.0 context.retry_dangerous_codes = [] + context.respect_server_retry_after_header = False context.proxy_auth_method = None context.pool_connections = 10 context.pool_maxsize = 20 @@ -48,16 +49,19 @@ def http_client(self, client_context): """Create UnifiedHttpClient instance.""" return UnifiedHttpClient(client_context) - @pytest.mark.parametrize("status_code,path", [ - (429, "reason.response"), - (503, "reason.response"), - (500, "direct_response"), - ]) + @pytest.mark.parametrize( + "status_code,path", + [ + (429, "reason.response"), + (503, "reason.response"), + (500, "direct_response"), + ], + ) def test_max_retry_error_with_status_codes(self, http_client, status_code, path): """Test MaxRetryError with various status codes and response paths.""" mock_pool = Mock() max_retry_error = MaxRetryError(pool=mock_pool, url="http://test.com") - + if path == "reason.response": max_retry_error.reason = Mock() max_retry_error.reason.response = Mock() @@ -79,12 +83,21 @@ def test_max_retry_error_with_status_codes(self, http_client, status_code, path) assert "http-code" in error.context assert error.context["http-code"] == status_code - @pytest.mark.parametrize("setup_func", [ - lambda e: None, # No setup - error with no attributes - lambda e: setattr(e, "reason", None), # reason=None - lambda e: (setattr(e, "reason", Mock()), setattr(e.reason, "response", None)), # reason.response=None - lambda e: (setattr(e, "reason", Mock()), setattr(e.reason, "response", Mock(spec=[]))), # No status attr - ]) + @pytest.mark.parametrize( + "setup_func", + [ + lambda e: None, # No setup - error with no attributes + lambda e: setattr(e, "reason", None), # reason=None + lambda e: ( + setattr(e, "reason", Mock()), + setattr(e.reason, "response", None), + ), # reason.response=None + lambda e: ( + setattr(e, "reason", Mock()), + setattr(e.reason, "response", Mock(spec=[])), + ), # No status attr + ], + ) def test_max_retry_error_missing_status(self, http_client, setup_func): """Test MaxRetryError without status code (no crash, empty context).""" mock_pool = Mock() @@ -104,12 +117,12 @@ def test_max_retry_error_prefers_reason_response(self, http_client): """Test that e.reason.response.status is preferred over e.response.status.""" mock_pool = Mock() max_retry_error = MaxRetryError(pool=mock_pool, url="http://test.com") - + # Set both structures with different status codes max_retry_error.reason = Mock() max_retry_error.reason.response = Mock() max_retry_error.reason.response.status = 429 # Should use this - + max_retry_error.response = Mock() max_retry_error.response.status = 500 # Should be ignored