mtenn.combination.MeanCombination
- class mtenn.combination.MeanCombination(*args, **kwargs)[source]
Bases:
CombinationCombine a list of predictions by taking the mean. See the docs for
MeanCombinationFuncfor more details.Methods
forward(pred_list, grad_dict, param_names, ...)This function signature should be the same for any
Combinationsubclass implementation.- forward(pred_list, grad_dict, param_names, *model_params)[source]
This function signature should be the same for any
Combinationsubclass implementation. The return values should be:torch.Tensor: Scalar-value tensor giving the final combined predictiontorch.Tensor: Tensor of shape(n_predictions,)giving the input per-pose predictions. This is necessary forPytorchto track the gradients of these predictions in the case of eg a cross-entropy loss on the per-pose predictions
- Parameters:
pred_list (List[torch.Tensor]) – List of \(\mathrm{\Delta G}\) predictions to be combined, shape of
(n_predictions,)grad_dict (dict[str, List[torch.Tensor]]) – Dict mapping from parameter name to list of gradients. Should contain
n_model_parametersentries, with each entry mapping to a list ofn_predictionstensors. Each of these tensors is adetached gradient so the shape of each tensor will depend on the model parameter it corresponds to, but the shapes of each tensor in any given entry should be identicalparam_names (List[str]) – List of parameter names. Should contain
n_model_parametersentries, corresponding 1:1 with the keys ingrad_dictmodel_params (List[torch.Tensor]) – Actual parameters that we’ll return the gradients for. Each param should be passed directly for the backward pass to work right. These tensors should correspond 1:1 with and should be in the same order as the entries in
param_names(ie theith entry inparam_namesshould be the name of theith model parameter inmodel_params)