Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions google/cloud/aiplatform/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2600,16 +2600,16 @@ def submit(
self._gca_resource.job_spec.psc_interface_config = psc_interface_config

if (
timeout
timeout is not None
or restart_job_on_worker_restart
or disable_retries
or scheduling_strategy
or max_wait_duration
or scheduling_strategy is not None
or max_wait_duration is not None
):
timeout = duration_pb2.Duration(seconds=timeout) if timeout else None
timeout = duration_pb2.Duration(seconds=timeout) if timeout is not None else None
max_wait_duration = (
duration_pb2.Duration(seconds=max_wait_duration)
if max_wait_duration
if max_wait_duration is not None
else None
)
self._gca_resource.job_spec.scheduling = gca_custom_job_compat.Scheduling(
Expand Down Expand Up @@ -3130,16 +3130,16 @@ def _run(
self._gca_resource.trial_job_spec.network = network

if (
timeout
timeout is not None
or restart_job_on_worker_restart
or disable_retries
or max_wait_duration
or scheduling_strategy
or max_wait_duration is not None
or scheduling_strategy is not None
):
timeout = duration_pb2.Duration(seconds=timeout) if timeout else None
timeout = duration_pb2.Duration(seconds=timeout) if timeout is not None else None
max_wait_duration = (
duration_pb2.Duration(seconds=max_wait_duration)
if max_wait_duration
if max_wait_duration is not None
else None
)
self._gca_resource.trial_job_spec.scheduling = (
Expand Down
92 changes: 92 additions & 0 deletions tests/unit/aiplatform/test_custom_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.
#

import datetime
import pytest
import logging

Expand Down Expand Up @@ -730,6 +731,97 @@ def test_submit_custom_job(self, create_custom_job_mock, get_custom_job_mock):
)
assert job.network == _TEST_NETWORK

def test_submit_custom_job_with_zero_max_wait_duration(
self, create_custom_job_mock, get_custom_job_mock):

aiplatform.init(
project=_TEST_PROJECT,
location=_TEST_LOCATION,
staging_bucket=_TEST_STAGING_BUCKET,
encryption_spec_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME,
)

job = aiplatform.CustomJob(
display_name=_TEST_DISPLAY_NAME,
worker_pool_specs=_TEST_WORKER_POOL_SPEC,
base_output_dir=_TEST_BASE_OUTPUT_DIR,
labels=_TEST_LABELS,
)

job.submit(
service_account=_TEST_SERVICE_ACCOUNT,
network=_TEST_NETWORK,
timeout=_TEST_TIMEOUT,
restart_job_on_worker_restart=_TEST_RESTART_JOB_ON_WORKER_RESTART,
create_request_timeout=None,
disable_retries=_TEST_DISABLE_RETRIES,
max_wait_duration=0,
)

job.wait_for_resource_creation()

assert job.resource_name == _TEST_CUSTOM_JOB_NAME

job.wait()

expected_custom_job = _get_custom_job_proto()
expected_custom_job.job_spec.scheduling.max_wait_duration = datetime.timedelta(seconds=0)

create_custom_job_mock.assert_called_once_with(
parent=_TEST_PARENT,
custom_job=expected_custom_job,
timeout=None,
)
assert (
job._gca_resource.state == gca_job_state_compat.JobState.JOB_STATE_PENDING
)

def test_submit_custom_job_with_default_max_wait_duration(
self, create_custom_job_mock, get_custom_job_mock):

aiplatform.init(
project=_TEST_PROJECT,
location=_TEST_LOCATION,
staging_bucket=_TEST_STAGING_BUCKET,
encryption_spec_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME,
)

job = aiplatform.CustomJob(
display_name=_TEST_DISPLAY_NAME,
worker_pool_specs=_TEST_WORKER_POOL_SPEC,
base_output_dir=_TEST_BASE_OUTPUT_DIR,
labels=_TEST_LABELS,
)

job.submit(
service_account=_TEST_SERVICE_ACCOUNT,
network=_TEST_NETWORK,
timeout=_TEST_TIMEOUT,
restart_job_on_worker_restart=_TEST_RESTART_JOB_ON_WORKER_RESTART,
create_request_timeout=None,
disable_retries=_TEST_DISABLE_RETRIES,
)

job.wait_for_resource_creation()

assert job.resource_name == _TEST_CUSTOM_JOB_NAME

job.wait()

expected_custom_job = _get_custom_job_proto()
expected_custom_job.job_spec.scheduling.max_wait_duration = None

create_custom_job_mock.assert_called_once_with(
parent=_TEST_PARENT,
custom_job=expected_custom_job,
timeout=None,
)

assert "max_wait_duration" not in expected_custom_job.job_spec.scheduling
assert (
job._gca_resource.state == gca_job_state_compat.JobState.JOB_STATE_PENDING
)

@pytest.mark.usefixtures(
"get_experiment_run_mock", "get_tensorboard_run_artifact_not_found_mock"
)
Expand Down
142 changes: 141 additions & 1 deletion tests/unit/aiplatform/test_hyperparameter_tuning_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@
# limitations under the License.
#

import copy
import datetime
import pytest

import copy
from importlib import reload
from unittest import mock
from unittest.mock import patch
Expand Down Expand Up @@ -523,6 +524,145 @@ def test_create_hyperparameter_tuning_job(
assert job.network == _TEST_NETWORK
assert job.trials == []

def test_create_hyperparameter_tuning_job_with_zero_max_wait_duration(
self,
create_hyperparameter_tuning_job_mock,
get_hyperparameter_tuning_job_mock,
):

aiplatform.init(
project=_TEST_PROJECT,
location=_TEST_LOCATION,
staging_bucket=_TEST_STAGING_BUCKET,
encryption_spec_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME,
)

custom_job = aiplatform.CustomJob(
display_name=test_constants.TrainingJobConstants._TEST_DISPLAY_NAME,
worker_pool_specs=test_constants.TrainingJobConstants._TEST_WORKER_POOL_SPEC,
base_output_dir=test_constants.TrainingJobConstants._TEST_BASE_OUTPUT_DIR,
)

job = aiplatform.HyperparameterTuningJob(
display_name=_TEST_DISPLAY_NAME,
custom_job=custom_job,
metric_spec={_TEST_METRIC_SPEC_KEY: _TEST_METRIC_SPEC_VALUE},
parameter_spec={
"lr": hpt.DoubleParameterSpec(min=0.001, max=0.1, scale="log"),
"units": hpt.IntegerParameterSpec(min=4, max=1028, scale="linear"),
"activation": hpt.CategoricalParameterSpec(
values=["relu", "sigmoid", "elu", "selu", "tanh"]
),
"batch_size": hpt.DiscreteParameterSpec(
values=[4, 8, 16, 32, 64],
scale="linear",
conditional_parameter_spec={
"decay": _TEST_CONDITIONAL_PARAMETER_DECAY,
"learning_rate": _TEST_CONDITIONAL_PARAMETER_LR,
},
),
},
parallel_trial_count=_TEST_PARALLEL_TRIAL_COUNT,
max_trial_count=_TEST_MAX_TRIAL_COUNT,
max_failed_trial_count=_TEST_MAX_FAILED_TRIAL_COUNT,
search_algorithm=_TEST_SEARCH_ALGORITHM,
measurement_selection=_TEST_MEASUREMENT_SELECTION,
labels=_TEST_LABELS,
)

job.run(
service_account=_TEST_SERVICE_ACCOUNT,
network=_TEST_NETWORK,
timeout=_TEST_TIMEOUT,
restart_job_on_worker_restart=_TEST_RESTART_JOB_ON_WORKER_RESTART,
sync=True,
create_request_timeout=None,
disable_retries=_TEST_DISABLE_RETRIES,
max_wait_duration=0,
)

job.wait()

expected_hyperparameter_tuning_job = _get_hyperparameter_tuning_job_proto()
expected_hyperparameter_tuning_job.trial_job_spec.scheduling.max_wait_duration = datetime.timedelta(seconds=0)

create_hyperparameter_tuning_job_mock.assert_called_once_with(
parent=_TEST_PARENT,
hyperparameter_tuning_job=expected_hyperparameter_tuning_job,
timeout=None,
)
assert job.state == gca_job_state_compat.JobState.JOB_STATE_SUCCEEDED

def test_create_hyperparameter_tuning_job_with_default_max_wait_duration(
self,
create_hyperparameter_tuning_job_mock,
get_hyperparameter_tuning_job_mock,
):

aiplatform.init(
project=_TEST_PROJECT,
location=_TEST_LOCATION,
staging_bucket=_TEST_STAGING_BUCKET,
encryption_spec_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME,
)

custom_job = aiplatform.CustomJob(
display_name=test_constants.TrainingJobConstants._TEST_DISPLAY_NAME,
worker_pool_specs=test_constants.TrainingJobConstants._TEST_WORKER_POOL_SPEC,
base_output_dir=test_constants.TrainingJobConstants._TEST_BASE_OUTPUT_DIR,
)

job = aiplatform.HyperparameterTuningJob(
display_name=_TEST_DISPLAY_NAME,
custom_job=custom_job,
metric_spec={_TEST_METRIC_SPEC_KEY: _TEST_METRIC_SPEC_VALUE},
parameter_spec={
"lr": hpt.DoubleParameterSpec(min=0.001, max=0.1, scale="log"),
"units": hpt.IntegerParameterSpec(min=4, max=1028, scale="linear"),
"activation": hpt.CategoricalParameterSpec(
values=["relu", "sigmoid", "elu", "selu", "tanh"]
),
"batch_size": hpt.DiscreteParameterSpec(
values=[4, 8, 16, 32, 64],
scale="linear",
conditional_parameter_spec={
"decay": _TEST_CONDITIONAL_PARAMETER_DECAY,
"learning_rate": _TEST_CONDITIONAL_PARAMETER_LR,
},
),
},
parallel_trial_count=_TEST_PARALLEL_TRIAL_COUNT,
max_trial_count=_TEST_MAX_TRIAL_COUNT,
max_failed_trial_count=_TEST_MAX_FAILED_TRIAL_COUNT,
search_algorithm=_TEST_SEARCH_ALGORITHM,
measurement_selection=_TEST_MEASUREMENT_SELECTION,
labels=_TEST_LABELS,
)

job.run(
service_account=_TEST_SERVICE_ACCOUNT,
network=_TEST_NETWORK,
timeout=_TEST_TIMEOUT,
restart_job_on_worker_restart=_TEST_RESTART_JOB_ON_WORKER_RESTART,
sync=True,
create_request_timeout=None,
disable_retries=_TEST_DISABLE_RETRIES,
)

job.wait()

expected_hyperparameter_tuning_job = _get_hyperparameter_tuning_job_proto()
expected_hyperparameter_tuning_job.trial_job_spec.scheduling.max_wait_duration = None

create_hyperparameter_tuning_job_mock.assert_called_once_with(
parent=_TEST_PARENT,
hyperparameter_tuning_job=expected_hyperparameter_tuning_job,
timeout=None,
)

assert "max_wait_duration" not in expected_hyperparameter_tuning_job.trial_job_spec.scheduling
assert job.state == gca_job_state_compat.JobState.JOB_STATE_SUCCEEDED

@pytest.mark.parametrize("sync", [True, False])
def test_create_hyperparameter_tuning_job_with_timeout(
self,
Expand Down
Loading