Skip to content

Add ScaleRBFLinearKernel covar module#5233

Open
saitcakmak wants to merge 1 commit into
facebook:mainfrom
saitcakmak:export-D109044330
Open

Add ScaleRBFLinearKernel covar module#5233
saitcakmak wants to merge 1 commit into
facebook:mainfrom
saitcakmak:export-D109044330

Conversation

@saitcakmak

Copy link
Copy Markdown
Contributor

Summary:
Following the ScaleMaternKernel pattern in Ax botorch_modular, this adds ScaleRBFLinearKernel — a sum of ScaleKernel(RBF-ARD) and LinearKernel. It is the kernel-side analog of a linear mean: the LinearKernel models linear-in-input correlations (e.g. a scaling-law trend) through the covariance rather than the mean, while the ScaleKernel-wrapped ARD RBF captures local flexibility and lets the model learn the relative amplitude of the local vs. global signal. This is the kernel-only counterpart of the RBFLinearGP model studied in D108386392, designed to plug into a plain SingleTaskGP via ModelConfig(covar_module_class=...).

ScaleRBFLinearKernel subclasses GPyTorch's AdditiveKernel so that isinstance checks and storage serialization work cleanly. Its __init__ mirrors ScaleMaternKernel, adding a variance_prior for the linear component and an active_dims argument.

active_dims is applied to both component kernels: it is set on the inner RBFKernel and on the LinearKernel, and the ScaleKernel inherits active_dims from the kernel it wraps. Because GPyTorch subsets inputs once at the outermost kernel in a call chain (and ScaleKernel.forward calls base_kernel.forward directly), the input is subset exactly once rather than twice. Threading active_dims through the input constructor also enables the remove_task_features workflow (excluding a task feature from the kernel for a SingleTaskGP on a MultiTaskDataset).

Changes:

  • kernels.py: new ScaleRBFLinearKernel(AdditiveKernel) with active_dims support.
  • input_constructors/covar_modules.py: register _covar_module_argparse_scale_rbf_linear, reusing _get_default_ard_num_dims_and_batch_shape to infer ard_num_dims/batch_shape/active_dims from the dataset and model class.
  • botorch_modular_registry.py: add ScaleRBFLinearKernel to KERNEL_REGISTRY (this automatically wires up JSON/SQA encode+decode).
  • Tests in test_kernels.py (incl. an active-dims single-subsetting check) and test_covar_modules_argparse.py (incl. active-dims and remove_task_features cases).

Differential Revision: D109044330

Summary:
Following the `ScaleMaternKernel` pattern in Ax botorch_modular, this adds `ScaleRBFLinearKernel` — a sum of `ScaleKernel(RBF-ARD)` and `LinearKernel`. It is the kernel-side analog of a linear mean: the `LinearKernel` models linear-in-input correlations (e.g. a scaling-law trend) through the covariance rather than the mean, while the `ScaleKernel`-wrapped ARD RBF captures local flexibility and lets the model learn the relative amplitude of the local vs. global signal. This is the kernel-only counterpart of the `RBFLinearGP` model studied in D108386392, designed to plug into a plain `SingleTaskGP` via `ModelConfig(covar_module_class=...)`.

`ScaleRBFLinearKernel` subclasses GPyTorch's `AdditiveKernel` so that `isinstance` checks and storage serialization work cleanly. Its `__init__` mirrors `ScaleMaternKernel`, adding a `variance_prior` for the linear component and an `active_dims` argument.

`active_dims` is applied to both component kernels: it is set on the inner `RBFKernel` and on the `LinearKernel`, and the `ScaleKernel` inherits `active_dims` from the kernel it wraps. Because GPyTorch subsets inputs once at the outermost kernel in a call chain (and `ScaleKernel.forward` calls `base_kernel.forward` directly), the input is subset exactly once rather than twice. Threading `active_dims` through the input constructor also enables the `remove_task_features` workflow (excluding a task feature from the kernel for a `SingleTaskGP` on a `MultiTaskDataset`).

Changes:
- `kernels.py`: new `ScaleRBFLinearKernel(AdditiveKernel)` with `active_dims` support.
- `input_constructors/covar_modules.py`: register `_covar_module_argparse_scale_rbf_linear`, reusing `_get_default_ard_num_dims_and_batch_shape` to infer `ard_num_dims`/`batch_shape`/`active_dims` from the dataset and model class.
- `botorch_modular_registry.py`: add `ScaleRBFLinearKernel` to `KERNEL_REGISTRY` (this automatically wires up JSON/SQA encode+decode).
- Tests in `test_kernels.py` (incl. an active-dims single-subsetting check) and `test_covar_modules_argparse.py` (incl. active-dims and remove_task_features cases).

Differential Revision: D109044330
@meta-cla meta-cla Bot added the CLA Signed Do not delete this pull request or issue due to inactivity. label Jun 18, 2026
@meta-codesync

meta-codesync Bot commented Jun 18, 2026

Copy link
Copy Markdown

@saitcakmak has exported this pull request. If you are a Meta employee, you can view the originating Diff in D109044330.

@codecov-commenter

Copy link
Copy Markdown

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 96.56%. Comparing base (9bd6ec8) to head (36a5255).

Additional details and impacted files
@@           Coverage Diff           @@
##             main    #5233   +/-   ##
=======================================
  Coverage   96.56%   96.56%           
=======================================
  Files         619      619           
  Lines       70168    70255   +87     
=======================================
+ Hits        67756    67843   +87     
  Misses       2412     2412           

☔ View full report in Codecov by Harness.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
  • 📦 JS Bundle Analysis: Save yourself from yourself by tracking and limiting bundle sizes in JS merges.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed Do not delete this pull request or issue due to inactivity. meta-exported

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants