mtenn.combination.MaxCombination
- class mtenn.combination.MaxCombination(negate_preds=True, pred_scale=1000.0)[source]
Bases:
CombinationApproximate max/min of the predictions using the LogSumExp function for smoothness. See the docs for
MaxCombinationFuncfor more details.Methods
__init__([negate_preds, pred_scale])- Parameters:
negate_preds (bool, default=True) -- Negate the predictions before calculating the LSE, effectively finding
forward(pred_list, grad_dict, param_names, ...)This function signature should be the same for any
Combinationsubclass implementation.- __init__(negate_preds=True, pred_scale=1000.0)[source]
- Parameters:
negate_preds (bool, default=True) – Negate the predictions before calculating the LSE, effectively finding the min. Preds are negated again before being returned
pred_scale (float, default=1000.0) – Fixed positive value to scale predictions by before taking the LSE. This tightens the bounds of the LSE approximation
- forward(pred_list, grad_dict, param_names, *model_params)[source]
This function signature should be the same for any
Combinationsubclass implementation. The return values should be:torch.Tensor: Scalar-value tensor giving the final combined predictiontorch.Tensor: Tensor of shape(n_predictions,)giving the input per-pose predictions. This is necessary forPytorchto track the gradients of these predictions in the case of eg a cross-entropy loss on the per-pose predictions
- Parameters:
pred_list (List[torch.Tensor]) – List of \(\mathrm{\Delta G}\) predictions to be combined, shape of
(n_predictions,)grad_dict (dict[str, List[torch.Tensor]]) – Dict mapping from parameter name to list of gradients. Should contain
n_model_parametersentries, with each entry mapping to a list ofn_predictionstensors. Each of these tensors is adetached gradient so the shape of each tensor will depend on the model parameter it corresponds to, but the shapes of each tensor in any given entry should be identicalparam_names (List[str]) – List of parameter names. Should contain
n_model_parametersentries, corresponding 1:1 with the keys ingrad_dictmodel_params (List[torch.Tensor]) – Actual parameters that we’ll return the gradients for. Each param should be passed directly for the backward pass to work right. These tensors should correspond 1:1 with and should be in the same order as the entries in
param_names(ie theith entry inparam_namesshould be the name of theith model parameter inmodel_params)