Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 103 additions & 0 deletions tests/unit/vertex_langchain/test_reasoning_engines.py
Original file line number Diff line number Diff line change
Expand Up @@ -794,6 +794,109 @@ def test_create_reasoning_engine(
retry=_TEST_RETRY,
)

def test_create_reasoning_engine_with_env_vars(
self,
create_reasoning_engine_mock,
cloud_storage_create_bucket_mock,
tarfile_open_mock,
cloudpickle_dump_mock,
get_gca_resource_mock,
):
reasoning_engines.ReasoningEngine.create(
self.test_app,
display_name=_TEST_REASONING_ENGINE_DISPLAY_NAME,
requirements=_TEST_REASONING_ENGINE_REQUIREMENTS,
extra_packages=[_TEST_REASONING_ENGINE_EXTRA_PACKAGE_PATH],
env_vars={
"TEST_ENV_VAR": "TEST_ENV_VAR_VALUE",
"TEST_SECRET_ENV_VAR": types.SecretRef(
secret="TEST_SECRET_NAME",
version="TEST_SECRET_VERSION",
),
},
)
want_reasoning_engine = types.ReasoningEngine(
display_name=_TEST_REASONING_ENGINE_DISPLAY_NAME,
spec=types.ReasoningEngineSpec(
package_spec=types.ReasoningEngineSpec.PackageSpec(
python_version=f"{sys.version_info.major}.{sys.version_info.minor}",
pickle_object_gcs_uri=_TEST_REASONING_ENGINE_GCS_URI,
dependency_files_gcs_uri=_TEST_REASONING_ENGINE_DEPENDENCY_FILES_GCS_URI,
requirements_gcs_uri=_TEST_REASONING_ENGINE_REQUIREMENTS_GCS_URI,
),
deployment_spec=types.ReasoningEngineSpec.DeploymentSpec(
env=[
types.EnvVar(
name="TEST_ENV_VAR",
value="TEST_ENV_VAR_VALUE",
)
],
secret_env=[
types.SecretEnvVar(
name="TEST_SECRET_ENV_VAR",
secret_ref=types.SecretRef(
secret="TEST_SECRET_NAME",
version="TEST_SECRET_VERSION",
),
)
],
),
),
)
want_reasoning_engine.spec.class_methods.append(
_TEST_REASONING_ENGINE_QUERY_SCHEMA
)
create_reasoning_engine_mock.assert_called_with(
parent=_TEST_PARENT,
reasoning_engine=want_reasoning_engine,
)

@mock.patch.dict(os.environ, {"TEST_ENV_VAR_FROM_OS": "os-value"})
def test_generate_deployment_spec_from_env_var_names(self):
deployment_spec, update_masks = _utils._generate_deployment_spec_or_raise(
env_vars=["TEST_ENV_VAR_FROM_OS"],
)

assert _utils.to_dict(deployment_spec) == {
"env": [{"name": "TEST_ENV_VAR_FROM_OS", "value": "os-value"}]
}
assert update_masks == ["spec.deployment_spec.env"]

def test_generate_deployment_spec_from_secret_ref_dict(self):
deployment_spec, update_masks = _utils._generate_deployment_spec_or_raise(
env_vars={
"TEST_SECRET_ENV_VAR": {
"secret": "TEST_SECRET_NAME",
"version": "TEST_SECRET_VERSION",
},
},
)

assert _utils.to_dict(deployment_spec) == {
"secretEnv": [
{
"name": "TEST_SECRET_ENV_VAR",
"secretRef": {
"secret": "TEST_SECRET_NAME",
"version": "TEST_SECRET_VERSION",
},
}
]
}
assert update_masks == ["spec.deployment_spec.secret_env"]

def test_generate_deployment_spec_rejects_invalid_env_var_value_type(self):
with pytest.raises(TypeError, match="Unknown value type in env_vars"):
_utils._generate_deployment_spec_or_raise(env_vars={"TEST_ENV_VAR": 1})

def test_generate_deployment_spec_rejects_missing_env_var_name(self):
with pytest.raises(ValueError, match="Env var not found in os.environ"):
_utils._generate_deployment_spec_or_raise(env_vars=["MISSING_ENV_VAR"])

def test_generate_deployment_spec_rejects_string_env_vars(self):
with pytest.raises(TypeError, match="env_vars must be a list, tuple or a dict"):
_utils._generate_deployment_spec_or_raise(env_vars="TEST_ENV_VAR")

