mtenn.conversion_utils.e3nn.E3NN

class mtenn.conversion_utils.e3nn.E3NN(*args, model=None, **kwargs)[source]

Bases: Network

mtenn wrapper around the e3nn model. This class handles construction of the model and the formatting into Representation and Strategy blocks.

Methods

__init__(*args[, model])

Initialize the underlying e3nn.nn.models.gate_points_2101.Network model.

forward(data)

Make a prediction of the target property based on an input structure.

get_model([model, model_kwargs, grouped, ...])

Exposed function to build a Model or GroupedModel from an E3NN (or args/kwargs).

__init__(*args, model=None, **kwargs)[source]

Initialize the underlying e3nn.nn.models.gate_points_2101.Network model. If a value is passed for model, builds a new e3nn.nn.models.gate_points_2101.Network model based on those hyperparameters, and copies over the weights. Otherwise, all *args and **kwargs are passed directly to the e3nn.nn.models.gate_points_2101.Network constructor.

Parameters:

model (e3nn.nn.models.gate_points_2101.Network, optional) – e3nn model to use to construct the underlying model

forward(data)[source]

Make a prediction of the target property based on an input structure.

Parameters:

data (dict[str, torch.Tensor]) – This dictionary should at minimum contain entries for:

  • "pos": Atom coordinates

  • "x": One-hot encoding of atomic numbers

And optionally "z", which stores the node attributes.

Returns:

Model prediction

Return type:

torch.Tensor

static get_model(model=None, model_kwargs=None, grouped=False, fix_device=False, strategy: str = 'delta', layer_norm: bool = False, combination=None, pred_readout=None, comb_readout=None)[source]

Exposed function to build a Model or GroupedModel from an E3NN (or args/kwargs). If no model is given, use the model_kwargs.

Parameters:
  • model (mtenn.conversion_utils.e3nn.E3NN, optional) – E3NN model to use to build the Model object. If not given, use the passed model_kwargs

  • model_kwargs (dict, optional) – Dictionary used to initialize E3NN model if nothing is passed for model

  • grouped (bool, default=False) – Build a GroupedModel

  • fix_device (bool, default=False) – If True, make sure the input is on the same device as the model, copying over as necessary

  • strategy (str, default=’delta’) – Strategy to use to combine representations of the different parts. Options are [delta, concat, complex]

  • layer_norm (bool, default=False) – Apply a LayerNorm normalization before passing through the linear layer

  • combination (mtenn.combination.Combination, optional) – Combination object to use to combine multiple predictions. A value must be passed if grouped is True

  • pred_readout (mtenn.readout.Readout, optional) – Readout object for the individual energy predictions. If a GroupedModel is being built, this Readout will be applied to each individual prediction before the values are passed to the Combination. If a Model is being built, this will be applied to the single prediction before it is returned

  • comb_readout (mtenn.readout.Readout, optional) – Readout object for the combined multi-pose prediction, in the case that a GroupedModel is being built. Otherwise, this is ignored

Returns:

Model or GroupedModel containing the desired Representation, Strategy, and Combination and Readout s as desired

Return type:

mtenn.model.Model