Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions google/cloud/aiplatform/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5155,6 +5155,14 @@ def resource_name(self) -> str:
self._assert_gca_resource_is_available()
return ModelRegistry._parse_versioned_name(self._gca_resource.name)[0]

def _sync_gca_resource(self) -> None:
"""Sync GAPIC service representation of client class resource.
Uses versioned resource name so the non-default version is not lost.
"""
self._gca_resource = self._get_gca_resource(
resource_name=self.versioned_resource_name
)

@property
def name(self) -> str:
"""Name of this resource."""
Expand Down
36 changes: 36 additions & 0 deletions tests/unit/aiplatform/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4701,6 +4701,42 @@ def test_init_with_version_arg(self, get_model_with_version):
# The Model yielded from upload SHOULD have a version in the versioned resource name
assert model.versioned_resource_name.endswith(f"@{_TEST_VERSION_ID}")

def test_sync_gca_resource_uses_versioned_name(self, get_model_with_version):
# Regression test for https://github.com/googleapis/python-aiplatform/issues/2619
# _sync_gca_resource must use versioned_resource_name so the non-default
# version is not silently replaced by the default version.
model = models.Model(model_name=_TEST_MODEL_NAME, version=_TEST_VERSION_ID)
get_model_with_version.reset_mock()

model._sync_gca_resource()

versioned_name = models.ModelRegistry._get_versioned_name(
_TEST_MODEL_PARENT, _TEST_VERSION_ID
)
get_model_with_version.assert_called_once_with(
name=versioned_name, retry=base._DEFAULT_RETRY
)

def test_update_preserves_version(
self, update_model_mock, get_model_with_version
):
# Regression test for https://github.com/googleapis/python-aiplatform/issues/2619
# Model.update() calls _sync_gca_resource(); verify it fetches the versioned name.
model = models.Model(model_name=_TEST_MODEL_NAME, version=_TEST_VERSION_ID)
get_model_with_version.reset_mock()

model.update(display_name=_TEST_MODEL_NAME)

versioned_name = models.ModelRegistry._get_versioned_name(
_TEST_MODEL_PARENT, _TEST_VERSION_ID
)
get_model_with_version.assert_called_once_with(
name=versioned_name, retry=base._DEFAULT_RETRY
)
# Version must still be intact after update
assert model.version_id == _TEST_VERSION_ID
assert model.versioned_resource_name.endswith(f"@{_TEST_VERSION_ID}")

@pytest.mark.parametrize(
"parent,location,project",
[
Expand Down