Source code for mtenn.combination

"""
Implementations for the ``Combination`` block in a :py:class:`GroupedModel
<mtenn.model.GroupedModel>`.

The ``Combination`` block is responsible for combining multiple single-pose model
predictions into a single multi-pose prediction. For more details on the implementation
of these classes, see the :ref:`comb-docs-page` docs page and the guide on
:ref:`new-combination-guide`.

All equations referenced here correspond to those in :ref:`implemented-combs`.
"""

import abc
import torch


[docs] class Combination(torch.nn.Module, abc.ABC): """ Abstract base class for the ``Combination`` block. Any subclass needs to implement the ``forward`` method in order to be used. This class is designed to just be a wrapper around a ``torch.autograd.Function`` subclass, as described in :ref:`the guide <new-combination-guide>`. """
[docs] @abc.abstractmethod def forward(self, pred_list, grad_dict, param_names, *model_params): """ This function signature should be the same for any ``Combination`` subclass implementation. The return values should be: * ``torch.Tensor``: Scalar-value tensor giving the final combined prediction * ``torch.Tensor``: Tensor of shape ``(n_predictions,)`` giving the input per-pose predictions. This is necessary for ``Pytorch`` to track the gradients of these predictions in the case of eg a cross-entropy loss on the per-pose predictions Parameters ---------- pred_list: List[torch.Tensor] List of :math:`\mathrm{\Delta G}` predictions to be combined, shape of ``(n_predictions,)`` grad_dict: dict[str, List[torch.Tensor]] Dict mapping from parameter name to list of gradients. Should contain ``n_model_parameters`` entries, with each entry mapping to a list of ``n_predictions`` tensors. Each of these tensors is a ``detach`` ed gradient so the shape of each tensor will depend on the model parameter it corresponds to, but the shapes of each tensor in any given entry should be identical param_names: List[str] List of parameter names. Should contain ``n_model_parameters`` entries, corresponding 1:1 with the keys in ``grad_dict`` model_params: List[torch.Tensor] Actual parameters that we'll return the gradients for. Each param should be passed directly for the backward pass to work right. These tensors should correspond 1:1 with and should be in the same order as the entries in ``param_names`` (ie the ``i`` th entry in ``param_names`` should be the name of the ``i`` th model parameter in ``model_params``) """ raise NotImplementedError("Must implement the `forward` method.")
[docs] @staticmethod def split_grad_dict(grad_dict): """ Helper method used by all ``Combination`` classes to split up the passed ``grad_dict`` for saving by context manager. Parameters ---------- grad_dict : Dict[str, List[torch.Tensor]] Dict mapping from parameter name to list of gradients Returns ------- List[str] Key in ``grad_dict`` corresponding 1:1 with the gradients List[torch.Tensor] Gradients from ``grad_dict`` corresponding 1:1 with the keys """ # Deconstruct grad_dict to be saved for backwards grad_dict_keys = [ k for k, grad_list in grad_dict.items() for _ in range(len(grad_list)) ] grad_dict_tensors = [ grad for grad_list in grad_dict.values() for grad in grad_list ] return grad_dict_keys, grad_dict_tensors
[docs] @staticmethod def join_grad_dict(grad_dict_keys, grad_dict_tensors): """ Helper method used by all ``Combination`` classes to reconstruct the ``grad_dict`` from keys and grad tensors. Parameters ---------- grad_dict_keys : List[str] Key in ``grad_dict`` corresponding 1:1 with the gradients grad_dict_tensors : List[torch.Tensor] Gradients from ``grad_dict`` corresponding 1:1 with the keys Returns ------- Dict[str, List[torch.Tensor]] Dict mapping from parameter name to list of gradients """ # Reconstruct grad_dict grad_dict = {} for k, grad in zip(grad_dict_keys, grad_dict_tensors): try: grad_dict[k].append(grad) except KeyError: grad_dict[k] = [grad] return grad_dict
[docs] class MeanCombination(Combination): """ Combine a list of predictions by taking the mean. See the docs for :py:class:`MeanCombinationFunc <mtenn.combination.MeanCombinationFunc>` for more details. """
[docs] def forward(self, pred_list, grad_dict, param_names, *model_params): return MeanCombinationFunc.apply( pred_list, grad_dict, param_names, *model_params )
[docs] class MeanCombinationFunc(torch.autograd.Function): """ Custom autograd function that will handle the gradient math for us for combining :math:`\mathrm{\Delta G}` predictions to their mean. .. math:: \Delta \\text{G}(\\theta) = \\frac{1}{N} \\sum_{i=1}^{N} \\Delta \\text{G}_i (\\theta) See :ref:`mean-comb-imp` for more details on the math. """
[docs] @staticmethod def forward(pred_list, grad_dict, param_names, *model_params): """ Take the mean of all input :math:`\mathrm{\Delta G}` predictions. Parameters ---------- pred_list: List[torch.Tensor] List of :math:`\mathrm{\Delta G}` predictions to be combined, shape of ``(n_predictions,)`` grad_dict: dict[str, List[torch.Tensor]] Dict mapping from parameter name to list of gradients. Should contain ``n_model_parameters`` entries, with each entry mapping to a list of ``n_predictions`` tensors. Each of these tensors is a ``detach`` ed gradient so the shape of each tensor will depend on the model parameter it corresponds to, but the shapes of each tensor in any given entry should be identical param_names: List[str] List of parameter names. Should contain ``n_model_parameters`` entries, corresponding 1:1 with the keys in ``grad_dict`` model_params: List[torch.Tensor] Actual parameters that we'll return the gradients for. Each param should be passed directly for the backward pass to work right. These tensors should correspond 1:1 with and should be in the same order as the entries in ``param_names`` (ie the ``i`` th entry in ``param_names`` should be the name of the ``i`` th model parameter in ``model_params``) Returns ------- torch.Tensor Scalar-value tensor giving the mean of the input :math:`\mathrm{\Delta G}` predictions torch.Tensor Tensor of shape ``(n_predictions,)`` giving the input per-pose predictions """ # Return mean of all preds all_preds = torch.stack(pred_list).flatten() final_pred = all_preds.mean(axis=None).detach() return final_pred, all_preds
[docs] @staticmethod def setup_context(ctx, inputs, output): """ Store data for backward pass. Parameters ---------- ctx Pytorch context manager inputs : List List containing all the parameters that will get passed to ``forward`` output : torch.Tensor Value returned from ``forward`` """ pred_list, grad_dict, param_names, *model_params = inputs grad_dict_keys, grad_dict_tensors = Combination.split_grad_dict(grad_dict) # Save non-Tensors for backward ctx.grad_dict_keys = grad_dict_keys ctx.param_names = param_names # Save Tensors for backward # Saving: # * Predictions (1 tensor) # * Grad tensors (N params * M poses tensors) # * Model param tensors (N params tensors) ctx.save_for_backward( torch.stack(pred_list).flatten(), *grad_dict_tensors, *model_params, )
[docs] @staticmethod def backward(ctx, comb_grad, pose_grads): """ Compute and return gradients for each parameter. Parameters ---------- ctx Pytorch context manager comb_grad : torch.Tensor Scalar-value tensor giving the :math:`\\frac{\\partial L}{\\partial \\Delta \\text{G}}` term from :eq:`comb-grad` pose_grads : torch.Tensor Tensor of shape ``(n_predictions,)``, giving the :math:`\\frac{\\partial L}{\\partial \\Delta \\text{G}_i}` terms from :eq:`pose-grad` """ # Unpack saved tensors preds, *other_tensors = ctx.saved_tensors # First section of these tensors are the flattened lists of gradients from each # individual pose or each model parameter grad_dict_tensors = other_tensors[: len(ctx.grad_dict_keys)] # Reconstruct dict mapping from model parameter name to list of gradient tensors # The ith entry in each list gives the gradient of the ith pose prediction wrt # that model parameter grad_dict = Combination.join_grad_dict(ctx.grad_dict_keys, grad_dict_tensors) # Calculate final gradients for each parameter final_grads = {} for n, grad_list in grad_dict.items(): # Compute the gradient contributions from any combined prediction loss, # according to eqns (1), (4) cur_final_grad = comb_grad * torch.stack(grad_list, axis=-1).mean(axis=-1) # Make sure lengths match up (should always be true but just in case) if len(pose_grads) != len(grad_list): raise RuntimeError("Mismatch in gradient lengths.") # Compute the gradient contributions from any per-pose prediction loss, # according to eqn (2) for pose_grad, param_grad in zip(pose_grads, grad_list): cur_final_grad += pose_grad * param_grad # Store total gradient for each parameter final_grads[n] = cur_final_grad.clone() # Return gradients for each of the model parameters that were passed in. Also # need to return values for the other values that were passed to forward # (pred_list, grad_dict, param_names), but these don't get gradients so we just # return None return_vals = [None] * 3 + [final_grads[n] for n in ctx.param_names] return tuple(return_vals)
[docs] class MaxCombination(Combination): """ Approximate max/min of the predictions using the LogSumExp function for smoothness. See the docs for :py:class:`MaxCombinationFunc <mtenn.combination.MaxCombinationFunc>` for more details. """
[docs] def __init__(self, negate_preds=True, pred_scale=1000.0): """ Parameters ---------- negate_preds : bool, default=True Negate the predictions before calculating the LSE, effectively finding the min. Preds are negated again before being returned pred_scale : float, default=1000.0 Fixed positive value to scale predictions by before taking the LSE. This tightens the bounds of the LSE approximation """ super(MaxCombination, self).__init__() self.negate_preds = negate_preds self.pred_scale = pred_scale
def __repr__(self): return f"MaxCombination(negate_preds={self.negate_preds}, pred_scale={self.pred_scale})" def __str__(self): return repr(self)
[docs] def forward(self, pred_list, grad_dict, param_names, *model_params): return MaxCombinationFunc.apply( self.negate_preds, self.pred_scale, pred_list, grad_dict, param_names, *model_params, )
[docs] class MaxCombinationFunc(torch.autograd.Function): """ Custom autograd function that will handle the gradient math for us for taking the max/min of the :math:`\mathrm{\Delta G}` predictions. For the ``forward`` pass, the final :math:`\mathrm{\Delta G}` prediction is calculated according to the following: .. math:: n = \\begin{cases} -1 & \\text{negate_preds} \\\\ \\phantom{-}1 & \\text{not negate_preds} \\end{cases} .. math:: t &= \\text{pred_scale} \Delta G &= n \\frac{1}{t} \mathrm{ln} \sum_{n=1}^N \mathrm{exp} (n t \Delta G_n) The logic and math behind this scaling approach are detailed `here <https://en.wikipedia.org/wiki/LogSumExp#Properties>`_. See :ref:`max-comb-imp` for more details on the math. """
[docs] @staticmethod def forward( negate_preds, pred_scale, pred_list, grad_dict, param_names, *model_params ): """ Find the max/min of all input :math:`\mathrm{\Delta G}` predictions. Parameters ---------- negate_preds: bool Negate the predictions before calculating the LSE, effectively finding the min. Preds are negated again before being returned pred_scale: float Fixed positive value to scale predictions by before taking the LSE. This tightens the bounds of the LSE approximation pred_list: List[torch.Tensor] List of :math:`\mathrm{\Delta G}` predictions to be combined, shape of ``(n_predictions,)`` grad_dict: dict[str, List[torch.Tensor]] Dict mapping from parameter name to list of gradients. Should contain ``n_model_parameters`` entries, with each entry mapping to a list of ``n_predictions`` tensors. Each of these tensors is a ``detach`` ed gradient so the shape of each tensor will depend on the model parameter it corresponds to, but the shapes of each tensor in any given entry should be identical param_names: List[str] List of parameter names. Should contain ``n_model_parameters`` entries, corresponding 1:1 with the keys in ``grad_dict`` model_params: List[torch.Tensor] Actual parameters that we'll return the gradients for. Each param should be passed directly for the backward pass to work right. These tensors should correspond 1:1 with and should be in the same order as the entries in ``param_names`` (ie the ``i`` th entry in ``param_names`` should be the name of the ``i`` th model parameter in ``model_params``) Returns ------- torch.Tensor Scalar-value tensor giving the max/min of the input :math:`\mathrm{\Delta G}` predictions torch.Tensor Tensor of shape ``(n_predictions,)`` giving the input per-pose predictions """ # The value of negate_preds tells us if we are finding the max or min. If True, # we are finding the min and need to flip the sign of each individual # prediction, as well as the final combined prediction (this is the value n # described in the class docstring and associated implementation math section) negative_multiplier = -1 if negate_preds else 1 # Combine all torch tensors so we don't need to keep doing it at each step all_preds = torch.stack(pred_list).flatten() # We use adj_preds here to store the adjusted per-pose prediction values. These # values have been negated (if we are finding the min), and multiplied by our # scale value, if given # These values correspond to the values inside the exponential in eqn (5) (and # subsequent equations) adj_preds = negative_multiplier * pred_scale * all_preds.detach() # Although defining this intermediate value isn't as helpful/necessary in the # forward pass, we do so anyway for consistency with the backward pass, where # it will be necessary for numerical stability # This corresponds to eqn (6) Q = torch.logsumexp(adj_preds, dim=0) # Perform the inverse adjustments we applied to the per-pose predictions, giving # us (approximately) the original value of the max/min per-pose prediction final_pred = (negative_multiplier * Q / pred_scale).detach() return final_pred, all_preds
[docs] @staticmethod def setup_context(ctx, inputs, output): """ Store data for backward pass. Parameters ---------- ctx Pytorch context manager inputs : List List containing all the parameters that will get passed to ``forward`` output : torch.Tensor Values returned from ``forward`` """ # Unpack the inputs ( negate_preds, pred_scale, pred_list, grad_dict, param_names, *model_params, ) = inputs # Break the grad dict up into lists of keys and corresponding lists of gradients grad_dict_keys, grad_dict_tensors = Combination.split_grad_dict(grad_dict) # Save non-Tensors for backward ctx.negate_preds = negate_preds ctx.pred_scale = pred_scale ctx.grad_dict_keys = grad_dict_keys ctx.param_names = param_names # Save Tensors for backward # Saving: # * Predictions (1 tensor of shape (n_predictions,)) # * Grad tensors (N params * M poses tensors, where all gradients corresponding # to a given model parameter are adjacent, ie first M tensors are the # per-pose gradients for the first model parameter, etc) # * Model param tensors (N params tensors) ctx.save_for_backward( torch.stack(pred_list).flatten(), *grad_dict_tensors, *model_params, )
[docs] @staticmethod def backward(ctx, comb_grad, pose_grads): """ Compute and return gradients for each parameter. Parameters ---------- ctx Pytorch context manager comb_grad : torch.Tensor Scalar-value tensor giving the :math:`\\frac{\\partial L}{\\partial \\Delta \\text{G}}` term from :eq:`comb-grad` pose_grads : torch.Tensor Tensor of shape ``(n_predictions,)``, giving the :math:`\\frac{\\partial L}{\\partial \\Delta \\text{G}_i}` terms from :eq:`pose-grad` """ # Unpack saved tensors preds, *other_tensors = ctx.saved_tensors # First section of these tensors are the flattened lists of gradients from each # individual pose or each model parameter grad_dict_tensors = other_tensors[: len(ctx.grad_dict_keys)] # Reconstruct dict mapping from model parameter name to list of gradient tensors # The ith entry in each list gives the gradient of the ith pose prediction wrt # that model parameter grad_dict = Combination.join_grad_dict(ctx.grad_dict_keys, grad_dict_tensors) # Set negation multiplier for finding max/min (see docstring and associated # implementation math section for more details) negative_multiplier = -1 if ctx.negate_preds else 1 # We use adj_preds here to store the adjusted per-pose prediction values. These # values have been negated (if we are finding the min), and multiplied by our # scale value, if given # These values correspond to the values inside the exponential in eqn (5) (and # subsequent equations) adj_preds = negative_multiplier * ctx.pred_scale * preds.detach() # Calculate our normalizing constant (eqn (6)) Q = torch.logsumexp(adj_preds, dim=0) # Calculate final gradients for each parameter final_grads = {} for n, grad_list in grad_dict.items(): # Compute the gradient contributions from any combined prediction loss, # according to eqns (1), (9) cur_final_grad = comb_grad * ( torch.stack( [ grad * (pred - Q).exp() for grad, pred in zip(grad_list, adj_preds) ], axis=-1, ) .detach() .sum(axis=-1) ) # Make sure lengths match up (should always be true but just in case) if len(pose_grads) != len(grad_list): raise RuntimeError("Mismatch in gradient lengths.") # Compute the gradient contributions from any per-pose prediction loss, # according to eqn (2) for pose_grad, param_grad in zip(pose_grads, grad_list): cur_final_grad += pose_grad * param_grad # Store total gradient for each parameter final_grads[n] = cur_final_grad.clone() # Return gradients for each of the model parameters that were passed in. Also # need to return values for the other values that were passed to forward # (negate_preds, pred_scale, pred_list, grad_dict, param_names), but these # don't get gradients so we just return None return_vals = [None] * 5 + [final_grads[n] for n in ctx.param_names] return tuple(return_vals)