diff --git a/tests/unit/vertex_langchain/test_reasoning_engines.py b/tests/unit/vertex_langchain/test_reasoning_engines.py index 019dc214d9..1355100371 100644 --- a/tests/unit/vertex_langchain/test_reasoning_engines.py +++ b/tests/unit/vertex_langchain/test_reasoning_engines.py @@ -794,6 +794,109 @@ def test_create_reasoning_engine( retry=_TEST_RETRY, ) + def test_create_reasoning_engine_with_env_vars( + self, + create_reasoning_engine_mock, + cloud_storage_create_bucket_mock, + tarfile_open_mock, + cloudpickle_dump_mock, + get_gca_resource_mock, + ): + reasoning_engines.ReasoningEngine.create( + self.test_app, + display_name=_TEST_REASONING_ENGINE_DISPLAY_NAME, + requirements=_TEST_REASONING_ENGINE_REQUIREMENTS, + extra_packages=[_TEST_REASONING_ENGINE_EXTRA_PACKAGE_PATH], + env_vars={ + "TEST_ENV_VAR": "TEST_ENV_VAR_VALUE", + "TEST_SECRET_ENV_VAR": types.SecretRef( + secret="TEST_SECRET_NAME", + version="TEST_SECRET_VERSION", + ), + }, + ) + want_reasoning_engine = types.ReasoningEngine( + display_name=_TEST_REASONING_ENGINE_DISPLAY_NAME, + spec=types.ReasoningEngineSpec( + package_spec=types.ReasoningEngineSpec.PackageSpec( + python_version=f"{sys.version_info.major}.{sys.version_info.minor}", + pickle_object_gcs_uri=_TEST_REASONING_ENGINE_GCS_URI, + dependency_files_gcs_uri=_TEST_REASONING_ENGINE_DEPENDENCY_FILES_GCS_URI, + requirements_gcs_uri=_TEST_REASONING_ENGINE_REQUIREMENTS_GCS_URI, + ), + deployment_spec=types.ReasoningEngineSpec.DeploymentSpec( + env=[ + types.EnvVar( + name="TEST_ENV_VAR", + value="TEST_ENV_VAR_VALUE", + ) + ], + secret_env=[ + types.SecretEnvVar( + name="TEST_SECRET_ENV_VAR", + secret_ref=types.SecretRef( + secret="TEST_SECRET_NAME", + version="TEST_SECRET_VERSION", + ), + ) + ], + ), + ), + ) + want_reasoning_engine.spec.class_methods.append( + _TEST_REASONING_ENGINE_QUERY_SCHEMA + ) + create_reasoning_engine_mock.assert_called_with( + parent=_TEST_PARENT, + reasoning_engine=want_reasoning_engine, + ) + + @mock.patch.dict(os.environ, {"TEST_ENV_VAR_FROM_OS": "os-value"}) + def test_generate_deployment_spec_from_env_var_names(self): + deployment_spec, update_masks = _utils._generate_deployment_spec_or_raise( + env_vars=["TEST_ENV_VAR_FROM_OS"], + ) + + assert _utils.to_dict(deployment_spec) == { + "env": [{"name": "TEST_ENV_VAR_FROM_OS", "value": "os-value"}] + } + assert update_masks == ["spec.deployment_spec.env"] + + def test_generate_deployment_spec_from_secret_ref_dict(self): + deployment_spec, update_masks = _utils._generate_deployment_spec_or_raise( + env_vars={ + "TEST_SECRET_ENV_VAR": { + "secret": "TEST_SECRET_NAME", + "version": "TEST_SECRET_VERSION", + }, + }, + ) + + assert _utils.to_dict(deployment_spec) == { + "secretEnv": [ + { + "name": "TEST_SECRET_ENV_VAR", + "secretRef": { + "secret": "TEST_SECRET_NAME", + "version": "TEST_SECRET_VERSION", + }, + } + ] + } + assert update_masks == ["spec.deployment_spec.secret_env"] + + def test_generate_deployment_spec_rejects_invalid_env_var_value_type(self): + with pytest.raises(TypeError, match="Unknown value type in env_vars"): + _utils._generate_deployment_spec_or_raise(env_vars={"TEST_ENV_VAR": 1}) + + def test_generate_deployment_spec_rejects_missing_env_var_name(self): + with pytest.raises(ValueError, match="Env var not found in os.environ"): + _utils._generate_deployment_spec_or_raise(env_vars=["MISSING_ENV_VAR"]) + + def test_generate_deployment_spec_rejects_string_env_vars(self): + with pytest.raises(TypeError, match="env_vars must be a list, tuple or a dict"): + _utils._generate_deployment_spec_or_raise(env_vars="TEST_ENV_VAR") + @pytest.mark.usefixtures("caplog") def test_create_reasoning_engine_warn_resource_name( self, diff --git a/vertexai/reasoning_engines/_reasoning_engines.py b/vertexai/reasoning_engines/_reasoning_engines.py index 322bf2a2d4..9ffd0799ca 100644 --- a/vertexai/reasoning_engines/_reasoning_engines.py +++ b/vertexai/reasoning_engines/_reasoning_engines.py @@ -161,6 +161,9 @@ def create( gcs_dir_name: str = _DEFAULT_GCS_DIR_NAME, sys_version: Optional[str] = None, extra_packages: Optional[Sequence[str]] = None, + env_vars: Optional[ + Union[Sequence[str], Dict[str, Union[str, aip_types.SecretRef]]] + ] = None, ) -> "ReasoningEngine": """Creates a new ReasoningEngine. @@ -301,6 +304,11 @@ def create( reasoning_engine_spec = aip_types.ReasoningEngineSpec( package_spec=package_spec, ) + if env_vars: + deployment_spec, _ = _utils._generate_deployment_spec_or_raise( + env_vars=env_vars, + ) + reasoning_engine_spec.deployment_spec = deployment_spec class_methods_spec = _generate_class_methods_spec_or_raise( reasoning_engine, _get_registered_operations(reasoning_engine) ) diff --git a/vertexai/reasoning_engines/_utils.py b/vertexai/reasoning_engines/_utils.py index dbb0938748..fc09eacbf4 100644 --- a/vertexai/reasoning_engines/_utils.py +++ b/vertexai/reasoning_engines/_utils.py @@ -16,13 +16,26 @@ import dataclasses import inspect import json +import os import types import typing -from typing import Any, Callable, Dict, Iterable, Mapping, Optional, Sequence, Union +from typing import ( + Any, + Callable, + Dict, + Iterable, + List, + Mapping, + Optional, + Sequence, + Tuple, + Union, +) import proto from google.cloud.aiplatform import base +from google.cloud.aiplatform_v1beta1 import types as aip_types from google.api import httpbody_pb2 from google.protobuf import struct_pb2 from google.protobuf import json_format @@ -54,6 +67,98 @@ _LOGGER = base.Logger(__name__) +def _update_deployment_spec_with_env_vars_dict_or_raise( + *, + deployment_spec: aip_types.ReasoningEngineSpec.DeploymentSpec, + env_vars: Dict[str, Union[str, aip_types.SecretRef]], +) -> None: + for key, value in env_vars.items(): + if isinstance(value, dict): + try: + secret_ref = to_proto(value, aip_types.SecretRef()) + except Exception as e: + raise ValueError(f"Failed to convert to secret ref: {value}") from e + deployment_spec.secret_env.append( + aip_types.SecretEnvVar(name=key, secret_ref=secret_ref) + ) + elif isinstance(value, aip_types.SecretRef): + deployment_spec.secret_env.append( + aip_types.SecretEnvVar(name=key, secret_ref=value) + ) + elif isinstance(value, str): + deployment_spec.env.append(aip_types.EnvVar(name=key, value=value)) + else: + raise TypeError( + f"Unknown value type in env_vars for {key}. " + f"Must be a str or SecretRef: {value}" + ) + + +def _update_deployment_spec_with_env_vars_list_or_raise( + *, + deployment_spec: aip_types.ReasoningEngineSpec.DeploymentSpec, + env_vars: Sequence[str], +) -> None: + for env_var in env_vars: + if env_var not in os.environ: + raise ValueError(f"Env var not found in os.environ: {env_var}.") + deployment_spec.env.append( + aip_types.EnvVar(name=env_var, value=os.environ[env_var]) + ) + + +def _generate_deployment_spec_or_raise( + *, + env_vars: Optional[ + Union[Sequence[str], Dict[str, Union[str, aip_types.SecretRef]]] + ] = None, + psc_interface_config: Optional[aip_types.PscInterfaceConfig] = None, + min_instances: Optional[int] = None, + max_instances: Optional[int] = None, + resource_limits: Optional[Dict[str, str]] = None, + container_concurrency: Optional[int] = None, +) -> Tuple[aip_types.ReasoningEngineSpec.DeploymentSpec, List[str]]: + deployment_spec = aip_types.ReasoningEngineSpec.DeploymentSpec() + update_masks = [] + if env_vars: + deployment_spec.env = [] + deployment_spec.secret_env = [] + if isinstance(env_vars, dict): + _update_deployment_spec_with_env_vars_dict_or_raise( + deployment_spec=deployment_spec, + env_vars=env_vars, + ) + elif isinstance(env_vars, (list, tuple)): + _update_deployment_spec_with_env_vars_list_or_raise( + deployment_spec=deployment_spec, + env_vars=env_vars, + ) + else: + raise TypeError( + f"env_vars must be a list, tuple or a dict, but got {type(env_vars)}." + ) + if deployment_spec.env: + update_masks.append("spec.deployment_spec.env") + if deployment_spec.secret_env: + update_masks.append("spec.deployment_spec.secret_env") + if psc_interface_config: + deployment_spec.psc_interface_config = psc_interface_config + update_masks.append("spec.deployment_spec.psc_interface_config") + if min_instances is not None: + deployment_spec.min_instances = min_instances + update_masks.append("spec.deployment_spec.min_instances") + if max_instances is not None: + deployment_spec.max_instances = max_instances + update_masks.append("spec.deployment_spec.max_instances") + if resource_limits: + deployment_spec.resource_limits = resource_limits + update_masks.append("spec.deployment_spec.resource_limits") + if container_concurrency is not None: + deployment_spec.container_concurrency = container_concurrency + update_masks.append("spec.deployment_spec.container_concurrency") + return deployment_spec, update_masks + + def to_proto( obj: Union[JsonDict, proto.Message], message: Optional[proto.Message] = None,