mtenn.strategy.SplitConcatStrategy
- class mtenn.strategy.SplitConcatStrategy(input_size, complex_extract_key=None, ligand_extract_key=None, protein_extract_key=None, layer_norm=False)[source]
Bases:
StrategyStrategy for combining the complex representation and parts representations in some learned manner. This implementation assumes different sizes for the protein and ligand representations, and so skips the sum-pooling step. If possible, it’s recommended to use :py:class:ConcatStrategy. For representation sizes of \(d\) for the complex and protein and \(l\) for the ligand, this
Strategyacts as a function \(\phi: \mathbb{R}^{2d+l} \rightarrow \mathbb{R}\) that predicts a scalar-value \(\Delta G\) prediction.We concatenate the representations in the order of complex, ligand, protein:
\[ \begin{align}\begin{aligned}\mathrm{\boldsymbol{x}} &= [\mathrm{\boldsymbol{x}_{complex}}, \mathrm{\boldsymbol{x}_{ligand}}, \mathrm{\boldsymbol{x}_{protein}}]\\\Delta \mathrm{G_{pred}} &= \phi (\mathrm{\boldsymbol{x}})\end{aligned}\end{align} \]Methods
__init__(input_size[, complex_extract_key, ...])Set the key to use to access vector representations if
dicts are passed to theforwardcall.forward(comp, ligand, protein)Concatenate all representations, and pass through a one-layer linear NN.
- __init__(input_size, complex_extract_key=None, ligand_extract_key=None, protein_extract_key=None, layer_norm=False)[source]
Set the key to use to access vector representations if
dicts are passed to theforwardcall.- Parameters:
input_size (int) – Input size of linear model
complex_extract_key (str, optional) – Key to use to extract representation from a dict for the complex
ligand_extract_key (str, optional) – Key to use to extract representation from a dict for the ligand
protein_extract_key (str, optional) – Key to use to extract representation from a dict for the protein
layer_norm (bool, default=False) – Apply a
LayerNormnormalization before passing through the linear layer
- forward(comp, ligand, protein)[source]
Concatenate all representations, and pass through a one-layer linear NN.
- Parameters:
comp (torch.Tensor) – Complex representation
ligand (torch.Tensor) – Ligand representation
protein (torch.Tensor) – Protein representation
- Returns:
Predicted \(\Delta G\) value
- Return type:
torch.Tensor