mtenn.model.GroupedModel

class mtenn.model.GroupedModel(representation, strategy, combination, pred_readout=None, comb_readout=None, fix_device=False)[source]

Bases: Model

Subclass of Model for use with multi-pose data, eg multiple docked poses of the same molecule with the same protein. In addition to the Representation and Strategy blocks in the Model class, GroupedModel also has a Combination block that dictates how the individual Model predictions for each item in the group of data are combined.

Methods

__init__(representation, strategy, combination)

The representation, strategy, and pred_readout args will be used to initialize the underlying Model, while the combination and comb_readout args will be applied to the output of the individual pose predictions.

forward(input_list)

Handles all the logic detailed in the docs page.

__init__(representation, strategy, combination, pred_readout=None, comb_readout=None, fix_device=False)[source]

The representation, strategy, and pred_readout args will be used to initialize the underlying Model, while the combination and comb_readout args will be applied to the output of the individual pose predictions.

Parameters:
  • representation (Representation) – Representation block for this model

  • strategy (Strategy) – Strategy block for this model

  • combination (Combination) – Combination block for this model

  • pred_readout (Readout, optional) – Readout block for the individual pose predictions

  • comb_readout (Readout, optional) – Readout block for the combined pose

  • fix_device (bool, default=False) – If True, make sure the input is on the same device as the model, copying over as necessary.

forward(input_list)[source]

Handles all the logic detailed in the docs page.

Parameters:

input_list (list[tuple[dict]]) – List of tuples of (complex representation, part representations)

Returns:

  • torch.Tensor – Final multi-pose model prediction

  • list[torch.Tensor] – A list containing the pre-pred_readout prediction for each entry in input_list