mtenn.conversion_utils.schnet.SchNet

class mtenn.conversion_utils.schnet.SchNet(*args, model=None, **kwargs)[source]

Bases: SchNet

mtenn wrapper around the PyTorch Geometric SchNet model. This class handles construction of the model and the formatting into Representation and Strategy blocks.

Methods

__init__(*args[, model])

Initialize the underlying torch_geometric.nn.models.SchNet model.

forward(data)

Make a prediction of the target property based on an input structure.

get_model([model, grouped, fix_device, ...])

Exposed function to build a Model or GroupedModel from a SchNet (or args/kwargs).

__init__(*args, model=None, **kwargs)[source]

Initialize the underlying torch_geometric.nn.models.SchNet model. If a value is passed for model, builds a new torch_geometric.nn.models.SchNet model based on those hyperparameters, and copies over the weights. Otherwise, all *args and **kwargs are passed directly to the torch_geometric.nn.models.SchNet constructor.

Parameters:

model (torch_geometric.nn.models.SchNet, optional) – PyTorch Geometric SchNet model to use to construct the underlying model

forward(data)[source]

Make a prediction of the target property based on an input structure.

Parameters:

data (dict[str, torch.Tensor]) – This dictionary should at minimum contain entries for:

  • "pos": Atom coordinates

  • "z": Atomic numbers

Returns:

Model prediction

Return type:

torch.Tensor

static get_model(model=None, grouped=False, fix_device=False, strategy: str = 'delta', layer_norm: bool = False, combination=None, pred_readout=None, comb_readout=None)[source]

Exposed function to build a Model or GroupedModel from a SchNet (or args/kwargs). If no model is given, build a default SchNet model.

Parameters:
  • model (mtenn.conversion_utils.schnet.SchNet, optional) – SchNet model to use to build the Model object. If not given, build a default model

  • grouped (bool, default=False) – Build a GroupedModel

  • fix_device (bool, default=False) – If True, make sure the input is on the same device as the model, copying over as necessary

  • strategy (str, default=’delta’) – Strategy to use to combine representations of the different parts. Options are [delta, concat, complex]

  • layer_norm (bool, default=False) – Apply a LayerNorm normalization before passing through the linear layer

  • combination (mtenn.combination.Combination, optional) – Combination object to use to combine multiple predictions. A value must be passed if grouped is True

  • pred_readout (mtenn.readout.Readout, optional) – Readout object for the individual energy predictions. If a GroupedModel is being built, this Readout will be applied to each individual prediction before the values are passed to the Combination. If a Model is being built, this will be applied to the single prediction before it is returned

  • comb_readout (mtenn.readout.Readout, optional) – Readout object for the combined multi-pose prediction, in the case that a GroupedModel is being built. Otherwise, this is ignored

Returns:

Model or GroupedModel containing the desired Representation, Strategy, and Combination and Readout s as desired

Return type:

mtenn.model.Model