mtenn.model.GroupedModel
- class mtenn.model.GroupedModel(representation, strategy, combination, pred_readout=None, comb_readout=None, fix_device=False)[source]
Bases:
ModelSubclass of
Modelfor use with multi-pose data, eg multiple docked poses of the same molecule with the same protein. In addition to theRepresentationandStrategyblocks in theModelclass,GroupedModelalso has aCombinationblock that dictates how the individualModelpredictions for each item in the group of data are combined.Methods
__init__(representation, strategy, combination)The
representation,strategy, andpred_readoutargs will be used to initialize the underlyingModel, while thecombinationandcomb_readoutargs will be applied to the output of the individual pose predictions.forward(input_list)Handles all the logic detailed in the docs page.
- __init__(representation, strategy, combination, pred_readout=None, comb_readout=None, fix_device=False)[source]
The
representation,strategy, andpred_readoutargs will be used to initialize the underlyingModel, while thecombinationandcomb_readoutargs will be applied to the output of the individual pose predictions.- Parameters:
representation (Representation) –
Representationblock for this modelstrategy (Strategy) –
Strategyblock for this modelcombination (Combination) –
Combinationblock for this modelpred_readout (Readout, optional) –
Readoutblock for the individual pose predictionscomb_readout (Readout, optional) –
Readoutblock for the combined posefix_device (bool, default=False) – If True, make sure the input is on the same device as the model, copying over as necessary.
- forward(input_list)[source]
Handles all the logic detailed in the docs page.
- Parameters:
input_list (list[tuple[dict]]) – List of tuples of (complex representation, part representations)
- Returns:
torch.Tensor – Final multi-pose model prediction
list[torch.Tensor] – A list containing the pre-
pred_readoutprediction for each entry ininput_list