Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/somd2/config/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,9 @@ def __init__(
the rest of the system. This can either be a single scaling factor, or a list of
scale factors for each lambda window. When a single scaling factor is used, then
the scale factor will be interpolated between a value of 1.0 in the end states,
and the value of 'rest2_scale' in intermediate lambda = 0.5 state.
and the value of 'rest2_scale' in intermediate lambda = 0.5 state. When multiple
values are used, then the number should match the number of lambda windows at which
energies are sampled.

rest2_selection: str
A sire selection string for atoms to include in the REST2 region in
Expand Down
22 changes: 22 additions & 0 deletions src/somd2/runner/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,25 @@ def __init__(self, system, config):
else:
self._lambda_energy = self._lambda_values

# Make sure the lambda values are in the lambda energy list.
is_missing = False
for lambda_value in self._lambda_values:
if lambda_value not in self._lambda_energy:
self._lambda_energy.append(lambda_value)
is_missing = True

# Make sure the lambda_values entries are unique.
if not len(self._lambda_values) == len(set(self._lambda_values)):
msg = "Duplicate entries in 'lambda_values' list."
_logger.error(msg)
raise ValueError(msg)

# Make sure the lambda_energy entries are unique.
if not len(self._lambda_energy) == len(set(self._lambda_energy)):
msg = "Duplicate entries in 'lambda_energy' list."
_logger.error(msg)
raise ValueError(msg)

from math import isclose

# Set the REST2 scale factors.
Expand All @@ -258,6 +277,9 @@ def __init__(self, system, config):
else:
if len(self._config.rest2_scale) != len(self._lambda_energy):
msg = f"Length of 'rest2_scale' must match the number of {_lam_sym} values."
if is_missing:
msg += f"If you have omitted some 'lambda_values` from `lambda_energy`, please "
f"add them to `lambda_energy`, along with the corresponding `rest2_scale` values."
_logger.error(msg)
raise ValueError(msg)
# Make sure the end states are close to 1.0.
Expand Down
9 changes: 4 additions & 5 deletions src/somd2/runner/_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,8 +317,11 @@ def _run(
# Get the lambda value.
lambda_value = self._lambda_values[index]

# Get the index in the lambda_energy array.
nrg_index = self._lambda_energy.index(lambda_value)

# Get the REST2 scaling factor.
rest2_scale = self._rest2_scale_factors[index]
rest2_scale = self._rest2_scale_factors[nrg_index]

# Check for completion if this is a restart.
if is_restart:
Expand Down Expand Up @@ -445,10 +448,6 @@ def generate_lam_vals(lambda_base, increment=0.001):
# Create the array of lambda values for energy sampling.
lambda_energy = self._lambda_energy.copy()

# If missing, add the lambda value.
if lambda_value not in self._lambda_energy:
lambda_energy.append(lambda_value)

# Sort the lambda values.
lambda_energy = sorted(lambda_energy)

Expand Down
12 changes: 6 additions & 6 deletions tests/runner/test_lambda_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,10 @@ def test_lambda_energy(ethane_methanol):
)

# Make sure the lambda_array in the metadata is correct. This is the
# sampled lambda plus the lambda_energy values in the config.
assert meta["lambda_array"] == [0.0, 0.5]
# sampled lambda_values plus the lambda_energy values in the config.
assert meta["lambda_array"] == [0.0, 0.5, 1.0]

# Make sure the second dimension of the energy trajectory is the correct
# size. This is one for the current lambda value, one for its gradient,
# and one for the length of lambda_energy.
assert energy_traj.shape[1] == 3
# Make sure the second dimension of the energy trajectory is the correct.
# This is the sampled lambda values, i.e. unique entries from lambda_values
# and lambda_energy, plus the gradient for TI.
assert energy_traj.shape[1] == 4
Loading