mtenn.model.SplitModel

class mtenn.model.SplitModel(complex_representation, strategy, ligand_representation=None, protein_representation=None, readout=None, fix_device=False)[source]

Bases: Module

Model object containing a Representation block that will take an input and convert it into some representation, and a Strategy block that will take a complex representation and any number of constituent “part” representations, and convert to a final scalar value.

Methods

__init__(complex_representation, strategy[, ...])

Build a Model.

forward(comp[, prot, lig])

Handles all the logic detailed in the docs page.

get_representation(*args, **kwargs)

Pass a structure through the model's Representation block.

__init__(complex_representation, strategy, ligand_representation=None, protein_representation=None, readout=None, fix_device=False)[source]

Build a Model.

Parameters:
  • complex_representation (Representation) – Representation block for the complex

  • strategy (Strategy) – Strategy block for this model

  • ligand_representation (Representation, optional) – Representation block for the ligand. Leave unset to use the Representation block in complex_representation

  • protein_representation (Representation, optional) – Representation block for the protein. Leave unset to use the Representation block in complex_representation

  • readout (Readout, optional) – Readout block for this model

  • fix_device (bool, default=False) – If True, make sure the input is on the same device as the model, copying over as necessary.

get_representation(*args, **kwargs)[source]

Pass a structure through the model’s Representation block. All arguments are passed directly to self.representation.

forward(comp, prot=None, lig=None)[source]

Handles all the logic detailed in the docs page. This class assumes the only data in parts is the protein and ligand, in that order.

Parameters:
  • comp (dict) – Complex structure that will be passed to the Representation block

  • part (list[dict], optional) – Structures for protein and ligand. If this is not passed, the constituent parts will be automatically parsed from comp

Returns:

  • torch.Tensor – Final model prediction. If the model has a readout, this value will have the readout applied

  • list[torch.Tensor] – A list containing only the pre-readout model prediction. This value is returned mainly to align the signatures for the single- and multi-pose models