diff --git a/README.md b/README.md index 92811ffa..dbe654cc 100644 --- a/README.md +++ b/README.md @@ -193,7 +193,7 @@ The following example replaces the default `requests` backend with `httpx` and r import httpx from sinch import SinchClient from sinch.core.ports.http_transport import HTTPTransport -from sinch.core.endpoint import HTTPEndpoint +from sinch.core.models.http_request import HttpRequest from sinch.core.models.http_response import HTTPResponse @@ -204,9 +204,7 @@ class MyHTTPImplementation(HTTPTransport): proxy=f"http://{proxy_user}:{proxy_password}@{proxy_url}" ) - def send(self, endpoint: HTTPEndpoint) -> HTTPResponse: - request_data = self.prepare_request(endpoint) - request_data = self.authenticate(endpoint, request_data) + def send(self, request_data: HttpRequest) -> HTTPResponse: body = request_data.request_body response = self.http_client.request( diff --git a/sinch/core/adapters/requests_http_transport.py b/sinch/core/adapters/requests_http_transport.py index 6fc62a77..ef233641 100644 --- a/sinch/core/adapters/requests_http_transport.py +++ b/sinch/core/adapters/requests_http_transport.py @@ -1,21 +1,28 @@ import requests from sinch.core.ports.http_transport import HTTPTransport, HttpRequest -from sinch.core.endpoint import HTTPEndpoint from sinch.core.models.http_response import HTTPResponse - class HTTPTransportRequests(HTTPTransport): + """ + Sync HTTP transport using the requests library. + """ + def __init__(self, sinch): super().__init__(sinch) self.http_session = requests.Session() - def send(self, endpoint: HTTPEndpoint) -> HTTPResponse: - request_data: HttpRequest = self.prepare_request(endpoint) - request_data: HttpRequest = self.authenticate(endpoint, request_data) + def send(self, request_data: HttpRequest) -> HTTPResponse: + """ + Performs the HTTP call with requests and maps the result to an HTTPResponse. + :param request_data: The prepared request to send. + :type request_data: HttpRequest + :returns: The HTTP response. + :rtype: HTTPResponse + """ self.sinch.configuration.logger.debug( - f"Sync HTTP {request_data.http_method} call with headers:" - f" {request_data.headers}, body: {request_data.request_body} and query_params: {request_data.query_params} to URL: {request_data.url}" + f"Sync HTTP request {request_data.http_method} call with headers:" + f" {request_data.headers} and body: {request_data.request_body} to URL: {request_data.url}" ) response = self.http_session.request( method=request_data.http_method, @@ -30,8 +37,8 @@ def send(self, endpoint: HTTPEndpoint) -> HTTPResponse: response_body = self.deserialize_json_response(response) self.sinch.configuration.logger.debug( - f"Sync HTTP {response.status_code} response with headers: {response.headers}" - f"and body: {response_body} from URL: {request_data.url}" + f"Sync HTTP response {response.status_code} with headers: {response.headers}" + f" and body: {response_body} from URL: {request_data.url}" ) return HTTPResponse( diff --git a/sinch/core/ports/http_transport.py b/sinch/core/ports/http_transport.py index ec0edafc..01e6cb87 100644 --- a/sinch/core/ports/http_transport.py +++ b/sinch/core/ports/http_transport.py @@ -1,70 +1,74 @@ from abc import ABC, abstractmethod from platform import python_version +from typing import Optional + +from requests import Response from sinch.core.endpoint import HTTPEndpoint from sinch.core.models.http_request import HttpRequest from sinch.core.models.http_response import HTTPResponse from sinch.core.exceptions import ValidationException, SinchException from sinch.core.enums import HTTPAuthentication -from sinch.core.token_manager import TokenState from sinch import __version__ as sdk_version class HTTPTransport(ABC): - """Base class for HTTP transports. + """ + Base class for HTTP transports. - Subclasses implement ``send`` to perform the raw HTTP call. - The public ``request`` method adds cross-cutting concerns on top: - authentication, logging hooks, and automatic token refresh on 401. + Subclasses implement :meth:`send` to perform the raw HTTP call. The public + :meth:`request` method adds cross-cutting concerns on top: request + preparation, authentication, and automatic token refresh on 401. """ def __init__(self, sinch): self.sinch = sinch - # ------------------------------------------------------------------ - # Subclass contract - # ------------------------------------------------------------------ @abstractmethod - def send(self, endpoint: HTTPEndpoint) -> HTTPResponse: - """Execute a single HTTP round-trip and return the response. - - Implementations must prepare the request, authenticate, perform the - HTTP call, deserialize the response, and return an ``HTTPResponse``. - They should **not** handle token refresh — that is done by - ``request``. + def send(self, request_data: HttpRequest) -> HTTPResponse: """ + Performs a single HTTP round-trip for an already-prepared, authenticated request. - # ------------------------------------------------------------------ - # Public API - # ------------------------------------------------------------------ + :param request_data: The prepared request to send. + :type request_data: HttpRequest + :returns: The HTTP response. + :rtype: HTTPResponse + """ def request(self, endpoint: HTTPEndpoint) -> HTTPResponse: - """Send a request with automatic OAuth token refresh on 401. + """ + Sends a request, renewing the token and retrying once on an expired-token 401. - If the server responds with 401 *and* the token is detected as - expired, the token is invalidated and **one** retry is attempted - with a fresh token. A second consecutive 401 is handed straight - to the endpoint's error handler — no further retries. + :param endpoint: The endpoint to call. + :type endpoint: HTTPEndpoint + :returns: The handled HTTP response. + :rtype: HTTPResponse """ - http_response = self.send(endpoint) + request_data = self.prepare_request(endpoint) + request_data = self.authenticate(endpoint, request_data) + http_response = self.send(request_data) if self._should_refresh_token(endpoint, http_response): - self.sinch.configuration.token_manager.handle_invalid_token( - http_response - ) - if ( - self.sinch.configuration.token_manager.token_state - == TokenState.EXPIRED - ): - http_response = self.send(endpoint) + used_token = self._get_bearer_token_from_request(request_data) + new_token = self.sinch.configuration.token_manager.refresh_auth_token(used_token) + self._set_bearer_token(request_data, new_token.access_token) + http_response = self.send(request_data) return endpoint.handle_response(http_response) - # ------------------------------------------------------------------ - # Internals - # ------------------------------------------------------------------ - def authenticate(self, endpoint, request_data): + def authenticate(self, endpoint: HTTPEndpoint, request_data: HttpRequest) -> HttpRequest: + """ + Stamps the credentials required by the endpoint's auth scheme onto the request. + + :param endpoint: The endpoint being called, whose HTTP_AUTHENTICATION selects the scheme. + :type endpoint: HTTPEndpoint + :param request_data: The request to authenticate, mutated in place. + :type request_data: HttpRequest + :returns: The same request, with auth applied. + :rtype: HttpRequest + :raises ValidationException: If the credentials required by the scheme are missing. + """ if endpoint.HTTP_AUTHENTICATION in (HTTPAuthentication.BASIC.value, HTTPAuthentication.OAUTH.value): if ( not self.sinch.configuration.key_id @@ -87,10 +91,7 @@ def authenticate(self, endpoint, request_data): if endpoint.HTTP_AUTHENTICATION == HTTPAuthentication.OAUTH.value: token = self.sinch.authentication.get_auth_token().access_token - request_data.headers.update({ - "Authorization": f"Bearer {token}", - "Content-Type": "application/json" - }) + self._set_bearer_token(request_data, token) elif endpoint.HTTP_AUTHENTICATION == HTTPAuthentication.SMS_TOKEN.value: if not self.sinch.configuration.sms_api_token or not self.sinch.configuration.service_plan_id: raise ValidationException( @@ -101,14 +102,19 @@ def authenticate(self, endpoint, request_data): is_from_server=False, response=None ) - request_data.headers.update({ - "Authorization": f"Bearer {self.sinch.configuration.sms_api_token}", - "Content-Type": "application/json" - }) + self._set_bearer_token(request_data, self.sinch.configuration.sms_api_token) return request_data def prepare_request(self, endpoint: HTTPEndpoint) -> HttpRequest: + """ + Builds the HttpRequest for an endpoint. + + :param endpoint: The endpoint to build the request for. + :type endpoint: HTTPEndpoint + :returns: The prepared request. + :rtype: HttpRequest + """ url_query_params = endpoint.build_query_params() return HttpRequest( @@ -124,7 +130,16 @@ def prepare_request(self, endpoint: HTTPEndpoint) -> HttpRequest: ) @staticmethod - def deserialize_json_response(response): + def deserialize_json_response(response: Response) -> dict: + """ + Parses the JSON body of a response. + + :param response: The raw HTTP response. + :type response: Response + :returns: The parsed body. + :rtype: dict + :raises SinchException: If the body is present but not valid JSON. + """ if response.content: try: response_body = response.json() @@ -140,10 +155,48 @@ def deserialize_json_response(response): return response_body @staticmethod - def _should_refresh_token(endpoint, http_response): - """Return True when a 401 response should trigger a token refresh.""" - return ( - http_response.status_code == 401 - and endpoint.HTTP_AUTHENTICATION - == HTTPAuthentication.OAUTH.value - ) + def _should_refresh_token(endpoint: HTTPEndpoint, http_response: HTTPResponse) -> bool: + """ + Returns True for an OAuth endpoint that got a 401 with an expired-token header. + + :param endpoint: The endpoint that was called. + :type endpoint: HTTPEndpoint + :param http_response: The response received. + :type http_response: HTTPResponse + :returns: Whether the token should be refreshed and the request retried. + :rtype: bool + """ + if endpoint.HTTP_AUTHENTICATION != HTTPAuthentication.OAUTH.value: + return False + if http_response.status_code != 401: + return False + www_authenticate = http_response.headers.get("www-authenticate") or "" + return "expired" in www_authenticate + + @staticmethod + def _get_bearer_token_from_request(request_data: HttpRequest) -> Optional[str]: + """ + Extracts the bearer token from the request's Authorization header. + + :param request_data: The request. + :type request_data: HttpRequest + :returns: The bearer token, or None if absent or not a bearer. + :rtype: Optional[str] + """ + auth = request_data.headers.get("Authorization", "") + return auth.removeprefix("Bearer ") if auth.startswith("Bearer ") else None + + @staticmethod + def _set_bearer_token(request_data: HttpRequest, token: str) -> None: + """ + Stamps the bearer token onto the request's Authorization header. + + :param request_data: The request. + :type request_data: HttpRequest + :param token: The bearer token. + :type token: str + """ + request_data.headers.update({ + "Authorization": f"Bearer {token}", + "Content-Type": "application/json" + }) diff --git a/sinch/core/token_manager.py b/sinch/core/token_manager.py index b66bba85..b898de2d 100644 --- a/sinch/core/token_manager.py +++ b/sinch/core/token_manager.py @@ -1,35 +1,57 @@ from enum import Enum from abc import ABC, abstractmethod +import threading from sinch.domains.authentication.models.v1.authentication import OAuthToken from sinch.domains.authentication.endpoints.v1.oauth import OAuthEndpoint from sinch.core.exceptions import ValidationException class TokenState(Enum): + """ + Lifecycle state of the cached OAuth token. + """ + VALID = "VALID" + """ + A usable token is currently cached. + """ INVALID = "INVALID" - EXPIRED = "EXPIRED" + """ + No token has been obtained yet. + """ class TokenManagerBase(ABC): + """ + Base class for OAuth token managers. + + Holds the cached access token together with the lock that guards every + token mutation. + """ + def __init__(self, sinch): self.sinch = sinch - self.token = None - self.token_state = TokenState.INVALID + self.token: OAuthToken | None = None + self.token_state: TokenState = TokenState.INVALID + self._lock: threading.Lock = threading.Lock() @abstractmethod def get_auth_token(self) -> OAuthToken: pass - def invalidate_expired_token(self): - self.token = None - self.token_state = TokenState.EXPIRED + @abstractmethod + def refresh_auth_token(self, used_token: str) -> OAuthToken: + pass + - def handle_invalid_token(self, http_response): - if http_response.headers.get("www-authenticate") and "expired" in http_response.headers["www-authenticate"]: - self.invalidate_expired_token() + def set_auth_token(self, token: dict): + """ + Sets the OAuth token and marks the token_state as VALID. - def set_auth_token(self, token) -> None: + :param token: The token fields. + :type token: dict + :raises ValidationException: If the fields do not match the OAuthToken structure. + """ try: self.token = OAuthToken(**token) self.token_state = TokenState.VALID @@ -39,13 +61,49 @@ def set_auth_token(self, token) -> None: is_from_server=False, response=None ) + + def _fetch_new_token(self) -> OAuthToken: + """ + Requests a new token from the OAuth endpoint and stores it as the current token. + + :returns: The freshly fetched token. + :rtype: OAuthToken + """ + self.token = self.sinch.configuration.transport.request(OAuthEndpoint()) + self.token_state = TokenState.VALID + return self.token class TokenManager(TokenManagerBase): + """ + Thread-safe synchronous OAuth token manager. + """ + def get_auth_token(self) -> OAuthToken: - if self.token: + """ + Returns the stored token, fetching one on first use. Uses double-checked locking + + :returns: A valid OAuth token. + :rtype: OAuthToken + """ + if self.token is not None: return self.token + + with self._lock: + if self.token is not None: + return self.token + return self._fetch_new_token() + + def refresh_auth_token(self, used_token: str) -> OAuthToken: + """ + Renews the token after an expired-token 401, deduping concurrent renewals. - self.token = self.sinch.configuration.transport.request(OAuthEndpoint()) - self.token_state = TokenState.VALID - return self.token + :param used_token: The access token used by the request that received the 401. + :type used_token: str + :returns: A valid token. + :rtype: OAuthToken + """ + with self._lock: + if self.token is not None and self.token.access_token != used_token: + return self.token + return self._fetch_new_token() diff --git a/tests/unit/http_transport_tests.py b/tests/unit/http_transport_tests.py deleted file mode 100644 index c567dcb9..00000000 --- a/tests/unit/http_transport_tests.py +++ /dev/null @@ -1,253 +0,0 @@ -import pytest -from unittest.mock import Mock, call -from sinch.core.enums import HTTPAuthentication -from sinch.core.exceptions import ValidationException -from sinch.core.models.http_request import HttpRequest -from sinch.core.endpoint import HTTPEndpoint -from sinch.core.models.http_response import HTTPResponse -from sinch.core.ports.http_transport import HTTPTransport -from sinch.core.token_manager import TokenState - - -# Mock classes and fixtures -def _make_mock_endpoint(auth_type, error_on_4xx=False): - """Create a MockEndpoint that satisfies the abstract property contract.""" - - class _Endpoint(HTTPEndpoint): - HTTP_AUTHENTICATION = auth_type - HTTP_METHOD = "GET" - - def __init__(self): - # Skip super().__init__ — we don't need project_id / request_data - pass - - def build_url(self, sinch): - return "api.sinch.com/test" - - def get_url_without_origin(self, sinch): - return "/test" - - def request_body(self): - return {} - - def build_query_params(self): - return {} - - def handle_response(self, response: HTTPResponse): - if error_on_4xx and response.status_code >= 400: - raise ValidationException( - message=f"HTTP {response.status_code}", - is_from_server=True, - response=response, - ) - return response - - return _Endpoint() - - -@pytest.fixture -def mock_sinch(): - sinch = Mock() - sinch.configuration = Mock() - sinch.configuration.key_id = "test_key_id" - sinch.configuration.key_secret = "test_key_secret" - sinch.configuration.project_id = "test_project_id" - sinch.configuration.sms_api_token = "test_sms_token" - sinch.configuration.service_plan_id = "test_service_plan" - return sinch - - -@pytest.fixture -def base_request(): - return HttpRequest( - headers={}, - url="https://api.sinch.com/test", - http_method="GET", - request_body={}, - query_params={}, - auth=() - ) - - -class MockHTTPTransport(HTTPTransport): - """Transport whose send() returns from a pre-configured list of responses.""" - - def __init__(self, sinch, responses=None): - super().__init__(sinch) - self._responses = list(responses or []) - self._call_count = 0 - - def send(self, endpoint: HTTPEndpoint) -> HTTPResponse: - if self._call_count < len(self._responses): - resp = self._responses[self._call_count] - else: - resp = HTTPResponse(status_code=200, body={}, headers={}) - self._call_count += 1 - return resp - - @property - def call_count(self): - return self._call_count - - -# Synchronous Transport Tests -class TestHTTPTransport: - @pytest.mark.parametrize("auth_type", [ - HTTPAuthentication.BASIC.value, - HTTPAuthentication.OAUTH.value, - HTTPAuthentication.SMS_TOKEN.value - ]) - def test_authenticate(self, mock_sinch, base_request, auth_type): - transport = MockHTTPTransport(mock_sinch) - endpoint = _make_mock_endpoint(auth_type) - - if auth_type == HTTPAuthentication.BASIC.value: - result = transport.authenticate(endpoint, base_request) - assert result.auth == ("test_key_id", "test_key_secret") - - elif auth_type == HTTPAuthentication.OAUTH.value: - mock_sinch.authentication.get_auth_token.return_value.access_token = "test_token" - result = transport.authenticate(endpoint, base_request) - assert result.headers["Authorization"] == "Bearer test_token" - assert result.headers["Content-Type"] == "application/json" - - elif auth_type == HTTPAuthentication.SMS_TOKEN.value: - result = transport.authenticate(endpoint, base_request) - assert result.headers["Authorization"] == "Bearer test_sms_token" - assert result.headers["Content-Type"] == "application/json" - - @pytest.mark.parametrize("auth_type,missing_creds", [ - (HTTPAuthentication.BASIC.value, {"key_id": None}), - (HTTPAuthentication.OAUTH.value, {"key_secret": None}), - (HTTPAuthentication.SMS_TOKEN.value, {"sms_api_token": None}) - ]) - def test_authenticate_missing_credentials(self, mock_sinch, base_request, auth_type, missing_creds): - transport = MockHTTPTransport(mock_sinch) - endpoint = _make_mock_endpoint(auth_type) - - for cred, value in missing_creds.items(): - setattr(mock_sinch.configuration, cred, value) - - with pytest.raises(ValidationException): - transport.authenticate(endpoint, base_request) - - -class TestTokenRefreshRetry: - """Tests for the automatic token refresh on 401 expired responses.""" - - @staticmethod - def _expired_401(): - return HTTPResponse( - status_code=401, - body={"error": "token expired"}, - headers={"www-authenticate": "Bearer error=\"expired\""}, - ) - - @staticmethod - def _non_expired_401(): - return HTTPResponse( - status_code=401, - body={"error": "unauthorized"}, - headers={"www-authenticate": "Bearer error=\"invalid_token\""}, - ) - - @staticmethod - def _ok_200(): - return HTTPResponse(status_code=200, body={"ok": True}, headers={}) - - def test_retry_succeeds_after_expired_token(self, mock_sinch): - """A single 401-expired followed by a 200 should retry once and succeed.""" - from sinch.core.token_manager import TokenManager - - token_manager = Mock(spec=TokenManager) - token_manager.token_state = TokenState.VALID - - def mark_expired(http_response): - token_manager.token_state = TokenState.EXPIRED - - token_manager.handle_invalid_token.side_effect = mark_expired - mock_sinch.configuration.token_manager = token_manager - - transport = MockHTTPTransport( - mock_sinch, - responses=[self._expired_401(), self._ok_200()], - ) - endpoint = _make_mock_endpoint(HTTPAuthentication.OAUTH.value) - - result = transport.request(endpoint) - - assert result.status_code == 200 - assert transport.call_count == 2 - token_manager.handle_invalid_token.assert_called_once() - - def test_no_infinite_loop_on_persistent_401(self, mock_sinch): - """Two consecutive 401-expired must NOT cause infinite retries. - - The second 401 should be handed to the endpoint's error handler - and send() should be called at most twice. - """ - from sinch.core.token_manager import TokenManager - - token_manager = Mock(spec=TokenManager) - token_manager.token_state = TokenState.VALID - - def mark_expired(http_response): - token_manager.token_state = TokenState.EXPIRED - - token_manager.handle_invalid_token.side_effect = mark_expired - mock_sinch.configuration.token_manager = token_manager - - transport = MockHTTPTransport( - mock_sinch, - responses=[self._expired_401(), self._expired_401()], - ) - endpoint = _make_mock_endpoint(HTTPAuthentication.OAUTH.value, error_on_4xx=True) - - with pytest.raises(ValidationException, match="401"): - transport.request(endpoint) - - # send() must have been called exactly twice: initial + one retry - assert transport.call_count == 2 - - def test_no_retry_when_401_is_not_expired(self, mock_sinch): - """A 401 without 'expired' in WWW-Authenticate should NOT trigger a retry.""" - from sinch.core.token_manager import TokenManager - - token_manager = Mock(spec=TokenManager) - token_manager.token_state = TokenState.VALID - - # handle_invalid_token inspects the header but does NOT set EXPIRED - # because the header says "invalid_token", not "expired" - token_manager.handle_invalid_token.side_effect = lambda r: None - mock_sinch.configuration.token_manager = token_manager - - transport = MockHTTPTransport( - mock_sinch, - responses=[self._non_expired_401()], - ) - endpoint = _make_mock_endpoint(HTTPAuthentication.OAUTH.value, error_on_4xx=True) - - with pytest.raises(ValidationException, match="401"): - transport.request(endpoint) - - # send() called only once — no retry - assert transport.call_count == 1 - - def test_no_retry_for_non_oauth_endpoint(self, mock_sinch): - """A 401 on a BASIC-auth endpoint should NOT trigger token refresh.""" - from sinch.core.token_manager import TokenManager - - token_manager = Mock(spec=TokenManager) - mock_sinch.configuration.token_manager = token_manager - - transport = MockHTTPTransport( - mock_sinch, - responses=[self._expired_401()], - ) - endpoint = _make_mock_endpoint(HTTPAuthentication.BASIC.value, error_on_4xx=True) - - with pytest.raises(ValidationException, match="401"): - transport.request(endpoint) - - assert transport.call_count == 1 - token_manager.handle_invalid_token.assert_not_called() diff --git a/tests/unit/test_http_transport.py b/tests/unit/test_http_transport.py new file mode 100644 index 00000000..3afdf533 --- /dev/null +++ b/tests/unit/test_http_transport.py @@ -0,0 +1,275 @@ +import json +import pytest +from unittest.mock import Mock +from sinch.core.enums import HTTPAuthentication +from sinch.core.exceptions import ValidationException, SinchException +from sinch.core.models.http_request import HttpRequest +from sinch.core.endpoint import HTTPEndpoint +from sinch.core.models.http_response import HTTPResponse +from sinch.core.adapters.requests_http_transport import HTTPTransportRequests +from sinch.core.token_manager import TokenManager +from sinch.domains.authentication.models.v1.authentication import OAuthToken + + +# Mock classes and fixtures +def _make_mock_endpoint(auth_type, error_on_4xx=False): + """Create a MockEndpoint that satisfies the abstract property contract.""" + + class _Endpoint(HTTPEndpoint): + HTTP_AUTHENTICATION = auth_type + HTTP_METHOD = "GET" + + def __init__(self): + # Skip super().__init__ — we don't need project_id / request_data + pass + + def build_url(self, sinch): + return "api.sinch.com/test" + + def get_url_without_origin(self, sinch): + return "/test" + + def request_body(self): + return {} + + def build_query_params(self): + return {} + + def handle_response(self, response: HTTPResponse): + if error_on_4xx and response.status_code >= 400: + raise ValidationException( + message=f"HTTP {response.status_code}", + is_from_server=True, + response=response, + ) + return response + + return _Endpoint() + + +def _requests_response(status_code, body=None, headers=None): + """Fake of a requests.Response, just enough for deserialize_json_response.""" + resp = Mock() + resp.status_code = status_code + resp.content = json.dumps(body or {}).encode() + resp.json.return_value = body or {} + resp.headers = headers or {} + return resp + + +def _server_rejecting_expired_token(accepted_token): + """Fake http_session.request: 200 only when the request carries `accepted_token`, + otherwise a 401-expired — like a server that rejects the stale token.""" + def respond(*args, **kwargs): + if kwargs["headers"].get("Authorization") == accepted_token: + return _requests_response(200, body={"ok": True}) + return _requests_response(401, headers={"www-authenticate": 'Bearer error="expired"'}) + return respond + + +def _token_manager(mock_sinch, *, old="old", new="new"): + """Mock TokenManager that hands out `old` and renews to `new`.""" + token_manager = Mock(spec=TokenManager) + token_manager.refresh_auth_token.return_value = OAuthToken( + access_token=new, expires_in=3599, scope="", token_type="bearer" + ) + mock_sinch.configuration.token_manager = token_manager + # authenticate() reads the initial token via sinch.authentication.get_auth_token() + mock_sinch.authentication.get_auth_token.return_value.access_token = old + return token_manager + + +@pytest.fixture +def mock_sinch(): + sinch = Mock() + sinch.configuration = Mock() + sinch.configuration.key_id = "test_key_id" + sinch.configuration.key_secret = "test_key_secret" + sinch.configuration.project_id = "test_project_id" + sinch.configuration.sms_api_token = "test_sms_token" + sinch.configuration.service_plan_id = "test_service_plan" + return sinch + + +@pytest.fixture +def base_request(): + return HttpRequest( + headers={}, + url="https://api.sinch.com/test", + http_method="GET", + request_body={}, + query_params={}, + auth=() + ) + + +class TestHTTPTransport: + @pytest.mark.parametrize("auth_type", [ + HTTPAuthentication.BASIC.value, + HTTPAuthentication.OAUTH.value, + HTTPAuthentication.SMS_TOKEN.value + ]) + def test_authenticate(self, mock_sinch, base_request, auth_type): + transport = HTTPTransportRequests(mock_sinch) + endpoint = _make_mock_endpoint(auth_type) + + if auth_type == HTTPAuthentication.BASIC.value: + result = transport.authenticate(endpoint, base_request) + assert result.auth == ("test_key_id", "test_key_secret") + + elif auth_type == HTTPAuthentication.OAUTH.value: + mock_sinch.authentication.get_auth_token.return_value.access_token = "test_token" + result = transport.authenticate(endpoint, base_request) + assert result.headers["Authorization"] == "Bearer test_token" + assert result.headers["Content-Type"] == "application/json" + + elif auth_type == HTTPAuthentication.SMS_TOKEN.value: + result = transport.authenticate(endpoint, base_request) + assert result.headers["Authorization"] == "Bearer test_sms_token" + assert result.headers["Content-Type"] == "application/json" + + @pytest.mark.parametrize("auth_type,missing_creds", [ + (HTTPAuthentication.BASIC.value, {"key_id": None}), + (HTTPAuthentication.OAUTH.value, {"key_secret": None}), + (HTTPAuthentication.SMS_TOKEN.value, {"sms_api_token": None}) + ]) + def test_authenticate_missing_credentials(self, mock_sinch, base_request, auth_type, missing_creds): + transport = HTTPTransportRequests(mock_sinch) + endpoint = _make_mock_endpoint(auth_type) + + for cred, value in missing_creds.items(): + setattr(mock_sinch.configuration, cred, value) + + with pytest.raises(ValidationException): + transport.authenticate(endpoint, base_request) + + +class TestSend: + def test_send_maps_requests_response(self, mock_sinch, base_request): + transport = HTTPTransportRequests(mock_sinch) + transport.http_session.request = Mock( + return_value=_requests_response(200, body={"x": 1}) + ) + + result = transport.send(base_request) + + assert isinstance(result, HTTPResponse) + assert result.status_code == 200 + assert result.body == {"x": 1} + + def test_send_empty_body_returns_empty_dict(self, mock_sinch, base_request): + transport = HTTPTransportRequests(mock_sinch) + transport.http_session.request = Mock( + return_value=Mock(status_code=204, content=b"", headers={}) + ) + + result = transport.send(base_request) + + assert result.status_code == 204 + assert result.body == {} + + def test_send_raises_on_invalid_json(self, mock_sinch, base_request): + transport = HTTPTransportRequests(mock_sinch) + bad_response = Mock(status_code=200, content=b"not json", headers={}) + bad_response.json.side_effect = ValueError("bad json") + transport.http_session.request = Mock(return_value=bad_response) + + with pytest.raises(SinchException): + transport.send(base_request) + + +class TestTokenRefreshRetry: + """Tests for the automatic token refresh on 401-expired responses.""" + + @staticmethod + def _expired_401(): + return _requests_response( + 401, + body={"error": "token expired"}, + headers={"www-authenticate": 'Bearer error="expired"'}, + ) + + @staticmethod + def _non_expired_401(): + return _requests_response( + 401, + body={"error": "unauthorized"}, + headers={"www-authenticate": 'Bearer error="invalid_token"'}, + ) + + def test_retry_succeeds_after_expired_token(self, mock_sinch): + token_manager = _token_manager(mock_sinch) + transport = HTTPTransportRequests(mock_sinch) + # The server accepts only the renewed token, so a 200 proves the retry re-stamped it. + transport.http_session.request = Mock(side_effect=_server_rejecting_expired_token("Bearer new")) + endpoint = _make_mock_endpoint(HTTPAuthentication.OAUTH.value) + + result = transport.request(endpoint) + + assert result.status_code == 200 + assert transport.http_session.request.call_count == 2 + token_manager.refresh_auth_token.assert_called_once_with("old") + + def test_no_retry_when_401_is_not_expired(self, mock_sinch): + token_manager = _token_manager(mock_sinch) + transport = HTTPTransportRequests(mock_sinch) + transport.http_session.request = Mock(side_effect=[self._non_expired_401()]) + endpoint = _make_mock_endpoint(HTTPAuthentication.OAUTH.value, error_on_4xx=True) + + with pytest.raises(ValidationException, match="401"): + transport.request(endpoint) + + assert transport.http_session.request.call_count == 1 + token_manager.refresh_auth_token.assert_not_called() + + def test_no_retry_for_non_oauth_endpoint(self, mock_sinch): + token_manager = _token_manager(mock_sinch) + transport = HTTPTransportRequests(mock_sinch) + transport.http_session.request = Mock(side_effect=[self._expired_401()]) + endpoint = _make_mock_endpoint(HTTPAuthentication.BASIC.value, error_on_4xx=True) + + with pytest.raises(ValidationException, match="401"): + transport.request(endpoint) + + assert transport.http_session.request.call_count == 1 + token_manager.refresh_auth_token.assert_not_called() + + def test_only_one_retry_on_persistent_401(self, mock_sinch): + token_manager = _token_manager(mock_sinch) + transport = HTTPTransportRequests(mock_sinch) + transport.http_session.request = Mock(side_effect=[self._expired_401(), self._expired_401()]) + endpoint = _make_mock_endpoint(HTTPAuthentication.OAUTH.value, error_on_4xx=True) + + with pytest.raises(ValidationException, match="401"): + transport.request(endpoint) + + assert transport.http_session.request.call_count == 2 + token_manager.refresh_auth_token.assert_called_once() + + def test_no_refresh_on_successful_request(self, mock_sinch): + token_manager = _token_manager(mock_sinch) + transport = HTTPTransportRequests(mock_sinch) + transport.http_session.request = Mock( + return_value=_requests_response(200, body={"ok": True}) + ) + endpoint = _make_mock_endpoint(HTTPAuthentication.OAUTH.value) + + result = transport.request(endpoint) + + assert result.status_code == 200 + assert transport.http_session.request.call_count == 1 + token_manager.refresh_auth_token.assert_not_called() + + def test_no_refresh_on_401_without_www_authenticate(self, mock_sinch): + token_manager = _token_manager(mock_sinch) + transport = HTTPTransportRequests(mock_sinch) + transport.http_session.request = Mock( + return_value=_requests_response(401, body={}) + ) + endpoint = _make_mock_endpoint(HTTPAuthentication.OAUTH.value, error_on_4xx=True) + + with pytest.raises(ValidationException, match="401"): + transport.request(endpoint) + + assert transport.http_session.request.call_count == 1 + token_manager.refresh_auth_token.assert_not_called() diff --git a/tests/unit/test_token_manager.py b/tests/unit/test_token_manager.py index 502d5ad1..03c7e82a 100644 --- a/tests/unit/test_token_manager.py +++ b/tests/unit/test_token_manager.py @@ -1,3 +1,5 @@ +import threading +import time import pytest from unittest.mock import Mock @@ -30,3 +32,63 @@ def test_get_auth_token_and_check_if_cached(sinch_client_sync, auth_token): assert isinstance(access_token, OAuthToken) assert token_manager.token is auth_token + + +def test_get_auth_token_fetches_once_under_concurrency(auth_token): + num_threads = 20 + barrier = threading.Barrier(num_threads) + sinch = Mock() + + def slow_fetch(endpoint): + time.sleep(0.05) + return auth_token + + sinch.configuration.transport.request.side_effect = slow_fetch + + token_manager = TokenManager(sinch) + + results = [] + + def worker(): + barrier.wait() + results.append(token_manager.get_auth_token()) + + threads = [threading.Thread(target=worker) for _ in range(num_threads)] + for thread in threads: + thread.start() + for thread in threads: + thread.join() + + assert sinch.configuration.transport.request.call_count == 1 + assert all(result is auth_token for result in results) + + +def test_refresh_auth_token_renews_once_under_concurrency(auth_token): + num_threads = 20 + barrier = threading.Barrier(num_threads) + sinch = Mock() + + def slow_fetch(endpoint): + time.sleep(0.05) + return auth_token + + sinch.configuration.transport.request.side_effect = slow_fetch + token_manager = TokenManager(sinch) + token_manager.token = OAuthToken( + access_token="old", expires_in=1, scope="", token_type="bearer" + ) + + results = [] + + def worker(): + barrier.wait() + results.append(token_manager.refresh_auth_token("old")) + + threads = [threading.Thread(target=worker) for _ in range(num_threads)] + for thread in threads: + thread.start() + for thread in threads: + thread.join() + + assert sinch.configuration.transport.request.call_count == 1 + assert all(result is auth_token for result in results)