mtenn.conversion_utils.visnet.EquivariantVecToScalar

class mtenn.conversion_utils.visnet.EquivariantVecToScalar(mean, reduce_op)[source]

Bases: Module

Wrapper around torch_geometric.utils.scatter to use it as a Module.

Methods

__init__(mean, reduce_op)

Store use parameters.

forward(x)

Perform the scatter operation.

__init__(mean, reduce_op)[source]

Store use parameters.

Parameters:
  • mean (torch.Tensor) – Mean of predicted value

  • reduce_op (str) – Reduce operation to use in torch_geometric.utils.scatter

forward(x)[source]

Perform the scatter operation.

Parameters:

x (torch.Tensor) – Input tensor

Returns:

Output of scatter call

Return type:

torch.Tensor