mtenn.conversion_utils.visnet.ViSNet

class mtenn.conversion_utils.visnet.ViSNet(*args, model=None, **kwargs)[source]

Bases: Module

mtenn wrapper around the PyTorch Geometric ViSNet model. This class handles construction of the model and the formatting into Representation and Strategy blocks.

Methods

__init__(*args[, model])

Initialize the underlying torch_geometric.nn.models.ViSNet model.

forward(data)

Predict a vector representation of an input structure.

get_model([model, grouped, fix_device, ...])

Exposed function to build a Model or GroupedModel from a ViSNet (or args/kwargs).

__init__(*args, model=None, **kwargs)[source]

Initialize the underlying torch_geometric.nn.models.ViSNet model. If a value is passed for model, builds a new torch_geometric.nn.models.ViSNet model based on those hyperparameters, and copies over the weights. Otherwise, all *args and **kwargs are passed directly to the torch_geometric.nn.models.ViSNet constructor.

Parameters:

model (torch_geometric.nn.models.ViSNet, optional) – PyTorch Geometric ViSNet model to use to construct the underlying model

forward(data)[source]

Predict a vector representation of an input structure.

Parameters:

data (dict[str, torch.Tensor]) – This dictionary should at minimum contain entries for:

  • "pos": Atom coordinates

  • "z": Atomic numbers

Returns:

Predicted vector representation of input

Return type:

torch.Tensor

static get_model(model=None, grouped=False, fix_device=False, strategy: str = 'delta', combination=None, pred_readout=None, comb_readout=None)[source]

Exposed function to build a Model or GroupedModel from a ViSNet (or args/kwargs). If no model is given, build a default ViSNet model.

Parameters:
  • model (mtenn.conversion_utils.visnet.ViSNet, optional) – ViSNet model to use to build the Model object. If not given, build a default model

  • grouped (bool, default=False) – Build a GroupedModel

  • fix_device (bool, default=False) – If True, make sure the input is on the same device as the model, copying over as necessary

  • strategy (str, default=’delta’) – Strategy to use to combine representations of the different parts. Options are [delta, concat, complex]

  • combination (mtenn.combination.Combination, optional) – Combination object to use to combine multiple predictions. A value must be passed if grouped is True

  • pred_readout (mtenn.readout.Readout, optional) – Readout object for the individual energy predictions. If a GroupedModel is being built, this Readout will be applied to each individual prediction before the values are passed to the Combination. If a Model is being built, this will be applied to the single prediction before it is returned

  • comb_readout (mtenn.readout.Readout, optional) – Readout object for the combined multi-pose prediction, in the case that a GroupedModel is being built. Otherwise, this is ignored

Returns:

Model or GroupedModel containing the desired Representation, Strategy, and Combination and Readout s as desired

Return type:

mtenn.model.Model