diff --git a/ax/api/configs.py b/ax/api/configs.py index 4937f38aa19..e9da44e8dc6 100644 --- a/ax/api/configs.py +++ b/ax/api/configs.py @@ -40,6 +40,7 @@ class ChoiceParameterConfig: values: list[float] | list[int] | list[str] | list[bool] parameter_type: Literal["float", "int", "str", "bool"] is_ordered: bool | None = None + log_scale: bool | None = None dependent_parameters: Mapping[TParameterValue, Sequence[str]] | None = None diff --git a/ax/api/utils/instantiation/from_config.py b/ax/api/utils/instantiation/from_config.py index 75774277549..3d1cdef426e 100644 --- a/ax/api/utils/instantiation/from_config.py +++ b/ax/api/utils/instantiation/from_config.py @@ -93,6 +93,7 @@ def parameter_from_config( parameter_type=_parameter_type_converter(config.parameter_type), values=cast(list[TParamValue], config.values), is_ordered=config.is_ordered, + log_scale=config.log_scale, dependents=cast( dict[TParamValue, list[str]] | None, config.dependent_parameters, diff --git a/ax/api/utils/instantiation/tests/test_from_config.py b/ax/api/utils/instantiation/tests/test_from_config.py index 96cf460db99..4dc33345913 100644 --- a/ax/api/utils/instantiation/tests/test_from_config.py +++ b/ax/api/utils/instantiation/tests/test_from_config.py @@ -26,6 +26,7 @@ from ax.core.search_space import SearchSpace from ax.exceptions.core import UserInputError from ax.utils.common.testutils import TestCase +from pyre_extensions import assert_is_instance class TestFromConfig(TestCase): @@ -222,6 +223,46 @@ def test_create_choice_parameter(self) -> None: value="a", ), ) + + choice_config_with_log_scale = ChoiceParameterConfig( + name="choice_param_with_log_scale", + parameter_type="float", + values=[1.0, 2.0, 3.0], + log_scale=True, + ) + self.assertEqual( + parameter_from_config(config=choice_config_with_log_scale), + ChoiceParameter( + name="choice_param_with_log_scale", + parameter_type=CoreParameterType.FLOAT, + values=[1.0, 2.0, 3.0], + log_scale=True, + ), + ) + + # log_scale=False should override the auto-detection default. + choice_config_without_log_scale = ChoiceParameterConfig( + name="choice_param_without_log_scale", + parameter_type="int", + values=[1, 10, 100], + log_scale=False, + ) + parameter = parameter_from_config(config=choice_config_without_log_scale) + self.assertFalse(assert_is_instance(parameter, ChoiceParameter).log_scale) + + # log_scale is only supported for numeric parameters. + with self.assertRaisesRegex( + UserInputError, "log_scale is only supported for numerical parameters" + ): + parameter_from_config( + config=ChoiceParameterConfig( + name="str_choice_param_with_log_scale", + parameter_type="str", + values=["a", "b", "c"], + log_scale=True, + ) + ) + self.assertFalse(any("sort_values" in str(w.message) for w in ws)) def test_experiment_from_config(self) -> None: