mtenn.conversion_utils.schnet.MemoizedRadiusInteractionGraph
- class mtenn.conversion_utils.schnet.MemoizedRadiusInteractionGraph(lookup_function: Callable | None = None, cutoff: float = 10.0, max_num_neighbors: int = 32)[source]
Bases:
RadiusInteractionGraphMemoized version of the PyG RadiusInteractionGraph. The lookup function used to map the position and batch tensors to the edge index/weight tensors can be any function with the appropriate signature, but defaults to a string of the number of atoms and the x, y, and z coords of the first and last two atoms to 7 decimal points. Note that this of course relies on the atoms for a given sample being in the same order each time.
Methods
__init__([lookup_function, cutoff, ...])Initialize underlying RadiusInteractionGraph and define lookup function for storing computed results.
forward(pos, batch)Perform forward pass of RadiusInteractionGraph class, first checking to see if this calculation has already been done.
- __init__(lookup_function: Callable | None = None, cutoff: float = 10.0, max_num_neighbors: int = 32)[source]
Initialize underlying RadiusInteractionGraph and define lookup function for storing computed results.
- Parameters:
lookup_function (Callable, optional) – Function mapping from position and batch tensors to a dict lookup key
cutoff (float, default=10.0) – Cutoff distance for interatomic interactions
max_num_neighbors (int, default=32) – The maximum number of neighbors to collect for each node within the
cutoffdistance with the default interaction graph method
- forward(pos: Tensor, batch: Tensor)[source]
Perform forward pass of RadiusInteractionGraph class, first checking to see if this calculation has already been done.
- Parameters:
pos (torch.Tensor) – Coordinates of each atom
batch (torch.Tensor) – Batch indices assigning each atom to a separate molecule
- Returns:
torch.Tensor – Edge index tensor
torch.Tensor – Edge weight tensor