Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion ax/adapter/adapter_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,7 +639,9 @@ def _roundtrip_transform(x: npt.NDArray) -> npt.NDArray:
observation_features
)
new_x: list[float] = [
float(observation_features[0].parameters[p]) for p in param_names
# pyrefly: ignore [bad-argument-type]
float(observation_features[0].parameters[p])
for p in param_names
]
# turn it back into an array
return np.array(new_x)
Expand Down
2 changes: 2 additions & 0 deletions ax/adapter/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,7 @@ def _set_model_space(self, arm_data: DataFrame) -> None:
# `sort_values=True` at construction time, so we can always
# sort here. Values are guaranteed numeric by the gate above,
# hence sortable.
# pyrefly: ignore [bad-argument-type]
p.set_values(sorted(cast(list[float], [*p.values, *extra_values])))
# Remove parameter constraints from the model space.
self._model_space.set_parameter_constraints([])
Expand Down Expand Up @@ -1237,6 +1238,7 @@ def gen_arms(
arms.append(arm)
if of.metadata:
candidate_metadata[arm.signature] = of.metadata
# pyrefly: ignore [bad-return]
return arms, candidate_metadata or None # None if empty cand. metadata.


Expand Down
2 changes: 2 additions & 0 deletions ax/adapter/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def __init__(
fit_on_init=fit_on_init,
)
# Re-assign for more precise typing.
# pyrefly: ignore [bad-override-mutable-attribute]
self.generator: DiscreteGenerator = generator

def _fit(
Expand Down Expand Up @@ -137,6 +138,7 @@ def _validate_gen_inputs(
"must be either a positive integer or -1."
)

# pyrefly: ignore [bad-override-param-name]
def _gen(
self,
n: int,
Expand Down
2 changes: 2 additions & 0 deletions ax/adapter/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def __init__(
fit_on_init=fit_on_init,
)
# Re-assign for more precise typing.
# pyrefly: ignore [bad-override-mutable-attribute]
self.generator: RandomGenerator = generator

def _fit(
Expand All @@ -78,6 +79,7 @@ def _fit(
"""Extracts the list of parameters from the search space."""
self.parameters = list(search_space.parameters.keys())

# pyrefly: ignore [bad-override-param-name]
def _gen(
self,
n: int,
Expand Down
2 changes: 2 additions & 0 deletions ax/adapter/tests/test_cross_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,7 @@ def test_efficient_loo_cv_is_attempted(self) -> None:
mock_posterior = GPyTorchPosterior(distribution=mock_mvn)

# Get the surrogate model from the adapter
# pyrefly: ignore [missing-attribute]
surrogate = self.adapter.generator.surrogate
model = surrogate.model

Expand Down Expand Up @@ -661,6 +662,7 @@ def _fold_gen(td: ExperimentData) -> Iterable[CVData]:
mock_posterior = mock.MagicMock()
mock_posterior.__class__.__name__ = "UnknownPosterior"

# pyrefly: ignore [missing-attribute]
surrogate = self.adapter.generator.surrogate
model = surrogate.model

Expand Down
6 changes: 5 additions & 1 deletion ax/adapter/tests/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,14 +285,18 @@ def test_ST_MTGP(self, use_saas: bool = False) -> None:
model,
SaasFullyBayesianMultiTaskGP if use_saas else MultiTaskGP,
)
# pyrefly: ignore [not-iterable]
data_covar_module, _ = model.covar_module.kernels
if use_saas is False and default_model is False:
self.assertIsInstance(data_covar_module, ScaleKernel)
base_kernel = data_covar_module.base_kernel
self.assertIsInstance(base_kernel, MaternKernel)
self.assertEqual(
base_kernel.lengthscale_prior.concentration, 6.0
# pyrefly: ignore [missing-attribute]
base_kernel.lengthscale_prior.concentration,
6.0,
)
# pyrefly: ignore [missing-attribute]
self.assertEqual(base_kernel.lengthscale_prior.rate, 3.0)
elif use_saas is False:
self.assertIsInstance(data_covar_module, RBFKernel)
Expand Down
9 changes: 8 additions & 1 deletion ax/adapter/tests/test_torch_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,10 @@ def test_best_point(self) -> None:
# UnitX removes 1 and divides by 5. Reversing here.
self.assertEqual(arm.parameters.keys(), {"x"})
self.assertAlmostEqual(
float(arm.parameters["x"]), (best_point_value * 5.0) + 1.0, places=5
# pyrefly: ignore [bad-argument-type]
float(arm.parameters["x"]),
(best_point_value * 5.0) + 1.0,
places=5,
)
# 1.0 in transformed space is 6.0 in original space.
self.assertEqual(run.arms[0].parameters, {"x": 6.0})
Expand Down Expand Up @@ -524,15 +527,19 @@ def test_candidate_metadata_propagation(self) -> None:
[list(arm.parameters.values()) for arm in exp.trials[0].arms],
dtype=torch.double,
)
# pyrefly: ignore [not-iterable]
for dataset in datasets:
self.assertTrue(torch.equal(dataset.X, X_expected))

candidate_metadata = mock_generator_fit.call_args.kwargs.get(
"candidate_metadata"
)
# pyrefly: ignore [bad-argument-type]
self.assertEqual(len(candidate_metadata), 1)
# pyrefly: ignore [unsupported-operation]
self.assertEqual(len(candidate_metadata[0]), len(exp.trials[0].arms))
self.assertEqual(
# pyrefly: ignore [unsupported-operation]
candidate_metadata[0][0],
{
"preexisting_batch_cand_metadata": "some_value",
Expand Down
5 changes: 5 additions & 0 deletions ax/adapter/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ def __init__(
)

# Re-assign self.generator for more precise typing.
# pyrefly: ignore [bad-override-mutable-attribute]
self.generator: TorchGenerator = generator

def feature_importances(self, metric_signature: str) -> dict[str, float]:
Expand Down Expand Up @@ -877,6 +878,7 @@ def _fit(
**kwargs,
)

# pyrefly: ignore [bad-override-param-name]
def _gen(
self,
n: int,
Expand Down Expand Up @@ -1017,6 +1019,7 @@ def _transform_observation_features(
try:
tobfs = np.array(
[
# pyrefly: ignore [bad-argument-type]
[float(of.parameters[p]) for p in self.parameters]
for of in observation_features
]
Expand Down Expand Up @@ -1201,10 +1204,12 @@ def _untransform_objective_thresholds(
[fixed_features_obs]
)[0]
thresholds = t.untransform_outcome_constraints(
# pyrefly: ignore [bad-argument-type]
outcome_constraints=thresholds,
fixed_features=fixed_features_obs,
)

# pyrefly: ignore [bad-return]
return thresholds

def _validate_preference_config(
Expand Down
1 change: 1 addition & 0 deletions ax/adapter/transfer_learning/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,6 +599,7 @@ def _fit(
)
# This is a bit of a hack to ensure that only the data for the target task
# is used in the X_baseline. It also avoids task feature in X_baseline.
# pyrefly: ignore [missing-attribute]
self.generator.surrogate._training_data = [
ds.datasets[ds.target_outcome_name] for ds in task_datasets
]
Expand Down
2 changes: 2 additions & 0 deletions ax/adapter/transforms/log.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def transform_search_space(self, search_space: SearchSpace) -> SearchSpace:

# Apply log10 transformation
transformed_values = [
# pyrefly: ignore [bad-argument-type]
assert_is_instance(math.log10(float(v)), TParamValue)
for v in values
]
Expand All @@ -119,6 +120,7 @@ def transform_search_space(self, search_space: SearchSpace) -> SearchSpace:
if dependents is not None:
dependents = {
math.log10(
# pyrefly: ignore [bad-argument-type]
float(assert_is_instance_of_tuple(k, (float, int)))
): v
for k, v in dependents.items()
Expand Down
1 change: 1 addition & 0 deletions ax/adapter/transforms/map_key_to_float.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ def __init__(
parameters = {MAP_KEY: {}}

self.parameters: dict[str, dict[str, Any]] = parameters
# pyrefly: ignore [bad-override-mutable-attribute]
self._parameter_list: list[RangeParameter] = []
# Construct the parameter if needed.
if is_map_data:
Expand Down
2 changes: 2 additions & 0 deletions ax/adapter/transforms/merge_repeated_measurements.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,11 @@ def __init__(
raise NotImplementedError(
"All metrics must have noise observations."
)
# pyrefly: ignore [bad-index]
arm_to_multi_obs[arm_name][m]["means"].extend(
df_m[("mean", m)].tolist()
)
# pyrefly: ignore [bad-index]
arm_to_multi_obs[arm_name][m]["vars"].extend(
(df_m[("sem", m)] ** 2).tolist()
)
Expand Down
1 change: 1 addition & 0 deletions ax/adapter/transforms/power_transform_y.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,4 +357,5 @@ def _compute_inverse_bounds(
elif lambda_ > 2.0 + tol:
bounds[0] = (1.0 / (2.0 - lambda_) - mu) / sigma
inv_bounds[k] = tuple(assert_is_instance_list(bounds, float))
# pyrefly: ignore [bad-return]
return inv_bounds
3 changes: 3 additions & 0 deletions ax/adapter/transforms/tests/test_add_execution_viability.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,11 @@ def test_transform_optimization_config(self) -> None:
config={"feasibility_threshold": 0.8},
)
original_constraints_count = len(
# pyrefly: ignore [missing-attribute]
self.experiment.optimization_config.outcome_constraints
)
new_opt_config = t.transform_optimization_config(
# pyrefly: ignore [bad-argument-type]
self.experiment.optimization_config,
)
self.assertEqual(
Expand All @@ -81,6 +83,7 @@ def test_transform_optimization_config(self) -> None:
config={},
)
new_opt_config = t.transform_optimization_config(
# pyrefly: ignore [bad-argument-type]
self.experiment.optimization_config,
)
feasibility_constraint = new_opt_config.outcome_constraints[-1]
Expand Down
5 changes: 5 additions & 0 deletions ax/adapter/transforms/tests/test_base_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def transform_observation_features(
) -> list[ObservationFeatures]:
for obs_feat in observation_features:
for param_name, param_value in obs_feat.parameters.items():
# pyrefly: ignore [bad-argument-type]
obs_feat.parameters[param_name] = float(param_value) * 2
return observation_features

Expand All @@ -45,9 +46,13 @@ def test_IdentityTransform(self) -> None:
x = MagicMock()
ys = []
ys.append(t.transform_search_space(x))
# pyrefly: ignore [bad-argument-type]
ys.append(t.transform_observation_features(x))
# pyrefly: ignore [bad-argument-type]
ys.append(t._transform_observation_data(x))
# pyrefly: ignore [bad-argument-type]
ys.append(t.untransform_observation_features(x))
# pyrefly: ignore [bad-argument-type]
ys.append(t._untransform_observation_data(x))
self.assertEqual(len(x.mock_calls), 0)
for y in ys:
Expand Down
3 changes: 3 additions & 0 deletions ax/adapter/transforms/tests/test_bilog_y.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def setUp(self) -> None:
with_relative_constraint=True,
)
self.data = self.exp.fetch_data()
# pyrefly: ignore [missing-attribute]
self.bound = self.exp.optimization_config.outcome_constraints[1].bound

def get_adapter(self) -> Adapter:
Expand Down Expand Up @@ -85,6 +86,7 @@ def test_Bilog(self) -> None:
)

def test_TransformUntransform(self) -> None:
# pyrefly: ignore [missing-attribute]
bound = self.exp.optimization_config.outcome_constraints[0].bound
observations = observations_from_data(
experiment=self.exp, data=self.exp.lookup_data()
Expand Down Expand Up @@ -154,6 +156,7 @@ def test_TransformOptimizationConfig(self) -> None:
)
oc = self.exp.optimization_config
# This should be a no-op
# pyrefly: ignore [bad-argument-type]
new_oc = t.transform_optimization_config(optimization_config=oc)
self.assertEqual(new_oc, oc)

Expand Down
1 change: 1 addition & 0 deletions ax/adapter/transforms/tests/test_cast_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,7 @@ def test_transform_experiment_data_cast(self) -> None:
.astype({"x": float, "y": float})
)
expected_arm_data["z"] = None
# pyrefly: ignore [missing-attribute]
expected_arm_data["z"] = expected_arm_data["z"].astype("Int64")
expected_arm_data = expected_arm_data[["x", "y", "z", "metadata"]]
assert_frame_equal(transformed.arm_data, expected_arm_data)
Expand Down
8 changes: 8 additions & 0 deletions ax/adapter/transforms/tests/test_choice_encode_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,11 +157,13 @@ def test_transform_search_space(self) -> None:

for param_name in ["b", "c"]:
self.assertEqual(
# pyrefly: ignore [missing-attribute]
ss2.parameters[param_name].values,
assert_is_instance(
self.search_space[param_name], ChoiceParameter
).values,
)
# pyrefly: ignore [missing-attribute]
self.assertEqual(ss2.parameters["d"].values, [0, 1, 2])

# Fidelity parameter is transformed correctly.
Expand Down Expand Up @@ -365,9 +367,13 @@ def test_transform_search_space(self) -> None:
self.assertEqual(ss2.parameters[p].parameter_type, ParameterType.INT)
self.assertEqual(ss2.parameters["d"].parameter_type, ParameterType.STRING)

# pyrefly: ignore [missing-attribute]
self.assertEqual(ss2.parameters["b"].lower, 0)
# pyrefly: ignore [missing-attribute]
self.assertEqual(ss2.parameters["b"].upper, 2)
# pyrefly: ignore [missing-attribute]
self.assertEqual(ss2.parameters["c"].lower, 0)
# pyrefly: ignore [missing-attribute]
self.assertEqual(ss2.parameters["c"].upper, 2)
self.assertEqual(ss2.parameters["d"].values, ["q", "r", "z"])

Expand Down Expand Up @@ -419,5 +425,7 @@ def test_transform_search_space_with_different_values(self) -> None:
]
)
t_ss = self.t.transform_search_space(ss)
# pyrefly: ignore [missing-attribute]
self.assertEqual(t_ss.parameters["b"].lower, 1)
# pyrefly: ignore [missing-attribute]
self.assertEqual(t_ss.parameters["b"].upper, 2)
3 changes: 3 additions & 0 deletions ax/adapter/transforms/tests/test_fill_missing_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,13 @@ def test_init_with_deprecated_config(self) -> None:
with self.assertLogs(
"ax.adapter.transforms.fill_missing_parameters", level="ERROR"
) as lg:
# pyrefly: ignore [bad-argument-type]
t = FillMissingParameters(config=self.config)
self.assertIn("deprecated", lg.output[0])
self.assertEqual(t._fill_values, self.config_values)

def test_init_with_both_config_and_search_space(self) -> None:
# pyrefly: ignore [bad-argument-type]
t = FillMissingParameters(search_space=self.search_space, config=self.config)
# Search space values should override config values
self.assertEqual(t._fill_values, self.search_space_backfill_values)
Expand Down Expand Up @@ -166,6 +168,7 @@ def test_deprecated_config_behavior_still_works(self) -> None:
ObservationFeatures(parameters={"x": None}),
ObservationFeatures(parameters={"x": 0.0}),
]
# pyrefly: ignore [bad-argument-type]
t = FillMissingParameters(config=self.config)
expected = [
ObservationFeatures(parameters={"x": 2.0, "y": 3.0}),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,7 @@ def test_TransformObservationFeatures(self) -> None:
keep_indices = [
i
for i, obs in enumerate(self.observations)
# pyrefly: ignore [unsupported-operation]
if not isnan(obs.features.metadata["step"])
]
observation_features = [self.observations[i].features for i in keep_indices]
Expand All @@ -393,6 +394,7 @@ def test_TransformObservationFeatures(self) -> None:
for i in keep_indices:
obs = self.observations[i]
obsf = obs.features.clone()
# pyrefly: ignore [missing-attribute]
obsf.parameters[self.map_key] = obsf.metadata.pop(self.map_key)
expected.append(obsf)

Expand All @@ -404,6 +406,7 @@ def test_TransformObservationFeatures(self) -> None:
keep_indices = [
i
for i, obs in enumerate(self.observations)
# pyrefly: ignore [unsupported-operation]
if isnan(obs.features.metadata["step"])
]
observation_features = [self.observations[i].features for i in keep_indices]
Expand All @@ -423,6 +426,7 @@ def test_TransformObservationFeatures(self) -> None:
untransformed = self.t.untransform_observation_features(obs_ft2)
expected = observation_features
for obs in expected:
# pyrefly: ignore [unsupported-operation]
obs.metadata["step"] = 1.0

self.assertEqual(untransformed, observation_features)
Expand Down Expand Up @@ -466,9 +470,11 @@ def test_TransformObservationFeaturesKeyNotInMetadata(self) -> None:
obs_ft2 = deepcopy(observation_features)
# remove the key from metadata dicts
for obsf in obs_ft2:
# pyrefly: ignore [missing-attribute]
obsf.metadata.pop(self.map_key)
# To avoid this being treated as empty metadata.
# In typical experiment, trial completion timestamp would be here.
# pyrefly: ignore [unsupported-operation]
obsf.metadata["dummy"] = 1.0
# should be exactly one parameter
(p,) = self.t._parameter_list
Expand Down
1 change: 1 addition & 0 deletions ax/adapter/transforms/tests/test_metadata_to_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def test_TransformObservationFeatures(self) -> None:
Keys.TASK_FEATURE_NAME.value: i,
}
)
# pyrefly: ignore [unsupported-operation]
del new_obs_ft.metadata[Keys.TASK_FEATURE_NAME]
expected_obs_ft2.append(new_obs_ft)
self.assertEqual(obs_ft2, expected_obs_ft2)
Expand Down
Loading
Loading