Source code for mtenn.strategy

"""
Implementations for the ``Strategy`` block in a :py:class:`Model
<mtenn.model.Model>` or :py:class:`GroupedModel <mtenn.model.GroupedModel>`.
"""

import abc
from itertools import permutations
import torch


[docs] class Strategy(torch.nn.Module, abc.ABC): """ Abstract base class for the ``Strategy`` block. Any subclass needs to implement the ``forward`` method in order to be used. """
[docs] @abc.abstractmethod def forward(self, comp, *parts): """ For any strategy class, this function should take a complex representation and (optionally) any number of "part" representations, and return a single :math:`\\mathrm{\\Delta G}` prediction. """ raise NotImplementedError("Must implement the `forward` method.")
[docs] class DeltaStrategy(Strategy): """ Simple strategy for subtracting the sum of the individual component energies from the complex energy. This ``Strategy`` requires an ``energy_func`` :math:`\\phi: \\mathbb{R}^n \\rightarrow \\mathbb{R}` that maps from an n-dimensional vector representation (output from a ``Representation`` block) to a scalar-value energy prediction. .. math:: \\mathrm{G} &= \\phi (\\mathrm{\\boldsymbol{x}}) \\Delta \\mathrm{G_{pred}} &= \\mathrm{G_{complex}} - \\sum_n \\mathrm{G}_n """
[docs] def __init__(self, energy_func): """ Store module for predicting an energy from representation. Parameters ---------- energy_func : torch.nn.Module Some torch module that will predict an energy from an n-dimension vector representation of a structure """ super(DeltaStrategy, self).__init__() self.energy_func: torch.nn.Module = energy_func
[docs] def forward(self, comp, *parts): """ Make energy predictions for each representation, and then perform the delta calculation. Parameters ---------- comp : torch.Tensor Complex representation that will be passed to ``self.energy_func`` parts : list[torch.Tensor], optional Representations for all individual parts of the complex (eg ligand and protein separately) that will be passed to ``self.energy_func`` Returns ------- torch.Tensor Predicted :math:`\\Delta G` value """ # Get energy predictions for each representation complex_pred = self.energy_func(comp) parts_preds = [self.energy_func(p) for p in parts] # Replace invalid predictions with 0 parts_preds = [ p if len(p.flatten()) > 0 else torch.zeros_like(complex_pred) for p in parts_preds ] # Calculate delta G dG_pred = complex_pred - sum(parts_preds) return dG_pred
[docs] class SplitDeltaStrategy(Strategy): """ Simple strategy for subtracting the sum of the individual component energies from the complex energy. This ``Strategy`` requires an individual ``energy_func`` (:math:`\\phi: \\mathbb{R}^n \\rightarrow \\mathbb{R}` that maps from an n-dimensional vector representation (output from a ``Representation`` block) to a scalar-value energy prediction) for the complex, ligand, and protein representations. As with :py:class:`SplitModel <mtenn.model.SplitModel>`, the ``complex_energy_func`` is required, and any missing values in the ``ligand_energy_func`` or ``protein_energy_func`` will be filled in with the ``complex_energy_func``. The point of this class is to be able to handle different dimensionalities in the different representations that may be generated by a ``SplitModel``. In the simplest case, one could like ``ligand_energy_func`` and ``protein_energy_func`` blank to achieve the same behavior as the standard :py:class:DeltaStrategy. .. math:: \\mathrm{G_{complex}} &= \\phi_{\\mathrm{complex}} (\\mathrm{\\boldsymbol{x}_{complex}}) \\mathrm{G_{ligand}} &= \\phi_{\\mathrm{ligand}} (\\mathrm{\\boldsymbol{x}_{ligand}}) \\mathrm{G_{protein}} &= \\phi_{\\mathrm{protein}} (\\mathrm{\\boldsymbol{x}_{protein}}) \\Delta \\mathrm{G_{pred}} &= \\mathrm{G_{complex}} - (\\mathrm{G_{ligand}} + \\mathrm{G_{protein}}) """
[docs] def __init__( self, complex_energy_func, ligand_energy_func=None, protein_energy_func=None ): """ Store module for predicting an energy from representation. Parameters ---------- complex_energy_func : torch.nn.Module Some torch module that will predict an energy from an n-dimension vector representation of a structure ligand_energy_func : torch.nn.Module, optional Some torch module that will predict an energy from an n-dimension vector representation of a structure protein_energy_func : torch.nn.Module, optional Some torch module that will predict an energy from an n-dimension vector representation of a structure """ super(SplitDeltaStrategy, self).__init__() self.complex_energy_func: torch.nn.Module = complex_energy_func if not isinstance(complex_energy_func, torch.nn.Module): raise ValueError("Passed complex_energy_func is not a pytorch model.") if ligand_energy_func is None: ligand_energy_func = complex_energy_func if protein_energy_func is None: protein_energy_func = complex_energy_func self.ligand_energy_func: torch.nn.Module = ligand_energy_func self.protein_energy_func: torch.nn.Module = protein_energy_func
[docs] def forward(self, comp, prot, lig): """ Make energy predictions for each representation, and then perform the delta calculation. Parameters ---------- comp : torch.Tensor Complex representation that will be passed to ``self.energy_func`` parts : list[torch.Tensor], optional Representations for all individual parts of the complex (eg ligand and protein separately) that will be passed to ``self.energy_func`` Returns ------- torch.Tensor Predicted :math:`\\Delta G` value """ # Get energy predictions for each representation complex_pred = self.complex_energy_func(comp) lig_pred = self.ligand_energy_func(lig) prot_pred = self.protein_energy_func(prot) # Calculate delta G dG_pred = complex_pred - (lig_pred + prot_pred) return dG_pred
[docs] class ConcatStrategy(Strategy): """ Strategy for combining the complex representation and parts representations in some learned manner, using sum-pooling to ensure permutation-invariance of the parts. For 3 n-dimensional input representations (eg complex, protein-only, and ligand-only), this ``Strategy`` acts as a function :math:`\\phi: \\mathbb{R}^{3n} \\rightarrow \\mathbb{R}` that predicts a scalar-value :math:`\\Delta G` prediction. The input :math:`\\mathrm{\\boldsymbol{x}}` to :math:`\\phi` is computed in a permutation-invariant manner. For a protein-ligand complex, this looks like: .. math:: \\mathrm{\\boldsymbol{x}_{parts}} &= [\\mathrm{\\boldsymbol{x}_{protein}}, \\mathrm{\\boldsymbol{x}_{ligand}}] + [\\mathrm{\\boldsymbol{x}_{ligand}}, \\mathrm{\\boldsymbol{x}_{protein}}] \\mathrm{\\boldsymbol{x}} &= [\\mathrm{\\boldsymbol{x}_{complex}}, \\mathrm{\\boldsymbol{x}_{parts}}] \\Delta \\mathrm{G_{pred}} &= \\phi (\\mathrm{\\boldsymbol{x}}) In general, we will sum every permutation of the non-complex representations, and then this sum will be concatenated to the complex representation. In its current iteration, this ``Strategy`` does not require you to specify the dimensionality (:math:`n`) of each representation. Instead, the first time an instance of this ``Strategy`` is used, it will calculate the required input size and initialize a one-layer linear network of the appropriate dimensionality. """
[docs] def __init__(self, input_size, extract_key=None, layer_norm=False): """ Set the key to use to access vector representations if ``dict`` s are passed to the ``forward`` call. Parameters ---------- input_size : int Input size of linear model extract_key : str, optional Key to use to extract representation from a dict layer_norm: bool, default=False Apply a ``LayerNorm`` normalization before passing through the linear layer """ super(ConcatStrategy, self).__init__() if layer_norm: self.reduce_nn = torch.nn.Sequential( torch.nn.LayerNorm(input_size), torch.nn.Linear(input_size, 1) ) else: self.reduce_nn = torch.nn.Linear(input_size, 1) self.extract_key = extract_key
[docs] def forward(self, comp, *parts): """ Calculate permutation-invariant concatenation of all representations, and pass through a one-layer linear NN. This network will be initialized based on the input sizes the first time this method is called for a given instance of this class. Parameters ---------- comp : torch.Tensor Complex representation parts : list[torch.Tensor], optional Representations for all individual parts of the complex (eg ligand and protein separately) Returns ------- torch.Tensor Predicted :math:`\\Delta G` value """ # Extract representation from dict if self.extract_key and isinstance(comp, dict): comp = comp[self.extract_key] parts = [p[self.extract_key] for p in parts] # Flatten tensors comp = comp.flatten() parts = [p.flatten() for p in parts] # Enumerate all possible permutations of parts and add together parts_size = sum([len(p) for p in parts]) parts_cat = torch.zeros((parts_size), device=comp.device) for idxs in permutations(range(len(parts)), len(parts)): parts_cat += torch.cat([parts[i] for i in idxs]) # Concat comp w/ permut-invariant parts representation full_embedded = torch.cat([comp, parts_cat]) return self.reduce_nn(full_embedded)
[docs] class SplitConcatStrategy(Strategy): """ Strategy for combining the complex representation and parts representations in some learned manner. This implementation assumes different sizes for the protein and ligand representations, and so skips the sum-pooling step. If possible, it's recommended to use :py:class:ConcatStrategy. For representation sizes of :math:`d` for the complex and protein and :math:`l` for the ligand, this ``Strategy`` acts as a function :math:`\\phi: \\mathbb{R}^{2d+l} \\rightarrow \\mathbb{R}` that predicts a scalar-value :math:`\\Delta G` prediction. We concatenate the representations in the order of complex, ligand, protein: .. math:: \\mathrm{\\boldsymbol{x}} &= [\\mathrm{\\boldsymbol{x}_{complex}}, \\mathrm{\\boldsymbol{x}_{ligand}}, \\mathrm{\\boldsymbol{x}_{protein}}] \\Delta \\mathrm{G_{pred}} &= \\phi (\\mathrm{\\boldsymbol{x}}) """
[docs] def __init__( self, input_size, complex_extract_key=None, ligand_extract_key=None, protein_extract_key=None, layer_norm=False, ): """ Set the key to use to access vector representations if ``dict`` s are passed to the ``forward`` call. Parameters ---------- input_size : int Input size of linear model complex_extract_key : str, optional Key to use to extract representation from a dict for the complex ligand_extract_key : str, optional Key to use to extract representation from a dict for the ligand protein_extract_key : str, optional Key to use to extract representation from a dict for the protein layer_norm: bool, default=False Apply a ``LayerNorm`` normalization before passing through the linear layer """ super(ConcatStrategy, self).__init__() if layer_norm: self.reduce_nn = torch.nn.Sequential( torch.nn.LayerNorm(input_size), torch.nn.Linear(input_size, 1) ) else: self.reduce_nn = torch.nn.Linear(input_size, 1) self.extract_keys = [ complex_extract_key, ligand_extract_key, protein_extract_key, ]
[docs] def forward(self, comp, ligand, protein): """ Concatenate all representations, and pass through a one-layer linear NN. Parameters ---------- comp : torch.Tensor Complex representation ligand : torch.Tensor Ligand representation protein : torch.Tensor Protein representation Returns ------- torch.Tensor Predicted :math:`\\Delta G` value """ # Extract representation from dict all_reps = [] for extract_key, rep in zip(self.extract_keys, [comp, ligand, protein]): if (extract_key is not None) and isinstance(rep, dict): all_reps.append(rep[self.extract_key].flatten()) else: all_reps.append(rep.flatten()) # Concat comp w/ permut-invariant parts representation full_embedded = torch.cat(all_reps) return self.reduce_nn(full_embedded)
[docs] class ComplexOnlyStrategy(Strategy): """ Strategy to only predict based on the complex representation. This is useful if you want to make a prediction on just the ligand or just the protein, and essentially just reduces to a standard version of whatever your underlying model is. """
[docs] def __init__(self, energy_func): """ Store module for predicting an energy from representation. Parameters ---------- energy_func : torch.nn.Module Some torch module that will predict an energy from an n-dimension vector representation of a structure """ super().__init__() self.energy_func: torch.nn.Module = energy_func
[docs] def forward(self, comp, *parts): """ Make energy prediction for the complex representation. Parameters ---------- comp : torch.Tensor Complex representation that will be passed to ``self.energy_func`` parts : list[torch.Tensor], optional Ignored, but present just to match the signatures Returns ------- torch.Tensor Predicted value """ complex_pred = self.energy_func(comp) return complex_pred