diff --git a/google/cloud/aiplatform/jobs.py b/google/cloud/aiplatform/jobs.py index 472c1661d5..db7527f037 100644 --- a/google/cloud/aiplatform/jobs.py +++ b/google/cloud/aiplatform/jobs.py @@ -2600,16 +2600,16 @@ def submit( self._gca_resource.job_spec.psc_interface_config = psc_interface_config if ( - timeout + timeout is not None or restart_job_on_worker_restart or disable_retries - or scheduling_strategy - or max_wait_duration + or scheduling_strategy is not None + or max_wait_duration is not None ): - timeout = duration_pb2.Duration(seconds=timeout) if timeout else None + timeout = duration_pb2.Duration(seconds=timeout) if timeout is not None else None max_wait_duration = ( duration_pb2.Duration(seconds=max_wait_duration) - if max_wait_duration + if max_wait_duration is not None else None ) self._gca_resource.job_spec.scheduling = gca_custom_job_compat.Scheduling( @@ -3130,16 +3130,16 @@ def _run( self._gca_resource.trial_job_spec.network = network if ( - timeout + timeout is not None or restart_job_on_worker_restart or disable_retries - or max_wait_duration - or scheduling_strategy + or max_wait_duration is not None + or scheduling_strategy is not None ): - timeout = duration_pb2.Duration(seconds=timeout) if timeout else None + timeout = duration_pb2.Duration(seconds=timeout) if timeout is not None else None max_wait_duration = ( duration_pb2.Duration(seconds=max_wait_duration) - if max_wait_duration + if max_wait_duration is not None else None ) self._gca_resource.trial_job_spec.scheduling = ( diff --git a/tests/unit/aiplatform/test_custom_job.py b/tests/unit/aiplatform/test_custom_job.py index f51ec9c948..20db320e0f 100644 --- a/tests/unit/aiplatform/test_custom_job.py +++ b/tests/unit/aiplatform/test_custom_job.py @@ -14,6 +14,7 @@ # limitations under the License. # +import datetime import pytest import logging @@ -730,6 +731,97 @@ def test_submit_custom_job(self, create_custom_job_mock, get_custom_job_mock): ) assert job.network == _TEST_NETWORK + def test_submit_custom_job_with_zero_max_wait_duration( + self, create_custom_job_mock, get_custom_job_mock): + + aiplatform.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + staging_bucket=_TEST_STAGING_BUCKET, + encryption_spec_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME, + ) + + job = aiplatform.CustomJob( + display_name=_TEST_DISPLAY_NAME, + worker_pool_specs=_TEST_WORKER_POOL_SPEC, + base_output_dir=_TEST_BASE_OUTPUT_DIR, + labels=_TEST_LABELS, + ) + + job.submit( + service_account=_TEST_SERVICE_ACCOUNT, + network=_TEST_NETWORK, + timeout=_TEST_TIMEOUT, + restart_job_on_worker_restart=_TEST_RESTART_JOB_ON_WORKER_RESTART, + create_request_timeout=None, + disable_retries=_TEST_DISABLE_RETRIES, + max_wait_duration=0, + ) + + job.wait_for_resource_creation() + + assert job.resource_name == _TEST_CUSTOM_JOB_NAME + + job.wait() + + expected_custom_job = _get_custom_job_proto() + expected_custom_job.job_spec.scheduling.max_wait_duration = datetime.timedelta(seconds=0) + + create_custom_job_mock.assert_called_once_with( + parent=_TEST_PARENT, + custom_job=expected_custom_job, + timeout=None, + ) + assert ( + job._gca_resource.state == gca_job_state_compat.JobState.JOB_STATE_PENDING + ) + + def test_submit_custom_job_with_default_max_wait_duration( + self, create_custom_job_mock, get_custom_job_mock): + + aiplatform.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + staging_bucket=_TEST_STAGING_BUCKET, + encryption_spec_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME, + ) + + job = aiplatform.CustomJob( + display_name=_TEST_DISPLAY_NAME, + worker_pool_specs=_TEST_WORKER_POOL_SPEC, + base_output_dir=_TEST_BASE_OUTPUT_DIR, + labels=_TEST_LABELS, + ) + + job.submit( + service_account=_TEST_SERVICE_ACCOUNT, + network=_TEST_NETWORK, + timeout=_TEST_TIMEOUT, + restart_job_on_worker_restart=_TEST_RESTART_JOB_ON_WORKER_RESTART, + create_request_timeout=None, + disable_retries=_TEST_DISABLE_RETRIES, + ) + + job.wait_for_resource_creation() + + assert job.resource_name == _TEST_CUSTOM_JOB_NAME + + job.wait() + + expected_custom_job = _get_custom_job_proto() + expected_custom_job.job_spec.scheduling.max_wait_duration = None + + create_custom_job_mock.assert_called_once_with( + parent=_TEST_PARENT, + custom_job=expected_custom_job, + timeout=None, + ) + + assert "max_wait_duration" not in expected_custom_job.job_spec.scheduling + assert ( + job._gca_resource.state == gca_job_state_compat.JobState.JOB_STATE_PENDING + ) + @pytest.mark.usefixtures( "get_experiment_run_mock", "get_tensorboard_run_artifact_not_found_mock" ) diff --git a/tests/unit/aiplatform/test_hyperparameter_tuning_job.py b/tests/unit/aiplatform/test_hyperparameter_tuning_job.py index 5631ad48d2..cb797aeb4c 100644 --- a/tests/unit/aiplatform/test_hyperparameter_tuning_job.py +++ b/tests/unit/aiplatform/test_hyperparameter_tuning_job.py @@ -14,9 +14,10 @@ # limitations under the License. # +import copy +import datetime import pytest -import copy from importlib import reload from unittest import mock from unittest.mock import patch @@ -523,6 +524,145 @@ def test_create_hyperparameter_tuning_job( assert job.network == _TEST_NETWORK assert job.trials == [] + def test_create_hyperparameter_tuning_job_with_zero_max_wait_duration( + self, + create_hyperparameter_tuning_job_mock, + get_hyperparameter_tuning_job_mock, + ): + + aiplatform.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + staging_bucket=_TEST_STAGING_BUCKET, + encryption_spec_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME, + ) + + custom_job = aiplatform.CustomJob( + display_name=test_constants.TrainingJobConstants._TEST_DISPLAY_NAME, + worker_pool_specs=test_constants.TrainingJobConstants._TEST_WORKER_POOL_SPEC, + base_output_dir=test_constants.TrainingJobConstants._TEST_BASE_OUTPUT_DIR, + ) + + job = aiplatform.HyperparameterTuningJob( + display_name=_TEST_DISPLAY_NAME, + custom_job=custom_job, + metric_spec={_TEST_METRIC_SPEC_KEY: _TEST_METRIC_SPEC_VALUE}, + parameter_spec={ + "lr": hpt.DoubleParameterSpec(min=0.001, max=0.1, scale="log"), + "units": hpt.IntegerParameterSpec(min=4, max=1028, scale="linear"), + "activation": hpt.CategoricalParameterSpec( + values=["relu", "sigmoid", "elu", "selu", "tanh"] + ), + "batch_size": hpt.DiscreteParameterSpec( + values=[4, 8, 16, 32, 64], + scale="linear", + conditional_parameter_spec={ + "decay": _TEST_CONDITIONAL_PARAMETER_DECAY, + "learning_rate": _TEST_CONDITIONAL_PARAMETER_LR, + }, + ), + }, + parallel_trial_count=_TEST_PARALLEL_TRIAL_COUNT, + max_trial_count=_TEST_MAX_TRIAL_COUNT, + max_failed_trial_count=_TEST_MAX_FAILED_TRIAL_COUNT, + search_algorithm=_TEST_SEARCH_ALGORITHM, + measurement_selection=_TEST_MEASUREMENT_SELECTION, + labels=_TEST_LABELS, + ) + + job.run( + service_account=_TEST_SERVICE_ACCOUNT, + network=_TEST_NETWORK, + timeout=_TEST_TIMEOUT, + restart_job_on_worker_restart=_TEST_RESTART_JOB_ON_WORKER_RESTART, + sync=True, + create_request_timeout=None, + disable_retries=_TEST_DISABLE_RETRIES, + max_wait_duration=0, + ) + + job.wait() + + expected_hyperparameter_tuning_job = _get_hyperparameter_tuning_job_proto() + expected_hyperparameter_tuning_job.trial_job_spec.scheduling.max_wait_duration = datetime.timedelta(seconds=0) + + create_hyperparameter_tuning_job_mock.assert_called_once_with( + parent=_TEST_PARENT, + hyperparameter_tuning_job=expected_hyperparameter_tuning_job, + timeout=None, + ) + assert job.state == gca_job_state_compat.JobState.JOB_STATE_SUCCEEDED + + def test_create_hyperparameter_tuning_job_with_default_max_wait_duration( + self, + create_hyperparameter_tuning_job_mock, + get_hyperparameter_tuning_job_mock, + ): + + aiplatform.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + staging_bucket=_TEST_STAGING_BUCKET, + encryption_spec_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME, + ) + + custom_job = aiplatform.CustomJob( + display_name=test_constants.TrainingJobConstants._TEST_DISPLAY_NAME, + worker_pool_specs=test_constants.TrainingJobConstants._TEST_WORKER_POOL_SPEC, + base_output_dir=test_constants.TrainingJobConstants._TEST_BASE_OUTPUT_DIR, + ) + + job = aiplatform.HyperparameterTuningJob( + display_name=_TEST_DISPLAY_NAME, + custom_job=custom_job, + metric_spec={_TEST_METRIC_SPEC_KEY: _TEST_METRIC_SPEC_VALUE}, + parameter_spec={ + "lr": hpt.DoubleParameterSpec(min=0.001, max=0.1, scale="log"), + "units": hpt.IntegerParameterSpec(min=4, max=1028, scale="linear"), + "activation": hpt.CategoricalParameterSpec( + values=["relu", "sigmoid", "elu", "selu", "tanh"] + ), + "batch_size": hpt.DiscreteParameterSpec( + values=[4, 8, 16, 32, 64], + scale="linear", + conditional_parameter_spec={ + "decay": _TEST_CONDITIONAL_PARAMETER_DECAY, + "learning_rate": _TEST_CONDITIONAL_PARAMETER_LR, + }, + ), + }, + parallel_trial_count=_TEST_PARALLEL_TRIAL_COUNT, + max_trial_count=_TEST_MAX_TRIAL_COUNT, + max_failed_trial_count=_TEST_MAX_FAILED_TRIAL_COUNT, + search_algorithm=_TEST_SEARCH_ALGORITHM, + measurement_selection=_TEST_MEASUREMENT_SELECTION, + labels=_TEST_LABELS, + ) + + job.run( + service_account=_TEST_SERVICE_ACCOUNT, + network=_TEST_NETWORK, + timeout=_TEST_TIMEOUT, + restart_job_on_worker_restart=_TEST_RESTART_JOB_ON_WORKER_RESTART, + sync=True, + create_request_timeout=None, + disable_retries=_TEST_DISABLE_RETRIES, + ) + + job.wait() + + expected_hyperparameter_tuning_job = _get_hyperparameter_tuning_job_proto() + expected_hyperparameter_tuning_job.trial_job_spec.scheduling.max_wait_duration = None + + create_hyperparameter_tuning_job_mock.assert_called_once_with( + parent=_TEST_PARENT, + hyperparameter_tuning_job=expected_hyperparameter_tuning_job, + timeout=None, + ) + + assert "max_wait_duration" not in expected_hyperparameter_tuning_job.trial_job_spec.scheduling + assert job.state == gca_job_state_compat.JobState.JOB_STATE_SUCCEEDED + @pytest.mark.parametrize("sync", [True, False]) def test_create_hyperparameter_tuning_job_with_timeout( self,