Source code for mtenn.conversion_utils.schnet

"""
``Representation`` and ``Strategy`` implementations for the SchNet model architecture.
The underlying model that we use is the implementation in
`PyTorch Geometric <https://pytorch-geometric.readthedocs.io/en/latest/generated/
torch_geometric.nn.models.SchNet.html#torch_geometric.nn.models.SchNet>`_.
"""
from copy import deepcopy
import torch
from torch_geometric.nn.models import SchNet as PygSchNet
from torch_geometric.nn.models.schnet import RadiusInteractionGraph
from typing import Callable, Optional

from mtenn.model import GroupedModel, Model
from mtenn.strategy import ComplexOnlyStrategy, ConcatStrategy, DeltaStrategy


[docs] class SchNet(PygSchNet): """ ``mtenn`` wrapper around the PyTorch Geometric SchNet model. This class handles construction of the model and the formatting into ``Representation`` and ``Strategy`` blocks. """
[docs] def __init__(self, *args, model=None, **kwargs): """ Initialize the underlying ``torch_geometric.nn.models.SchNet`` model. If a value is passed for ``model``, builds a new ``torch_geometric.nn.models.SchNet`` model based on those hyperparameters, and copies over the weights. Otherwise, all ``*args`` and ``**kwargs`` are passed directly to the ``torch_geometric.nn.models.SchNet`` constructor. Parameters ---------- model : ``torch_geometric.nn.models.SchNet``, optional PyTorch Geometric SchNet model to use to construct the underlying model """ # If no model is passed, construct default SchNet model, otherwise copy # all parameters and weights over if model is None: super(SchNet, self).__init__(*args, **kwargs) else: try: atomref = model.atomref.weight.detach().clone() except AttributeError: atomref = None model_params = ( model.hidden_channels, model.num_filters, model.num_interactions, model.num_gaussians, model.cutoff, model.interaction_graph, model.interaction_graph.max_num_neighbors, model.readout, model.dipole, model.mean, model.std, atomref, ) super(SchNet, self).__init__(*model_params) self.load_state_dict(model.state_dict())
[docs] def forward(self, data): """ Make a prediction of the target property based on an input structure. Parameters ---------- data : dict[str, torch.Tensor] This dictionary should at minimum contain entries for: * ``"pos"``: Atom coordinates * ``"z"``: Atomic numbers Returns ------- torch.Tensor Model prediction """ return super(SchNet, self).forward(data["z"], data["pos"])
@property def output_dim(self): return self.lin1.out_features @property def extract_key(self): return None def _get_representation(self): """ Copy model and set last layer as the ``Identity``. Parameters ---------- model: mtenn.conversion_utils.schnet.SchNet ``SchNet`` model Returns ------- mtenn.conversion_utils.schnet.SchNet Copied ``SchNet`` model with the last layer replaced by the ``Identity`` """ # Copy model so initial model isn't affected model_copy = deepcopy(self) # Replace final linear layer with an identity module model_copy.lin2 = torch.nn.Identity() return model_copy def _get_energy_func(self, layer_norm=False): """ Return copy of last layer of the model. Parameters ---------- model: mtenn.conversion_utils.schnet.SchNet ``SchNet`` model layer_norm: bool, default=False Apply a ``LayerNorm`` normalization before passing through the linear layer Returns ------- torch.nn.modules.linear.Linear Copy of last layer """ lin = deepcopy(self.lin2) if layer_norm: energy_func = torch.nn.Sequential(torch.nn.LayerNorm(lin.in_features), lin) else: energy_func = lin return energy_func def _get_delta_strategy(self, layer_norm=False): """ Build a :py:class:`DeltaStrategy <mtenn.strategy.DeltaStrategy>` object based on the calling model. Parameters ---------- layer_norm: bool, default=False Apply a ``LayerNorm`` normalization before passing through the linear layer Returns ------- mtenn.strategy.DeltaStrategy ``DeltaStrategy`` built from the model """ return DeltaStrategy(self._get_energy_func(layer_norm)) def _get_complex_only_strategy(self, layer_norm=False): """ Build a :py:class:`ComplexOnlyStrategy <mtenn.strategy.ComplexOnlyStrategy>` object based on the calling model. Parameters ---------- layer_norm: bool, default=False Apply a ``LayerNorm`` normalization before passing through the linear layer Returns ------- mtenn.strategy.ComplexOnlyStrategy ``ComplexOnlyStrategy`` built from the model """ return ComplexOnlyStrategy(self._get_energy_func(layer_norm)) def _get_concat_strategy(self, layer_norm=False): """ Build a :py:class:`ConcatStrategy <mtenn.strategy.ConcatStrategy>` object with the appropriate ``input_size``. Parameters ---------- layer_norm: bool, default=False Apply a ``LayerNorm`` normalization before passing through the linear layer Returns ------- ConcatStrategy ``ConcatStrategy`` for the model """ # Calculate input size as 3 * dimensionality of output of Representation # (ie lin1 layer) input_size = 3 * self.lin1.out_features return ConcatStrategy(input_size=input_size, layer_norm=layer_norm)
[docs] @staticmethod def get_model( model=None, grouped=False, fix_device=False, strategy: str = "delta", layer_norm: bool = False, combination=None, pred_readout=None, comb_readout=None, ): """ Exposed function to build a :py:class:`Model <mtenn.model.Model>` or :py:class:`GroupedModel <mtenn.model.GroupedModel>` from a :py:class:`SchNet <mtenn.conversion_utils.schnet.SchNet>` (or args/kwargs). If no ``model`` is given, build a default ``SchNet`` model. Parameters ---------- model: mtenn.conversion_utils.schnet.SchNet, optional ``SchNet`` model to use to build the ``Model`` object. If not given, build a default model grouped: bool, default=False Build a ``GroupedModel`` fix_device: bool, default=False If True, make sure the input is on the same device as the model, copying over as necessary strategy: str, default='delta' ``Strategy`` to use to combine representations of the different parts. Options are [``delta``, ``concat``, ``complex``] layer_norm: bool, default=False Apply a ``LayerNorm`` normalization before passing through the linear layer combination: mtenn.combination.Combination, optional ``Combination`` object to use to combine multiple predictions. A value must be passed if ``grouped`` is ``True`` pred_readout : mtenn.readout.Readout, optional ``Readout`` object for the individual energy predictions. If a ``GroupedModel`` is being built, this ``Readout`` will be applied to each individual prediction before the values are passed to the ``Combination``. If a ``Model`` is being built, this will be applied to the single prediction before it is returned comb_readout : mtenn.readout.Readout, optional Readout object for the combined multi-pose prediction, in the case that a ``GroupedModel`` is being built. Otherwise, this is ignored Returns ------- mtenn.model.Model ``Model`` or ``GroupedModel`` containing the desired ``Representation``, ``Strategy``, and ``Combination`` and ``Readout`` s as desired """ if model is None: model = SchNet() # First get representation module representation = model._get_representation() # Construct strategy module based on model and # representation (if necessary) strategy = strategy.lower() if strategy == "delta": strategy = model._get_delta_strategy(layer_norm) elif strategy == "concat": strategy = model._get_concat_strategy(layer_norm) elif strategy == "complex": strategy = model._get_complex_only_strategy(layer_norm) else: raise ValueError(f"Unknown strategy: {strategy}") # Check on `combination` if grouped and (combination is None): raise ValueError( "Must pass a value for `combination` if `grouped` is `True`." ) if grouped: return GroupedModel( representation, strategy, combination, pred_readout, comb_readout, fix_device, ) else: return Model(representation, strategy, pred_readout, fix_device)
[docs] class MemoizedRadiusInteractionGraph(RadiusInteractionGraph): """ Memoized version of the PyG RadiusInteractionGraph. The lookup function used to map the position and batch tensors to the edge index/weight tensors can be any function with the appropriate signature, but defaults to a string of the number of atoms and the x, y, and z coords of the first and last two atoms to 7 decimal points. Note that this of course relies on the atoms for a given sample being in the same order each time. """
[docs] def __init__( self, lookup_function: Optional[Callable] = None, cutoff: float = 10.0, max_num_neighbors: int = 32, ): """ Initialize underlying RadiusInteractionGraph and define lookup function for storing computed results. Parameters ---------- lookup_function: Callable, optional Function mapping from position and batch tensors to a dict lookup key cutoff: float, default=10.0 Cutoff distance for interatomic interactions max_num_neighbors: int, default=32 The maximum number of neighbors to collect for each node within the ``cutoff`` distance with the default interaction graph method """ super().__init__(cutoff=cutoff, max_num_neighbors=max_num_neighbors) if lookup_function is None: def lookup_function(pos: torch.Tensor, batch: torch.Tensor): return ( str(len(pos)) + "".join([f"{v:0.7f}" for v in pos[:2, :].flatten()]) + "".join([f"{v:0.7f}" for v in pos[-2:, :].flatten()]) ) self.lookup_function = lookup_function self.lookup_table = {}
def __repr__(self): return ( f"MemoizedRadiusInteractionGraph(cutoff={self.cutoff}, " f"max_num_neighbors={self.max_num_neighbors})" )
[docs] def forward(self, pos: torch.Tensor, batch: torch.Tensor): """ Perform forward pass of RadiusInteractionGraph class, first checking to see if this calculation has already been done. Parameters ---------- pos: torch.Tensor Coordinates of each atom batch: torch.Tensor Batch indices assigning each atom to a separate molecule Returns ------- torch.Tensor Edge index tensor torch.Tensor Edge weight tensor """ lookup_key = self.lookup_function(pos, batch) try: return self.lookup_table[lookup_key] except KeyError: edge_index, edge_weight = super().forward(pos, batch) self.lookup_table[lookup_key] = (edge_index, edge_weight) return edge_index, edge_weight