From 0e62ad579091d3766ef07dae81772bd11c3026a4 Mon Sep 17 00:00:00 2001 From: Maggie Moss Date: Mon, 22 Jun 2026 11:41:19 -0700 Subject: [PATCH] Enable Pyrefly for fbcode/ax (#5236) Summary: Pull Request resolved: https://github.com/facebook/Ax/pull/5236 Migrates `fbcode/ax` from Pyre to Pyrefly by removing the root `python.set_pyrefly(False)` opt-out so the tree inherits `set_pyrefly(True)` from the fbcode root default. Pre-existing type errors are suppressed via annotation-only `# pyrefly: ignore` comments; no runtime changes. ___ Reviewed By: connernilsen Differential Revision: D109082535 --- ax/adapter/adapter_utils.py | 4 +- ax/adapter/base.py | 2 + ax/adapter/discrete.py | 2 + ax/adapter/random.py | 2 + ax/adapter/tests/test_cross_validation.py | 2 + ax/adapter/tests/test_registry.py | 6 +- ax/adapter/tests/test_torch_adapter.py | 9 ++- ax/adapter/torch.py | 5 ++ ax/adapter/transfer_learning/adapter.py | 1 + ax/adapter/transforms/log.py | 2 + ax/adapter/transforms/map_key_to_float.py | 1 + .../transforms/merge_repeated_measurements.py | 2 + ax/adapter/transforms/power_transform_y.py | 1 + .../tests/test_add_execution_viability.py | 3 + .../transforms/tests/test_base_transform.py | 5 ++ ax/adapter/transforms/tests/test_bilog_y.py | 3 + .../transforms/tests/test_cast_transform.py | 1 + .../tests/test_choice_encode_transform.py | 8 +++ .../tests/test_fill_missing_parameters.py | 3 + .../tests/test_map_key_to_float_transform.py | 6 ++ .../transforms/tests/test_metadata_to_task.py | 1 + .../tests/test_objective_as_constraint.py | 1 + .../tests/test_relativize_transform.py | 1 + .../tests/test_transform_to_new_sq.py | 6 ++ ax/adapter/transforms/trial_as_task.py | 1 + ax/adapter/transforms/unit_x.py | 2 + ax/analysis/analysis.py | 1 + ax/analysis/graphviz/graphviz_analysis.py | 1 + .../healthcheck/can_generate_candidates.py | 1 + .../healthcheck/constraints_feasibility.py | 1 + .../healthcheck/metric_fetching_errors.py | 3 + .../healthcheck/no_effects_analysis.py | 1 + .../healthcheck/regression_analysis.py | 3 + .../healthcheck/regression_detection_utils.py | 1 + .../healthcheck/search_space_analysis.py | 2 + .../tests/test_baseline_improvement.py | 1 + .../tests/test_can_generate_candidates.py | 4 ++ .../tests/test_constraints_feasibility.py | 2 + .../tests/test_healtheck_exception.py | 1 + .../tests/test_metric_fetching_errors.py | 9 +++ .../tests/test_no_effects_analysis.py | 1 + .../tests/test_regression_analysis.py | 2 + .../tests/test_search_space_analysis.py | 3 + .../tests/test_transfer_learning_analysis.py | 20 ++++++ .../healthcheck/transfer_learning_analysis.py | 1 + ax/analysis/overview.py | 4 ++ .../test_objective_p_feasible_frontier.py | 6 ++ ax/analysis/results.py | 2 + ax/analysis/tests/test_diagnostics.py | 1 + ax/analysis/tests/test_overview.py | 1 + ax/analysis/tests/test_results.py | 5 ++ ax/analysis/utils.py | 1 + ax/api/client.py | 4 ++ .../instantiation/tests/test_from_config.py | 1 + ax/api/utils/storage.py | 5 ++ ax/benchmark/benchmark_metric.py | 1 + ax/benchmark/benchmark_result.py | 2 + ax/benchmark/noise.py | 1 + ax/benchmark/problems/data.py | 1 + .../problems/synthetic/from_botorch.py | 1 + ax/benchmark/testing/benchmark_stubs.py | 1 + ax/benchmark/tests/test_benchmark.py | 6 ++ ax/benchmark/tests/test_benchmark_metric.py | 3 + ax/benchmark/tests/test_benchmark_runner.py | 4 +- ax/core/__init__.py | 1 + ax/core/auxiliary.py | 1 + ax/core/data.py | 7 ++ ax/core/experiment.py | 1 + ax/core/experiment_status.py | 1 + ax/core/multi_type_experiment.py | 2 + ax/core/observation_utils.py | 6 ++ ax/core/parameter.py | 10 ++- ax/core/runner.py | 1 + ax/core/search_space.py | 4 ++ ax/core/tests/test_auxiliary_source.py | 1 + ax/core/tests/test_batch_trial.py | 5 ++ ax/core/tests/test_experiment.py | 35 +++++++++- ax/core/tests/test_multi_type_experiment.py | 8 +++ ax/core/tests/test_observation.py | 2 + ax/core/tests/test_optimization_config.py | 12 ++++ ax/core/tests/test_search_space.py | 9 +++ ax/core/tests/test_trial.py | 18 ++++- ax/core/tests/test_utils.py | 24 ++++++- ax/core/trial_status.py | 1 + ax/core/utils.py | 1 + ax/early_stopping/simulation.py | 3 + ax/early_stopping/strategies/base.py | 1 + ax/generation_strategy/dispatch_utils.py | 5 ++ .../external_generation_node.py | 2 + ax/generation_strategy/generator_spec.py | 1 + .../tests/test_best_model_selector.py | 7 +- .../tests/test_center_generation_node.py | 2 + .../tests/test_dispatch_utils.py | 21 +++++- .../tests/test_transition_criterion.py | 4 ++ ax/generators/discrete/thompson.py | 2 + ax/generators/random/base.py | 1 + ax/generators/random/uniform.py | 1 + ax/generators/tests/test_botorch_moo_utils.py | 2 + ax/generators/tests/test_thompson.py | 2 + ax/generators/tests/test_torch_model_utils.py | 1 + ax/generators/tests/test_torch_utils.py | 1 + ax/generators/tests/test_utils.py | 8 +++ .../torch/botorch_modular/acquisition.py | 5 ++ .../torch/botorch_modular/generator.py | 2 + .../botorch_modular/optimizer_argparse.py | 2 + .../torch/botorch_modular/surrogate.py | 8 +++ ax/generators/torch/botorch_modular/utils.py | 1 + ax/generators/torch/tests/test_acquisition.py | 65 ++++++++++++++++++- ax/generators/torch/tests/test_generator.py | 11 ++++ ax/generators/torch/tests/test_kernels.py | 7 +- .../torch/tests/test_optimizer_argparse.py | 1 + ax/generators/torch/tests/test_surrogate.py | 65 +++++++++++++++++-- ax/generators/torch/tests/test_utils.py | 1 + ax/generators/utils.py | 17 +++++ ax/metrics/map_replay.py | 1 + ax/metrics/noisy_function.py | 1 + ax/metrics/tensorboard.py | 2 + ax/metrics/tests/test_noisy_function.py | 6 ++ ax/orchestration/orchestrator.py | 6 ++ ax/orchestration/tests/test_orchestrator.py | 37 ++++++++++- ax/plot/contour.py | 5 ++ ax/plot/diagnostic.py | 1 + ax/plot/feature_importances.py | 7 ++ ax/plot/helper.py | 5 ++ ax/plot/pareto_frontier.py | 25 ++++++- ax/plot/scatter.py | 9 +++ ax/plot/tests/test_contours.py | 4 ++ ax/plot/tests/test_slices.py | 1 + ax/plot/tests/test_traces.py | 3 + ax/plot/trace.py | 8 +++ ax/service/ax_client.py | 5 ++ ax/service/managed_loop.py | 3 + ax/service/tests/test_ax_client.py | 5 ++ ax/service/tests/test_best_point.py | 1 + ax/service/tests/test_best_point_utils.py | 3 + ax/service/tests/test_interactive_loop.py | 2 + ax/service/tests/test_managed_loop.py | 24 +++++-- ax/service/utils/best_point.py | 1 + ax/service/utils/instantiation.py | 2 + ax/service/utils/report_utils.py | 1 + ax/storage/json_store/decoder.py | 11 ++++ ax/storage/json_store/decoders.py | 2 + .../json_store/tests/test_json_store.py | 13 ++++ ax/storage/registry_bundle.py | 3 + ax/storage/sqa_store/db.py | 1 + ax/storage/sqa_store/decoder.py | 5 ++ ax/storage/sqa_store/encoder.py | 5 ++ ax/storage/sqa_store/json.py | 1 + ax/storage/sqa_store/load.py | 4 ++ ax/storage/sqa_store/tests/test_sqa_store.py | 29 ++++++++- .../tests/test_with_db_settings_base.py | 63 +++++++++++++++--- ax/storage/sqa_store/with_db_settings_base.py | 54 +++++++++++++-- ax/storage/utils.py | 1 + ax/utils/common/mock.py | 1 + ax/utils/common/random.py | 1 + ax/utils/common/tests/test_executils.py | 15 +++++ ax/utils/common/tests/test_random.py | 1 + ax/utils/common/tests/test_result.py | 7 ++ ax/utils/common/testutils.py | 5 +- ax/utils/common/timeutils.py | 1 + ax/utils/measurement/synthetic_functions.py | 4 ++ ax/utils/sensitivity/sobol_measures.py | 1 + .../sensitivity/tests/test_sensitivity.py | 2 + ax/utils/stats/no_effects.py | 2 + ax/utils/stats/tests/test_model_fit_stats.py | 1 + ax/utils/testing/core_stubs.py | 21 +++++- ax/utils/testing/modeling_stubs.py | 1 + ax/utils/testing/preference_stubs.py | 1 + 168 files changed, 965 insertions(+), 46 deletions(-) diff --git a/ax/adapter/adapter_utils.py b/ax/adapter/adapter_utils.py index 95bd08be0c5..66df17eeec7 100644 --- a/ax/adapter/adapter_utils.py +++ b/ax/adapter/adapter_utils.py @@ -639,7 +639,9 @@ def _roundtrip_transform(x: npt.NDArray) -> npt.NDArray: observation_features ) new_x: list[float] = [ - float(observation_features[0].parameters[p]) for p in param_names + # pyrefly: ignore [bad-argument-type] + float(observation_features[0].parameters[p]) + for p in param_names ] # turn it back into an array return np.array(new_x) diff --git a/ax/adapter/base.py b/ax/adapter/base.py index 75ff506e6f3..f590a311167 100644 --- a/ax/adapter/base.py +++ b/ax/adapter/base.py @@ -488,6 +488,7 @@ def _set_model_space(self, arm_data: DataFrame) -> None: # `sort_values=True` at construction time, so we can always # sort here. Values are guaranteed numeric by the gate above, # hence sortable. + # pyrefly: ignore [bad-argument-type] p.set_values(sorted(cast(list[float], [*p.values, *extra_values]))) # Remove parameter constraints from the model space. self._model_space.set_parameter_constraints([]) @@ -1237,6 +1238,7 @@ def gen_arms( arms.append(arm) if of.metadata: candidate_metadata[arm.signature] = of.metadata + # pyrefly: ignore [bad-return] return arms, candidate_metadata or None # None if empty cand. metadata. diff --git a/ax/adapter/discrete.py b/ax/adapter/discrete.py index 8183c8f2efd..ba33ad5c97e 100644 --- a/ax/adapter/discrete.py +++ b/ax/adapter/discrete.py @@ -73,6 +73,7 @@ def __init__( fit_on_init=fit_on_init, ) # Re-assign for more precise typing. + # pyrefly: ignore [bad-override-mutable-attribute] self.generator: DiscreteGenerator = generator def _fit( @@ -137,6 +138,7 @@ def _validate_gen_inputs( "must be either a positive integer or -1." ) + # pyrefly: ignore [bad-override-param-name] def _gen( self, n: int, diff --git a/ax/adapter/random.py b/ax/adapter/random.py index 32166af0f16..001ee4553dd 100644 --- a/ax/adapter/random.py +++ b/ax/adapter/random.py @@ -68,6 +68,7 @@ def __init__( fit_on_init=fit_on_init, ) # Re-assign for more precise typing. + # pyrefly: ignore [bad-override-mutable-attribute] self.generator: RandomGenerator = generator def _fit( @@ -78,6 +79,7 @@ def _fit( """Extracts the list of parameters from the search space.""" self.parameters = list(search_space.parameters.keys()) + # pyrefly: ignore [bad-override-param-name] def _gen( self, n: int, diff --git a/ax/adapter/tests/test_cross_validation.py b/ax/adapter/tests/test_cross_validation.py index d888d2dc2a4..895e2a32864 100644 --- a/ax/adapter/tests/test_cross_validation.py +++ b/ax/adapter/tests/test_cross_validation.py @@ -543,6 +543,7 @@ def test_efficient_loo_cv_is_attempted(self) -> None: mock_posterior = GPyTorchPosterior(distribution=mock_mvn) # Get the surrogate model from the adapter + # pyrefly: ignore [missing-attribute] surrogate = self.adapter.generator.surrogate model = surrogate.model @@ -661,6 +662,7 @@ def _fold_gen(td: ExperimentData) -> Iterable[CVData]: mock_posterior = mock.MagicMock() mock_posterior.__class__.__name__ = "UnknownPosterior" + # pyrefly: ignore [missing-attribute] surrogate = self.adapter.generator.surrogate model = surrogate.model diff --git a/ax/adapter/tests/test_registry.py b/ax/adapter/tests/test_registry.py index 1657fe8e2ea..3aec9918e22 100644 --- a/ax/adapter/tests/test_registry.py +++ b/ax/adapter/tests/test_registry.py @@ -285,14 +285,18 @@ def test_ST_MTGP(self, use_saas: bool = False) -> None: model, SaasFullyBayesianMultiTaskGP if use_saas else MultiTaskGP, ) + # pyrefly: ignore [not-iterable] data_covar_module, _ = model.covar_module.kernels if use_saas is False and default_model is False: self.assertIsInstance(data_covar_module, ScaleKernel) base_kernel = data_covar_module.base_kernel self.assertIsInstance(base_kernel, MaternKernel) self.assertEqual( - base_kernel.lengthscale_prior.concentration, 6.0 + # pyrefly: ignore [missing-attribute] + base_kernel.lengthscale_prior.concentration, + 6.0, ) + # pyrefly: ignore [missing-attribute] self.assertEqual(base_kernel.lengthscale_prior.rate, 3.0) elif use_saas is False: self.assertIsInstance(data_covar_module, RBFKernel) diff --git a/ax/adapter/tests/test_torch_adapter.py b/ax/adapter/tests/test_torch_adapter.py index 63b5e55195d..adfc19ddc1a 100644 --- a/ax/adapter/tests/test_torch_adapter.py +++ b/ax/adapter/tests/test_torch_adapter.py @@ -452,7 +452,10 @@ def test_best_point(self) -> None: # UnitX removes 1 and divides by 5. Reversing here. self.assertEqual(arm.parameters.keys(), {"x"}) self.assertAlmostEqual( - float(arm.parameters["x"]), (best_point_value * 5.0) + 1.0, places=5 + # pyrefly: ignore [bad-argument-type] + float(arm.parameters["x"]), + (best_point_value * 5.0) + 1.0, + places=5, ) # 1.0 in transformed space is 6.0 in original space. self.assertEqual(run.arms[0].parameters, {"x": 6.0}) @@ -524,15 +527,19 @@ def test_candidate_metadata_propagation(self) -> None: [list(arm.parameters.values()) for arm in exp.trials[0].arms], dtype=torch.double, ) + # pyrefly: ignore [not-iterable] for dataset in datasets: self.assertTrue(torch.equal(dataset.X, X_expected)) candidate_metadata = mock_generator_fit.call_args.kwargs.get( "candidate_metadata" ) + # pyrefly: ignore [bad-argument-type] self.assertEqual(len(candidate_metadata), 1) + # pyrefly: ignore [unsupported-operation] self.assertEqual(len(candidate_metadata[0]), len(exp.trials[0].arms)) self.assertEqual( + # pyrefly: ignore [unsupported-operation] candidate_metadata[0][0], { "preexisting_batch_cand_metadata": "some_value", diff --git a/ax/adapter/torch.py b/ax/adapter/torch.py index 8383d65b120..b8c33ae09be 100644 --- a/ax/adapter/torch.py +++ b/ax/adapter/torch.py @@ -211,6 +211,7 @@ def __init__( ) # Re-assign self.generator for more precise typing. + # pyrefly: ignore [bad-override-mutable-attribute] self.generator: TorchGenerator = generator def feature_importances(self, metric_signature: str) -> dict[str, float]: @@ -877,6 +878,7 @@ def _fit( **kwargs, ) + # pyrefly: ignore [bad-override-param-name] def _gen( self, n: int, @@ -1017,6 +1019,7 @@ def _transform_observation_features( try: tobfs = np.array( [ + # pyrefly: ignore [bad-argument-type] [float(of.parameters[p]) for p in self.parameters] for of in observation_features ] @@ -1201,10 +1204,12 @@ def _untransform_objective_thresholds( [fixed_features_obs] )[0] thresholds = t.untransform_outcome_constraints( + # pyrefly: ignore [bad-argument-type] outcome_constraints=thresholds, fixed_features=fixed_features_obs, ) + # pyrefly: ignore [bad-return] return thresholds def _validate_preference_config( diff --git a/ax/adapter/transfer_learning/adapter.py b/ax/adapter/transfer_learning/adapter.py index b92ff2aea2c..676f70ad46d 100644 --- a/ax/adapter/transfer_learning/adapter.py +++ b/ax/adapter/transfer_learning/adapter.py @@ -599,6 +599,7 @@ def _fit( ) # This is a bit of a hack to ensure that only the data for the target task # is used in the X_baseline. It also avoids task feature in X_baseline. + # pyrefly: ignore [missing-attribute] self.generator.surrogate._training_data = [ ds.datasets[ds.target_outcome_name] for ds in task_datasets ] diff --git a/ax/adapter/transforms/log.py b/ax/adapter/transforms/log.py index c03d45e150e..c78fbd78c82 100644 --- a/ax/adapter/transforms/log.py +++ b/ax/adapter/transforms/log.py @@ -109,6 +109,7 @@ def transform_search_space(self, search_space: SearchSpace) -> SearchSpace: # Apply log10 transformation transformed_values = [ + # pyrefly: ignore [bad-argument-type] assert_is_instance(math.log10(float(v)), TParamValue) for v in values ] @@ -119,6 +120,7 @@ def transform_search_space(self, search_space: SearchSpace) -> SearchSpace: if dependents is not None: dependents = { math.log10( + # pyrefly: ignore [bad-argument-type] float(assert_is_instance_of_tuple(k, (float, int))) ): v for k, v in dependents.items() diff --git a/ax/adapter/transforms/map_key_to_float.py b/ax/adapter/transforms/map_key_to_float.py index 1ac256554f5..df3c4edd484 100644 --- a/ax/adapter/transforms/map_key_to_float.py +++ b/ax/adapter/transforms/map_key_to_float.py @@ -106,6 +106,7 @@ def __init__( parameters = {MAP_KEY: {}} self.parameters: dict[str, dict[str, Any]] = parameters + # pyrefly: ignore [bad-override-mutable-attribute] self._parameter_list: list[RangeParameter] = [] # Construct the parameter if needed. if is_map_data: diff --git a/ax/adapter/transforms/merge_repeated_measurements.py b/ax/adapter/transforms/merge_repeated_measurements.py index d61109372f4..bfe425995eb 100644 --- a/ax/adapter/transforms/merge_repeated_measurements.py +++ b/ax/adapter/transforms/merge_repeated_measurements.py @@ -65,9 +65,11 @@ def __init__( raise NotImplementedError( "All metrics must have noise observations." ) + # pyrefly: ignore [bad-index] arm_to_multi_obs[arm_name][m]["means"].extend( df_m[("mean", m)].tolist() ) + # pyrefly: ignore [bad-index] arm_to_multi_obs[arm_name][m]["vars"].extend( (df_m[("sem", m)] ** 2).tolist() ) diff --git a/ax/adapter/transforms/power_transform_y.py b/ax/adapter/transforms/power_transform_y.py index 8f93d561d8a..d0dd1695121 100644 --- a/ax/adapter/transforms/power_transform_y.py +++ b/ax/adapter/transforms/power_transform_y.py @@ -357,4 +357,5 @@ def _compute_inverse_bounds( elif lambda_ > 2.0 + tol: bounds[0] = (1.0 / (2.0 - lambda_) - mu) / sigma inv_bounds[k] = tuple(assert_is_instance_list(bounds, float)) + # pyrefly: ignore [bad-return] return inv_bounds diff --git a/ax/adapter/transforms/tests/test_add_execution_viability.py b/ax/adapter/transforms/tests/test_add_execution_viability.py index 5c68cad6756..98ee0eff1f9 100644 --- a/ax/adapter/transforms/tests/test_add_execution_viability.py +++ b/ax/adapter/transforms/tests/test_add_execution_viability.py @@ -56,9 +56,11 @@ def test_transform_optimization_config(self) -> None: config={"feasibility_threshold": 0.8}, ) original_constraints_count = len( + # pyrefly: ignore [missing-attribute] self.experiment.optimization_config.outcome_constraints ) new_opt_config = t.transform_optimization_config( + # pyrefly: ignore [bad-argument-type] self.experiment.optimization_config, ) self.assertEqual( @@ -81,6 +83,7 @@ def test_transform_optimization_config(self) -> None: config={}, ) new_opt_config = t.transform_optimization_config( + # pyrefly: ignore [bad-argument-type] self.experiment.optimization_config, ) feasibility_constraint = new_opt_config.outcome_constraints[-1] diff --git a/ax/adapter/transforms/tests/test_base_transform.py b/ax/adapter/transforms/tests/test_base_transform.py index f4802aafe5e..8712e69aab7 100644 --- a/ax/adapter/transforms/tests/test_base_transform.py +++ b/ax/adapter/transforms/tests/test_base_transform.py @@ -34,6 +34,7 @@ def transform_observation_features( ) -> list[ObservationFeatures]: for obs_feat in observation_features: for param_name, param_value in obs_feat.parameters.items(): + # pyrefly: ignore [bad-argument-type] obs_feat.parameters[param_name] = float(param_value) * 2 return observation_features @@ -45,9 +46,13 @@ def test_IdentityTransform(self) -> None: x = MagicMock() ys = [] ys.append(t.transform_search_space(x)) + # pyrefly: ignore [bad-argument-type] ys.append(t.transform_observation_features(x)) + # pyrefly: ignore [bad-argument-type] ys.append(t._transform_observation_data(x)) + # pyrefly: ignore [bad-argument-type] ys.append(t.untransform_observation_features(x)) + # pyrefly: ignore [bad-argument-type] ys.append(t._untransform_observation_data(x)) self.assertEqual(len(x.mock_calls), 0) for y in ys: diff --git a/ax/adapter/transforms/tests/test_bilog_y.py b/ax/adapter/transforms/tests/test_bilog_y.py index eceb7a7db40..23a83e08d36 100644 --- a/ax/adapter/transforms/tests/test_bilog_y.py +++ b/ax/adapter/transforms/tests/test_bilog_y.py @@ -33,6 +33,7 @@ def setUp(self) -> None: with_relative_constraint=True, ) self.data = self.exp.fetch_data() + # pyrefly: ignore [missing-attribute] self.bound = self.exp.optimization_config.outcome_constraints[1].bound def get_adapter(self) -> Adapter: @@ -85,6 +86,7 @@ def test_Bilog(self) -> None: ) def test_TransformUntransform(self) -> None: + # pyrefly: ignore [missing-attribute] bound = self.exp.optimization_config.outcome_constraints[0].bound observations = observations_from_data( experiment=self.exp, data=self.exp.lookup_data() @@ -154,6 +156,7 @@ def test_TransformOptimizationConfig(self) -> None: ) oc = self.exp.optimization_config # This should be a no-op + # pyrefly: ignore [bad-argument-type] new_oc = t.transform_optimization_config(optimization_config=oc) self.assertEqual(new_oc, oc) diff --git a/ax/adapter/transforms/tests/test_cast_transform.py b/ax/adapter/transforms/tests/test_cast_transform.py index d41de7085ee..695712a59ef 100644 --- a/ax/adapter/transforms/tests/test_cast_transform.py +++ b/ax/adapter/transforms/tests/test_cast_transform.py @@ -490,6 +490,7 @@ def test_transform_experiment_data_cast(self) -> None: .astype({"x": float, "y": float}) ) expected_arm_data["z"] = None + # pyrefly: ignore [missing-attribute] expected_arm_data["z"] = expected_arm_data["z"].astype("Int64") expected_arm_data = expected_arm_data[["x", "y", "z", "metadata"]] assert_frame_equal(transformed.arm_data, expected_arm_data) diff --git a/ax/adapter/transforms/tests/test_choice_encode_transform.py b/ax/adapter/transforms/tests/test_choice_encode_transform.py index 5bf8ed02208..a0a229b1731 100644 --- a/ax/adapter/transforms/tests/test_choice_encode_transform.py +++ b/ax/adapter/transforms/tests/test_choice_encode_transform.py @@ -157,11 +157,13 @@ def test_transform_search_space(self) -> None: for param_name in ["b", "c"]: self.assertEqual( + # pyrefly: ignore [missing-attribute] ss2.parameters[param_name].values, assert_is_instance( self.search_space[param_name], ChoiceParameter ).values, ) + # pyrefly: ignore [missing-attribute] self.assertEqual(ss2.parameters["d"].values, [0, 1, 2]) # Fidelity parameter is transformed correctly. @@ -365,9 +367,13 @@ def test_transform_search_space(self) -> None: self.assertEqual(ss2.parameters[p].parameter_type, ParameterType.INT) self.assertEqual(ss2.parameters["d"].parameter_type, ParameterType.STRING) + # pyrefly: ignore [missing-attribute] self.assertEqual(ss2.parameters["b"].lower, 0) + # pyrefly: ignore [missing-attribute] self.assertEqual(ss2.parameters["b"].upper, 2) + # pyrefly: ignore [missing-attribute] self.assertEqual(ss2.parameters["c"].lower, 0) + # pyrefly: ignore [missing-attribute] self.assertEqual(ss2.parameters["c"].upper, 2) self.assertEqual(ss2.parameters["d"].values, ["q", "r", "z"]) @@ -419,5 +425,7 @@ def test_transform_search_space_with_different_values(self) -> None: ] ) t_ss = self.t.transform_search_space(ss) + # pyrefly: ignore [missing-attribute] self.assertEqual(t_ss.parameters["b"].lower, 1) + # pyrefly: ignore [missing-attribute] self.assertEqual(t_ss.parameters["b"].upper, 2) diff --git a/ax/adapter/transforms/tests/test_fill_missing_parameters.py b/ax/adapter/transforms/tests/test_fill_missing_parameters.py index 82f7b16cd76..46e51a10212 100644 --- a/ax/adapter/transforms/tests/test_fill_missing_parameters.py +++ b/ax/adapter/transforms/tests/test_fill_missing_parameters.py @@ -56,11 +56,13 @@ def test_init_with_deprecated_config(self) -> None: with self.assertLogs( "ax.adapter.transforms.fill_missing_parameters", level="ERROR" ) as lg: + # pyrefly: ignore [bad-argument-type] t = FillMissingParameters(config=self.config) self.assertIn("deprecated", lg.output[0]) self.assertEqual(t._fill_values, self.config_values) def test_init_with_both_config_and_search_space(self) -> None: + # pyrefly: ignore [bad-argument-type] t = FillMissingParameters(search_space=self.search_space, config=self.config) # Search space values should override config values self.assertEqual(t._fill_values, self.search_space_backfill_values) @@ -166,6 +168,7 @@ def test_deprecated_config_behavior_still_works(self) -> None: ObservationFeatures(parameters={"x": None}), ObservationFeatures(parameters={"x": 0.0}), ] + # pyrefly: ignore [bad-argument-type] t = FillMissingParameters(config=self.config) expected = [ ObservationFeatures(parameters={"x": 2.0, "y": 3.0}), diff --git a/ax/adapter/transforms/tests/test_map_key_to_float_transform.py b/ax/adapter/transforms/tests/test_map_key_to_float_transform.py index 5ba81413886..d8b25eb0891 100644 --- a/ax/adapter/transforms/tests/test_map_key_to_float_transform.py +++ b/ax/adapter/transforms/tests/test_map_key_to_float_transform.py @@ -383,6 +383,7 @@ def test_TransformObservationFeatures(self) -> None: keep_indices = [ i for i, obs in enumerate(self.observations) + # pyrefly: ignore [unsupported-operation] if not isnan(obs.features.metadata["step"]) ] observation_features = [self.observations[i].features for i in keep_indices] @@ -393,6 +394,7 @@ def test_TransformObservationFeatures(self) -> None: for i in keep_indices: obs = self.observations[i] obsf = obs.features.clone() + # pyrefly: ignore [missing-attribute] obsf.parameters[self.map_key] = obsf.metadata.pop(self.map_key) expected.append(obsf) @@ -404,6 +406,7 @@ def test_TransformObservationFeatures(self) -> None: keep_indices = [ i for i, obs in enumerate(self.observations) + # pyrefly: ignore [unsupported-operation] if isnan(obs.features.metadata["step"]) ] observation_features = [self.observations[i].features for i in keep_indices] @@ -423,6 +426,7 @@ def test_TransformObservationFeatures(self) -> None: untransformed = self.t.untransform_observation_features(obs_ft2) expected = observation_features for obs in expected: + # pyrefly: ignore [unsupported-operation] obs.metadata["step"] = 1.0 self.assertEqual(untransformed, observation_features) @@ -466,9 +470,11 @@ def test_TransformObservationFeaturesKeyNotInMetadata(self) -> None: obs_ft2 = deepcopy(observation_features) # remove the key from metadata dicts for obsf in obs_ft2: + # pyrefly: ignore [missing-attribute] obsf.metadata.pop(self.map_key) # To avoid this being treated as empty metadata. # In typical experiment, trial completion timestamp would be here. + # pyrefly: ignore [unsupported-operation] obsf.metadata["dummy"] = 1.0 # should be exactly one parameter (p,) = self.t._parameter_list diff --git a/ax/adapter/transforms/tests/test_metadata_to_task.py b/ax/adapter/transforms/tests/test_metadata_to_task.py index c1409e80d24..fe1944963c0 100644 --- a/ax/adapter/transforms/tests/test_metadata_to_task.py +++ b/ax/adapter/transforms/tests/test_metadata_to_task.py @@ -96,6 +96,7 @@ def test_TransformObservationFeatures(self) -> None: Keys.TASK_FEATURE_NAME.value: i, } ) + # pyrefly: ignore [unsupported-operation] del new_obs_ft.metadata[Keys.TASK_FEATURE_NAME] expected_obs_ft2.append(new_obs_ft) self.assertEqual(obs_ft2, expected_obs_ft2) diff --git a/ax/adapter/transforms/tests/test_objective_as_constraint.py b/ax/adapter/transforms/tests/test_objective_as_constraint.py index 37a70dddc93..6cf93534587 100644 --- a/ax/adapter/transforms/tests/test_objective_as_constraint.py +++ b/ax/adapter/transforms/tests/test_objective_as_constraint.py @@ -123,6 +123,7 @@ def test_no_op_when_feasible_points_exist(self) -> None: # transform_optimization_config should not modify the config opt_config = none_throws(deepcopy(adapter._experiment.optimization_config)) + # pyrefly: ignore [bad-argument-type] transformed = t.transform_optimization_config(opt_config, adapter) self.assertEqual(len(transformed.outcome_constraints), 1) diff --git a/ax/adapter/transforms/tests/test_relativize_transform.py b/ax/adapter/transforms/tests/test_relativize_transform.py index a4ce889f3dd..0b2e089b9d9 100644 --- a/ax/adapter/transforms/tests/test_relativize_transform.py +++ b/ax/adapter/transforms/tests/test_relativize_transform.py @@ -289,6 +289,7 @@ class BadRelativize(BaseRelativize): for abstract_cls in [BaseRelativize, BadRelativize]: with self.assertRaisesRegex(TypeError, "Can't instantiate abstract class"): + # pyrefly: ignore [bad-instantiation] abstract_cls(search_space=None, adapter=None) def test_transform_status_quos_always_zero(self) -> None: diff --git a/ax/adapter/transforms/tests/test_transform_to_new_sq.py b/ax/adapter/transforms/tests/test_transform_to_new_sq.py index 1243860245c..7369636b8ca 100644 --- a/ax/adapter/transforms/tests/test_transform_to_new_sq.py +++ b/ax/adapter/transforms/tests/test_transform_to_new_sq.py @@ -237,8 +237,10 @@ def test_transform_experiment_data(self) -> None: assert_frame_equal(target_trial_data, transformed_target_trial_data) # Check that the data for trials 0 and 1 are transformed correctly. + # pyrefly: ignore [unsupported-operation] sq_data_target = self.adapter.status_quo_data_by_trial[2] for t_idx in (0, 1): + # pyrefly: ignore [unsupported-operation] sq_data = self.adapter.status_quo_data_by_trial[t_idx] # Get the data for the non-sq arms. trial_data = experiment_data.observation_data.loc[t_idx] @@ -315,7 +317,9 @@ def test_transform_experiment_data_retains_sq_less_trials(self) -> None: assert_frame_equal(orig_trial_1, transformed_trial_1) # Trial 0's data should still be transformed (it has SQ data). + # pyrefly: ignore [unsupported-operation] sq_data_0 = self.adapter.status_quo_data_by_trial[0] + # pyrefly: ignore [unsupported-operation] sq_data_target = self.adapter.status_quo_data_by_trial[2] orig_trial_0 = experiment_data.observation_data.loc[0] orig_trial_0_non_sq = orig_trial_0[ @@ -397,7 +401,9 @@ def test_non_relativizable_trial_preserved(self) -> None: self.assertIn("status_quo", arms) # Trial 0 should still be transformed normally. + # pyrefly: ignore [unsupported-operation] sq_data_0 = self.adapter.status_quo_data_by_trial[0] + # pyrefly: ignore [unsupported-operation] sq_data_target = self.adapter.status_quo_data_by_trial[2] orig_trial_0 = experiment_data.observation_data.loc[0] orig_trial_0_non_sq = orig_trial_0[ diff --git a/ax/adapter/transforms/trial_as_task.py b/ax/adapter/transforms/trial_as_task.py index fd251875746..0cdb1656d80 100644 --- a/ax/adapter/transforms/trial_as_task.py +++ b/ax/adapter/transforms/trial_as_task.py @@ -133,6 +133,7 @@ def __init__( del self.trial_level_map[p_name] continue if "target_trial" in self.config: + # pyrefly: ignore [bad-argument-type] target_trial = int(self.config["target_trial"]) else: target_trial = none_throws( diff --git a/ax/adapter/transforms/unit_x.py b/ax/adapter/transforms/unit_x.py index a0b6df828d4..4cce885f436 100644 --- a/ax/adapter/transforms/unit_x.py +++ b/ax/adapter/transforms/unit_x.py @@ -64,6 +64,7 @@ def transform_observation_features( for obsf in observation_features: for p_name, (l, u) in self.bounds.items(): if p_name in obsf.parameters: + # pyrefly: ignore [bad-argument-type] param = float(obsf.parameters[p_name]) obsf.parameters[p_name] = self._normalize_value(param, (l, u)) return observation_features @@ -120,6 +121,7 @@ def untransform_observation_features( for obsf in observation_features: for p_name, (l, u) in self.bounds.items(): if p_name in obsf.parameters: + # pyrefly: ignore [bad-argument-type] param = float(obsf.parameters[p_name]) obsf.parameters[p_name] = param * (u - l) + l return observation_features diff --git a/ax/analysis/analysis.py b/ax/analysis/analysis.py index d1b7c22d4ac..64e28171668 100644 --- a/ax/analysis/analysis.py +++ b/ax/analysis/analysis.py @@ -162,6 +162,7 @@ def _create_analysis_card( title=title, subtitle=subtitle, df=df, + # pyrefly: ignore [bad-argument-type] blob=df.to_json(), ) diff --git a/ax/analysis/graphviz/graphviz_analysis.py b/ax/analysis/graphviz/graphviz_analysis.py index 81bb81b2a18..24aadcf65b3 100644 --- a/ax/analysis/graphviz/graphviz_analysis.py +++ b/ax/analysis/graphviz/graphviz_analysis.py @@ -12,6 +12,7 @@ class GraphvizAnalysisCard(AnalysisCard): def get_digraph(self) -> Digraph: + # pyrefly: ignore [bad-return] return Source(self.blob) def _body_html(self, depth: int) -> str: diff --git a/ax/analysis/healthcheck/can_generate_candidates.py b/ax/analysis/healthcheck/can_generate_candidates.py index fbd775123fd..0fb626e13b9 100644 --- a/ax/analysis/healthcheck/can_generate_candidates.py +++ b/ax/analysis/healthcheck/can_generate_candidates.py @@ -22,6 +22,7 @@ @final +# pyrefly: ignore [bad-class-definition] class CanGenerateCandidatesAnalysis(Analysis): REASON_PREFIX: str = "This experiment cannot generate candidates.\nREASON: " LAST_RUN_TEMPLATE: str = "\n\nLAST TRIAL RUN: {days} day(s) ago" diff --git a/ax/analysis/healthcheck/constraints_feasibility.py b/ax/analysis/healthcheck/constraints_feasibility.py index ae4bda5e311..d1521ca3155 100644 --- a/ax/analysis/healthcheck/constraints_feasibility.py +++ b/ax/analysis/healthcheck/constraints_feasibility.py @@ -221,6 +221,7 @@ def compute( if isinstance(oc, ScalarizedOutcomeConstraint): constraint_map[str(oc)] = oc else: + # pyrefly: ignore [unsupported-operation] constraint_map[oc.metric_names[0]] = oc for col in p_feasible_cols: diff --git a/ax/analysis/healthcheck/metric_fetching_errors.py b/ax/analysis/healthcheck/metric_fetching_errors.py index f46f984cdde..88b51009567 100644 --- a/ax/analysis/healthcheck/metric_fetching_errors.py +++ b/ax/analysis/healthcheck/metric_fetching_errors.py @@ -27,6 +27,7 @@ @final +# pyrefly: ignore [bad-class-definition] class MetricFetchingErrorsAnalysis(Analysis): """ Analysis to check if any metric fetch errors occurred. @@ -125,6 +126,7 @@ def compute( # Escalate to FAIL if any metric on the optimization config has errors. status = HealthcheckStatus.WARNING if (opt_config := experiment.optimization_config) is not None: + # pyrefly: ignore [bad-assignment] errored_metric_names: set[str] = { e["metric_name"] for e in metric_fetch_errors_for_card } @@ -149,6 +151,7 @@ def compute( return create_healthcheck_analysis_card( name=self.__class__.__name__, title="Metric Fetch Errors", + # pyrefly: ignore [unsupported-operation] subtitle=subtitle.to_markdown(index=False) + remediation, df=df, status=status, diff --git a/ax/analysis/healthcheck/no_effects_analysis.py b/ax/analysis/healthcheck/no_effects_analysis.py index 4c9fcff4d24..d4bb35960ac 100644 --- a/ax/analysis/healthcheck/no_effects_analysis.py +++ b/ax/analysis/healthcheck/no_effects_analysis.py @@ -23,6 +23,7 @@ @final +# pyrefly: ignore [bad-class-definition] class TestOfNoEffectAnalysis(Analysis): """ Analysis for checking whether a randomization test can show that there are any diff --git a/ax/analysis/healthcheck/regression_analysis.py b/ax/analysis/healthcheck/regression_analysis.py index d1cc2603c1d..bdcfdf2093a 100644 --- a/ax/analysis/healthcheck/regression_analysis.py +++ b/ax/analysis/healthcheck/regression_analysis.py @@ -25,6 +25,7 @@ @final +# pyrefly: ignore [bad-class-definition] class RegressionAnalysis(Analysis): r""" Analysis for detecting the regressing arm, metric pairs across all trials with data. @@ -46,7 +47,9 @@ def __init__(self, prob_threshold: float = 0.95) -> None: """ self.prob_threshold = prob_threshold + # pyrefly: ignore [bad-override] @override + # pyrefly: ignore [bad-override] def compute( self, experiment: Experiment | None, diff --git a/ax/analysis/healthcheck/regression_detection_utils.py b/ax/analysis/healthcheck/regression_detection_utils.py index 4db141d8807..c04252486ae 100644 --- a/ax/analysis/healthcheck/regression_detection_utils.py +++ b/ax/analysis/healthcheck/regression_detection_utils.py @@ -50,6 +50,7 @@ def detect_regressions_by_trial( thresholds=thresholds, ) + # pyrefly: ignore [bad-return] return regressing_arms_metrics_by_trial diff --git a/ax/analysis/healthcheck/search_space_analysis.py b/ax/analysis/healthcheck/search_space_analysis.py index 329c9ae5b67..605441d0fa0 100644 --- a/ax/analysis/healthcheck/search_space_analysis.py +++ b/ax/analysis/healthcheck/search_space_analysis.py @@ -26,6 +26,7 @@ @final +# pyrefly: ignore [bad-class-definition] class SearchSpaceAnalysis(Analysis): r""" Analysis for checking wehther the search space of the experiment should be expanded. @@ -155,6 +156,7 @@ def search_space_boundary_proportions( num_ub = 0 # counts how many parameters are equal to the boundary's upper bound for parameterization in parameterizations: value = parameterization[parameter_name] + # pyrefly: ignore [bad-argument-type] value = float(value) # for choice parameters, we check if the value is equal to the lower # or upper bound diff --git a/ax/analysis/healthcheck/tests/test_baseline_improvement.py b/ax/analysis/healthcheck/tests/test_baseline_improvement.py index 4fb13636fed..bc1e25b6600 100644 --- a/ax/analysis/healthcheck/tests/test_baseline_improvement.py +++ b/ax/analysis/healthcheck/tests/test_baseline_improvement.py @@ -93,6 +93,7 @@ def test_multi_objective_partial_improvement(self) -> None: "branin_b": [(50.0, 0.1), (100.0, 0.1)], }, arm_names=["status_quo", "0_0"], + # pyrefly: ignore [bad-argument-type] experiment=self.moo_experiment, ) diff --git a/ax/analysis/healthcheck/tests/test_can_generate_candidates.py b/ax/analysis/healthcheck/tests/test_can_generate_candidates.py index 4837489de83..a95dab4c4d2 100644 --- a/ax/analysis/healthcheck/tests/test_can_generate_candidates.py +++ b/ax/analysis/healthcheck/tests/test_can_generate_candidates.py @@ -20,6 +20,7 @@ class TestCanGenerateCandidates(TestCase): def test_passes_if_can_generate(self) -> None: # GIVEN we can generate candidates # WHEN we run the healthcheck + # pyrefly: ignore [bad-instantiation] card = CanGenerateCandidatesAnalysis( can_generate_candidates=True, reason="No problems found.", @@ -52,6 +53,7 @@ def test_warns_if_a_trial_was_recently_run(self) -> None: trial.mark_running(no_runner_required=True) trial._time_run_started = datetime.now() - timedelta(days=1) # WHEN we run the healthcheck + # pyrefly: ignore [bad-instantiation] card = CanGenerateCandidatesAnalysis( can_generate_candidates=False, reason="The data is borked.", @@ -86,6 +88,7 @@ def test_is_fail_no_trials_have_been_run(self) -> None: trial = experiment.trials[0] self.assertEqual(trial.status, TrialStatus.CANDIDATE) # WHEN we run the healthcheck + # pyrefly: ignore [bad-instantiation] card = CanGenerateCandidatesAnalysis( can_generate_candidates=False, reason="The data is gone.", @@ -121,6 +124,7 @@ def test_is_fail_if_no_trial_was_recently_run(self) -> None: trial._time_run_started = datetime.now() - timedelta(days=3) trial.mark_completed() # WHEN we run the healthcheck + # pyrefly: ignore [bad-instantiation] card = CanGenerateCandidatesAnalysis( can_generate_candidates=False, reason="The data is old.", diff --git a/ax/analysis/healthcheck/tests/test_constraints_feasibility.py b/ax/analysis/healthcheck/tests/test_constraints_feasibility.py index b085c1e0249..a6da214c079 100644 --- a/ax/analysis/healthcheck/tests/test_constraints_feasibility.py +++ b/ax/analysis/healthcheck/tests/test_constraints_feasibility.py @@ -54,6 +54,7 @@ def _create_experiment_with_data( { "arm_name": ["status_quo", "0_0", "0_1", "0_2", "0_3", "0_4"], "metric_name": ["branin_a"] * 6, + # pyrefly: ignore [bad-argument-type] "mean": list(np.random.normal(0, 1, 6)), "sem": [0.1] * 6, "trial_index": [0] * 6, @@ -64,6 +65,7 @@ def _create_experiment_with_data( { "arm_name": ["status_quo", "0_0", "0_1", "0_2", "0_3", "0_4"], "metric_name": ["branin_b"] * 6, + # pyrefly: ignore [bad-argument-type] "mean": list(np.random.normal(0, 1, 6)), "sem": [0.1] * 6, "trial_index": [0] * 6, diff --git a/ax/analysis/healthcheck/tests/test_healtheck_exception.py b/ax/analysis/healthcheck/tests/test_healtheck_exception.py index d37b84cd364..24bfad1fac0 100644 --- a/ax/analysis/healthcheck/tests/test_healtheck_exception.py +++ b/ax/analysis/healthcheck/tests/test_healtheck_exception.py @@ -27,6 +27,7 @@ def compute( raise ValueError(ERROR_MESSAGE) def test_error_analysis_card_on_exception(self) -> None: + # pyrefly: ignore [bad-instantiation] analysis = self.DummyAnalysis() with self.assertLogs("ax.analysis.analysis", "ERROR") as logs: analysis_cards = analysis.compute_or_error_card().flatten() diff --git a/ax/analysis/healthcheck/tests/test_metric_fetching_errors.py b/ax/analysis/healthcheck/tests/test_metric_fetching_errors.py index 72b9f31bb18..88b5efe0e86 100644 --- a/ax/analysis/healthcheck/tests/test_metric_fetching_errors.py +++ b/ax/analysis/healthcheck/tests/test_metric_fetching_errors.py @@ -128,6 +128,7 @@ def test_metric_fetching_errors_with_traceback(self) -> None: orchestrator.poll_and_process_results() self.assertEqual(len(exp._metric_fetching_errors), 1) # WHEN we compute MetricFetchingErrorsAnalysis with a traceback creator + # pyrefly: ignore [bad-instantiation] card = MetricFetchingErrorsAnalysis( add_traceback_paste_callable=create_dummy_traceback_pastes ).compute(experiment=exp) @@ -188,7 +189,9 @@ def test_metric_fetching_errors_without_traceback(self) -> None: ) orchestrator.poll_and_process_results() self.assertEqual(len(exp._metric_fetching_errors), 1) + # pyrefly: ignore [bad-instantiation] # WHEN we compute MetricFetchingErrorsAnalysis without a traceback creator + # pyrefly: ignore [bad-instantiation] card = MetricFetchingErrorsAnalysis().compute(experiment=exp) # THEN we get a card with a dataframe of errors self.assertEqual(len(card.df), 1) @@ -243,8 +246,10 @@ def test_error_order(self) -> None: options=OrchestratorOptions(), ) orchestrator.poll_and_process_results() + # pyrefly: ignore [bad-instantiation] self.assertEqual(len(exp._metric_fetching_errors), 2) # WHEN we compute MetricFetchingErrorsAnalysis + # pyrefly: ignore [bad-instantiation] card = MetricFetchingErrorsAnalysis().compute(experiment=exp) # THEN we get a cards in descending ts order self.assertEqual(len(card.df), 2) @@ -269,9 +274,11 @@ def test_error_gets_updated_for_same_metric(self) -> None: orchestrator.poll_and_process_results() original_ts = exp._metric_fetching_errors[(0, "test_metric")]["timestamp"] exp.trials[0].mark_running(no_runner_required=True, unsafe=True) + # pyrefly: ignore [bad-instantiation] orchestrator.poll_and_process_results() self.assertEqual(len(exp._metric_fetching_errors), 1) + # pyrefly: ignore [bad-instantiation] card = MetricFetchingErrorsAnalysis().compute(experiment=exp) self.assertEqual(len(card.df), 1) self.assertGreater(card.df["timestamp"].iloc[0], original_ts) @@ -335,9 +342,11 @@ def test_critical_metric_errors_returns_fail(self) -> None: for case in cases: with self.subTest(case=case["label"]): exp = get_branin_experiment(**case["exp_kwargs"]) + # pyrefly: ignore [bad-instantiation] for metric_name in case["error_metrics"]: exp._metric_fetching_errors[(0, metric_name)] = ( self._make_metric_fetching_error(0, metric_name) ) + # pyrefly: ignore [bad-instantiation] card = MetricFetchingErrorsAnalysis().compute(experiment=exp) self.assertEqual(card.get_status(), HealthcheckStatus.FAIL) diff --git a/ax/analysis/healthcheck/tests/test_no_effects_analysis.py b/ax/analysis/healthcheck/tests/test_no_effects_analysis.py index 8abf2832caf..5dce9606cea 100644 --- a/ax/analysis/healthcheck/tests/test_no_effects_analysis.py +++ b/ax/analysis/healthcheck/tests/test_no_effects_analysis.py @@ -24,6 +24,7 @@ def setUp(self) -> None: self.moo_experiment = get_branin_experiment_with_multi_objective( with_trial=True ) + # pyrefly: ignore [bad-instantiation] self.tone = TestOfNoEffectAnalysis() def test_effects_detected(self) -> None: diff --git a/ax/analysis/healthcheck/tests/test_regression_analysis.py b/ax/analysis/healthcheck/tests/test_regression_analysis.py index 01ca0bf20c3..8b6dc07958a 100644 --- a/ax/analysis/healthcheck/tests/test_regression_analysis.py +++ b/ax/analysis/healthcheck/tests/test_regression_analysis.py @@ -34,6 +34,7 @@ def test_regression_analysis(self) -> None: ) experiment.attach_data(Data(df=df)) + # pyrefly: ignore [bad-instantiation] ra = RegressionAnalysis(prob_threshold=0.90) card = ra.compute(experiment=experiment, generation_strategy=None) self.assertEqual(card.name, "RegressionAnalysis") @@ -54,6 +55,7 @@ def test_regression_analysis(self) -> None: } ) experiment.attach_data(Data(df=df)) + # pyrefly: ignore [bad-instantiation] ra = RegressionAnalysis(prob_threshold=0.90) card = ra.compute(experiment=experiment, generation_strategy=None) self.assertEqual(card.name, "RegressionAnalysis") diff --git a/ax/analysis/healthcheck/tests/test_search_space_analysis.py b/ax/analysis/healthcheck/tests/test_search_space_analysis.py index c1571852c3f..01602a8688f 100644 --- a/ax/analysis/healthcheck/tests/test_search_space_analysis.py +++ b/ax/analysis/healthcheck/tests/test_search_space_analysis.py @@ -30,6 +30,7 @@ def test_search_space_analysis(self) -> None: Arm(name="1_2", parameters={"x1": -5.0, "x2": 1.0}), ] experiment.new_batch_trial(generator_run=GeneratorRun(arms=arms)) + # pyrefly: ignore [bad-instantiation] ssa = SearchSpaceAnalysis(trial_index=0) card = ssa.compute(experiment=experiment) @@ -55,6 +56,7 @@ def test_search_space_analysis(self) -> None: Arm(name="2_2", parameters={"x1": -5.0, "x2": 2.0}), ] experiment.new_batch_trial(generator_run=GeneratorRun(arms=arms)) + # pyrefly: ignore [bad-instantiation] ssa = SearchSpaceAnalysis(trial_index=1) card = ssa.compute(experiment=experiment) self.assertEqual(card.name, "SearchSpaceAnalysis") @@ -67,6 +69,7 @@ def test_search_space_analysis(self) -> None: Arm(name="2_2", parameters={"x1": -5.0, "x2": 2.0}), ] experiment.new_batch_trial(generator_run=GeneratorRun(arms=arms)) + # pyrefly: ignore [bad-instantiation] ssa = SearchSpaceAnalysis(trial_index=2) card = ssa.compute(experiment=experiment) self.assertEqual(card.name, "SearchSpaceAnalysis") diff --git a/ax/analysis/healthcheck/tests/test_transfer_learning_analysis.py b/ax/analysis/healthcheck/tests/test_transfer_learning_analysis.py index d7926503985..b347b38a7d5 100644 --- a/ax/analysis/healthcheck/tests/test_transfer_learning_analysis.py +++ b/ax/analysis/healthcheck/tests/test_transfer_learning_analysis.py @@ -54,6 +54,7 @@ def test_no_experiment_type_returns_pass(self) -> None: """When no experiment_type is set and no experiment_types provided, return PASS.""" experiment = _make_experiment(["x1", "x2"], experiment_type=None) + # pyrefly: ignore [bad-instantiation] analysis = TransferLearningAnalysis() card = analysis.compute(experiment=experiment) self.assertEqual(card.get_status(), HealthcheckStatus.PASS) @@ -62,7 +63,9 @@ def test_no_experiment_type_returns_pass(self) -> None: @patch(_MOCK_TARGET, return_value={}) def test_no_candidates_returns_pass(self, mock_identify: object) -> None: + # pyrefly: ignore [bad-instantiation] experiment = _make_experiment(["x1", "x2"], experiment_type="my_type") + # pyrefly: ignore [bad-instantiation] analysis = TransferLearningAnalysis() card = analysis.compute(experiment=experiment) self.assertEqual(card.get_status(), HealthcheckStatus.PASS) @@ -77,8 +80,10 @@ def test_single_candidate_returns_warning(self, mock_identify: object) -> None: mock_identify.return_value = { # pyre-ignore[16] "source_exp": TransferLearningMetadata( overlap_parameters=["x1", "x2", "x3", "x4"], + # pyrefly: ignore [bad-instantiation] ), } + # pyrefly: ignore [bad-instantiation] analysis = TransferLearningAnalysis() card = analysis.compute(experiment=experiment) self.assertEqual(card.get_status(), HealthcheckStatus.WARNING) @@ -107,9 +112,11 @@ def test_multiple_candidates_preserves_order(self, mock_identify: object) -> Non overlap_parameters=["x1", "x2", "x3"], ), "exp_low": TransferLearningMetadata( + # pyrefly: ignore [bad-instantiation] overlap_parameters=["x1"], ), } + # pyrefly: ignore [bad-instantiation] analysis = TransferLearningAnalysis() card = analysis.compute(experiment=experiment) self.assertEqual(card.get_status(), HealthcheckStatus.WARNING) @@ -132,10 +139,12 @@ def test_multiple_candidates_preserves_order(self, mock_identify: object) -> Non def test_percentage_calculation(self, mock_identify: object) -> None: experiment = _make_experiment(["x1", "x2", "x3"], experiment_type="my_type") mock_identify.return_value = { # pyre-ignore[16] + # pyrefly: ignore [bad-instantiation] "exp_a": TransferLearningMetadata( overlap_parameters=["x1"], ), } + # pyrefly: ignore [bad-instantiation] analysis = TransferLearningAnalysis() card = analysis.compute(experiment=experiment) self.assertEqual(card.df.iloc[0]["Overlap (%)"], 33.3) @@ -145,18 +154,22 @@ def test_parameters_listed_alphabetically(self, mock_identify: object) -> None: experiment = _make_experiment( ["alpha", "beta", "gamma", "delta"], experiment_type="my_type" ) + # pyrefly: ignore [bad-instantiation] mock_identify.return_value = { # pyre-ignore[16] "exp_a": TransferLearningMetadata( overlap_parameters=["gamma", "alpha", "delta"], ), } + # pyrefly: ignore [bad-instantiation] analysis = TransferLearningAnalysis() card = analysis.compute(experiment=experiment) self.assertEqual(card.df.iloc[0]["Parameters"], "alpha, delta, gamma") def test_requires_experiment(self) -> None: + # pyrefly: ignore [bad-instantiation] analysis = TransferLearningAnalysis() with self.assertRaises(UserInputError): + # pyrefly: ignore [bad-instantiation] analysis.compute(experiment=None) @patch(_MOCK_TARGET, return_value={}) @@ -164,6 +177,7 @@ def test_experiment_name_passed_to_identify(self, mock_identify: object) -> None """Verify that experiment.name is forwarded to identify_transferable_experiments so it can filter the target out.""" experiment = _make_experiment(["x1", "x2", "x3"], experiment_type="my_type") + # pyrefly: ignore [bad-instantiation] analysis = TransferLearningAnalysis() analysis.compute(experiment=experiment) mock_identify.assert_called_once() # pyre-ignore[16] @@ -176,6 +190,7 @@ def test_diff_paste_callable_adds_comparison_column( ) -> None: """When create_diff_paste_callable is provided, a 'Comparison' column should be added alongside the existing 'Parameters' column.""" + # pyrefly: ignore [bad-instantiation] experiment = _make_experiment( ["x1", "x2", "x3", "x4"], experiment_type="my_type" ) @@ -184,6 +199,7 @@ def test_diff_paste_callable_adds_comparison_column( overlap_parameters=["x1", "x2", "x3"], ), } + # pyrefly: ignore [bad-instantiation] analysis = TransferLearningAnalysis( create_diff_paste_callable=_dummy_create_diff_paste, ) @@ -204,6 +220,7 @@ def test_diff_paste_callable_receives_correct_content( ) mock_identify.return_value = { # pyre-ignore[16] "source_exp": TransferLearningMetadata( + # pyrefly: ignore [bad-instantiation] overlap_parameters=["gamma", "alpha"], ), } @@ -213,6 +230,7 @@ def _capture_callable(before: str, after: str, title: str) -> str: captured_args.append((before, after, title)) return "https://example.com/diff" + # pyrefly: ignore [bad-instantiation] analysis = TransferLearningAnalysis( create_diff_paste_callable=_capture_callable, ) @@ -233,6 +251,7 @@ def _capture_callable(before: str, after: str, title: str) -> str: self.assertIn("source_exp", title) self.assertIn("test_experiment", title) + # pyrefly: ignore [bad-instantiation] @patch(_MOCK_TARGET) def test_no_callable_has_no_comparison_column(self, mock_identify: object) -> None: """Without callable, the 'Parameters' column should be present @@ -243,6 +262,7 @@ def test_no_callable_has_no_comparison_column(self, mock_identify: object) -> No overlap_parameters=["x1", "x2"], ), } + # pyrefly: ignore [bad-instantiation] analysis = TransferLearningAnalysis() card = analysis.compute(experiment=experiment) self.assertIn("Parameters", card.df.columns) diff --git a/ax/analysis/healthcheck/transfer_learning_analysis.py b/ax/analysis/healthcheck/transfer_learning_analysis.py index 325f3755a1b..5dddbec36d7 100644 --- a/ax/analysis/healthcheck/transfer_learning_analysis.py +++ b/ax/analysis/healthcheck/transfer_learning_analysis.py @@ -40,6 +40,7 @@ def _body_html(self, depth: int) -> str: @final +# pyrefly: ignore [bad-class-definition] class TransferLearningAnalysis(Analysis): def __init__( self, diff --git a/ax/analysis/overview.py b/ax/analysis/overview.py index a7e454c33df..81a927e98ec 100644 --- a/ax/analysis/overview.py +++ b/ax/analysis/overview.py @@ -189,6 +189,7 @@ def compute( ) health_check_analyses = [ + # pyrefly: ignore [bad-instantiation] MetricFetchingErrorsAnalysis(), ( EarlyStoppingAnalysis( @@ -199,6 +200,7 @@ def compute( if has_map_data and has_map_metrics and not has_batch_trials else None ), + # pyrefly: ignore [bad-instantiation] CanGenerateCandidatesAnalysis( can_generate_candidates=self.can_generate, reason=self.can_generate_reason, @@ -229,11 +231,13 @@ def compute( if not has_batch_trials else None, BaselineImprovementAnalysis() if not has_batch_trials else None, + # pyrefly: ignore [bad-instantiation] TransferLearningAnalysis( config=self.sqa_config, create_diff_paste_callable=self.create_diff_paste_callable, ), *[ + # pyrefly: ignore [bad-instantiation] SearchSpaceAnalysis(trial_index=trial.index) for trial in candidate_trials ], diff --git a/ax/analysis/plotly/tests/test_objective_p_feasible_frontier.py b/ax/analysis/plotly/tests/test_objective_p_feasible_frontier.py index 5e214230424..59f54571e35 100644 --- a/ax/analysis/plotly/tests/test_objective_p_feasible_frontier.py +++ b/ax/analysis/plotly/tests/test_objective_p_feasible_frontier.py @@ -37,6 +37,7 @@ def setUp(self) -> None: ) self.experiment.optimization_config = OptimizationConfig( objective=Objective(metric=self.experiment.metrics["branin_a"]), + # pyrefly: ignore [missing-attribute] outcome_constraints=self.experiment.optimization_config.outcome_constraints, ) opt_config = none_throws(self.experiment.optimization_config) @@ -136,6 +137,7 @@ def test_compute(self) -> None: for pruning in (False, True): target = Arm(parameters={"x1": 0.0, "x2": 0.0}) if pruning else None + # pyrefly: ignore [missing-attribute] self.experiment.optimization_config.pruning_target_parameterization = target adapter = Generators.BOTORCH_MODULAR( experiment=self.experiment, @@ -188,7 +190,9 @@ def test_validate_applicable_state(self) -> None: ), ) + # pyrefly: ignore [bad-argument-type] self.experiment.optimization_config = opt_config + # pyrefly: ignore [missing-attribute] opt_config.outcome_constraints = [] self.assertIn( "requires at least one outcome constraint.", @@ -198,6 +202,7 @@ def test_validate_applicable_state(self) -> None: ) ), ) + # pyrefly: ignore [missing-attribute] opt_config.outcome_constraints = [ OutcomeConstraint( expression="1.0*branin_b + 1.0*branin_c <= 10.0", @@ -219,6 +224,7 @@ def test_validate_applicable_state(self) -> None: # Restore valid constraints and verify type().__name__ renders correctly # for the adapter/generator type check metric_name = self.experiment.metrics["branin_b"].name + # pyrefly: ignore [missing-attribute] opt_config.outcome_constraints = [ OutcomeConstraint( expression=f"{metric_name} <= 10.0", diff --git a/ax/analysis/results.py b/ax/analysis/results.py index c579a189027..babae5ba26f 100644 --- a/ax/analysis/results.py +++ b/ax/analysis/results.py @@ -128,6 +128,7 @@ def compute( relativize = experiment.status_quo is not None and has_batch_trials # Compute both observed and modeled effects for each objective and constraint. arm_effect_pair_group = ( + # pyrefly: ignore [bad-instantiation] ArmEffectsPair( metric_names=[ *regression_objective_names, @@ -342,6 +343,7 @@ def compute( @final +# pyrefly: ignore [bad-class-definition] class ArmEffectsPair(Analysis): """ Compute two ArmEffectsPlots in a single AnalysisCardGroup, one plotting model diff --git a/ax/analysis/tests/test_diagnostics.py b/ax/analysis/tests/test_diagnostics.py index 07221d955c0..038b28b79c0 100644 --- a/ax/analysis/tests/test_diagnostics.py +++ b/ax/analysis/tests/test_diagnostics.py @@ -202,6 +202,7 @@ def test_compute_bandit(self) -> None: trial.add_arms_and_weights(arms=arms).mark_running(no_runner_required=True) for arm in trial.arms: + # pyrefly: ignore [bad-argument-type] x1, x2 = float(arm.parameters["x1"]), float(arm.parameters["x2"]) data_rows.append( { diff --git a/ax/analysis/tests/test_overview.py b/ax/analysis/tests/test_overview.py index dbd145f02c9..6ea1caec48a 100644 --- a/ax/analysis/tests/test_overview.py +++ b/ax/analysis/tests/test_overview.py @@ -230,6 +230,7 @@ def test_bandit_experiment_dispatch(self) -> None: # Generate data rows in same loop for arm in trial.arms: + # pyrefly: ignore [bad-argument-type] x1, x2 = float(arm.parameters["x1"]), float(arm.parameters["x2"]) data_rows.append( { diff --git a/ax/analysis/tests/test_results.py b/ax/analysis/tests/test_results.py index 9eb21a29efc..6b38aa78cea 100644 --- a/ax/analysis/tests/test_results.py +++ b/ax/analysis/tests/test_results.py @@ -404,6 +404,7 @@ def test_compute_with_bandit_experiment(self) -> None: "arm_name": arm.name, "metric_name": "foo", "metric_signature": "foo", + # pyrefly: ignore [bad-argument-type] "mean": float(arm.parameters["x1"]), "sem": 0.1, } @@ -633,6 +634,7 @@ def test_compute(self) -> None: ) with self.subTest("valid_experiment"): + # pyrefly: ignore [bad-instantiation] analysis = ArmEffectsPair(metric_names=["branin"]) card_group = analysis.compute( experiment=experiment, @@ -651,6 +653,7 @@ def test_compute(self) -> None: ) with self.subTest("requires_experiment"): + # pyrefly: ignore [bad-instantiation] analysis = ArmEffectsPair(metric_names=["test_metric"]) with self.assertRaisesRegex(UserInputError, "requires an Experiment"): analysis.compute() @@ -684,6 +687,7 @@ def test_compute_with_status_quo(self) -> None: ) with self.subTest("relativization"): + # pyrefly: ignore [bad-instantiation] analysis = ArmEffectsPair(metric_names=["branin"], relativize=True) card_group = analysis.compute( experiment=experiment, @@ -703,6 +707,7 @@ def test_compute_with_status_quo(self) -> None: experiment.attach_data(data2) trial2.mark_completed() + # pyrefly: ignore [bad-instantiation] analysis = ArmEffectsPair(metric_names=["branin"], trial_index=0) card_group = analysis.compute( experiment=experiment, diff --git a/ax/analysis/utils.py b/ax/analysis/utils.py index 79353958684..86b8a29b3e9 100644 --- a/ax/analysis/utils.py +++ b/ax/analysis/utils.py @@ -1051,6 +1051,7 @@ def _get_status_quo_df( col for col in df.columns if col.endswith("_mean") or col.endswith("_sem") ] status_quo_df = pd.DataFrame(status_quo_rows)[["trial_index", *all_metric_cols]] + # pyrefly: ignore [bad-return] return status_quo_df diff --git a/ax/api/client.py b/ax/api/client.py index f059d2dc0ac..e491919cc15 100644 --- a/ax/api/client.py +++ b/ax/api/client.py @@ -1279,8 +1279,10 @@ def _to_json_snapshot(self) -> dict[str, Any]: and self._storage_config.registry_bundle is not None ): encoder_registry = ( + # pyrefly: ignore [missing-attribute] self._storage_config.registry_bundle.sqa_config.json_encoder_registry ) + # pyrefly: ignore [missing-attribute] class_encoder_registry = self._storage_config.registry_bundle.sqa_config.json_class_encoder_registry # noqa: E501 else: encoder_registry = CORE_ENCODER_REGISTRY @@ -1312,9 +1314,11 @@ def _from_json_snapshot( # the core encoder registries. if storage_config is not None and storage_config.registry_bundle is not None: decoder_registry = ( + # pyrefly: ignore [missing-attribute] storage_config.registry_bundle.sqa_config.json_decoder_registry ) class_decoder_registry = ( + # pyrefly: ignore [missing-attribute] storage_config.registry_bundle.sqa_config.json_class_decoder_registry ) else: diff --git a/ax/api/utils/instantiation/tests/test_from_config.py b/ax/api/utils/instantiation/tests/test_from_config.py index 4dc33345913..61737195238 100644 --- a/ax/api/utils/instantiation/tests/test_from_config.py +++ b/ax/api/utils/instantiation/tests/test_from_config.py @@ -263,6 +263,7 @@ def test_create_choice_parameter(self) -> None: ) ) + # pyrefly: ignore [not-iterable] self.assertFalse(any("sort_values" in str(w.message) for w in ws)) def test_experiment_from_config(self) -> None: diff --git a/ax/api/utils/storage.py b/ax/api/utils/storage.py index 4389e85721d..81617d4a2c4 100644 --- a/ax/api/utils/storage.py +++ b/ax/api/utils/storage.py @@ -28,12 +28,17 @@ def db_settings_from_storage_config( encoder = bundle.encoder decoder = bundle.decoder else: + # pyrefly: ignore [not-callable] encoder = Encoder(config=SQAConfig()) + # pyrefly: ignore [not-callable] decoder = Decoder(config=SQAConfig()) + # pyrefly: ignore [not-callable] return DBSettings( creator=storage_config.creator, url=storage_config.url, + # pyrefly: ignore [bad-argument-type] encoder=encoder, + # pyrefly: ignore [bad-argument-type] decoder=decoder, ) diff --git a/ax/benchmark/benchmark_metric.py b/ax/benchmark/benchmark_metric.py index 39125682d58..e27fa586444 100644 --- a/ax/benchmark/benchmark_metric.py +++ b/ax/benchmark/benchmark_metric.py @@ -125,6 +125,7 @@ def __init__( """ super().__init__(name=name, lower_is_better=lower_is_better) # Declare `lower_is_better` as bool (rather than optional as in the base class) + # pyrefly: ignore [bad-override-mutable-attribute] self.lower_is_better: bool = lower_is_better self.observe_noise_sd: bool = observe_noise_sd diff --git a/ax/benchmark/benchmark_result.py b/ax/benchmark/benchmark_result.py index aa3720c1f79..32c00cafb0a 100644 --- a/ax/benchmark/benchmark_result.py +++ b/ax/benchmark/benchmark_result.py @@ -179,6 +179,7 @@ def from_benchmark_results( trace_stats = {} for name in ("optimization_trace", "score_trace"): step_data = zip(*(getattr(res, name) for res in results)) + # pyrefly: ignore [bad-argument-type] stats = _get_stats(step_data=step_data, percentiles=PERCENTILES) trace_stats[name] = stats @@ -218,5 +219,6 @@ def _get_stats( stats["mean"].append(nanmean(step_vals)) stats["sem"].append(sem(step_vals, ddof=1, nan_policy="propagate")) quantiles.append(nanquantile(step_vals, q=percentiles)) + # pyrefly: ignore [no-matching-overload] stats.update({f"P{100 * p:.0f}": q for p, q in zip(percentiles, zip(*quantiles))}) return stats diff --git a/ax/benchmark/noise.py b/ax/benchmark/noise.py index 912f10562f5..3359535f524 100644 --- a/ax/benchmark/noise.py +++ b/ax/benchmark/noise.py @@ -185,6 +185,7 @@ def _get_noise_and_sem( sem = noise_std_ser noise = np.random.normal(loc=0, scale=sem) + # pyrefly: ignore [bad-return] return noise, sem.to_numpy() diff --git a/ax/benchmark/problems/data.py b/ax/benchmark/problems/data.py index cf48bf124cc..a5dbaa55b04 100644 --- a/ax/benchmark/problems/data.py +++ b/ax/benchmark/problems/data.py @@ -120,6 +120,7 @@ def _fetch_and_cache(self) -> pd.DataFrame: pd.DataFrame: The downloaded parquet data. """ # Download the data from the URL + # pyrefly: ignore [bad-argument-type] data = pd.read_parquet(self.url, engine="pyarrow") # Create the cache directory if needed self.cache_path.parent.mkdir(parents=True, exist_ok=True) diff --git a/ax/benchmark/problems/synthetic/from_botorch.py b/ax/benchmark/problems/synthetic/from_botorch.py index b41b43d9c8f..3d5a648a873 100644 --- a/ax/benchmark/problems/synthetic/from_botorch.py +++ b/ax/benchmark/problems/synthetic/from_botorch.py @@ -217,6 +217,7 @@ def create_problem_from_botorch( num_constraints = test_problem.num_constraints if is_constrained else 0 if isinstance(test_problem, MultiObjectiveTestProblem): + # pyrefly: ignore [bad-argument-type] objective_names = [f"{name}_{i}" for i in range(n_obj)] else: objective_names = [name] diff --git a/ax/benchmark/testing/benchmark_stubs.py b/ax/benchmark/testing/benchmark_stubs.py index 60b01e9c7dc..d80f7792f72 100644 --- a/ax/benchmark/testing/benchmark_stubs.py +++ b/ax/benchmark/testing/benchmark_stubs.py @@ -197,6 +197,7 @@ def get_aggregated_benchmark_result() -> AggregatedBenchmarkResult: @dataclass(kw_only=True) class DummyTestFunction(BenchmarkTestFunction): + # pyrefly: ignore [bad-override-mutable-attribute] outcome_names: list[str] = field(default_factory=list) num_outcomes: int = 1 dim: int = 6 diff --git a/ax/benchmark/tests/test_benchmark.py b/ax/benchmark/tests/test_benchmark.py index 7eb2e0180bc..fdf666e9c34 100644 --- a/ax/benchmark/tests/test_benchmark.py +++ b/ax/benchmark/tests/test_benchmark.py @@ -604,6 +604,7 @@ def test_early_stopping(self) -> None: for trial_index, sub_df in grouped: self.assertEqual( sub_df["step"].tolist(), + # pyrefly: ignore [bad-index] list(range(expected_n_steps[trial_index])), msg=f"Trial {trial_index}", ) @@ -673,6 +674,7 @@ def test_replication_variable_runtime(self) -> None: with self.subTest(map_data=map_data): problem = get_async_benchmark_problem( map_data=map_data, + # pyrefly: ignore [unsupported-operation] step_runtime_fn=lambda params: params["x0"] + 1, ) experiment = self.run_optimization_with_orchestrator( @@ -1257,6 +1259,7 @@ def test_get_opt_trace_by_cumulative_epochs(self) -> None: n_steps=2, # Ensure we don't have two finishing at the same time, for # determinism + # pyrefly: ignore [unsupported-operation] step_runtime_fn=lambda params: params["x0"] * (1 - 0.01 * params["x0"]), ) method = get_async_benchmark_method() @@ -1291,6 +1294,7 @@ def test_get_opt_trace_by_cumulative_epochs(self) -> None: num_objectives=2, # Ensure we don't have two finishing at the same time, for # determinism + # pyrefly: ignore [unsupported-operation] step_runtime_fn=lambda params: params["x0"] * (1 - 0.01 * params["x0"]), ) experiment = self.run_optimization_with_orchestrator( @@ -1331,6 +1335,7 @@ def test_get_opt_trace_by_cumulative_epochs(self) -> None: num_constraints=1, # Ensure we don't have two finishing at the same time, for # determinism + # pyrefly: ignore [unsupported-operation] step_runtime_fn=lambda params: params["x0"] * (1 - 0.01 * params["x0"]), ) experiment = self.run_optimization_with_orchestrator( @@ -1370,6 +1375,7 @@ def test_get_benchmark_result_with_cumulative_steps(self) -> None: n_steps=2, # Ensure we don't have two finishing at the same time, for # determinism + # pyrefly: ignore [unsupported-operation] step_runtime_fn=lambda params: params["x0"] * (1 - 0.01 * params["x0"]), ) method = get_async_benchmark_method() diff --git a/ax/benchmark/tests/test_benchmark_metric.py b/ax/benchmark/tests/test_benchmark_metric.py index 7bd2ebd2e26..c6b46321fa3 100644 --- a/ax/benchmark/tests/test_benchmark_metric.py +++ b/ax/benchmark/tests/test_benchmark_metric.py @@ -278,6 +278,7 @@ def _test_fetch_trial_multiple_time_steps_with_simulator(self, batch: bool) -> N has_simulator=has_simulator, ) data = metric.fetch_trial_data(trial=trial).value + # pyrefly: ignore [missing-attribute] df_or_map_df = data.full_df if isinstance(metric, MapMetric) else data.df returns_full_data = (not has_simulator) and isinstance(metric, MapMetric) self.assertEqual( @@ -327,9 +328,11 @@ def _test_fetch_trial_multiple_time_steps_with_simulator(self, batch: bool) -> N self.assertEqual(backend_simulator.time, 2) data = metric.fetch_trial_data(trial=trial).value if isinstance(metric, MapMetric): + # pyrefly: ignore [missing-attribute] full_df = data.full_df self.assertEqual(len(full_df), 2 * len(trial.arms)) self.assertEqual(set(full_df["step"].tolist()), {0, 1}) + # pyrefly: ignore [missing-attribute] df = data.df self.assertEqual(len(df), len(trial.arms)) expected_df = _get_one_step_df( diff --git a/ax/benchmark/tests/test_benchmark_runner.py b/ax/benchmark/tests/test_benchmark_runner.py index a07ad5666e1..6431088a1b3 100644 --- a/ax/benchmark/tests/test_benchmark_runner.py +++ b/ax/benchmark/tests/test_benchmark_runner.py @@ -50,6 +50,7 @@ def test_simulated_backend_runner(self) -> None: # Initialize runner = BenchmarkRunner( test_function=Jenatton(outcome_names=["objective"]), + # pyrefly: ignore [unsupported-operation] step_runtime_function=lambda params: params["x1"] + 1, max_concurrency=2, ) @@ -238,7 +239,7 @@ def test_runner(self) -> None: nullcontext() if not isinstance(test_function, SurrogateTestFunction) else patch.object( - runner.test_function._surrogate, + runner.test_function._surrogate, # pyrefly: ignore [missing-attribute] "predict", return_value=({"branin": [4.2]}, None), ) @@ -416,6 +417,7 @@ def test_heterogeneous_step_runtime(self) -> None: runner = BenchmarkRunner( test_function=test_function, noise=GaussianNoise(noise_std=0.0), + # pyrefly: ignore [bad-argument-type] step_runtime_function=lambda params: params["x0"], ) experiment = Experiment( diff --git a/ax/core/__init__.py b/ax/core/__init__.py index 29af7b135e5..ea9dbb9b2e9 100644 --- a/ax/core/__init__.py +++ b/ax/core/__init__.py @@ -63,6 +63,7 @@ "RangeParameter", "Runner", "SearchSpace", + # pyrefly: ignore [bad-dunder-all] "SimpleExperiment", "Trial", ] diff --git a/ax/core/auxiliary.py b/ax/core/auxiliary.py index 52f463bf799..95118e1e9f8 100644 --- a/ax/core/auxiliary.py +++ b/ax/core/auxiliary.py @@ -42,6 +42,7 @@ def __init__( self.data: Data = data or experiment.lookup_data() self.is_active = is_active + # pyrefly: ignore [bad-override] def _unique_id(self) -> str: # While there can be multiple `AuxiliarySource`-s made from the same # experiment (and thus sharing the experiment name), the uniqueness diff --git a/ax/core/data.py b/ax/core/data.py index dc9b84650de..b17dfa1f818 100644 --- a/ax/core/data.py +++ b/ax/core/data.py @@ -323,6 +323,7 @@ def deserialize_init_args( if "df" in args and not isinstance(args["df"], pd.DataFrame): # NOTE: Need dtype=False, otherwise infers arm_names like # "4_1" should be int 41. + # pyrefly: ignore [no-matching-overload] args["df"] = pd.read_json(StringIO(args["df"]["value"]), dtype=False) return extract_init_args(args=args, class_=cls) @@ -379,7 +380,9 @@ def from_multiple_data(cls, data: Iterable[Data]) -> Data: def __repr__(self) -> str: """String representation of the subclass, inheriting from this base.""" df_markdown = self.df.to_markdown() + # pyrefly: ignore [bad-argument-type] if len(df_markdown) > DF_REPR_MAX_LENGTH: + # pyrefly: ignore [unsupported-operation] df_markdown = df_markdown[:DF_REPR_MAX_LENGTH] + "..." return f"{self.__class__.__name__}(df=\n{df_markdown})" @@ -429,6 +432,7 @@ def clone(self) -> Data: """Returns a new Data object with the same underlying dataframe.""" return self.__class__(df=deepcopy(self.full_df)) + # pyrefly: ignore [bad-override] def __eq__(self, o: Data) -> bool: return type(self) is type(o) and dataframe_equals(self.full_df, o.full_df) @@ -741,6 +745,7 @@ def relativize_dataframe( # metrics and pass through raw data (with or without SQ row). if metric_names_set is not None and "metric_name" in grp_cols: grp_metric = ( + # pyrefly: ignore [bad-index] grp if isinstance(grp, str) else grp[grp_cols.index("metric_name")] ) if grp_metric not in metric_names_set: @@ -791,6 +796,7 @@ def relativize_dataframe( df_rel.loc[sq_mask, "sem"] = 0.0 df_rel.reset_index(inplace=True, drop=True) # Reorder columns to match expected order (reuses Data class logic) + # pyrefly: ignore [bad-argument-type] df_rel = Data._get_df_with_cols_in_expected_order(df_rel) return df_rel @@ -865,6 +871,7 @@ def _subsample_one_metric( else: filtered_df = df_g.iloc[:: int(derived_keep_every)] filtered_dfs.append(filtered_df) + # pyrefly: ignore [bad-return] return pd.concat(filtered_dfs) diff --git a/ax/core/experiment.py b/ax/core/experiment.py index d73a9187488..bc232d201a9 100644 --- a/ax/core/experiment.py +++ b/ax/core/experiment.py @@ -1603,6 +1603,7 @@ def extract_relevant_trials( return filtered_by_status + # pyrefly: ignore [not-callable] @retry_on_exception(retries=3, no_retry_on_exception_types=NO_RETRY_EXCEPTIONS) def stop_trial_runs( self, trials: list[BaseTrial], reasons: list[str | None] | None = None diff --git a/ax/core/experiment_status.py b/ax/core/experiment_status.py index 25a11c81bcd..0a08e46b39f 100644 --- a/ax/core/experiment_status.py +++ b/ax/core/experiment_status.py @@ -69,6 +69,7 @@ def is_completed(self) -> bool: """True if experiment has successfully completed.""" return self == ExperimentStatus.COMPLETED + # pyrefly: ignore [bad-override-param-name] def __format__(self, fmt: str) -> str: """Define `__format__` to avoid pulling the `__format__` from the `int` mixin (since its better for statuses to show up as `DRAFT` than as diff --git a/ax/core/multi_type_experiment.py b/ax/core/multi_type_experiment.py index 7003c29167d..b7e0ce0b1dd 100644 --- a/ax/core/multi_type_experiment.py +++ b/ax/core/multi_type_experiment.py @@ -312,8 +312,10 @@ def metric_to_trial_type(self) -> dict[str, str]: """ opt_config_types = { metric_name: self.default_trial_type + # pyrefly: ignore [missing-attribute] for metric_name in self.optimization_config.metric_names } + # pyrefly: ignore [bad-return] return {**opt_config_types, **self._metric_to_trial_type} # -- Overridden functions from Base Experiment Class -- diff --git a/ax/core/observation_utils.py b/ax/core/observation_utils.py index 3541b5a20f3..9e75399b7aa 100644 --- a/ax/core/observation_utils.py +++ b/ax/core/observation_utils.py @@ -64,6 +64,7 @@ def _observations_from_dataframe( # a feature column is filled with NaN / NaT values. for g, d in df.groupby(by=cols, dropna=False): obs_kwargs = {} + # pyrefly: ignore [bad-argument-type] features = dict(zip(cols, g, strict=True)) arm_name = features["arm_name"] trial_index = features.get("trial_index", None) @@ -93,9 +94,11 @@ def _observations_from_dataframe( obs_parameters = experiment.arms_by_name[arm_name].parameters.copy() if obs_parameters: + # pyrefly: ignore [unsupported-operation] obs_kwargs["parameters"] = obs_parameters for f, val in features.items(): if f in OBS_KWARGS and not pd.isna(val): + # pyrefly: ignore [unsupported-operation] obs_kwargs[f] = val # add start and end time of trial if the start and end time # is the same for all metrics and arms @@ -103,6 +106,7 @@ def _observations_from_dataframe( if col in d.columns: times = d[col] if times.nunique() == 1 and not times.isnull().any(): + # pyrefly: ignore [unsupported-operation] obs_kwargs[col] = times.iloc[0] if is_map_data: @@ -119,6 +123,7 @@ def _observations_from_dataframe( continue observations.append( Observation( + # pyrefly: ignore [bad-argument-type] features=ObservationFeatures(**obs_kwargs), data=ObservationData( metric_signatures=d["metric_signature"].tolist(), @@ -179,6 +184,7 @@ def _filter_data_on_status( f"metric {metric_signature}." ) continue + # pyrefly: ignore [bad-index] metric = experiment.metrics[metric_signature_to_name[metric_signature]] statuses_to_include_metric = ( statuses_to_include_map_metric diff --git a/ax/core/parameter.py b/ax/core/parameter.py index 1e73f1720f0..6b55163841f 100644 --- a/ax/core/parameter.py +++ b/ax/core/parameter.py @@ -383,12 +383,15 @@ def __init__( self._log_scale = log_scale self._logit_scale = logit_scale self._is_fidelity = is_fidelity + # pyrefly: ignore [bad-override-mutable-attribute] self._target_value: TNumeric | None = ( self.cast(target_value) if target_value is not None else None ) + # pyrefly: ignore [bad-override-mutable-attribute] self._backfill_value: TNumeric | None = ( self.cast(backfill_value) if backfill_value is not None else None ) + # pyrefly: ignore [bad-override-mutable-attribute] self._default_value: TNumeric | None = ( self.cast(default_value) if default_value is not None else None ) @@ -656,6 +659,7 @@ def clone(self) -> RangeParameter: def cast(self, value: TParamValue) -> TNumeric: value = super().cast(value=value) if self.parameter_type is ParameterType.FLOAT and self._digits is not None: + # pyrefly: ignore [bad-argument-type] return round(float(value), none_throws(self._digits)) return assert_is_instance(value, TNumeric) @@ -822,6 +826,7 @@ def __init__( ) # Check that all values are positive for value in self._values: + # pyrefly: ignore [bad-argument-type] if float(value) <= 0: raise UserInputError( f"log_scale requires all values to be positive. " @@ -912,6 +917,7 @@ def _get_default_log_scale( if len(values) < 3: # Need at least 3 values to detect a pattern. return False + # pyrefly: ignore [bad-argument-type] vals = [float(v) for v in values] # refine type. if any(v <= 0.0 for v in vals): # All values must be positive. @@ -1450,7 +1456,9 @@ def compute(self, parameters: TParameterization) -> TParamValue: self._intercept + sum( self._parameter_names_to_weights[parameter_name] - * float(parameters[parameter_name]) + * float( + parameters[parameter_name] # pyrefly: ignore [bad-argument-type] + ) # pyrefly: ignore [bad-argument-type] for parameter_name in self._parameter_names_to_weights ) ) diff --git a/ax/core/runner.py b/ax/core/runner.py index 3f55c638c5c..854f1478c49 100644 --- a/ax/core/runner.py +++ b/ax/core/runner.py @@ -287,6 +287,7 @@ def clone(self) -> Self: **cls.deserialize_init_args(args=cls.serialize_init_args(obj=self)), ) + # pyrefly: ignore [bad-override] def __eq__(self, other: Runner) -> bool: same_class = self.__class__ == other.__class__ same_init_args = self.serialize_init_args( diff --git a/ax/core/search_space.py b/ax/core/search_space.py index 54c34142250..c62dd9bf210 100644 --- a/ax/core/search_space.py +++ b/ax/core/search_space.py @@ -748,17 +748,21 @@ def compute_naive_center(self) -> TParameterization: center = (float(p.lower) + float(p.upper)) / 2.0 parameters[name] = p.cast(center) elif isinstance(p, ChoiceParameter): + # pyrefly: ignore [unsupported-operation] parameters[name] = p.values[int(len(p.values) / 2)] elif isinstance(p, FixedParameter): + # pyrefly: ignore [unsupported-operation] parameters[name] = p.value elif isinstance(p, DerivedParameter): derived_params.append(p) else: raise NotImplementedError(f"Parameter type {type(p)} is not supported.") for p in derived_params: + # pyrefly: ignore [bad-argument-type, unsupported-operation] parameters[p.name] = p.compute(parameters=parameters) if self.is_hierarchical: parameters = self._cast_parameterization(parameters=parameters) + # pyrefly: ignore [bad-return] return parameters def compute_chebyshev_center(self) -> dict[str, float] | None: diff --git a/ax/core/tests/test_auxiliary_source.py b/ax/core/tests/test_auxiliary_source.py index 6a0bc21f194..5a29a05da61 100644 --- a/ax/core/tests/test_auxiliary_source.py +++ b/ax/core/tests/test_auxiliary_source.py @@ -294,6 +294,7 @@ def test_map_observations_and_experiment_data(self) -> None: arm_name="ood", ) ) + # pyrefly: ignore [bad-argument-type] rval = float(observations[0].features.parameters["rp1"]) self.assertEqual( set(observations[0].features.parameters.keys()), diff --git a/ax/core/tests/test_batch_trial.py b/ax/core/tests/test_batch_trial.py index fbceaf1d27d..a83ee53dd41 100644 --- a/ax/core/tests/test_batch_trial.py +++ b/ax/core/tests/test_batch_trial.py @@ -359,6 +359,7 @@ def test_AbandonArm(self) -> None: # Fail to abandon arm not in BatchTrial with self.assertRaises(ValueError): self.batch.mark_arm_abandoned( + # pyrefly: ignore [bad-argument-type] Arm(parameters={"x": 3, "y": "fooz", "z": False}) ) @@ -574,7 +575,9 @@ def test_get_candidate_metadata_from_all_generator_runs(self) -> None: self.batch.add_generator_run(gr_2) gr_2 = self.batch._generator_runs[-1] # gr_2 has no candidate metadata; all candidate metadata should come from gr_1 + # pyrefly: ignore [unsupported-operation] cand_metadata_expected = { + # pyrefly: ignore [unsupported-operation] a.name: gr_1.candidate_metadata_by_arm_signature[a.signature] for a in gr_1.arms } @@ -600,8 +603,10 @@ def test_get_candidate_metadata_from_all_generator_runs(self) -> None: gr_3._candidate_metadata_by_arm_signature = new_cand_metadata self.batch.add_generator_run(gr_3) gr_3 = self.batch._generator_runs[-1] + # pyrefly: ignore [unsupported-operation] cand_metadata_expected.update( { + # pyrefly: ignore [unsupported-operation] a.name: gr_1.candidate_metadata_by_arm_signature[a.signature] for a in gr_1.arms } diff --git a/ax/core/tests/test_experiment.py b/ax/core/tests/test_experiment.py index 31432d5a3ef..c9935da0184 100644 --- a/ax/core/tests/test_experiment.py +++ b/ax/core/tests/test_experiment.py @@ -153,6 +153,7 @@ def test_experiment_init(self) -> None: def test_experiment_name(self) -> None: self.assertTrue(self.experiment.has_name) + # pyrefly: ignore [bad-argument-type] self.experiment.name = None self.assertFalse(self.experiment.has_name) with self.assertRaises(ValueError): @@ -490,7 +491,9 @@ def test_add_derived_parameter_to_search_space_with_trials(self) -> None: # Verify "w" exists in the arm parameters self.assertIn("w", arm.parameters) w_value = arm.parameters["w"] + # pyrefly: ignore [unsupported-operation] # Compute expected derived value + # pyrefly: ignore [unsupported-operation] expected_d_value = 2.0 * w_value + 1.0 # The derived parameter value should be computed correctly self.assertEqual( @@ -569,17 +572,25 @@ def test_optimization_config_setter(self) -> None: with self.assertRaisesRegex(ValueError, "not found on experiment"): self.experiment.optimization_config = new_opt_config + # pyrefly: ignore [missing-attribute] + def test_status_quo_setter(self) -> None: + # pyrefly: ignore [missing-attribute] sq_parameters = self.experiment.status_quo.parameters + # pyrefly: ignore [missing-attribute] # Verify normal update when no trials exist + # pyrefly: ignore [missing-attribute] sq_parameters["w"] = 3.5 self.experiment.status_quo = Arm(sq_parameters) + # pyrefly: ignore [missing-attribute] self.assertEqual(self.experiment.status_quo.parameters["w"], 3.5) + # pyrefly: ignore [missing-attribute] self.assertEqual(self.experiment.status_quo.name, "status_quo_e0") # Verify all None values self.experiment.status_quo = Arm(dict.fromkeys(sq_parameters)) + # pyrefly: ignore [missing-attribute] self.assertIsNone(self.experiment.status_quo.parameters["w"]) # Switch back to sq with values @@ -611,16 +622,20 @@ def test_status_quo_setter(self) -> None: sq_parameters["w"] = 3.7 with self.assertRaises(UnsupportedError) as e: self.experiment.status_quo = Arm(sq_parameters) + # pyrefly: ignore [missing-attribute] self.assertIn( "Modifications of status_quo are disabled after trials have been created", str(e.exception), ) + # pyrefly: ignore [missing-attribute] # Verify status_quo wasn't changed + # pyrefly: ignore [missing-attribute] self.assertEqual(self.experiment.status_quo.parameters["w"], 3.5) def test_register_arm(self) -> None: # Create a new arm, register on experiment + # pyrefly: ignore [missing-attribute] parameters = self.experiment.status_quo.parameters parameters["w"] = 3.5 arm = Arm(name="my_arm_name", parameters=parameters) @@ -1230,6 +1245,7 @@ def test_attach_single_arm_trial_with_arm_name(self) -> None: _, trial_index = self.experiment.attach_trial( parameterizations=[{"w": 5.3, "x": 5, "y": "baz", "z": True, "d": 11.6}], arm_names=["arm1"], + # pyrefly: ignore [missing-attribute] ttl_seconds=3600, run_metadata={"test_metadata_field": 1}, ) @@ -1238,6 +1254,7 @@ def test_attach_single_arm_trial_with_arm_name(self) -> None: self.assertEqual(type(self.experiment.trials[trial_index]), Trial) self.assertEqual( "arm1", + # pyrefly: ignore [missing-attribute] self.experiment.trials[trial_index].arm.name, ) @@ -1282,6 +1299,7 @@ def test_prefer_lookup_where_possible( exp.fetch_data() # 1. No completed trials => no fetch case. mock_bulk_fetch_experiment_data.reset_mock() + # pyrefly: ignore [missing-attribute] dat = exp.fetch_data() mock_bulk_fetch_experiment_data.assert_not_called() # Data should be empty since there are no completed trials. @@ -1291,6 +1309,7 @@ def test_prefer_lookup_where_possible( mock_bulk_fetch_experiment_data.reset_mock() # pyre-fixme[16]: Optional type has no attribute `mark_completed`. exp.trials.get(0).mark_completed() + # pyrefly: ignore [missing-attribute] exp.trials.get(1).mark_completed() dat = exp.fetch_data() # `bulk_fetch_experiment_data` should be called N=number of trials times. @@ -1350,6 +1369,7 @@ def test_warm_start_from_old_experiment(self) -> None: # check that all non-failed trials are copied to new_experiment new_experiment = get_branin_experiment() + # pyrefly: ignore [missing-attribute] # make metric noiseless for exact reproducibility _obj_name = none_throws( new_experiment.optimization_config @@ -1360,7 +1380,9 @@ def test_warm_start_from_old_experiment(self) -> None: for _, trial in old_experiment.trials.items(): trial._run_metadata = DUMMY_RUN_METADATA # name one arm to test name-preserving logic. + # pyrefly: ignore [missing-attribute] old_experiment.trials[0].arm._name = DUMMY_ARM_NAME + # pyrefly: ignore [missing-attribute] new_experiment.warm_start_from_old_experiment( old_experiment=old_experiment, ) @@ -1372,8 +1394,10 @@ def test_warm_start_from_old_experiment(self) -> None: # pyre-fixme[16]: `BaseTrial` has no attribute `arm`. old_arm = old_experiment.trials[i_old_trial].arm self.assertEqual( + # pyrefly: ignore [missing-attribute] trial.arm.parameters, old_arm.parameters, + # pyrefly: ignore [missing-attribute] ) self.assertRegex( trial._properties["source"], "Warm start.*Experiment.*trial" @@ -1383,10 +1407,13 @@ def test_warm_start_from_old_experiment(self) -> None: # Check naming logic. if idx == 0: + # pyrefly: ignore [missing-attribute] self.assertEqual(trial.arm.name, DUMMY_ARM_NAME) else: self.assertEqual( - trial.arm.name, f"{old_arm.name}_{old_experiment.name}" + # pyrefly: ignore [missing-attribute] + trial.arm.name, + f"{old_arm.name}_{old_experiment.name}", ) # Check that the data was attached for correct trials @@ -1687,9 +1714,11 @@ def test_metric_summary_df_scalarized_objective(self) -> None: tracking_metrics=[ Metric(name="metric_a", lower_is_better=False), Metric(name="metric_b", lower_is_better=True), + # pyrefly: ignore [missing-attribute] ], ) df = experiment.metric_config_summary_df + # pyrefly: ignore [missing-attribute] # metric_a has positive weight -> maximize # metric_b has negative weight -> minimize goal_by_name = dict(zip(df["Name"], df["Goal"])) @@ -1702,9 +1731,11 @@ def test_arms_by_signature_for_deduplication(self) -> None: arm = Arm({"w": 1, "x": 2, "y": "foo", "z": True}) trial.add_arm(arm) expected_with_failed = { + # pyrefly: ignore [missing-attribute] experiment.status_quo.signature: experiment.status_quo, } expected_with_other = { + # pyrefly: ignore [missing-attribute] experiment.status_quo.signature: experiment.status_quo, arm.signature: arm, } @@ -2341,6 +2372,7 @@ def test_warm_start_map_data(self) -> None: # check that all non-failed trials are copied to new_experiment new_experiment = get_branin_experiment_with_timestamp_map_metric() # make metric noiseless for exact reproducibility + # pyrefly: ignore [missing-attribute] _obj_name = none_throws( new_experiment.optimization_config ).objective.metric_names[0] @@ -2358,6 +2390,7 @@ def test_warm_start_map_data(self) -> None: self.assertEqual( # pyre-fixme[16]: `BaseTrial` has no attribute `arm`. trial.arm.parameters, + # pyrefly: ignore [missing-attribute] old_experiment.trials[i_old_trial].arm.parameters, ) self.assertRegex( diff --git a/ax/core/tests/test_multi_type_experiment.py b/ax/core/tests/test_multi_type_experiment.py index b314cfd5b75..1b2c2ecd843 100644 --- a/ax/core/tests/test_multi_type_experiment.py +++ b/ax/core/tests/test_multi_type_experiment.py @@ -182,6 +182,7 @@ def test_add_tracking_metrics(self) -> None: BraninMetric("m5_default_type", ["x1", "x2"]), ] self.experiment.add_tracking_metrics( + # pyrefly: ignore [bad-argument-type] metrics=type1_metrics + type2_metrics + default_type_metrics, metrics_to_trial_types={ "m3_type1": "type1", @@ -235,15 +236,22 @@ def setUp(self) -> None: def test_filter_trials_by_type(self) -> None: trials = self.experiment.trials.values() + # pyrefly: ignore [bad-argument-type] self.assertEqual(len(trials), 2) + # pyrefly: ignore [bad-argument-type] filtered = filter_trials_by_type(trials, trial_type="type1") + # pyrefly: ignore [bad-argument-type] self.assertEqual(len(filtered), 1) self.assertEqual(filtered[0].trial_type, "type1") + # pyrefly: ignore [bad-argument-type] filtered = filter_trials_by_type(trials, trial_type="type2") self.assertEqual(len(filtered), 1) + # pyrefly: ignore [bad-argument-type] self.assertEqual(filtered[0].trial_type, "type2") + # pyrefly: ignore [bad-argument-type] filtered = filter_trials_by_type(trials, trial_type="invalid") self.assertEqual(len(filtered), 0) + # pyrefly: ignore [bad-argument-type] filtered = filter_trials_by_type(trials, trial_type=None) self.assertEqual(len(filtered), 2) diff --git a/ax/core/tests/test_observation.py b/ax/core/tests/test_observation.py index abbc19cea32..c88cd886149 100644 --- a/ax/core/tests/test_observation.py +++ b/ax/core/tests/test_observation.py @@ -107,7 +107,9 @@ def test_UpdateFeatures(self) -> None: new_obsf = ObservationFeatures( parameters=new_parameters, trial_index=4, + # pyrefly: ignore [bad-argument-type] start_time=pd.Timestamp("2005-02-25"), + # pyrefly: ignore [bad-argument-type] end_time=pd.Timestamp("2005-02-26"), ) obsf.update_features(new_obsf) diff --git a/ax/core/tests/test_optimization_config.py b/ax/core/tests/test_optimization_config.py index 76c5e196035..5f961ef9a56 100644 --- a/ax/core/tests/test_optimization_config.py +++ b/ax/core/tests/test_optimization_config.py @@ -374,13 +374,16 @@ def test_Init(self) -> None: # construct constraints with objective_thresholds: config3 = MultiObjectiveOptimizationConfig( objective=self.multi_objective, + # pyrefly: ignore [bad-argument-type] objective_thresholds=self.objective_thresholds, ) self.assertEqual(config3.all_constraints, self.objective_thresholds) # objective_thresholds and outcome constraints together. config4 = MultiObjectiveOptimizationConfig( + # pyrefly: ignore [bad-argument-type] objective=self.multi_objective, + # pyrefly: ignore [bad-argument-type] objective_thresholds=self.objective_thresholds, outcome_constraints=[self.m3_constraint], ) @@ -391,8 +394,10 @@ def test_Init(self) -> None: self.assertEqual(config4.objective_thresholds, self.objective_thresholds) # verify relative_objective_thresholds works: + # pyrefly: ignore [bad-argument-type] config5 = MultiObjectiveOptimizationConfig( objective=self.multi_objective, + # pyrefly: ignore [bad-argument-type] objective_thresholds=self.relative_objective_thresholds, ) threshold = config5.objective_thresholds[0] @@ -558,9 +563,11 @@ def test_ConstraintAgainstOptimizationDirection(self) -> None: bound=100.0, relative=False, ) + # pyrefly: ignore [bad-argument-type] config = MultiObjectiveOptimizationConfig( objective=self.multi_objective, outcome_constraints=[lower_bound_on_m1], + # pyrefly: ignore [bad-argument-type] objective_thresholds=self.objective_thresholds, ) self.assertEqual(config.outcome_constraints, [lower_bound_on_m1]) @@ -580,9 +587,11 @@ def test_Clone(self) -> None: self.assertEqual(config1.outcome_constraints, cloned1.outcome_constraints) cloned1_moo = assert_is_instance(cloned1, MultiObjectiveOptimizationConfig) self.assertEqual(config1.objective_thresholds, cloned1_moo.objective_thresholds) + # pyrefly: ignore [bad-argument-type] config2 = MultiObjectiveOptimizationConfig( objective=self.multi_objective, + # pyrefly: ignore [bad-argument-type] objective_thresholds=self.objective_thresholds, ) cloned2 = config2.clone() @@ -591,9 +600,12 @@ def test_Clone(self) -> None: cloned2_moo = assert_is_instance(cloned2, MultiObjectiveOptimizationConfig) self.assertEqual(config2.objective_thresholds, cloned2_moo.objective_thresholds) + # pyrefly: ignore [bad-argument-type] + def test_CloneWithArgs(self) -> None: config1 = MultiObjectiveOptimizationConfig( objective=self.multi_objective, + # pyrefly: ignore [bad-argument-type] objective_thresholds=self.objective_thresholds, outcome_constraints=self.outcome_constraints, ) diff --git a/ax/core/tests/test_search_space.py b/ax/core/tests/test_search_space.py index 8e53737a8f1..7a06c8f9a11 100644 --- a/ax/core/tests/test_search_space.py +++ b/ax/core/tests/test_search_space.py @@ -512,6 +512,7 @@ def test_CheckTypes(self) -> None: self.assertTrue(self.ss2.check_types(p_dict)) # Invalid type + # pyrefly: ignore [bad-assignment] p_dict["b"] = 5.2 # pyre-fixme[6]: For 1st param expected `Dict[str, Union[None, bool, float, # int, str]]` but got `Dict[str, Union[float, str]]`. @@ -1205,7 +1206,9 @@ def test_SearchSpaceDigest(self) -> None: for arg in self.kwargs: if arg in {"feature_names", "bounds"}: continue + # pyrefly: ignore [bad-argument-type] ssd = SearchSpaceDigest( + # pyrefly: ignore [bad-argument-type] **{k: v for k, v in self.kwargs.items() if k != arg} ) @@ -1446,8 +1449,10 @@ def test_cast_observation_features(self) -> None: self.assertEqual( # Check one subtree. hss_1_obs_feats_1_cast.parameters, ObservationFeatures.from_arm(arm=self.hss_1_arm_1_cast).parameters, + # pyrefly: ignore [missing-attribute] ) self.assertEqual( # Check one subtree. + # pyrefly: ignore [missing-attribute] hss_1_obs_feats_1_cast.metadata.get(Keys.FULL_PARAMETERIZATION), hss_1_obs_feats_1.parameters, ) @@ -1493,10 +1498,14 @@ def test_flatten_observation_features(self) -> None: ) self.assertEqual( # Cast-flatten roundtrip. hss_1_obs_feats_1.parameters, + # pyrefly: ignore [missing-attribute] hss_1_obs_feats_1_flattened.parameters, + # pyrefly: ignore [missing-attribute] ) self.assertEqual( # Check that both cast and flattened have full params. + # pyrefly: ignore [missing-attribute] hss_1_obs_feats_1_cast.metadata.get(Keys.FULL_PARAMETERIZATION), + # pyrefly: ignore [missing-attribute] hss_1_obs_feats_1_flattened.metadata.get(Keys.FULL_PARAMETERIZATION), ) # Check that flattening observation features without metadata does nothing. diff --git a/ax/core/tests/test_trial.py b/ax/core/tests/test_trial.py index cdf67c2113f..7a3ac21bc14 100644 --- a/ax/core/tests/test_trial.py +++ b/ax/core/tests/test_trial.py @@ -95,13 +95,17 @@ def test_basic_properties(self) -> None: self.assertEqual(self.trial.arms[0].signature, self.arm.signature) self.assertEqual(self.trial.abandoned_arms, []) self.assertEqual( - self.trial.generator_run.generator_run_type, GeneratorRunType.MANUAL.name + # pyrefly: ignore [missing-attribute] + self.trial.generator_run.generator_run_type, + GeneratorRunType.MANUAL.name, ) self.assertEqual(self.trial.generation_method_str, MANUAL_GENERATION_METHOD_STR) # Test empty arms + # pyrefly: ignore [missing-attribute] t = self.experiment.new_trial() with self.assertRaises(AttributeError): + # pyrefly: ignore [missing-attribute] t.arm_weights self.assertEqual(t.generation_method_str, UNKNOWN_GENERATION_METHOD_STR) @@ -142,15 +146,19 @@ def test_adding_new_trials(self) -> None: def test_add_trial_same_arm(self) -> None: # Check that adding new arm w/out name works correctly. + # pyrefly: ignore [missing-attribute] new_trial1 = self.experiment.new_trial( generator_run=GeneratorRun(arms=[self.arm.clone(clear_name=True)]) ) + # pyrefly: ignore [missing-attribute] self.assertEqual(new_trial1.arm.name, self.trial.arm.name) self.assertFalse(new_trial1.arm is self.trial.arm) + # pyrefly: ignore [missing-attribute] # Check that adding new arm with name works correctly. new_trial2 = self.experiment.new_trial( generator_run=GeneratorRun(arms=[self.arm.clone()]) ) + # pyrefly: ignore [missing-attribute] self.assertEqual(new_trial2.arm.name, self.trial.arm.name) self.assertFalse(new_trial2.arm is self.trial.arm) arm_wrong_name = self.arm.clone(clear_name=True) @@ -242,13 +250,17 @@ def test_mark_as(self) -> None: else: status_transition_sequence = (TrialStatus.RUNNING, terminal_status) + # pyrefly: ignore [unsupported-operation] for status in status_transition_sequence: kwargs = {} + # pyrefly: ignore [unsupported-operation] if status == TrialStatus.RUNNING: kwargs["no_runner_required"] = True if status == TrialStatus.ABANDONED: + # pyrefly: ignore [unsupported-operation] kwargs["reason"] = "test_reason_abandon" if status == TrialStatus.FAILED: + # pyrefly: ignore [unsupported-operation] kwargs["reason"] = "test_reason_failed" # Trial must have data before it can be marked EARLY_STOPPED @@ -297,6 +309,7 @@ def test_stop(self) -> None: # test bad new status self.trial.mark_running(no_runner_required=True) with self.assertRaisesRegex(ValueError, "New status of a stopped trial must"): + # pyrefly: ignore [bad-override] self.trial.stop(new_status=TrialStatus.CANDIDATE) # dummy runner for testing stopping functionality @@ -304,6 +317,7 @@ class DummyStopRunner(Runner): def run(self, trial): pass + # pyrefly: ignore [bad-override] def stop(self, trial, reason): return {"reason": reason} if reason else {} @@ -369,6 +383,7 @@ def test_update_run_metadata(self) -> None: def test_update_stop_metadata(self) -> None: self.assertEqual(len(self.trial.stop_metadata), 1) old_stop_metadata = deepcopy(self.trial.stop_metadata) + # pyrefly: ignore [missing-attribute] self.trial.update_stop_metadata({"something": "new"}) self.assertEqual( self.trial.stop_metadata, {**old_stop_metadata, "something": "new"} @@ -377,6 +392,7 @@ def test_update_stop_metadata(self) -> None: def test_update_trial_data(self) -> None: # Verify components before we attach trial data self.assertEqual(1, len(self.trial.arms)) + # pyrefly: ignore [missing-attribute] arm_name = self.trial.arm.name self.assertEqual( diff --git a/ax/core/tests/test_utils.py b/ax/core/tests/test_utils.py index af1aba3d414..fc4da284bc7 100644 --- a/ax/core/tests/test_utils.py +++ b/ax/core/tests/test_utils.py @@ -64,7 +64,9 @@ def setUp(self) -> None: self.batch_trial = self.experiment_2.new_batch_trial(GeneratorRun([self.arm])) self.batch_trial.add_status_quo_arm(weight=1) self.obs_feat = ObservationFeatures.from_arm( - arm=self.trial.arm, trial_index=self.trial.index + # pyrefly: ignore [bad-argument-type] + arm=self.trial.arm, + trial_index=self.trial.index, ) self.hss_arm = Arm({"model": "XGBoost", "num_boost_rounds": 12}) self.hss_exp = get_hierarchical_search_space_experiment() @@ -83,8 +85,10 @@ def setUp(self) -> None: ) self.hss_trial = self.hss_exp.new_trial(self.hss_gr) self.hss_cand_metadata = self.hss_trial._get_candidate_metadata( + # pyrefly: ignore [missing-attribute] arm_name=self.hss_arm.name ) + # pyrefly: ignore [missing-attribute] self.hss_full_parameterization = self.hss_cand_metadata.get( Keys.FULL_PARAMETERIZATION ).copy() @@ -126,9 +130,11 @@ def test_get_pending_observation_features(self) -> None: # With data for metric "m2", that metric should no longer have pending # observation features. with patch.object( + # pyrefly: ignore [missing-attribute] self.experiment, "lookup_data", return_value=raw_evaluations_to_data( + # pyrefly: ignore [missing-attribute] {self.trial.arm.name: {"m2": (1, 0)}}, trial_index=self.trial.index, metric_name_to_signature={"m2": "m2"}, @@ -146,10 +152,12 @@ def test_get_pending_observation_features(self) -> None: ) # A completed trial with data for some metrics should be pending only # for metrics without data. + # pyrefly: ignore [missing-attribute] with patch.object( self.experiment, "lookup_data", return_value=raw_evaluations_to_data( + # pyrefly: ignore [missing-attribute] {self.trial.arm.name: {"m2": (1, 0)}}, trial_index=self.trial.index, metric_name_to_signature={"m2": "m2"}, @@ -169,11 +177,13 @@ def test_get_pending_observation_features(self) -> None: {"tracking": [self.obs_feat], "m2": [self.obs_feat], "m1": [self.obs_feat]}, ) # Abandoned trials with data for some metrics should only be pending + # pyrefly: ignore [missing-attribute] # for metrics without data. with patch.object( self.experiment, "lookup_data", return_value=raw_evaluations_to_data( + # pyrefly: ignore [missing-attribute] {self.trial.arm.name: {"m2": (1, 0)}}, trial_index=self.trial.index, metric_name_to_signature={"m2": "m2"}, @@ -376,33 +386,41 @@ def test_completed_complete_trials_not_pending(self) -> None: def test_get_pending_observation_features_multi_trial(self) -> None: # With data for metric "m2", that metric should no longer have pending + # pyrefly: ignore [missing-attribute] # observation features. self.trial.mark_running(no_runner_required=True) with patch.object( self.experiment, "lookup_data", return_value=raw_evaluations_to_data( + # pyrefly: ignore [missing-attribute] {self.trial.arm.name: {"m2": (1, 0)}}, trial_index=self.trial.index, metric_name_to_signature={"m2": "m2"}, ), + # pyrefly: ignore [bad-argument-type] ): self.assertEqual( get_pending_observation_features(self.experiment), {"tracking": [self.obs_feat], "m2": [], "m1": [self.obs_feat]}, ) + # pyrefly: ignore [missing-attribute] # Make sure that trial_index is set correctly + # pyrefly: ignore [bad-argument-type] other_obs_feat = ObservationFeatures.from_arm(arm=self.trial.arm, trial_index=1) other_trial = self.experiment.new_trial(GeneratorRun([self.arm])) + # pyrefly: ignore [missing-attribute] other_trial.mark_running(no_runner_required=True) trial_0_data = raw_evaluations_to_data( + # pyrefly: ignore [missing-attribute] {self.trial.arm.name: {"m2": (1, 0)}}, trial_index=self.trial.index, metric_name_to_signature={"m2": "m2"}, ) trial_1_data = raw_evaluations_to_data( + # pyrefly: ignore [missing-attribute] {other_trial.arm.name: {"m2": (1, 0), "tracking": (1, 0)}}, trial_index=other_trial.index, metric_name_to_signature={"m2": "m2", "tracking": "tracking"}, @@ -472,6 +490,7 @@ def test_get_pending_observation_features_hss(self) -> None: none_throws(pf.metadata), none_throws(self.hss_gr.candidate_metadata_by_arm_signature)[ self.hss_arm.signature + # pyrefly: ignore [missing-attribute] ], ) @@ -482,6 +501,7 @@ def test_get_pending_observation_features_hss(self) -> None: self.hss_exp, "lookup_data", return_value=raw_evaluations_to_data( + # pyrefly: ignore [missing-attribute] {self.hss_trial.arm.name: {"m2": (1, 0)}}, trial_index=self.hss_trial.index, metric_name_to_signature={"m2": "m2"}, @@ -508,6 +528,7 @@ def test_get_pending_observation_features_batch_trial(self) -> None: # Status quo of this experiment is out-of-design, so it shouldn't be # among the pending points. self.assertEqual( + # pyrefly: ignore [bad-argument-type] get_pending_observation_features(self.experiment_2), { "tracking": [self.obs_feat], @@ -519,6 +540,7 @@ def test_get_pending_observation_features_batch_trial(self) -> None: # Status quo of this experiment is out-of-design, so it shouldn't be # among the pending points. sq_obs_feat = ObservationFeatures.from_arm( + # pyrefly: ignore [bad-argument-type] self.batch_trial.arms_by_name.get("status_quo"), trial_index=self.batch_trial.index, ) diff --git a/ax/core/trial_status.py b/ax/core/trial_status.py index 5b94c9e9a84..c3f74101122 100644 --- a/ax/core/trial_status.py +++ b/ax/core/trial_status.py @@ -130,6 +130,7 @@ def is_stale(self) -> bool: """True if this trial is a stale one.""" return self == TrialStatus.STALE + # pyrefly: ignore [bad-override-param-name] def __format__(self, fmt: str) -> str: """Define `__format__` to avoid pulling the `__format__` from the `int` mixin (since its better for statuses to show up as `RUNNING` than as diff --git a/ax/core/utils.py b/ax/core/utils.py index 60d67bc10e6..41c0639895a 100644 --- a/ax/core/utils.py +++ b/ax/core/utils.py @@ -251,6 +251,7 @@ def compute_metric_availability( if len(data.metric_names) > 0: df = data.full_df for trial_idx, group in df.groupby("trial_index")["metric_name"]: + # pyrefly: ignore [bad-argument-type] metrics_per_trial[int(trial_idx)] = set(group.unique()) # Compute availability for each trial. diff --git a/ax/early_stopping/simulation.py b/ax/early_stopping/simulation.py index db62df71a1c..6d87fc4b7c8 100644 --- a/ax/early_stopping/simulation.py +++ b/ax/early_stopping/simulation.py @@ -249,6 +249,7 @@ def best_trial_vulnerable( # Check if best trial should be stopped (against reference trials) stop_selector = _check_patience_window( wide_df=wide_df, + # pyrefly: ignore [bad-argument-type] trial_indices={best_trial_index}, progression=progression, patience=patience, @@ -262,6 +263,7 @@ def best_trial_vulnerable( if stop_best: return EarlyStoppingSimulationResult( best_stopped=True, + # pyrefly: ignore [bad-argument-type] best_trial_index=best_trial_index, best_stop_progression=progression, ) @@ -283,5 +285,6 @@ def best_trial_vulnerable( return EarlyStoppingSimulationResult( best_stopped=False, + # pyrefly: ignore [bad-argument-type] best_trial_index=best_trial_index, ) diff --git a/ax/early_stopping/strategies/base.py b/ax/early_stopping/strategies/base.py index 6acfe28aa30..554f889ff54 100644 --- a/ax/early_stopping/strategies/base.py +++ b/ax/early_stopping/strategies/base.py @@ -411,6 +411,7 @@ def is_eligible( # Check eligibility of each metric. for metric_signature, metric_df in df.groupby("metric_signature"): # check for no data + # pyrefly: ignore [bad-index] metric_name = experiment.signature_to_metric[metric_signature].name df_trial = metric_df[metric_df["trial_index"] == trial_index] df_trial = df_trial.dropna(subset=["mean"]) diff --git a/ax/generation_strategy/dispatch_utils.py b/ax/generation_strategy/dispatch_utils.py index c1b7494c9d8..1d25169b81a 100644 --- a/ax/generation_strategy/dispatch_utils.py +++ b/ax/generation_strategy/dispatch_utils.py @@ -55,6 +55,7 @@ def _make_sobol_step( should_deduplicate: bool = False, ) -> GenerationStep: """Shortcut for creating a Sobol generation step.""" + # pyrefly: ignore [bad-return] return GenerationStep( generator=Generators.SOBOL, num_trials=num_trials, @@ -102,6 +103,7 @@ def _make_botorch_step( generator_kwargs["transform_configs"].setdefault("Winsorize", {}) # Add manually specified winsorization config. generator_kwargs["transform_configs"]["Winsorize"]["winsorization_config"] = ( + # pyrefly: ignore [unsupported-operation] winsorization_config ) @@ -124,6 +126,7 @@ def _make_botorch_step( "`disable_progbar`, and `jit_compile` are only supported with" " fully Bayesian models. These are being ignored." ) + # pyrefly: ignore [bad-return] return GenerationStep( generator=generator, num_trials=num_trials, @@ -524,9 +527,11 @@ def choose_generation_strategy_legacy( ) # set name for GS bo_step = nodes[-1] + # pyrefly: ignore [missing-attribute] surrogate_spec = bo_step.generator_spec.generator_kwargs.get("surrogate_spec") name = None if ( + # pyrefly: ignore [missing-attribute] bo_step.generator_spec.generator_enum is Generators.BOTORCH_MODULAR and surrogate_spec is not None and (model_config := surrogate_spec.model_configs[0]).botorch_model_class diff --git a/ax/generation_strategy/external_generation_node.py b/ax/generation_strategy/external_generation_node.py index 311896da90e..c8d69faa139 100644 --- a/ax/generation_strategy/external_generation_node.py +++ b/ax/generation_strategy/external_generation_node.py @@ -120,6 +120,7 @@ def _fitted_adapter(self) -> None: return None @property + # pyrefly: ignore [bad-override] def generator_spec_to_gen_from(self) -> GeneratorSpec | None: return self._generator_spec_to_gen_from @@ -155,6 +156,7 @@ def _fit( ) self.fit_time_since_gen += time.monotonic() - t_fit_start + # pyrefly: ignore [bad-override-param-name] def _gen( self, experiment: Experiment, diff --git a/ax/generation_strategy/generator_spec.py b/ax/generation_strategy/generator_spec.py index bbc328af790..2ee146fadcc 100644 --- a/ax/generation_strategy/generator_spec.py +++ b/ax/generation_strategy/generator_spec.py @@ -376,6 +376,7 @@ def __repr__(self) -> str: def __hash__(self) -> int: return hash(repr(self)) + # pyrefly: ignore [bad-override] def __eq__(self, other: GeneratorSpec) -> bool: return repr(self) == repr(other) diff --git a/ax/generation_strategy/tests/test_best_model_selector.py b/ax/generation_strategy/tests/test_best_model_selector.py index 19e2f6f7c51..ab1143240be 100644 --- a/ax/generation_strategy/tests/test_best_model_selector.py +++ b/ax/generation_strategy/tests/test_best_model_selector.py @@ -39,7 +39,12 @@ def setUp(self) -> None: def test_user_input_error(self) -> None: with self.assertRaisesRegex(UserInputError, "ReductionCriterion"): SingleDiagnosticBestModelSelector( - "Fisher exact test p", metric_aggregation=min, criterion=max + # pyrefly: ignore [bad-argument-type] + "Fisher exact test p", + # pyrefly: ignore [bad-argument-type] + metric_aggregation=min, + # pyrefly: ignore [bad-argument-type] + criterion=max, ) with self.assertRaisesRegex(UserInputError, "use MIN or MAX"): SingleDiagnosticBestModelSelector( diff --git a/ax/generation_strategy/tests/test_center_generation_node.py b/ax/generation_strategy/tests/test_center_generation_node.py index 462762a0ee4..82504969e84 100644 --- a/ax/generation_strategy/tests/test_center_generation_node.py +++ b/ax/generation_strategy/tests/test_center_generation_node.py @@ -167,6 +167,7 @@ def test_center_generation_with_logit_scale(self) -> None: # logit(0.9) = log(0.9 / 0.1) ≈ 2.197 # center in logit space = 0 # inverse_logit(0) = 1 / (1 + exp(0)) = 0.5 + # pyrefly: ignore [bad-argument-type] self.assertAlmostEqual(float(params["x1"]), 0.5, places=5) self.assertEqual(params["x2"], 0.5) @@ -194,6 +195,7 @@ def test_center_generation_with_logit_scale_extreme_bounds(self) -> None: # logit(0.999) = log(0.999) - log(0.001) ≈ 6.906 # center in logit space = (-9.210 + 6.906) / 2 ≈ -1.152 # expit(-1.152) ≈ 0.240 + # pyrefly: ignore [bad-argument-type] center = float(params["x1"]) self.assertGreater(center, 0.0001) # Above lower bound self.assertLess(center, 0.999) # Below upper bound diff --git a/ax/generation_strategy/tests/test_dispatch_utils.py b/ax/generation_strategy/tests/test_dispatch_utils.py index 8620dd82bd3..7ca89e4f5da 100644 --- a/ax/generation_strategy/tests/test_dispatch_utils.py +++ b/ax/generation_strategy/tests/test_dispatch_utils.py @@ -342,6 +342,7 @@ def test_choose_generation_strategy_legacy(self) -> None: ) with self.subTest("BO_MIXED (mixed multi-objective optimization)"): search_space = get_branin_search_space(with_choice_parameter=True) + # pyrefly: ignore [missing-attribute] search_space.parameters["x2"]._is_ordered = False optimization_config = MultiObjectiveOptimizationConfig( objective=Objective( @@ -486,7 +487,9 @@ def test_choose_generation_strategy_legacy(self) -> None: with self.subTest("num_initialization_trials"): ss = get_large_factorial_search_space() + # pyrefly: ignore [missing-attribute] for _, param in ss.parameters.items(): + # pyrefly: ignore [missing-attribute] param._is_ordered = True # 2 * len(ss.parameters) init trials are performed if num_trials is large gs_12_init_trials = choose_generation_strategy_legacy( @@ -576,16 +579,21 @@ def test_make_botorch_step_extra(self) -> None: # Step.__new__` actually returns a `GenerationNode`. none_throws(bo_step.generator_spec.generator_kwargs)["transforms"], [LogY], + # pyrefly: ignore [missing-attribute] ) self.assertEqual( + # pyrefly: ignore [missing-attribute] none_throws(bo_step.generator_spec.generator_kwargs)["transform_configs"], {}, ) # With derelativize_with_raw_status_quo. bo_step = _make_botorch_step( - generator_kwargs=generator_kwargs, derelativize_with_raw_status_quo=True + # pyrefly: ignore [missing-attribute] + generator_kwargs=generator_kwargs, + derelativize_with_raw_status_quo=True, ) self.assertEqual( + # pyrefly: ignore [missing-attribute] none_throws(bo_step.generator_spec.generator_kwargs)["transform_configs"], { "Derelativize": {"use_raw_status_quo": True}, @@ -737,12 +745,16 @@ def test_winsorization(self) -> None: winsorized = choose_generation_strategy_legacy( search_space=get_branin_search_space(), winsorization_config=WinsorizationConfig(upper_quantile_margin=2), + # pyrefly: ignore [bad-argument-type] ) tc = none_throws(winsorized._nodes[1].generator_specs[0].generator_kwargs).get( + # pyrefly: ignore [unsupported-operation] "transform_configs" ) + # pyrefly: ignore [bad-argument-type] self.assertIn("Winsorize", tc) self.assertDictEqual( + # pyrefly: ignore [unsupported-operation] tc["Winsorize"], { "winsorization_config": WinsorizationConfig( @@ -757,22 +769,29 @@ def test_winsorization(self) -> None: winsorized = choose_generation_strategy_legacy( search_space=get_branin_search_space(), derelativize_with_raw_status_quo=True, + # pyrefly: ignore [bad-argument-type] ) tc = none_throws(winsorized._nodes[1].generator_specs[0].generator_kwargs).get( "transform_configs" + # pyrefly: ignore [unsupported-operation] ) self.assertIn( "Winsorize", + # pyrefly: ignore [bad-argument-type] tc, + # pyrefly: ignore [bad-argument-type] ) self.assertDictEqual( + # pyrefly: ignore [unsupported-operation] tc["Winsorize"], {"derelativize_with_raw_status_quo": True}, ) self.assertIn( "Derelativize", + # pyrefly: ignore [bad-argument-type] tc, ) + # pyrefly: ignore [unsupported-operation] self.assertDictEqual(tc["Derelativize"], {"use_raw_status_quo": True}) def test_num_trials(self) -> None: diff --git a/ax/generation_strategy/tests/test_transition_criterion.py b/ax/generation_strategy/tests/test_transition_criterion.py index 362626dd932..1378b532775 100644 --- a/ax/generation_strategy/tests/test_transition_criterion.py +++ b/ax/generation_strategy/tests/test_transition_criterion.py @@ -280,6 +280,7 @@ def test_min_trials_is_met(self) -> None: # Attach data for both completed trials experiment.attach_data( get_branin_data( + # pyrefly: ignore [bad-argument-type] trials=[experiment.trials[0], experiment.trials[1]], metrics=["branin"], ) @@ -353,6 +354,7 @@ def test_min_trials_count_only_with_data(self) -> None: # Attach data for "branin" (the opt config metric) to 1 trial only experiment.attach_data( + # pyrefly: ignore [bad-argument-type] get_branin_data(trials=[experiment.trials[0]], metrics=["branin"]) ) # Still not met — only 1 trial has data, need 2 @@ -367,6 +369,7 @@ def test_min_trials_count_only_with_data(self) -> None: [ { "trial_index": 1, + # pyrefly: ignore [missing-attribute] "arm_name": experiment.trials[1].arm.name, "metric_name": "not_branin", "mean": 1.0, @@ -384,6 +387,7 @@ def test_min_trials_count_only_with_data(self) -> None: # Attach "branin" data to trial 1 too experiment.attach_data( + # pyrefly: ignore [bad-argument-type] get_branin_data(trials=[experiment.trials[1]], metrics=["branin"]) ) # Now 2 trials have "branin" data — criterion should be met diff --git a/ax/generators/discrete/thompson.py b/ax/generators/discrete/thompson.py index 0b9314af5d2..98e57a34995 100644 --- a/ax/generators/discrete/thompson.py +++ b/ax/generators/discrete/thompson.py @@ -183,6 +183,7 @@ def predict( f"(X: {X[j]} - note that this is post-transform application)." ) f[j, i], cov[j, i, i] = X_to_Y_and_Yvar[ + # pyrefly: ignore [bad-index] assert_is_instance(x, TParamValue) ] return f, cov @@ -328,6 +329,7 @@ def _fit_X_to_Ys_and_Yvars( hashableX = [self._hash_TParamValueList(x) for x in X] for Y, Yvar in zip(Ys, Yvars): X_to_Ys_and_Yvars.append(dict(zip(hashableX, zip(Y, Yvar)))) + # pyrefly: ignore [bad-return] return X_to_Ys_and_Yvars def _hash_TParamValueList(self, x: Iterable[TParamValue]) -> str: diff --git a/ax/generators/random/base.py b/ax/generators/random/base.py index 48d3560d256..762f27fdaad 100644 --- a/ax/generators/random/base.py +++ b/ax/generators/random/base.py @@ -152,6 +152,7 @@ def gen( has_continuous_parameters = len(continuous_indices) > 0 if model_gen_options: max_draws = model_gen_options.get("max_rs_draws", DEFAULT_MAX_RS_DRAWS) + # pyrefly: ignore [bad-argument-type] max_draws = int(assert_is_instance_of_tuple(max_draws, (int, float))) try: # With equality constraints, unconstrained sampling has probability diff --git a/ax/generators/random/uniform.py b/ax/generators/random/uniform.py index d87c4136763..eefaf0193c8 100644 --- a/ax/generators/random/uniform.py +++ b/ax/generators/random/uniform.py @@ -34,6 +34,7 @@ def __init__( init_position=init_position, fallback_to_sample_polytope=fallback_to_sample_polytope, ) + # pyrefly: ignore [bad-argument-type] self._rs = np.random.RandomState(seed=self.seed) if self.init_position > 0: # Fast-forward the random state by generating & discarding samples. diff --git a/ax/generators/tests/test_botorch_moo_utils.py b/ax/generators/tests/test_botorch_moo_utils.py index 4af791486ef..cf82b1050b5 100644 --- a/ax/generators/tests/test_botorch_moo_utils.py +++ b/ax/generators/tests/test_botorch_moo_utils.py @@ -219,7 +219,9 @@ def test_get_weighted_mc_objective(self) -> None: weighted_obj = get_weighted_mc_objective( objective_weights=objective_weights, ) + # pyrefly: ignore [bad-argument-type, not-callable] self.assertTrue(torch.equal(weighted_obj.weights, torch.tensor([1.0, 1.0]))) + # pyrefly: ignore [not-callable] self.assertEqual(weighted_obj.outcomes.tolist(), [1, 3]) # test infer objective thresholds alone diff --git a/ax/generators/tests/test_thompson.py b/ax/generators/tests/test_thompson.py index d29bf198abb..7e41359b9f6 100644 --- a/ax/generators/tests/test_thompson.py +++ b/ax/generators/tests/test_thompson.py @@ -105,6 +105,7 @@ def test_ThompsonSamplerTopKError(self) -> None: ) def test_TopTwo_alters_weights_vs_TopOne(self) -> None: + # pyrefly: ignore [bad-argument-type] np.random.seed(0) # Compare TTTS results to the vanilla TS @@ -157,6 +158,7 @@ def test_TopTwo_alters_weights_vs_TopOne(self) -> None: self.assertTrue(full_w2[3] > full_w2[2] > full_w2[1] > full_w2[0]) def test_ThompsonSamplerMinWeight(self) -> None: + # pyrefly: ignore [bad-argument-type] np.random.seed(0) generator = ThompsonSampler(min_weight=0.01) generator.fit( diff --git a/ax/generators/tests/test_torch_model_utils.py b/ax/generators/tests/test_torch_model_utils.py index e82d0542f46..081a3d25f4a 100644 --- a/ax/generators/tests/test_torch_model_utils.py +++ b/ax/generators/tests/test_torch_model_utils.py @@ -71,6 +71,7 @@ def test_with_outcome_constraints_can_subset(self) -> None: self.assertTrue(torch.equal(obj_weights_sub, torch.tensor([[1.0]]))) # pyre-fixme[16]: Optional type has no attribute `__getitem__`. self.assertTrue(torch.equal(ocs_sub[0], torch.tensor([[1.0]]))) + # pyrefly: ignore [unsupported-operation] self.assertTrue(torch.equal(ocs_sub[1], torch.tensor([1.0]))) self.assertTrue(torch.equal(subset_model_results.indices, torch.tensor([0]))) diff --git a/ax/generators/tests/test_torch_utils.py b/ax/generators/tests/test_torch_utils.py index 3a035d3eb16..371ee3c03f0 100644 --- a/ax/generators/tests/test_torch_utils.py +++ b/ax/generators/tests/test_torch_utils.py @@ -48,6 +48,7 @@ def setUp(self) -> None: def test_get_X_pending_and_observed(self) -> None: def _to_obs_set(X: torch.Tensor) -> set[tuple[float]]: + # pyrefly: ignore [bad-return] return {tuple(float(x_i) for x_i in x) for x in X} # Apply filter normally diff --git a/ax/generators/tests/test_utils.py b/ax/generators/tests/test_utils.py index 0ba8753662a..4bf7e9882cf 100644 --- a/ax/generators/tests/test_utils.py +++ b/ax/generators/tests/test_utils.py @@ -214,6 +214,7 @@ def test_rejection_sample(self) -> None: # 1. (0.6, 0.6): sum=1.2 satisfies, but rounds to (1,1): sum=2 violates # 2. (0.4, 0.4): sum=0.8 satisfies, rounds to (0,0): sum=0 satisfies call_count = 0 + # pyrefly: ignore [bad-specialization] values_to_return: list[npt.NDArray[np.floating[Any]]] = [ np.array([[0.6, 0.6]]), np.array([[0.4, 0.4]]), @@ -223,15 +224,22 @@ def mock_gen_unconstrained( n: int, d: int, tunable_feature_indices: npt.NDArray[np.intp], + # pyrefly: ignore [bad-specialization] fixed_features: dict[int, float] | None, + # pyrefly: ignore [bad-specialization] ) -> npt.NDArray[np.floating[Any]]: nonlocal call_count result = values_to_return[min(call_count, len(values_to_return) - 1)] call_count += 1 return result + # pyrefly: ignore [bad-specialization] + + # pyrefly: ignore [bad-specialization] def rounding_func( + # pyrefly: ignore [bad-specialization] point: npt.NDArray[np.floating[Any]], + # pyrefly: ignore [bad-specialization] ) -> npt.NDArray[np.floating[Any]]: return np.round(point) diff --git a/ax/generators/torch/botorch_modular/acquisition.py b/ax/generators/torch/botorch_modular/acquisition.py index 97abfb68cd6..6f06a112377 100644 --- a/ax/generators/torch/botorch_modular/acquisition.py +++ b/ax/generators/torch/botorch_modular/acquisition.py @@ -134,6 +134,7 @@ def determine_optimizer( # the remaining parameters. cardinalities = [len(c) for c in discrete_choices.values()] max_cardinality = max(cardinalities) + # pyrefly: ignore [incompatible-overload-residual] total_discrete_choices = reduce(operator.mul, cardinalities) if total_discrete_choices > MAX_CHOICES_ENUMERATE: if max_cardinality <= MAX_CARDINALITY_FOR_LOCAL_SEARCH: @@ -864,12 +865,14 @@ def optimize( self.acqf, MultiOutputAcquisitionFunctionWrapper ) candidates, acqf_values = optimize_with_nsgaii( + # pyrefly: ignore [bad-argument-type] acq_function=self.acqf, bounds=bounds, q=n, fixed_features=fixed_features, inequality_constraints=inequality_constraints, num_objectives=len(acqf.acqfs), + # pyrefly: ignore [bad-argument-type] discrete_choices=discrete_choices if discrete_choices else None, post_processing_func=rounding_func, **optimizer_options_with_defaults, @@ -913,6 +916,7 @@ def optimize( ) n_candidates = candidates.shape[0] + # pyrefly: ignore [bad-return] return candidates, acqf_values, arm_weights[:n_candidates] * n_candidates / n def evaluate(self, X: Tensor) -> Tensor: @@ -928,6 +932,7 @@ def evaluate(self, X: Tensor) -> Tensor: model and input `X`. """ if isinstance(self.acqf, qKnowledgeGradient): + # pyrefly: ignore [bad-argument-count, unexpected-keyword] return self.acqf.evaluate(X=X) else: # NOTE: `AcquisitionFunction.__call__` calls `forward`, diff --git a/ax/generators/torch/botorch_modular/generator.py b/ax/generators/torch/botorch_modular/generator.py index 456ac78778c..70a2ef744f9 100644 --- a/ax/generators/torch/botorch_modular/generator.py +++ b/ax/generators/torch/botorch_modular/generator.py @@ -527,6 +527,7 @@ def cross_validate( return X_test_prediction @property + # pyrefly: ignore [bad-override] def dtype(self) -> torch.dtype: """Torch data type of the tensors in the training data used in the model, of which this ``Acquisition`` is a subcomponent. @@ -534,6 +535,7 @@ def dtype(self) -> torch.dtype: return self.surrogate.dtype @property + # pyrefly: ignore [bad-override] def device(self) -> torch.device: """Torch device type of the tensors in the training data used in the model, of which this ``Acquisition`` is a subcomponent. diff --git a/ax/generators/torch/botorch_modular/optimizer_argparse.py b/ax/generators/torch/botorch_modular/optimizer_argparse.py index c885744c9a4..c4e46797540 100644 --- a/ax/generators/torch/botorch_modular/optimizer_argparse.py +++ b/ax/generators/torch/botorch_modular/optimizer_argparse.py @@ -106,6 +106,7 @@ def optimizer_argparse( "optimize_acqf_mixed", "optimize_acqf_mixed_alternating", ]: + # pyrefly: ignore [unsupported-operation] options["options"] = { "init_batch_limit": INIT_BATCH_LIMIT, "batch_limit": BATCH_LIMIT, @@ -125,5 +126,6 @@ def optimizer_argparse( if optimizer == "optimize_acqf": options["sequential"] = True + # pyrefly: ignore [bad-argument-type] options.update(**{k: v for k, v in provided_options.items() if k != "options"}) return options diff --git a/ax/generators/torch/botorch_modular/surrogate.py b/ax/generators/torch/botorch_modular/surrogate.py index 09800143e2b..17c047ad7e0 100644 --- a/ax/generators/torch/botorch_modular/surrogate.py +++ b/ax/generators/torch/botorch_modular/surrogate.py @@ -556,6 +556,7 @@ def _construct_model( ) botorch_model_class = none_throws(model_config.botorch_model_class) if self._dataset_matches_cache(dataset=dataset): + # pyrefly: ignore [bad-index] return self._submodels[outcome_names] formatted_model_inputs = submodel_input_constructor( botorch_model_class, # Do not pass as kwarg since this is used to dispatch. @@ -756,11 +757,14 @@ def fit( self.metric_to_best_model_config[metric_signature] = none_throws( best_model_config ) + # pyrefly: ignore [unsupported-operation] self._submodels[outcome_name_tuple] = model + # pyrefly: ignore [unsupported-operation] self._last_datasets[outcome_name_tuple] = dataset if should_use_model_list: if all(isinstance(model, GPyTorchModel) for model in models): + # pyrefly: ignore [bad-argument-type] self._model = ModelListGP(*models) else: self._model = ModelList(*models) @@ -1173,8 +1177,10 @@ def models_for_gen(self, n: int) -> tuple[list[dict[str, str]], list[Model]]: models_i.append(self._model_name_to_model[outcome][model_name]) model_names_i[outcome] = model_name if isinstance(self._model, ModelListGP): + # pyrefly: ignore [bad-argument-type] models.append(ModelListGP(*models_i)) elif isinstance(self._model, ModelList): + # pyrefly: ignore [bad-argument-type] models.append(ModelList(*models_i)) elif len(models_i) > 1: # If MBM supports ModelList in the future, this will need to be @@ -1183,8 +1189,10 @@ def models_for_gen(self, n: int) -> tuple[list[dict[str, str]], list[Model]]: "Got multiple models but not a ModelListGP or ModelList." ) # pragma: no cover else: + # pyrefly: ignore [bad-argument-type] models.append(models_i[0]) model_names.append(model_names_i) + # pyrefly: ignore [bad-return] return model_names, models diff --git a/ax/generators/torch/botorch_modular/utils.py b/ax/generators/torch/botorch_modular/utils.py index 36b16557ac3..4675f0ea637 100644 --- a/ax/generators/torch/botorch_modular/utils.py +++ b/ax/generators/torch/botorch_modular/utils.py @@ -212,6 +212,7 @@ def use_model_list( # e.g. a contextual model, where we want to jointly model the metric # each context (and context-level metrics are different outcomes). return False + # pyrefly: ignore [bad-argument-type] elif issubclass(botorch_model_class, BatchedMultiOutputGPyTorchModel) and all( torch.equal(datasets[0].X, ds.X) for ds in datasets[1:] ): diff --git a/ax/generators/torch/tests/test_acquisition.py b/ax/generators/torch/tests/test_acquisition.py index c4232c21f7f..f0eeecf0a48 100644 --- a/ax/generators/torch/tests/test_acquisition.py +++ b/ax/generators/torch/tests/test_acquisition.py @@ -87,6 +87,7 @@ # Used to avoid going through BoTorch `Acquisition.__init__` which # requires valid kwargs (correct sizes and lengths of tensors, etc). class DummyAcquisitionFunction(AcquisitionFunction): + # pyrefly: ignore [bad-override-mutable-attribute] X_pending: Tensor | None = None def __init__(self, eta: float = 1e-3, model: Any = None, **kwargs: Any) -> None: @@ -105,6 +106,10 @@ def forward(self, X: Tensor) -> Tensor: return torch.atleast_1d(res) +# pyrefly: ignore [inconsistent-inheritance] + + +# pyrefly: ignore [inconsistent-inheritance] class DummyOneShotAcquisitionFunction(DummyAcquisitionFunction, qKnowledgeGradient): def evaluate(self, X: Tensor, **kwargs: Any) -> Tensor: return X.sum(dim=-1) @@ -361,11 +366,14 @@ def test_optimize(self) -> None: ) as mock_prune_irrelevant_parameters, ): acquisition.optimize( + # pyrefly: ignore [bad-argument-type] n=n, search_space_digest=self.search_space_digest, + # pyrefly: ignore [bad-argument-type] inequality_constraints=self.inequality_constraints, fixed_features=self.fixed_features, rounding_func=self.rounding_func, + # pyrefly: ignore [bad-argument-type] optimizer_options=self.optimizer_options, ) mock_optimizer_argparse.assert_called_once_with( @@ -816,13 +824,17 @@ def test_optimize_acqf_discrete_local_search(self) -> None: mock.patch( f"{ACQUISITION_PATH}.optimizer_argparse", wraps=optimizer_argparse ) as mock_optimizer_argparse, + # pyrefly: ignore [bad-argument-type] ): acquisition.optimize( n=3, + # pyrefly: ignore [bad-argument-type] search_space_digest=ssd, + # pyrefly: ignore [bad-argument-type] inequality_constraints=self.inequality_constraints, fixed_features=None, rounding_func=self.rounding_func, + # pyrefly: ignore [bad-argument-type] optimizer_options=self.optimizer_options, ) mock_optimizer_argparse.assert_called_once_with( @@ -845,16 +857,25 @@ def test_optimize_acqf_discrete_local_search(self) -> None: }, set(kwargs.keys()), ) + # pyrefly: ignore [bad-index] self.assertEqual(kwargs["acq_function"], acquisition.acqf) self.assertEqual(kwargs["q"], 3) self.assertEqual( - kwargs["inequality_constraints"], self.inequality_constraints + # pyrefly: ignore [bad-index] + kwargs["inequality_constraints"], + self.inequality_constraints, ) self.assertEqual( - kwargs["num_restarts"], self.optimizer_options["num_restarts"] + # pyrefly: ignore [bad-index] + kwargs["num_restarts"], + # pyrefly: ignore [bad-index] + self.optimizer_options["num_restarts"], ) self.assertEqual( - kwargs["raw_samples"], self.optimizer_options["raw_samples"] + # pyrefly: ignore [bad-index] + kwargs["raw_samples"], + # pyrefly: ignore [bad-index] + self.optimizer_options["raw_samples"], ) self.assertTrue( all( @@ -962,9 +983,11 @@ def test_optimize_acqf_discrete_too_many_choices(self) -> None: mock.patch( f"{ACQUISITION_PATH}.optimizer_argparse", wraps=optimizer_argparse ) as mock_optimizer_argparse, + # pyrefly: ignore [bad-argument-type] mock.patch( f"{ACQUISITION_PATH}.optimize_acqf_discrete_local_search", return_value=(valid_candidates, torch.rand(3)), + # pyrefly: ignore [bad-argument-type] ), mock.patch( f"{ACQUISITION_PATH}.optimize_acqf_mixed_alternating", @@ -974,9 +997,11 @@ def test_optimize_acqf_discrete_too_many_choices(self) -> None: acquisition.optimize( n=3, search_space_digest=ssd, + # pyrefly: ignore [bad-argument-type] inequality_constraints=self.inequality_constraints, fixed_features=None, rounding_func=self.rounding_func, + # pyrefly: ignore [bad-argument-type] optimizer_options=self.optimizer_options, ) mock_optimizer_argparse.assert_called_once_with( @@ -986,9 +1011,11 @@ def test_optimize_acqf_discrete_too_many_choices(self) -> None: ) @mock_botorch_optimize + # pyrefly: ignore [bad-argument-type] def test_optimize_mixed(self) -> None: ssd = SearchSpaceDigest( feature_names=["a", "b"], + # pyrefly: ignore [bad-argument-type] bounds=[(0, 1), (0, 2)], categorical_features=[1], discrete_choices={1: [0, 1, 2]}, @@ -1000,9 +1027,11 @@ def test_optimize_mixed(self) -> None: acquisition.optimize( n=3, search_space_digest=ssd, + # pyrefly: ignore [bad-argument-type] inequality_constraints=self.inequality_constraints, fixed_features=None, rounding_func=self.rounding_func, + # pyrefly: ignore [bad-argument-type] optimizer_options=self.optimizer_options, ) mock_optimize_acqf_mixed.assert_called_with( @@ -1027,6 +1056,7 @@ def test_optimize_mixed(self) -> None: @mock_botorch_optimize def test_optimize_acqf_mixed_alternating(self) -> None: b_upper_bound = 15 + # pyrefly: ignore [bad-argument-type] ssd = SearchSpaceDigest( feature_names=["a", "b", "c"], bounds=[(0, 1), (0, b_upper_bound), (0, 5)], @@ -1043,6 +1073,7 @@ def test_optimize_acqf_mixed_alternating(self) -> None: acquisition.optimize( n=3, search_space_digest=ssd, + # pyrefly: ignore [bad-argument-type] inequality_constraints=self.inequality_constraints, fixed_features={0: 0.5}, rounding_func=self.rounding_func, @@ -1070,6 +1101,7 @@ def test_optimize_acqf_mixed_alternating(self) -> None: num_restarts=2, raw_samples=4, ) + # pyrefly: ignore [bad-argument-type] # Check with cateogrial features but no non-integer features. ssd_categorical = dataclasses.replace( @@ -1087,6 +1119,7 @@ def test_optimize_acqf_mixed_alternating(self) -> None: candidates, acqf_values, arm_weights = acquisition.optimize( n=3, search_space_digest=ssd_categorical, + # pyrefly: ignore [bad-argument-type] inequality_constraints=self.inequality_constraints, fixed_features={0: 0.5}, rounding_func=self.rounding_func, @@ -1136,11 +1169,13 @@ def test_optimize_acqf_mixed_alternating(self) -> None: mock_alternating.assert_called() # Check if the `fixed_features` argument works for discrete features. + # pyrefly: ignore [bad-argument-type] ub = 10 ssd_many_combinations = SearchSpaceDigest( feature_names=["a", "b", "c"], bounds=[(0, 1), (0, ub), (0, ub)], ordinal_features=[1, 2], + # pyrefly: ignore [bad-argument-type] discrete_choices={1: list(range(ub + 1)), 2: list(range(ub + 1))}, ) dict_args = { @@ -1154,11 +1189,13 @@ def test_optimize_acqf_mixed_alternating(self) -> None: f"{ACQUISITION_PATH}.optimize_acqf_mixed_alternating", wraps=optimize_acqf_mixed_alternating, ) as mock_alternating: + # pyrefly: ignore [bad-argument-type] acquisition.optimize(**dict_args) mock_alternating.assert_called() # Now that we have made sure alternating minimization is called, call the # optimizer for real. + # pyrefly: ignore [bad-argument-type] candidates, _, _ = acquisition.optimize(**dict_args) self.assertTrue((candidates[:, 1] == 0).all()) @@ -2291,9 +2328,11 @@ def test_no_pruning_with_qLogProbabilityOfFeasibility(self) -> None: acquisition = self.get_acquisition_function( fixed_features=self.fixed_features, ) + # pyrefly: ignore [bad-argument-type] n = 1 # Create valid mock candidates that satisfy: # - Bounds: [(0, 10), (0, 10), (0, 10)] + # pyrefly: ignore [bad-argument-type] # - Fixed features: x[1] = 2.0 # - Inequality constraint: -x[0] + x[1] >= 1 => x[0] <= x[1] - 1 = 1.0 valid_candidates = torch.tensor([[0.5, 2.0, 5.0]], **self.tkwargs) @@ -2311,9 +2350,11 @@ def test_no_pruning_with_qLogProbabilityOfFeasibility(self) -> None: acquisition.optimize( n=n, search_space_digest=self.search_space_digest, + # pyrefly: ignore [bad-argument-type] inequality_constraints=self.inequality_constraints, fixed_features=self.fixed_features, rounding_func=self.rounding_func, + # pyrefly: ignore [bad-argument-type] optimizer_options=self.optimizer_options, ) mock_prune_irrelevant_parameters.assert_not_called() @@ -2370,10 +2411,12 @@ def test_validate_candidates(self) -> None: ) @mock_botorch_optimize + # pyrefly: ignore [bad-argument-type] def test_optimize_with_equality_constraints(self) -> None: """Test that equality_constraints are forwarded to optimize_acqf.""" acquisition = self.get_acquisition_function( fixed_features=self.fixed_features, + # pyrefly: ignore [bad-argument-type] ) # Equality constraint: x[0] + x[2] = 4.0 # Compatible with fixed_features={1: 2.0} and @@ -2392,10 +2435,12 @@ def test_optimize_with_equality_constraints(self) -> None: acquisition.optimize( n=n, search_space_digest=self.search_space_digest, + # pyrefly: ignore [bad-argument-type] inequality_constraints=self.inequality_constraints, equality_constraints=equality_constraints, fixed_features=self.fixed_features, rounding_func=self.rounding_func, + # pyrefly: ignore [bad-argument-type] optimizer_options=self.optimizer_options, ) mock_optimize_acqf.assert_called_with( @@ -2421,10 +2466,12 @@ def test_optimize_mixed_with_equality_constraints(self) -> None: """Test that equality_constraints are forwarded to optimize_acqf_mixed.""" ssd = SearchSpaceDigest( feature_names=["a", "b"], + # pyrefly: ignore [bad-argument-type] bounds=[(0, 1), (0, 2)], categorical_features=[1], discrete_choices={1: [0, 1, 2]}, ) + # pyrefly: ignore [bad-argument-type] acquisition = self.get_acquisition_function() equality_constraints = [ ( @@ -2445,10 +2492,12 @@ def test_optimize_mixed_with_equality_constraints(self) -> None: acquisition.optimize( n=3, search_space_digest=ssd, + # pyrefly: ignore [bad-argument-type] inequality_constraints=self.inequality_constraints, equality_constraints=equality_constraints, fixed_features=None, rounding_func=self.rounding_func, + # pyrefly: ignore [bad-argument-type] optimizer_options=self.optimizer_options, ) mock_optimize_acqf_mixed.assert_called_with( @@ -2468,6 +2517,7 @@ def test_optimize_acqf_mixed_alternating_with_equality_constraints( self, ) -> None: """Test equality_constraints forwarded to optimize_acqf_mixed_alternating.""" + # pyrefly: ignore [bad-argument-type] ssd = SearchSpaceDigest( feature_names=["a", "b", "c"], bounds=[(0, 1), (0, 15), (0, 5)], @@ -2494,6 +2544,7 @@ def test_optimize_acqf_mixed_alternating_with_equality_constraints( acquisition.optimize( n=3, search_space_digest=ssd, + # pyrefly: ignore [bad-argument-type] inequality_constraints=self.inequality_constraints, equality_constraints=equality_constraints, fixed_features={0: 0.5}, @@ -2648,6 +2699,7 @@ def test_init_with_subset_model_false( mock_get_objective_and_transform.return_value = (botorch_objective, None) mock_get_X.return_value = (self.pending_observations[0], self.X[:1]) self.options[Keys.SUBSET_MODEL] = False + # pyrefly: ignore [bad-argument-type] with mock.patch( f"{ACQUISITION_PATH}.get_outcome_constraint_transforms", return_value=self.constraints, @@ -2675,6 +2727,7 @@ def test_init_with_subset_model_false( self.assertEqual(self.mock_input_constructor.call_count, 2) for call, (_, botorch_acqf_options) in zip( self.mock_input_constructor.call_args_list, + # pyrefly: ignore [bad-argument-type] self.botorch_acqf_classes_with_options, ): ckwargs = call.kwargs @@ -2700,6 +2753,7 @@ def test_optimize(self) -> None: acquisition = self.get_acquisition_function(fixed_features=self.fixed_features) n = 5 # Use more generations and larger population to reliably find feasible + # pyrefly: ignore [bad-argument-type] # candidates that satisfy the inequality constraint optimizer_options = {"max_gen": 10, "population_size": 50} # Mock candidates that satisfy constraints: @@ -2728,6 +2782,7 @@ def test_optimize(self) -> None: acquisition.optimize( n=n, search_space_digest=self.search_space_digest, + # pyrefly: ignore [bad-argument-type] inequality_constraints=self.inequality_constraints, fixed_features=self.fixed_features, rounding_func=self.rounding_func, @@ -2763,6 +2818,7 @@ def test_optimize(self) -> None: @skip_if_import_error def test_optimize_with_nsgaii_features(self) -> None: + # pyrefly: ignore [bad-argument-type] """Test that optimize_with_nsgaii correctly handles all features. This tests that candidates generated by optimize_with_nsgaii: @@ -2793,6 +2849,7 @@ def rounding_func(X: Tensor) -> Tensor: candidates, _, _ = acquisition.optimize( n=n, search_space_digest=discrete_search_space_digest, + # pyrefly: ignore [bad-argument-type] inequality_constraints=self.inequality_constraints, fixed_features=self.fixed_features, rounding_func=rounding_func, @@ -2820,6 +2877,7 @@ def rounding_func(X: Tensor) -> Tensor: ) # 3. Verify discrete choices: dimension 0 should only have allowed values + # pyrefly: ignore [bad-argument-type] allowed_values = torch.tensor( discrete_search_space_digest.discrete_choices[0], **self.tkwargs ) @@ -2850,6 +2908,7 @@ def test_optimize_nsgaii_raises_with_equality_constraints(self) -> None: acquisition.optimize( n=3, search_space_digest=self.search_space_digest, + # pyrefly: ignore [bad-argument-type] inequality_constraints=self.inequality_constraints, equality_constraints=equality_constraints, fixed_features=self.fixed_features, diff --git a/ax/generators/torch/tests/test_generator.py b/ax/generators/torch/tests/test_generator.py index b91b98bd0df..05ef349de65 100644 --- a/ax/generators/torch/tests/test_generator.py +++ b/ax/generators/torch/tests/test_generator.py @@ -183,13 +183,16 @@ def setUp(self) -> None: linear_constraints=self.linear_constraints, fixed_features=self.fixed_features, pending_observations=self.pending_observations, + # pyrefly: ignore [bad-argument-type] model_gen_options=self.model_gen_options, ) self.moo_torch_opt_config = dataclasses.replace( self.torch_opt_config, objective_weights=self.moo_objective_weights, objective_thresholds=self.moo_objective_thresholds, + # pyrefly: ignore [bad-argument-type] outcome_constraints=self.moo_outcome_constraints, + # pyrefly: ignore [bad-argument-type] model_gen_options={ Keys.OPTIMIZER_KWARGS: self.optimizer_options, Keys.ACQF_KWARGS: {"eta": 3.0}, @@ -814,8 +817,10 @@ def test_feature_importances(self) -> None: ), ) ) + # pyrefly: ignore [bad-argument-type] self.assertEqual(importances.shape, (2, 1, 3)) # Add model we don't support + # pyrefly: ignore [bad-argument-type] vanilla_model.covar_module = None model.surrogate._model = vanilla_model # pyre-ignore with self.assertRaisesRegex( @@ -1101,9 +1106,11 @@ def test_MOO(self) -> None: torch.cat([ds.Y for ds in self.moo_training_data], dim=-1), ) ) + # pyrefly: ignore [bad-argument-type] self.assertTrue( torch.equal( training_data.Yvar, + # pyrefly: ignore [bad-argument-type] torch.cat([ds.Yvar for ds in self.moo_training_data], dim=-1), ) ) @@ -1119,10 +1126,12 @@ def test_MOO(self) -> None: # gen_metadata stores maximization-aligned thresholds. obj_t = gen_results.gen_metadata["objective_thresholds"] self.assertTrue(torch.equal(obj_t, self.moo_objective_thresholds)) + # pyrefly: ignore [bad-argument-type] self.assertIsInstance(ckwargs["objective"], WeightedMCMultiOutputObjective) self.assertTrue( torch.equal( + # pyrefly: ignore [bad-argument-type] ckwargs["objective"].weights, extract_objectives(self.moo_objective_weights)[1], ) @@ -1213,11 +1222,13 @@ def test_gen_multi_acquisition(self) -> None: botorch_acqf_classes_with_options = [ (PosteriorMean, {}), (qLogNoisyExpectedImprovement, self.botorch_acqf_options), + # pyrefly: ignore [bad-argument-type] ] surrogate = Surrogate() model = BoTorchGenerator( surrogate=surrogate, acquisition_class=MultiAcquisition, + # pyrefly: ignore [bad-argument-type] botorch_acqf_classes_with_options=botorch_acqf_classes_with_options, ) diff --git a/ax/generators/torch/tests/test_kernels.py b/ax/generators/torch/tests/test_kernels.py index 0aab3c5a6d0..310bc1ebea0 100644 --- a/ax/generators/torch/tests/test_kernels.py +++ b/ax/generators/torch/tests/test_kernels.py @@ -193,7 +193,10 @@ def test_default_kernel(self) -> None: self.assertIsNone(botorch_kernel.active_dims) else: self.assertEqual( - ax_kernel.active_dims.tolist(), botorch_kernel.active_dims.tolist() + # pyrefly: ignore [not-callable] + ax_kernel.active_dims.tolist(), + # pyrefly: ignore [not-callable] + botorch_kernel.active_dims.tolist(), ) def test_default_mle(self) -> None: @@ -205,7 +208,9 @@ def test_default_mle(self) -> None: batch_shape=torch.Size([3]), mle=True, ) + # pyrefly: ignore [not-callable] self.assertTrue((kernel.lengthscale == sqrt(2) / 10).all()) self.assertEqual(kernel.lengthscale.shape, torch.Size([3, 1, 2])) self.assertFalse(hasattr(kernel, "lengthscale_prior")) + # pyrefly: ignore [not-callable] self.assertEqual(kernel.active_dims.tolist(), active_dims) diff --git a/ax/generators/torch/tests/test_optimizer_argparse.py b/ax/generators/torch/tests/test_optimizer_argparse.py index 4036c9fb4df..49cd02274f5 100644 --- a/ax/generators/torch/tests/test_optimizer_argparse.py +++ b/ax/generators/torch/tests/test_optimizer_argparse.py @@ -114,6 +114,7 @@ def test_optimizer_options(self) -> None: expected_options = {k: v for k, v in default.items() if k != "options"} if "options" in default: expected_options["options"] = { + # pyrefly: ignore [invalid-argument] **default["options"], **inner_options, } diff --git a/ax/generators/torch/tests/test_surrogate.py b/ax/generators/torch/tests/test_surrogate.py index 80dfa36b178..859ecc08e33 100644 --- a/ax/generators/torch/tests/test_surrogate.py +++ b/ax/generators/torch/tests/test_surrogate.py @@ -280,6 +280,7 @@ def test__make_botorch_input_transform(self) -> None: dataset=dataset, ) transform = assert_is_instance(transform, Normalize) + # pyrefly: ignore [not-callable] self.assertEqual(transform.indices.tolist(), [0]) self.assertEqual(transform.bounds.tolist(), [[1.0], [5.0]]) @@ -435,7 +436,9 @@ def _get_surrogate( if use_outcome_transform: outcome_transform_classes: list[type[OutcomeTransform]] = [Standardize] outcome_transform_options = {"Standardize": {"m": n_outcomes}} + # pyrefly: ignore [bad-assignment] else: + # pyrefly: ignore [bad-assignment] outcome_transform_classes = None outcome_transform_options = {} @@ -529,41 +532,61 @@ def test_copy_options(self) -> None: refit=True, ) models = assert_is_instance(surrogate.model.models, ModuleList) + # pyrefly: ignore [missing-attribute] model1_old_lengtscale = ( + # pyrefly: ignore [missing-attribute] models[1].covar_module.base_kernel.lengthscale.detach().clone() ) # Change the lengthscales of one model and make sure the other isn't changed + # pyrefly: ignore [missing-attribute] models[0].covar_module.base_kernel.lengthscale += 1 self.assertAllClose( + # pyrefly: ignore [missing-attribute] model1_old_lengtscale, + # pyrefly: ignore [missing-attribute] models[1].covar_module.base_kernel.lengthscale, ) # Test the same thing with the likelihood noise constraint + # pyrefly: ignore [missing-attribute] models[0].likelihood.noise_covar.raw_noise_constraint.lower_bound.fill_(1e-4) self.assertEqual( - models[0].likelihood.noise_covar.raw_noise_constraint.lower_bound, 1e-4 + # pyrefly: ignore [missing-attribute] + models[0].likelihood.noise_covar.raw_noise_constraint.lower_bound, + 1e-4, + # pyrefly: ignore [missing-attribute] ) self.assertEqual( - models[1].likelihood.noise_covar.raw_noise_constraint.lower_bound, 1e-3 + # pyrefly: ignore [missing-attribute] + models[1].likelihood.noise_covar.raw_noise_constraint.lower_bound, + # pyrefly: ignore [missing-attribute] + 1e-3, ) # Check input transform # bounds will be taken from the search space digest + # pyrefly: ignore [missing-attribute] self.assertAllClose( + # pyrefly: ignore [missing-attribute] models[0].input_transform.offset, + # pyrefly: ignore [missing-attribute] torch.tensor([[0, 1, 2]], **self.tkwargs), ) self.assertAllClose( + # pyrefly: ignore [missing-attribute] models[1].input_transform.offset, torch.tensor([[0, 1, 2]], **self.tkwargs), ) # Check outcome transform self.assertAllClose( - models[0].outcome_transform.means, torch.tensor([[3.5]], **self.tkwargs) + # pyrefly: ignore [missing-attribute] + models[0].outcome_transform.means, + torch.tensor([[3.5]], **self.tkwargs), ) self.assertAllClose( - models[1].outcome_transform.means, torch.tensor([[7]], **self.tkwargs) + # pyrefly: ignore [missing-attribute] + models[1].outcome_transform.means, + torch.tensor([[7]], **self.tkwargs), ) def test_botorch_transforms(self) -> None: @@ -645,7 +668,9 @@ def test_dtype_and_device_properties(self) -> None: f"{SURROGATE_PATH}.submodel_input_constructor", wraps=submodel_input_constructor, ) + # pyrefly: ignore [bad-index] @patch(f"{SURROGATE_PATH}.fit_botorch_model", wraps=fit_botorch_model) + # pyrefly: ignore [bad-index] def test_fit_model_reuse(self, mock_fit: Mock, mock_constructor: Mock) -> None: surrogate, _ = self._get_surrogate( botorch_model_class=SingleTaskGP, use_outcome_transform=False @@ -658,10 +683,13 @@ def test_fit_model_reuse(self, mock_fit: Mock, mock_constructor: Mock) -> None: datasets=self.training_data, search_space_digest=search_space_digest, ) + # pyrefly: ignore [bad-index] mock_fit.assert_called_once() mock_constructor.assert_called_once() key = tuple(self.training_data[0].outcome_names) + # pyrefly: ignore [bad-index] submodel = surrogate._submodels[key] + # pyrefly: ignore [bad-index] self.assertIs(surrogate._last_datasets[key], self.training_data[0]) self.assertIs(surrogate._last_search_space_digest, search_space_digest) @@ -674,6 +702,7 @@ def test_fit_model_reuse(self, mock_fit: Mock, mock_constructor: Mock) -> None: mock_fit.assert_called_once() mock_constructor.assert_called_once() # Model is still the same object. + # pyrefly: ignore [bad-index] self.assertIs(submodel, surrogate._submodels[key]) # Change the search space digest. @@ -692,6 +721,7 @@ def test_fit_model_reuse(self, mock_fit: Mock, mock_constructor: Mock) -> None: self.assertIn( "Discarding all previously trained models", mock_log.call_args[0][0] ) + # pyrefly: ignore [bad-index] self.assertIsNot(submodel, surrogate._submodels[key]) self.assertIs(surrogate._last_search_space_digest, search_space_digest) @@ -738,7 +768,9 @@ def test_construct_model(self) -> None: len(call_kwargs), 6 if botorch_model_class is SaasFullyBayesianSingleTaskGP else 4, ) + # pyrefly: ignore [unsupported-operation] + # pyrefly: ignore [unsupported-operation] mock_construct_inputs.assert_called_with( training_data=self.training_data[0], ) @@ -758,7 +790,9 @@ def test_construct_model(self) -> None: # Cache the model & dataset as we would in `Surrogate.fit``. outcomes = self.training_data[0].outcome_names key = tuple(outcomes) + # pyrefly: ignore [unsupported-operation] surrogate._submodels[key] = model + # pyrefly: ignore [unsupported-operation] surrogate._last_datasets[key] = self.training_data[0] surrogate.metric_to_best_model_config[outcomes[0]] = ( surrogate.surrogate_spec.model_configs[0] @@ -987,6 +1021,7 @@ def test_construct_model_remove_task_features(self) -> None: patch.object( botorch_model_class, "__init__", return_value=None, autospec=True ) as mock_init, + # pyrefly: ignore [bad-argument-type] patch(f"{SURROGATE_PATH}.fit_botorch_model") as mock_fit, ): surrogate._construct_model( @@ -1009,6 +1044,7 @@ def test_construct_model_remove_task_features(self) -> None: # check that active_dims is set to omit task feature if remove_task_features: self.assertTrue( + # pyrefly: ignore [bad-argument-type] torch.equal(covar_module.active_dims, torch.tensor([0, 1, 2])) ) else: @@ -1032,6 +1068,7 @@ def test_fit_multiple_model_configs( self, mock_diag_dict: Mock, mock_in_sample_metric: Mock ) -> None: # These mocks are used because we control which kernel is selected by + # pyrefly: ignore [unbound-name] # changing the values of model diagnostics side_effect_dict = { "MSE": [0.2, 0.1], @@ -1055,6 +1092,7 @@ def test_fit_multiple_model_configs( elif eval_criterion == "Log likelihood": mock_diag_fn = mock_ll + # pyrefly: ignore [unbound-name] mock_diag_fn.side_effect = side_effect_dict[eval_criterion] mock_diag_fn.reset_mock() mock_in_sample_metric.reset_mock() @@ -1152,6 +1190,7 @@ def test_fit_multiple_model_configs( warnings.filterwarnings("always") surrogate.fit( [dataset], + # pyrefly: ignore [bad-index] search_space_digest=search_space_digest, ) @@ -1176,7 +1215,7 @@ def test_fit_multiple_model_configs( covar_module = ( model.covar_module if not isinstance(model, MultiTaskGP) - else model.covar_module.kernels[0] + else model.covar_module.kernels[0] # pyrefly: ignore [bad-index] ) self.assertIsInstance( @@ -2249,16 +2288,21 @@ def test_with_botorch_transforms(self) -> None: model_configs=[ ModelConfig( botorch_model_class=SingleTaskGP, + # pyrefly: ignore [missing-attribute] mll_class=ExactMarginalLogLikelihood, + # pyrefly: ignore [missing-attribute] input_transform_classes=[Normalize], input_transform_options={ + # pyrefly: ignore [missing-attribute] "Normalize": {"d": 3, "bounds": None, "indices": None} }, outcome_transform_classes=[Standardize], + # pyrefly: ignore [missing-attribute] outcome_transform_options={"Standardize": {"m": 1}}, ) ] ) + # pyrefly: ignore [missing-attribute] ) surrogate.fit( datasets=self.supervised_training_data, @@ -2273,16 +2317,23 @@ def test_with_botorch_transforms(self) -> None: for i in range(2): self.assertIsInstance(models[i].outcome_transform, Standardize) self.assertIsInstance(models[i].input_transform, Normalize) + # pyrefly: ignore [missing-attribute] self.assertEqual(models[0].outcome_transform.means.item(), 4.5) + # pyrefly: ignore [missing-attribute] self.assertEqual(models[1].outcome_transform.means.item(), 3.5) self.assertAlmostEqual( - models[0].outcome_transform.stdvs.item(), 1 / math.sqrt(2) + # pyrefly: ignore [missing-attribute] + models[0].outcome_transform.stdvs.item(), + 1 / math.sqrt(2), ) self.assertAlmostEqual( - models[1].outcome_transform.stdvs.item(), 1 / math.sqrt(2) + # pyrefly: ignore [missing-attribute] + models[1].outcome_transform.stdvs.item(), + 1 / math.sqrt(2), ) self.assertTrue( torch.allclose( + # pyrefly: ignore [missing-attribute] models[0].input_transform.bounds, models[1].input_transform.bounds + 1.0, # pyre-ignore ) diff --git a/ax/generators/torch/tests/test_utils.py b/ax/generators/torch/tests/test_utils.py index 8126998aa73..43a6671d5dd 100644 --- a/ax/generators/torch/tests/test_utils.py +++ b/ax/generators/torch/tests/test_utils.py @@ -787,6 +787,7 @@ def test_construct_acquisition_and_optimizer_options(self) -> None: construct_acquisition_and_optimizer_options( acqf_options=acqf_options, botorch_acqf_options=botorch_acqf_options, + # pyrefly: ignore [bad-argument-type] model_gen_options={**model_gen_options, "extra": "key"}, ) diff --git a/ax/generators/utils.py b/ax/generators/utils.py index f23ea25a3fc..b2baacece00 100644 --- a/ax/generators/utils.py +++ b/ax/generators/utils.py @@ -121,9 +121,13 @@ def rejection_sample( # _gen_unconstrained returns points including fixed features. # pyre-ignore[28]: Unexpected keyword argument to anonymous call. point = gen_unconstrained( + # pyrefly: ignore [bad-argument-count, unexpected-keyword] n=1, + # pyrefly: ignore [unexpected-keyword] d=d, + # pyrefly: ignore [unexpected-keyword] tunable_feature_indices=tunable_feature_indices, + # pyrefly: ignore [unexpected-keyword] fixed_features=fixed_features, )[0] @@ -378,10 +382,12 @@ def best_observed_point( bounds=bounds, objective_weights=objective_weights, outcome_constraints=outcome_constraints, + # pyrefly: ignore [bad-return] linear_constraints=linear_constraints, fixed_features=fixed_features, options=options, ) + # pyrefly: ignore [bad-return] return None if best_point_and_value is None else best_point_and_value[0] @@ -395,6 +401,7 @@ def best_in_sample_point( fixed_features: dict[int, float] | None = None, options: TConfig | None = None, ) -> tuple[TTensoray, float] | None: + # pyrefly: ignore [bad-assignment] """Select the best point that has been observed. Implements two approaches to selecting the best point. @@ -446,12 +453,18 @@ def best_in_sample_point( - d-array of the best point, - utility at the best point. """ + # pyrefly: ignore [bad-assignment] # Parse options + # pyrefly: ignore [bad-assignment] if options is None: options = {} + # pyrefly: ignore [bad-assignment] method: str = options.get("best_point_method", "max_utility") + # pyrefly: ignore [bad-assignment] B: float | None = options.get("utility_baseline", None) + # pyrefly: ignore [bad-assignment] threshold: float = options.get("probability_threshold", 0.95) + # pyrefly: ignore [bad-assignment] nsamp: int = options.get("feasibility_mc_samples", 10000) # Get points observed for all objective and constraint outcomes if objective_weights is None: @@ -469,6 +482,7 @@ def best_in_sample_point( X=X_obs, bounds=bounds, linear_constraints=linear_constraints, + # pyrefly: ignore [bad-argument-type] fixed_features=fixed_features, ) if len(X_obs) == 0: @@ -478,6 +492,7 @@ def best_in_sample_point( if isinstance(Xs[0], torch.Tensor): X_obs = assert_is_instance(X_obs, torch.Tensor).detach().clone() # (n_feasible x n_outcomes), (n_feasible x n_outcomes x n_outcomes) + # pyrefly: ignore [bad-argument-type] f, cov = as_array(model.predict(X_obs)) # (n_outcomes,) x (n_outcomes, n_feasible) => (n_feasible,) obj = objective_weights_np @ f.transpose() @@ -498,6 +513,7 @@ def best_in_sample_point( if method == "feasible_threshold": utility = obj utility[pfeas < threshold] = -np.inf + # pyrefly: ignore [bad-return] elif method == "max_utility": if B is None: B = obj.min() @@ -508,6 +524,7 @@ def best_in_sample_point( if utility[i] == -np.inf: return None else: + # pyrefly: ignore [bad-return] return X_obs[i, :], utility[i] diff --git a/ax/metrics/map_replay.py b/ax/metrics/map_replay.py index a05e36e9266..b7d238184de 100644 --- a/ax/metrics/map_replay.py +++ b/ax/metrics/map_replay.py @@ -56,6 +56,7 @@ def __init__( ) # Pre-group by trial_index for O(1) trial lookups instead of O(n) filtering self._trial_groups: dict[int, pd.DataFrame] = { + # pyrefly: ignore [bad-argument-type] int(trial_idx): group for trial_idx, group in self._replay_df.groupby("trial_index") } diff --git a/ax/metrics/noisy_function.py b/ax/metrics/noisy_function.py index 37cd1cf0cf3..3aec14e4478 100644 --- a/ax/metrics/noisy_function.py +++ b/ax/metrics/noisy_function.py @@ -137,6 +137,7 @@ def __init__( Metric.__init__(self, name=name, lower_is_better=lower_is_better) @property + # pyrefly: ignore [bad-override] def param_names(self) -> list[str]: raise NotImplementedError( "GenericNoisyFunctionMetric does not implement a param_names attribute" diff --git a/ax/metrics/tensorboard.py b/ax/metrics/tensorboard.py index ac9d2c1b7a4..b457b7d16d8 100644 --- a/ax/metrics/tensorboard.py +++ b/ax/metrics/tensorboard.py @@ -208,6 +208,7 @@ def bulk_fetch_trial_data( res[metric.signature] = Ok(Data(df=df)) except Exception as e: + # pyrefly: ignore [unsupported-operation] res[metric.signature] = Err( MetricFetchE( message=f"Failed to fetch data for {metric.name}", @@ -217,6 +218,7 @@ def bulk_fetch_trial_data( self._clear_multiplexer_if_possible(multiplexer=mul) + # pyrefly: ignore [bad-return] return res def fetch_trial_data( diff --git a/ax/metrics/tests/test_noisy_function.py b/ax/metrics/tests/test_noisy_function.py index d5817c71b90..321ef567018 100644 --- a/ax/metrics/tests/test_noisy_function.py +++ b/ax/metrics/tests/test_noisy_function.py @@ -18,6 +18,7 @@ class GenericNoisyFunctionMetricTest(TestCase): def test_GenericNoisyFunctionMetric(self) -> None: def f(params: dict[str, TParamValue]) -> float: + # pyrefly: ignore [bad-argument-type] return float(params["x"]) + 1.0 # noiseless @@ -31,6 +32,7 @@ def f(params: dict[str, TParamValue]) -> float: self.assertEqual(df["metric_name"].tolist(), ["test_metric"]) self.assertEqual( df["mean"].tolist(), + # pyrefly: ignore [bad-argument-type] [float(none_throws(trial.arm).parameters["x"]) + 1.0], ) self.assertEqual(df["sem"].tolist(), [0.0]) @@ -47,6 +49,7 @@ def f(params: dict[str, TParamValue]) -> float: self.assertEqual(df["metric_name"].tolist(), ["test_metric"]) self.assertNotEqual( df["mean"].tolist(), + # pyrefly: ignore [bad-argument-type] [float(none_throws(trial.arm).parameters["x"]) + 1.0], ) self.assertEqual(df["sem"].tolist(), [1.0]) @@ -54,6 +57,7 @@ def f(params: dict[str, TParamValue]) -> float: self.assertEqual(df["arm_name"].tolist(), ["0_0"]) self.assertEqual(df["metric_name"].tolist(), ["test_metric"]) arm = none_throws(trial.arm) + # pyrefly: ignore [bad-argument-type] self.assertEqual(df["mean"].tolist(), [float(arm.parameters["x"]) + 1.0]) self.assertEqual(df["sem"].tolist(), [0.0]) @@ -68,6 +72,8 @@ def f(params: dict[str, TParamValue]) -> float: self.assertEqual(df["arm_name"].tolist(), ["0_0"]) self.assertEqual(df["metric_name"].tolist(), ["test_metric"]) arm = none_throws(trial.arm) + # pyrefly: ignore [bad-argument-type] self.assertEqual(df["mean"].tolist(), [float(arm.parameters["x"]) + 1.0]) + # pyrefly: ignore [bad-argument-type] self.assertEqual(df["mean"].tolist(), [float(arm.parameters["x"]) + 1.0]) self.assertTrue(math.isnan(df["sem"].tolist()[0])) diff --git a/ax/orchestration/orchestrator.py b/ax/orchestration/orchestrator.py index 046d17b059b..831908a47ee 100644 --- a/ax/orchestration/orchestrator.py +++ b/ax/orchestration/orchestrator.py @@ -777,6 +777,7 @@ def run_trials_and_yield_results( # -------- II. Methods that are typically called within the `Orchestrator`. ------- + # pyrefly: ignore [not-callable] @retry_on_exception(retries=3, no_retry_on_exception_types=NO_RETRY_EXCEPTIONS) def run_trials( self, @@ -842,6 +843,9 @@ def run_trials( self._log_next_no_trials_reason = True return metadata + # pyrefly: ignore [not-callable] + + # pyrefly: ignore [not-callable] @retry_on_exception(retries=3, no_retry_on_exception_types=NO_RETRY_EXCEPTIONS) def poll_trial_status( self, poll_all_trial_statuses: bool = False @@ -1835,8 +1839,10 @@ def _get_next_trials( ttl_seconds=self.options.ttl_seconds_for_trials, trial_type=self.trial_type, ) + # pyrefly: ignore [bad-return] trials.append(trial) + # pyrefly: ignore [bad-return] return trials, None def _gen_new_trials_from_generation_strategy( diff --git a/ax/orchestration/tests/test_orchestrator.py b/ax/orchestration/tests/test_orchestrator.py index 652315eea60..0a29ff034c0 100644 --- a/ax/orchestration/tests/test_orchestrator.py +++ b/ax/orchestration/tests/test_orchestrator.py @@ -134,6 +134,7 @@ class TestAxOrchestrator(TestCase): PENDING_FEATURES_EXTRACTOR: tuple[ # pyre-ignore[8] str, Callable[ + # pyrefly: ignore [invalid-argument] [...], dict[str, list[ObservationFeatures]] | None, ], @@ -145,6 +146,7 @@ class TestAxOrchestrator(TestCase): PENDING_FEATURES_BATCH_EXTRACTOR: tuple[ # pyre-ignore[8] str, Callable[ + # pyrefly: ignore [invalid-argument] [...], dict[str, list[ObservationFeatures]] | None, ], @@ -345,6 +347,7 @@ def test_init_with_no_impl_with_runner(self) -> None: ): Orchestrator( experiment=self.branin_experiment_no_impl_runner_or_metrics, + # pyrefly: ignore [bad-argument-type] generation_strategy=generation_strategy, options=OrchestratorOptions( total_trials=10, **self.orchestrator_options_kwargs @@ -359,6 +362,7 @@ def test_init_with_no_impl_with_runner(self) -> None: ): Orchestrator( experiment=self.branin_experiment_no_impl_runner_or_metrics, + # pyrefly: ignore [bad-argument-type] generation_strategy=generation_strategy, options=OrchestratorOptions( total_trials=10, **self.orchestrator_options_kwargs @@ -387,6 +391,7 @@ def test_init_with_branin_experiment(self) -> None: self.assertIsNone(orchestrator._latest_optimization_start_timestamp) orchestrator.run_all_trials() # Runs no trials since total trials is 0. # `_latest_optimization_start_timestamp` should be set now. + # pyrefly: ignore [no-matching-overload] self.assertLessEqual( orchestrator._latest_optimization_start_timestamp, # pyre-fixme[6]: For 2nd param expected `SupportsDunderGT[Variable[_T]]` @@ -555,7 +560,9 @@ def test_run_n_trials_callback(self) -> None: # pyre-fixme[53]: Captured variable `test_obj` is not annotated. def _callback(orchestrator: Orchestrator) -> None: + # pyrefly: ignore [unsupported-operation] test_obj[0] = orchestrator._latest_optimization_start_timestamp + # pyrefly: ignore [unsupported-operation] test_obj[1] = "apple" return @@ -588,6 +595,7 @@ def test_run_n_trials_single_step_existing_experiment( experiment=self.branin_experiment, # Has runner and metrics. generation_strategy=gs, options=OrchestratorOptions( + # pyrefly: ignore [bad-argument-type] init_seconds_between_polls=0.1, # Short between polls so test is fast. wait_for_running_trials=False, enforce_immutable_search_space_and_opt_config=False, @@ -1665,6 +1673,7 @@ def test_batch_trial(self, status_quo_weight: float = 0.0) -> None: ) if status_quo_weight > 0: self.assertEqual( + # pyrefly: ignore [bad-index] trial.arm_weights[self.branin_experiment.status_quo], 1.0, ) @@ -2470,6 +2479,7 @@ def test_it_works_with_multitask_models( generation_strategy=gs, options=OrchestratorOptions( total_trials=3, + # pyrefly: ignore [bad-argument-type] init_seconds_between_polls=0.1, # Short between polls so test is fast. **self.orchestrator_options_kwargs, ), @@ -2479,6 +2489,7 @@ def test_it_works_with_multitask_models( # for the MTGP step. with patch( "ax.adapter.random.RandomAdapter.gen", + # pyrefly: ignore [bad-argument-type] return_value=GeneratorRun(arms=[experiment.status_quo]), ): orchestrator.run_n_trials(max_trials=3) @@ -3153,64 +3164,77 @@ def setUp(self) -> None: self._mock_orchestrator_poll_sleep() def test_init_with_no_impl_with_runner(self) -> None: + # pyrefly: ignore [missing-attribute] self.branin_experiment_no_impl_runner_or_metrics.update_runner( trial_type="type1", runner=self.runner ) super().test_init_with_no_impl_with_runner() def test_update_options_with_validate_metrics(self) -> None: + # pyrefly: ignore [missing-attribute] self.branin_experiment_no_impl_runner_or_metrics.update_runner( trial_type="type1", runner=self.runner ) super().test_update_options_with_validate_metrics() def test_retries(self) -> None: + # pyrefly: ignore [missing-attribute] self.branin_experiment.update_runner("type1", BrokenRunnerRuntimeError()) super().test_retries() def test_retries_nonretriable_error(self) -> None: + # pyrefly: ignore [missing-attribute] self.branin_experiment.update_runner("type1", BrokenRunnerValueError()) super().test_retries_nonretriable_error() def test_failure_rate_some_failed(self) -> None: + # pyrefly: ignore [missing-attribute] self.branin_experiment.update_runner("type1", RunnerWithFrequentFailedTrials()) super().test_failure_rate_some_failed() def test_failure_rate_all_failed(self) -> None: + # pyrefly: ignore [missing-attribute] self.branin_experiment.update_runner("type1", RunnerWithAllFailedTrials()) super().test_failure_rate_all_failed() def test_run_trials_and_yield_results_with_early_stopper(self) -> None: + # pyrefly: ignore [missing-attribute] self.branin_experiment.update_runner("type1", InfinitePollRunner()) super().test_run_trials_and_yield_results_with_early_stopper() def test_orchestrator_with_metric_with_new_data_after_completion(self) -> None: + # pyrefly: ignore [missing-attribute] self.branin_experiment.update_runner( "type1", SyntheticRunnerWithPredictableStatusPolling() ) super().test_orchestrator_with_metric_with_new_data_after_completion() def test_poll_and_process_results_with_reasons(self) -> None: + # pyrefly: ignore [missing-attribute] self.branin_experiment.update_runner( "type1", RunnerWithFailedAndAbandonedTrials() ) super().test_poll_and_process_results_with_reasons() def test_poll_trial_status_fallback_to_individual_polling(self) -> None: + # pyrefly: ignore [missing-attribute] self.branin_experiment.update_runner( "type1", RunnerWithFailingPollTrialStatus() ) super().test_poll_trial_status_fallback_to_individual_polling() def test_poll_trial_status_abandons_trial_on_individual_failure(self) -> None: + # pyrefly: ignore [missing-attribute] self.branin_experiment.update_runner("type1", RunnerWithAllPollsFailing()) super().test_poll_trial_status_abandons_trial_on_individual_failure() def test_generate_candidates_works_for_iteration(self) -> None: + # pyrefly: ignore [missing-attribute] self.branin_experiment.update_runner("type1", InfinitePollRunner()) super().test_generate_candidates_works_for_iteration() def test_orchestrator_with_odd_index_early_stopping_strategy(self) -> None: + # pyrefly: ignore [missing-attribute] self.branin_timestamp_map_metric_experiment.update_runner( "type1", RunnerWithEarlyStoppingStrategy() ) @@ -3221,7 +3245,10 @@ def test_fetch_and_process_trials_data_results_failed_non_objective( ) -> None: # add a tracking metric self.branin_timestamp_map_metric_experiment.add_tracking_metric( - BraninMetric("branin", ["x1", "x2"]), trial_type="type1" + # pyrefly: ignore [unexpected-keyword] + BraninMetric("branin", ["x1", "x2"]), + # pyrefly: ignore [unexpected-keyword] + trial_type="type1", ) super().test_fetch_and_process_trials_data_results_failed_non_objective() @@ -3237,6 +3264,7 @@ def test_validate_options_not_none_mt_trial_type( def test_run_n_trials_single_step_existing_experiment( self, all_completed_trials: bool = False ) -> None: + # pyrefly: ignore [missing-attribute] self.branin_experiment.update_runner( "type1", SyntheticRunnerWithSingleRunningTrial() ) @@ -3249,12 +3277,17 @@ def test_run_n_trials_single_step_existing_experiment( self.assertEqual(metric_names, ["m1"]) def test_generate_candidates_does_not_generate_if_missing_data(self) -> None: + # pyrefly: ignore [missing-attribute] self.branin_experiment.update_runner("type1", InfinitePollRunner()) super().test_generate_candidates_does_not_generate_if_missing_data() def test_generate_candidates_does_not_generate_if_missing_opt_config(self) -> None: + # pyrefly: ignore [missing-attribute] self.branin_experiment.update_runner("type1", InfinitePollRunner()) self.branin_experiment.add_tracking_metric( - get_branin_metric(), trial_type="type1" + # pyrefly: ignore [unexpected-keyword] + get_branin_metric(), + # pyrefly: ignore [unexpected-keyword] + trial_type="type1", ) super().test_generate_candidates_does_not_generate_if_missing_opt_config() diff --git a/ax/plot/contour.py b/ax/plot/contour.py index ecec3d1f2b8..727bd5f0b57 100644 --- a/ax/plot/contour.py +++ b/ax/plot/contour.py @@ -492,6 +492,7 @@ def interact_contour_plotly( "ticksuffix": "%" if rel else "", "tickfont": {"size": 8}, }, + # pyrefly: ignore [bad-argument-type] "colorscale": [(i / (len(f_scale) - 1), rgb(v)) for i, v in enumerate(f_scale)], "xaxis": "x", "yaxis": "y", @@ -679,7 +680,9 @@ def interact_contour_plotly( for yvar_idx, yvar in enumerate(param_names): cur_visible = yvar_idx == 1 + # pyrefly: ignore [bad-index] f_start = xbuttons[0]["args"][0]["z"][trace_cnt * yvar_idx] + # pyrefly: ignore [bad-index] sd_start = xbuttons[0]["args"][0]["z"][trace_cnt * yvar_idx + 1] # create traces @@ -731,7 +734,9 @@ def interact_contour_plotly( } for key in base_in_sample_arm_config.keys(): + # pyrefly: ignore [bad-assignment] f_in_sample_arm_trace[key] = base_in_sample_arm_config[key] + # pyrefly: ignore [unsupported-operation] sd_in_sample_arm_trace[key] = base_in_sample_arm_config[key] traces += [f_trace, sd_trace, f_in_sample_arm_trace, sd_in_sample_arm_trace] diff --git a/ax/plot/diagnostic.py b/ax/plot/diagnostic.py index 9b6ea9940a3..0338ff06228 100644 --- a/ax/plot/diagnostic.py +++ b/ax/plot/diagnostic.py @@ -147,6 +147,7 @@ def _obs_vs_pred_dropdown_plot( ) ) else: + # pyrefly: ignore [bad-argument-type] layout_axis_range.append(None) traces.append(_diagonal_trace(min_, max_, visible=(i == 0))) diff --git a/ax/plot/feature_importances.py b/ax/plot/feature_importances.py index 53680dfce26..de8b64bb54e 100644 --- a/ax/plot/feature_importances.py +++ b/ax/plot/feature_importances.py @@ -53,6 +53,7 @@ def plot_feature_importance_by_feature_plotly( model._experiment.signature_to_metric[signature].name for signature in model.metric_signatures ] + # pyrefly: ignore [bad-assignment] sensitivity_values = { metric_name: model.feature_importances(metric_name) for i, metric_name in enumerate(sorted(metric_names)) @@ -65,6 +66,7 @@ def plot_feature_importance_by_feature_plotly( if label_dict is not None: sensitivity_values = { # pyre-ignore label_dict.get(metric_name, metric_name): v + # pyrefly: ignore [missing-attribute] for metric_name, v in sensitivity_values.items() } traces = [] @@ -77,7 +79,9 @@ def plot_feature_importance_by_feature_plotly( if isinstance(par, ChoiceParameter) and not par.is_ordered ] + # pyrefly: ignore [missing-attribute] for i, metric_name in enumerate(sorted(sensitivity_values.keys())): + # pyrefly: ignore [unsupported-operation] importances = sensitivity_values[metric_name] factor_col = "Factor" importance_col = "Importance" @@ -154,6 +158,7 @@ def plot_feature_importance_by_feature_plotly( ) legend_counter[row[sign_col]] += 1 + # pyrefly: ignore [bad-argument-type] is_visible = [False] * (len(sensitivity_values) * len(df)) for j in range(i * len(df), (i + 1) * len(df)): is_visible[j] = True @@ -175,9 +180,11 @@ def plot_feature_importance_by_feature_plotly( }, # hack to put dropdown below title regardless of number of features } ] + # pyrefly: ignore [missing-attribute] features = list(list(sensitivity_values.values())[0].keys()) longest_label = max(len(f) for f in features) + # pyrefly: ignore [missing-attribute] longest_metric = max(len(m) for m in sensitivity_values.keys()) layout = go.Layout( diff --git a/ax/plot/helper.py b/ax/plot/helper.py index 2ab93d24064..c7307b6dd5b 100644 --- a/ax/plot/helper.py +++ b/ax/plot/helper.py @@ -252,7 +252,9 @@ def _get_in_sample_arms( cov_dict=obs.data.covariance_matrix, agg_metric_weight_dict=agg_metric["weight"], ) + # pyrefly: ignore [unsupported-operation] obs_y[agg_metric_name] = agg_mean + # pyrefly: ignore [unsupported-operation] obs_se[agg_metric_name] = np.sqrt(agg_var) if training_in_design[i]: # Update with the input fixed features @@ -281,6 +283,7 @@ def _get_in_sample_arms( se_hat=pred_se, context_stratum=None, ) + # pyrefly: ignore [bad-return] return in_sample_plot, raw_data, arm_name_to_parameters @@ -596,6 +599,7 @@ def contour_config_to_trace(config) -> list[dict[str, Any]]: "ticksuffix": "%" if rel else "", "tickfont": {"size": 8}, }, + # pyrefly: ignore [bad-argument-type] "colorscale": [(i / (len(f_scale) - 1), rgb(v)) for i, v in enumerate(f_scale)], "xaxis": "x", "yaxis": "y", @@ -622,6 +626,7 @@ def contour_config_to_trace(config) -> list[dict[str, Any]]: } f_trace.update(CONTOUR_CONFIG) + # pyrefly: ignore [no-matching-overload] sd_trace.update(CONTOUR_CONFIG) # get in-sample arms diff --git a/ax/plot/pareto_frontier.py b/ax/plot/pareto_frontier.py index 5d11638f4e5..fb635d99cac 100644 --- a/ax/plot/pareto_frontier.py +++ b/ax/plot/pareto_frontier.py @@ -155,6 +155,7 @@ def scatter_plot_with_pareto_frontier_plotly( y=[reference_point[1]], mode="markers", marker={ + # pyrefly: ignore [bad-argument-type] "color": rgba(COLORS.STEELBLUE.value), "size": 25, "symbol": "star", @@ -166,6 +167,7 @@ def scatter_plot_with_pareto_frontier_plotly( x=[extra_point_x, reference_point[0]], y=[reference_point[1], reference_point[1]], mode="lines", + # pyrefly: ignore [bad-argument-type] marker={"color": rgba(COLORS.STEELBLUE.value)}, ) extra_point_y = min(Y_pareto[:, 1]) if minimize[1] else max(Y_pareto[:, 1]) @@ -173,6 +175,7 @@ def scatter_plot_with_pareto_frontier_plotly( x=[reference_point[0], reference_point[0]], y=[extra_point_y, reference_point[1]], mode="lines", + # pyrefly: ignore [bad-argument-type] marker={"color": rgba(COLORS.STEELBLUE.value)}, ) reference_point_lines = [reference_point_line_1, reference_point_line_2] @@ -190,6 +193,7 @@ def scatter_plot_with_pareto_frontier_plotly( y=Y_pareto_with_extra[:, 1], mode="lines", line_shape="hv", + # pyrefly: ignore [bad-argument-type] marker={"color": rgba(COLORS.STEELBLUE.value)}, ) ] @@ -211,6 +215,7 @@ def scatter_plot_with_pareto_frontier_plotly( y=Y_pareto[:, 1], mode="lines", line_shape="hv", + # pyrefly: ignore [bad-argument-type] marker={"color": rgba(COLORS.STEELBLUE.value)}, ) ] @@ -237,6 +242,7 @@ def _get_single_pareto_trace( frontier: ParetoFrontierResults, CI_level: float, legend_label: str = "mean", + # pyrefly: ignore [bad-function-definition] trace_color: tuple[int] = COLORS.STEELBLUE.value, show_parameterization_on_hover: bool = True, ) -> go.Scatter: @@ -286,17 +292,20 @@ def _get_single_pareto_trace( "type": "data", "array": Z * np.array(secondary_sems), "thickness": 2, + # pyrefly: ignore [bad-argument-type] "color": rgba(trace_color, CI_OPACITY), }, error_y={ "type": "data", "array": Z * np.array(primary_sems), "thickness": 2, + # pyrefly: ignore [bad-argument-type] "color": rgba(trace_color, CI_OPACITY), }, mode="markers", text=labels, hoverinfo="text", + # pyrefly: ignore [bad-argument-type] marker={"color": rgba(trace_color)}, ) @@ -348,6 +357,7 @@ def plot_pareto_frontier( "yref": "y", "y0": primary_threshold, "y1": primary_threshold, + # pyrefly: ignore [bad-argument-type] "line": {"color": rgba(COLORS.CORAL.value), "width": 3}, } ) @@ -361,6 +371,7 @@ def plot_pareto_frontier( "xref": "x", "x0": secondary_threshold, "x1": secondary_threshold, + # pyrefly: ignore [bad-argument-type] "line": {"color": rgba(COLORS.CORAL.value), "width": 3}, } ) @@ -584,6 +595,7 @@ def _validate_and_maybe_get_default_metric_names( multi_objective = assert_is_instance( none_throws(optimization_config).objective, MultiObjective ) + # pyrefly: ignore [bad-assignment] metric_names = tuple(multi_objective.metric_names) else: raise UserInputError( @@ -637,6 +649,7 @@ def _validate_experiment_and_maybe_get_objective_thresholds( "objective threshold for each metric. Returning an empty list." ) + # pyrefly: ignore [bad-return] return objective_thresholds @@ -646,10 +659,12 @@ def _validate_and_maybe_get_default_reference_point( metric_names: tuple[str, str], ) -> tuple[float, float] | None: if reference_point is None: + # pyrefly: ignore [bad-assignment] reference_point = { objective_threshold.metric_names[0]: objective_threshold.bound for objective_threshold in objective_thresholds } + # pyrefly: ignore [bad-argument-type] missing_metric_names = set(metric_names) - set(reference_point) if missing_metric_names: warnings.warn( @@ -660,7 +675,9 @@ def _validate_and_maybe_get_default_reference_point( ) return None reference_point = tuple( - reference_point[metric_name] for metric_name in metric_names + # pyrefly: ignore [bad-index, unsupported-operation] + reference_point[metric_name] + for metric_name in metric_names ) if len(reference_point) != 2: warnings.warn( @@ -681,6 +698,7 @@ def _validate_and_maybe_get_default_minimize( ) -> tuple[bool, bool] | None: if minimize is None: # Determine `minimize` defaults + # pyrefly: ignore [bad-assignment] minimize = tuple( _maybe_get_default_minimize_single_metric( metric_name=metric_name, @@ -690,6 +708,7 @@ def _validate_and_maybe_get_default_minimize( for metric_name in metric_names ) # If either value of minimize is missing, return `None` + # pyrefly: ignore [not-iterable] if any(i_min is None for i_min in minimize): warnings.warn( "Extraction of default `minimize` failed. Please specify `minimize` " @@ -697,12 +716,15 @@ def _validate_and_maybe_get_default_minimize( "includes 2 objectives. Returning None." ) return None + # pyrefly: ignore [bad-assignment, not-iterable] minimize = tuple(none_throws(i_min) for i_min in minimize) # If only one bool provided, use for both dimensions elif isinstance(minimize, bool): minimize = (minimize, minimize) + # pyrefly: ignore [bad-argument-type] if len(minimize) != 2: warnings.warn( + # pyrefly: ignore [bad-argument-type] f"Expected 2-dimensional `minimize` but got {len(minimize)} dimensions: " f"{minimize}. Please specify `minimize` of length 2 or provide an " "experiment whose `optimization_config` includes 2 objectives. Returning " @@ -710,6 +732,7 @@ def _validate_and_maybe_get_default_minimize( ) return None + # pyrefly: ignore [bad-return] return minimize diff --git a/ax/plot/scatter.py b/ax/plot/scatter.py index 76a3f826707..95d2c4cbd38 100644 --- a/ax/plot/scatter.py +++ b/ax/plot/scatter.py @@ -95,6 +95,7 @@ def _error_scatter_data( ) x = x_rel.tolist() x_se = x_se_rel.tolist() + # pyrefly: ignore [bad-return] return x, x_se, y, y_se @@ -243,15 +244,19 @@ def _error_scatter_trace( ) i += 1 + # pyrefly: ignore [bad-argument-type] if color_metric or color_parameter: + # pyrefly: ignore [bad-argument-type] rgba_blue_scale = [rgba(c) for c in BLUE_SCALE] marker = { "color": colors, "colorscale": rgba_blue_scale, "colorbar": {"title": color_metric or color_parameter}, "showscale": True, + # pyrefly: ignore [bad-argument-type] } else: + # pyrefly: ignore [bad-argument-type] marker = {"color": rgba(color)} trace = go.Scatter( @@ -267,17 +272,21 @@ def _error_scatter_trace( if show_CI: if x_se is not None: trace.update( + # pyrefly: ignore [bad-argument-type] error_x={ "type": "data", "array": np.multiply(x_se, Z), + # pyrefly: ignore [bad-argument-type] "color": rgba(color, CI_OPACITY), } ) if y_se is not None: + # pyrefly: ignore [bad-argument-type] trace.update( error_y={ "type": "data", "array": np.multiply(y_se, Z), + # pyrefly: ignore [bad-argument-type] "color": rgba(color, CI_OPACITY), } ) diff --git a/ax/plot/tests/test_contours.py b/ax/plot/tests/test_contours.py index f5cd8f4ca7d..2d4f756337b 100644 --- a/ax/plot/tests/test_contours.py +++ b/ax/plot/tests/test_contours.py @@ -37,6 +37,7 @@ def test_Contours(self) -> None: model, # pyre-fixme[16]: `Adapter` has no attribute `parameters`. model.parameters[0], + # pyrefly: ignore [missing-attribute] model.parameters[1], model_metric_names[0], ) @@ -45,7 +46,9 @@ def test_Contours(self) -> None: self.assertIsInstance(plot, go.Figure) plot = plot_contour( model, + # pyrefly: ignore [missing-attribute] model.parameters[0], + # pyrefly: ignore [missing-attribute] model.parameters[1], model_metric_names[0], ) @@ -80,6 +83,7 @@ def test_Contours(self) -> None: parameters_to_use=["foo"], ) for i in [2, 3]: + # pyrefly: ignore [missing-attribute] parameters_to_use = model.parameters[:i] plot = interact_contour_plotly( model, diff --git a/ax/plot/tests/test_slices.py b/ax/plot/tests/test_slices.py index 9789d2ec172..ebfaceab6fe 100644 --- a/ax/plot/tests/test_slices.py +++ b/ax/plot/tests/test_slices.py @@ -44,6 +44,7 @@ def test_Slices(self) -> None: self.assertIsInstance(plot, go.Figure) plot = interact_slice_plotly(model) self.assertIsInstance(plot, go.Figure) + # pyrefly: ignore [missing-attribute] plot = plot_slice(model, model.parameters[0], model_metric_names[0]) self.assertIsInstance(plot, AxPlotConfig) plot = interact_slice(model) diff --git a/ax/plot/tests/test_traces.py b/ax/plot/tests/test_traces.py index b501285c360..213b9d90671 100644 --- a/ax/plot/tests/test_traces.py +++ b/ax/plot/tests/test_traces.py @@ -41,6 +41,7 @@ def test_Traces(self) -> None: # Assert that each type of plot can be constructed successfully plot = optimization_trace_single_method_plotly( np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), + # pyrefly: ignore [bad-argument-type] self.model_metric_names[0], optimization_direction="minimize", autoset_axis_limits=False, @@ -48,6 +49,7 @@ def test_Traces(self) -> None: self.assertIsInstance(plot, go.Figure) plot = optimization_trace_single_method( np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), + # pyrefly: ignore [bad-argument-type] self.model_metric_names[0], optimization_direction="minimize", autoset_axis_limits=False, @@ -58,6 +60,7 @@ def test_TracesAutoAxes(self) -> None: for optimization_direction in ["minimize", "maximize", "passthrough"]: plot = optimization_trace_single_method_plotly( np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), + # pyrefly: ignore [bad-argument-type] self.model_metric_names[0], optimization_direction=optimization_direction, autoset_axis_limits=True, diff --git a/ax/plot/trace.py b/ax/plot/trace.py index e106f1357c1..399ce128eee 100644 --- a/ax/plot/trace.py +++ b/ax/plot/trace.py @@ -58,6 +58,7 @@ def map_data_single_trace_scatters( x=x, y=y, mode="lines+markers", + # pyrefly: ignore [bad-argument-type] line={"color": rgba(trace_color)}, opacity=opacity, hovertemplate=f"{legend_label}
" @@ -159,6 +160,7 @@ def map_data_multiple_metrics_dropdown_plotly( optimization_direction="minimize" if lower_is_better else "maximize", ) else: + # pyrefly: ignore [unsupported-operation] layout_yaxis_ranges[metric_name] = None metric_dropdown = [] @@ -229,7 +231,9 @@ def mean_trace_scatter( x=np.arange(1, y.shape[1] + 1), y=np.mean(y, axis=0), mode="lines", + # pyrefly: ignore [bad-argument-type] line={"color": rgba(trace_color)}, + # pyrefly: ignore [bad-argument-type] fillcolor=rgba(trace_color, 0.3), fill="tonexty", text=hover_labels, @@ -271,6 +275,7 @@ def sem_range_scatter( legendgroup=legend_label, mode="lines", line={"width": 0}, + # pyrefly: ignore [bad-argument-type] fillcolor=rgba(trace_color, 0.3), fill="tonexty", showlegend=False, @@ -311,6 +316,7 @@ def mean_markers_scatter( "visible": True, }, mode="markers", + # pyrefly: ignore [bad-argument-type] marker={"color": rgba(marker_color)}, text=hover_labels, ) @@ -337,6 +343,7 @@ def optimum_objective_scatter( x=[1, num_iterations], y=[optimum] * 2, mode="lines", + # pyrefly: ignore [bad-argument-type] line={"dash": "dash", "color": rgba(optimum_color)}, name="Optimum", ) @@ -656,6 +663,7 @@ def optimization_times( textposition="auto", error_y={"type": "data", "array": res["2sems"], "visible": True}, marker={ + # pyrefly: ignore [bad-argument-type] "color": rgba(DISCRETE_COLOR_SCALE[i]), "line": {"color": "rgb(0,0,0)", "width": 1.0}, }, diff --git a/ax/service/ax_client.py b/ax/service/ax_client.py index 7e07eaef80b..3f4883f326a 100644 --- a/ax/service/ax_client.py +++ b/ax/service/ax_client.py @@ -336,6 +336,7 @@ def create_experiment( } if len(objectives.keys()) > 1: objective_kwargs["objective_thresholds"] = ( + # pyrefly: ignore [unsupported-operation] self.build_objective_thresholds(objectives) ) @@ -355,6 +356,7 @@ def create_experiment( is_test=is_test, default_trial_type=default_trial_type, default_runner=default_runner, + # pyrefly: ignore [bad-argument-type] **objective_kwargs, ) self._set_runner(experiment=experiment) @@ -711,6 +713,7 @@ def runner_config_type(self) -> type[RunnerConfig] | None: return None return self.experiment.runner.config_type + # pyrefly: ignore [not-callable] @retry_on_exception( logger=logger, exception_types=(RuntimeError,), @@ -1541,6 +1544,7 @@ def from_json_snapshot( # ---------------------- Private helper methods. --------------------- @property + # pyrefly: ignore [bad-override] def experiment(self) -> Experiment: """Returns the experiment set on this Ax client.""" return none_throws( @@ -1556,6 +1560,7 @@ def get_trial(self, trial_index: int) -> Trial: return assert_is_instance(self.experiment.trials[trial_index], Trial) @property + # pyrefly: ignore [bad-override] def generation_strategy(self) -> GenerationStrategy: """Returns the generation strategy, set on this experiment.""" return none_throws( diff --git a/ax/service/managed_loop.py b/ax/service/managed_loop.py index 8f6c8850767..d9201495a88 100644 --- a/ax/service/managed_loop.py +++ b/ax/service/managed_loop.py @@ -174,11 +174,14 @@ def _outcome_to_dict( # (float, float) or (float, None) mean, sem = outcome if sem is None: + # pyrefly: ignore [bad-argument-type] return {objective_name: float(mean)} else: + # pyrefly: ignore [bad-argument-type] return {objective_name: (float(mean), float(sem))} else: # Single float + # pyrefly: ignore [bad-argument-type] return {objective_name: float(outcome)} diff --git a/ax/service/tests/test_ax_client.py b/ax/service/tests/test_ax_client.py index a4b96d6dca8..9bc7bc13641 100644 --- a/ax/service/tests/test_ax_client.py +++ b/ax/service/tests/test_ax_client.py @@ -156,6 +156,7 @@ def get_branin_currin_optimization_with_N_sobol_trials( "branin": ObjectiveProperties( minimize=minimize, threshold=( + # pyrefly: ignore [bad-index] float(branin_currin.ref_point[0]) if include_objective_thresholds else None @@ -164,6 +165,7 @@ def get_branin_currin_optimization_with_N_sobol_trials( "currin": ObjectiveProperties( minimize=minimize, threshold=( + # pyrefly: ignore [bad-index] float(branin_currin.ref_point[1]) if include_objective_thresholds else None @@ -184,6 +186,7 @@ def get_branin_currin_optimization_with_N_sobol_trials( currin = float(branin_currin(torch.tensor([x, y]))[1]) raw_data: TTrialEvaluation = {"branin": branin, "currin": currin} if tracking_metric_names is not None: + # pyrefly: ignore [unsupported-operation] raw_data["c"] = branin + currin ax_client.complete_trial(trial_index, raw_data=raw_data) return ax_client, branin_currin @@ -3061,7 +3064,9 @@ def helper_test_get_pareto_optimal_points( for obs in observed_pareto.values(): branin: float = obs[1][0]["branin"] currin: float = obs[1][0]["currin"] + # pyrefly: ignore [bad-index] self.assertGreater(branin, branin_currin.ref_point[0].item()) + # pyrefly: ignore [bad-index] self.assertGreater(currin, branin_currin.ref_point[1].item()) if outcome_constraints is not None: self.assertEqual(branin + currin, obs[1][0]["c"]) diff --git a/ax/service/tests/test_best_point.py b/ax/service/tests/test_best_point.py index e4a14ebfba5..e27d9644bd2 100644 --- a/ax/service/tests/test_best_point.py +++ b/ax/service/tests/test_best_point.py @@ -712,6 +712,7 @@ def test_constrained_infer_reference_point_from_experiment(self) -> None: for experiment in experiments: # special case logs a warning message. data = experiment.fetch_data() + # pyrefly: ignore [missing-attribute] if experiment.optimization_config.outcome_constraints[0].bound == 1000.0: with self.assertLogs(logger, "WARNING"): inferred_reference_point = infer_reference_point_from_experiment( diff --git a/ax/service/tests/test_best_point_utils.py b/ax/service/tests/test_best_point_utils.py index ce82aaabd7b..70ed60f9696 100644 --- a/ax/service/tests/test_best_point_utils.py +++ b/ax/service/tests/test_best_point_utils.py @@ -468,6 +468,7 @@ def test_best_from_model_prediction(self) -> None: # It works even when there are no predictions already stored on the # GeneratorRun for trial in exp.trials.values(): + # pyrefly: ignore [missing-attribute] trial.generator_run._best_arm_predictions = None res = get_best_parameters_from_model_predictions_with_trial_index( experiment=exp, adapter=gs.adapter @@ -1040,6 +1041,7 @@ def test_best_parameters_from_model_predictions_scalarized(self) -> None: metric1 = get_branin_metric(name="branin") metric2 = GenericNoisyFunctionMetric( name="distance_from_origin", + # pyrefly: ignore [unsupported-operation] f=lambda params: (params["x1"] ** 2 + params["x2"] ** 2) ** 0.5, noise_sd=0.01, lower_is_better=True, @@ -1184,6 +1186,7 @@ def test_get_best_trial_with_scalarized_objective(self) -> None: metric1 = get_branin_metric(name="branin") metric2 = GenericNoisyFunctionMetric( name="distance", + # pyrefly: ignore [unsupported-operation] f=lambda params: (params["x1"] ** 2 + params["x2"] ** 2) ** 0.5, noise_sd=0.01, lower_is_better=True, diff --git a/ax/service/tests/test_interactive_loop.py b/ax/service/tests/test_interactive_loop.py index fac29c73f12..9823ac35324 100644 --- a/ax/service/tests/test_interactive_loop.py +++ b/ax/service/tests/test_interactive_loop.py @@ -68,6 +68,7 @@ def _elicit( parameterization, trial_index = parameterization_with_trial_index x = np.array([parameterization.get(f"x{i + 1}") for i in range(6)]) + # pyrefly: ignore [bad-return] return ( trial_index, { @@ -140,6 +141,7 @@ def _sleep_elicit( x = np.array([parameterization.get(f"x{i + 1}") for i in range(6)]) + # pyrefly: ignore [bad-return] return ( trial_index, { diff --git a/ax/service/tests/test_managed_loop.py b/ax/service/tests/test_managed_loop.py index 71796abeafb..79c536e56b4 100644 --- a/ax/service/tests/test_managed_loop.py +++ b/ax/service/tests/test_managed_loop.py @@ -25,6 +25,7 @@ def _branin_evaluation_function( parameterization: TParameterization, ) -> dict[str, tuple[float, float]]: + # pyrefly: ignore [bad-argument-type] x1, x2 = float(parameterization["x1"]), float(parameterization["x2"]) return { "branin": (float(branin(x1, x2)), 0.0), @@ -54,8 +55,10 @@ def test_optimize_returns_deprecation_warning(self) -> None: {"name": "x1", "type": "range", "bounds": [-10.0, 10.0]}, {"name": "x2", "type": "range", "bounds": [-10.0, 10.0]}, ], + # pyrefly: ignore [unsupported-operation] evaluation_function=lambda p: (p["x1"] + 2 * p["x2"] - 7) ** 2 - + (2 * p["x1"] + p["x2"] - 5) ** 2, + + (2 * p["x1"] + p["x2"] - 5) # pyrefly: ignore [unsupported-operation] + ** 2, # pyrefly: ignore [unsupported-operation] minimize=True, total_trials=5, ) @@ -67,8 +70,10 @@ def test_optimize_single_metric(self) -> None: {"name": "x1", "type": "range", "bounds": [-10.0, 10.0]}, {"name": "x2", "type": "range", "bounds": [-10.0, 10.0]}, ], + # pyrefly: ignore [unsupported-operation] evaluation_function=lambda p: (p["x1"] + 2 * p["x2"] - 7) ** 2 - + (2 * p["x1"] + p["x2"] - 5) ** 2, + + (2 * p["x1"] + p["x2"] - 5) # pyrefly: ignore [unsupported-operation] + ** 2, # pyrefly: ignore [unsupported-operation] minimize=True, total_trials=5, ) @@ -85,6 +90,7 @@ def test_optimize_tuple_return(self) -> None: {"name": "x2", "type": "range", "bounds": [-10.0, 10.0]}, ], evaluation_function=lambda p: ( + # pyrefly: ignore [unsupported-operation] (p["x1"] + 2 * p["x2"] - 7) ** 2 + (2 * p["x1"] + p["x2"] - 5) ** 2, 0.0, ), @@ -102,6 +108,7 @@ def test_optimize_tuple_none_sem(self) -> None: {"name": "x2", "type": "range", "bounds": [-10.0, 10.0]}, ], evaluation_function=lambda p: ( + # pyrefly: ignore [unsupported-operation] (p["x1"] + 2 * p["x2"] - 7) ** 2 + (2 * p["x1"] + p["x2"] - 5) ** 2, None, ), @@ -140,8 +147,10 @@ def test_optimize_with_parameter_constraints(self) -> None: {"name": "x1", "type": "range", "bounds": [-10.0, 10.0]}, {"name": "x2", "type": "range", "bounds": [-10.0, 10.0]}, ], + # pyrefly: ignore [unsupported-operation] evaluation_function=lambda p: (p["x1"] + 2 * p["x2"] - 7) ** 2 - + (2 * p["x1"] + p["x2"] - 5) ** 2, + + (2 * p["x1"] + p["x2"] - 5) # pyrefly: ignore [unsupported-operation] + ** 2, # pyrefly: ignore [unsupported-operation] minimize=True, parameter_constraints=["x1 + x2 <= 5"], total_trials=5, @@ -166,6 +175,7 @@ def test_optimize_choice_parameters(self) -> None: "value_type": "int", }, ], + # pyrefly: ignore [unsupported-operation] evaluation_function=lambda p: (p["x1"] - 3) ** 2 + (p["x2"] - 2) ** 2, minimize=True, total_trials=5, @@ -190,6 +200,7 @@ def test_optimize_search_space_exhausted(self) -> None: "value_type": "int", }, ], + # pyrefly: ignore [unsupported-operation] evaluation_function=lambda p: (p["x1"] - 1) ** 2 + (p["x2"] - 1) ** 2, minimize=True, total_trials=10, @@ -210,6 +221,7 @@ def test_optimize_int_range_parameter(self) -> None: "value_type": "int", }, ], + # pyrefly: ignore [unsupported-operation] evaluation_function=lambda p: (p["k"] - 5) ** 2, minimize=True, total_trials=5, @@ -228,6 +240,7 @@ def test_optimize_log_scale(self) -> None: "log_scale": True, }, ], + # pyrefly: ignore [unsupported-operation] evaluation_function=lambda p: (p["lr"] - 0.01) ** 2, minimize=True, total_trials=5, @@ -241,6 +254,7 @@ def test_optimize_rejects_batch_trials(self) -> None: parameters=[ {"name": "x1", "type": "range", "bounds": [-10.0, 10.0]}, ], + # pyrefly: ignore [unsupported-operation] evaluation_function=lambda p: p["x1"] ** 2, total_trials=5, arms_per_trial=3, @@ -261,8 +275,10 @@ def test_optimize_with_custom_generation_strategy(self) -> None: {"name": "x1", "type": "range", "bounds": [-10.0, 10.0]}, {"name": "x2", "type": "range", "bounds": [-10.0, 10.0]}, ], + # pyrefly: ignore [unsupported-operation] evaluation_function=lambda p: (p["x1"] + 2 * p["x2"] - 7) ** 2 - + (2 * p["x1"] + p["x2"] - 5) ** 2, + + (2 * p["x1"] + p["x2"] - 5) # pyrefly: ignore [unsupported-operation] + ** 2, # pyrefly: ignore [unsupported-operation] minimize=True, total_trials=5, generation_strategy=gs, diff --git a/ax/service/utils/best_point.py b/ax/service/utils/best_point.py index 083159dfe33..0942d18cdea 100644 --- a/ax/service/utils/best_point.py +++ b/ax/service/utils/best_point.py @@ -1355,6 +1355,7 @@ def get_trace( keep_order=False, # sort by trial index ) + # pyrefly: ignore [bad-argument-type] return {int(k): float(v) for k, v in cumulative_value.items()} diff --git a/ax/service/utils/instantiation.py b/ax/service/utils/instantiation.py index 6e2399967de..ecff2df7190 100644 --- a/ax/service/utils/instantiation.py +++ b/ax/service/utils/instantiation.py @@ -243,7 +243,9 @@ def _make_range_param( parameter_type=cls._to_parameter_type( bounds, parameter_type, name, "bounds" ), + # pyrefly: ignore [bad-argument-type] lower=assert_is_instance_of_tuple(bounds[0], (float, int)), + # pyrefly: ignore [bad-argument-type] upper=assert_is_instance_of_tuple(bounds[1], (float, int)), log_scale=assert_is_instance(representation.get("log_scale", False), bool), digits=assert_is_instance_optional(representation.get("digits", None), int), diff --git a/ax/service/utils/report_utils.py b/ax/service/utils/report_utils.py index 0fad0f7ae7f..bc848747180 100644 --- a/ax/service/utils/report_utils.py +++ b/ax/service/utils/report_utils.py @@ -455,6 +455,7 @@ def get_standard_plots( for by_walltime in [False, True]: logger.debug(f"Starting MapMetric plot {by_walltime=}.") output_plot_list.append( + # pyrefly: ignore [bad-argument-type] _get_curve_plot_dropdown( experiment=experiment, map_metrics=map_metrics, diff --git a/ax/storage/json_store/decoder.py b/ax/storage/json_store/decoder.py index ce73f7e74b0..89ed3879a38 100644 --- a/ax/storage/json_store/decoder.py +++ b/ax/storage/json_store/decoder.py @@ -203,6 +203,7 @@ def object_from_json( elif _type == "DataFrame": # Need dtype=False, otherwise infers arm_names like "4_1" # should be int 41 + # pyrefly: ignore [no-matching-overload] return pd.read_json(StringIO(object_json["value"]), dtype=False) elif _type == "ndarray": return np.array(object_json["value"]) @@ -433,6 +434,7 @@ def generator_run_from_json( generator_run._best_arm_predictions = (arm, arm_prediction) if isinstance(generator_run._model_predictions, list): + # pyrefly: ignore [bad-assignment] generator_run._model_predictions = tuple(generator_run._model_predictions) # Remove deprecated kwargs from generator kwargs & adapter kwargs. @@ -503,8 +505,10 @@ def transition_criterion_from_json( class_decoder_registry=class_decoder_registry, ) return MinTrials( + # pyrefly: ignore [bad-argument-type] threshold=object_json.get("threshold"), only_in_statuses=[status], + # pyrefly: ignore [bad-argument-type] transition_to=object_json.get("transition_to"), use_all_trials_in_exp=True, ) @@ -659,6 +663,7 @@ def trials_from_json( if is_trial else batch_trial_from_json(experiment=experiment, **trial_json) ) + # pyrefly: ignore [bad-return] return loaded_trials @@ -810,6 +815,7 @@ def _load_experiment_info( ) exp._trials = trials_from_json( exp, + # pyrefly: ignore [bad-argument-type] exp_info.get("trials_json"), decoder_registry=decoder_registry, class_decoder_registry=class_decoder_registry, @@ -1176,6 +1182,7 @@ def generation_step_from_json( generator_name=generation_step_json.pop("generator_name", None), use_all_trials_in_exp=generation_step_json.pop("use_all_trials_in_exp", False), ) + # pyrefly: ignore [bad-return] return generation_step @@ -1338,12 +1345,14 @@ def surrogate_from_list_surrogate_json( decoder_registry=decoder_registry, class_decoder_registry=class_decoder_registry, ), + # pyrefly: ignore [bad-argument-type] model_options=list_surrogate_json.get("submodel_options"), mll_class=object_from_json( object_json=list_surrogate_json.get("mll_class"), decoder_registry=decoder_registry, class_decoder_registry=class_decoder_registry, ), + # pyrefly: ignore [bad-argument-type] mll_options=list_surrogate_json.get("mll_options"), input_transform_classes=object_from_json( object_json=list_surrogate_json.get( @@ -1380,6 +1389,7 @@ def surrogate_from_list_surrogate_json( decoder_registry=decoder_registry, class_decoder_registry=class_decoder_registry, ), + # pyrefly: ignore [bad-argument-type] covar_module_options=list_surrogate_json.get( "submodel_covar_module_options" ), @@ -1390,6 +1400,7 @@ def surrogate_from_list_surrogate_json( decoder_registry=decoder_registry, class_decoder_registry=class_decoder_registry, ), + # pyrefly: ignore [bad-argument-type] likelihood_options=list_surrogate_json.get( "submodel_likelihood_options" ), diff --git a/ax/storage/json_store/decoders.py b/ax/storage/json_store/decoders.py index b31be6bb617..1d7fd35a181 100644 --- a/ax/storage/json_store/decoders.py +++ b/ax/storage/json_store/decoders.py @@ -350,6 +350,7 @@ def botorch_component_from_json(botorch_class: type[T], json: dict[str, Any]) -> `CLASS_DECODER_REGISTRY` from state dict.""" state_dict = json.pop("state_dict") if issubclass(botorch_class, ChainedInputTransform): + # pyrefly: ignore [bad-return] return botorch_class( **{ k: botorch_component_from_json( @@ -360,6 +361,7 @@ def botorch_component_from_json(botorch_class: type[T], json: dict[str, Any]) -> } ) if issubclass(botorch_class, ChainedOutcomeTransform): + # pyrefly: ignore [bad-return] return botorch_class( **{ k: botorch_component_from_json( diff --git a/ax/storage/json_store/tests/test_json_store.py b/ax/storage/json_store/tests/test_json_store.py index d40598dcf47..b2a04dc4226 100644 --- a/ax/storage/json_store/tests/test_json_store.py +++ b/ax/storage/json_store/tests/test_json_store.py @@ -566,6 +566,7 @@ def test_SaveAndLoad(self) -> None: def test_SaveValidation(self) -> None: with self.assertRaises(ValueError): save_experiment( + # pyrefly: ignore [bad-argument-type] self.experiment.trials[0], "test.json", encoder_registry=CORE_ENCODER_REGISTRY, @@ -657,7 +658,9 @@ class TestDataclass: a_field: int not_a_field: dataclasses.InitVar[int | None] = None + # pyrefly: ignore [bad-function-definition] def __post_init__(self, doesnt_serialize: None) -> None: + # pyrefly: ignore [missing-attribute] self.not_a_field = 1 obj = TestDataclass(a_field=-1) @@ -952,9 +955,11 @@ def test_encode_decode_surrogate_spec(self) -> None: def test_RegistryAdditions(self) -> None: class MyRunner(Runner): + # pyrefly: ignore [bad-override] def run(): pass + # pyrefly: ignore [bad-override] def staging_required(): return False @@ -995,9 +1000,11 @@ class MyMetric(Metric): pass class MyRunner(Runner): + # pyrefly: ignore [bad-override] def run(): pass + # pyrefly: ignore [bad-override] def staging_required(): return False @@ -1571,12 +1578,15 @@ def test_mbm_backwards_compatibility(self) -> None: "warm_start_refit": True, } expected_object = get_botorch_model_with_surrogate_spec(with_covar_module=False) + # pyrefly: ignore [missing-attribute] expected_object.surrogate_spec.model_configs[0].input_transform_classes = None + # pyrefly: ignore [missing-attribute] expected_object.surrogate_spec.model_configs[0].name = "from deprecated args" # The new default value is None; we need to manually set it to the old value self.assertIsNone( none_throws(expected_object.surrogate_spec).model_configs[0].mll_class ) + # pyrefly: ignore [missing-attribute] expected_object.surrogate_spec.model_configs[ 0 ].mll_class = ExactMarginalLogLikelihood @@ -1628,6 +1638,7 @@ def test_mbm_backwards_compatibility_2(self) -> None: extra_args = {} if legacy_input_transform: extra_args["input_transform_classes"] = [Normalize] + # pyrefly: ignore [unsupported-operation] extra_args["input_transform_options"] = { "Normalize": { "d": 7, @@ -1643,6 +1654,7 @@ def test_mbm_backwards_compatibility_2(self) -> None: } } else: + # pyrefly: ignore [unsupported-operation] extra_args["input_transform_classes"] = None new_object = SurrogateSpec( model_configs=[ @@ -1650,6 +1662,7 @@ def test_mbm_backwards_compatibility_2(self) -> None: botorch_model_class=SingleTaskGP, mll_class=ExactMarginalLogLikelihood, name="from deprecated args", + # pyrefly: ignore [bad-argument-type] **extra_args, ) ], diff --git a/ax/storage/registry_bundle.py b/ax/storage/registry_bundle.py index dbde9930738..48a287df170 100644 --- a/ax/storage/registry_bundle.py +++ b/ax/storage/registry_bundle.py @@ -190,6 +190,7 @@ def __init__( json_decoder_registry=json_decoder_registry, json_class_decoder_registry=json_class_decoder_registry, ) + # pyrefly: ignore [not-callable] self._sqa_config = SQAConfig( json_encoder_registry={**self.encoder_registry, **CORE_ENCODER_REGISTRY}, json_decoder_registry={**self.decoder_registry, **CORE_DECODER_REGISTRY}, @@ -199,7 +200,9 @@ def __init__( json_class_decoder_registry=self.class_decoder_registry, ) + # pyrefly: ignore [not-callable] self._encoder = Encoder(self._sqa_config) + # pyrefly: ignore [not-callable] self._decoder = Decoder(self._sqa_config) @cached_property diff --git a/ax/storage/sqa_store/db.py b/ax/storage/sqa_store/db.py index c7864a1acd6..e527d5ea4b9 100644 --- a/ax/storage/sqa_store/db.py +++ b/ax/storage/sqa_store/db.py @@ -70,6 +70,7 @@ class SQABase: class Base(SQABase): metadata: Any = ... + # pyrefly: ignore [bad-assignment] __tablename__: str = ... __table__: Any = ... __table_args__: Any = ... diff --git a/ax/storage/sqa_store/decoder.py b/ax/storage/sqa_store/decoder.py index d564c764765..6d0c1c3731b 100644 --- a/ax/storage/sqa_store/decoder.py +++ b/ax/storage/sqa_store/decoder.py @@ -200,6 +200,7 @@ def _auxiliary_experiments_by_purpose_from_experiment_sqa( reduced_state=reduced_state, ) auxiliary_experiments_by_purpose[purpose].append(aux_experiment) + # pyrefly: ignore [bad-return] return auxiliary_experiments_by_purpose # TODO[@mpolson64]: Stop storing target arm in experiment properties @@ -419,6 +420,7 @@ def experiment_from_sqa( data_by_trial[trial_index][timestamp] = self.data_from_sqa( data_sqa=data_sqa ) + # pyrefly: ignore [bad-argument-type] experiment.data = data_by_trial_to_data(data_by_trial=data_by_trial) trial_type_to_runner = { @@ -767,6 +769,7 @@ def opt_config_and_tracking_metrics_from_sqa( objective, Union[MultiObjective, ScalarizedObjective] ), outcome_constraints=outcome_constraints, + # pyrefly: ignore [bad-argument-type] objective_thresholds=objective_thresholds, pruning_target_parameterization=pruning_target_parameterization, ) @@ -870,6 +873,7 @@ def generator_run_from_sqa( generator_run = GeneratorRun( arms=arms, + # pyrefly: ignore [bad-argument-type] weights=weights, optimization_config=opt_config, search_space=search_space, @@ -1143,6 +1147,7 @@ def data_from_sqa(self, data_sqa: SQAData) -> Data: # Override df from deserialize_init_args with `data_json`. # NOTE: Need dtype=False, otherwise infers arm_names like # "4_1" should be int 41. + # pyrefly: ignore [no-matching-overload] df = pd.read_json(StringIO(data_sqa.data_json), dtype=False) # Ensure trial_index is int (dtype=False can leave it as string) if "trial_index" in df.columns: diff --git a/ax/storage/sqa_store/encoder.py b/ax/storage/sqa_store/encoder.py index 7ee21b70b8d..84d179287a5 100644 --- a/ax/storage/sqa_store/encoder.py +++ b/ax/storage/sqa_store/encoder.py @@ -1058,7 +1058,9 @@ def generation_strategy_to_sqa( def runner_to_sqa(self, runner: Runner, trial_type: str | None = None) -> SQARunner: """Convert Ax Runner to SQLAlchemy.""" + # pyrefly: ignore [bad-assignment] runner_class = type(runner) + # pyrefly: ignore [bad-argument-type] runner_type = self.config.runner_registry.get(runner_class) if runner_type is None: raise SQAEncodeError( @@ -1067,6 +1069,7 @@ def runner_to_sqa(self, runner: Runner, trial_type: str | None = None) -> SQARun "The runner registry currently contains the following: " f"{','.join(map(str, self.config.runner_registry.keys()))} " ) + # pyrefly: ignore [missing-attribute] properties = runner_class.serialize_init_args(obj=runner) # pyre-fixme: Expected `Base` for 1st...t `typing.Type[Runner]`. runner_class: SQARunner = self.config.class_to_sqa_class[Runner] @@ -1158,6 +1161,7 @@ def experiment_data_to_sqa( return [ self.data_to_sqa( data=Data(df=df), + # pyrefly: ignore [bad-argument-type] trial_index=trial_index, timestamp=0, ) @@ -1204,6 +1208,7 @@ def auxiliary_experiments_by_purpose_to_sqa( # pyre-fixme: Expected `Base` for 1st...ot `typing.Type[AuxiliaryExperiment]`. auxiliary_experiment_class: SQAAuxiliaryExperiment = ( + # pyrefly: ignore [bad-assignment] self.config.class_to_sqa_class[AuxiliaryExperiment] ) diff --git a/ax/storage/sqa_store/json.py b/ax/storage/sqa_store/json.py index 630ece26ff9..df17d90cfee 100644 --- a/ax/storage/sqa_store/json.py +++ b/ax/storage/sqa_store/json.py @@ -24,6 +24,7 @@ class JSONEncodedObject(TypeDecorator): """ + # pyrefly: ignore [bad-override-mutable-attribute] impl: VARCHAR = VARCHAR(JSON_FIELD_LENGTH) cache_ok = True diff --git a/ax/storage/sqa_store/load.py b/ax/storage/sqa_store/load.py index b5cd204cb0d..96c01373def 100644 --- a/ax/storage/sqa_store/load.py +++ b/ax/storage/sqa_store/load.py @@ -340,6 +340,7 @@ def _set_sqa_metric_to_base_type( if sqa_metric.intent in composite_metric_intents: if sqa_metric.properties is None: sqa_metric.properties = {} + # pyrefly: ignore [unsupported-operation] sqa_metric.properties["skip_runners_and_metrics"] = True @@ -708,9 +709,11 @@ def load_analysis_cards_by_experiment_name( # Create query options which will recursively load all children of the # SQAAnalysisCard up to depth 20 + # pyrefly: ignore [bad-argument-type] card_query_options = joinedload(analysis_card_sqa_class.children) for _ in range(19): card_query_options = card_query_options.joinedload( + # pyrefly: ignore [bad-argument-type] analysis_card_sqa_class.children ) @@ -720,6 +723,7 @@ def load_analysis_cards_by_experiment_name( with session_scope() as session: query = ( + # pyrefly: ignore [no-matching-overload] session.query(analysis_card_sqa_class) .join(exp_sqa_class.analysis_cards) .filter(exp_sqa_class.name == experiment_name) diff --git a/ax/storage/sqa_store/tests/test_sqa_store.py b/ax/storage/sqa_store/tests/test_sqa_store.py index 5f3c50540b4..64c68eebb17 100644 --- a/ax/storage/sqa_store/tests/test_sqa_store.py +++ b/ax/storage/sqa_store/tests/test_sqa_store.py @@ -871,11 +871,17 @@ def test_experiment_save_and_load_reduced_state( self.assertNotEqual(loaded_experiment, exp) # Remove all fields that are not part of the reduced state and # check that everything else is equal as expected. + # pyrefly: ignore [missing-attribute] exp.trials.get(1).generator_run._generator_kwargs = None + # pyrefly: ignore [missing-attribute] exp.trials.get(1).generator_run._adapter_kwargs = None + # pyrefly: ignore [missing-attribute] exp.trials.get(1).generator_run._gen_metadata = None + # pyrefly: ignore [missing-attribute] exp.trials.get(1).generator_run._generator_state_after_gen = None + # pyrefly: ignore [missing-attribute] exp.trials.get(1).generator_run._search_space = None + # pyrefly: ignore [missing-attribute] exp.trials.get(1).generator_run._optimization_config = None self.assertEqual(loaded_experiment, exp) delete_experiment(exp_name=exp.name) @@ -1062,6 +1068,7 @@ def test_experiment_save_and_update_trials(self) -> None: def test_save_validation(self) -> None: with self.assertRaises(ValueError): + # pyrefly: ignore [bad-argument-type] save_experiment(self.experiment.trials[0]) experiment = get_experiment_with_batch_trial() @@ -1097,12 +1104,14 @@ def test_encode_decode(self) -> None: with self.assertRaises(RuntimeError): converted_object.evaluation_function(parameterization={}) + # pyrefly: ignore [missing-attribute] original_object.evaluation_function = None converted_object.evaluation_function = None # Experiment to SQA encoder stores the experiment subclass # among its properties; we then remove the subclass when # decoding. Removing subclass from original object here # for parity with the expected decoded (converted) object. + # pyrefly: ignore [missing-attribute] original_object._properties.pop(Keys.SUBCLASS) self.assertEqual( @@ -1934,7 +1943,10 @@ def test_decode_order_parameter_constraint_failure(self) -> None: ) with self.assertRaises(SQADecodeError): self.decoder.parameter_constraint_from_sqa( - sqa_parameter, self.dummy_parameters + # pyrefly: ignore [bad-argument-type] + sqa_parameter, + # pyrefly: ignore [bad-argument-type] + self.dummy_parameters, ) def test_decode_sum_parameter_constraint_failure(self) -> None: @@ -1945,7 +1957,10 @@ def test_decode_sum_parameter_constraint_failure(self) -> None: ) with self.assertRaises(SQADecodeError): self.decoder.parameter_constraint_from_sqa( - sqa_parameter, self.dummy_parameters + # pyrefly: ignore [bad-argument-type] + sqa_parameter, + # pyrefly: ignore [bad-argument-type] + self.dummy_parameters, ) def test_metric_validation(self) -> None: @@ -2137,9 +2152,11 @@ def test_get_properties(self) -> None: def test_registry_additions(self) -> None: class MyRunner(Runner): + # pyrefly: ignore [bad-override] def run(): pass + # pyrefly: ignore [bad-override] def staging_required(): return False @@ -2174,9 +2191,11 @@ class MyMetric(Metric): def test_registry_bundle(self) -> None: class MyRunner(Runner): + # pyrefly: ignore [bad-override] def run(): pass + # pyrefly: ignore [bad-override] def staging_required(): return False @@ -2453,11 +2472,17 @@ def test_encode_decode_generation_strategy_reduced_state_load_experiment( # state along with the generation strategy. self.assertNotEqual(new_generation_strategy.experiment, experiment) # Adjust experiment and GS to reduced state. + # pyrefly: ignore [missing-attribute] experiment.trials.get(0).generator_run._generator_kwargs = None + # pyrefly: ignore [missing-attribute] experiment.trials.get(0).generator_run._adapter_kwargs = None + # pyrefly: ignore [missing-attribute] experiment.trials.get(0).generator_run._gen_metadata = None + # pyrefly: ignore [missing-attribute] experiment.trials.get(0).generator_run._generator_state_after_gen = None + # pyrefly: ignore [missing-attribute] experiment.trials.get(0).generator_run._search_space = None + # pyrefly: ignore [missing-attribute] experiment.trials.get(0).generator_run._optimization_config = None generation_strategy._generator_runs[0]._generator_kwargs = None generation_strategy._generator_runs[0]._adapter_kwargs = None diff --git a/ax/storage/sqa_store/tests/test_with_db_settings_base.py b/ax/storage/sqa_store/tests/test_with_db_settings_base.py index d9895104b75..b7784515e78 100644 --- a/ax/storage/sqa_store/tests/test_with_db_settings_base.py +++ b/ax/storage/sqa_store/tests/test_with_db_settings_base.py @@ -53,12 +53,16 @@ def setUp(self) -> None: ) _save_experiment( self.experiment, + # pyrefly: ignore [missing-attribute] encoder=self.with_db_settings.db_settings.encoder, + # pyrefly: ignore [missing-attribute] decoder=self.with_db_settings.db_settings.decoder, ) _save_generation_strategy( generation_strategy=self.generation_strategy, + # pyrefly: ignore [missing-attribute] encoder=self.with_db_settings.db_settings.encoder, + # pyrefly: ignore [missing-attribute] decoder=self.with_db_settings.db_settings.decoder, ) @@ -90,13 +94,17 @@ def init_experiment_and_generation_strategy( if save_experiment: _save_experiment( experiment, + # pyrefly: ignore [missing-attribute] encoder=self.with_db_settings.db_settings.encoder, + # pyrefly: ignore [missing-attribute] decoder=self.with_db_settings.db_settings.decoder, ) if save_generation_strategy: _save_generation_strategy( generation_strategy=generation_strategy, + # pyrefly: ignore [missing-attribute] encoder=self.with_db_settings.db_settings.encoder, + # pyrefly: ignore [missing-attribute] decoder=self.with_db_settings.db_settings.decoder, ) return experiment, generation_strategy @@ -116,7 +124,10 @@ def test_save_experiment(self) -> None: saved = self.with_db_settings._save_experiment_to_db_if_possible(experiment) self.assertTrue(saved) loaded_experiment = _load_experiment( - experiment.name, decoder=self.with_db_settings.db_settings.decoder + # pyrefly: ignore [missing-attribute] + experiment.name, + # pyrefly: ignore [missing-attribute] + decoder=self.with_db_settings.db_settings.decoder, ) self.assertIsNotNone(loaded_experiment) self.assertEqual(experiment, loaded_experiment) @@ -131,6 +142,7 @@ def test_save_generation_strategy(self) -> None: self.assertTrue(saved) loaded_gs = _load_generation_strategy_by_experiment_name( experiment_name=experiment.name, + # pyrefly: ignore [missing-attribute] decoder=self.with_db_settings.db_settings.decoder, ) self.assertIsNotNone(loaded_gs) @@ -208,7 +220,10 @@ def test_save_new_trial(self) -> None: ) exp = _load_experiment( - experiment.name, decoder=self.with_db_settings.db_settings.decoder + # pyrefly: ignore [missing-attribute] + experiment.name, + # pyrefly: ignore [missing-attribute] + decoder=self.with_db_settings.db_settings.decoder, ) trial = exp.new_trial() saved = self.with_db_settings._save_or_update_trial_in_db_if_possible( @@ -216,7 +231,10 @@ def test_save_new_trial(self) -> None: ) self.assertTrue(saved) exp = _load_experiment( - experiment.name, decoder=self.with_db_settings.db_settings.decoder + # pyrefly: ignore [missing-attribute] + experiment.name, + # pyrefly: ignore [missing-attribute] + decoder=self.with_db_settings.db_settings.decoder, ) self.assertEqual(len(exp.trials), 1) self.assertEqual(exp.trials[0].status, TrialStatus.CANDIDATE) @@ -227,13 +245,18 @@ def test_save_updated_trial(self) -> None: ) exp = _load_experiment( - experiment.name, decoder=self.with_db_settings.db_settings.decoder + # pyrefly: ignore [missing-attribute] + experiment.name, + # pyrefly: ignore [missing-attribute] + decoder=self.with_db_settings.db_settings.decoder, ) trial = exp.new_trial() _save_or_update_trials( experiment=experiment, trials=[trial], + # pyrefly: ignore [missing-attribute] encoder=self.with_db_settings.db_settings.encoder, + # pyrefly: ignore [missing-attribute] decoder=self.with_db_settings.db_settings.decoder, ) self.assertEqual(trial.status, TrialStatus.CANDIDATE) @@ -244,7 +267,10 @@ def test_save_updated_trial(self) -> None: ) self.assertTrue(saved) exp = _load_experiment( - experiment.name, decoder=self.with_db_settings.db_settings.decoder + # pyrefly: ignore [missing-attribute] + experiment.name, + # pyrefly: ignore [missing-attribute] + decoder=self.with_db_settings.db_settings.decoder, ) self.assertEqual(len(exp.trials), 1) self.assertEqual(exp.trials[0].status, TrialStatus.RUNNING) @@ -263,7 +289,10 @@ def test_updated_trials_mini_batch(self) -> None: trials=[trial], ) loaded_experiment = _load_experiment( - experiment.name, decoder=self.with_db_settings.db_settings.decoder + # pyrefly: ignore [missing-attribute] + experiment.name, + # pyrefly: ignore [missing-attribute] + decoder=self.with_db_settings.db_settings.decoder, ) self.assertEqual( loaded_experiment.trials[trial.index].status, @@ -284,7 +313,10 @@ def test_updated_trials_mini_batch(self) -> None: trials=trials, ) loaded_experiment = _load_experiment( - experiment.name, decoder=self.with_db_settings.db_settings.decoder + # pyrefly: ignore [missing-attribute] + experiment.name, + # pyrefly: ignore [missing-attribute] + decoder=self.with_db_settings.db_settings.decoder, ) # All trials except for the one we marked as running should be candidates. for t in trials: @@ -315,7 +347,10 @@ def test_update_reduced_state_generator_runs(self) -> None: ) loaded_experiment = _load_experiment( - experiment.name, decoder=self.with_db_settings.db_settings.decoder + # pyrefly: ignore [missing-attribute] + experiment.name, + # pyrefly: ignore [missing-attribute] + decoder=self.with_db_settings.db_settings.decoder, ) # Only the last trial's generator run should have large model attributes @@ -330,7 +365,10 @@ def test_update_reduced_state_generator_runs(self) -> None: self.assertIsNotNone(getattr(t.generator_run, python_attr_name)) loaded_generation_strategy = _load_generation_strategy_by_experiment_name( - experiment.name, decoder=self.with_db_settings.db_settings.decoder + # pyrefly: ignore [missing-attribute] + experiment.name, + # pyrefly: ignore [missing-attribute] + decoder=self.with_db_settings.db_settings.decoder, ) # Only the last generator run should have large model attributes @@ -352,7 +390,10 @@ def test_update_experiment_properties_in_db(self) -> None: experiment_with_updated_properties=experiment ) loaded_experiment = _load_experiment( - experiment.name, decoder=self.with_db_settings.db_settings.decoder + # pyrefly: ignore [missing-attribute] + experiment.name, + # pyrefly: ignore [missing-attribute] + decoder=self.with_db_settings.db_settings.decoder, ) self.assertEqual( loaded_experiment._properties, @@ -369,6 +410,7 @@ def test_try_load_generation_strategy(self) -> None: ) as lg: output = try_load_generation_strategy( experiment_name=experiment.name, + # pyrefly: ignore [missing-attribute] decoder=self.with_db_settings.db_settings.decoder, experiment=experiment, ) @@ -391,6 +433,7 @@ def test_try_load_generation_strategy(self) -> None: ) as lg: output = try_load_generation_strategy( experiment_name=experiment.name, + # pyrefly: ignore [missing-attribute] decoder=self.with_db_settings.db_settings.decoder, experiment=experiment, ) diff --git a/ax/storage/sqa_store/with_db_settings_base.py b/ax/storage/sqa_store/with_db_settings_base.py index c86589bfea4..14972894816 100644 --- a/ax/storage/sqa_store/with_db_settings_base.py +++ b/ax/storage/sqa_store/with_db_settings_base.py @@ -99,7 +99,10 @@ def __init__( self._suppress_all_errors = suppress_all_errors if self.db_settings_set: init_engine_and_session_factory( - creator=self.db_settings.creator, url=self.db_settings.url + # pyrefly: ignore [missing-attribute] + creator=self.db_settings.creator, + # pyrefly: ignore [missing-attribute] + url=self.db_settings.url, ) logger.setLevel(logging_level) @@ -133,12 +136,18 @@ def _get_experiment_and_generation_strategy_db_id( return None, None exp_id = _get_experiment_id( - experiment_name=experiment_name, config=self.db_settings.decoder.config + # pyrefly: ignore [missing-attribute] + experiment_name=experiment_name, + # pyrefly: ignore [missing-attribute] + config=self.db_settings.decoder.config, ) if not exp_id: return None, None gs_id = get_generation_strategy_id( - experiment_name=experiment_name, decoder=self.db_settings.decoder + # pyrefly: ignore [missing-attribute] + experiment_name=experiment_name, + # pyrefly: ignore [missing-attribute] + decoder=self.db_settings.decoder, ) return exp_id, gs_id @@ -236,6 +245,7 @@ def _load_experiment_and_generation_strategy( start_time = time.time() experiment = _load_experiment( experiment_name, + # pyrefly: ignore [missing-attribute] decoder=self.db_settings.decoder, reduced_state=reduced_state, load_trials_in_batches_of_size=LOADING_MINI_BATCH_SIZE, @@ -251,6 +261,7 @@ def _load_experiment_and_generation_strategy( ) generation_strategy = try_load_generation_strategy( experiment_name=experiment_name, + # pyrefly: ignore [missing-attribute] decoder=self.db_settings.decoder, experiment=experiment, reduced_state=reduced_state, @@ -271,7 +282,9 @@ def _save_experiment_to_db_if_possible(self, experiment: Experiment) -> bool: if self.db_settings_set: _save_experiment_to_db_if_possible( experiment=experiment, + # pyrefly: ignore [missing-attribute] encoder=self.db_settings.encoder, + # pyrefly: ignore [missing-attribute] decoder=self.db_settings.decoder, suppress_all_errors=self._suppress_all_errors, ) @@ -314,6 +327,7 @@ def _save_or_update_trials_and_generation_strategy_if_possible( if experiment.status is not None and self.db_settings_set: update_experiment_status( experiment=experiment, + # pyrefly: ignore [missing-attribute] config=self.db_settings.encoder.config, ) return @@ -359,7 +373,9 @@ def _save_or_update_trials_in_db_if_possible( _save_or_update_trials_in_db_if_possible( experiment=experiment, trials=trials, + # pyrefly: ignore [missing-attribute] encoder=self.db_settings.encoder, + # pyrefly: ignore [missing-attribute] decoder=self.db_settings.decoder, suppress_all_errors=self._suppress_all_errors, reduce_state_generator_runs=reduce_state_generator_runs, @@ -387,7 +403,9 @@ def _save_generation_strategy_to_db_if_possible( # the database because only they make changes locally _save_generation_strategy_to_db_if_possible( generation_strategy=generation_strategy, + # pyrefly: ignore [missing-attribute] encoder=self.db_settings.encoder, + # pyrefly: ignore [missing-attribute] decoder=self.db_settings.decoder, suppress_all_errors=self._suppress_all_errors, ) @@ -421,7 +439,9 @@ def _update_generation_strategy_in_db_if_possible( _update_generation_strategy_in_db_if_possible( generation_strategy=generation_strategy, new_generator_runs=new_generator_runs, + # pyrefly: ignore [missing-attribute] encoder=self.db_settings.encoder, + # pyrefly: ignore [missing-attribute] decoder=self.db_settings.decoder, suppress_all_errors=self._suppress_all_errors, reduce_state_generator_runs=reduce_state_generator_runs, @@ -436,7 +456,9 @@ def _update_runner_on_experiment_in_db_if_possible( _update_runner_on_experiment_in_db_if_possible( experiment=experiment, runner=runner, + # pyrefly: ignore [missing-attribute] encoder=self.db_settings.encoder, + # pyrefly: ignore [missing-attribute] decoder=self.db_settings.decoder, suppress_all_errors=self._suppress_all_errors, ) @@ -452,6 +474,7 @@ def _update_experiment_properties_in_db( if self.db_settings_set: _update_experiment_properties_in_db( experiment_with_updated_properties=exp, + # pyrefly: ignore [missing-attribute] sqa_config=self.db_settings.encoder.config, suppress_all_errors=self._suppress_all_errors, ) @@ -467,6 +490,7 @@ def _save_analysis_card_to_db_if_possible( _save_analysis_card_to_db( experiment=experiment, analysis_card=analysis_card, + # pyrefly: ignore [missing-attribute] sqa_config=self.db_settings.encoder.config, suppress_all_errors=self._suppress_all_errors, ) @@ -478,6 +502,7 @@ def _save_analysis_card_to_db_if_possible( # ------------- Utils for storage that assume `DBSettings` are provided -------- +# pyrefly: ignore [not-callable] @retry_on_exception( retries=3, default_return_on_suppression=False, @@ -493,7 +518,9 @@ def _save_experiment_to_db_if_possible( start_time = time.time() _save_experiment( experiment, + # pyrefly: ignore [bad-argument-type] encoder=encoder, + # pyrefly: ignore [bad-argument-type] decoder=decoder, ) logger.debug( @@ -502,6 +529,7 @@ def _save_experiment_to_db_if_possible( ) +# pyrefly: ignore [not-callable] @retry_on_exception( retries=3, default_return_on_suppression=False, @@ -520,7 +548,9 @@ def _save_or_update_trials_in_db_if_possible( _save_or_update_trials( experiment=experiment, trials=trials, + # pyrefly: ignore [bad-argument-type] encoder=encoder, + # pyrefly: ignore [bad-argument-type] decoder=decoder, batch_size=STORAGE_MINI_BATCH_SIZE, reduce_state_generator_runs=reduce_state_generator_runs, @@ -532,6 +562,7 @@ def _save_or_update_trials_in_db_if_possible( ) +# pyrefly: ignore [not-callable] @retry_on_exception( retries=3, default_return_on_suppression=False, @@ -547,7 +578,9 @@ def _save_generation_strategy_to_db_if_possible( start_time = time.time() _save_generation_strategy( generation_strategy=generation_strategy, + # pyrefly: ignore [bad-argument-type] encoder=encoder, + # pyrefly: ignore [bad-argument-type] decoder=decoder, ) logger.debug( @@ -556,6 +589,7 @@ def _save_generation_strategy_to_db_if_possible( ) +# pyrefly: ignore [not-callable] @retry_on_exception( retries=3, default_return_on_suppression=False, @@ -574,7 +608,9 @@ def _update_generation_strategy_in_db_if_possible( _update_generation_strategy( generation_strategy=generation_strategy, generator_runs=new_generator_runs, + # pyrefly: ignore [bad-argument-type] encoder=encoder, + # pyrefly: ignore [bad-argument-type] decoder=decoder, batch_size=STORAGE_MINI_BATCH_SIZE, reduce_state_generator_runs=reduce_state_generator_runs, @@ -586,6 +622,7 @@ def _update_generation_strategy_in_db_if_possible( ) +# pyrefly: ignore [not-callable] @retry_on_exception( retries=3, default_return_on_suppression=False, @@ -600,10 +637,17 @@ def _update_runner_on_experiment_in_db_if_possible( suppress_all_errors: bool, # Used by the decorator. ) -> None: update_runner_on_experiment( - experiment=experiment, runner=runner, encoder=encoder, decoder=decoder + # pyrefly: ignore [bad-argument-type] + experiment=experiment, + runner=runner, + # pyrefly: ignore [bad-argument-type] + encoder=encoder, + # pyrefly: ignore [bad-argument-type] + decoder=decoder, ) +# pyrefly: ignore [not-callable] @retry_on_exception( retries=3, default_return_on_suppression=False, @@ -621,6 +665,7 @@ def _update_experiment_properties_in_db( ) +# pyrefly: ignore [not-callable] @retry_on_exception( retries=3, default_return_on_suppression=False, @@ -651,6 +696,7 @@ def try_load_generation_strategy( start_time = time.time() generation_strategy = _load_generation_strategy_by_experiment_name( experiment_name=experiment_name, + # pyrefly: ignore [bad-argument-type] decoder=decoder, experiment=experiment, reduced_state=reduced_state, diff --git a/ax/storage/utils.py b/ax/storage/utils.py index fe48eee0767..54df52b90af 100644 --- a/ax/storage/utils.py +++ b/ax/storage/utils.py @@ -82,6 +82,7 @@ def data_to_data_by_trial(data: Data) -> dict[int, OrderedDict[int, Data]]: """ if len(data.full_df) == 0: return {} + # pyrefly: ignore [bad-return] return { trial_index: OrderedDict([(0, Data(df=df))]) for trial_index, df in data.full_df.groupby("trial_index") diff --git a/ax/utils/common/mock.py b/ax/utils/common/mock.py index d2687b64ffb..15b094a5258 100644 --- a/ax/utils/common/mock.py +++ b/ax/utils/common/mock.py @@ -20,6 +20,7 @@ def mock_patch_method_original( mock_path: str, original_method: Callable[..., T], + # pyrefly: ignore [bad-return] ) -> MagicMock: """Context manager for patching a method returning type T on class C, to track calls to it while still executing the original method. There diff --git a/ax/utils/common/random.py b/ax/utils/common/random.py index a9fb6e45329..fff35bfe28c 100644 --- a/ax/utils/common/random.py +++ b/ax/utils/common/random.py @@ -22,6 +22,7 @@ def set_rng_seed(seed: int) -> None: seed: The random number generator seed. """ random.seed(seed) + # pyrefly: ignore [bad-argument-type] np.random.seed(seed) torch.manual_seed(seed) diff --git a/ax/utils/common/tests/test_executils.py b/ax/utils/common/tests/test_executils.py index 738490ddd5a..65ae4a2482e 100644 --- a/ax/utils/common/tests/test_executils.py +++ b/ax/utils/common/tests/test_executils.py @@ -28,6 +28,7 @@ def test_default_return(self) -> None: """ class DecoratorTester: + # pyrefly: ignore [not-callable] @retry_on_exception( suppress_all_errors=True, default_return_on_suppression="SUCCESS" ) @@ -44,6 +45,7 @@ def test_kwarg_passage(self) -> None: """ class DecoratorTester: + # pyrefly: ignore [not-callable] @retry_on_exception(default_return_on_suppression="SUCCESS") def error_throwing_function( self, suppress_all_errors=False, extra_kwarg="1234" @@ -70,6 +72,7 @@ def test_message_checking(self) -> None: logger = getLogger("test_message_checking") class DecoratorTester: + # pyrefly: ignore [not-callable] @retry_on_exception( default_return_on_suppression="SUCCESS", check_message_contains=["Hello", "World"], @@ -95,6 +98,7 @@ def test_empty_exception_type_tuple(self) -> None: logger = getLogger("test_message_checking") class DecoratorTester: + # pyrefly: ignore [not-callable] @retry_on_exception( default_return_on_suppression="SUCCESS", exception_types=(), @@ -117,6 +121,7 @@ def test_message_checking_fail(self) -> None: """ class DecoratorTester: + # pyrefly: ignore [not-callable] @retry_on_exception( default_return_on_suppression="SUCCESS", check_message_contains=["Hello", "World"], @@ -140,6 +145,7 @@ class DecoratorTester: def __init__(self) -> None: self.retries_done = 0 + # pyrefly: ignore [not-callable] @retry_on_exception(retries=4) def error_throwing_function(self) -> None: # The call below will succeed only on the 3rd try @@ -152,6 +158,7 @@ def succeed_on_3rd_try(self) -> None: "This error surfacing means enough retries were not done" ) else: + # pyrefly: ignore [bad-return] return "SUCCESS" decorator_tester = DecoratorTester() @@ -166,11 +173,13 @@ class DecoratorTester: def __init__(self) -> None: self.start_time = time.time() + # pyrefly: ignore [not-callable] @retry_on_exception(retries=4, initial_wait_seconds=1) def error_throwing_function(self) -> None: # The call below will succeed only on the 3rd try return self.succeed_after_five_seconds() + # pyrefly: ignore [not-callable] @retry_on_exception(retries=4) def no_wait_error_throwing_function(self) -> None: # The call below will succeed only on the 3rd try @@ -182,6 +191,7 @@ def succeed_after_five_seconds(self) -> None: "This error surfacing means enough retries were not done" ) else: + # pyrefly: ignore [bad-return] return "SUCCESS" decorator_tester = DecoratorTester() @@ -203,6 +213,7 @@ class DecoratorTester: def __init__(self) -> None: self.xyz = 0 + # pyrefly: ignore [not-callable] @retry_on_exception(retries=2, logger=logger) def error_throwing_function(self) -> None: # The call below will succeed only on the 3rd try @@ -213,6 +224,7 @@ def succeed_on_3rd_try(self) -> None: self.xyz += 1 raise KeyError else: + # pyrefly: ignore [bad-return] return "SUCCESS" decorator_tester = DecoratorTester() @@ -226,6 +238,7 @@ class MyRuntimeError(RuntimeError): class DecoratorTester: error_throwing_function_call_count = 0 + # pyrefly: ignore [not-callable] @retry_on_exception(no_retry_on_exception_types=(MyRuntimeError,)) def error_throwing_function(self) -> None: self.error_throwing_function_call_count += 1 @@ -243,6 +256,7 @@ def error_throwing_function(self) -> None: class DecoratorTester: error_throwing_function_call_count = 0 + # pyrefly: ignore [not-callable] @retry_on_exception( exception_types=(RuntimeError,), no_retry_on_exception_types=(MyRuntimeError,), @@ -264,6 +278,7 @@ def test_on_function_with_wrapper_message(self) -> None: mock: Mock = Mock() + # pyrefly: ignore [not-callable] @retry_on_exception(wrap_error_message_in="Wrapper error message") def error_throwing_function() -> None: mock() diff --git a/ax/utils/common/tests/test_random.py b/ax/utils/common/tests/test_random.py index 8fe4fda46a9..9621269af45 100644 --- a/ax/utils/common/tests/test_random.py +++ b/ax/utils/common/tests/test_random.py @@ -19,6 +19,7 @@ def test_set_rng_seed(self) -> None: # Set the seeds manually & using the helper, and compares the random numbers. seed = 0 random.seed(seed) + # pyrefly: ignore [bad-argument-type] np.random.seed(seed) torch.manual_seed(seed) native_rand = random.random() diff --git a/ax/utils/common/tests/test_result.py b/ax/utils/common/tests/test_result.py index f67a7891d3a..d033d901140 100644 --- a/ax/utils/common/tests/test_result.py +++ b/ax/utils/common/tests/test_result.py @@ -47,14 +47,20 @@ def g(val: str) -> int: def h() -> int: return -1 + # pyrefly: ignore [bad-argument-type] self.assertEqual(self.ok.map(op=f), Ok(1)) self.assertEqual(self.ok.map_err(op=g), Ok(0)) + # pyrefly: ignore [bad-argument-type] self.assertEqual(self.ok.map_or(default="foo", op=f), 1) + # pyrefly: ignore [bad-argument-type] self.assertEqual(self.ok.map_or_else(default_op=h, op=f), 1) + # pyrefly: ignore [bad-argument-type] self.assertEqual(self.err.map(op=f), Err("yikes")) self.assertEqual(self.err.map_err(op=g), Err(5)) + # pyrefly: ignore [bad-argument-type] self.assertEqual(self.err.map_or(default="foo", op=f), "foo") + # pyrefly: ignore [bad-argument-type] self.assertEqual(self.err.map_or_else(default_op=h, op=f), -1) def test_unwrap(self) -> None: @@ -62,6 +68,7 @@ def test_unwrap(self) -> None: with self.assertRaises(RuntimeError): self.ok.unwrap_err() self.assertEqual(self.ok.unwrap_or(1), 0) + # pyrefly: ignore [bad-argument-type] self.assertEqual(self.ok.unwrap_or_else(1), 0) with self.assertRaises(RuntimeError): diff --git a/ax/utils/common/testutils.py b/ax/utils/common/testutils.py index 1384cd92496..07f0e9702f4 100644 --- a/ax/utils/common/testutils.py +++ b/ax/utils/common/testutils.py @@ -343,7 +343,10 @@ def setUp(self) -> None: ) def run( - self, result: unittest.result.TestResult | None = ... + # pyrefly: ignore [bad-function-definition] + self, + # pyrefly: ignore [bad-function-definition] + result: unittest.result.TestResult | None = ..., ) -> unittest.result.TestResult | None: # Arrange for a SIGALRM signal to be delivered to the calling process # in specified number of seconds. diff --git a/ax/utils/common/timeutils.py b/ax/utils/common/timeutils.py index 848f3183da7..fdae9b8b5f5 100644 --- a/ax/utils/common/timeutils.py +++ b/ax/utils/common/timeutils.py @@ -28,6 +28,7 @@ def to_ts(ds: str) -> datetime: def _ts_to_pandas(ts: int) -> pd.Timestamp: """Convert int timestamp into pandas timestamp.""" + # pyrefly: ignore [bad-return] return pd.Timestamp(datetime.fromtimestamp(ts)) diff --git a/ax/utils/measurement/synthetic_functions.py b/ax/utils/measurement/synthetic_functions.py index 76905bc7fb0..c499d78e28f 100644 --- a/ax/utils/measurement/synthetic_functions.py +++ b/ax/utils/measurement/synthetic_functions.py @@ -193,6 +193,7 @@ class Hartmann6(SyntheticFunction): _required_dimensionality = 6 _domain: list[tuple[float, float]] = [(0.0, 1.0) for i in range(6)] _minimums = [(0.20169, 0.150011, 0.476874, 0.275332, 0.311652, 0.6573)] + # pyrefly: ignore [bad-override-mutable-attribute] _fmin: float = -3.32237 _fmax = 0.0 _alpha: npt.NDArray = np.array([1.0, 1.2, 3.0, 3.2]) @@ -230,6 +231,7 @@ class Aug_Hartmann6(Hartmann6): _required_dimensionality = 7 _domain: list[tuple[float, float]] = [(0.0, 1.0) for i in range(7)] + # pyrefly: ignore [bad-override-mutable-attribute] _minimums: list[tuple[float, ...]] = [ (0.20169, 0.150011, 0.476874, 0.275332, 0.311652, 0.6573, 1.0) ] @@ -257,6 +259,7 @@ class Branin(SyntheticFunction): _required_dimensionality = 2 _domain: list[tuple[float, float]] = [(-5.0, 10.0), (0.0, 15.0)] + # pyrefly: ignore [bad-override-mutable-attribute] _minimums: list[tuple[float, float]] = [ (-np.pi, 12.275), (np.pi, 2.275), @@ -282,6 +285,7 @@ class Aug_Branin(SyntheticFunction): _required_dimensionality = 3 _domain: list[tuple[float, float]] = [(-5.0, 10.0), (0.0, 15.0), (0.0, 1.0)] + # pyrefly: ignore [bad-override-mutable-attribute] _minimums: list[tuple[float, float, float]] = [ (-np.pi, 12.275, 1.0), (np.pi, 2.275, 1.0), diff --git a/ax/utils/sensitivity/sobol_measures.py b/ax/utils/sensitivity/sobol_measures.py index 8fd5af8df7f..094924c6009 100644 --- a/ax/utils/sensitivity/sobol_measures.py +++ b/ax/utils/sensitivity/sobol_measures.py @@ -1022,6 +1022,7 @@ def _get_model_per_metric( if metric_model.num_outputs > 1: # subset to relevant output metric_model = metric_model.subset_output([i]) model_list.append(metric_model) + # pyrefly: ignore [bad-return] return model_list diff --git a/ax/utils/sensitivity/tests/test_sensitivity.py b/ax/utils/sensitivity/tests/test_sensitivity.py index 312b7ddb3c0..2d76008c7e4 100644 --- a/ax/utils/sensitivity/tests/test_sensitivity.py +++ b/ax/utils/sensitivity/tests/test_sensitivity.py @@ -131,6 +131,7 @@ def test_DgsmGpMean_batched_gradient_equivalence(self) -> None: input_mc_samples.requires_grad = True posterior = self.model.posterior(input_mc_samples) torch.sum(posterior.mean).backward() + # pyrefly: ignore [missing-attribute] grad_unbatched_raw = input_mc_samples.grad.clone() # Aggregate the same way as gradient_measure(): mean per dimension grad_unbatched = torch.tensor( @@ -155,6 +156,7 @@ def test_DgsmGpMean_batched_gradient_equivalence(self) -> None: posterior_saas = self.saas_model.posterior(input_mc_samples_saas) torch.sum(posterior_saas.mean).backward() grad_unbatched_saas = torch.tensor( + # pyrefly: ignore [unsupported-operation] [torch.mean(input_mc_samples_saas.grad[:, i]) for i in range(2)] ) diff --git a/ax/utils/stats/no_effects.py b/ax/utils/stats/no_effects.py index 15b744f2f8f..fcb10b80cde 100644 --- a/ax/utils/stats/no_effects.py +++ b/ax/utils/stats/no_effects.py @@ -52,6 +52,7 @@ def check_experiment_effects_per_metric( df_tone = pd.DataFrame(columns=cols) + # pyrefly: ignore [not-iterable] for metric_name, trial_index in df_grouped.groups.keys(): dfm = df_grouped.get_group((metric_name, trial_index)) @@ -150,6 +151,7 @@ def check_experiment_effects( effective = np.min(ps) < no_effect_alpha bounds_df = pd.DataFrame(fx_bounds, columns=["metric_name", "min", "max", "p"]) bounds_df.sort_values(by="p", inplace=True) + # pyrefly: ignore [bad-return] return effective, ineffective_on_objectives, bounds_df diff --git a/ax/utils/stats/tests/test_model_fit_stats.py b/ax/utils/stats/tests/test_model_fit_stats.py index 3ec8bf83bdc..6072e3b3435 100644 --- a/ax/utils/stats/tests/test_model_fit_stats.py +++ b/ax/utils/stats/tests/test_model_fit_stats.py @@ -53,6 +53,7 @@ def test_kendall_tau_rank_correlation_perfect_negative(self) -> None: self.assertAlmostEqual(ax_result, -1.0) def test_entropy_of_observations(self) -> None: + # pyrefly: ignore [bad-argument-type] np.random.seed(1234) n = 16 yc = np.ones(n) diff --git a/ax/utils/testing/core_stubs.py b/ax/utils/testing/core_stubs.py index 926e7704e0d..e28be07d71b 100644 --- a/ax/utils/testing/core_stubs.py +++ b/ax/utils/testing/core_stubs.py @@ -255,6 +255,7 @@ def get_experiment_with_custom_runner_and_metric( optimization_config = OptimizationConfig( objective=custom_scalarized_objective, + # pyrefly: ignore [bad-argument-type] outcome_constraints=outcome_constraints, ) else: @@ -557,8 +558,11 @@ def get_branin_experiment_with_timestamp_map_metric( ] # Add objective metrics so the Experiment owns the real metric types. # pyre-ignore[6]: Covariance issue with list[T] + # pyrefly: ignore [bad-argument-type] tracking_metrics.extend( - local_get_map_metric(f"branin_map_{m}") for m in range(num_objectives) + # pyrefly: ignore [bad-argument-type] + local_get_map_metric(f"branin_map_{m}") + for m in range(num_objectives) ) if with_outcome_constraint: tracking_metrics.append( @@ -1890,9 +1894,11 @@ def _get_candidate_metadata(self, arm_name: str) -> dict[str, Any] | None: def _get_candidate_metadata_from_all_generator_runs( self, + # pyrefly: ignore [bad-override] ) -> dict[str, dict[str, Any] | None]: return {"test": None} + # pyrefly: ignore [bad-override] def abandoned_arms(self) -> str: return "test" @@ -1900,13 +1906,18 @@ def abandoned_arms(self) -> str: def arms(self) -> list[Arm]: return self._arms + # pyrefly: ignore [bad-override] @arms.setter def arms(self, val: list[Arm]) -> None: self._arms = val + # pyrefly: ignore [bad-override] + + # pyrefly: ignore [bad-override] def arms_by_name(self) -> str: return "test" + # pyrefly: ignore [bad-override] def generator_runs(self) -> str: return "test" @@ -2386,12 +2397,14 @@ def get_branin_multi_objective_optimization_config( objective_thresholds.append( ObjectiveThreshold( metric=get_branin_metric(name="branin_c"), + # pyrefly: ignore [bad-assignment] bound=5.0, op=ComparisonOp.LEQ, relative=False, ) ) else: + # pyrefly: ignore [bad-assignment] objective_thresholds = None outcome_constraints = [] if with_relative_constraint: @@ -2470,11 +2483,15 @@ def get_arms() -> list[Arm]: return list(get_arm_weights1().keys()) +# pyrefly: ignore [bad-argument-type] + + def get_weights() -> list[float]: return list(get_arm_weights1().values()) def get_branin_arms(n: int, seed: int) -> list[Arm]: + # pyrefly: ignore [bad-argument-type] np.random.seed(seed) x1_raw = np.random.rand(n) x2_raw = np.random.rand(n) @@ -2690,6 +2707,7 @@ def get_branin_data_batch( fill_vals = fill_vals or {} metrics = metrics or ["branin"] for arm in batch.arms: + # pyrefly: ignore [bad-argument-type] params = arm.parameters for k, v in fill_vals.items(): if params.get(k, None) is None: @@ -2698,6 +2716,7 @@ def get_branin_data_batch( means.append(5.0) else: means.append( + # pyrefly: ignore [bad-argument-type] branin( float(none_throws(params["x1"])), float(none_throws(params["x2"])), diff --git a/ax/utils/testing/modeling_stubs.py b/ax/utils/testing/modeling_stubs.py index adba7c17472..09923fead95 100644 --- a/ax/utils/testing/modeling_stubs.py +++ b/ax/utils/testing/modeling_stubs.py @@ -564,6 +564,7 @@ def get_surrogate_generation_step() -> GenerationStep: Note: This is kept for backward compatibility testing. New code should use get_surrogate_generation_node() instead. """ + # pyrefly: ignore [bad-return] return GenerationStep( generator=Generators.BOTORCH_MODULAR, num_trials=-1, diff --git a/ax/utils/testing/preference_stubs.py b/ax/utils/testing/preference_stubs.py index 5f4ae900924..7e07c4c28d1 100644 --- a/ax/utils/testing/preference_stubs.py +++ b/ax/utils/testing/preference_stubs.py @@ -64,6 +64,7 @@ def experimental_metric_eval( ) for metric_name in metric_names } + # pyrefly: ignore [bad-return] return result_dict