mtenn.combination.Combination
- class mtenn.combination.Combination(*args, **kwargs)[source]
Bases:
Module,ABCAbstract base class for the
Combinationblock. Any subclass needs to implement theforwardmethod in order to be used.This class is designed to just be a wrapper around a
torch.autograd.Functionsubclass, as described in the guide.Methods
forward(pred_list, grad_dict, param_names, ...)This function signature should be the same for any
Combinationsubclass implementation.join_grad_dict(grad_dict_keys, grad_dict_tensors)Helper method used by all
Combinationclasses to reconstruct thegrad_dictfrom keys and grad tensors.split_grad_dict(grad_dict)Helper method used by all
Combinationclasses to split up the passedgrad_dictfor saving by context manager.- abstract forward(pred_list, grad_dict, param_names, *model_params)[source]
This function signature should be the same for any
Combinationsubclass implementation. The return values should be:torch.Tensor: Scalar-value tensor giving the final combined predictiontorch.Tensor: Tensor of shape(n_predictions,)giving the input per-pose predictions. This is necessary forPytorchto track the gradients of these predictions in the case of eg a cross-entropy loss on the per-pose predictions
- Parameters:
pred_list (List[torch.Tensor]) – List of \(\mathrm{\Delta G}\) predictions to be combined, shape of
(n_predictions,)grad_dict (dict[str, List[torch.Tensor]]) – Dict mapping from parameter name to list of gradients. Should contain
n_model_parametersentries, with each entry mapping to a list ofn_predictionstensors. Each of these tensors is adetached gradient so the shape of each tensor will depend on the model parameter it corresponds to, but the shapes of each tensor in any given entry should be identicalparam_names (List[str]) – List of parameter names. Should contain
n_model_parametersentries, corresponding 1:1 with the keys ingrad_dictmodel_params (List[torch.Tensor]) – Actual parameters that we’ll return the gradients for. Each param should be passed directly for the backward pass to work right. These tensors should correspond 1:1 with and should be in the same order as the entries in
param_names(ie theith entry inparam_namesshould be the name of theith model parameter inmodel_params)
- static split_grad_dict(grad_dict)[source]
Helper method used by all
Combinationclasses to split up the passedgrad_dictfor saving by context manager.- Parameters:
grad_dict (Dict[str, List[torch.Tensor]]) – Dict mapping from parameter name to list of gradients
- Returns:
List[str] – Key in
grad_dictcorresponding 1:1 with the gradientsList[torch.Tensor] – Gradients from
grad_dictcorresponding 1:1 with the keys
- static join_grad_dict(grad_dict_keys, grad_dict_tensors)[source]
Helper method used by all
Combinationclasses to reconstruct thegrad_dictfrom keys and grad tensors.- Parameters:
grad_dict_keys (List[str]) – Key in
grad_dictcorresponding 1:1 with the gradientsgrad_dict_tensors (List[torch.Tensor]) – Gradients from
grad_dictcorresponding 1:1 with the keys
- Returns:
Dict mapping from parameter name to list of gradients
- Return type:
Dict[str, List[torch.Tensor]]