mtenn.strategy.DeltaStrategy

class mtenn.strategy.DeltaStrategy(energy_func)[source]

Bases: Strategy

Simple strategy for subtracting the sum of the individual component energies from the complex energy. This Strategy requires an energy_func \(\phi: \mathbb{R}^n \rightarrow \mathbb{R}\) that maps from an n-dimensional vector representation (output from a Representation block) to a scalar-value energy prediction.

\[ \begin{align}\begin{aligned}\mathrm{G} &= \phi (\mathrm{\boldsymbol{x}})\\\Delta \mathrm{G_{pred}} &= \mathrm{G_{complex}} - \sum_n \mathrm{G}_n\end{aligned}\end{align} \]

Methods

__init__(energy_func)

Store module for predicting an energy from representation.

forward(comp, *parts)

Make energy predictions for each representation, and then perform the delta calculation.

__init__(energy_func)[source]

Store module for predicting an energy from representation.

Parameters:

energy_func (torch.nn.Module) – Some torch module that will predict an energy from an n-dimension vector representation of a structure

forward(comp, *parts)[source]

Make energy predictions for each representation, and then perform the delta calculation.

Parameters:
  • comp (torch.Tensor) – Complex representation that will be passed to self.energy_func

  • parts (list[torch.Tensor], optional) – Representations for all individual parts of the complex (eg ligand and protein separately) that will be passed to self.energy_func

Returns:

Predicted \(\Delta G\) value

Return type:

torch.Tensor