mtenn.strategy.ConcatStrategy

class mtenn.strategy.ConcatStrategy(input_size, extract_key=None, layer_norm=False)[source]

Bases: Strategy

Strategy for combining the complex representation and parts representations in some learned manner, using sum-pooling to ensure permutation-invariance of the parts. For 3 n-dimensional input representations (eg complex, protein-only, and ligand-only), this Strategy acts as a function \(\phi: \mathbb{R}^{3n} \rightarrow \mathbb{R}\) that predicts a scalar-value \(\Delta G\) prediction.

The input \(\mathrm{\boldsymbol{x}}\) to \(\phi\) is computed in a permutation-invariant manner. For a protein-ligand complex, this looks like:

\[ \begin{align}\begin{aligned}\mathrm{\boldsymbol{x}_{parts}} &= [\mathrm{\boldsymbol{x}_{protein}}, \mathrm{\boldsymbol{x}_{ligand}}] + [\mathrm{\boldsymbol{x}_{ligand}}, \mathrm{\boldsymbol{x}_{protein}}]\\\mathrm{\boldsymbol{x}} &= [\mathrm{\boldsymbol{x}_{complex}}, \mathrm{\boldsymbol{x}_{parts}}]\\\Delta \mathrm{G_{pred}} &= \phi (\mathrm{\boldsymbol{x}})\end{aligned}\end{align} \]

In general, we will sum every permutation of the non-complex representations, and then this sum will be concatenated to the complex representation.

In its current iteration, this Strategy does not require you to specify the dimensionality (\(n\)) of each representation. Instead, the first time an instance of this Strategy is used, it will calculate the required input size and initialize a one-layer linear network of the appropriate dimensionality.

Methods

__init__(input_size[, extract_key, layer_norm])

Set the key to use to access vector representations if dict s are passed to the forward call.

forward(comp, *parts)

Calculate permutation-invariant concatenation of all representations, and pass through a one-layer linear NN.

__init__(input_size, extract_key=None, layer_norm=False)[source]

Set the key to use to access vector representations if dict s are passed to the forward call.

Parameters:
  • input_size (int) – Input size of linear model

  • extract_key (str, optional) – Key to use to extract representation from a dict

  • layer_norm (bool, default=False) – Apply a LayerNorm normalization before passing through the linear layer

forward(comp, *parts)[source]

Calculate permutation-invariant concatenation of all representations, and pass through a one-layer linear NN. This network will be initialized based on the input sizes the first time this method is called for a given instance of this class.

Parameters:
  • comp (torch.Tensor) – Complex representation

  • parts (list[torch.Tensor], optional) – Representations for all individual parts of the complex (eg ligand and protein separately)

Returns:

Predicted \(\Delta G\) value

Return type:

torch.Tensor