Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

if sys.version_info < (3, 8):
raise RuntimeError("skyflow requires Python 3.8+")
current_version = '2.0.0'
current_version = '2.0.0.dev0+e2aa629'

setup(
name='skyflow',
Expand Down
1 change: 0 additions & 1 deletion skyflow/error/_skyflow_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,4 @@ def __init__(self,
self.http_status = http_status if http_status else SkyflowMessages.HttpStatus.BAD_REQUEST.value
self.details = details
self.request_id = request_id
log_error(message, http_code, request_id, grpc_code, http_status, details)
super().__init__()
2 changes: 1 addition & 1 deletion skyflow/utils/_skyflow_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class Error(Enum):
EMPTY_CREDENTIALS_TOKEN = f"{error_prefix} Initialization failed. Invalid token.Specify a valid credentials token."
INVALID_CREDENTIALS_TOKEN_IN_CONFIG = f"{error_prefix} Initialization failed. Invalid credentials token for {{}} with id {{}}. Expected token to be a string."
INVALID_CREDENTIALS_TOKEN = f"{error_prefix} Initialization failed. Invalid credentials token. Expected token to be a string."
EXPIRED_TOKEN = f"${error_prefix} Initialization failed. Given token is expired. Specify a valid credentials token."
EXPIRED_TOKEN = f"{error_prefix} Initialization failed. Given token is expired. Specify a valid credentials token."
EMPTY_API_KEY_IN_CONFIG = f"{error_prefix} Initialization failed. Invalid api key for {{}} with id {{}}.Specify a valid api key."
EMPTY_API_KEY= f"{error_prefix} Initialization failed. Invalid api key.Specify a valid api key."
INVALID_API_KEY_IN_CONFIG = f"{error_prefix} Initialization failed. Invalid api key for {{}} with id {{}}. Expected api key to be a string."
Expand Down
34 changes: 15 additions & 19 deletions skyflow/utils/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,26 +30,18 @@
invalid_input_error_code = SkyflowMessages.ErrorCodes.INVALID_INPUT.value

def get_credentials(config_level_creds = None, common_skyflow_creds = None, logger = None):
dotenv.load_dotenv()
dotenv_path = dotenv.find_dotenv(usecwd=True)
if dotenv_path:
load_dotenv(dotenv_path)
env_skyflow_credentials = os.getenv("SKYFLOW_CREDENTIALS")
if config_level_creds:
return config_level_creds
if common_skyflow_creds:
return common_skyflow_creds
dotenv_path = dotenv.find_dotenv(usecwd=True)
if dotenv_path:
load_dotenv(dotenv_path)
env_skyflow_credentials = os.getenv("SKYFLOW_CREDENTIALS")
if env_skyflow_credentials:
env_skyflow_credentials.strip()
try:
env_creds = env_skyflow_credentials.replace('\n', '\\n')
return {
'credentials_string': env_creds
}
except json.JSONDecodeError:
raise SkyflowError(SkyflowMessages.Error.INVALID_JSON_FORMAT_IN_CREDENTIALS_ENV.value, invalid_input_error_code)
else:
raise SkyflowError(SkyflowMessages.Error.INVALID_CREDENTIALS.value, invalid_input_error_code)
env_creds = env_skyflow_credentials.strip().replace('\n', '\\n')
return {'credentials_string': env_creds}
raise SkyflowError(SkyflowMessages.Error.INVALID_CREDENTIALS.value, invalid_input_error_code)

def validate_api_key(api_key: str, logger = None) -> bool:
if len(api_key) != 42:
Expand Down Expand Up @@ -185,8 +177,12 @@ def get_data_from_content_type(data, content_type):
return converted_data, files


_CACHED_METRICS: dict = {}

def get_metrics():
sdk_name_version = "skyflow-python@" + SDK_VERSION
global _CACHED_METRICS
if _CACHED_METRICS:
return _CACHED_METRICS

try:
sdk_client_device_model = platform.node()
Expand All @@ -203,13 +199,13 @@ def get_metrics():
except Exception:
sdk_runtime_details = ""

details_dic = {
'sdk_name_version': sdk_name_version,
_CACHED_METRICS = {
'sdk_name_version': "skyflow-python@" + SDK_VERSION,
'sdk_client_device_model': sdk_client_device_model,
'sdk_client_os_details': sdk_client_os_details,
'sdk_runtime_details': "Python " + sdk_runtime_details,
}
return details_dic
return _CACHED_METRICS

def parse_insert_response(api_response, continue_on_error):
# Retrieve the headers and data from the API response
Expand Down
2 changes: 1 addition & 1 deletion skyflow/utils/_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
SDK_VERSION = '2.0.0'
SDK_VERSION = '2.0.0.dev0+e2aa629'
6 changes: 3 additions & 3 deletions skyflow/utils/validations/_validations.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,8 @@ def validate_credentials(logger, credentials, config_id_type=None, config_id=Non
)
if is_expired(credentials.get("token"), logger):
raise SkyflowError(
SkyflowMessages.Error.INVALID_CREDENTIALS_TOKEN.value.format(config_id_type, config_id)
if config_id_type and config_id else SkyflowMessages.Error.INVALID_CREDENTIALS_TOKEN.value,
SkyflowMessages.Error.EXPIRED_TOKEN.value
if config_id_type and config_id else SkyflowMessages.Error.EXPIRED_TOKEN.value,
invalid_input_error_code
)
elif "api_key" in credentials:
Expand Down Expand Up @@ -389,7 +389,7 @@ def validate_deidentify_file_request(logger, request: DeidentifyFileRequest):
if hasattr(request, 'wait_time') and request.wait_time is not None:
if not isinstance(request.wait_time, (int, float)):
raise SkyflowError(SkyflowMessages.Error.INVALID_WAIT_TIME.value, invalid_input_error_code)
if request.wait_time < 0 and request.wait_time > 64:
if request.wait_time < 0 or request.wait_time > 64:
raise SkyflowError(SkyflowMessages.Error.WAIT_TIME_GREATER_THEN_64.value, invalid_input_error_code)

def validate_insert_request(logger, request):
Expand Down
43 changes: 26 additions & 17 deletions skyflow/vault/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ def __init__(self, config):
self.__logger = None
self.__is_config_updated = False
self.__bearer_token = None
self.__credentials = None
self.__vault_url = None
self.__is_static_token = None

def set_common_skyflow_credentials(self, credentials):
self.__common_skyflow_credentials = credentials
Expand All @@ -23,16 +26,27 @@ def set_logger(self, log_level, logger):
self.__logger = logger

def initialize_client_configuration(self):
credentials = get_credentials(self.__config.get("credentials"), self.__common_skyflow_credentials, logger = self.__logger)
token = self.get_bearer_token(credentials)
vault_url = get_vault_url(self.__config.get("cluster_id"),
self.__config.get("env"),
self.__config.get("vault_id"),
logger = self.__logger)
self.initialize_api_client(vault_url, token)

def initialize_api_client(self, vault_url, token):
self.__api_client = Skyflow(base_url=vault_url, token=token)
if self.__api_client is not None and not self.__is_config_updated:
if self.__is_static_token:
return
if self.__bearer_token is not None and not is_expired(self.__bearer_token):
return

needs_reinit = self.__api_client is None or self.__is_config_updated
if needs_reinit:
self.__credentials = get_credentials(self.__config.get("credentials"), self.__common_skyflow_credentials, logger=self.__logger)
self.__vault_url = get_vault_url(self.__config.get("cluster_id"),
self.__config.get("env"),
self.__config.get("vault_id"),
logger=self.__logger)
self.__is_static_token = 'token' in self.__credentials or 'api_key' in self.__credentials
bearer_token = self.get_bearer_token(self.__credentials)
if needs_reinit:
self.initialize_api_client(self.__vault_url, bearer_token)

def initialize_api_client(self, vault_url, bearer_token):
token_provider = lambda: self.__bearer_token if self.__bearer_token else bearer_token # noqa: E731
self.__api_client = Skyflow(base_url=vault_url, token=token_provider)

def get_records_api(self):
return self.__api_client.records
Expand Down Expand Up @@ -63,11 +77,10 @@ def get_bearer_token(self, credentials):
"ctx": self.__config.get("ctx")
}

if self.__bearer_token is None or self.__is_config_updated:
if self.__bearer_token is None or self.__is_config_updated or is_expired(self.__bearer_token):
if 'path' in credentials:
path = credentials.get("path")
self.__bearer_token, _ = generate_bearer_token(
path,
credentials.get("path"),
options,
self.__logger
)
Expand All @@ -83,10 +96,6 @@ def get_bearer_token(self, credentials):
else:
log_info(SkyflowMessages.Info.REUSE_BEARER_TOKEN.value, self.__logger)

if is_expired(self.__bearer_token):
self.__is_config_updated = True
raise SyntaxError(SkyflowMessages.Error.EXPIRED_TOKEN.value, SkyflowMessages.ErrorCodes.INVALID_INPUT.value)

return self.__bearer_token

def update_config(self, config):
Expand Down
26 changes: 13 additions & 13 deletions skyflow/vault/controller/_detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def __poll_for_processed_file(self, run_id, max_wait_time=64):
current_wait_time = 1 # Start with 1 second
try:
while True:
response = files_api.get_run(run_id, vault_id=self.__vault_client.get_vault_id(), request_options=self.__get_headers()).data
response = files_api.get_run(run_id, vault_id=self.__vault_client.get_vault_id(), request_options={'additional_headers': self.__get_headers()}).data
status = response.status
if status == 'IN_PROGRESS':
if current_wait_time >= max_wait_time:
Expand Down Expand Up @@ -228,7 +228,7 @@ def deidentify_text(self, request: DeidentifyTextRequest) -> DeidentifyTextRespo
restrict_regex=deidentify_text_body['restrict_regex'],
token_type=deidentify_text_body['token_type'],
transformations=deidentify_text_body['transformations'],
request_options=self.__get_headers()
request_options={'additional_headers': self.__get_headers()}
)
deidentify_text_response = parse_deidentify_text_response(api_response)
log_info(SkyflowMessages.Info.DEIDENTIFY_TEXT_SUCCESS.value, self.__vault_client.get_logger())
Expand All @@ -252,7 +252,7 @@ def reidentify_text(self, request: ReidentifyTextRequest) -> ReidentifyTextRespo
vault_id=self.__vault_client.get_vault_id(),
text=reidentify_text_body['text'],
format=reidentify_text_body['format'],
request_options=self.__get_headers()
request_options={'additional_headers': self.__get_headers()}
)
reidentify_text_response = parse_reidentify_text_response(api_response)
log_info(SkyflowMessages.Info.REIDENTIFY_TEXT_SUCCESS.value, self.__vault_client.get_logger())
Expand Down Expand Up @@ -296,7 +296,7 @@ def deidentify_file(self, request: DeidentifyFileRequest):
'allow_regex': request.allow_regex_list,
'restrict_regex': request.restrict_regex_list,
'transformations': self.__get_transformations(request),
'request_options': self.__get_headers()
'request_options': {'additional_headers': self.__get_headers()}
}

