diff --git a/google/cloud/aiplatform/models.py b/google/cloud/aiplatform/models.py index b12c045548..8ebd9080ef 100644 --- a/google/cloud/aiplatform/models.py +++ b/google/cloud/aiplatform/models.py @@ -5155,6 +5155,14 @@ def resource_name(self) -> str: self._assert_gca_resource_is_available() return ModelRegistry._parse_versioned_name(self._gca_resource.name)[0] + def _sync_gca_resource(self) -> None: + """Sync GAPIC service representation of client class resource. + Uses versioned resource name so the non-default version is not lost. + """ + self._gca_resource = self._get_gca_resource( + resource_name=self.versioned_resource_name + ) + @property def name(self) -> str: """Name of this resource.""" diff --git a/tests/unit/aiplatform/test_models.py b/tests/unit/aiplatform/test_models.py index 1eb57c7804..411b1e1e15 100644 --- a/tests/unit/aiplatform/test_models.py +++ b/tests/unit/aiplatform/test_models.py @@ -4701,6 +4701,42 @@ def test_init_with_version_arg(self, get_model_with_version): # The Model yielded from upload SHOULD have a version in the versioned resource name assert model.versioned_resource_name.endswith(f"@{_TEST_VERSION_ID}") + def test_sync_gca_resource_uses_versioned_name(self, get_model_with_version): + # Regression test for https://github.com/googleapis/python-aiplatform/issues/2619 + # _sync_gca_resource must use versioned_resource_name so the non-default + # version is not silently replaced by the default version. + model = models.Model(model_name=_TEST_MODEL_NAME, version=_TEST_VERSION_ID) + get_model_with_version.reset_mock() + + model._sync_gca_resource() + + versioned_name = models.ModelRegistry._get_versioned_name( + _TEST_MODEL_PARENT, _TEST_VERSION_ID + ) + get_model_with_version.assert_called_once_with( + name=versioned_name, retry=base._DEFAULT_RETRY + ) + + def test_update_preserves_version( + self, update_model_mock, get_model_with_version + ): + # Regression test for https://github.com/googleapis/python-aiplatform/issues/2619 + # Model.update() calls _sync_gca_resource(); verify it fetches the versioned name. + model = models.Model(model_name=_TEST_MODEL_NAME, version=_TEST_VERSION_ID) + get_model_with_version.reset_mock() + + model.update(display_name=_TEST_MODEL_NAME) + + versioned_name = models.ModelRegistry._get_versioned_name( + _TEST_MODEL_PARENT, _TEST_VERSION_ID + ) + get_model_with_version.assert_called_once_with( + name=versioned_name, retry=base._DEFAULT_RETRY + ) + # Version must still be intact after update + assert model.version_id == _TEST_VERSION_ID + assert model.versioned_resource_name.endswith(f"@{_TEST_VERSION_ID}") + @pytest.mark.parametrize( "parent,location,project", [