mtenn.model.Model

class mtenn.model.Model(representation, strategy, readout=None, fix_device=False)[source]

Bases: Module

Model object containing a Representation block that will take an input and convert it into some representation, and a Strategy block that will take a complex representation and any number of constituent “part” representations, and convert to a final scalar value.

Methods

__init__(representation, strategy[, ...])

Build a Model.

forward(comp, *parts)

Handles all the logic detailed in the docs page.

get_representation(*args, **kwargs)

Pass a structure through the model's Representation block.

__init__(representation, strategy, readout=None, fix_device=False)[source]

Build a Model.

Parameters:
  • representation (Representation) – Representation block for this model

  • strategy (Strategy) – Strategy block for this model

  • readout (Readout, optional) – Readout block for this model

  • fix_device (bool, default=False) – If True, make sure the input is on the same device as the model, copying over as necessary.

get_representation(*args, **kwargs)[source]

Pass a structure through the model’s Representation block. All arguments are passed directly to self.representation.

forward(comp, *parts)[source]

Handles all the logic detailed in the docs page.

Parameters:
  • comp (dict) – Complex structure that will be passed to the Representation block

  • part (list[dict], optional) – Structures for all individual parts of the complex (eg ligand and protein separately). If this is not passed, the constituent parts will be automatically parsed from comp

Returns:

  • torch.Tensor – Final model prediction. If the model has a readout, this value will have the readout applied

  • list[torch.Tensor] – A list containing only the pre-readout model prediction. This value is returned mainly to align the signatures for the single- and multi-pose models