diff --git a/setup.py b/setup.py index 09f844d2..435f3652 100644 --- a/setup.py +++ b/setup.py @@ -7,7 +7,7 @@ if sys.version_info < (3, 8): raise RuntimeError("skyflow requires Python 3.8+") -current_version = '2.0.0' +current_version = '2.0.0.dev0+e2aa629' setup( name='skyflow', diff --git a/skyflow/error/_skyflow_error.py b/skyflow/error/_skyflow_error.py index 7b917fae..fca43935 100644 --- a/skyflow/error/_skyflow_error.py +++ b/skyflow/error/_skyflow_error.py @@ -15,5 +15,4 @@ def __init__(self, self.http_status = http_status if http_status else SkyflowMessages.HttpStatus.BAD_REQUEST.value self.details = details self.request_id = request_id - log_error(message, http_code, request_id, grpc_code, http_status, details) super().__init__() \ No newline at end of file diff --git a/skyflow/utils/_skyflow_messages.py b/skyflow/utils/_skyflow_messages.py index 3672cfa8..8aea3b8b 100644 --- a/skyflow/utils/_skyflow_messages.py +++ b/skyflow/utils/_skyflow_messages.py @@ -47,7 +47,7 @@ class Error(Enum): EMPTY_CREDENTIALS_TOKEN = f"{error_prefix} Initialization failed. Invalid token.Specify a valid credentials token." INVALID_CREDENTIALS_TOKEN_IN_CONFIG = f"{error_prefix} Initialization failed. Invalid credentials token for {{}} with id {{}}. Expected token to be a string." INVALID_CREDENTIALS_TOKEN = f"{error_prefix} Initialization failed. Invalid credentials token. Expected token to be a string." - EXPIRED_TOKEN = f"${error_prefix} Initialization failed. Given token is expired. Specify a valid credentials token." + EXPIRED_TOKEN = f"{error_prefix} Initialization failed. Given token is expired. Specify a valid credentials token." EMPTY_API_KEY_IN_CONFIG = f"{error_prefix} Initialization failed. Invalid api key for {{}} with id {{}}.Specify a valid api key." EMPTY_API_KEY= f"{error_prefix} Initialization failed. Invalid api key.Specify a valid api key." INVALID_API_KEY_IN_CONFIG = f"{error_prefix} Initialization failed. Invalid api key for {{}} with id {{}}. Expected api key to be a string." diff --git a/skyflow/utils/_utils.py b/skyflow/utils/_utils.py index 4278357e..d8eedca2 100644 --- a/skyflow/utils/_utils.py +++ b/skyflow/utils/_utils.py @@ -30,26 +30,18 @@ invalid_input_error_code = SkyflowMessages.ErrorCodes.INVALID_INPUT.value def get_credentials(config_level_creds = None, common_skyflow_creds = None, logger = None): - dotenv.load_dotenv() - dotenv_path = dotenv.find_dotenv(usecwd=True) - if dotenv_path: - load_dotenv(dotenv_path) - env_skyflow_credentials = os.getenv("SKYFLOW_CREDENTIALS") if config_level_creds: return config_level_creds if common_skyflow_creds: return common_skyflow_creds + dotenv_path = dotenv.find_dotenv(usecwd=True) + if dotenv_path: + load_dotenv(dotenv_path) + env_skyflow_credentials = os.getenv("SKYFLOW_CREDENTIALS") if env_skyflow_credentials: - env_skyflow_credentials.strip() - try: - env_creds = env_skyflow_credentials.replace('\n', '\\n') - return { - 'credentials_string': env_creds - } - except json.JSONDecodeError: - raise SkyflowError(SkyflowMessages.Error.INVALID_JSON_FORMAT_IN_CREDENTIALS_ENV.value, invalid_input_error_code) - else: - raise SkyflowError(SkyflowMessages.Error.INVALID_CREDENTIALS.value, invalid_input_error_code) + env_creds = env_skyflow_credentials.strip().replace('\n', '\\n') + return {'credentials_string': env_creds} + raise SkyflowError(SkyflowMessages.Error.INVALID_CREDENTIALS.value, invalid_input_error_code) def validate_api_key(api_key: str, logger = None) -> bool: if len(api_key) != 42: @@ -185,8 +177,12 @@ def get_data_from_content_type(data, content_type): return converted_data, files +_CACHED_METRICS: dict = {} + def get_metrics(): - sdk_name_version = "skyflow-python@" + SDK_VERSION + global _CACHED_METRICS + if _CACHED_METRICS: + return _CACHED_METRICS try: sdk_client_device_model = platform.node() @@ -203,13 +199,13 @@ def get_metrics(): except Exception: sdk_runtime_details = "" - details_dic = { - 'sdk_name_version': sdk_name_version, + _CACHED_METRICS = { + 'sdk_name_version': "skyflow-python@" + SDK_VERSION, 'sdk_client_device_model': sdk_client_device_model, 'sdk_client_os_details': sdk_client_os_details, 'sdk_runtime_details': "Python " + sdk_runtime_details, } - return details_dic + return _CACHED_METRICS def parse_insert_response(api_response, continue_on_error): # Retrieve the headers and data from the API response diff --git a/skyflow/utils/_version.py b/skyflow/utils/_version.py index 0d05fc30..dd1e4b72 100644 --- a/skyflow/utils/_version.py +++ b/skyflow/utils/_version.py @@ -1 +1 @@ -SDK_VERSION = '2.0.0' \ No newline at end of file +SDK_VERSION = '2.0.0.dev0+e2aa629' \ No newline at end of file diff --git a/skyflow/utils/validations/_validations.py b/skyflow/utils/validations/_validations.py index f3428f45..acca531f 100644 --- a/skyflow/utils/validations/_validations.py +++ b/skyflow/utils/validations/_validations.py @@ -122,8 +122,8 @@ def validate_credentials(logger, credentials, config_id_type=None, config_id=Non ) if is_expired(credentials.get("token"), logger): raise SkyflowError( - SkyflowMessages.Error.INVALID_CREDENTIALS_TOKEN.value.format(config_id_type, config_id) - if config_id_type and config_id else SkyflowMessages.Error.INVALID_CREDENTIALS_TOKEN.value, + SkyflowMessages.Error.EXPIRED_TOKEN.value + if config_id_type and config_id else SkyflowMessages.Error.EXPIRED_TOKEN.value, invalid_input_error_code ) elif "api_key" in credentials: @@ -389,7 +389,7 @@ def validate_deidentify_file_request(logger, request: DeidentifyFileRequest): if hasattr(request, 'wait_time') and request.wait_time is not None: if not isinstance(request.wait_time, (int, float)): raise SkyflowError(SkyflowMessages.Error.INVALID_WAIT_TIME.value, invalid_input_error_code) - if request.wait_time < 0 and request.wait_time > 64: + if request.wait_time < 0 or request.wait_time > 64: raise SkyflowError(SkyflowMessages.Error.WAIT_TIME_GREATER_THEN_64.value, invalid_input_error_code) def validate_insert_request(logger, request): diff --git a/skyflow/vault/client/client.py b/skyflow/vault/client/client.py index f47a525c..0304c11a 100644 --- a/skyflow/vault/client/client.py +++ b/skyflow/vault/client/client.py @@ -14,6 +14,9 @@ def __init__(self, config): self.__logger = None self.__is_config_updated = False self.__bearer_token = None + self.__credentials = None + self.__vault_url = None + self.__is_static_token = None def set_common_skyflow_credentials(self, credentials): self.__common_skyflow_credentials = credentials @@ -23,16 +26,27 @@ def set_logger(self, log_level, logger): self.__logger = logger def initialize_client_configuration(self): - credentials = get_credentials(self.__config.get("credentials"), self.__common_skyflow_credentials, logger = self.__logger) - token = self.get_bearer_token(credentials) - vault_url = get_vault_url(self.__config.get("cluster_id"), - self.__config.get("env"), - self.__config.get("vault_id"), - logger = self.__logger) - self.initialize_api_client(vault_url, token) - - def initialize_api_client(self, vault_url, token): - self.__api_client = Skyflow(base_url=vault_url, token=token) + if self.__api_client is not None and not self.__is_config_updated: + if self.__is_static_token: + return + if self.__bearer_token is not None and not is_expired(self.__bearer_token): + return + + needs_reinit = self.__api_client is None or self.__is_config_updated + if needs_reinit: + self.__credentials = get_credentials(self.__config.get("credentials"), self.__common_skyflow_credentials, logger=self.__logger) + self.__vault_url = get_vault_url(self.__config.get("cluster_id"), + self.__config.get("env"), + self.__config.get("vault_id"), + logger=self.__logger) + self.__is_static_token = 'token' in self.__credentials or 'api_key' in self.__credentials + bearer_token = self.get_bearer_token(self.__credentials) + if needs_reinit: + self.initialize_api_client(self.__vault_url, bearer_token) + + def initialize_api_client(self, vault_url, bearer_token): + token_provider = lambda: self.__bearer_token if self.__bearer_token else bearer_token # noqa: E731 + self.__api_client = Skyflow(base_url=vault_url, token=token_provider) def get_records_api(self): return self.__api_client.records @@ -63,11 +77,10 @@ def get_bearer_token(self, credentials): "ctx": self.__config.get("ctx") } - if self.__bearer_token is None or self.__is_config_updated: + if self.__bearer_token is None or self.__is_config_updated or is_expired(self.__bearer_token): if 'path' in credentials: - path = credentials.get("path") self.__bearer_token, _ = generate_bearer_token( - path, + credentials.get("path"), options, self.__logger ) @@ -83,10 +96,6 @@ def get_bearer_token(self, credentials): else: log_info(SkyflowMessages.Info.REUSE_BEARER_TOKEN.value, self.__logger) - if is_expired(self.__bearer_token): - self.__is_config_updated = True - raise SyntaxError(SkyflowMessages.Error.EXPIRED_TOKEN.value, SkyflowMessages.ErrorCodes.INVALID_INPUT.value) - return self.__bearer_token def update_config(self, config): diff --git a/skyflow/vault/controller/_detect.py b/skyflow/vault/controller/_detect.py index 44ef2540..cb5e8836 100644 --- a/skyflow/vault/controller/_detect.py +++ b/skyflow/vault/controller/_detect.py @@ -62,7 +62,7 @@ def __poll_for_processed_file(self, run_id, max_wait_time=64): current_wait_time = 1 # Start with 1 second try: while True: - response = files_api.get_run(run_id, vault_id=self.__vault_client.get_vault_id(), request_options=self.__get_headers()).data + response = files_api.get_run(run_id, vault_id=self.__vault_client.get_vault_id(), request_options={'additional_headers': self.__get_headers()}).data status = response.status if status == 'IN_PROGRESS': if current_wait_time >= max_wait_time: @@ -228,7 +228,7 @@ def deidentify_text(self, request: DeidentifyTextRequest) -> DeidentifyTextRespo restrict_regex=deidentify_text_body['restrict_regex'], token_type=deidentify_text_body['token_type'], transformations=deidentify_text_body['transformations'], - request_options=self.__get_headers() + request_options={'additional_headers': self.__get_headers()} ) deidentify_text_response = parse_deidentify_text_response(api_response) log_info(SkyflowMessages.Info.DEIDENTIFY_TEXT_SUCCESS.value, self.__vault_client.get_logger()) @@ -252,7 +252,7 @@ def reidentify_text(self, request: ReidentifyTextRequest) -> ReidentifyTextRespo vault_id=self.__vault_client.get_vault_id(), text=reidentify_text_body['text'], format=reidentify_text_body['format'], - request_options=self.__get_headers() + request_options={'additional_headers': self.__get_headers()} ) reidentify_text_response = parse_reidentify_text_response(api_response) log_info(SkyflowMessages.Info.REIDENTIFY_TEXT_SUCCESS.value, self.__vault_client.get_logger()) @@ -296,7 +296,7 @@ def deidentify_file(self, request: DeidentifyFileRequest): 'allow_regex': request.allow_regex_list, 'restrict_regex': request.restrict_regex_list, 'transformations': self.__get_transformations(request), - 'request_options': self.__get_headers() + 'request_options': {'additional_headers': self.__get_headers()} } elif file_extension in ['mp3', 'wav']: @@ -316,7 +316,7 @@ def deidentify_file(self, request: DeidentifyFileRequest): 'bleep_frequency': getattr(request, 'bleep', None).frequency if getattr(request, 'bleep', None) is not None else None, 'bleep_start_padding': getattr(request, 'bleep', None).start_padding if getattr(request, 'bleep', None) is not None else None, 'bleep_stop_padding': getattr(request, 'bleep', None).stop_padding if getattr(request, 'bleep', None) is not None else None, - 'request_options': self.__get_headers() + 'request_options': {'additional_headers': self.__get_headers()} } elif file_extension == 'pdf': @@ -331,7 +331,7 @@ def deidentify_file(self, request: DeidentifyFileRequest): 'restrict_regex': request.restrict_regex_list, 'max_resolution': getattr(request, 'max_resolution', None), 'density': getattr(request, 'pixel_density', None), - 'request_options': self.__get_headers() + 'request_options': {'additional_headers': self.__get_headers()} } elif file_extension in ['jpeg', 'jpg', 'png', 'bmp', 'tif', 'tiff']: @@ -347,7 +347,7 @@ def deidentify_file(self, request: DeidentifyFileRequest): 'masking_method': getattr(request, 'masking_method', None), 'output_ocr_text': getattr(request, 'output_ocr_text', None), 'output_processed_image': getattr(request, 'output_processed_image', None), - 'request_options': self.__get_headers() + 'request_options': {'additional_headers': self.__get_headers()} } elif file_extension in ['ppt', 'pptx']: @@ -360,7 +360,7 @@ def deidentify_file(self, request: DeidentifyFileRequest): 'token_type': self.__get_token_format(request), 'allow_regex': request.allow_regex_list, 'restrict_regex': request.restrict_regex_list, - 'request_options': self.__get_headers() + 'request_options': {'additional_headers': self.__get_headers()} } elif file_extension in ['csv', 'xls', 'xlsx']: @@ -373,7 +373,7 @@ def deidentify_file(self, request: DeidentifyFileRequest): 'token_type': self.__get_token_format(request), 'allow_regex': request.allow_regex_list, 'restrict_regex': request.restrict_regex_list, - 'request_options': self.__get_headers() + 'request_options': {'additional_headers': self.__get_headers()} } elif file_extension in ['doc', 'docx']: @@ -386,7 +386,7 @@ def deidentify_file(self, request: DeidentifyFileRequest): 'token_type': self.__get_token_format(request), 'allow_regex': request.allow_regex_list, 'restrict_regex': request.restrict_regex_list, - 'request_options': self.__get_headers() + 'request_options': {'additional_headers': self.__get_headers()} } elif file_extension in ['json', 'xml']: @@ -400,7 +400,7 @@ def deidentify_file(self, request: DeidentifyFileRequest): 'allow_regex': request.allow_regex_list, 'restrict_regex': request.restrict_regex_list, 'transformations': self.__get_transformations(request), - 'request_options': self.__get_headers() + 'request_options': {'additional_headers': self.__get_headers()} } else: @@ -414,7 +414,7 @@ def deidentify_file(self, request: DeidentifyFileRequest): 'allow_regex': request.allow_regex_list, 'restrict_regex': request.restrict_regex_list, 'transformations': self.__get_transformations(request), - 'request_options': self.__get_headers() + 'request_options': {'additional_headers': self.__get_headers()} } log_info(SkyflowMessages.Info.DETECT_FILE_REQUEST_RESOLVED.value, self.__vault_client.get_logger()) @@ -448,7 +448,7 @@ def get_detect_run(self, request: GetDetectRunRequest): response = files_api.get_run( run_id, vault_id=self.__vault_client.get_vault_id(), - request_options=self.__get_headers() + request_options={'additional_headers': self.__get_headers()} ) if response.data.status == 'IN_PROGRESS': parsed_response = self.__parse_deidentify_file_response(DeidentifyFileResponse(run_id=run_id, status='IN_PROGRESS')) diff --git a/skyflow/vault/controller/_vault.py b/skyflow/vault/controller/_vault.py index 7cc9ec77..c757730a 100644 --- a/skyflow/vault/controller/_vault.py +++ b/skyflow/vault/controller/_vault.py @@ -89,10 +89,9 @@ def __get_file_for_file_upload(self, request: FileUploadRequest) -> Optional[Fil return None def __get_headers(self): - headers = { - SKY_META_DATA_HEADER: json.dumps(get_metrics()) - } - return headers + if not hasattr(self, '_cached_headers'): + self._cached_headers = {SKY_META_DATA_HEADER: json.dumps(get_metrics())} + return self._cached_headers def insert(self, request: InsertRequest): log_info(SkyflowMessages.Info.VALIDATE_INSERT_REQUEST.value, self.__vault_client.get_logger()) @@ -106,11 +105,11 @@ def insert(self, request: InsertRequest): log_info(SkyflowMessages.Info.INSERT_TRIGGERED.value, self.__vault_client.get_logger()) if request.continue_on_error: api_response = records_api.record_service_batch_operation(self.__vault_client.get_vault_id(), - records=insert_body, continue_on_error=request.continue_on_error, byot=request.token_mode.value, request_options=self.__get_headers()) + records=insert_body, continue_on_error=request.continue_on_error, byot=request.token_mode.value, request_options={'additional_headers': self.__get_headers()}) else: api_response = records_api.record_service_insert_record(self.__vault_client.get_vault_id(), - request.table, records=insert_body,tokenization= request.return_tokens, upsert=request.upsert, homogeneous=request.homogeneous, byot=request.token_mode.value, request_options=self.__get_headers()) + request.table, records=insert_body,tokenization= request.return_tokens, upsert=request.upsert, homogeneous=request.homogeneous, byot=request.token_mode.value, request_options={'additional_headers': self.__get_headers()}) insert_response = parse_insert_response(api_response, request.continue_on_error) log_info(SkyflowMessages.Info.INSERT_SUCCESS.value, self.__vault_client.get_logger()) @@ -138,7 +137,7 @@ def update(self, request: UpdateRequest): record=record, tokenization=request.return_tokens, byot=request.token_mode.value, - request_options = self.__get_headers() + request_options={'additional_headers': self.__get_headers()} ) log_info(SkyflowMessages.Info.UPDATE_SUCCESS.value, self.__vault_client.get_logger()) update_response = parse_update_record_response(api_response) @@ -159,7 +158,7 @@ def delete(self, request: DeleteRequest): self.__vault_client.get_vault_id(), request.table, skyflow_ids=request.ids, - request_options=self.__get_headers() + request_options={'additional_headers': self.__get_headers()} ) log_info(SkyflowMessages.Info.DELETE_SUCCESS.value, self.__vault_client.get_logger()) delete_response = parse_delete_response(api_response) @@ -189,7 +188,7 @@ def get(self, request: GetRequest): download_url=request.download_url, column_name=request.column_name, column_values=request.column_values, - request_options=self.__get_headers() + request_options={'additional_headers': self.__get_headers()} ) log_info(SkyflowMessages.Info.GET_SUCCESS.value, self.__vault_client.get_logger()) get_response = parse_get_response(api_response) @@ -209,7 +208,7 @@ def query(self, request: QueryRequest): api_response = query_api.query_service_execute_query( self.__vault_client.get_vault_id(), query=request.query, - request_options=self.__get_headers() + request_options={'additional_headers': self.__get_headers()} ) log_info(SkyflowMessages.Info.QUERY_SUCCESS.value, self.__vault_client.get_logger()) query_response = parse_query_response(api_response) @@ -237,7 +236,7 @@ def detokenize(self, request: DetokenizeRequest): self.__vault_client.get_vault_id(), detokenization_parameters=tokens_list, continue_on_error = request.continue_on_error, - request_options=self.__get_headers() + request_options={'additional_headers': self.__get_headers()} ) log_info(SkyflowMessages.Info.DETOKENIZE_SUCCESS.value, self.__vault_client.get_logger()) detokenize_response = parse_detokenize_response(api_response) @@ -262,7 +261,7 @@ def tokenize(self, request: TokenizeRequest): api_response = tokens_api.record_service_tokenize( self.__vault_client.get_vault_id(), tokenization_parameters=records_list, - request_options=self.__get_headers() + request_options={'additional_headers': self.__get_headers()} ) tokenize_response = parse_tokenize_response(api_response) log_info(SkyflowMessages.Info.TOKENIZE_SUCCESS.value, self.__vault_client.get_logger()) @@ -285,7 +284,7 @@ def upload_file(self, request: FileUploadRequest): file=self.__get_file_for_file_upload(request), skyflow_id=request.skyflow_id, return_file_metadata= False, - request_options=self.__get_headers() + request_options={'additional_headers': self.__get_headers()} ) log_info(SkyflowMessages.Info.FILE_UPLOAD_REQUEST_RESOLVED.value, self.__vault_client.get_logger()) log_info(SkyflowMessages.Info.FILE_UPLOAD_SUCCESS.value, self.__vault_client.get_logger()) diff --git a/tests/utils/validations/test__validations.py b/tests/utils/validations/test__validations.py index 48332a55..b1247ebc 100644 --- a/tests/utils/validations/test__validations.py +++ b/tests/utils/validations/test__validations.py @@ -116,7 +116,7 @@ def test_validate_credentials_with_expired_token(self): with patch('skyflow.service_account.is_expired', return_value=True): with self.assertRaises(SkyflowError) as context: validate_credentials(self.logger, credentials) - self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_CREDENTIALS_TOKEN.value) + self.assertEqual(context.exception.message, SkyflowMessages.Error.EXPIRED_TOKEN.value) def test_validate_credentials_empty_credentials(self): credentials = {} diff --git a/tests/vault/client/test__client.py b/tests/vault/client/test__client.py index 565b1e6f..9d0d2520 100644 --- a/tests/vault/client/test__client.py +++ b/tests/vault/client/test__client.py @@ -1,5 +1,5 @@ import unittest -from unittest.mock import patch, MagicMock +from unittest.mock import patch, MagicMock, call from skyflow.vault.client.client import VaultClient CONFIG = { @@ -12,11 +12,19 @@ } CREDENTIALS_WITH_API_KEY = {"api_key": "dummy_api_key"} +CREDENTIALS_WITH_TOKEN = {"token": "dummy_static_token"} +CREDENTIALS_WITH_PATH = {"path": "/some/path/credentials.json"} +CREDENTIALS_WITH_STRING = {"credentials_string": '{"clientID": "x"}'} + class TestVaultClient(unittest.TestCase): def setUp(self): self.vault_client = VaultClient(CONFIG) + # ------------------------------------------------------------------ # + # Basic setters / getters # + # ------------------------------------------------------------------ # + def test_set_common_skyflow_credentials(self): credentials = {"api_key": "dummy_api_key"} self.vault_client.set_common_skyflow_credentials(credentials) @@ -28,73 +36,289 @@ def test_set_logger(self): self.assertEqual(self.vault_client.get_log_level(), "INFO") self.assertEqual(self.vault_client.get_logger(), mock_logger) + def test_get_vault_id(self): + self.assertEqual(self.vault_client.get_vault_id(), CONFIG["vault_id"]) + + def test_get_config(self): + self.assertEqual(self.vault_client.get_config(), CONFIG) + + def test_get_common_skyflow_credentials(self): + credentials = {"api_key": "dummy_api_key"} + self.vault_client.set_common_skyflow_credentials(credentials) + self.assertEqual(self.vault_client.get_common_skyflow_credentials(), credentials) + + def test_get_log_level(self): + self.vault_client.set_logger("DEBUG", MagicMock()) + self.assertEqual(self.vault_client.get_log_level(), "DEBUG") + + def test_get_logger(self): + mock_logger = MagicMock() + self.vault_client.set_logger("INFO", mock_logger) + self.assertEqual(self.vault_client.get_logger(), mock_logger) + + # ------------------------------------------------------------------ # + # initialize_client_configuration — first call (slow path) # + # ------------------------------------------------------------------ # + @patch("skyflow.vault.client.client.get_credentials") @patch("skyflow.vault.client.client.get_vault_url") @patch("skyflow.vault.client.client.VaultClient.initialize_api_client") - def test_initialize_client_configuration(self, mock_init_api_client, mock_get_vault_url, mock_get_credentials): - mock_get_credentials.return_value = (CREDENTIALS_WITH_API_KEY) + def test_initialize_client_configuration_first_call( + self, mock_init_api_client, mock_get_vault_url, mock_get_credentials + ): + mock_get_credentials.return_value = CREDENTIALS_WITH_API_KEY mock_get_vault_url.return_value = "https://test-vault-url.com" self.vault_client.initialize_client_configuration() - mock_get_credentials.assert_called_once_with(CONFIG["credentials"], None, logger=None) - mock_get_vault_url.assert_called_once_with(CONFIG["cluster_id"], CONFIG["env"], CONFIG["vault_id"], logger=None) + mock_get_credentials.assert_called_once_with( + CONFIG["credentials"], None, logger=None + ) + mock_get_vault_url.assert_called_once_with( + CONFIG["cluster_id"], CONFIG["env"], CONFIG["vault_id"], logger=None + ) mock_init_api_client.assert_called_once() - @patch("skyflow.vault.client.client.Skyflow") - def test_initialize_api_client(self, mock_api_client): - self.vault_client.initialize_api_client("https://test-vault-url.com", "dummy_token") - mock_api_client.assert_called_once_with(base_url="https://test-vault-url.com", token="dummy_token") + # ------------------------------------------------------------------ # + # initialize_client_configuration — fast path (static token) # + # ------------------------------------------------------------------ # - def test_get_records_api(self): + @patch("skyflow.vault.client.client.get_credentials") + @patch("skyflow.vault.client.client.get_vault_url") + @patch("skyflow.vault.client.client.VaultClient.initialize_api_client") + def test_initialize_client_configuration_fast_path_api_key( + self, mock_init_api_client, mock_get_vault_url, mock_get_credentials + ): + """Once initialized with api_key, subsequent calls skip all work.""" + mock_get_credentials.return_value = CREDENTIALS_WITH_API_KEY + mock_get_vault_url.return_value = "https://test-vault-url.com" + # Side-effect simulates initialize_api_client actually setting __api_client + mock_init_api_client.side_effect = lambda *_: setattr( + self.vault_client, "_VaultClient__api_client", MagicMock() + ) + + self.vault_client.initialize_client_configuration() # first call — slow path + mock_get_credentials.reset_mock() + mock_get_vault_url.reset_mock() + mock_init_api_client.reset_mock() + + self.vault_client.initialize_client_configuration() # second call — fast path + + mock_get_credentials.assert_not_called() + mock_get_vault_url.assert_not_called() + mock_init_api_client.assert_not_called() + + @patch("skyflow.vault.client.client.get_credentials") + @patch("skyflow.vault.client.client.get_vault_url") + @patch("skyflow.vault.client.client.VaultClient.initialize_api_client") + def test_initialize_client_configuration_fast_path_static_token( + self, mock_init_api_client, mock_get_vault_url, mock_get_credentials + ): + """Once initialized with a static token, subsequent calls skip all work.""" + mock_get_credentials.return_value = CREDENTIALS_WITH_TOKEN + mock_get_vault_url.return_value = "https://test-vault-url.com" + mock_init_api_client.side_effect = lambda *_: setattr( + self.vault_client, "_VaultClient__api_client", MagicMock() + ) + + self.vault_client.initialize_client_configuration() + mock_get_credentials.reset_mock() + mock_get_vault_url.reset_mock() + mock_init_api_client.reset_mock() + + self.vault_client.initialize_client_configuration() + + mock_get_credentials.assert_not_called() + mock_get_vault_url.assert_not_called() + mock_init_api_client.assert_not_called() + + # ------------------------------------------------------------------ # + # initialize_client_configuration — fast path (service account) # + # ------------------------------------------------------------------ # + + @patch("skyflow.vault.client.client.is_expired", return_value=False) + @patch("skyflow.vault.client.client.get_credentials") + @patch("skyflow.vault.client.client.get_vault_url") + @patch("skyflow.vault.client.client.VaultClient.initialize_api_client") + def test_initialize_client_configuration_fast_path_valid_sa_token( + self, mock_init_api_client, mock_get_vault_url, mock_get_credentials, mock_is_expired + ): + """Service account with a still-valid token skips get_bearer_token entirely.""" + mock_get_credentials.return_value = CREDENTIALS_WITH_PATH + mock_get_vault_url.return_value = "https://test-vault-url.com" + + # Seed the cached bearer token as if first call already ran self.vault_client._VaultClient__api_client = MagicMock() - self.vault_client._VaultClient__api_client.records = MagicMock() - records_api = self.vault_client.get_records_api() - self.assertIsNotNone(records_api) + self.vault_client._VaultClient__is_static_token = False + self.vault_client._VaultClient__bearer_token = "cached_sa_token" + self.vault_client._VaultClient__credentials = CREDENTIALS_WITH_PATH - def test_get_tokens_api(self): + self.vault_client.initialize_client_configuration() + + mock_get_credentials.assert_not_called() + mock_get_vault_url.assert_not_called() + mock_init_api_client.assert_not_called() + + # ------------------------------------------------------------------ # + # initialize_client_configuration — token expiry (no client reinit) # + # ------------------------------------------------------------------ # + + @patch("skyflow.vault.client.client.generate_bearer_token", return_value=("new_sa_token", None)) + @patch("skyflow.vault.client.client.is_expired", return_value=True) + @patch("skyflow.vault.client.client.get_credentials") + @patch("skyflow.vault.client.client.get_vault_url") + @patch("skyflow.vault.client.client.VaultClient.initialize_api_client") + def test_initialize_client_configuration_expired_token_no_reinit( + self, mock_init_api_client, mock_get_vault_url, mock_get_credentials, + mock_is_expired, mock_generate_bearer_token + ): + """Expired service account token is regenerated in-place; httpx client is NOT recreated.""" + mock_get_credentials.return_value = CREDENTIALS_WITH_PATH + mock_get_vault_url.return_value = "https://test-vault-url.com" + + # Client already initialized — simulate warm state with an expired token self.vault_client._VaultClient__api_client = MagicMock() - self.vault_client._VaultClient__api_client.tokens = MagicMock() - tokens_api = self.vault_client.get_tokens_api() - self.assertIsNotNone(tokens_api) + self.vault_client._VaultClient__is_static_token = False + self.vault_client._VaultClient__bearer_token = "expired_sa_token" + self.vault_client._VaultClient__credentials = CREDENTIALS_WITH_PATH - def test_get_query_api(self): + self.vault_client.initialize_client_configuration() + + # Token was regenerated + mock_generate_bearer_token.assert_called_once() + self.assertEqual( + self.vault_client._VaultClient__bearer_token, "new_sa_token" + ) + # httpx client was NOT recreated + mock_init_api_client.assert_not_called() + + # ------------------------------------------------------------------ # + # initialize_client_configuration — config update forces reinit # + # ------------------------------------------------------------------ # + + @patch("skyflow.vault.client.client.get_credentials") + @patch("skyflow.vault.client.client.get_vault_url") + @patch("skyflow.vault.client.client.VaultClient.initialize_api_client") + def test_initialize_client_configuration_reinit_after_update_config( + self, mock_init_api_client, mock_get_vault_url, mock_get_credentials + ): + """update_config() marks the client stale; next call must recreate it.""" + mock_get_credentials.return_value = CREDENTIALS_WITH_API_KEY + mock_get_vault_url.return_value = "https://test-vault-url.com" + + # Simulate already-initialized client self.vault_client._VaultClient__api_client = MagicMock() - self.vault_client._VaultClient__api_client.query = MagicMock() - query_api = self.vault_client.get_query_api() - self.assertIsNotNone(query_api) + self.vault_client._VaultClient__is_static_token = True - def test_get_vault_id(self): - self.assertEqual(self.vault_client.get_vault_id(), CONFIG["vault_id"]) + self.vault_client.update_config({"cluster_id": "new_cluster"}) + self.vault_client.initialize_client_configuration() + + mock_get_credentials.assert_called_once() + mock_get_vault_url.assert_called_once() + mock_init_api_client.assert_called_once() + + # ------------------------------------------------------------------ # + # initialize_api_client — lambda token provider # + # ------------------------------------------------------------------ # + + @patch("skyflow.vault.client.client.Skyflow") + def test_initialize_api_client_passes_callable_token(self, mock_skyflow): + """initialize_api_client must pass a callable (lambda) as token, not a string.""" + self.vault_client.initialize_api_client("https://test-vault-url.com", "initial_token") + + args, kwargs = mock_skyflow.call_args + self.assertEqual(kwargs["base_url"], "https://test-vault-url.com") + self.assertTrue(callable(kwargs["token"]), "token must be a callable (lambda)") + + @patch("skyflow.vault.client.client.Skyflow") + def test_initialize_api_client_lambda_returns_cached_bearer_token(self, mock_skyflow): + """Lambda returns __bearer_token when it is set (interceptor behaviour).""" + self.vault_client._VaultClient__bearer_token = "refreshed_token" + self.vault_client.initialize_api_client("https://test-vault-url.com", "initial_token") + + _, kwargs = mock_skyflow.call_args + self.assertEqual(kwargs["token"](), "refreshed_token") + + @patch("skyflow.vault.client.client.Skyflow") + def test_initialize_api_client_lambda_falls_back_to_initial_token(self, mock_skyflow): + """Lambda falls back to the initial token when __bearer_token is None.""" + self.vault_client._VaultClient__bearer_token = None + self.vault_client.initialize_api_client("https://test-vault-url.com", "initial_token") + + _, kwargs = mock_skyflow.call_args + self.assertEqual(kwargs["token"](), "initial_token") + + # ------------------------------------------------------------------ # + # get_bearer_token # + # ------------------------------------------------------------------ # + + def test_get_bearer_token_with_api_key(self): + result = self.vault_client.get_bearer_token(CREDENTIALS_WITH_API_KEY) + self.assertEqual(result, "dummy_api_key") + + def test_get_bearer_token_with_static_token(self): + result = self.vault_client.get_bearer_token(CREDENTIALS_WITH_TOKEN) + self.assertEqual(result, "dummy_static_token") + + @patch("skyflow.vault.client.client.generate_bearer_token", return_value=("sa_token", None)) + def test_get_bearer_token_generates_from_path_on_first_call(self, mock_generate): + result = self.vault_client.get_bearer_token(CREDENTIALS_WITH_PATH) + mock_generate.assert_called_once() + self.assertEqual(result, "sa_token") + self.assertEqual(self.vault_client._VaultClient__bearer_token, "sa_token") + + @patch("skyflow.vault.client.client.generate_bearer_token_from_creds", return_value=("sa_token_str", None)) + @patch("skyflow.vault.client.client.log_info") + def test_get_bearer_token_generates_from_credentials_string(self, mock_log, mock_generate): + result = self.vault_client.get_bearer_token(CREDENTIALS_WITH_STRING) + mock_generate.assert_called_once() + self.assertEqual(result, "sa_token_str") + + @patch("skyflow.vault.client.client.generate_bearer_token", return_value=("new_token", None)) + @patch("skyflow.vault.client.client.is_expired", return_value=True) + @patch("skyflow.vault.client.client.log_info") + def test_get_bearer_token_regenerates_on_expiry(self, mock_log, mock_is_expired, mock_generate): + """Expired token is regenerated silently — no exception raised.""" + self.vault_client._VaultClient__bearer_token = "expired_token" + result = self.vault_client.get_bearer_token(CREDENTIALS_WITH_PATH) + mock_generate.assert_called_once() + self.assertEqual(result, "new_token") @patch("skyflow.vault.client.client.generate_bearer_token") - @patch("skyflow.vault.client.client.generate_bearer_token_from_creds") + @patch("skyflow.vault.client.client.is_expired", return_value=False) @patch("skyflow.vault.client.client.log_info") - def test_get_bearer_token_with_api_key(self, mock_log_info, mock_generate_bearer_token, - mock_generate_bearer_token_from_creds): - token = self.vault_client.get_bearer_token(CREDENTIALS_WITH_API_KEY) - self.assertEqual(token, CREDENTIALS_WITH_API_KEY["api_key"]) - - def test_update_config(self): - new_config = {"credentials": "new_credentials"} - self.vault_client.update_config(new_config) + def test_get_bearer_token_reuses_valid_cached_token(self, mock_log, mock_is_expired, mock_generate): + """Valid cached token is reused without calling generate_bearer_token.""" + self.vault_client._VaultClient__bearer_token = "valid_token" + result = self.vault_client.get_bearer_token(CREDENTIALS_WITH_PATH) + mock_generate.assert_not_called() + self.assertEqual(result, "valid_token") + + # ------------------------------------------------------------------ # + # update_config # + # ------------------------------------------------------------------ # + + def test_update_config_sets_flag(self): + self.vault_client.update_config({"credentials": "new_credentials"}) self.assertTrue(self.vault_client._VaultClient__is_config_updated) self.assertEqual(self.vault_client.get_config()["credentials"], "new_credentials") - def test_get_config(self): - self.assertEqual(self.vault_client.get_config(), CONFIG) + # ------------------------------------------------------------------ # + # API accessor stubs # + # ------------------------------------------------------------------ # - def test_get_common_skyflow_credentials(self): - credentials = {"api_key": "dummy_api_key"} - self.vault_client.set_common_skyflow_credentials(credentials) - self.assertEqual(self.vault_client.get_common_skyflow_credentials(), credentials) + def test_get_records_api(self): + self.vault_client._VaultClient__api_client = MagicMock() + self.assertIsNotNone(self.vault_client.get_records_api()) - def test_get_log_level(self): - log_level = "DEBUG" - self.vault_client.set_logger(log_level, MagicMock()) - self.assertEqual(self.vault_client.get_log_level(), log_level) + def test_get_tokens_api(self): + self.vault_client._VaultClient__api_client = MagicMock() + self.assertIsNotNone(self.vault_client.get_tokens_api()) - def test_get_logger(self): - mock_logger = MagicMock() - self.vault_client.set_logger("INFO", mock_logger) - self.assertEqual(self.vault_client.get_logger(), mock_logger) \ No newline at end of file + def test_get_query_api(self): + self.vault_client._VaultClient__api_client = MagicMock() + self.assertIsNotNone(self.vault_client.get_query_api()) + + +if __name__ == "__main__": + unittest.main()