diff --git a/google/cloud/aiplatform/models.py b/google/cloud/aiplatform/models.py index b12c045548..0df059125c 100644 --- a/google/cloud/aiplatform/models.py +++ b/google/cloud/aiplatform/models.py @@ -5240,8 +5240,8 @@ def __init__( # Create ModelRegistry with the unversioned resource name self._registry = ModelRegistry( self.resource_name, - location=location, - project=project, + location=location or self.location, + project=project or self.project, credentials=credentials, ) diff --git a/tests/unit/aiplatform/test_models.py b/tests/unit/aiplatform/test_models.py index 1eb57c7804..bd008e0ed8 100644 --- a/tests/unit/aiplatform/test_models.py +++ b/tests/unit/aiplatform/test_models.py @@ -4701,6 +4701,33 @@ def test_init_with_version_arg(self, get_model_with_version): # The Model yielded from upload SHOULD have a version in the versioned resource name assert model.versioned_resource_name.endswith(f"@{_TEST_VERSION_ID}") + def test_versioning_registry_uses_location_from_resource_name( + self, create_client_mock + ): + # Regression test for https://github.com/googleapis/python-aiplatform/issues/2608: + # When a Model is initialized with a fully-qualified resource name that encodes a + # non-default location, the versioning registry client must use that location, not + # the global default from aiplatform.init(). + models.Model(_TEST_MODEL_RESOURCE_NAME_CUSTOM_LOCATION) + create_client_mock.assert_any_call( + client_class=utils.ModelClientWithOverride, + credentials=initializer.global_config.credentials, + location_override=_TEST_LOCATION_2, + appended_user_agent=None, + ) + + def test_versioning_registry_uses_project_from_resource_name( + self, get_model_with_custom_project_mock + ): + # Regression test for https://github.com/googleapis/python-aiplatform/issues/2608: + # When a Model is initialized with a fully-qualified resource name that encodes a + # non-default project, the versioning registry must use that project, not the + # global default from aiplatform.init(). + model = models.Model(_TEST_MODEL_RESOURCE_NAME_CUSTOM_PROJECT) + assert model._registry.model_resource_name.startswith( + f"projects/{_TEST_PROJECT_2}/" + ) + @pytest.mark.parametrize( "parent,location,project", [