Model

This document describes the classes in mtenn.model and their substituent parts, and the format of their inputs and outputs.

Model Blocks

Each class in mtenn.model is comprised of at least one of the following blocks:

Representation

The Representation block (mtenn.representation) is responsible for taking an input structure and learning an n-dimensional vector embedding. In practice, this block is a thin wrapper around an existing model architecture, potentially with some ad hoc manipulation done to ensure the output is a vector rather than a single scalar value.

For more information on the models that are currently implemented in mtenn, see the Currently Implemented Models section.

For information on adding new models into the mtenn framework, see the guide on adding a new installable model.

Strategy

The Strategy block (mtenn.strategy) is responsible for taking any number of vectors, each output from a Representation block, and combining them into a \(\Delta g_{\mathrm{bind}}\) prediction in kT units.

Currently, the following Strategy blocks are implemented in mtenn:

Readout

The Readout block (mtenn.readout) is responsible for converting the \(\Delta g_{\mathrm{bind}}\) prediction output from a Strategy block from kT units into any arbitrary other unit. Importantly, in contrast to the two previous blocks, this block doesn’t have any learned parameters. This increases model portability, as it allows the same model to be trained on multiple different data types by only swapping out this last layer.

We currently have implementations for pIC50 and pKi.

Combination

The Combination block (mtenn.combination) is responsible for combining multiple model predictions for the same compound into a single prediction. The internal workings of these blocks are a bit complex, and a more in-depth explanation is given in the Combination page.

Currently, the following Combination blocks are implemented in mtenn:

Single-Pose Models

This section is a description of the mtenn.model.Model class (referred to as Model from here), which makes a prediction on a single input conformation. The general data flow through a Model object is as depicted in the below diagram:

../_images/mtenn_model_diagram.png

In text form this is:

  1. The protein-ligand complex structure is passed to the Model

  2. Internally, the Model breaks the structure into 3 sub-structures: the full complex, just the protein, and just the ligand

  3. Each of these sub-structures is individually passed to the Representation block to generate a total of 3 vector representations

  4. All 3 representations are passed to the Strategy block, where they are combined into a \(\Delta g_{\mathrm{bind}}\) prediction in implicit kT units

  5. (optional) The \(\Delta g_{\mathrm{bind}}\) prediction is passed to the Readout block, where it is converted into whatever the final units are

Multi-Pose Models

This section is a description of the mtenn.model.GroupedModel class (GroupedModel from here), which makes a prediction on multiple input conformations. The general data flow through a GroupedModel object is as depicted in the below diagram:

../_images/mtenn_grouped_model_diagram.png

In text form this is:

  1. Each input conformation is passed through the same Model object to get a prediction for each individual conformation

  2. All predictions are passed through a Combination block to get an overall \(\Delta g_{\mathrm{bind}}\) prediction for the group of input poses

  3. (optional) The overall \(\Delta g_{\mathrm{bind}}\) prediction is passed to the Readout block, where it is converted into whatever the final units are

Ligand-Only Models

This section is a description of the mtenn.model.LigandOnlyModel class (LigandOnlyModel from here), which makes a prediction based only on a ligand representation. This class is mainly useful for 2D baseline models to compare the structure-based models against (eg ligand-only GNNs, fingerprint-based models, etc). The general data flow through a LigandOnlyModel object is the same as for a Model, but the Representation block is responsible for generating the energy prediction from the input, and the Strategy block is simply the identity function.

Currently Implemented Models

Data Model

Input Data

Currently, all of the single-pose models in mtenn (Model and LigandOnlyModel) expect a dict object to be passed as their input. The GroupedModel expects a list of these dicts, each one corresponding to a different input pose. What keys each model expects in the dict is left to the implementation of that model in the conversion_utils module. For more details on the data expected by each model, check that model’s docs page.

Output Data

To unify the outputs of all model types, all 3 models (Model, GroupedModel, and LigandOnlyModel) return two values: a scalar value that represents the model’s final prediction, and a list of values that represent the pre-Readout predictions of each input pose. In the case of the single-pose models, this list will have exactly one element. In the case of a multi-pose model, this list will have one element for each element in the list of input poses.