diff --git a/CHANGELOG.md b/CHANGELOG.md index 517a60bf..4fa44f85 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,11 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Added - New feature: Support for macOS and Linux. - Documentation: Added API documentation in the Wiki. +- Bulk copy now supports `Authentication=ActiveDirectoryServicePrincipal` + via an `entra_id_token_factory` callback registered on the mssql-py-core + connection. The callback is invoked by mssql-tds mid-handshake (FedAuth + workflow 0x02) so the tenant id can be resolved from the server-supplied + STS URL. Requires `mssql-py-core` 0.1.5+. Partial fix for #534. ### Changed - Improved error handling in the connection module. diff --git a/mssql_python/auth.py b/mssql_python/auth.py index 9b488c6d..f564bfb9 100644 --- a/mssql_python/auth.py +++ b/mssql_python/auth.py @@ -4,6 +4,7 @@ This module handles authentication for the mssql_python package. """ +import hashlib import platform import struct import threading @@ -154,6 +155,143 @@ def _acquire_token( raise RuntimeError(f"Failed to create {credential_class.__name__}: {e}") from e +def _parse_tenant_id(sts_url: str) -> Optional[str]: + """Extract tenant ID (GUID or domain) from a FedAuthInfo STS URL. + + Expected formats: + https://login.microsoftonline.com// + https://login.microsoftonline.com//?... + https://login.microsoftonline.com/ + where is either a GUID (e.g. ``aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee``) + or a verified domain (e.g. ``contoso.onmicrosoft.com``). Both forms are + accepted by ``azure.identity.ClientSecretCredential``. + """ + # pylint: disable=import-outside-toplevel + from urllib.parse import urlparse + + try: + parsed = urlparse(sts_url) + except (ValueError, AttributeError): + return None + # Reject anything that isn't an https URL with a netloc. ``urlparse`` will + # happily put a bare string like ``"tenant-guid"`` into ``path``, which + # would then look like a valid tenant. Azure AD STS URLs are always https. + if parsed.scheme != "https" or not parsed.netloc: + return None + path = (parsed.path or "").strip("/") + if not path: + return None + first_segment = path.split("/", 1)[0] + return first_segment or None + + +class ServicePrincipalAuth: + """Builds an ``entra_id_token_factory`` callable for ActiveDirectoryServicePrincipal. + + The bulkcopy path through mssql-py-core uses callback-based token + acquisition (FedAuth workflow ``0x02``) because tenant_id is only known + from the STS URL that the server returns during the TDS handshake. + """ + + @staticmethod + def make_token_factory(client_id: str, client_secret: str): + """Return a callable suitable for ``entra_id_token_factory``. + + Signature: ``(spn: str, sts_url: str, auth_method: str) -> bytes``. + Returns the JWT encoded as UTF-16LE bytes (the TDS FedAuth wire format). + + ``ClientSecretCredential`` instances are reused across calls via the + module-level ``_credential_cache``, keyed by + ``("serviceprincipal", tenant_id, client_id)`` so that azure-identity's + in-memory token cache (which is per-credential-instance) actually + works across handshake retries, reconnects, and separate bulkcopy + invocations using the same identity. + """ + if not client_id: + raise ValueError("ServicePrincipal auth requires a non-empty client_id (UID)") + if not client_secret: + raise ValueError("ServicePrincipal auth requires a non-empty client_secret (PWD)") + + def _factory(spn: str, sts_url: str, auth_method: str) -> bytes: + # pylint: disable=import-outside-toplevel,unused-argument + try: + from azure.identity import ClientSecretCredential + from azure.core.exceptions import ClientAuthenticationError + except ImportError as e: + raise RuntimeError( + "Azure authentication libraries are not installed. " + "Please install with: pip install azure-identity azure-core" + ) from e + + if not spn: + raise RuntimeError( + "ServicePrincipal token factory: empty SPN from server " + "(cannot construct token scope)" + ) + tenant_id = _parse_tenant_id(sts_url) + if not tenant_id: + raise RuntimeError(f"Could not extract tenant_id from STS URL: {sts_url!r}") + + logger.info( + "ServicePrincipal token factory: acquiring token for tenant=%s, spn=%s", + tenant_id, + spn, + ) + try: + # Reuse the shared credential cache (introduced for MSI in PR #573) + # so SP credentials get the same per-instance token reuse semantics + # as the other AD methods. + # + # The cache key includes a hash of client_secret so a rotated + # secret produces a different cache entry. Without this, an + # external secret rotation would not invalidate the cached + # ClientSecretCredential: azure-identity's internal token cache + # would keep returning the previously-issued token (good for + # up to ~1 hour) until expiry, masking the rotation. Hashing + # avoids storing the raw secret in the dict key. + secret_hash = hashlib.sha256(client_secret.encode("utf-8")).hexdigest() + cache_key = _credential_cache_key( + "serviceprincipal", + { + "tenant_id": tenant_id, + "client_id": client_id, + "secret_hash": secret_hash, + }, + ) + with _credential_cache_lock: + credential = _credential_cache.get(cache_key) + if credential is None: + credential = ClientSecretCredential( + tenant_id=tenant_id, + client_id=client_id, + client_secret=client_secret, + ) + _credential_cache[cache_key] = credential + # mssql-tds passes the resource SPN; azure-identity wants a scope. + scope = spn if spn.endswith("/.default") else spn.rstrip("/") + "/.default" + token = credential.get_token(scope).token + logger.info( + "ServicePrincipal token factory: token acquired, length=%d chars", + len(token), + ) + return token.encode("utf-16-le") + except ClientAuthenticationError as e: + # Keep the detailed provider error in debug logs only. The + # surfaced message is intentionally generic so that any + # secret-bearing provider text never reaches the user-facing + # exception chain. + logger.error( + "ServicePrincipal authentication failed: tenant=%s, error=%s", + tenant_id, + str(e), + ) + raise RuntimeError( + "ServicePrincipal authentication failed; " "see debug logs for provider details" + ) from None + + return _factory + + def _extract_msi_client_id(connection_string: str) -> Optional[str]: """Pull UID out of a connection string for user-assigned MSI. @@ -230,6 +368,17 @@ def process_auth_parameters(parameters: List[str]) -> Tuple[List[str], Optional[ # Managed identity authentication (system- or user-assigned) logger.debug("process_auth_parameters: Managed identity authentication detected") auth_type = "msi" + elif value_lower == AuthType.SERVICE_PRINCIPAL.value: + # ServicePrincipal authentication. ODBC (msodbcsql 17.3+) + # handles this natively for regular queries, so leave + # auth_type=None to let ODBC own the query path. + # Bulkcopy still needs the auth type — extract_auth_type() + # propagates it as "serviceprincipal" so the bulkcopy path + # can register an entra_id_token_factory callback (Model B, + # required because tenant_id is only known from the STS URL + # that the server returns during the FedAuth handshake). + logger.debug("process_auth_parameters: Service principal authentication detected") + auth_type = None modified_parameters.append(param) logger.debug( @@ -299,6 +448,7 @@ def extract_auth_type(connection_string: str) -> Optional[str]: AuthType.DEVICE_CODE.value: "devicecode", AuthType.DEFAULT.value: "default", AuthType.MSI.value: "msi", + AuthType.SERVICE_PRINCIPAL.value: "serviceprincipal", } for part in connection_string.split(";"): key, _, value = part.strip().partition("=") @@ -313,13 +463,6 @@ def process_connection_string( """ Process connection string and handle authentication. - NOTE: Returns a 4-tuple. Callers must unpack all four elements. - Destructuring with three names raises ``ValueError: too many values - to unpack``. The fourth element (``credential_kwargs``) is needed by - Connection.__init__ to persist credential constructor args (e.g. the - user-assigned MSI ``client_id``) for the bulkcopy fresh-token path, - since UID is stripped from the sanitized connection string. - Args: connection_string: The connection string to process diff --git a/mssql_python/constants.py b/mssql_python/constants.py index 5de02ece..f9f9331d 100644 --- a/mssql_python/constants.py +++ b/mssql_python/constants.py @@ -338,6 +338,7 @@ class AuthType(Enum): DEVICE_CODE = "activedirectorydevicecode" DEFAULT = "activedirectorydefault" MSI = "activedirectorymsi" + SERVICE_PRINCIPAL = "activedirectoryserviceprincipal" class SQLTypes: diff --git a/mssql_python/cursor.py b/mssql_python/cursor.py index ece27c61..9915eea2 100644 --- a/mssql_python/cursor.py +++ b/mssql_python/cursor.py @@ -2933,31 +2933,60 @@ def bulkcopy( # Token acquisition — only thing cursor must handle (needs azure-identity SDK) if self.connection._auth_type: - # Fresh token acquisition for mssql-py-core connection. credential - # kwargs (e.g. user-assigned MSI client_id) were captured by - # Connection.__init__ before remove_sensitive_params stripped UID - # from connection_str — re-parsing here would miss them. - from mssql_python.auth import AADAuth - - try: - raw_token = AADAuth.get_raw_token( + # Fresh token acquisition for mssql-py-core connection + from mssql_python.auth import AADAuth, ServicePrincipalAuth + + if self.connection._auth_type == "serviceprincipal": + # Model B: callback-based. tenant_id is only known from the + # STS URL the server returns mid-handshake, so we register a + # factory that py-core invokes during FedAuth (workflow 0x02). + client_id = params.get("uid", "") + client_secret = params.get("pwd", "") + if not client_id or not client_secret: + raise RuntimeError( + "Bulk copy with Authentication=ActiveDirectoryServicePrincipal " + "requires UID (client_id) and PWD (client_secret) in the " + "connection string." + ) + try: + factory = ServicePrincipalAuth.make_token_factory(client_id, client_secret) + except (RuntimeError, ValueError) as e: + raise RuntimeError( + f"Bulk copy failed: unable to build ServicePrincipal token factory: {e}" + ) from e + pycore_context["entra_id_token_factory"] = factory + # Keep authentication/user_name/password in pycore_context — + # py-core's auth validator + transformer need them to resolve + # the auth method to ActiveDirectoryServicePrincipal before + # the factory is dispatched at handshake time. + logger.debug("Bulk copy: registered ServicePrincipal token factory") + else: + # Model A: pre-acquired token. Used for Default, DeviceCode, + # Interactive (non-Windows), MSI (system- or user-assigned), + # and any other AD method whose tenant_id is discoverable + # client-side via Azure Identity SDK. credential kwargs + # (e.g. user-assigned MSI client_id) were captured by + # Connection.__init__ before remove_sensitive_params stripped + # UID from connection_str — re-parsing here would miss them. + try: + raw_token = AADAuth.get_raw_token( + self.connection._auth_type, + self.connection._credential_kwargs, + ) + except (RuntimeError, ValueError) as e: + raise RuntimeError( + f"Bulk copy failed: unable to acquire Azure AD token " + f"for auth_type '{self.connection._auth_type}': {e}" + ) from e + pycore_context["access_token"] = raw_token + # Token replaces credential fields — py-core's validator rejects + # access_token combined with authentication/user_name/password. + for key in ("authentication", "user_name", "password"): + pycore_context.pop(key, None) + logger.debug( + "Bulk copy: acquired fresh Azure AD token for auth_type=%s", self.connection._auth_type, - self.connection._credential_kwargs, ) - except (RuntimeError, ValueError) as e: - raise RuntimeError( - f"Bulk copy failed: unable to acquire Azure AD token " - f"for auth_type '{self.connection._auth_type}': {e}" - ) from e - pycore_context["access_token"] = raw_token - # Token replaces credential fields — py-core's validator rejects - # access_token combined with authentication/user_name/password. - for key in ("authentication", "user_name", "password"): - pycore_context.pop(key, None) - logger.debug( - "Bulk copy: acquired fresh Azure AD token for auth_type=%s", - self.connection._auth_type, - ) pycore_connection = None pycore_cursor = None @@ -3007,9 +3036,17 @@ def bulkcopy( raise type(e)(str(e)) from None finally: - # Clear sensitive data to minimize memory exposure + # Clear sensitive data to minimize memory exposure. The + # entra_id_token_factory closure captures client_secret, so drop + # our dict reference to it (Rust still holds an Arc until the + # connection is dropped, but at least we don't keep an extra ref). if pycore_context: - for key in ("password", "user_name", "access_token"): + for key in ( + "password", + "user_name", + "access_token", + "entra_id_token_factory", + ): pycore_context.pop(key, None) # Clean up bulk copy resources for resource in (pycore_cursor, pycore_connection): diff --git a/tests/test_008_auth.py b/tests/test_008_auth.py index f8df6f6f..7d013abb 100644 --- a/tests/test_008_auth.py +++ b/tests/test_008_auth.py @@ -11,6 +11,8 @@ from unittest.mock import patch, MagicMock from mssql_python.auth import ( AADAuth, + ServicePrincipalAuth, + _parse_tenant_id, process_auth_parameters, remove_sensitive_params, get_auth_token, @@ -44,6 +46,20 @@ class MockInteractiveBrowserCredential: def get_token(self, scope): return MockToken() + class MockClientSecretCredential: + # Captures construction kwargs and get_token args so ServicePrincipal + # tests can assert the right tenant/client_id/secret/scope flowed + # through from the connection string + STS URL. + last_init_kwargs = None + last_scope = None + + def __init__(self, **kwargs): + MockClientSecretCredential.last_init_kwargs = kwargs + + def get_token(self, scope): + MockClientSecretCredential.last_scope = scope + return MockToken() + class MockManagedIdentityCredential: # Captures construction kwargs so user-assigned MSI tests can assert # client_id was forwarded correctly. @@ -63,6 +79,7 @@ class MockIdentity: DefaultAzureCredential = MockDefaultAzureCredential DeviceCodeCredential = MockDeviceCodeCredential InteractiveBrowserCredential = MockInteractiveBrowserCredential + ClientSecretCredential = MockClientSecretCredential ManagedIdentityCredential = MockManagedIdentityCredential class MockCore: @@ -100,6 +117,7 @@ def test_auth_type_constants(self): assert AuthType.DEVICE_CODE.value == "activedirectorydevicecode" assert AuthType.DEFAULT.value == "activedirectorydefault" assert AuthType.MSI.value == "activedirectorymsi" + assert AuthType.SERVICE_PRINCIPAL.value == "activedirectoryserviceprincipal" class TestAADAuth: @@ -330,6 +348,20 @@ def test_default_auth(self): _, auth_type = process_auth_parameters(params) assert auth_type == "default" + def test_service_principal_auth_leaves_odbc_path_alone(self): + """ServicePrincipal is handled natively by ODBC. process_auth_parameters + must return auth_type=None so the ODBC path doesn't pre-acquire a token + (which would require tenant_id we don't have client-side).""" + params = ["Authentication=ActiveDirectoryServicePrincipal", "Server=test"] + modified_params, auth_type = process_auth_parameters(params) + assert "Authentication=ActiveDirectoryServicePrincipal" in modified_params + assert auth_type is None + + def test_service_principal_auth_case_insensitive(self): + params = ["authentication=activedirectoryserviceprincipal", "Server=test"] + _, auth_type = process_auth_parameters(params) + assert auth_type is None + def test_msi_auth(self): params = ["Authentication=ActiveDirectoryMSI", "Server=test"] _, auth_type = process_auth_parameters(params) @@ -433,6 +465,12 @@ def test_devicecode(self): == "devicecode" ) + def test_serviceprincipal(self): + assert ( + extract_auth_type("Server=test;Authentication=ActiveDirectoryServicePrincipal;") + == "serviceprincipal" + ) + def test_msi(self): assert extract_auth_type("Server=test;Authentication=ActiveDirectoryMSI;") == "msi" @@ -1012,3 +1050,301 @@ def __init__(self): assert credential_kwargs is None finally: azure_identity.DefaultAzureCredential = original + + +class TestParseTenantId: + def test_guid_tenant(self): + url = "https://login.microsoftonline.com/aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee/" + assert _parse_tenant_id(url) == "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee" + + def test_guid_tenant_no_trailing_slash(self): + url = "https://login.microsoftonline.com/aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee" + assert _parse_tenant_id(url) == "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee" + + def test_domain_tenant(self): + url = "https://login.microsoftonline.com/contoso.onmicrosoft.com/" + assert _parse_tenant_id(url) == "contoso.onmicrosoft.com" + + def test_tenant_with_query_string(self): + url = "https://login.microsoftonline.com/tenant-guid/?foo=bar" + assert _parse_tenant_id(url) == "tenant-guid" + + def test_extra_path_segments_after_tenant(self): + url = "https://login.microsoftonline.com/tenant-guid/oauth2/authorize" + assert _parse_tenant_id(url) == "tenant-guid" + + def test_empty_string(self): + assert _parse_tenant_id("") is None + + def test_no_path(self): + assert _parse_tenant_id("https://login.microsoftonline.com/") is None + + def test_rejects_bare_string_without_scheme(self): + # urlparse puts a bare string into path; without a scheme/netloc check + # this would be silently treated as a tenant id. + assert _parse_tenant_id("tenant-guid") is None + + def test_rejects_path_only_url(self): + assert _parse_tenant_id("/tenant-guid/oauth2") is None + + def test_rejects_http_scheme(self): + # Azure AD STS URLs are always https. Reject http to avoid trusting + # a downgraded URL. + assert _parse_tenant_id("http://login.microsoftonline.com/tenant/") is None + + +class TestServicePrincipalAuth: + """Tests for the ActiveDirectoryServicePrincipal token factory.""" + + def test_make_token_factory_returns_callable(self): + factory = ServicePrincipalAuth.make_token_factory("client-id", "client-secret") + assert callable(factory) + + def test_factory_requires_client_id(self): + with pytest.raises(ValueError, match="client_id"): + ServicePrincipalAuth.make_token_factory("", "client-secret") + + def test_factory_requires_client_secret(self): + with pytest.raises(ValueError, match="client_secret"): + ServicePrincipalAuth.make_token_factory("client-id", "") + + def test_factory_returns_utf16le_bytes(self): + factory = ServicePrincipalAuth.make_token_factory("client-id", "client-secret") + result = factory( + "https://database.windows.net/", + "https://login.microsoftonline.com/tenant-guid/", + "activedirectoryserviceprincipal", + ) + assert isinstance(result, bytes) + # SAMPLE_TOKEN is hex chars (ASCII). UTF-16LE encoding doubles each byte + # and inserts a 0x00 high byte after each ASCII char. + assert result == SAMPLE_TOKEN.encode("utf-16-le") + assert len(result) == len(SAMPLE_TOKEN) * 2 + + def test_factory_forwards_credentials_to_ClientSecretCredential(self): + az = sys.modules["azure.identity"] + az.ClientSecretCredential.last_init_kwargs = None + az.ClientSecretCredential.last_scope = None + + factory = ServicePrincipalAuth.make_token_factory( + "11111111-2222-3333-4444-555555555555", "my-secret" + ) + factory( + "https://database.windows.net/", + "https://login.microsoftonline.com/aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee/", + "activedirectoryserviceprincipal", + ) + + assert az.ClientSecretCredential.last_init_kwargs == { + "tenant_id": "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee", + "client_id": "11111111-2222-3333-4444-555555555555", + "client_secret": "my-secret", + } + + def test_factory_builds_scope_from_spn(self): + az = sys.modules["azure.identity"] + az.ClientSecretCredential.last_scope = None + + factory = ServicePrincipalAuth.make_token_factory("cid", "secret") + factory( + "https://database.windows.net/", + "https://login.microsoftonline.com/tenant/", + "activedirectoryserviceprincipal", + ) + assert az.ClientSecretCredential.last_scope == "https://database.windows.net/.default" + + def test_factory_keeps_existing_default_suffix(self): + az = sys.modules["azure.identity"] + az.ClientSecretCredential.last_scope = None + + factory = ServicePrincipalAuth.make_token_factory("cid", "secret") + factory( + "https://database.windows.net/.default", + "https://login.microsoftonline.com/tenant/", + "activedirectoryserviceprincipal", + ) + assert az.ClientSecretCredential.last_scope == "https://database.windows.net/.default" + + def test_factory_errors_on_unparseable_sts_url(self): + factory = ServicePrincipalAuth.make_token_factory("cid", "secret") + with pytest.raises(RuntimeError, match="Could not extract tenant_id"): + factory( + "https://database.windows.net/", + "https://login.microsoftonline.com/", # no tenant segment + "activedirectoryserviceprincipal", + ) + + def test_factory_propagates_authentication_error(self): + from azure.core.exceptions import ClientAuthenticationError + + class FailingCred: + def __init__(self, **kwargs): + pass + + def get_token(self, scope): + raise ClientAuthenticationError("AADSTS7000215: Invalid client secret") + + original = sys.modules["azure.identity"].ClientSecretCredential + sys.modules["azure.identity"].ClientSecretCredential = FailingCred + try: + factory = ServicePrincipalAuth.make_token_factory("cid", "secret") + with pytest.raises(RuntimeError, match="ServicePrincipal authentication failed"): + factory( + "https://database.windows.net/", + "https://login.microsoftonline.com/tenant-guid/", + "activedirectoryserviceprincipal", + ) + finally: + sys.modules["azure.identity"].ClientSecretCredential = original + + def test_factory_does_not_leak_provider_message_in_runtime_error(self): + """The user-facing RuntimeError must not echo the provider message + (which can carry tenant ids, claims, or other sensitive context). + Provider detail is preserved in debug logs only.""" + from azure.core.exceptions import ClientAuthenticationError + + secret_marker = "AADSTS7000215_SECRET_MARKER_in_provider_message" + + class FailingCred: + def __init__(self, **kwargs): + pass + + def get_token(self, scope): + raise ClientAuthenticationError(secret_marker) + + original = sys.modules["azure.identity"].ClientSecretCredential + sys.modules["azure.identity"].ClientSecretCredential = FailingCred + try: + factory = ServicePrincipalAuth.make_token_factory("cid", "secret") + try: + factory( + "https://database.windows.net/", + "https://login.microsoftonline.com/tenant-guid/", + "activedirectoryserviceprincipal", + ) + except RuntimeError as e: + full_chain = str(e) + cause = e.__cause__ + while cause is not None: + full_chain += " || " + str(cause) + cause = getattr(cause, "__cause__", None) + assert ( + secret_marker not in full_chain + ), f"Provider message leaked into surfaced exception chain: {full_chain}" + finally: + sys.modules["azure.identity"].ClientSecretCredential = original + + def test_factory_rejects_empty_spn(self): + factory = ServicePrincipalAuth.make_token_factory("cid", "secret") + with pytest.raises(RuntimeError, match="empty SPN"): + factory( + "", + "https://login.microsoftonline.com/tenant-guid/", + "activedirectoryserviceprincipal", + ) + + def test_factory_caches_credential_per_tenant(self): + """ClientSecretCredential must be reused across calls for the same + tenant so azure-identity's per-instance token cache actually works.""" + az = sys.modules["azure.identity"] + construction_count = {"n": 0} + + original = az.ClientSecretCredential + + class _Tok: + token = SAMPLE_TOKEN + + class CountingCred: + def __init__(self, **kwargs): + construction_count["n"] += 1 + + def get_token(self, scope): + return _Tok() + + az.ClientSecretCredential = CountingCred + try: + factory = ServicePrincipalAuth.make_token_factory("cid", "secret") + sts = "https://login.microsoftonline.com/tenant-guid/" + for _ in range(3): + factory("https://database.windows.net/", sts, "activedirectoryserviceprincipal") + assert construction_count["n"] == 1, ( + f"Expected 1 ClientSecretCredential construction across 3 calls, " + f"got {construction_count['n']}" + ) + # A different tenant should produce a second instance. + factory( + "https://database.windows.net/", + "https://login.microsoftonline.com/other-tenant/", + "activedirectoryserviceprincipal", + ) + assert construction_count["n"] == 2 + finally: + az.ClientSecretCredential = original + + def test_factory_rotates_credential_when_secret_changes(self): + """A new client_secret for the same tenant+client_id MUST produce a new + ClientSecretCredential instance. Without this, an external secret + rotation would not invalidate the cached credential: azure-identity's + internal token cache would keep returning the previously-issued token + (good for up to ~1 hour) until expiry, masking the rotation.""" + az = sys.modules["azure.identity"] + construction_count = {"n": 0} + + original = az.ClientSecretCredential + + class _Tok: + token = SAMPLE_TOKEN + + class CountingCred: + def __init__(self, **kwargs): + construction_count["n"] += 1 + + def get_token(self, scope): + return _Tok() + + az.ClientSecretCredential = CountingCred + try: + sts = "https://login.microsoftonline.com/tenant-guid/" + spn = "https://database.windows.net/" + + # Old secret, two calls -> 1 construction (cached) + factory_old = ServicePrincipalAuth.make_token_factory("cid", "old-secret") + factory_old(spn, sts, "activedirectoryserviceprincipal") + factory_old(spn, sts, "activedirectoryserviceprincipal") + assert construction_count["n"] == 1 + + # Rotate the secret. Same tenant + client_id, different secret. + # MUST produce a fresh ClientSecretCredential so azure-identity + # cannot serve a stale token from its internal cache. + factory_new = ServicePrincipalAuth.make_token_factory("cid", "new-secret") + factory_new(spn, sts, "activedirectoryserviceprincipal") + assert construction_count["n"] == 2, ( + f"Expected 2 ClientSecretCredential constructions after secret rotation, " + f"got {construction_count['n']}. A rotated secret was silently ignored." + ) + + # Calling the new factory again should hit cache (1 more = 2 total) + factory_new(spn, sts, "activedirectoryserviceprincipal") + assert construction_count["n"] == 2 + + # Calling the OLD factory again should still hit the OLD cache entry + # (it's keyed on the hash of "old-secret"), not construct again. + factory_old(spn, sts, "activedirectoryserviceprincipal") + assert construction_count["n"] == 2 + finally: + az.ClientSecretCredential = original + + def test_factory_cache_key_does_not_contain_raw_secret(self): + """The cache key must hash the secret, never store it raw. Otherwise + the secret is visible in process memory as part of the dict key.""" + from mssql_python.auth import _credential_cache + + secret_marker = "RAW_SECRET_MARKER_must_not_appear_in_cache_key" + factory = ServicePrincipalAuth.make_token_factory("cid", secret_marker) + factory( + "https://database.windows.net/", + "https://login.microsoftonline.com/tenant-guid/", + "activedirectoryserviceprincipal", + ) + for key in _credential_cache.keys(): + assert secret_marker not in repr(key), f"Raw secret leaked into cache key: {key!r}"