mtenn.combination.MaxCombinationFunc
- class mtenn.combination.MaxCombinationFunc(*args, **kwargs)[source]
Bases:
FunctionCustom autograd function that will handle the gradient math for us for taking the max/min of the \(\mathrm{\Delta G}\) predictions.
For the
forwardpass, the final \(\mathrm{\Delta G}\) prediction is calculated according to the following:\[\begin{split}n = \begin{cases} -1 & \text{negate_preds} \\ \phantom{-}1 & \text{not negate_preds} \end{cases}\end{split}\]\[ \begin{align}\begin{aligned}t &= \text{pred_scale}\\\Delta G &= n \frac{1}{t} \mathrm{ln} \sum_{n=1}^N \mathrm{exp} (n t \Delta G_n)\end{aligned}\end{align} \]The logic and math behind this scaling approach are detailed here.
See Max Combination for more details on the math.
Methods
backward(ctx, comb_grad, pose_grads)Compute and return gradients for each parameter.
forward(negate_preds, pred_scale, pred_list, ...)Find the max/min of all input \(\mathrm{\Delta G}\) predictions.
setup_context(ctx, inputs, output)Store data for backward pass.
- static forward(negate_preds, pred_scale, pred_list, grad_dict, param_names, *model_params)[source]
Find the max/min of all input \(\mathrm{\Delta G}\) predictions.
- Parameters:
negate_preds (bool) – Negate the predictions before calculating the LSE, effectively finding the min. Preds are negated again before being returned
pred_scale (float) – Fixed positive value to scale predictions by before taking the LSE. This tightens the bounds of the LSE approximation
pred_list (List[torch.Tensor]) – List of \(\mathrm{\Delta G}\) predictions to be combined, shape of
(n_predictions,)grad_dict (dict[str, List[torch.Tensor]]) – Dict mapping from parameter name to list of gradients. Should contain
n_model_parametersentries, with each entry mapping to a list ofn_predictionstensors. Each of these tensors is adetached gradient so the shape of each tensor will depend on the model parameter it corresponds to, but the shapes of each tensor in any given entry should be identicalparam_names (List[str]) – List of parameter names. Should contain
n_model_parametersentries, corresponding 1:1 with the keys ingrad_dictmodel_params (List[torch.Tensor]) – Actual parameters that we’ll return the gradients for. Each param should be passed directly for the backward pass to work right. These tensors should correspond 1:1 with and should be in the same order as the entries in
param_names(ie theith entry inparam_namesshould be the name of theith model parameter inmodel_params)
- Returns:
torch.Tensor – Scalar-value tensor giving the max/min of the input \(\mathrm{\Delta G}\) predictions
torch.Tensor – Tensor of shape
(n_predictions,)giving the input per-pose predictions
- static setup_context(ctx, inputs, output)[source]
Store data for backward pass.
- Parameters:
ctx – Pytorch context manager
inputs (List) – List containing all the parameters that will get passed to
forwardoutput (torch.Tensor) – Values returned from
forward