elif file_extension in ['mp3', 'wav']:
Expand All @@ -316,7 +316,7 @@ def deidentify_file(self, request: DeidentifyFileRequest):
'bleep_frequency': getattr(request, 'bleep', None).frequency if getattr(request, 'bleep', None) is not None else None,
'bleep_start_padding': getattr(request, 'bleep', None).start_padding if getattr(request, 'bleep', None) is not None else None,
'bleep_stop_padding': getattr(request, 'bleep', None).stop_padding if getattr(request, 'bleep', None) is not None else None,
'request_options': self.__get_headers()
'request_options': {'additional_headers': self.__get_headers()}
}

elif file_extension == 'pdf':
Expand All @@ -331,7 +331,7 @@ def deidentify_file(self, request: DeidentifyFileRequest):
'restrict_regex': request.restrict_regex_list,
'max_resolution': getattr(request, 'max_resolution', None),
'density': getattr(request, 'pixel_density', None),
'request_options': self.__get_headers()
'request_options': {'additional_headers': self.__get_headers()}
}

elif file_extension in ['jpeg', 'jpg', 'png', 'bmp', 'tif', 'tiff']:
Expand All @@ -347,7 +347,7 @@ def deidentify_file(self, request: DeidentifyFileRequest):
'masking_method': getattr(request, 'masking_method', None),
'output_ocr_text': getattr(request, 'output_ocr_text', None),
'output_processed_image': getattr(request, 'output_processed_image', None),
'request_options': self.__get_headers()
'request_options': {'additional_headers': self.__get_headers()}
}

