mtenn.model.SplitModel
- class mtenn.model.SplitModel(complex_representation, strategy, ligand_representation=None, protein_representation=None, readout=None, fix_device=False)[source]
Bases:
ModuleModel object containing a
Representationblock that will take an input and convert it into some representation, and aStrategyblock that will take a complex representation and any number of constituent “part” representations, and convert to a final scalar value.Methods
__init__(complex_representation, strategy[, ...])Build a
Model.forward(comp[, prot, lig])Handles all the logic detailed in the docs page.
get_representation(*args, **kwargs)Pass a structure through the model's
Representationblock.- __init__(complex_representation, strategy, ligand_representation=None, protein_representation=None, readout=None, fix_device=False)[source]
Build a
Model.- Parameters:
complex_representation (Representation) –
Representationblock for the complexstrategy (Strategy) –
Strategyblock for this modelligand_representation (Representation, optional) –
Representationblock for the ligand. Leave unset to use theRepresentationblock incomplex_representationprotein_representation (Representation, optional) –
Representationblock for the protein. Leave unset to use theRepresentationblock incomplex_representationreadout (Readout, optional) –
Readoutblock for this modelfix_device (bool, default=False) – If True, make sure the input is on the same device as the model, copying over as necessary.
- get_representation(*args, **kwargs)[source]
Pass a structure through the model’s
Representationblock. All arguments are passed directly toself.representation.
- forward(comp, prot=None, lig=None)[source]
Handles all the logic detailed in the docs page. This class assumes the only data in
partsis the protein and ligand, in that order.- Parameters:
comp (dict) – Complex structure that will be passed to the
Representationblockpart (list[dict], optional) – Structures for protein and ligand. If this is not passed, the constituent parts will be automatically parsed from
comp
- Returns:
torch.Tensor – Final model prediction. If the model has a
readout, this value will have thereadoutappliedlist[torch.Tensor] – A list containing only the pre-
readoutmodel prediction. This value is returned mainly to align the signatures for the single- and multi-pose models