mtenn.conversion_utils.e3nn.E3NN
- class mtenn.conversion_utils.e3nn.E3NN(*args, model=None, **kwargs)[source]
Bases:
Networkmtennwrapper around the e3nn model. This class handles construction of the model and the formatting intoRepresentationandStrategyblocks.Methods
__init__(*args[, model])Initialize the underlying
e3nn.nn.models.gate_points_2101.Networkmodel.forward(data)Make a prediction of the target property based on an input structure.
get_model([model, model_kwargs, grouped, ...])Exposed function to build a
ModelorGroupedModelfrom anE3NN(or args/kwargs).- __init__(*args, model=None, **kwargs)[source]
Initialize the underlying
e3nn.nn.models.gate_points_2101.Networkmodel. If a value is passed formodel, builds a newe3nn.nn.models.gate_points_2101.Networkmodel based on those hyperparameters, and copies over the weights. Otherwise, all*argsand**kwargsare passed directly to thee3nn.nn.models.gate_points_2101.Networkconstructor.- Parameters:
model (
e3nn.nn.models.gate_points_2101.Network, optional) – e3nn model to use to construct the underlying model
- forward(data)[source]
Make a prediction of the target property based on an input structure.
- Parameters:
data (dict[str, torch.Tensor]) – This dictionary should at minimum contain entries for:
"pos": Atom coordinates"x": One-hot encoding of atomic numbers
And optionally
"z", which stores the node attributes.- Returns:
Model prediction
- Return type:
torch.Tensor
- static get_model(model=None, model_kwargs=None, grouped=False, fix_device=False, strategy: str = 'delta', layer_norm: bool = False, combination=None, pred_readout=None, comb_readout=None)[source]
Exposed function to build a
ModelorGroupedModelfrom anE3NN(or args/kwargs). If nomodelis given, use themodel_kwargs.- Parameters:
model (mtenn.conversion_utils.e3nn.E3NN, optional) –
E3NNmodel to use to build theModelobject. If not given, use the passedmodel_kwargsmodel_kwargs (dict, optional) – Dictionary used to initialize
E3NNmodel if nothing is passed formodelgrouped (bool, default=False) – Build a
GroupedModelfix_device (bool, default=False) – If True, make sure the input is on the same device as the model, copying over as necessary
strategy (str, default=’delta’) –
Strategyto use to combine representations of the different parts. Options are [delta,concat,complex]layer_norm (bool, default=False) – Apply a
LayerNormnormalization before passing through the linear layercombination (mtenn.combination.Combination, optional) –
Combinationobject to use to combine multiple predictions. A value must be passed ifgroupedisTruepred_readout (mtenn.readout.Readout, optional) –
Readoutobject for the individual energy predictions. If aGroupedModelis being built, thisReadoutwill be applied to each individual prediction before the values are passed to theCombination. If aModelis being built, this will be applied to the single prediction before it is returnedcomb_readout (mtenn.readout.Readout, optional) – Readout object for the combined multi-pose prediction, in the case that a
GroupedModelis being built. Otherwise, this is ignored
- Returns:
ModelorGroupedModelcontaining the desiredRepresentation,Strategy, andCombinationandReadouts as desired- Return type: