Source code for mtenn.conversion_utils.e3nn

"""
``Representation`` and ``Strategy`` implementations for an E(3)-equivariant model using
`e3nn <https://docs.e3nn.org/en/latest/index.html>`_. The underlying model that we use
is specifically the `January 2021 Network
<https://docs.e3nn.org/en/latest/api/nn/models/gate_points_2101.html>`_ model.
"""
from copy import deepcopy
import torch
from e3nn.nn.models.gate_points_2101 import Network

from mtenn.model import GroupedModel, Model
from mtenn.strategy import ComplexOnlyStrategy, ConcatStrategy, DeltaStrategy


[docs] class E3NN(Network): """ ``mtenn`` wrapper around the e3nn model. This class handles construction of the model and the formatting into ``Representation`` and ``Strategy`` blocks. """
[docs] def __init__(self, *args, model=None, **kwargs): """ Initialize the underlying ``e3nn.nn.models.gate_points_2101.Network`` model. If a value is passed for ``model``, builds a new ``e3nn.nn.models.gate_points_2101.Network`` model based on those hyperparameters, and copies over the weights. Otherwise, all ``*args`` and ``**kwargs`` are passed directly to the ``e3nn.nn.models.gate_points_2101.Network`` constructor. Parameters ---------- model : ``e3nn.nn.models.gate_points_2101.Network``, optional e3nn model to use to construct the underlying model """ # If no model is passed, construct e3nn model with model_kwargs, # otherwise copy all parameters and weights over if model is None: super(E3NN, self).__init__(*args, **kwargs) self.model_parameters = kwargs else: model_kwargs = { "irreps_in": model.irreps_in, "irreps_hidden": model.irreps_hidden, "irreps_out": model.irreps_out, "irreps_node_attr": model.irreps_node_attr, "irreps_edge_attr": model.irreps_edge_attr, "layers": len(model.layers) - 1, "max_radius": model.max_radius, "number_of_basis": model.number_of_basis, "num_nodes": model.num_nodes, "reduce_output": model.reduce_output, } # These need a bit of work to get # Use last layer bc guaranteed to be present and is just a Convolution (no # Gate) conv = model.layers[-1] model_kwargs["radial_layers"] = len(conv.fc.hs) - 2 model_kwargs["radial_neurons"] = conv.fc.hs[1] model_kwargs["num_neighbors"] = conv.num_neighbors super(E3NN, self).__init__(**model_kwargs) self.model_parameters = model_kwargs self.load_state_dict(model.state_dict())
[docs] def forward(self, data): """ Make a prediction of the target property based on an input structure. Parameters ---------- data : dict[str, torch.Tensor] This dictionary should at minimum contain entries for: * ``"pos"``: Atom coordinates * ``"x"``: One-hot encoding of atomic numbers And optionally ``"z"``, which stores the node attributes. Returns ------- torch.Tensor Model prediction """ x = super(E3NN, self).forward(data) copy = deepcopy(data) copy["x"] = torch.clone(x) return copy
@property def output_dim(self): return self.irreps_out @property def extract_key(self): return "x" def _get_representation(self): """ Copy model and remove last layer. Returns ------- mtenn.conversion_utils.e3nn.E3NN Copied e3nn model with the last layer removed """ # Copy model so initial model isn't affected model_copy = deepcopy(self) # Remove last layer model_copy.layers = model_copy.layers[:-1] model_copy.irreps_out = model_copy.layers[-1].irreps_out return model_copy def _get_energy_func(self, layer_norm=False): """ Return copy of last layer of the model. Parameters ---------- layer_norm: bool, default=False Apply a ``LayerNorm`` normalization before passing through the linear layer Returns ------- e3nn.nn.models.gate_points_2101.Network Copy of last layer """ final_conv = deepcopy(self) conv_kwargs = deepcopy(self.model_parameters) conv_kwargs["layers"] = 0 new_model = Network(**conv_kwargs) if layer_norm: new_model.layers[0] = E3NNLayerNorm(final_conv.layers[-1].irreps_in.dim) new_model.layers.append(final_conv.layers[-1]) else: new_model.layers[0] = final_conv.layers[-1] return new_model def _get_delta_strategy(self, layer_norm=False): """ Build a :py:class:`DeltaStrategy <mtenn.strategy.DeltaStrategy>` object based on the calling model. Parameters ---------- layer_norm: bool, default=False Apply a ``LayerNorm`` normalization before passing through the linear layer Returns ------- mtenn.strategy.DeltaStrategy ``DeltaStrategy`` built from the model """ return DeltaStrategy(self._get_energy_func(layer_norm)) def _get_complex_only_strategy(self, layer_norm=False): """ Build a :py:class:`ComplexOnlyStrategy <mtenn.strategy.ComplexOnlyStrategy>` object based on the calling model. Parameters ---------- layer_norm: bool, default=False Apply a ``LayerNorm`` normalization before passing through the linear layer Returns ------- mtenn.strategy.ComplexOnlyStrategy ``ComplexOnlyStrategy`` built from the model """ return ComplexOnlyStrategy(self._get_energy_func(layer_norm)) def _get_concat_strategy(self, layer_norm=False): """ Build a :py:class:`ConcatStrategy <mtenn.strategy.ConcatStrategy>` object using the key ``"x"`` to extract the tensor representation from the data dict. Parameters ---------- layer_norm: bool, default=False Apply a ``LayerNorm`` normalization before passing through the linear layer Returns ------- ConcatStrategy ``ConcatStrategy`` for the model """ # Calculate input size as 3 * dimensionality of output of Representation # (last layer in Representation is 2nd to last in original model) input_size = 3 * self.layers[-2].irreps_out.dim return ConcatStrategy( input_size=input_size, extract_key="x", layer_norm=layer_norm )
[docs] @staticmethod def get_model( model=None, model_kwargs=None, grouped=False, fix_device=False, strategy: str = "delta", layer_norm: bool = False, combination=None, pred_readout=None, comb_readout=None, ): """ Exposed function to build a :py:class:`Model <mtenn.model.Model>` or :py:class:`GroupedModel <mtenn.model.GroupedModel>` from an :py:class:`E3NN <mtenn.conversion_utils.e3nn.E3NN>` (or args/kwargs). If no ``model`` is given, use the ``model_kwargs``. Parameters ---------- model: mtenn.conversion_utils.e3nn.E3NN, optional ``E3NN`` model to use to build the ``Model`` object. If not given, use the passed ``model_kwargs`` model_kwargs: dict, optional Dictionary used to initialize ``E3NN`` model if nothing is passed for ``model`` grouped: bool, default=False Build a ``GroupedModel`` fix_device: bool, default=False If True, make sure the input is on the same device as the model, copying over as necessary strategy: str, default='delta' ``Strategy`` to use to combine representations of the different parts. Options are [``delta``, ``concat``, ``complex``] layer_norm: bool, default=False Apply a ``LayerNorm`` normalization before passing through the linear layer combination: mtenn.combination.Combination, optional ``Combination`` object to use to combine multiple predictions. A value must be passed if ``grouped`` is ``True`` pred_readout : mtenn.readout.Readout, optional ``Readout`` object for the individual energy predictions. If a ``GroupedModel`` is being built, this ``Readout`` will be applied to each individual prediction before the values are passed to the ``Combination``. If a ``Model`` is being built, this will be applied to the single prediction before it is returned comb_readout : mtenn.readout.Readout, optional Readout object for the combined multi-pose prediction, in the case that a ``GroupedModel`` is being built. Otherwise, this is ignored Returns ------- mtenn.model.Model ``Model`` or ``GroupedModel`` containing the desired ``Representation``, ``Strategy``, and ``Combination`` and ``Readout`` s as desired """ if model is None: model = E3NN(model_kwargs) # Get representation module representation = model._get_representation(reduce_output=strategy == "concat") # Construct strategy module based on model and # representation (if necessary) strategy = strategy.lower() if strategy == "delta": strategy = model._get_delta_strategy(layer_norm) elif strategy == "concat": strategy = model._get_concat_strategy(layer_norm) elif strategy == "complex": strategy = model._get_complex_only_strategy(layer_norm) else: raise ValueError(f"Unknown strategy: {strategy}") # Check on `combination` if grouped and (combination is None): raise ValueError( "Must pass a value for `combination` if `grouped` is `True`." ) if grouped: return GroupedModel( representation, strategy, combination, pred_readout, comb_readout, fix_device, ) else: return Model(representation, strategy, pred_readout, fix_device)
[docs] class E3NNLayerNorm(torch.nn.LayerNorm): """ Wrapper class around the torch LayerNorm to match the expected signature in the e3nn model forward pass. """
[docs] def forward(self, x, *args, **kwargs): return super().forward(x)