mtenn.strategy.ConcatStrategy
- class mtenn.strategy.ConcatStrategy(input_size, extract_key=None, layer_norm=False)[source]
Bases:
StrategyStrategy for combining the complex representation and parts representations in some learned manner, using sum-pooling to ensure permutation-invariance of the parts. For 3 n-dimensional input representations (eg complex, protein-only, and ligand-only), this
Strategyacts as a function \(\phi: \mathbb{R}^{3n} \rightarrow \mathbb{R}\) that predicts a scalar-value \(\Delta G\) prediction.The input \(\mathrm{\boldsymbol{x}}\) to \(\phi\) is computed in a permutation-invariant manner. For a protein-ligand complex, this looks like:
\[ \begin{align}\begin{aligned}\mathrm{\boldsymbol{x}_{parts}} &= [\mathrm{\boldsymbol{x}_{protein}}, \mathrm{\boldsymbol{x}_{ligand}}] + [\mathrm{\boldsymbol{x}_{ligand}}, \mathrm{\boldsymbol{x}_{protein}}]\\\mathrm{\boldsymbol{x}} &= [\mathrm{\boldsymbol{x}_{complex}}, \mathrm{\boldsymbol{x}_{parts}}]\\\Delta \mathrm{G_{pred}} &= \phi (\mathrm{\boldsymbol{x}})\end{aligned}\end{align} \]In general, we will sum every permutation of the non-complex representations, and then this sum will be concatenated to the complex representation.
In its current iteration, this
Strategydoes not require you to specify the dimensionality (\(n\)) of each representation. Instead, the first time an instance of thisStrategyis used, it will calculate the required input size and initialize a one-layer linear network of the appropriate dimensionality.Methods
__init__(input_size[, extract_key, layer_norm])Set the key to use to access vector representations if
dicts are passed to theforwardcall.forward(comp, *parts)Calculate permutation-invariant concatenation of all representations, and pass through a one-layer linear NN.
- __init__(input_size, extract_key=None, layer_norm=False)[source]
Set the key to use to access vector representations if
dicts are passed to theforwardcall.- Parameters:
input_size (int) – Input size of linear model
extract_key (str, optional) – Key to use to extract representation from a dict
layer_norm (bool, default=False) – Apply a
LayerNormnormalization before passing through the linear layer
- forward(comp, *parts)[source]
Calculate permutation-invariant concatenation of all representations, and pass through a one-layer linear NN. This network will be initialized based on the input sizes the first time this method is called for a given instance of this class.
- Parameters:
comp (torch.Tensor) – Complex representation
parts (list[torch.Tensor], optional) – Representations for all individual parts of the complex (eg ligand and protein separately)
- Returns:
Predicted \(\Delta G\) value
- Return type:
torch.Tensor