mtenn.combination.MeanCombinationFunc
- class mtenn.combination.MeanCombinationFunc(*args, **kwargs)[source]
Bases:
FunctionCustom autograd function that will handle the gradient math for us for combining \(\mathrm{\Delta G}\) predictions to their mean.
\[\Delta \text{G}(\theta) = \frac{1}{N} \sum_{i=1}^{N} \Delta \text{G}_i (\theta)\]See Mean Combination for more details on the math.
Methods
backward(ctx, comb_grad, pose_grads)Compute and return gradients for each parameter.
forward(pred_list, grad_dict, param_names, ...)Take the mean of all input \(\mathrm{\Delta G}\) predictions.
setup_context(ctx, inputs, output)Store data for backward pass.
- static forward(pred_list, grad_dict, param_names, *model_params)[source]
Take the mean of all input \(\mathrm{\Delta G}\) predictions.
- Parameters:
pred_list (List[torch.Tensor]) – List of \(\mathrm{\Delta G}\) predictions to be combined, shape of
(n_predictions,)grad_dict (dict[str, List[torch.Tensor]]) – Dict mapping from parameter name to list of gradients. Should contain
n_model_parametersentries, with each entry mapping to a list ofn_predictionstensors. Each of these tensors is adetached gradient so the shape of each tensor will depend on the model parameter it corresponds to, but the shapes of each tensor in any given entry should be identicalparam_names (List[str]) – List of parameter names. Should contain
n_model_parametersentries, corresponding 1:1 with the keys ingrad_dictmodel_params (List[torch.Tensor]) – Actual parameters that we’ll return the gradients for. Each param should be passed directly for the backward pass to work right. These tensors should correspond 1:1 with and should be in the same order as the entries in
param_names(ie theith entry inparam_namesshould be the name of theith model parameter inmodel_params)
- Returns:
torch.Tensor – Scalar-value tensor giving the mean of the input \(\mathrm{\Delta G}\) predictions
torch.Tensor – Tensor of shape
(n_predictions,)giving the input per-pose predictions
- static setup_context(ctx, inputs, output)[source]
Store data for backward pass.
- Parameters:
ctx – Pytorch context manager
inputs (List) – List containing all the parameters that will get passed to
forwardoutput (torch.Tensor) – Value returned from
forward