hypll.manifolds.poincare_ball.curvature

Classes

Curvature

Class representing curvature of a manifold.

class hypll.manifolds.poincare_ball.curvature.Curvature

Class representing curvature of a manifold.

value

Learnable parameter indicating curvature of the manifold. The actual curvature is calculated as constraining_strategy(value).

constraining_strategy

Function applied to the curvature value in order to constrain the curvature of the manifold. By default uses softplus to guarantee positive curvature.

requires_grad

If the curvature requires gradient. False by default.

__init__(value: float = 1.0, constraining_strategy: ~typing.Callable[[~torch.Tensor], ~torch.Tensor] = <built-in function softplus>, requires_grad: bool = False)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward() Tensor

Returns curvature calculated as constraining_strategy(value).