mtenn.conversion_utils.schnet.SchNet
- class mtenn.conversion_utils.schnet.SchNet(*args, model=None, **kwargs)[source]
Bases:
SchNetmtennwrapper around the PyTorch Geometric SchNet model. This class handles construction of the model and the formatting intoRepresentationandStrategyblocks.Methods
__init__(*args[, model])Initialize the underlying
torch_geometric.nn.models.SchNetmodel.forward(data)Make a prediction of the target property based on an input structure.
get_model([model, grouped, fix_device, ...])Exposed function to build a
ModelorGroupedModelfrom aSchNet(or args/kwargs).- __init__(*args, model=None, **kwargs)[source]
Initialize the underlying
torch_geometric.nn.models.SchNetmodel. If a value is passed formodel, builds a newtorch_geometric.nn.models.SchNetmodel based on those hyperparameters, and copies over the weights. Otherwise, all*argsand**kwargsare passed directly to thetorch_geometric.nn.models.SchNetconstructor.- Parameters:
model (
torch_geometric.nn.models.SchNet, optional) – PyTorch Geometric SchNet model to use to construct the underlying model
- forward(data)[source]
Make a prediction of the target property based on an input structure.
- Parameters:
data (dict[str, torch.Tensor]) – This dictionary should at minimum contain entries for:
"pos": Atom coordinates"z": Atomic numbers
- Returns:
Model prediction
- Return type:
torch.Tensor
- static get_model(model=None, grouped=False, fix_device=False, strategy: str = 'delta', layer_norm: bool = False, combination=None, pred_readout=None, comb_readout=None)[source]
Exposed function to build a
ModelorGroupedModelfrom aSchNet(or args/kwargs). If nomodelis given, build a defaultSchNetmodel.- Parameters:
model (mtenn.conversion_utils.schnet.SchNet, optional) –
SchNetmodel to use to build theModelobject. If not given, build a default modelgrouped (bool, default=False) – Build a
GroupedModelfix_device (bool, default=False) – If True, make sure the input is on the same device as the model, copying over as necessary
strategy (str, default=’delta’) –
Strategyto use to combine representations of the different parts. Options are [delta,concat,complex]layer_norm (bool, default=False) – Apply a
LayerNormnormalization before passing through the linear layercombination (mtenn.combination.Combination, optional) –
Combinationobject to use to combine multiple predictions. A value must be passed ifgroupedisTruepred_readout (mtenn.readout.Readout, optional) –
Readoutobject for the individual energy predictions. If aGroupedModelis being built, thisReadoutwill be applied to each individual prediction before the values are passed to theCombination. If aModelis being built, this will be applied to the single prediction before it is returnedcomb_readout (mtenn.readout.Readout, optional) – Readout object for the combined multi-pose prediction, in the case that a
GroupedModelis being built. Otherwise, this is ignored
- Returns:
ModelorGroupedModelcontaining the desiredRepresentation,Strategy, andCombinationandReadouts as desired- Return type: