mtenn.combination

Implementations for the Combination block in a GroupedModel.

The Combination block is responsible for combining multiple single-pose model predictions into a single multi-pose prediction. For more details on the implementation of these classes, see the Combination docs page and the guide on Implementing a new Combination.

All equations referenced here correspond to those in Math for Implemented Combinations.

Classes

Combination

Abstract base class for the Combination block.

MaxCombination

Approximate max/min of the predictions using the LogSumExp function for smoothness.

MaxCombinationFunc

Custom autograd function that will handle the gradient math for us for taking the max/min of the \(\mathrm{\Delta G}\) predictions.

MeanCombination

Combine a list of predictions by taking the mean.

MeanCombinationFunc

Custom autograd function that will handle the gradient math for us for combining \(\mathrm{\Delta G}\) predictions to their mean.