mtenn.combination.MaxCombinationFunc

class mtenn.combination.MaxCombinationFunc(*args, **kwargs)[source]

Bases: Function

Custom autograd function that will handle the gradient math for us for taking the max/min of the \(\mathrm{\Delta G}\) predictions.

For the forward pass, the final \(\mathrm{\Delta G}\) prediction is calculated according to the following:

\[\begin{split}n = \begin{cases} -1 & \text{negate_preds} \\ \phantom{-}1 & \text{not negate_preds} \end{cases}\end{split}\]
\[ \begin{align}\begin{aligned}t &= \text{pred_scale}\\\Delta G &= n \frac{1}{t} \mathrm{ln} \sum_{n=1}^N \mathrm{exp} (n t \Delta G_n)\end{aligned}\end{align} \]

The logic and math behind this scaling approach are detailed here.

See Max Combination for more details on the math.

Methods

backward(ctx, comb_grad, pose_grads)

Compute and return gradients for each parameter.

forward(negate_preds, pred_scale, pred_list, ...)

Find the max/min of all input \(\mathrm{\Delta G}\) predictions.

setup_context(ctx, inputs, output)

Store data for backward pass.

static forward(negate_preds, pred_scale, pred_list, grad_dict, param_names, *model_params)[source]

Find the max/min of all input \(\mathrm{\Delta G}\) predictions.

Parameters:
  • negate_preds (bool) – Negate the predictions before calculating the LSE, effectively finding the min. Preds are negated again before being returned

  • pred_scale (float) – Fixed positive value to scale predictions by before taking the LSE. This tightens the bounds of the LSE approximation

  • pred_list (List[torch.Tensor]) – List of \(\mathrm{\Delta G}\) predictions to be combined, shape of (n_predictions,)

  • grad_dict (dict[str, List[torch.Tensor]]) – Dict mapping from parameter name to list of gradients. Should contain n_model_parameters entries, with each entry mapping to a list of n_predictions tensors. Each of these tensors is a detach ed gradient so the shape of each tensor will depend on the model parameter it corresponds to, but the shapes of each tensor in any given entry should be identical

  • param_names (List[str]) – List of parameter names. Should contain n_model_parameters entries, corresponding 1:1 with the keys in grad_dict

  • model_params (List[torch.Tensor]) – Actual parameters that we’ll return the gradients for. Each param should be passed directly for the backward pass to work right. These tensors should correspond 1:1 with and should be in the same order as the entries in param_names (ie the i th entry in param_names should be the name of the i th model parameter in model_params)

Returns:

  • torch.Tensor – Scalar-value tensor giving the max/min of the input \(\mathrm{\Delta G}\) predictions

  • torch.Tensor – Tensor of shape (n_predictions,) giving the input per-pose predictions

static setup_context(ctx, inputs, output)[source]

Store data for backward pass.

Parameters:
  • ctx – Pytorch context manager

  • inputs (List) – List containing all the parameters that will get passed to forward

  • output (torch.Tensor) – Values returned from forward

static backward(ctx, comb_grad, pose_grads)[source]

Compute and return gradients for each parameter.

Parameters:
  • ctx – Pytorch context manager

  • comb_grad (torch.Tensor) – Scalar-value tensor giving the \(\frac{\partial L}{\partial \Delta \text{G}}\) term from (1)

  • pose_grads (torch.Tensor) – Tensor of shape (n_predictions,), giving the \(\frac{\partial L}{\partial \Delta \text{G}_i}\) terms from (2)