mtenn.combination.MeanCombinationFunc

class mtenn.combination.MeanCombinationFunc(*args, **kwargs)[source]

Bases: Function

Custom autograd function that will handle the gradient math for us for combining \(\mathrm{\Delta G}\) predictions to their mean.

\[\Delta \text{G}(\theta) = \frac{1}{N} \sum_{i=1}^{N} \Delta \text{G}_i (\theta)\]

See Mean Combination for more details on the math.

Methods

backward(ctx, comb_grad, pose_grads)

Compute and return gradients for each parameter.

forward(pred_list, grad_dict, param_names, ...)

Take the mean of all input \(\mathrm{\Delta G}\) predictions.

setup_context(ctx, inputs, output)

Store data for backward pass.

static forward(pred_list, grad_dict, param_names, *model_params)[source]

Take the mean of all input \(\mathrm{\Delta G}\) predictions.

Parameters:
  • pred_list (List[torch.Tensor]) – List of \(\mathrm{\Delta G}\) predictions to be combined, shape of (n_predictions,)

  • grad_dict (dict[str, List[torch.Tensor]]) – Dict mapping from parameter name to list of gradients. Should contain n_model_parameters entries, with each entry mapping to a list of n_predictions tensors. Each of these tensors is a detach ed gradient so the shape of each tensor will depend on the model parameter it corresponds to, but the shapes of each tensor in any given entry should be identical

  • param_names (List[str]) – List of parameter names. Should contain n_model_parameters entries, corresponding 1:1 with the keys in grad_dict

  • model_params (List[torch.Tensor]) – Actual parameters that we’ll return the gradients for. Each param should be passed directly for the backward pass to work right. These tensors should correspond 1:1 with and should be in the same order as the entries in param_names (ie the i th entry in param_names should be the name of the i th model parameter in model_params)

Returns:

  • torch.Tensor – Scalar-value tensor giving the mean of the input \(\mathrm{\Delta G}\) predictions

  • torch.Tensor – Tensor of shape (n_predictions,) giving the input per-pose predictions

static setup_context(ctx, inputs, output)[source]

Store data for backward pass.

Parameters:
  • ctx – Pytorch context manager

  • inputs (List) – List containing all the parameters that will get passed to forward

  • output (torch.Tensor) – Value returned from forward

static backward(ctx, comb_grad, pose_grads)[source]

Compute and return gradients for each parameter.

Parameters:
  • ctx – Pytorch context manager

  • comb_grad (torch.Tensor) – Scalar-value tensor giving the \(\frac{\partial L}{\partial \Delta \text{G}}\) term from (1)

  • pose_grads (torch.Tensor) – Tensor of shape (n_predictions,), giving the \(\frac{\partial L}{\partial \Delta \text{G}_i}\) terms from (2)