Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 36 additions & 1 deletion descent/targets/energy.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@

DATA_SCHEMA = pyarrow.schema(
[
("id", pyarrow.string()),
("smiles", pyarrow.string()),
("coords", pyarrow.list_(pyarrow.float64())),
("box_vectors", pyarrow.list_(pyarrow.float64())),
("energy", pyarrow.list_(pyarrow.float64())),
("forces", pyarrow.list_(pyarrow.float64())),
]
Expand All @@ -22,6 +24,9 @@
class Entry(typing.TypedDict):
"""Represents a set of reference energies and forces."""

id: str | None
"""An optional identifier for the entry (e.g. a run name). Defaults to ``None``."""

smiles: str
"""The indexed SMILES description of the molecule the energies and forces were
computed for."""
Expand All @@ -34,6 +39,10 @@ class Entry(typing.TypedDict):
forces: torch.Tensor
"""The reference forces [kcal/mol/Å] with ``shape=(n_confs, n_particles, 3)``."""

box_vectors: torch.Tensor | None
"""The box vectors [Å] for periodic systems with ``shape=(n_confs, 3, 3)``, or
``None`` for non-periodic systems."""


def create_dataset(entries: list[Entry]) -> datasets.Dataset:
"""Create a dataset from a list of existing entries.
Expand All @@ -48,8 +57,12 @@ def create_dataset(entries: list[Entry]) -> datasets.Dataset:
table = pyarrow.Table.from_pylist(
[
{
"id": entry.get("id"),
"smiles": entry["smiles"],
"coords": torch.tensor(entry["coords"]).flatten().tolist(),
"box_vectors": None
if entry.get("box_vectors") is None
else torch.tensor(entry["box_vectors"]).flatten().tolist(),
"energy": torch.tensor(entry["energy"]).flatten().tolist(),
"forces": torch.tensor(entry["forces"]).flatten().tolist(),
}
Expand Down Expand Up @@ -82,8 +95,12 @@ def create_dataset_from_generator(
def _gen():
for entry in gen_fn():
yield {
"id": entry.get("id"),
"smiles": entry["smiles"],
"coords": torch.tensor(entry["coords"]).flatten().tolist(),
"box_vectors": None
if entry.get("box_vectors") is None
else torch.tensor(entry["box_vectors"]).flatten().tolist(),
"energy": torch.tensor(entry["energy"]).flatten().tolist(),
"forces": torch.tensor(entry["forces"]).flatten().tolist(),
}
Expand Down Expand Up @@ -149,9 +166,27 @@ def predict(
coords = (
(coords_flat.reshape(len(energy_ref), -1, 3)).detach().requires_grad_(True)
)
box_vectors = entry.get("box_vectors", None)

topology = topologies[smiles]

energy_pred = smee.compute_energy(topology, force_field, coords)
if box_vectors is not None:
# smee does not support batched periodic evaluations,
# so we loop over conformers.
box_vectors = smee.utils.tensor_like(box_vectors, coords_flat).reshape(
len(energy_ref), 3, 3
)
energy_pred = torch.cat(
[
smee.compute_energy(
topology, force_field, coords[i], box_vectors[i]
)
for i in range(len(energy_ref))
]
)
else:
energy_pred = smee.compute_energy(topology, force_field, coords, None)

forces_pred = -torch.autograd.grad(
energy_pred.sum(),
coords,
Expand Down
123 changes: 107 additions & 16 deletions descent/tests/targets/test_energy.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,34 +41,56 @@ def mock_hoh_entry() -> Entry:
}


def test_create_dataset(mock_meoh_entry):
@pytest.mark.parametrize(
"box_vectors",
[None, torch.eye(3).repeat(2, 1, 1) * 20.0],
ids=["non-periodic", "periodic"],
)
def test_create_dataset(mock_meoh_entry, box_vectors):
entry = {**mock_meoh_entry, "box_vectors": box_vectors}

expected_entries = [
{
"smiles": mock_meoh_entry["smiles"],
"coords": pytest.approx(mock_meoh_entry["coords"].flatten()),
"energy": pytest.approx(mock_meoh_entry["energy"]),
"forces": pytest.approx(mock_meoh_entry["forces"].flatten()),
"id": None,
"smiles": entry["smiles"],
"coords": pytest.approx(entry["coords"].flatten()),
"energy": pytest.approx(entry["energy"]),
"forces": pytest.approx(entry["forces"].flatten()),
"box_vectors": None
if box_vectors is None
else pytest.approx(box_vectors.flatten()),
},
]

dataset = create_dataset([mock_meoh_entry])
dataset = create_dataset([entry])
assert len(dataset) == 1

entries = list(descent.utils.dataset.iter_dataset(dataset))
assert entries == expected_entries


def test_create_dataset_from_generator(mock_meoh_entry):
@pytest.mark.parametrize(
"box_vectors",
[None, torch.eye(3).repeat(2, 1, 1) * 20.0],
ids=["non-periodic", "periodic"],
)
def test_create_dataset_from_generator(mock_meoh_entry, box_vectors):
entry = {**mock_meoh_entry, "box_vectors": box_vectors}

expected_entries = [
{
"smiles": mock_meoh_entry["smiles"],
"coords": pytest.approx(mock_meoh_entry["coords"].flatten()),
"energy": pytest.approx(mock_meoh_entry["energy"]),
"forces": pytest.approx(mock_meoh_entry["forces"].flatten()),
"id": None,
"smiles": entry["smiles"],
"coords": pytest.approx(entry["coords"].flatten()),
"energy": pytest.approx(entry["energy"]),
"forces": pytest.approx(entry["forces"].flatten()),
"box_vectors": None
if box_vectors is None
else pytest.approx(box_vectors.flatten()),
},
]

dataset = create_dataset_from_generator(lambda: iter([mock_meoh_entry]))
dataset = create_dataset_from_generator(lambda: iter([entry]))
assert len(dataset) == 1

entries = list(descent.utils.dataset.iter_dataset(dataset))
Expand All @@ -85,11 +107,12 @@ def test_extract_smiles(mock_meoh_entry, mock_hoh_entry):


@pytest.mark.parametrize(
"reference, normalize,"
"box_vectors, reference, normalize,"
"expected_energy_ref, expected_forces_ref, "
"expected_energy_pred, expected_forces_pred",
[
(
pytest.param(
None,
"mean",
True,
torch.tensor([-0.5, 0.5]) / math.sqrt(2.0),
Expand Down Expand Up @@ -118,8 +141,10 @@ def test_extract_smiles(mock_meoh_entry, mock_hoh_entry):
dtype=torch.float64,
)
/ math.sqrt(6.0 * 3.0),
id="non-periodic-mean-normalized",
),
(
pytest.param(
None,
"min",
False,
torch.tensor([0.0, 1.0]),
Expand All @@ -146,10 +171,73 @@ def test_extract_smiles(mock_meoh_entry, mock_hoh_entry):
],
dtype=torch.float64,
),
id="non-periodic-min",
),
pytest.param(
torch.eye(3).repeat(2, 1, 1) * 30.0,
"mean",
True,
torch.tensor([-0.5, 0.5]) / math.sqrt(2.0),
torch.tensor(
[
[0.0, 1.0, 2.0],
[3.0, 4.0, 5.0],
[6.0, 7.0, 8.0],
[9.0, 10.0, 11.0],
[12.0, 13.0, 14.0],
[15.0, 16.0, 17.0],
],
dtype=torch.float64,
)
/ math.sqrt(6.0 * 3.0),
torch.tensor([5.585737228393555, -5.585737705230713]),
torch.tensor(
[
[0.0, -19.695229476617897, 0.0],
[38.04311560258793, 9.847614738308948, 0.0],
[-38.04311560258793, 9.847614738308948, 0.0],
[0.0, 32.3990898002703, 0.0],
[-24.190123962094730, -16.19954490013515, 0.0],
[24.190123962094730, -16.19954490013515, 0.0],
],
dtype=torch.float64,
),
id="periodic-mean-normalized",
),
pytest.param(
torch.eye(3).repeat(2, 1, 1) * 30.0,
"min",
False,
torch.tensor([0.0, 1.0]),
torch.tensor(
[
[0.0, 1.0, 2.0],
[3.0, 4.0, 5.0],
[6.0, 7.0, 8.0],
[9.0, 10.0, 11.0],
[12.0, 13.0, 14.0],
[15.0, 16.0, 17.0],
],
dtype=torch.float64,
),
torch.tensor([0.0, -15.79885196685791]),
torch.tensor(
[
[0.0, -83.55977630615234, 0.0],
[161.40325927734375, 41.77988815307617, 0.0],
[-161.40325927734375, 41.77988815307617, 0.0],
[0.0, 137.4576873779297, 0.0],
[-102.62999725341797, -68.72884368896484, 0.0],
[102.62999725341797, -68.72884368896484, 0.0],
],
dtype=torch.float64,
),
id="periodic-min",
),
],
)
def test_predict(
box_vectors,
reference,
normalize,
expected_energy_ref,
Expand All @@ -158,7 +246,10 @@ def test_predict(
expected_forces_pred,
mock_hoh_entry,
):
dataset = create_dataset([mock_hoh_entry])
entry = {**mock_hoh_entry}
if box_vectors is not None:
entry["box_vectors"] = box_vectors
dataset = create_dataset([entry])

force_field, [topology] = smee.converters.convert_interchange(
openff.interchange.Interchange.from_smirnoff(
Expand Down
Loading