elif file_extension in ['ppt', 'pptx']:
Expand All @@ -360,7 +360,7 @@ def deidentify_file(self, request: DeidentifyFileRequest):
'token_type': self.__get_token_format(request),
'allow_regex': request.allow_regex_list,
'restrict_regex': request.restrict_regex_list,
'request_options': self.__get_headers()
'request_options': {'additional_headers': self.__get_headers()}
}

elif file_extension in ['csv', 'xls', 'xlsx']:
Expand All @@ -373,7 +373,7 @@ def deidentify_file(self, request: DeidentifyFileRequest):
'token_type': self.__get_token_format(request),
'allow_regex': request.allow_regex_list,
'restrict_regex': request.restrict_regex_list,
'request_options': self.__get_headers()
'request_options': {'additional_headers': self.__get_headers()}
}

elif file_extension in ['doc', 'docx']:
Expand All @@ -386,7 +386,7 @@ def deidentify_file(self, request: DeidentifyFileRequest):
'token_type': self.__get_token_format(request),
'allow_regex': request.allow_regex_list,
'restrict_regex': request.restrict_regex_list,
'request_options': self.__get_headers()
'request_options': {'additional_headers': self.__get_headers()}
}

elif file_extension in ['json', 'xml']:
Expand All @@ -400,7 +400,7 @@ def deidentify_file(self, request: DeidentifyFileRequest):
'allow_regex': request.allow_regex_list,
'restrict_regex': request.restrict_regex_list,
'transformations': self.__get_transformations(request),
'request_options': self.__get_headers()
'request_options': {'additional_headers': self.__get_headers()}
}

else:
Expand All @@ -414,7 +414,7 @@ def deidentify_file(self, request: DeidentifyFileRequest):
'allow_regex': request.allow_regex_list,
'restrict_regex': request.restrict_regex_list,
'transformations': self.__get_transformations(request),
'request_options': self.__get_headers()
'request_options': {'additional_headers': self.__get_headers()}
}

log_info(SkyflowMessages.Info.DETECT_FILE_REQUEST_RESOLVED.value, self.__vault_client.get_logger())
Expand Down Expand Up @@ -448,7 +448,7 @@ def get_detect_run(self, request: GetDetectRunRequest):
response = files_api.get_run(
run_id,
vault_id=self.__vault_client.get_vault_id(),
request_options=self.__get_headers()
request_options={'additional_headers': self.__get_headers()}
)
if response.data.status == 'IN_PROGRESS':
parsed_response = self.__parse_deidentify_file_response(DeidentifyFileResponse(run_id=run_id, status='IN_PROGRESS'))
Expand Down
Loading
Loading