Source code for mtenn.model

"""
This module contains the actual models that are used for making predictions. More
information on how everything here works is in the :ref:`docs page <model-docs-page>`.
"""

import os
import torch

from mtenn.combination import Combination
from mtenn.representation import Representation
from mtenn.strategy import ComplexOnlyStrategy, Strategy
from mtenn.readout import Readout


[docs] class Model(torch.nn.Module): """ Model object containing a ``Representation`` block that will take an input and convert it into some representation, and a ``Strategy`` block that will take a complex representation and any number of constituent "part" representations, and convert to a final scalar value. """
[docs] def __init__(self, representation, strategy, readout=None, fix_device=False): """ Build a ``Model``. Parameters ---------- representation : Representation ``Representation`` block for this model strategy : Strategy ``Strategy`` block for this model readout : Readout, optional ``Readout`` block for this model fix_device: bool, default=False If True, make sure the input is on the same device as the model, copying over as necessary. """ super(Model, self).__init__() self.representation: Representation = representation self.strategy: Strategy = strategy self.readout: Readout = readout self.fix_device = fix_device
[docs] def get_representation(self, *args, **kwargs): """ Pass a structure through the model's ``Representation`` block. All arguments are passed directly to ``self.representation``. """ return self.representation(*args, **kwargs)
[docs] def forward(self, comp, *parts): """ Handles all the logic detailed in the :ref:`docs page <single-pose-model-docs>`. Parameters ---------- comp : dict Complex structure that will be passed to the ``Representation`` block part : list[dict], optional Structures for all individual parts of the complex (eg ligand and protein separately). If this is not passed, the constituent parts will be automatically parsed from ``comp`` Returns ------- torch.Tensor Final model prediction. If the model has a ``readout``, this value will have the ``readout`` applied list[torch.Tensor] A list containing only the pre-``readout`` model prediction. This value is returned mainly to align the signatures for the single- and multi-pose models """ # This implementation of the forward function assumes the # get_representation function takes a single data object tmp_comp = self._fix_device(comp) complex_rep = self.get_representation(tmp_comp) if isinstance(self.strategy, ComplexOnlyStrategy): parts_rep = [] else: if len(parts) == 0: parts = Model._split_parts(tmp_comp) parts_rep = [self.get_representation(self._fix_device(p)) for p in parts] energy_val = self.strategy(complex_rep, *parts_rep) if self.readout: return self.readout(energy_val), [energy_val] else: return energy_val, [energy_val]
def _fix_device(self, data): """ Make sure that the pose tensors are on the same device as the model before attempting to call the model. Note that if ``self.fix_device`` is ``False``, this function does nothing. Also note that this function uses the torch ``to`` function, which means that if a tensor is on the wrong device, a copy of the tensor will be returned, whereas if the tensor is already on the correct device the original tensor will be returned. Parameters ---------- data : dict Structure pose Returns ------- dict New dict with all tensors on the appropriate device """ # We'll call this on everything for uniformity, but if we fix_device is # False we can just return if not self.fix_device: return data device = next(self.parameters()).device tmp_data = {} for k, v in data.items(): try: tmp_data[k] = v.to(device) except AttributeError: tmp_data[k] = v return tmp_data @staticmethod def _split_parts(comp): """ Helper method to split up the complex representation into different parts for protein and ligand. Parameters ---------- comp : dict[str, object] Dictionary representing the complex data object. Must have "lig" as a key that contains the index for splitting the data. Returns ------- dict[str, object] Protein representation dict[str, object] Ligand representation """ try: idx = comp["lig"] except KeyError: raise RuntimeError('Data object has no key "lig".') prot_rep = {} lig_rep = {} for k, v in comp.items(): if type(v) is not torch.Tensor: prot_rep[k] = v lig_rep[k] = v else: prot_rep[k] = v[~idx] lig_rep[k] = v[idx] return prot_rep, lig_rep
[docs] class GroupedModel(Model): """ Subclass of :py:class:`Model <mtenn.model.Model>` for use with multi-pose data, eg multiple docked poses of the same molecule with the same protein. In addition to the ``Representation`` and ``Strategy`` blocks in the ``Model`` class, ``GroupedModel`` also has a ``Combination`` block that dictates how the individual ``Model`` predictions for each item in the group of data are combined. """
[docs] def __init__( self, representation, strategy, combination, pred_readout=None, comb_readout=None, fix_device=False, ): """ The ``representation``, ``strategy``, and ``pred_readout`` args will be used to initialize the underlying :py:class:`Model <mtenn.model.Model>`, while the ``combination`` and ``comb_readout`` args will be applied to the output of the individual pose predictions. Parameters ---------- representation : Representation ``Representation`` block for this model strategy : Strategy ``Strategy`` block for this model combination : Combination ``Combination`` block for this model pred_readout : Readout, optional ``Readout`` block for the individual pose predictions comb_readout : Readout, optional ``Readout`` block for the combined pose fix_device: bool, default=False If True, make sure the input is on the same device as the model, copying over as necessary. """ super(GroupedModel, self).__init__( representation, strategy, pred_readout, fix_device ) self.combination: Combination = combination self.comb_readout: Readout = comb_readout
[docs] def forward(self, input_list): """ Handles all the logic detailed in the :ref:`docs page <multi-pose-model-docs>`. Parameters ---------- input_list : list[tuple[dict]] List of tuples of (complex representation, part representations) Returns ------- torch.Tensor Final multi-pose model prediction list[torch.Tensor] A list containing the pre-``pred_readout`` prediction for each entry in ``input_list`` """ # Get predictions for all inputs in the list, and combine them in a # tensor (while keeping track of gradients) pred_list = [] grad_dict = {} for i, inp in enumerate(input_list): if "MTENN_VERBOSE" in os.environ: print(f"pose {i}", flush=True) print( "size", ", ".join( [ f"{k}: {v.shape} ({v.dtype})" for k, v in inp.items() if type(v) is torch.Tensor ] ), sum([len(p.flatten()) for p in self.parameters()]), f"{torch.cuda.memory_allocated():,}", flush=True, ) # First get prediction pred, _ = super().forward(inp) pred_list.append(pred.detach()) # Get gradient per sample (only needed if we're doing training) self.zero_grad() if pred.requires_grad: pred.backward() for n, p in self.named_parameters(): try: grad_dict[n].append(p.grad.detach()) except KeyError: grad_dict[n] = [p.grad.detach()] # Zero grads again just to make sure nothing gets accumulated self.zero_grad() # Separate out param names and params param_names, model_params = zip(*self.named_parameters()) comb_pred, comb_pred_list = self.combination( pred_list, grad_dict, param_names, *model_params ) if self.comb_readout: return self.comb_readout(comb_pred), comb_pred_list else: return comb_pred, comb_pred_list
[docs] class LigandOnlyModel(Model): """ A ligand-only version of the ``Model``. In this case, the ``representation`` block will hold the entire model, while the ``strategy`` block will simply be set as an Identity module. """
[docs] def __init__(self, model, readout=None, fix_device=False): """ Build a ``LigandOnlyModel``. Parameters ---------- model This can be any kind of model that will go from a single input representation to a prediction (eg a :py:class:`GAT <mtenn.conversion_utils.gat.GAT` instance) fix_device: bool, default=False If True, make sure the input is on the same device as the model, copying over as necessary. """ super(LigandOnlyModel, self).__init__( representation=model, strategy=torch.nn.Identity(), readout=readout, fix_device=fix_device, )
[docs] def forward(self, rep): """ Handles all the logic detailed in the :ref:`docs page <ligand-only-model-docs>`. Parameters ---------- rep Whatever input representation the unerlying model takes Returns ------- torch.Tensor Final model prediction. If the model has a ``readout``, this value will have the ``readout`` applied list[torch.Tensor] A list containing only the pre-``readout`` model prediction. This value is returned mainly to align the signatures for the single- and multi-pose models """ # This implementation of the forward function assumes the # get_representation function takes a single data object tmp_rep = self._fix_device(rep) pred = self.get_representation(tmp_rep) if self.readout: return self.readout(pred), [pred] else: return pred, [pred]
[docs] class SplitModel(torch.nn.Module): """ Model object containing a ``Representation`` block that will take an input and convert it into some representation, and a ``Strategy`` block that will take a complex representation and any number of constituent "part" representations, and convert to a final scalar value. """
[docs] def __init__( self, complex_representation, strategy, ligand_representation=None, protein_representation=None, readout=None, fix_device=False, ): """ Build a ``Model``. Parameters ---------- complex_representation : Representation ``Representation`` block for the complex strategy : Strategy ``Strategy`` block for this model ligand_representation : Representation, optional ``Representation`` block for the ligand. Leave unset to use the ``Representation`` block in ``complex_representation`` protein_representation : Representation, optional ``Representation`` block for the protein. Leave unset to use the ``Representation`` block in ``complex_representation`` readout : Readout, optional ``Readout`` block for this model fix_device: bool, default=False If True, make sure the input is on the same device as the model, copying over as necessary. """ super(SplitModel, self).__init__() self.complex_representation: Representation = complex_representation if ligand_representation is None: ligand_representation = complex_representation if protein_representation is None: protein_representation = complex_representation self.ligand_representation: Representation = ligand_representation self.protein_representation: Representation = protein_representation self.strategy: Strategy = strategy self.readout: Readout = readout self.fix_device = fix_device
[docs] def get_representation(self, *args, **kwargs): """ Pass a structure through the model's ``Representation`` block. All arguments are passed directly to ``self.representation``. """ return self.representation(*args, **kwargs)
[docs] def forward(self, comp, prot=None, lig=None): """ Handles all the logic detailed in the :ref:`docs page <single-pose-model-docs>`. This class assumes the only data in ``parts`` is the protein and ligand, in that order. Parameters ---------- comp : dict Complex structure that will be passed to the ``Representation`` block part : list[dict], optional Structures for protein and ligand. If this is not passed, the constituent parts will be automatically parsed from ``comp`` Returns ------- torch.Tensor Final model prediction. If the model has a ``readout``, this value will have the ``readout`` applied list[torch.Tensor] A list containing only the pre-``readout`` model prediction. This value is returned mainly to align the signatures for the single- and multi-pose models """ # This implementation of the forward function assumes the # get_representation function takes a single data object tmp_comp = self._fix_device(comp) complex_rep = self.complex_representation(tmp_comp) if isinstance(self.strategy, ComplexOnlyStrategy): lig_rep = None prot_rep = None else: if (prot is None) and (lig is None): prot, lig = Model._split_parts(tmp_comp) lig_rep = self.ligand_representation(self._fix_device(lig)) prot_rep = self.protein_representation(self._fix_device(prot)) energy_val = self.strategy(comp=complex_rep, prot=prot_rep, lig=lig_rep) if self.readout: return self.readout(energy_val), [energy_val] else: return energy_val, [energy_val]
def _fix_device(self, data): """ Make sure that the pose tensors are on the same device as the model before attempting to call the model. Note that if ``self.fix_device`` is ``False``, this function does nothing. Also note that this function uses the torch ``to`` function, which means that if a tensor is on the wrong device, a copy of the tensor will be returned, whereas if the tensor is already on the correct device the original tensor will be returned. Parameters ---------- data : dict Structure pose Returns ------- dict New dict with all tensors on the appropriate device """ # We'll call this on everything for uniformity, but if we fix_device is # False we can just return if not self.fix_device: return data device = next(self.parameters()).device tmp_data = {} for k, v in data.items(): try: tmp_data[k] = v.to(device) except AttributeError: tmp_data[k] = v return tmp_data @staticmethod def _split_parts(comp): """ Helper method to split up the complex representation into different parts for protein and ligand. Parameters ---------- comp : dict[str, object] Dictionary representing the complex data object. Must have "lig" as a key that contains the index for splitting the data. Returns ------- dict[str, object] Protein representation dict[str, object] Ligand representation """ try: idx = comp["lig"] except KeyError: raise RuntimeError('Data object has no key "lig".') prot_rep = {} lig_rep = {} for k, v in comp.items(): if type(v) is not torch.Tensor: prot_rep[k] = v lig_rep[k] = v else: prot_rep[k] = v[~idx] lig_rep[k] = v[idx] return prot_rep, lig_rep