mtenn.model.Model
- class mtenn.model.Model(representation, strategy, readout=None, fix_device=False)[source]
Bases:
ModuleModel object containing a
Representationblock that will take an input and convert it into some representation, and aStrategyblock that will take a complex representation and any number of constituent “part” representations, and convert to a final scalar value.Methods
__init__(representation, strategy[, ...])Build a
Model.forward(comp, *parts)Handles all the logic detailed in the docs page.
get_representation(*args, **kwargs)Pass a structure through the model's
Representationblock.- __init__(representation, strategy, readout=None, fix_device=False)[source]
Build a
Model.- Parameters:
representation (Representation) –
Representationblock for this modelstrategy (Strategy) –
Strategyblock for this modelreadout (Readout, optional) –
Readoutblock for this modelfix_device (bool, default=False) – If True, make sure the input is on the same device as the model, copying over as necessary.
- get_representation(*args, **kwargs)[source]
Pass a structure through the model’s
Representationblock. All arguments are passed directly toself.representation.
- forward(comp, *parts)[source]
Handles all the logic detailed in the docs page.
- Parameters:
comp (dict) – Complex structure that will be passed to the
Representationblockpart (list[dict], optional) – Structures for all individual parts of the complex (eg ligand and protein separately). If this is not passed, the constituent parts will be automatically parsed from
comp
- Returns:
torch.Tensor – Final model prediction. If the model has a
readout, this value will have thereadoutappliedlist[torch.Tensor] – A list containing only the pre-
readoutmodel prediction. This value is returned mainly to align the signatures for the single- and multi-pose models