From 36a52558adc6b8b5b65946670b4216f54a7543f4 Mon Sep 17 00:00:00 2001 From: Sait Cakmak Date: Thu, 18 Jun 2026 12:23:15 -0700 Subject: [PATCH] Add ScaleRBFLinearKernel covar module MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: Following the `ScaleMaternKernel` pattern in Ax botorch_modular, this adds `ScaleRBFLinearKernel` — a sum of `ScaleKernel(RBF-ARD)` and `LinearKernel`. It is the kernel-side analog of a linear mean: the `LinearKernel` models linear-in-input correlations (e.g. a scaling-law trend) through the covariance rather than the mean, while the `ScaleKernel`-wrapped ARD RBF captures local flexibility and lets the model learn the relative amplitude of the local vs. global signal. This is the kernel-only counterpart of the `RBFLinearGP` model studied in D108386392, designed to plug into a plain `SingleTaskGP` via `ModelConfig(covar_module_class=...)`. `ScaleRBFLinearKernel` subclasses GPyTorch's `AdditiveKernel` so that `isinstance` checks and storage serialization work cleanly. Its `__init__` mirrors `ScaleMaternKernel`, adding a `variance_prior` for the linear component and an `active_dims` argument. `active_dims` is applied to both component kernels: it is set on the inner `RBFKernel` and on the `LinearKernel`, and the `ScaleKernel` inherits `active_dims` from the kernel it wraps. Because GPyTorch subsets inputs once at the outermost kernel in a call chain (and `ScaleKernel.forward` calls `base_kernel.forward` directly), the input is subset exactly once rather than twice. Threading `active_dims` through the input constructor also enables the `remove_task_features` workflow (excluding a task feature from the kernel for a `SingleTaskGP` on a `MultiTaskDataset`). Changes: - `kernels.py`: new `ScaleRBFLinearKernel(AdditiveKernel)` with `active_dims` support. - `input_constructors/covar_modules.py`: register `_covar_module_argparse_scale_rbf_linear`, reusing `_get_default_ard_num_dims_and_batch_shape` to infer `ard_num_dims`/`batch_shape`/`active_dims` from the dataset and model class. - `botorch_modular_registry.py`: add `ScaleRBFLinearKernel` to `KERNEL_REGISTRY` (this automatically wires up JSON/SQA encode+decode). - Tests in `test_kernels.py` (incl. an active-dims single-subsetting check) and `test_covar_modules_argparse.py` (incl. active-dims and remove_task_features cases). Differential Revision: D109044330 --- .../input_constructors/covar_modules.py | 22 +++-- .../torch/botorch_modular/kernels.py | 76 ++++++++++++++- .../tests/test_covar_modules_argparse.py | 92 ++++++++++++++++++- ax/generators/torch/tests/test_kernels.py | 90 +++++++++++++++++- ax/storage/botorch_modular_registry.py | 2 + 5 files changed, 273 insertions(+), 9 deletions(-) diff --git a/ax/generators/torch/botorch_modular/input_constructors/covar_modules.py b/ax/generators/torch/botorch_modular/input_constructors/covar_modules.py index 120a4bb0a78..acb327422e6 100644 --- a/ax/generators/torch/botorch_modular/input_constructors/covar_modules.py +++ b/ax/generators/torch/botorch_modular/input_constructors/covar_modules.py @@ -19,6 +19,7 @@ DefaultMaternKernel, DefaultRBFKernel, ScaleMaternKernel, + ScaleRBFLinearKernel, ) from ax.utils.common.logger import get_logger from ax.utils.common.typeutils import _argparse_type_encoder @@ -64,19 +65,26 @@ def _covar_module_argparse_base( return {**kwargs} -@covar_module_argparse.register(ScaleMaternKernel) -def _covar_module_argparse_scale_matern( - covar_module_class: type[ScaleMaternKernel], +@covar_module_argparse.register((ScaleMaternKernel, ScaleRBFLinearKernel)) +def _covar_module_argparse_scale_kernel( + covar_module_class: type[ScaleMaternKernel] | type[ScaleRBFLinearKernel], botorch_model_class: type[Model], dataset: SupervisedDataset, ard_num_dims: int | _DefaultType = DEFAULT, + active_dims: Sequence[int] | None = None, batch_shape: torch.Size | _DefaultType = DEFAULT, lengthscale_prior: Prior | None = None, outputscale_prior: Prior | None = None, remove_task_features: bool = False, **kwargs: Any, ) -> dict[str, Any]: - """Extract the base covar module kwargs form the given arguments. + """Extract the covar module kwargs for ``ScaleMaternKernel`` and + ``ScaleRBFLinearKernel``. + + Both kernels share the same configurable inputs: an ARD lengthscale prior and + an output scale prior for the (scaled) base kernel. For + ``ScaleRBFLinearKernel``, the ``LinearKernel`` summand uses its default + variance prior. NOTE: This setup does not allow for setting multi-dimensional priors, with different priors over lengthscales. @@ -87,8 +95,9 @@ def _covar_module_argparse_scale_matern( BoTorch model. dataset: Dataset containing feature matrix and the response. ard_num_dims: Number of lengthscales per feature. + active_dims: The active dimensions of the kernel. batch_shape: The number of lengthscales per batch. - lengthscale_prior: Lenthscale prior. + lengthscale_prior: Lengthscale prior. outputscale_prior: Outputscale prior. remove_task_features: A boolean indicating whether to remove the task features (e.g. when using a SingleTask model on a MultiTaskDataset). @@ -102,16 +111,17 @@ def _covar_module_argparse_scale_matern( botorch_model_class=botorch_model_class, dataset=dataset, remove_task_features=remove_task_features, + active_dims=active_dims, ) return _covar_module_argparse_base( covar_module_class=covar_module_class, dataset=dataset, botorch_model_class=botorch_model_class, ard_num_dims=ard_num_dims, + active_dims=active_dims, lengthscale_prior=lengthscale_prior, outputscale_prior=outputscale_prior, batch_shape=batch_shape, - active_dims=active_dims, **kwargs, ) diff --git a/ax/generators/torch/botorch_modular/kernels.py b/ax/generators/torch/botorch_modular/kernels.py index 8df29b26261..3d1398b68c6 100644 --- a/ax/generators/torch/botorch_modular/kernels.py +++ b/ax/generators/torch/botorch_modular/kernels.py @@ -25,7 +25,8 @@ from botorch.models.utils.gpytorch_modules import SQRT2, SQRT3 from gpytorch.constraints import Interval from gpytorch.constraints.constraints import GreaterThan -from gpytorch.kernels import PeriodicKernel +from gpytorch.kernels import AdditiveKernel, PeriodicKernel +from gpytorch.kernels.linear_kernel import LinearKernel from gpytorch.kernels.matern_kernel import MaternKernel from gpytorch.kernels.rbf_kernel import RBFKernel from gpytorch.kernels.scale_kernel import ScaleKernel @@ -36,6 +37,7 @@ class ScaleMaternKernel(ScaleKernel): def __init__( self, ard_num_dims: int | None = None, + active_dims: Sequence[int] | None = None, batch_shape: torch.Size | None = None, lengthscale_prior: Prior | None = None, outputscale_prior: Prior | None = None, @@ -46,6 +48,9 @@ def __init__( r""" Args: ard_num_dims: The number of lengthscales. + active_dims: The active input dimensions. The ``ScaleKernel`` + inherits its ``active_dims`` from the wrapped ``MaternKernel``, + so the input is subset exactly once (at the wrapper level). batch_shape: The batch shape. lengthscale_prior: The prior over the lengthscale parameter. outputscale_prior: The prior over the scaling parameter. @@ -57,6 +62,7 @@ def __init__( base_kernel = MaternKernel( nu=2.5, ard_num_dims=ard_num_dims, + active_dims=tuple(active_dims) if active_dims is not None else None, lengthscale_constraint=lengthscale_constraint, lengthscale_prior=lengthscale_prior, batch_shape=batch_shape, @@ -69,6 +75,74 @@ def __init__( ) +class ScaleRBFLinearKernel(AdditiveKernel): + r"""A sum of ``ScaleKernel(RBF-ARD)`` and ``LinearKernel``. + + Combines the local flexibility of an ARD RBF kernel with a global linear + covariance structure. The ``ScaleKernel`` wraps the RBF so the model can + learn the relative amplitude of the local (RBF) vs. global (linear) signal. + + The ``LinearKernel`` models linear-in-input correlations (e.g. a linear + scaling-law trend) through the kernel rather than the mean -- the kernel-side + analog of a linear mean. + """ + + def __init__( + self, + ard_num_dims: int | None = None, + active_dims: Sequence[int] | None = None, + batch_shape: torch.Size | None = None, + lengthscale_prior: Prior | None = None, + outputscale_prior: Prior | None = None, + variance_prior: Prior | None = None, + lengthscale_constraint: Interval | None = None, + outputscale_constraint: Interval | None = None, + **kwargs: Any, + ) -> None: + r""" + Args: + ard_num_dims: The number of lengthscales for the RBF kernel. When + ``active_dims`` is provided, this must equal the number of active + dimensions. + active_dims: The active input dimensions. The same subset is applied + to both component kernels. The ``ScaleKernel`` inherits its + ``active_dims`` from the wrapped ``RBFKernel``, so the input is + subset exactly once (at the wrapper level) rather than twice. + batch_shape: The batch shape, shared by both component kernels. + lengthscale_prior: The prior over the RBF lengthscale parameter. + outputscale_prior: The prior over the RBF output scale parameter. + variance_prior: The prior over the linear kernel variance parameter. + lengthscale_constraint: Optionally provide a lengthscale constraint + for the RBF kernel. + outputscale_constraint: Optionally provide an output scale constraint + for the ``ScaleKernel`` wrapping the RBF kernel. + kwargs: Additional keyword arguments passed to the ``LinearKernel``. + + Returns: None + """ + active_dims = tuple(active_dims) if active_dims is not None else None + scale_rbf_kernel = ScaleKernel( + RBFKernel( + ard_num_dims=ard_num_dims, + active_dims=active_dims, + lengthscale_constraint=lengthscale_constraint, + lengthscale_prior=lengthscale_prior, + batch_shape=batch_shape, + ), + outputscale_prior=outputscale_prior, + outputscale_constraint=outputscale_constraint, + batch_shape=batch_shape, + ) + linear_kernel = LinearKernel( + ard_num_dims=ard_num_dims, + active_dims=active_dims, + variance_prior=variance_prior, + batch_shape=batch_shape, + **kwargs, + ) + super().__init__(scale_rbf_kernel, linear_kernel) + + class TemporalKernel(ScaleKernel): """A product kernel of a periodic kernel and a Matern kernel. diff --git a/ax/generators/torch/tests/test_covar_modules_argparse.py b/ax/generators/torch/tests/test_covar_modules_argparse.py index 9017cd72036..828bfd09f42 100644 --- a/ax/generators/torch/tests/test_covar_modules_argparse.py +++ b/ax/generators/torch/tests/test_covar_modules_argparse.py @@ -19,11 +19,12 @@ DefaultMaternKernel, DefaultRBFKernel, ScaleMaternKernel, + ScaleRBFLinearKernel, ) from ax.utils.common.testutils import TestCase from botorch.models.gp_regression import SingleTaskGP from botorch.models.multitask import MultiTaskGP -from botorch.utils.datasets import SupervisedDataset +from botorch.utils.datasets import MultiTaskDataset, SupervisedDataset from gpytorch.kernels.kernel import Kernel from gpytorch.kernels.linear_kernel import LinearKernel from gpytorch.priors import GammaPrior @@ -159,6 +160,95 @@ def test_argparse_scalematern_kernel(self) -> None: self.assertEqual(covar_module_kwargs["batch_shape"], torch.Size([])) + def test_argparse_scale_rbf_linear_kernel(self) -> None: + # SingleTaskGP infers ard_num_dims from the 10 features; the + # multi-output dataset (Y has 2 columns) yields a batch shape of [2]. + # MultiTaskGP drops the task feature, so ard_num_dims is 9 and the + # batch shape is empty. + expected = [ + {"ard_num_dims": 10, "batch_shape": torch.Size([2])}, + {"ard_num_dims": 9, "batch_shape": torch.Size([])}, + ] + for i, botorch_model_class in enumerate([SingleTaskGP, MultiTaskGP]): + covar_module_kwargs = covar_module_argparse( + ScaleRBFLinearKernel, + botorch_model_class=botorch_model_class, + dataset=self.dataset, + lengthscale_prior=GammaPrior(6.0, 3.0), + outputscale_prior=GammaPrior(2.0, 0.15), + variance_prior=GammaPrior(1.0, 1.0), + ) + self.assertEqual( + covar_module_kwargs["ard_num_dims"], expected[i]["ard_num_dims"] + ) + self.assertEqual( + covar_module_kwargs["batch_shape"], expected[i]["batch_shape"] + ) + # No active_dims requested, so none is set. + self.assertIsNone(covar_module_kwargs["active_dims"]) + # Priors are passed straight through. + self.assertAlmostEqual( + covar_module_kwargs["lengthscale_prior"].concentration.item(), + 6.0, + places=4, + ) + self.assertAlmostEqual( + covar_module_kwargs["outputscale_prior"].rate.item(), 0.15, places=4 + ) + self.assertAlmostEqual( + covar_module_kwargs["variance_prior"].concentration.item(), + 1.0, + places=4, + ) + # The resulting kwargs can construct the kernel. + kernel = ScaleRBFLinearKernel(**covar_module_kwargs) + self.assertIsInstance(kernel, ScaleRBFLinearKernel) + + def test_argparse_scale_rbf_linear_kernel_active_dims(self) -> None: + # Explicit active_dims is passed through and normalized; ard_num_dims is + # set to the number of active dims. + covar_module_kwargs = covar_module_argparse( + ScaleRBFLinearKernel, + botorch_model_class=SingleTaskGP, + dataset=self.dataset, + active_dims=[0, 2, 4], + ) + self.assertEqual(covar_module_kwargs["active_dims"], [0, 2, 4]) + self.assertEqual(covar_module_kwargs["ard_num_dims"], 3) + kernel = ScaleRBFLinearKernel(**covar_module_kwargs) + self.assertIsInstance(kernel, ScaleRBFLinearKernel) + + def test_argparse_scale_rbf_linear_kernel_remove_task_features(self) -> None: + # A SingleTaskGP on a MultiTaskDataset with remove_task_features=True + # excludes the task feature from the kernel: active_dims drops the task + # column and ard_num_dims shrinks accordingly. + n, d = 10, 5 + task_feature_index = d - 1 + X = torch.cat([torch.randn((n, d - 1)), torch.randint(0, 2, (n, 1))], dim=-1) + Y = torch.randn((n, 1)) + joint_dataset = SupervisedDataset( + X=X, + Y=Y, + feature_names=[f"x{j}" for j in range(d - 1)] + ["task"], + outcome_names=["y"], + ) + dataset = MultiTaskDataset.from_joint_dataset( + dataset=joint_dataset, + task_feature_index=task_feature_index, + target_task_value=1, + ) + covar_module_kwargs = covar_module_argparse( + ScaleRBFLinearKernel, + botorch_model_class=SingleTaskGP, + dataset=dataset, + remove_task_features=True, + ) + # The task feature (last column, index d - 1) is excluded. + self.assertEqual(covar_module_kwargs["active_dims"], list(range(d - 1))) + self.assertEqual(covar_module_kwargs["ard_num_dims"], d - 1) + kernel = ScaleRBFLinearKernel(**covar_module_kwargs) + self.assertIsInstance(kernel, ScaleRBFLinearKernel) + def test_argparse_default(self) -> None: for kernel_class in (DefaultRBFKernel, DefaultMaternKernel): with self.assertRaisesRegex(UserInputError, "Only one of"): diff --git a/ax/generators/torch/tests/test_kernels.py b/ax/generators/torch/tests/test_kernels.py index 0aab3c5a6d0..c2187f66cff 100644 --- a/ax/generators/torch/tests/test_kernels.py +++ b/ax/generators/torch/tests/test_kernels.py @@ -17,13 +17,21 @@ DefaultMaternKernel, DefaultRBFKernel, ScaleMaternKernel, + ScaleRBFLinearKernel, TemporalKernel, ) from ax.utils.common.testutils import TestCase from botorch.models.utils.gpytorch_modules import get_covar_module_with_dim_scaled_prior from gpytorch.constraints import Positive -from gpytorch.kernels import MaternKernel, PeriodicKernel +from gpytorch.kernels import ( + LinearKernel, + MaternKernel, + PeriodicKernel, + RBFKernel, + ScaleKernel, +) from gpytorch.priors import GammaPrior +from pyre_extensions import assert_is_instance class KernelsTest(TestCase): @@ -48,6 +56,86 @@ def test_scalematern_kernel(self) -> None: # `concentration`. self.assertEqual(covar.outputscale_prior.concentration, 2.0) self.assertEqual(covar.base_kernel.batch_shape[0], 2) + self.assertIsNone(covar.active_dims) + + def test_scalematern_kernel_active_dims(self) -> None: + active_dims = [0, 2] + covar = ScaleMaternKernel( + ard_num_dims=len(active_dims), active_dims=active_dims + ) + base_kernel = assert_is_instance(covar.base_kernel, MaternKernel) + # active_dims lands on the inner MaternKernel, and the ScaleKernel + # inherits it (so subsetting happens exactly once at the wrapper level). + self.assertEqual(covar.active_dims.tolist(), active_dims) + self.assertEqual(base_kernel.active_dims.tolist(), active_dims) + self.assertEqual(base_kernel.ard_num_dims, len(active_dims)) + # The kernel only consumes the active columns: perturbing an inactive + # column leaves the covariance unchanged. + X = torch.randn(5, 3) + X_perturbed = X.clone() + X_perturbed[:, 1] = torch.randn(5) # column 1 is inactive + self.assertTrue( + torch.allclose(covar(X).to_dense(), covar(X_perturbed).to_dense()) + ) + + def test_scale_rbf_linear_kernel(self) -> None: + covar = ScaleRBFLinearKernel( + ard_num_dims=10, + lengthscale_prior=GammaPrior(6.0, 3.0), + outputscale_prior=GammaPrior(2.0, 0.15), + variance_prior=GammaPrior(1.0, 1.0), + batch_shape=torch.Size([2]), + ) + # The kernel is a sum of a ScaleKernel(RBF) and a LinearKernel. + scale_rbf = assert_is_instance(covar.kernels[0], ScaleKernel) + linear = assert_is_instance(covar.kernels[1], LinearKernel) + rbf = assert_is_instance(scale_rbf.base_kernel, RBFKernel) + # RBF lengthscale prior and ard_num_dims. + self.assertEqual(rbf.ard_num_dims, 10) + lengthscale_prior = assert_is_instance(rbf.lengthscale_prior, GammaPrior) + self.assertEqual(lengthscale_prior.rate, 3.0) + self.assertEqual(lengthscale_prior.concentration, 6.0) + # Outputscale prior on the ScaleKernel. + outputscale_prior = assert_is_instance(scale_rbf.outputscale_prior, GammaPrior) + self.assertEqual(outputscale_prior.rate, 0.15) + self.assertEqual(outputscale_prior.concentration, 2.0) + # Variance prior on the LinearKernel. + variance_prior = assert_is_instance(linear.variance_prior, GammaPrior) + self.assertEqual(variance_prior.rate, 1.0) + self.assertEqual(variance_prior.concentration, 1.0) + # Batch shape is shared by both components. + self.assertEqual(rbf.batch_shape, torch.Size([2])) + self.assertEqual(linear.batch_shape, torch.Size([2])) + # The kernel evaluates and produces a PSD covariance. + X = torch.randn(2, 5, 10) + covar_matrix = covar(X).to_dense() + self.assertEqual(covar_matrix.shape, torch.Size([2, 5, 5])) + + def test_scale_rbf_linear_kernel_active_dims(self) -> None: + active_dims = [0, 2] + covar = ScaleRBFLinearKernel( + ard_num_dims=len(active_dims), + active_dims=active_dims, + ) + scale_rbf = assert_is_instance(covar.kernels[0], ScaleKernel) + linear = assert_is_instance(covar.kernels[1], LinearKernel) + rbf = assert_is_instance(scale_rbf.base_kernel, RBFKernel) + # active_dims lands on both leaves, and the ScaleKernel inherits it + # from the wrapped RBF kernel (so subsetting happens exactly once). + self.assertEqual(scale_rbf.active_dims.tolist(), active_dims) + self.assertEqual(rbf.active_dims.tolist(), active_dims) + self.assertEqual(linear.active_dims.tolist(), active_dims) + # ard_num_dims matches the number of active dims. + self.assertEqual(rbf.ard_num_dims, len(active_dims)) + # The kernel only consumes the active columns: it produces the same + # covariance regardless of the values in the inactive column. + X = torch.randn(5, 3) + X_perturbed = X.clone() + X_perturbed[:, 1] = torch.randn(5) # column 1 is inactive + covar_matrix = covar(X).to_dense() + covar_matrix_perturbed = covar(X_perturbed).to_dense() + self.assertTrue(torch.allclose(covar_matrix, covar_matrix_perturbed)) + self.assertEqual(covar_matrix.shape, torch.Size([5, 5])) def test_temporal_kernel(self) -> None: ls_prior = GammaPrior(6.0, 3.0) diff --git a/ax/storage/botorch_modular_registry.py b/ax/storage/botorch_modular_registry.py index 6328349c8f5..4dba5039f44 100644 --- a/ax/storage/botorch_modular_registry.py +++ b/ax/storage/botorch_modular_registry.py @@ -16,6 +16,7 @@ DefaultMaternKernel, DefaultRBFKernel, ScaleMaternKernel, + ScaleRBFLinearKernel, ) from ax.generators.torch.botorch_modular.multi_acquisition import MultiAcquisition @@ -193,6 +194,7 @@ KERNEL_REGISTRY: dict[type[Kernel], str] = { LinearKernel: "LinearKernel", ScaleMaternKernel: "ScaleMaternKernel", + ScaleRBFLinearKernel: "ScaleRBFLinearKernel", RBFKernel: "RBFKernel", DefaultRBFKernel: "DefaultRBFKernel", DefaultMaternKernel: "DefaultMaternKernel",