"""Models for storing applied force field parameters."""
import ast
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Union
from openff.toolkit.typing.engines.smirnoff.parameters import ParameterHandler
from openff.utilities.utilities import has_package, requires_package
from pydantic import Field, PrivateAttr, validator
from openff.interchange.exceptions import MissingParametersError
from openff.interchange.models import (
DefaultModel,
PotentialKey,
TopologyKey,
VirtualSiteKey,
)
from openff.interchange.types import ArrayQuantity, FloatQuantity
if has_package("jax"):
from jax import numpy
else:
# Known mypy bug/limitation: https://github.com/python/mypy/issues/1153
import numpy # type: ignore[no-redef]
if TYPE_CHECKING:
from openff.interchange.components.mdtraj import _OFFBioTop
[docs]class Potential(DefaultModel):
"""Base class for storing applied parameters."""
parameters: Dict[str, FloatQuantity] = dict()
map_key: Optional[int] = None
[docs] @validator("parameters")
def validate_parameters(cls, v):
for key, val in v.items():
if isinstance(val, list):
v[key] = ArrayQuantity.validate_type(val)
else:
v[key] = FloatQuantity.validate_type(val)
return v
def __hash__(self):
return hash(tuple(self.parameters.values()))
[docs]class WrappedPotential(DefaultModel):
"""Model storing other Potential model(s) inside inner data."""
[docs] class InnerData(DefaultModel):
"""The potentials being wrapped."""
data: Dict[Potential, float]
_inner_data: InnerData = PrivateAttr()
def __init__(self, data):
if isinstance(data, Potential):
self._inner_data = self.InnerData(data={data: 1.0})
elif isinstance(data, dict):
self._inner_data = self.InnerData(data=data)
@property
def parameters(self):
"""Get the parameters as represented by the stored potentials and coefficients."""
keys = {
pot for pot in self._inner_data.data.keys() for pot in pot.parameters.keys()
}
params = dict()
for key in keys:
sum_ = 0.0
for pot, coeff in self._inner_data.data.items():
sum_ += coeff * pot.parameters[key]
params.update({key: sum_})
return params
def __repr__(self):
return str(self._inner_data.data)
[docs]class PotentialHandler(DefaultModel):
"""Base class for storing parametrized force field data."""
type: str = Field(..., description="The type of potentials this handler stores.")
expression: str = Field(
...,
description="The analytical expression governing the potentials in this handler.",
)
slot_map: Dict[Union[TopologyKey, VirtualSiteKey], PotentialKey] = Field(
dict(),
description="A mapping between TopologyKey objects and PotentialKey objects.",
)
potentials: Dict[PotentialKey, Union[Potential, WrappedPotential]] = Field(
dict(),
description="A mapping between PotentialKey objects and Potential objects.",
)
@property
def independent_variables(self) -> Set[str]:
"""
Return a set of variables found in the expression but not in any potentials.
"""
vars_in_potentials = set([*self.potentials.values()][0].parameters.keys())
vars_in_expression = {
node.id
for node in ast.walk(ast.parse(self.expression))
if isinstance(node, ast.Name)
}
return vars_in_expression - vars_in_potentials
[docs] def store_matches(
self,
parameter_handler: ParameterHandler,
topology: "_OFFBioTop",
) -> None:
"""Populate self.slot_map with key-val pairs of [TopologyKey, PotentialKey]."""
raise NotImplementedError
[docs] def store_potentials(self, parameter_handler: ParameterHandler) -> None:
"""Populate self.potentials with key-val pairs of [PotentialKey, Potential]."""
raise NotImplementedError
def _get_parameters(self, atom_indices: Tuple[int]) -> Dict:
for topology_key in self.slot_map:
if topology_key.atom_indices == atom_indices:
potential_key = self.slot_map[topology_key]
potential = self.potentials[potential_key]
parameters = potential.parameters
return parameters
raise MissingParametersError(
f"Could not find parameter in parameter in handler {self.type} "
f"associated with atoms {atom_indices}"
)
[docs] def get_force_field_parameters(self):
"""Return a flattened representation of the force field parameters."""
# TODO: Handle WrappedPotential
if any(
isinstance(potential, WrappedPotential)
for potential in self.potentials.values()
):
raise NotImplementedError
return numpy.array(
[[v.m for v in p.parameters.values()] for p in self.potentials.values()]
)
[docs] def set_force_field_parameters(self, new_p):
"""Set the force field parameters from a flattened representation."""
mapping = self.get_mapping()
if new_p.shape[0] != len(mapping):
raise RuntimeError
for potential_key, potential_index in self.get_mapping().items():
potential = self.potentials[potential_key]
if len(new_p[potential_index, :]) != len(potential.parameters):
raise RuntimeError
for parameter_index, parameter_key in enumerate(potential.parameters):
parameter_units = potential.parameters[parameter_key].units
modified_parameter = new_p[potential_index, parameter_index]
self.potentials[potential_key].parameters[parameter_key] = (
modified_parameter * parameter_units
)
[docs] def get_system_parameters(self, p=None):
"""
Return a flattened representation of system parameters.
These values are effectively force field parameters as applied to a chemical topology.
"""
# TODO: Handle WrappedPotential
if any(
isinstance(potential, WrappedPotential)
for potential in self.potentials.values()
):
raise NotImplementedError
if p is None:
p = self.get_force_field_parameters()
mapping = self.get_mapping()
q: List = list()
for potential_key in self.slot_map.values():
index = mapping[potential_key]
q.append(p[index])
return numpy.array(q)
[docs] def get_mapping(self) -> Dict:
"""Get a mapping between potentials and array indices."""
mapping: Dict = dict()
index = 0
for potential_key in self.slot_map.values():
if potential_key not in mapping:
mapping[potential_key] = index
index += 1
return mapping
[docs] def parametrize(self, p=None):
"""Return an array of system parameters, given an array of force field parameters."""
if p is None:
p = self.get_force_field_parameters()
return self.get_system_parameters(p=p)
[docs] def parametrize_partial(self):
"""Return a function that will call `self.parametrize()` with arguments specified by `self.mapping`."""
from functools import partial
return partial(
self.parametrize,
mapping=self.get_mapping(),
)
[docs] @requires_package("jax")
def get_param_matrix(self):
"""Get a matrix representing the mapping between force field and system parameters."""
from functools import partial
import jax
p = self.get_force_field_parameters()
parametrize_partial = partial(
self.parametrize,
)
jac_parametrize = jax.jacfwd(parametrize_partial)
jac_res = jac_parametrize(p)
return jac_res.reshape(-1, p.flatten().shape[0])