mtenn.strategy.SplitConcatStrategy

class mtenn.strategy.SplitConcatStrategy(input_size, complex_extract_key=None, ligand_extract_key=None, protein_extract_key=None, layer_norm=False)[source]

Bases: Strategy

Strategy for combining the complex representation and parts representations in some learned manner. This implementation assumes different sizes for the protein and ligand representations, and so skips the sum-pooling step. If possible, it’s recommended to use :py:class:ConcatStrategy. For representation sizes of \(d\) for the complex and protein and \(l\) for the ligand, this Strategy acts as a function \(\phi: \mathbb{R}^{2d+l} \rightarrow \mathbb{R}\) that predicts a scalar-value \(\Delta G\) prediction.

We concatenate the representations in the order of complex, ligand, protein:

\[ \begin{align}\begin{aligned}\mathrm{\boldsymbol{x}} &= [\mathrm{\boldsymbol{x}_{complex}}, \mathrm{\boldsymbol{x}_{ligand}}, \mathrm{\boldsymbol{x}_{protein}}]\\\Delta \mathrm{G_{pred}} &= \phi (\mathrm{\boldsymbol{x}})\end{aligned}\end{align} \]

Methods

__init__(input_size[, complex_extract_key, ...])

Set the key to use to access vector representations if dict s are passed to the forward call.

forward(comp, ligand, protein)

Concatenate all representations, and pass through a one-layer linear NN.

__init__(input_size, complex_extract_key=None, ligand_extract_key=None, protein_extract_key=None, layer_norm=False)[source]

Set the key to use to access vector representations if dict s are passed to the forward call.

Parameters:
  • input_size (int) – Input size of linear model

  • complex_extract_key (str, optional) – Key to use to extract representation from a dict for the complex

  • ligand_extract_key (str, optional) – Key to use to extract representation from a dict for the ligand

  • protein_extract_key (str, optional) – Key to use to extract representation from a dict for the protein

  • layer_norm (bool, default=False) – Apply a LayerNorm normalization before passing through the linear layer

forward(comp, ligand, protein)[source]

Concatenate all representations, and pass through a one-layer linear NN.

Parameters:
  • comp (torch.Tensor) – Complex representation

  • ligand (torch.Tensor) – Ligand representation

  • protein (torch.Tensor) – Protein representation

Returns:

Predicted \(\Delta G\) value

Return type:

torch.Tensor