From 13366acdae43e5cf83570eef4b2855baa0c5303b Mon Sep 17 00:00:00 2001 From: Lester Hedges Date: Tue, 13 May 2025 11:22:06 +0100 Subject: [PATCH] Fix handling of REST2 scale factors when subsampling. [closes #69] --- src/somd2/config/_config.py | 4 +++- src/somd2/runner/_base.py | 22 ++++++++++++++++++++++ src/somd2/runner/_runner.py | 9 ++++----- tests/runner/test_lambda_values.py | 12 ++++++------ 4 files changed, 35 insertions(+), 12 deletions(-) diff --git a/src/somd2/config/_config.py b/src/somd2/config/_config.py index 5d039404..0978ee82 100644 --- a/src/somd2/config/_config.py +++ b/src/somd2/config/_config.py @@ -305,7 +305,9 @@ def __init__( the rest of the system. This can either be a single scaling factor, or a list of scale factors for each lambda window. When a single scaling factor is used, then the scale factor will be interpolated between a value of 1.0 in the end states, - and the value of 'rest2_scale' in intermediate lambda = 0.5 state. + and the value of 'rest2_scale' in intermediate lambda = 0.5 state. When multiple + values are used, then the number should match the number of lambda windows at which + energies are sampled. rest2_selection: str A sire selection string for atoms to include in the REST2 region in diff --git a/src/somd2/runner/_base.py b/src/somd2/runner/_base.py index 5f4d724a..911587a7 100644 --- a/src/somd2/runner/_base.py +++ b/src/somd2/runner/_base.py @@ -240,6 +240,25 @@ def __init__(self, system, config): else: self._lambda_energy = self._lambda_values + # Make sure the lambda values are in the lambda energy list. + is_missing = False + for lambda_value in self._lambda_values: + if lambda_value not in self._lambda_energy: + self._lambda_energy.append(lambda_value) + is_missing = True + + # Make sure the lambda_values entries are unique. + if not len(self._lambda_values) == len(set(self._lambda_values)): + msg = "Duplicate entries in 'lambda_values' list." + _logger.error(msg) + raise ValueError(msg) + + # Make sure the lambda_energy entries are unique. + if not len(self._lambda_energy) == len(set(self._lambda_energy)): + msg = "Duplicate entries in 'lambda_energy' list." + _logger.error(msg) + raise ValueError(msg) + from math import isclose # Set the REST2 scale factors. @@ -258,6 +277,9 @@ def __init__(self, system, config): else: if len(self._config.rest2_scale) != len(self._lambda_energy): msg = f"Length of 'rest2_scale' must match the number of {_lam_sym} values." + if is_missing: + msg += f"If you have omitted some 'lambda_values` from `lambda_energy`, please " + f"add them to `lambda_energy`, along with the corresponding `rest2_scale` values." _logger.error(msg) raise ValueError(msg) # Make sure the end states are close to 1.0. diff --git a/src/somd2/runner/_runner.py b/src/somd2/runner/_runner.py index 13e59791..902e5a8d 100644 --- a/src/somd2/runner/_runner.py +++ b/src/somd2/runner/_runner.py @@ -317,8 +317,11 @@ def _run( # Get the lambda value. lambda_value = self._lambda_values[index] + # Get the index in the lambda_energy array. + nrg_index = self._lambda_energy.index(lambda_value) + # Get the REST2 scaling factor. - rest2_scale = self._rest2_scale_factors[index] + rest2_scale = self._rest2_scale_factors[nrg_index] # Check for completion if this is a restart. if is_restart: @@ -445,10 +448,6 @@ def generate_lam_vals(lambda_base, increment=0.001): # Create the array of lambda values for energy sampling. lambda_energy = self._lambda_energy.copy() - # If missing, add the lambda value. - if lambda_value not in self._lambda_energy: - lambda_energy.append(lambda_value) - # Sort the lambda values. lambda_energy = sorted(lambda_energy) diff --git a/tests/runner/test_lambda_values.py b/tests/runner/test_lambda_values.py index daadb78c..c19f5b28 100644 --- a/tests/runner/test_lambda_values.py +++ b/tests/runner/test_lambda_values.py @@ -85,10 +85,10 @@ def test_lambda_energy(ethane_methanol): ) # Make sure the lambda_array in the metadata is correct. This is the - # sampled lambda plus the lambda_energy values in the config. - assert meta["lambda_array"] == [0.0, 0.5] + # sampled lambda_values plus the lambda_energy values in the config. + assert meta["lambda_array"] == [0.0, 0.5, 1.0] - # Make sure the second dimension of the energy trajectory is the correct - # size. This is one for the current lambda value, one for its gradient, - # and one for the length of lambda_energy. - assert energy_traj.shape[1] == 3 + # Make sure the second dimension of the energy trajectory is the correct. + # This is the sampled lambda values, i.e. unique entries from lambda_values + # and lambda_energy, plus the gradient for TI. + assert energy_traj.shape[1] == 4