From 2d473a76c7db35d02569d84369ff2431ccd20f61 Mon Sep 17 00:00:00 2001 From: Christian Leopoldseder Date: Mon, 27 Apr 2026 07:44:39 -0700 Subject: [PATCH] feat: GenAI SDK client(multimodal) - Add `to_batch_job_source` and `get_batch_job_destination` to `MultimodalDataset` PiperOrigin-RevId: 906352851 --- .../test_create_multimodal_datasets.py | 45 +++--- .../genai/test_multimodal_datasets_genai.py | 137 +++++++++++------- vertexai/_genai/_datasets_utils.py | 11 +- vertexai/_genai/types/common.py | 22 +++ 4 files changed, 142 insertions(+), 73 deletions(-) diff --git a/tests/unit/vertexai/genai/replays/test_create_multimodal_datasets.py b/tests/unit/vertexai/genai/replays/test_create_multimodal_datasets.py index 9fa1711ac9..830390507f 100644 --- a/tests/unit/vertexai/genai/replays/test_create_multimodal_datasets.py +++ b/tests/unit/vertexai/genai/replays/test_create_multimodal_datasets.py @@ -74,11 +74,20 @@ def mock_import_bigframes(is_replay_mode): @pytest.fixture def mock_generate_multimodal_dataset_display_name(): - with mock.patch.object( + with mock.patch.object( _datasets_utils, "generate_multimodal_dataset_display_name" ) as mock_generate: - mock_generate.return_value = "test-generated-name" - yield mock_generate + mock_generate.return_value = "test-generated-name" + yield mock_generate + + +@pytest.fixture +def mock_get_batch_job_unique_name(): + with mock.patch.object( + _datasets_utils, "get_batch_job_unique_name" + ) as mock_unique_name: + mock_unique_name.return_value = "12345678901234_abcde" + yield mock_unique_name def test_create_dataset(client): @@ -169,21 +178,21 @@ def test_create_dataset_from_pandas(client, is_replay_mode): ) @pytest.mark.usefixtures("mock_bigquery_client", "mock_import_bigframes") def test_create_dataset_from_bigframes(client, is_replay_mode): - import bigframes.pandas + import bigframes.pandas - dataframe = pd.DataFrame( + dataframe = pd.DataFrame( { "col1": ["col1"], "col2": ["col2"], } ) - if is_replay_mode: - bf_dataframe = mock.MagicMock() - bf_dataframe.to_gbq.return_value = "temp_table_id" - else: - bf_dataframe = bigframes.pandas.DataFrame(dataframe) + if is_replay_mode: + bf_dataframe = mock.MagicMock() + bf_dataframe.to_gbq.return_value = "temp_table_id" + else: + bf_dataframe = bigframes.pandas.DataFrame(dataframe) - dataset = client.datasets.create_from_bigframes( + dataset = client.datasets.create_from_bigframes( dataframe=bf_dataframe, target_table_id=BIGQUERY_TABLE_NAME, multimodal_dataset={ @@ -191,21 +200,21 @@ def test_create_dataset_from_bigframes(client, is_replay_mode): }, ) - assert isinstance(dataset, types.MultimodalDataset) - assert dataset.display_name == "test-from-bigframes" - assert dataset.metadata.input_config.bigquery_source.uri == ( + assert isinstance(dataset, types.MultimodalDataset) + assert dataset.display_name == "test-from-bigframes" + assert dataset.metadata.input_config.bigquery_source.uri == ( f"bq://{BIGQUERY_TABLE_NAME}" ) - if not is_replay_mode: - bigquery_client = bigquery.Client( + if not is_replay_mode: + bigquery_client = bigquery.Client( project=client._api_client.project, location=client._api_client.location, credentials=client._api_client._credentials, ) - rows = bigquery_client.list_rows( + rows = bigquery_client.list_rows( dataset.metadata.input_config.bigquery_source.uri[5:] ) - pd.testing.assert_frame_equal( + pd.testing.assert_frame_equal( rows.to_dataframe(), dataframe, check_index_type=False ) diff --git a/tests/unit/vertexai/genai/test_multimodal_datasets_genai.py b/tests/unit/vertexai/genai/test_multimodal_datasets_genai.py index c120bcc95c..28a961ba85 100644 --- a/tests/unit/vertexai/genai/test_multimodal_datasets_genai.py +++ b/tests/unit/vertexai/genai/test_multimodal_datasets_genai.py @@ -23,23 +23,32 @@ @pytest.fixture def mock_import_bigframes(): - with mock.patch.object( + with mock.patch.object( _datasets_utils, "_try_import_bigframes" ) as mock_import_bigframes: - mock_read_gbq_table_result = mock.MagicMock() - mock_read_gbq_table_result.sql = "SELECT * FROM `project.dataset.table`" + mock_read_gbq_table_result = mock.MagicMock() + mock_read_gbq_table_result.sql = "SELECT * FROM `project.dataset.table`" - bigframes = mock.MagicMock() - bigframes.pandas.read_gbq_table.return_value = mock_read_gbq_table_result + bigframes = mock.MagicMock() + bigframes.pandas.read_gbq_table.return_value = mock_read_gbq_table_result - mock_import_bigframes.return_value = bigframes - yield mock_import_bigframes + mock_import_bigframes.return_value = bigframes + yield mock_import_bigframes + + +@pytest.fixture +def mock_get_batch_job_unique_name(): + with mock.patch.object( + _datasets_utils, "get_batch_job_unique_name" + ) as mock_unique_name: + mock_unique_name.return_value = "12345678901234_abcde" + yield mock_unique_name class TestMultimodalDataset: - def test_read_config(self): - dataset = types.MultimodalDataset( + def test_read_config(self): + dataset = types.MultimodalDataset( metadata={ "gemini_request_read_config": { "assembled_request_column_name": "test_column", @@ -47,30 +56,30 @@ def test_read_config(self): }, ) - assert isinstance(dataset.read_config, types.GeminiRequestReadConfig) - assert dataset.read_config.assembled_request_column_name == "test_column" + assert isinstance(dataset.read_config, types.GeminiRequestReadConfig) + assert dataset.read_config.assembled_request_column_name == "test_column" - def test_read_config_empty(self): - dataset = types.MultimodalDataset() - assert dataset.read_config is None + def test_read_config_empty(self): + dataset = types.MultimodalDataset() + assert dataset.read_config is None - def test_set_read_config(self): - dataset = types.MultimodalDataset() + def test_set_read_config(self): + dataset = types.MultimodalDataset() - dataset.set_read_config( + dataset.set_read_config( read_config={ "assembled_request_column_name": "test_column", }, ) - assert isinstance(dataset, types.MultimodalDataset) - assert ( + assert isinstance(dataset, types.MultimodalDataset) + assert ( dataset.metadata.gemini_request_read_config.assembled_request_column_name == "test_column" ) - def test_set_read_config_preserves_other_fields(self): - dataset = types.MultimodalDataset( + def test_set_read_config_preserves_other_fields(self): + dataset = types.MultimodalDataset( metadata={ "inputConfig": { "bigquerySource": {"uri": "bq://test_table"}, @@ -78,21 +87,21 @@ def test_set_read_config_preserves_other_fields(self): }, ) - dataset.set_read_config( + dataset.set_read_config( read_config={ "assembled_request_column_name": "test_column", }, ) - assert isinstance(dataset, types.MultimodalDataset) - assert ( + assert isinstance(dataset, types.MultimodalDataset) + assert ( dataset.metadata.gemini_request_read_config.assembled_request_column_name == "test_column" ) - assert dataset.metadata.input_config.bigquery_source.uri == "bq://test_table" + assert dataset.metadata.input_config.bigquery_source.uri == "bq://test_table" - def test_bigquery_uri(self): - dataset = types.MultimodalDataset( + def test_bigquery_uri(self): + dataset = types.MultimodalDataset( metadata={ "inputConfig": { "bigquerySource": {"uri": "bq://project.dataset.table"}, @@ -100,36 +109,36 @@ def test_bigquery_uri(self): }, ) - assert dataset.bigquery_uri == "bq://project.dataset.table" + assert dataset.bigquery_uri == "bq://project.dataset.table" - def test_bigquery_uri_empty(self): - dataset = types.MultimodalDataset() - assert dataset.bigquery_uri is None + def test_bigquery_uri_empty(self): + dataset = types.MultimodalDataset() + assert dataset.bigquery_uri is None - def test_set_bigquery_uri(self): - dataset = types.MultimodalDataset() + def test_set_bigquery_uri(self): + dataset = types.MultimodalDataset() - dataset.set_bigquery_uri("bq://project.dataset.table") + dataset.set_bigquery_uri("bq://project.dataset.table") - assert isinstance(dataset, types.MultimodalDataset) - assert ( + assert isinstance(dataset, types.MultimodalDataset) + assert ( dataset.metadata.input_config.bigquery_source.uri == "bq://project.dataset.table" ) - def test_set_bigquery_uri_without_prefix(self): - dataset = types.MultimodalDataset() + def test_set_bigquery_uri_without_prefix(self): + dataset = types.MultimodalDataset() - dataset.set_bigquery_uri("project.dataset.table") + dataset.set_bigquery_uri("project.dataset.table") - assert isinstance(dataset, types.MultimodalDataset) - assert ( + assert isinstance(dataset, types.MultimodalDataset) + assert ( dataset.metadata.input_config.bigquery_source.uri == "bq://project.dataset.table" ) - def test_set_bigquery_uri_preserves_other_fields(self): - dataset = types.MultimodalDataset( + def test_set_bigquery_uri_preserves_other_fields(self): + dataset = types.MultimodalDataset( metadata={ "gemini_request_read_config": { "assembled_request_column_name": "test_column", @@ -137,26 +146,48 @@ def test_set_bigquery_uri_preserves_other_fields(self): }, ) - dataset.set_bigquery_uri("bq://test_table") + dataset.set_bigquery_uri("bq://test_table") - assert isinstance(dataset, types.MultimodalDataset) - assert dataset.metadata.input_config.bigquery_source.uri == "bq://test_table" - assert ( + assert isinstance(dataset, types.MultimodalDataset) + assert dataset.metadata.input_config.bigquery_source.uri == "bq://test_table" + assert ( dataset.metadata.gemini_request_read_config.assembled_request_column_name == "test_column" ) - def test_to_bigframes(self, mock_import_bigframes): - dataset = types.MultimodalDataset() - dataset.set_bigquery_uri("bq://project.dataset.table") + def test_to_bigframes(self, mock_import_bigframes): + dataset = types.MultimodalDataset() + dataset.set_bigquery_uri("bq://project.dataset.table") - df = dataset.to_bigframes() + df = dataset.to_bigframes() - assert "project.dataset.table" in df.sql - mock_import_bigframes.return_value.pandas.read_gbq_table.assert_called_once_with( + assert "project.dataset.table" in df.sql + mock_import_bigframes.return_value.pandas.read_gbq_table.assert_called_once_with( "project.dataset.table" ) + def test_get_batch_job_destination(self, mock_get_batch_job_unique_name): + dataset = types.MultimodalDataset( + name="projects/vertex-sdk-dev/locations/us-central1/datasets/12345", + display_name="test_multimodal_dataset", + metadata={ + "inputConfig": { + "bigquerySource": { + "uri": "bq://target_project.target_dataset.target_table" + }, + }, + }, + ) + destination = dataset.get_batch_job_destination() + assert ( + destination.vertex_dataset.display_name + == "test_multimodal_dataset_batch_output_12345678901234_abcde" + ) + assert ( + destination.vertex_dataset.bigquery_destination + == "bq://target_project.target_dataset.target_table_batch_output_12345678901234_abcde" + ) + class TestGeminiRequestReadConfig: def test_single_turn_template(self): diff --git a/vertexai/_genai/_datasets_utils.py b/vertexai/_genai/_datasets_utils.py index e063e6802a..aae54c6e53 100644 --- a/vertexai/_genai/_datasets_utils.py +++ b/vertexai/_genai/_datasets_utils.py @@ -229,8 +229,15 @@ def _generate_target_table_id(dataset_id: str) -> str: def generate_multimodal_dataset_display_name() -> str: - """Generates a display name with a timestamp.""" - return f"MultimodalDataset {datetime.datetime.now().isoformat(sep=' ')}" + """Generates a display name with a timestamp.""" + return f"MultimodalDataset {datetime.datetime.now().isoformat(sep=' ')}" + + +def get_batch_job_unique_name() -> str: + """Generates a unique name suffix for a batch job destination.""" + timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S") + unique_id = uuid.uuid4().hex[0:5] + return f"{timestamp}_{unique_id}" def save_dataframe_to_bigquery( diff --git a/vertexai/_genai/types/common.py b/vertexai/_genai/types/common.py index 126eeec663..d0732bfcef 100644 --- a/vertexai/_genai/types/common.py +++ b/vertexai/_genai/types/common.py @@ -14094,6 +14094,28 @@ def to_bigframes( raise ValueError("Multimodal dataset bigquery source uri is not set.") return bigframes.pandas.read_gbq_table(self.bigquery_uri.removeprefix("bq://")) + def to_batch_job_source(self) -> "genai_types.BatchJobSource": + """Converts the dataset to a BatchJobSource.""" + return genai_types.BatchJobSource( + vertex_dataset_name=self.name, + ) + + def get_batch_job_destination(self) -> "genai_types.BatchJobDestination": + """Converts the dataset to a BatchJobDestination.""" + from .. import _datasets_utils + + unique_name = _datasets_utils.get_batch_job_unique_name() + bigquery_uri = self.bigquery_uri + if bigquery_uri is None: + raise ValueError("Multimodal dataset bigquery source uri is not set.") + curr_display_name = self.display_name or "genai_batch_job" + return genai_types.BatchJobDestination( + vertex_dataset=genai_types.VertexMultimodalDatasetDestination( + display_name=f"{curr_display_name}_batch_output_{unique_name}", + bigquery_destination=f"{bigquery_uri}_batch_output_{unique_name}", + ) + ) + class MultimodalDatasetDict(TypedDict, total=False): """Represents a multimodal dataset."""