mtenn.combination.MeanCombination

class mtenn.combination.MeanCombination(*args, **kwargs)[source]

Bases: Combination

Combine a list of predictions by taking the mean. See the docs for MeanCombinationFunc for more details.

Methods

forward(pred_list, grad_dict, param_names, ...)

This function signature should be the same for any Combination subclass implementation.

forward(pred_list, grad_dict, param_names, *model_params)[source]

This function signature should be the same for any Combination subclass implementation. The return values should be:

  • torch.Tensor: Scalar-value tensor giving the final combined prediction

  • torch.Tensor: Tensor of shape (n_predictions,) giving the input per-pose predictions. This is necessary for Pytorch to track the gradients of these predictions in the case of eg a cross-entropy loss on the per-pose predictions

Parameters:
  • pred_list (List[torch.Tensor]) – List of \(\mathrm{\Delta G}\) predictions to be combined, shape of (n_predictions,)

  • grad_dict (dict[str, List[torch.Tensor]]) – Dict mapping from parameter name to list of gradients. Should contain n_model_parameters entries, with each entry mapping to a list of n_predictions tensors. Each of these tensors is a detach ed gradient so the shape of each tensor will depend on the model parameter it corresponds to, but the shapes of each tensor in any given entry should be identical

  • param_names (List[str]) – List of parameter names. Should contain n_model_parameters entries, corresponding 1:1 with the keys in grad_dict

  • model_params (List[torch.Tensor]) – Actual parameters that we’ll return the gradients for. Each param should be passed directly for the backward pass to work right. These tensors should correspond 1:1 with and should be in the same order as the entries in param_names (ie the i th entry in param_names should be the name of the i th model parameter in model_params)