From 179140a0a21b2a28b3f517dadb4cdd58aece2398 Mon Sep 17 00:00:00 2001 From: Shubham Dhal Date: Wed, 18 Mar 2026 10:50:12 +0530 Subject: [PATCH 1/4] Add _retry_server_directed_only mode for Retry-After header compliance When enabled, the connector only retries on 429/503 if the server includes a Retry-After header in the response. This prevents duplicate side effects for non-idempotent ExecuteStatement operations where the server has not explicitly signaled that retry is safe. The new opt-in parameter `_retry_server_directed_only` threads through ClientContext, all three DatabricksRetryPolicy construction sites (Thrift, SEA, UnifiedHttpClient), and the retry policy's should_retry/is_retry methods. Default behavior (retry without requiring the header) is unchanged. Signed-off-by: Shubham Dhal --- src/databricks/sql/auth/common.py | 2 + src/databricks/sql/auth/retry.py | 15 ++- .../sql/backend/sea/utils/http_client.py | 4 + src/databricks/sql/backend/thrift_backend.py | 4 + .../sql/common/unified_http_client.py | 1 + src/databricks/sql/utils.py | 1 + tests/unit/test_retry.py | 108 ++++++++++++++++++ tests/unit/test_unified_http_client.py | 41 ++++--- 8 files changed, 160 insertions(+), 16 deletions(-) diff --git a/src/databricks/sql/auth/common.py b/src/databricks/sql/auth/common.py index 0e3a01918..2ae317390 100644 --- a/src/databricks/sql/auth/common.py +++ b/src/databricks/sql/auth/common.py @@ -47,6 +47,7 @@ def __init__( retry_stop_after_attempts_duration: Optional[float] = None, retry_delay_default: Optional[float] = None, retry_dangerous_codes: Optional[List[int]] = None, + retry_server_directed_only: Optional[bool] = None, proxy_auth_method: Optional[str] = None, pool_connections: Optional[int] = None, pool_maxsize: Optional[int] = None, @@ -80,6 +81,7 @@ def __init__( ) self.retry_delay_default = retry_delay_default or 5.0 self.retry_dangerous_codes = retry_dangerous_codes or [] + self.retry_server_directed_only = bool(retry_server_directed_only) self.proxy_auth_method = proxy_auth_method self.pool_connections = pool_connections or 10 self.pool_maxsize = pool_maxsize or 20 diff --git a/src/databricks/sql/auth/retry.py b/src/databricks/sql/auth/retry.py index b0c2f497d..72a2e87a3 100755 --- a/src/databricks/sql/auth/retry.py +++ b/src/databricks/sql/auth/retry.py @@ -94,6 +94,7 @@ def __init__( stop_after_attempts_duration: float, delay_default: float, force_dangerous_codes: List[int], + server_directed_only: bool = False, urllib3_kwargs: dict = {}, ): # These values do not change from one command to the next @@ -103,6 +104,7 @@ def __init__( self.stop_after_attempts_duration = stop_after_attempts_duration self._delay_default = delay_default self.force_dangerous_codes = force_dangerous_codes + self.server_directed_only = server_directed_only # the urllib3 kwargs are a mix of configuration (some of which we override) # and counters like `total` or `connect` which may change between successive retries @@ -202,6 +204,7 @@ def new( stop_after_attempts_duration=self.stop_after_attempts_duration, delay_default=self.delay_default, force_dangerous_codes=self.force_dangerous_codes, + server_directed_only=self.server_directed_only, urllib3_kwargs={}, ) @@ -323,7 +326,9 @@ def get_backoff_time(self) -> float: return proposed_backoff - def should_retry(self, method: str, status_code: int) -> Tuple[bool, str]: + def should_retry( + self, method: str, status_code: int, has_retry_after: bool = False + ) -> Tuple[bool, str]: """This method encapsulates the connector's approach to retries. We always retry a request unless one of these conditions is met: @@ -388,6 +393,12 @@ def should_retry(self, method: str, status_code: int) -> Tuple[bool, str]: if not self._is_method_retryable(method): return False, "Only POST requests are retried" + # In server_directed_only mode, only retry when the server explicitly signals + # it's safe via a Retry-After header. This prevents duplicate side effects for + # non-idempotent operations. + if self.server_directed_only and not has_retry_after: + return (False, "server_directed_only mode: no Retry-After header present") + # Request failed, was an ExecuteStatement and the command may have reached the server if ( self.command_type == CommandType.EXECUTE_STATEMENT @@ -430,7 +441,7 @@ def is_retry( Logs a debug message if the request will be retried """ - should_retry, msg = self.should_retry(method, status_code) + should_retry, msg = self.should_retry(method, status_code, has_retry_after) if should_retry: logger.debug(msg) diff --git a/src/databricks/sql/backend/sea/utils/http_client.py b/src/databricks/sql/backend/sea/utils/http_client.py index caefe9929..7f10549ec 100644 --- a/src/databricks/sql/backend/sea/utils/http_client.py +++ b/src/databricks/sql/backend/sea/utils/http_client.py @@ -92,6 +92,9 @@ def __init__( ) self._retry_delay_default = kwargs.get("_retry_delay_default", 5.0) self.force_dangerous_codes = kwargs.get("_retry_dangerous_codes", []) + self._retry_server_directed_only = kwargs.get( + "_retry_server_directed_only", False + ) # Connection pooling settings self.max_connections = kwargs.get("max_connections", 10) @@ -116,6 +119,7 @@ def __init__( stop_after_attempts_duration=self._retry_stop_after_attempts_duration, delay_default=self._retry_delay_default, force_dangerous_codes=self.force_dangerous_codes, + server_directed_only=self._retry_server_directed_only, urllib3_kwargs=urllib3_kwargs, ) else: diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index e23f3389b..bf3047e26 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -191,6 +191,9 @@ def __init__( " This behaviour is deprecated and will be removed in a future release." ) self.force_dangerous_codes = kwargs.get("_retry_dangerous_codes", []) + self._retry_server_directed_only = kwargs.get( + "_retry_server_directed_only", False + ) additional_transport_args = {} @@ -217,6 +220,7 @@ def __init__( stop_after_attempts_duration=self._retry_stop_after_attempts_duration, delay_default=self._retry_delay_default, force_dangerous_codes=self.force_dangerous_codes, + server_directed_only=self._retry_server_directed_only, urllib3_kwargs=urllib3_kwargs, ) diff --git a/src/databricks/sql/common/unified_http_client.py b/src/databricks/sql/common/unified_http_client.py index ef55564c8..474fe00e7 100644 --- a/src/databricks/sql/common/unified_http_client.py +++ b/src/databricks/sql/common/unified_http_client.py @@ -135,6 +135,7 @@ def _setup_pool_managers(self): stop_after_attempts_duration=self.config.retry_stop_after_attempts_duration, delay_default=self.config.retry_delay_default, force_dangerous_codes=self.config.retry_dangerous_codes, + server_directed_only=self.config.retry_server_directed_only, ) # Initialize the required attributes that DatabricksRetryPolicy expects diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index b1fff7202..89ce513c7 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -977,6 +977,7 @@ def build_client_context(server_hostname: str, version: str, **kwargs): ), retry_delay_default=kwargs.get("_retry_delay_default"), retry_dangerous_codes=kwargs.get("_retry_dangerous_codes"), + retry_server_directed_only=kwargs.get("_retry_server_directed_only"), proxy_auth_method=kwargs.get("_proxy_auth_method"), pool_connections=kwargs.get("_pool_connections"), pool_maxsize=kwargs.get("_pool_maxsize"), diff --git a/tests/unit/test_retry.py b/tests/unit/test_retry.py index 0d01d8675..180907605 100644 --- a/tests/unit/test_retry.py +++ b/tests/unit/test_retry.py @@ -84,6 +84,114 @@ def test_excessive_retry_attempts_error(self, t_mock, retry_policy): # Internally urllib3 calls the increment function generating a new instance for every retry retry_policy = retry_policy.increment() + @pytest.fixture() + def server_directed_retry_policy(self) -> DatabricksRetryPolicy: + return DatabricksRetryPolicy( + delay_min=1, + delay_max=30, + stop_after_attempts_count=3, + stop_after_attempts_duration=900, + delay_default=2, + force_dangerous_codes=[], + server_directed_only=True, + ) + + def test_server_directed_only__retries_with_retry_after( + self, server_directed_retry_policy + ): + """429 + Retry-After header → should retry""" + server_directed_retry_policy._retry_start_time = time.time() + server_directed_retry_policy.command_type = CommandType.OTHER + should_retry, msg = server_directed_retry_policy.should_retry( + "POST", 429, has_retry_after=True + ) + assert should_retry is True + + def test_server_directed_only__no_retry_without_retry_after( + self, server_directed_retry_policy + ): + """429 without Retry-After header → no retry""" + server_directed_retry_policy._retry_start_time = time.time() + server_directed_retry_policy.command_type = CommandType.OTHER + should_retry, msg = server_directed_retry_policy.should_retry( + "POST", 429, has_retry_after=False + ) + assert should_retry is False + assert "server_directed_only" in msg + + def test_server_directed_only__no_retry_503_without_header( + self, server_directed_retry_policy + ): + """503 without Retry-After header → no retry""" + server_directed_retry_policy._retry_start_time = time.time() + server_directed_retry_policy.command_type = CommandType.OTHER + should_retry, msg = server_directed_retry_policy.should_retry( + "POST", 503, has_retry_after=False + ) + assert should_retry is False + assert "server_directed_only" in msg + + def test_server_directed_only__overrides_dangerous_codes(self): + """force_dangerous_codes=[500] + no Retry-After → no retry in server_directed_only mode""" + policy = DatabricksRetryPolicy( + delay_min=1, + delay_max=30, + stop_after_attempts_count=3, + stop_after_attempts_duration=900, + delay_default=2, + force_dangerous_codes=[500], + server_directed_only=True, + ) + policy._retry_start_time = time.time() + policy.command_type = CommandType.EXECUTE_STATEMENT + should_retry, msg = policy.should_retry("POST", 500, has_retry_after=False) + assert should_retry is False + assert "server_directed_only" in msg + + def test_server_directed_only__non_retryable_codes_unaffected( + self, server_directed_retry_policy + ): + """401/403/501 still don't retry even with Retry-After header""" + server_directed_retry_policy._retry_start_time = time.time() + server_directed_retry_policy.command_type = CommandType.OTHER + for code in [401, 403, 501]: + should_retry, msg = server_directed_retry_policy.should_retry( + "POST", code, has_retry_after=True + ) + assert should_retry is False, f"Code {code} should never retry" + + def test_default_mode_unchanged(self, retry_policy): + """server_directed_only=False preserves existing behavior — 429 retries without header""" + retry_policy._retry_start_time = time.time() + retry_policy.command_type = CommandType.OTHER + should_retry, msg = retry_policy.should_retry( + "POST", 429, has_retry_after=False + ) + assert should_retry is True + + def test_server_directed_only__survives_new(self, server_directed_retry_policy): + """urllib3 calls .new() between retries to create a fresh policy instance. + Verify that server_directed_only is carried over and still enforced.""" + server_directed_retry_policy._retry_start_time = time.time() + server_directed_retry_policy.command_type = CommandType.OTHER + new_policy = server_directed_retry_policy.new() + assert new_policy.server_directed_only is True + # The new instance should still block retries without Retry-After + should_retry, msg = new_policy.should_retry("POST", 429, has_retry_after=False) + assert should_retry is False + assert "server_directed_only" in msg + + def test_server_directed_only__execute_statement_with_retry_after( + self, server_directed_retry_policy + ): + """EXECUTE_STATEMENT + 429 + Retry-After header → retry""" + server_directed_retry_policy._retry_start_time = time.time() + server_directed_retry_policy.command_type = CommandType.EXECUTE_STATEMENT + should_retry, msg = server_directed_retry_policy.should_retry( + "POST", 429, has_retry_after=True + ) + assert should_retry is True + def test_404_does_not_retry_for_any_command_type(self, retry_policy): """Test that 404 never retries for any CommandType""" retry_policy._retry_start_time = time.time() diff --git a/tests/unit/test_unified_http_client.py b/tests/unit/test_unified_http_client.py index 4e9ce1bbf..eae732431 100644 --- a/tests/unit/test_unified_http_client.py +++ b/tests/unit/test_unified_http_client.py @@ -37,6 +37,7 @@ def client_context(self): context.retry_stop_after_attempts_duration = 300.0 context.retry_delay_default = 5.0 context.retry_dangerous_codes = [] + context.retry_server_directed_only = False context.proxy_auth_method = None context.pool_connections = 10 context.pool_maxsize = 20 @@ -48,16 +49,19 @@ def http_client(self, client_context): """Create UnifiedHttpClient instance.""" return UnifiedHttpClient(client_context) - @pytest.mark.parametrize("status_code,path", [ - (429, "reason.response"), - (503, "reason.response"), - (500, "direct_response"), - ]) + @pytest.mark.parametrize( + "status_code,path", + [ + (429, "reason.response"), + (503, "reason.response"), + (500, "direct_response"), + ], + ) def test_max_retry_error_with_status_codes(self, http_client, status_code, path): """Test MaxRetryError with various status codes and response paths.""" mock_pool = Mock() max_retry_error = MaxRetryError(pool=mock_pool, url="http://test.com") - + if path == "reason.response": max_retry_error.reason = Mock() max_retry_error.reason.response = Mock() @@ -79,12 +83,21 @@ def test_max_retry_error_with_status_codes(self, http_client, status_code, path) assert "http-code" in error.context assert error.context["http-code"] == status_code - @pytest.mark.parametrize("setup_func", [ - lambda e: None, # No setup - error with no attributes - lambda e: setattr(e, "reason", None), # reason=None - lambda e: (setattr(e, "reason", Mock()), setattr(e.reason, "response", None)), # reason.response=None - lambda e: (setattr(e, "reason", Mock()), setattr(e.reason, "response", Mock(spec=[]))), # No status attr - ]) + @pytest.mark.parametrize( + "setup_func", + [ + lambda e: None, # No setup - error with no attributes + lambda e: setattr(e, "reason", None), # reason=None + lambda e: ( + setattr(e, "reason", Mock()), + setattr(e.reason, "response", None), + ), # reason.response=None + lambda e: ( + setattr(e, "reason", Mock()), + setattr(e.reason, "response", Mock(spec=[])), + ), # No status attr + ], + ) def test_max_retry_error_missing_status(self, http_client, setup_func): """Test MaxRetryError without status code (no crash, empty context).""" mock_pool = Mock() @@ -104,12 +117,12 @@ def test_max_retry_error_prefers_reason_response(self, http_client): """Test that e.reason.response.status is preferred over e.response.status.""" mock_pool = Mock() max_retry_error = MaxRetryError(pool=mock_pool, url="http://test.com") - + # Set both structures with different status codes max_retry_error.reason = Mock() max_retry_error.reason.response = Mock() max_retry_error.reason.response.status = 429 # Should use this - + max_retry_error.response = Mock() max_retry_error.response.status = 500 # Should be ignored From 0d583fa6d5553aecf8e662b8725da361b31c5f70 Mon Sep 17 00:00:00 2001 From: Shubham Dhal Date: Wed, 18 Mar 2026 15:25:57 +0530 Subject: [PATCH 2/4] Remove unnecessary _retry_server_directed_only instance variables Inline kwargs.get() at the single point of use in ThriftDatabricksClient and SeaHttpClient instead of storing as dead instance state. Signed-off-by: Shubham Dhal --- src/databricks/sql/backend/sea/utils/http_client.py | 5 +---- src/databricks/sql/backend/thrift_backend.py | 5 +---- 2 files changed, 2 insertions(+), 8 deletions(-) diff --git a/src/databricks/sql/backend/sea/utils/http_client.py b/src/databricks/sql/backend/sea/utils/http_client.py index 7f10549ec..a3e00ec67 100644 --- a/src/databricks/sql/backend/sea/utils/http_client.py +++ b/src/databricks/sql/backend/sea/utils/http_client.py @@ -92,9 +92,6 @@ def __init__( ) self._retry_delay_default = kwargs.get("_retry_delay_default", 5.0) self.force_dangerous_codes = kwargs.get("_retry_dangerous_codes", []) - self._retry_server_directed_only = kwargs.get( - "_retry_server_directed_only", False - ) # Connection pooling settings self.max_connections = kwargs.get("max_connections", 10) @@ -119,7 +116,7 @@ def __init__( stop_after_attempts_duration=self._retry_stop_after_attempts_duration, delay_default=self._retry_delay_default, force_dangerous_codes=self.force_dangerous_codes, - server_directed_only=self._retry_server_directed_only, + server_directed_only=kwargs.get("_retry_server_directed_only", False), urllib3_kwargs=urllib3_kwargs, ) else: diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index bf3047e26..eef90a2dc 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -191,9 +191,6 @@ def __init__( " This behaviour is deprecated and will be removed in a future release." ) self.force_dangerous_codes = kwargs.get("_retry_dangerous_codes", []) - self._retry_server_directed_only = kwargs.get( - "_retry_server_directed_only", False - ) additional_transport_args = {} @@ -220,7 +217,7 @@ def __init__( stop_after_attempts_duration=self._retry_stop_after_attempts_duration, delay_default=self._retry_delay_default, force_dangerous_codes=self.force_dangerous_codes, - server_directed_only=self._retry_server_directed_only, + server_directed_only=kwargs.get("_retry_server_directed_only", False), urllib3_kwargs=urllib3_kwargs, ) From 2b48f7db845ee962885b4cef6398c82444bb9d6e Mon Sep 17 00:00:00 2001 From: Shubham Dhal Date: Wed, 18 Mar 2026 16:49:52 +0530 Subject: [PATCH 3/4] Address PR feedback: rename and clean up retry-after parameter - Rename server_directed_only to respect_server_retry_after_header throughout for clarity - Store _respect_server_retry_after_header as instance variable in Thrift/SEA backends to match existing kwargs extraction pattern - Replace duplicate test fixture with _make_retry_policy(**overrides) helper for flexible policy construction in tests Signed-off-by: Shubham Dhal --- src/databricks/sql/auth/common.py | 4 +- src/databricks/sql/auth/retry.py | 16 +-- .../sql/backend/sea/utils/http_client.py | 5 +- src/databricks/sql/backend/thrift_backend.py | 5 +- .../sql/common/unified_http_client.py | 2 +- src/databricks/sql/utils.py | 2 +- tests/unit/test_retry.py | 123 +++++++----------- tests/unit/test_unified_http_client.py | 2 +- 8 files changed, 70 insertions(+), 89 deletions(-) diff --git a/src/databricks/sql/auth/common.py b/src/databricks/sql/auth/common.py index 2ae317390..cda30a528 100644 --- a/src/databricks/sql/auth/common.py +++ b/src/databricks/sql/auth/common.py @@ -47,7 +47,7 @@ def __init__( retry_stop_after_attempts_duration: Optional[float] = None, retry_delay_default: Optional[float] = None, retry_dangerous_codes: Optional[List[int]] = None, - retry_server_directed_only: Optional[bool] = None, + respect_server_retry_after_header: Optional[bool] = None, proxy_auth_method: Optional[str] = None, pool_connections: Optional[int] = None, pool_maxsize: Optional[int] = None, @@ -81,7 +81,7 @@ def __init__( ) self.retry_delay_default = retry_delay_default or 5.0 self.retry_dangerous_codes = retry_dangerous_codes or [] - self.retry_server_directed_only = bool(retry_server_directed_only) + self.respect_server_retry_after_header = bool(respect_server_retry_after_header) self.proxy_auth_method = proxy_auth_method self.pool_connections = pool_connections or 10 self.pool_maxsize = pool_maxsize or 20 diff --git a/src/databricks/sql/auth/retry.py b/src/databricks/sql/auth/retry.py index 72a2e87a3..695e7d35c 100755 --- a/src/databricks/sql/auth/retry.py +++ b/src/databricks/sql/auth/retry.py @@ -94,7 +94,7 @@ def __init__( stop_after_attempts_duration: float, delay_default: float, force_dangerous_codes: List[int], - server_directed_only: bool = False, + respect_server_retry_after_header: bool = False, urllib3_kwargs: dict = {}, ): # These values do not change from one command to the next @@ -104,7 +104,7 @@ def __init__( self.stop_after_attempts_duration = stop_after_attempts_duration self._delay_default = delay_default self.force_dangerous_codes = force_dangerous_codes - self.server_directed_only = server_directed_only + self.respect_server_retry_after_header = respect_server_retry_after_header # the urllib3 kwargs are a mix of configuration (some of which we override) # and counters like `total` or `connect` which may change between successive retries @@ -204,7 +204,7 @@ def new( stop_after_attempts_duration=self.stop_after_attempts_duration, delay_default=self.delay_default, force_dangerous_codes=self.force_dangerous_codes, - server_directed_only=self.server_directed_only, + respect_server_retry_after_header=self.respect_server_retry_after_header, urllib3_kwargs={}, ) @@ -393,11 +393,11 @@ def should_retry( if not self._is_method_retryable(method): return False, "Only POST requests are retried" - # In server_directed_only mode, only retry when the server explicitly signals - # it's safe via a Retry-After header. This prevents duplicate side effects for - # non-idempotent operations. - if self.server_directed_only and not has_retry_after: - return (False, "server_directed_only mode: no Retry-After header present") + # When respect_server_retry_after_header is enabled, only retry when the + # server explicitly signals it's safe via a Retry-After header. This prevents + # duplicate side effects for non-idempotent operations. + if self.respect_server_retry_after_header and not has_retry_after: + return (False, "respect_server_retry_after_header mode: no Retry-After header present") # Request failed, was an ExecuteStatement and the command may have reached the server if ( diff --git a/src/databricks/sql/backend/sea/utils/http_client.py b/src/databricks/sql/backend/sea/utils/http_client.py index a3e00ec67..99fe5edb5 100644 --- a/src/databricks/sql/backend/sea/utils/http_client.py +++ b/src/databricks/sql/backend/sea/utils/http_client.py @@ -92,6 +92,9 @@ def __init__( ) self._retry_delay_default = kwargs.get("_retry_delay_default", 5.0) self.force_dangerous_codes = kwargs.get("_retry_dangerous_codes", []) + self._respect_server_retry_after_header = kwargs.get( + "_respect_server_retry_after_header", False + ) # Connection pooling settings self.max_connections = kwargs.get("max_connections", 10) @@ -116,7 +119,7 @@ def __init__( stop_after_attempts_duration=self._retry_stop_after_attempts_duration, delay_default=self._retry_delay_default, force_dangerous_codes=self.force_dangerous_codes, - server_directed_only=kwargs.get("_retry_server_directed_only", False), + respect_server_retry_after_header=self._respect_server_retry_after_header, urllib3_kwargs=urllib3_kwargs, ) else: diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index eef90a2dc..776a7b5c8 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -191,6 +191,9 @@ def __init__( " This behaviour is deprecated and will be removed in a future release." ) self.force_dangerous_codes = kwargs.get("_retry_dangerous_codes", []) + self._respect_server_retry_after_header = kwargs.get( + "_respect_server_retry_after_header", False + ) additional_transport_args = {} @@ -217,7 +220,7 @@ def __init__( stop_after_attempts_duration=self._retry_stop_after_attempts_duration, delay_default=self._retry_delay_default, force_dangerous_codes=self.force_dangerous_codes, - server_directed_only=kwargs.get("_retry_server_directed_only", False), + respect_server_retry_after_header=self._respect_server_retry_after_header, urllib3_kwargs=urllib3_kwargs, ) diff --git a/src/databricks/sql/common/unified_http_client.py b/src/databricks/sql/common/unified_http_client.py index 474fe00e7..67c178e19 100644 --- a/src/databricks/sql/common/unified_http_client.py +++ b/src/databricks/sql/common/unified_http_client.py @@ -135,7 +135,7 @@ def _setup_pool_managers(self): stop_after_attempts_duration=self.config.retry_stop_after_attempts_duration, delay_default=self.config.retry_delay_default, force_dangerous_codes=self.config.retry_dangerous_codes, - server_directed_only=self.config.retry_server_directed_only, + respect_server_retry_after_header=self.config.respect_server_retry_after_header, ) # Initialize the required attributes that DatabricksRetryPolicy expects diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index 89ce513c7..1c431e5eb 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -977,7 +977,7 @@ def build_client_context(server_hostname: str, version: str, **kwargs): ), retry_delay_default=kwargs.get("_retry_delay_default"), retry_dangerous_codes=kwargs.get("_retry_dangerous_codes"), - retry_server_directed_only=kwargs.get("_retry_server_directed_only"), + respect_server_retry_after_header=kwargs.get("_respect_server_retry_after_header"), proxy_auth_method=kwargs.get("_proxy_auth_method"), pool_connections=kwargs.get("_pool_connections"), pool_maxsize=kwargs.get("_pool_maxsize"), diff --git a/tests/unit/test_retry.py b/tests/unit/test_retry.py index 180907605..7abab509d 100644 --- a/tests/unit/test_retry.py +++ b/tests/unit/test_retry.py @@ -7,9 +7,8 @@ class TestRetry: - @pytest.fixture() - def retry_policy(self) -> DatabricksRetryPolicy: - return DatabricksRetryPolicy( + def _make_retry_policy(self, **overrides) -> DatabricksRetryPolicy: + defaults = dict( delay_min=1, delay_max=30, stop_after_attempts_count=3, @@ -17,6 +16,12 @@ def retry_policy(self) -> DatabricksRetryPolicy: delay_default=2, force_dangerous_codes=[], ) + defaults.update(overrides) + return DatabricksRetryPolicy(**defaults) + + @pytest.fixture() + def retry_policy(self) -> DatabricksRetryPolicy: + return self._make_retry_policy() @pytest.fixture() def error_history(self) -> RequestHistory: @@ -84,84 +89,56 @@ def test_excessive_retry_attempts_error(self, t_mock, retry_policy): # Internally urllib3 calls the increment function generating a new instance for every retry retry_policy = retry_policy.increment() - @pytest.fixture() - def server_directed_retry_policy(self) -> DatabricksRetryPolicy: - return DatabricksRetryPolicy( - delay_min=1, - delay_max=30, - stop_after_attempts_count=3, - stop_after_attempts_duration=900, - delay_default=2, - force_dangerous_codes=[], - server_directed_only=True, - ) - - def test_server_directed_only__retries_with_retry_after( - self, server_directed_retry_policy - ): + def test_respect_server_retry_after__retries_with_retry_after(self): """429 + Retry-After header → should retry""" - server_directed_retry_policy._retry_start_time = time.time() - server_directed_retry_policy.command_type = CommandType.OTHER - should_retry, msg = server_directed_retry_policy.should_retry( - "POST", 429, has_retry_after=True - ) + policy = self._make_retry_policy(respect_server_retry_after_header=True) + policy._retry_start_time = time.time() + policy.command_type = CommandType.OTHER + should_retry, msg = policy.should_retry("POST", 429, has_retry_after=True) assert should_retry is True - def test_server_directed_only__no_retry_without_retry_after( - self, server_directed_retry_policy - ): + def test_respect_server_retry_after__no_retry_without_retry_after(self): """429 without Retry-After header → no retry""" - server_directed_retry_policy._retry_start_time = time.time() - server_directed_retry_policy.command_type = CommandType.OTHER - should_retry, msg = server_directed_retry_policy.should_retry( - "POST", 429, has_retry_after=False - ) + policy = self._make_retry_policy(respect_server_retry_after_header=True) + policy._retry_start_time = time.time() + policy.command_type = CommandType.OTHER + should_retry, msg = policy.should_retry("POST", 429, has_retry_after=False) assert should_retry is False - assert "server_directed_only" in msg + assert "respect_server_retry_after_header" in msg - def test_server_directed_only__no_retry_503_without_header( - self, server_directed_retry_policy - ): + def test_respect_server_retry_after__no_retry_503_without_header(self): """503 without Retry-After header → no retry""" - server_directed_retry_policy._retry_start_time = time.time() - server_directed_retry_policy.command_type = CommandType.OTHER - should_retry, msg = server_directed_retry_policy.should_retry( - "POST", 503, has_retry_after=False - ) + policy = self._make_retry_policy(respect_server_retry_after_header=True) + policy._retry_start_time = time.time() + policy.command_type = CommandType.OTHER + should_retry, msg = policy.should_retry("POST", 503, has_retry_after=False) assert should_retry is False - assert "server_directed_only" in msg + assert "respect_server_retry_after_header" in msg - def test_server_directed_only__overrides_dangerous_codes(self): - """force_dangerous_codes=[500] + no Retry-After → no retry in server_directed_only mode""" - policy = DatabricksRetryPolicy( - delay_min=1, - delay_max=30, - stop_after_attempts_count=3, - stop_after_attempts_duration=900, - delay_default=2, - force_dangerous_codes=[500], - server_directed_only=True, + def test_respect_server_retry_after__overrides_dangerous_codes(self): + """force_dangerous_codes=[500] + no Retry-After → no retry in respect_server_retry_after_header mode""" + policy = self._make_retry_policy( + force_dangerous_codes=[500], respect_server_retry_after_header=True ) policy._retry_start_time = time.time() policy.command_type = CommandType.EXECUTE_STATEMENT should_retry, msg = policy.should_retry("POST", 500, has_retry_after=False) assert should_retry is False - assert "server_directed_only" in msg + assert "respect_server_retry_after_header" in msg - def test_server_directed_only__non_retryable_codes_unaffected( - self, server_directed_retry_policy - ): + def test_respect_server_retry_after__non_retryable_codes_unaffected(self): """401/403/501 still don't retry even with Retry-After header""" - server_directed_retry_policy._retry_start_time = time.time() - server_directed_retry_policy.command_type = CommandType.OTHER + policy = self._make_retry_policy(respect_server_retry_after_header=True) + policy._retry_start_time = time.time() + policy.command_type = CommandType.OTHER for code in [401, 403, 501]: - should_retry, msg = server_directed_retry_policy.should_retry( + should_retry, msg = policy.should_retry( "POST", code, has_retry_after=True ) assert should_retry is False, f"Code {code} should never retry" def test_default_mode_unchanged(self, retry_policy): - """server_directed_only=False preserves existing behavior — 429 retries without header""" + """respect_server_retry_after_header=False preserves existing behavior — 429 retries without header""" retry_policy._retry_start_time = time.time() retry_policy.command_type = CommandType.OTHER should_retry, msg = retry_policy.should_retry( @@ -169,27 +146,25 @@ def test_default_mode_unchanged(self, retry_policy): ) assert should_retry is True - def test_server_directed_only__survives_new(self, server_directed_retry_policy): + def test_respect_server_retry_after__survives_new(self): """urllib3 calls .new() between retries to create a fresh policy instance. - Verify that server_directed_only is carried over and still enforced.""" - server_directed_retry_policy._retry_start_time = time.time() - server_directed_retry_policy.command_type = CommandType.OTHER - new_policy = server_directed_retry_policy.new() - assert new_policy.server_directed_only is True + Verify that respect_server_retry_after_header is carried over and still enforced.""" + policy = self._make_retry_policy(respect_server_retry_after_header=True) + policy._retry_start_time = time.time() + policy.command_type = CommandType.OTHER + new_policy = policy.new() + assert new_policy.respect_server_retry_after_header is True # The new instance should still block retries without Retry-After should_retry, msg = new_policy.should_retry("POST", 429, has_retry_after=False) assert should_retry is False - assert "server_directed_only" in msg + assert "respect_server_retry_after_header" in msg - def test_server_directed_only__execute_statement_with_retry_after( - self, server_directed_retry_policy - ): + def test_respect_server_retry_after__execute_statement_with_retry_after(self): """EXECUTE_STATEMENT + 429 + Retry-After header → retry""" - server_directed_retry_policy._retry_start_time = time.time() - server_directed_retry_policy.command_type = CommandType.EXECUTE_STATEMENT - should_retry, msg = server_directed_retry_policy.should_retry( - "POST", 429, has_retry_after=True - ) + policy = self._make_retry_policy(respect_server_retry_after_header=True) + policy._retry_start_time = time.time() + policy.command_type = CommandType.EXECUTE_STATEMENT + should_retry, msg = policy.should_retry("POST", 429, has_retry_after=True) assert should_retry is True def test_404_does_not_retry_for_any_command_type(self, retry_policy): diff --git a/tests/unit/test_unified_http_client.py b/tests/unit/test_unified_http_client.py index eae732431..44d05178d 100644 --- a/tests/unit/test_unified_http_client.py +++ b/tests/unit/test_unified_http_client.py @@ -37,7 +37,7 @@ def client_context(self): context.retry_stop_after_attempts_duration = 300.0 context.retry_delay_default = 5.0 context.retry_dangerous_codes = [] - context.retry_server_directed_only = False + context.respect_server_retry_after_header = False context.proxy_auth_method = None context.pool_connections = 10 context.pool_maxsize = 20 From 0795d35f242fc48705552383803effb5f1019105 Mon Sep 17 00:00:00 2001 From: Shubham Dhal Date: Wed, 18 Mar 2026 17:47:33 +0530 Subject: [PATCH 4/4] Fix Black formatting for retry.py and utils.py Signed-off-by: Shubham Dhal --- src/databricks/sql/auth/retry.py | 5 ++++- src/databricks/sql/utils.py | 4 +++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/src/databricks/sql/auth/retry.py b/src/databricks/sql/auth/retry.py index 695e7d35c..9cb29fdce 100755 --- a/src/databricks/sql/auth/retry.py +++ b/src/databricks/sql/auth/retry.py @@ -397,7 +397,10 @@ def should_retry( # server explicitly signals it's safe via a Retry-After header. This prevents # duplicate side effects for non-idempotent operations. if self.respect_server_retry_after_header and not has_retry_after: - return (False, "respect_server_retry_after_header mode: no Retry-After header present") + return ( + False, + "respect_server_retry_after_header mode: no Retry-After header present", + ) # Request failed, was an ExecuteStatement and the command may have reached the server if ( diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index 1c431e5eb..7d7c96807 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -977,7 +977,9 @@ def build_client_context(server_hostname: str, version: str, **kwargs): ), retry_delay_default=kwargs.get("_retry_delay_default"), retry_dangerous_codes=kwargs.get("_retry_dangerous_codes"), - respect_server_retry_after_header=kwargs.get("_respect_server_retry_after_header"), + respect_server_retry_after_header=kwargs.get( + "_respect_server_retry_after_header" + ), proxy_auth_method=kwargs.get("_proxy_auth_method"), pool_connections=kwargs.get("_pool_connections"), pool_maxsize=kwargs.get("_pool_maxsize"),