@pytest.mark.usefixtures("caplog")
def test_create_reasoning_engine_warn_resource_name(
self,
Expand Down
8 changes: 8 additions & 0 deletions vertexai/reasoning_engines/_reasoning_engines.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,9 @@ def create(
gcs_dir_name: str = _DEFAULT_GCS_DIR_NAME,
sys_version: Optional[str] = None,
extra_packages: Optional[Sequence[str]] = None,
env_vars: Optional[
Union[Sequence[str], Dict[str, Union[str, aip_types.SecretRef]]]
] = None,
) -> "ReasoningEngine":
"""Creates a new ReasoningEngine.

Expand Down Expand Up @@ -301,6 +304,11 @@ def create(
reasoning_engine_spec = aip_types.ReasoningEngineSpec(
package_spec=package_spec,
)
if env_vars:
deployment_spec, _ = _utils._generate_deployment_spec_or_raise(
env_vars=env_vars,
)
reasoning_engine_spec.deployment_spec = deployment_spec
class_methods_spec = _generate_class_methods_spec_or_raise(
reasoning_engine, _get_registered_operations(reasoning_engine)
)
Expand Down
107 changes: 106 additions & 1 deletion vertexai/reasoning_engines/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,26 @@
import dataclasses
import inspect
import json
import os
import types
import typing
from typing import Any, Callable, Dict, Iterable, Mapping, Optional, Sequence, Union
from typing import (
Any,
Callable,
Dict,
Iterable,
List,
Mapping,
Optional,
Sequence,
Tuple,
Union,
)

import proto

from google.cloud.aiplatform import base
from google.cloud.aiplatform_v1beta1 import types as aip_types
from google.api import httpbody_pb2
from google.protobuf import struct_pb2
from google.protobuf import json_format
Expand Down Expand Up @@ -54,6 +67,98 @@
_LOGGER = base.Logger(__name__)


def _update_deployment_spec_with_env_vars_dict_or_raise(
*,
deployment_spec: aip_types.ReasoningEngineSpec.DeploymentSpec,
env_vars: Dict[str, Union[str, aip_types.SecretRef]],
) -> None:
for key, value in env_vars.items():
if isinstance(value, dict):
try:
secret_ref = to_proto(value, aip_types.SecretRef())
except Exception as e:
raise ValueError(f"Failed to convert to secret ref: {value}") from e
deployment_spec.secret_env.append(
aip_types.SecretEnvVar(name=key, secret_ref=secret_ref)
)
elif isinstance(value, aip_types.SecretRef):
deployment_spec.secret_env.append(
aip_types.SecretEnvVar(name=key, secret_ref=value)
)
elif isinstance(value, str):
deployment_spec.env.append(aip_types.EnvVar(name=key, value=value))
else:
raise TypeError(
f"Unknown value type in env_vars for {key}. "
f"Must be a str or SecretRef: {value}"
)


def _update_deployment_spec_with_env_vars_list_or_raise(
*,
deployment_spec: aip_types.ReasoningEngineSpec.DeploymentSpec,
env_vars: Sequence[str],
) -> None:
for env_var in env_vars:
if env_var not in os.environ:
raise ValueError(f"Env var not found in os.environ: {env_var}.")
deployment_spec.env.append(
aip_types.EnvVar(name=env_var, value=os.environ[env_var])
)


def _generate_deployment_spec_or_raise(
*,
env_vars: Optional[
Union[Sequence[str], Dict[str, Union[str, aip_types.SecretRef]]]
] = None,
psc_interface_config: Optional[aip_types.PscInterfaceConfig] = None,
min_instances: Optional[int] = None,
max_instances: Optional[int] = None,
resource_limits: Optional[Dict[str, str]] = None,
container_concurrency: Optional[int] = None,
) -> Tuple[aip_types.ReasoningEngineSpec.DeploymentSpec, List[str]]:
deployment_spec = aip_types.ReasoningEngineSpec.DeploymentSpec()
update_masks = []
if env_vars:
deployment_spec.env = []
deployment_spec.secret_env = []
if isinstance(env_vars, dict):
_update_deployment_spec_with_env_vars_dict_or_raise(
deployment_spec=deployment_spec,
env_vars=env_vars,
)
elif isinstance(env_vars, (list, tuple)):
_update_deployment_spec_with_env_vars_list_or_raise(
deployment_spec=deployment_spec,
env_vars=env_vars,
)
else:
raise TypeError(
f"env_vars must be a list, tuple or a dict, but got {type(env_vars)}."
)
if deployment_spec.env:
update_masks.append("spec.deployment_spec.env")
if deployment_spec.secret_env:
update_masks.append("spec.deployment_spec.secret_env")
if psc_interface_config:
deployment_spec.psc_interface_config = psc_interface_config
update_masks.append("spec.deployment_spec.psc_interface_config")
if min_instances is not None:
deployment_spec.min_instances = min_instances
update_masks.append("spec.deployment_spec.min_instances")
if max_instances is not None:
deployment_spec.max_instances = max_instances
update_masks.append("spec.deployment_spec.max_instances")
if resource_limits:
deployment_spec.resource_limits = resource_limits
update_masks.append("spec.deployment_spec.resource_limits")
if container_concurrency is not None:
deployment_spec.container_concurrency = container_concurrency
update_masks.append("spec.deployment_spec.container_concurrency")
return deployment_spec, update_masks


def to_proto(
obj: Union[JsonDict, proto.Message],
message: Optional[proto.Message] = None,
Expand Down