mtenn.conversion_utils.visnet.ViSNet
- class mtenn.conversion_utils.visnet.ViSNet(*args, model=None, **kwargs)[source]
Bases:
Modulemtennwrapper around the PyTorch Geometric ViSNet model. This class handles construction of the model and the formatting intoRepresentationandStrategyblocks.Methods
__init__(*args[, model])Initialize the underlying
torch_geometric.nn.models.ViSNetmodel.forward(data)Predict a vector representation of an input structure.
get_model([model, grouped, fix_device, ...])Exposed function to build a
ModelorGroupedModelfrom aViSNet(or args/kwargs).- __init__(*args, model=None, **kwargs)[source]
Initialize the underlying
torch_geometric.nn.models.ViSNetmodel. If a value is passed formodel, builds a newtorch_geometric.nn.models.ViSNetmodel based on those hyperparameters, and copies over the weights. Otherwise, all*argsand**kwargsare passed directly to thetorch_geometric.nn.models.ViSNetconstructor.- Parameters:
model (
torch_geometric.nn.models.ViSNet, optional) – PyTorch Geometric ViSNet model to use to construct the underlying model
- forward(data)[source]
Predict a vector representation of an input structure.
- Parameters:
data (dict[str, torch.Tensor]) – This dictionary should at minimum contain entries for:
"pos": Atom coordinates"z": Atomic numbers
- Returns:
Predicted vector representation of input
- Return type:
torch.Tensor
- static get_model(model=None, grouped=False, fix_device=False, strategy: str = 'delta', combination=None, pred_readout=None, comb_readout=None)[source]
Exposed function to build a
ModelorGroupedModelfrom aViSNet(or args/kwargs). If nomodelis given, build a defaultViSNetmodel.- Parameters:
model (mtenn.conversion_utils.visnet.ViSNet, optional) –
ViSNetmodel to use to build theModelobject. If not given, build a default modelgrouped (bool, default=False) – Build a
GroupedModelfix_device (bool, default=False) – If True, make sure the input is on the same device as the model, copying over as necessary
strategy (str, default=’delta’) –
Strategyto use to combine representations of the different parts. Options are [delta,concat,complex]combination (mtenn.combination.Combination, optional) –
Combinationobject to use to combine multiple predictions. A value must be passed ifgroupedisTruepred_readout (mtenn.readout.Readout, optional) –
Readoutobject for the individual energy predictions. If aGroupedModelis being built, thisReadoutwill be applied to each individual prediction before the values are passed to theCombination. If aModelis being built, this will be applied to the single prediction before it is returnedcomb_readout (mtenn.readout.Readout, optional) – Readout object for the combined multi-pose prediction, in the case that a
GroupedModelis being built. Otherwise, this is ignored
- Returns:
ModelorGroupedModelcontaining the desiredRepresentation,Strategy, andCombinationandReadouts as desired- Return type: