Source code for mtenn.conversion_utils.visnet

"""
``Representation`` and ``Strategy`` implementations for the ViSNet model architecture.
The underlying model that we use is the implementation in
`PyTorch Geometric <https://pytorch-geometric.readthedocs.io/en/latest/generated/
torch_geometric.nn.models.ViSNet.html#torch_geometric.nn.models.ViSNet>`_.
"""
from copy import deepcopy
import torch
from torch_geometric.utils import scatter
from torch_geometric.nn.models import ViSNet as PygViSNet
from torch_geometric.nn.models.visnet import ViS_MP_Vertex

from mtenn.model import GroupedModel, Model
from mtenn.strategy import ComplexOnlyStrategy, ConcatStrategy, DeltaStrategy


[docs] class EquivariantVecToScalar(torch.nn.Module): """ Wrapper around ``torch_geometric.utils.scatter`` to use it as a ``Module``. """
[docs] def __init__(self, mean, reduce_op): """ Store use parameters. Parameters ---------- mean : torch.Tensor Mean of predicted value reduce_op : str Reduce operation to use in ``torch_geometric.utils.scatter`` """ super(EquivariantVecToScalar, self).__init__() self.mean = mean self.reduce_op = reduce_op
[docs] def forward(self, x): """ Perform the scatter operation. Parameters ---------- x : torch.Tensor Input tensor Returns ------- torch.Tensor Output of ``scatter`` call """ # All atoms from the same molecule and the same batch batch = torch.zeros(x.shape[0], dtype=torch.int64, device=x.device) y = scatter(x, batch, dim=0, reduce=self.reduce_op) return y + self.mean
[docs] class ViSNet(torch.nn.Module): """ ``mtenn`` wrapper around the PyTorch Geometric ViSNet model. This class handles construction of the model and the formatting into ``Representation`` and ``Strategy`` blocks. """
[docs] def __init__(self, *args, model=None, **kwargs): """ Initialize the underlying ``torch_geometric.nn.models.ViSNet`` model. If a value is passed for ``model``, builds a new ``torch_geometric.nn.models.ViSNet`` model based on those hyperparameters, and copies over the weights. Otherwise, all ``*args`` and ``**kwargs`` are passed directly to the ``torch_geometric.nn.models.ViSNet`` constructor. Parameters ---------- model : ``torch_geometric.nn.models.ViSNet``, optional PyTorch Geometric ViSNet model to use to construct the underlying model """ super().__init__() # If no model is passed, construct default ViSNet model, otherwise copy # all parameters and weights over if model is None: self.visnet = PygViSNet(*args, **kwargs) else: atomref = model.prior_model.atomref.weight.detach().clone() model_params = { "lmax": model.representation_model.lmax, "vecnorm_type": model.representation_model.vecnorm_type, "trainable_vecnorm": model.representation_model.trainable_vecnorm, "num_heads": model.representation_model.num_heads, "num_layers": model.representation_model.num_layers, "hidden_channels": model.representation_model.hidden_channels, "num_rbf": model.representation_model.num_rbf, "trainable_rbf": model.representation_model.trainable_rbf, "max_z": model.representation_model.max_z, "cutoff": model.representation_model.cutoff, "max_num_neighbors": model.representation_model.max_num_neighbors, "vertex": isinstance( model.representation_model.vis_mp_layers[0], ViS_MP_Vertex ), "reduce_op": model.reduce_op, "mean": model.mean, "std": model.std, "derivative": model.derivative, # not used. originally calculates "force" from energy "atomref": atomref, } self.visnet = PygViSNet(**model_params) self.visnet.load_state_dict(model.state_dict()) self.readout = EquivariantVecToScalar(self.visnet.mean, self.visnet.reduce_op)
[docs] def forward(self, data): """ Predict a vector representation of an input structure. Parameters ---------- data : dict[str, torch.Tensor] This dictionary should at minimum contain entries for: * ``"pos"``: Atom coordinates * ``"z"``: Atomic numbers Returns ------- torch.Tensor Predicted vector representation of input """ pos = data["pos"] z = data["z"] # all atom in one pass from the same molecule batch = torch.zeros(z.shape[0], device=z.device) x, v = self.visnet.representation_model(z, pos, batch) x = self.visnet.output_model.pre_reduce(x, v) x = x * self.visnet.std if self.visnet.prior_model is not None: x = self.visnet.prior_model(x, z) return x
def _get_representation(self): """ Copy model. Returns ------- mtenn.conversion_utils.visnet.ViSNet Copied ``ViSNet`` model """ # Copy model so initial model isn't affected return deepcopy(self) def _get_energy_func(self): """ Return copy of ``readout`` portion of the model. Returns ------- mtenn.conversion_utils.visnet.EquivariantVecToScalar Copy of ``self.readout`` """ return deepcopy(self.readout) def _get_delta_strategy(self): """ Build a :py:class:`DeltaStrategy <mtenn.strategy.DeltaStrategy>` object based on the calling model. Returns ------- mtenn.strategy.DeltaStrategy ``DeltaStrategy`` built from the model """ return DeltaStrategy(self._get_energy_func()) def _get_complex_only_strategy(self): """ Build a :py:class:`ComplexOnlyStrategy <mtenn.strategy.ComplexOnlyStrategy>` object based on the calling model. Returns ------- mtenn.strategy.ComplexOnlyStrategy ``ComplexOnlyStrategy`` built from the model """ return ComplexOnlyStrategy(self._get_energy_func())
[docs] @staticmethod def get_model( model=None, grouped=False, fix_device=False, strategy: str = "delta", combination=None, pred_readout=None, comb_readout=None, ): """ Exposed function to build a :py:class:`Model <mtenn.model.Model>` or :py:class:`GroupedModel <mtenn.model.GroupedModel>` from a :py:class:`ViSNet <mtenn.conversion_utils.visnet.ViSNet>` (or args/kwargs). If no ``model`` is given, build a default ``ViSNet`` model. Parameters ---------- model: mtenn.conversion_utils.visnet.ViSNet, optional ``ViSNet`` model to use to build the ``Model`` object. If not given, build a default model grouped: bool, default=False Build a ``GroupedModel`` fix_device: bool, default=False If True, make sure the input is on the same device as the model, copying over as necessary strategy: str, default='delta' ``Strategy`` to use to combine representations of the different parts. Options are [``delta``, ``concat``, ``complex``] combination: mtenn.combination.Combination, optional ``Combination`` object to use to combine multiple predictions. A value must be passed if ``grouped`` is ``True`` pred_readout : mtenn.readout.Readout, optional ``Readout`` object for the individual energy predictions. If a ``GroupedModel`` is being built, this ``Readout`` will be applied to each individual prediction before the values are passed to the ``Combination``. If a ``Model`` is being built, this will be applied to the single prediction before it is returned comb_readout : mtenn.readout.Readout, optional Readout object for the combined multi-pose prediction, in the case that a ``GroupedModel`` is being built. Otherwise, this is ignored Returns ------- mtenn.model.Model ``Model`` or ``GroupedModel`` containing the desired ``Representation``, ``Strategy``, and ``Combination`` and ``Readout`` s as desired """ if model is None: model = ViSNet() # First get representation module representation = model._get_representation() # Construct strategy module based on model and # representation (if necessary) strategy = strategy.lower() if strategy == "delta": strategy = model._get_delta_strategy() elif strategy == "concat": strategy = ConcatStrategy() elif strategy == "complex": strategy = model._get_complex_only_strategy() else: raise ValueError(f"Unknown strategy: {strategy}") # Check on `combination` if grouped and (combination is None): raise ValueError( "Must pass a value for `combination` if `grouped` is `True`." ) if grouped: return GroupedModel( representation, strategy, combination, pred_readout, comb_readout, fix_device, ) else: return Model(representation, strategy, pred_readout, fix_device)