Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
DefaultMaternKernel,
DefaultRBFKernel,
ScaleMaternKernel,
ScaleRBFLinearKernel,
)
from ax.utils.common.logger import get_logger
from ax.utils.common.typeutils import _argparse_type_encoder
Expand Down Expand Up @@ -64,19 +65,26 @@ def _covar_module_argparse_base(
return {**kwargs}


@covar_module_argparse.register(ScaleMaternKernel)
def _covar_module_argparse_scale_matern(
covar_module_class: type[ScaleMaternKernel],
@covar_module_argparse.register((ScaleMaternKernel, ScaleRBFLinearKernel))
def _covar_module_argparse_scale_kernel(
covar_module_class: type[ScaleMaternKernel] | type[ScaleRBFLinearKernel],
botorch_model_class: type[Model],
dataset: SupervisedDataset,
ard_num_dims: int | _DefaultType = DEFAULT,
active_dims: Sequence[int] | None = None,
batch_shape: torch.Size | _DefaultType = DEFAULT,
lengthscale_prior: Prior | None = None,
outputscale_prior: Prior | None = None,
remove_task_features: bool = False,
**kwargs: Any,
) -> dict[str, Any]:
"""Extract the base covar module kwargs form the given arguments.
"""Extract the covar module kwargs for ``ScaleMaternKernel`` and
``ScaleRBFLinearKernel``.

Both kernels share the same configurable inputs: an ARD lengthscale prior and
an output scale prior for the (scaled) base kernel. For
``ScaleRBFLinearKernel``, the ``LinearKernel`` summand uses its default
variance prior.

NOTE: This setup does not allow for setting multi-dimensional priors,
with different priors over lengthscales.
Expand All @@ -87,8 +95,9 @@ def _covar_module_argparse_scale_matern(
BoTorch model.
dataset: Dataset containing feature matrix and the response.
ard_num_dims: Number of lengthscales per feature.
active_dims: The active dimensions of the kernel.
batch_shape: The number of lengthscales per batch.
lengthscale_prior: Lenthscale prior.
lengthscale_prior: Lengthscale prior.
outputscale_prior: Outputscale prior.
remove_task_features: A boolean indicating whether to remove the task
features (e.g. when using a SingleTask model on a MultiTaskDataset).
Expand All @@ -102,16 +111,17 @@ def _covar_module_argparse_scale_matern(
botorch_model_class=botorch_model_class,
dataset=dataset,
remove_task_features=remove_task_features,
active_dims=active_dims,
)
return _covar_module_argparse_base(
covar_module_class=covar_module_class,
dataset=dataset,
botorch_model_class=botorch_model_class,
ard_num_dims=ard_num_dims,
active_dims=active_dims,
lengthscale_prior=lengthscale_prior,
outputscale_prior=outputscale_prior,
batch_shape=batch_shape,
active_dims=active_dims,
**kwargs,
)

Expand Down
76 changes: 75 additions & 1 deletion ax/generators/torch/botorch_modular/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
from botorch.models.utils.gpytorch_modules import SQRT2, SQRT3
from gpytorch.constraints import Interval
from gpytorch.constraints.constraints import GreaterThan
from gpytorch.kernels import PeriodicKernel
from gpytorch.kernels import AdditiveKernel, PeriodicKernel
from gpytorch.kernels.linear_kernel import LinearKernel
from gpytorch.kernels.matern_kernel import MaternKernel
from gpytorch.kernels.rbf_kernel import RBFKernel
from gpytorch.kernels.scale_kernel import ScaleKernel
Expand All @@ -36,6 +37,7 @@ class ScaleMaternKernel(ScaleKernel):
def __init__(
self,
ard_num_dims: int | None = None,
active_dims: Sequence[int] | None = None,
batch_shape: torch.Size | None = None,
lengthscale_prior: Prior | None = None,
outputscale_prior: Prior | None = None,
Expand All @@ -46,6 +48,9 @@ def __init__(
r"""
Args:
ard_num_dims: The number of lengthscales.
active_dims: The active input dimensions. The ``ScaleKernel``
inherits its ``active_dims`` from the wrapped ``MaternKernel``,
so the input is subset exactly once (at the wrapper level).
batch_shape: The batch shape.
lengthscale_prior: The prior over the lengthscale parameter.
outputscale_prior: The prior over the scaling parameter.
Expand All @@ -57,6 +62,7 @@ def __init__(
base_kernel = MaternKernel(
nu=2.5,
ard_num_dims=ard_num_dims,
active_dims=tuple(active_dims) if active_dims is not None else None,
lengthscale_constraint=lengthscale_constraint,
lengthscale_prior=lengthscale_prior,
batch_shape=batch_shape,
Expand All @@ -69,6 +75,74 @@ def __init__(
)


class ScaleRBFLinearKernel(AdditiveKernel):
r"""A sum of ``ScaleKernel(RBF-ARD)`` and ``LinearKernel``.

Combines the local flexibility of an ARD RBF kernel with a global linear
covariance structure. The ``ScaleKernel`` wraps the RBF so the model can
learn the relative amplitude of the local (RBF) vs. global (linear) signal.

The ``LinearKernel`` models linear-in-input correlations (e.g. a linear
scaling-law trend) through the kernel rather than the mean -- the kernel-side
analog of a linear mean.
"""

def __init__(
self,
ard_num_dims: int | None = None,
active_dims: Sequence[int] | None = None,
batch_shape: torch.Size | None = None,
lengthscale_prior: Prior | None = None,
outputscale_prior: Prior | None = None,
variance_prior: Prior | None = None,
lengthscale_constraint: Interval | None = None,
outputscale_constraint: Interval | None = None,
**kwargs: Any,
) -> None:
r"""
Args:
ard_num_dims: The number of lengthscales for the RBF kernel. When
``active_dims`` is provided, this must equal the number of active
dimensions.
active_dims: The active input dimensions. The same subset is applied
to both component kernels. The ``ScaleKernel`` inherits its
``active_dims`` from the wrapped ``RBFKernel``, so the input is
subset exactly once (at the wrapper level) rather than twice.
batch_shape: The batch shape, shared by both component kernels.
lengthscale_prior: The prior over the RBF lengthscale parameter.
outputscale_prior: The prior over the RBF output scale parameter.
variance_prior: The prior over the linear kernel variance parameter.
lengthscale_constraint: Optionally provide a lengthscale constraint
for the RBF kernel.
outputscale_constraint: Optionally provide an output scale constraint
for the ``ScaleKernel`` wrapping the RBF kernel.
kwargs: Additional keyword arguments passed to the ``LinearKernel``.

Returns: None
"""
active_dims = tuple(active_dims) if active_dims is not None else None
scale_rbf_kernel = ScaleKernel(
RBFKernel(
ard_num_dims=ard_num_dims,
active_dims=active_dims,
lengthscale_constraint=lengthscale_constraint,
lengthscale_prior=lengthscale_prior,
batch_shape=batch_shape,
),
outputscale_prior=outputscale_prior,
outputscale_constraint=outputscale_constraint,
batch_shape=batch_shape,
)
linear_kernel = LinearKernel(
ard_num_dims=ard_num_dims,
active_dims=active_dims,
variance_prior=variance_prior,
batch_shape=batch_shape,
**kwargs,
)
super().__init__(scale_rbf_kernel, linear_kernel)


class TemporalKernel(ScaleKernel):
"""A product kernel of a periodic kernel and a Matern kernel.

Expand Down
92 changes: 91 additions & 1 deletion ax/generators/torch/tests/test_covar_modules_argparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@
DefaultMaternKernel,
DefaultRBFKernel,
ScaleMaternKernel,
ScaleRBFLinearKernel,
)
from ax.utils.common.testutils import TestCase
from botorch.models.gp_regression import SingleTaskGP
from botorch.models.multitask import MultiTaskGP
from botorch.utils.datasets import SupervisedDataset
from botorch.utils.datasets import MultiTaskDataset, SupervisedDataset
from gpytorch.kernels.kernel import Kernel
from gpytorch.kernels.linear_kernel import LinearKernel
from gpytorch.priors import GammaPrior
Expand Down Expand Up @@ -159,6 +160,95 @@ def test_argparse_scalematern_kernel(self) -> None:

self.assertEqual(covar_module_kwargs["batch_shape"], torch.Size([]))

def test_argparse_scale_rbf_linear_kernel(self) -> None:
# SingleTaskGP infers ard_num_dims from the 10 features; the
# multi-output dataset (Y has 2 columns) yields a batch shape of [2].
# MultiTaskGP drops the task feature, so ard_num_dims is 9 and the
# batch shape is empty.
expected = [
{"ard_num_dims": 10, "batch_shape": torch.Size([2])},
{"ard_num_dims": 9, "batch_shape": torch.Size([])},
]
for i, botorch_model_class in enumerate([SingleTaskGP, MultiTaskGP]):
covar_module_kwargs = covar_module_argparse(
ScaleRBFLinearKernel,
botorch_model_class=botorch_model_class,
dataset=self.dataset,
lengthscale_prior=GammaPrior(6.0, 3.0),
outputscale_prior=GammaPrior(2.0, 0.15),
variance_prior=GammaPrior(1.0, 1.0),
)
self.assertEqual(
covar_module_kwargs["ard_num_dims"], expected[i]["ard_num_dims"]
)
self.assertEqual(
covar_module_kwargs["batch_shape"], expected[i]["batch_shape"]
)
# No active_dims requested, so none is set.
self.assertIsNone(covar_module_kwargs["active_dims"])
# Priors are passed straight through.
self.assertAlmostEqual(
covar_module_kwargs["lengthscale_prior"].concentration.item(),
6.0,
places=4,
)
self.assertAlmostEqual(
covar_module_kwargs["outputscale_prior"].rate.item(), 0.15, places=4
)
self.assertAlmostEqual(
covar_module_kwargs["variance_prior"].concentration.item(),
1.0,
places=4,
)
# The resulting kwargs can construct the kernel.
kernel = ScaleRBFLinearKernel(**covar_module_kwargs)
self.assertIsInstance(kernel, ScaleRBFLinearKernel)

def test_argparse_scale_rbf_linear_kernel_active_dims(self) -> None:
# Explicit active_dims is passed through and normalized; ard_num_dims is
# set to the number of active dims.
covar_module_kwargs = covar_module_argparse(
ScaleRBFLinearKernel,
botorch_model_class=SingleTaskGP,
dataset=self.dataset,
active_dims=[0, 2, 4],
)
self.assertEqual(covar_module_kwargs["active_dims"], [0, 2, 4])
self.assertEqual(covar_module_kwargs["ard_num_dims"], 3)
kernel = ScaleRBFLinearKernel(**covar_module_kwargs)
self.assertIsInstance(kernel, ScaleRBFLinearKernel)

def test_argparse_scale_rbf_linear_kernel_remove_task_features(self) -> None:
# A SingleTaskGP on a MultiTaskDataset with remove_task_features=True
# excludes the task feature from the kernel: active_dims drops the task
# column and ard_num_dims shrinks accordingly.
n, d = 10, 5
task_feature_index = d - 1
X = torch.cat([torch.randn((n, d - 1)), torch.randint(0, 2, (n, 1))], dim=-1)
Y = torch.randn((n, 1))
joint_dataset = SupervisedDataset(
X=X,
Y=Y,
feature_names=[f"x{j}" for j in range(d - 1)] + ["task"],
outcome_names=["y"],
)
dataset = MultiTaskDataset.from_joint_dataset(
dataset=joint_dataset,
task_feature_index=task_feature_index,
target_task_value=1,
)
covar_module_kwargs = covar_module_argparse(
ScaleRBFLinearKernel,
botorch_model_class=SingleTaskGP,
dataset=dataset,
remove_task_features=True,
)
# The task feature (last column, index d - 1) is excluded.
self.assertEqual(covar_module_kwargs["active_dims"], list(range(d - 1)))
self.assertEqual(covar_module_kwargs["ard_num_dims"], d - 1)
kernel = ScaleRBFLinearKernel(**covar_module_kwargs)
self.assertIsInstance(kernel, ScaleRBFLinearKernel)

def test_argparse_default(self) -> None:
for kernel_class in (DefaultRBFKernel, DefaultMaternKernel):
with self.assertRaisesRegex(UserInputError, "Only one of"):
Expand Down
90 changes: 89 additions & 1 deletion ax/generators/torch/tests/test_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,21 @@
DefaultMaternKernel,
DefaultRBFKernel,
ScaleMaternKernel,
ScaleRBFLinearKernel,
TemporalKernel,
)
from ax.utils.common.testutils import TestCase
from botorch.models.utils.gpytorch_modules import get_covar_module_with_dim_scaled_prior
from gpytorch.constraints import Positive
from gpytorch.kernels import MaternKernel, PeriodicKernel
from gpytorch.kernels import (
LinearKernel,
MaternKernel,
PeriodicKernel,
RBFKernel,
ScaleKernel,
)
from gpytorch.priors import GammaPrior
from pyre_extensions import assert_is_instance


class KernelsTest(TestCase):
Expand All @@ -48,6 +56,86 @@ def test_scalematern_kernel(self) -> None:
# `concentration`.
self.assertEqual(covar.outputscale_prior.concentration, 2.0)
self.assertEqual(covar.base_kernel.batch_shape[0], 2)
self.assertIsNone(covar.active_dims)

def test_scalematern_kernel_active_dims(self) -> None:
active_dims = [0, 2]
covar = ScaleMaternKernel(
ard_num_dims=len(active_dims), active_dims=active_dims
)
base_kernel = assert_is_instance(covar.base_kernel, MaternKernel)
# active_dims lands on the inner MaternKernel, and the ScaleKernel
# inherits it (so subsetting happens exactly once at the wrapper level).
self.assertEqual(covar.active_dims.tolist(), active_dims)
self.assertEqual(base_kernel.active_dims.tolist(), active_dims)
self.assertEqual(base_kernel.ard_num_dims, len(active_dims))
# The kernel only consumes the active columns: perturbing an inactive
# column leaves the covariance unchanged.
X = torch.randn(5, 3)
X_perturbed = X.clone()
X_perturbed[:, 1] = torch.randn(5) # column 1 is inactive
self.assertTrue(
torch.allclose(covar(X).to_dense(), covar(X_perturbed).to_dense())
)

def test_scale_rbf_linear_kernel(self) -> None:
covar = ScaleRBFLinearKernel(
ard_num_dims=10,
lengthscale_prior=GammaPrior(6.0, 3.0),
outputscale_prior=GammaPrior(2.0, 0.15),
variance_prior=GammaPrior(1.0, 1.0),
batch_shape=torch.Size([2]),
)
# The kernel is a sum of a ScaleKernel(RBF) and a LinearKernel.
scale_rbf = assert_is_instance(covar.kernels[0], ScaleKernel)
linear = assert_is_instance(covar.kernels[1], LinearKernel)
rbf = assert_is_instance(scale_rbf.base_kernel, RBFKernel)
# RBF lengthscale prior and ard_num_dims.
self.assertEqual(rbf.ard_num_dims, 10)
lengthscale_prior = assert_is_instance(rbf.lengthscale_prior, GammaPrior)
self.assertEqual(lengthscale_prior.rate, 3.0)
self.assertEqual(lengthscale_prior.concentration, 6.0)
# Outputscale prior on the ScaleKernel.
outputscale_prior = assert_is_instance(scale_rbf.outputscale_prior, GammaPrior)
self.assertEqual(outputscale_prior.rate, 0.15)
self.assertEqual(outputscale_prior.concentration, 2.0)
# Variance prior on the LinearKernel.
variance_prior = assert_is_instance(linear.variance_prior, GammaPrior)
self.assertEqual(variance_prior.rate, 1.0)
self.assertEqual(variance_prior.concentration, 1.0)
# Batch shape is shared by both components.
self.assertEqual(rbf.batch_shape, torch.Size([2]))
self.assertEqual(linear.batch_shape, torch.Size([2]))
# The kernel evaluates and produces a PSD covariance.
X = torch.randn(2, 5, 10)
covar_matrix = covar(X).to_dense()
self.assertEqual(covar_matrix.shape, torch.Size([2, 5, 5]))

def test_scale_rbf_linear_kernel_active_dims(self) -> None:
active_dims = [0, 2]
covar = ScaleRBFLinearKernel(
ard_num_dims=len(active_dims),
active_dims=active_dims,
)
scale_rbf = assert_is_instance(covar.kernels[0], ScaleKernel)
linear = assert_is_instance(covar.kernels[1], LinearKernel)
rbf = assert_is_instance(scale_rbf.base_kernel, RBFKernel)
# active_dims lands on both leaves, and the ScaleKernel inherits it
# from the wrapped RBF kernel (so subsetting happens exactly once).
self.assertEqual(scale_rbf.active_dims.tolist(), active_dims)
self.assertEqual(rbf.active_dims.tolist(), active_dims)
self.assertEqual(linear.active_dims.tolist(), active_dims)
# ard_num_dims matches the number of active dims.
self.assertEqual(rbf.ard_num_dims, len(active_dims))
# The kernel only consumes the active columns: it produces the same
# covariance regardless of the values in the inactive column.
X = torch.randn(5, 3)
X_perturbed = X.clone()
X_perturbed[:, 1] = torch.randn(5) # column 1 is inactive
covar_matrix = covar(X).to_dense()
covar_matrix_perturbed = covar(X_perturbed).to_dense()
self.assertTrue(torch.allclose(covar_matrix, covar_matrix_perturbed))
self.assertEqual(covar_matrix.shape, torch.Size([5, 5]))

def test_temporal_kernel(self) -> None:
ls_prior = GammaPrior(6.0, 3.0)
Expand Down
Loading
Loading