mtenn.conversion_utils.schnet.MemoizedRadiusInteractionGraph

class mtenn.conversion_utils.schnet.MemoizedRadiusInteractionGraph(lookup_function: Callable | None = None, cutoff: float = 10.0, max_num_neighbors: int = 32)[source]

Bases: RadiusInteractionGraph

Memoized version of the PyG RadiusInteractionGraph. The lookup function used to map the position and batch tensors to the edge index/weight tensors can be any function with the appropriate signature, but defaults to a string of the number of atoms and the x, y, and z coords of the first and last two atoms to 7 decimal points. Note that this of course relies on the atoms for a given sample being in the same order each time.

Methods

__init__([lookup_function, cutoff, ...])

Initialize underlying RadiusInteractionGraph and define lookup function for storing computed results.

forward(pos, batch)

Perform forward pass of RadiusInteractionGraph class, first checking to see if this calculation has already been done.

__init__(lookup_function: Callable | None = None, cutoff: float = 10.0, max_num_neighbors: int = 32)[source]

Initialize underlying RadiusInteractionGraph and define lookup function for storing computed results.

Parameters:
  • lookup_function (Callable, optional) – Function mapping from position and batch tensors to a dict lookup key

  • cutoff (float, default=10.0) – Cutoff distance for interatomic interactions

  • max_num_neighbors (int, default=32) – The maximum number of neighbors to collect for each node within the cutoff distance with the default interaction graph method

forward(pos: Tensor, batch: Tensor)[source]

Perform forward pass of RadiusInteractionGraph class, first checking to see if this calculation has already been done.

Parameters:
  • pos (torch.Tensor) – Coordinates of each atom

  • batch (torch.Tensor) – Batch indices assigning each atom to a separate molecule

Returns:

  • torch.Tensor – Edge index tensor

  • torch.Tensor – Edge weight tensor