Skip to content

similarity

Define similarity match modules.

MultiSimilarityMatching

MultiSimilarityMatching(
    encoders: Sequence[nn.Module],
    out_channels: int,
    tau: float = 0.1,
    tol: float = 0.0,
    max_iterations: int = 40,
    regularization: Union[str, Sequence[str]] = "weight",
    **kwargs: Union[str, Sequence[str]]
)

Bases: IterationLossModule

Multiple-target similarity matching circuit.

Some of the encoders can be skipped during the forward() call either by including fewer arguments than len(encoders) or by setting some to None.

Parameters:

  • encoders (Sequence[Module]) –

    modules to use for encoding the inputs

  • out_channels (int) –

    number of output channels

  • tau (float, default: 0.1 ) –

    factor by which to divide the competitor's learning rate

  • tol (float, default: 0.0 ) –

    tolerance for convergence test (disabled by default); if the change in every element of the output after an iteration is smaller than tol in absolute value, the iteration is assumed to have converged

  • max_iterations (int, default: 40 ) –

    maximum number of iterations to run in forward()

  • regularization (Union[str, Sequence[str]], default: 'weight' ) –

    type of encoder regularization to use; this can be a single string, or a sequence, to have different regularizations for each encoder; options are * "weight": use the encoders' parameters; regularization is added for all the tensors returned by encoder.parameters(), as long as those tensors are trainable (i.e., requires_grad is true) * "whiten": use a regularizer that encourages whitening * "none": do not use regularization for the encoder; most useful to allow for custom regularization, since lack of regularization leads to unstable dynamics in many cases

  • **kwargs

    additional keyword arguments passed to IterationLossModule

Source code in /home/docs/checkouts/readthedocs.org/user_builds/pynsm/envs/latest/lib/python3.11/site-packages/pynsm/arch/similarity.py
def __init__(
    self,
    encoders: Sequence[nn.Module],
    out_channels: int,
    tau: float = 0.1,
    tol: float = 0.0,
    max_iterations: int = 40,
    regularization: Union[str, Sequence[str]] = "weight",
    **kwargs,
):
    super().__init__(max_iterations=max_iterations, **kwargs)

    self.encoders = nn.ModuleList(encoders)
    self.out_channels = out_channels
    self.tau = tau
    self.tol = tol

    if isinstance(regularization, str):
        self.regularization = [regularization] * len(self.encoders)
    else:
        self.regularization = regularization

    for crt_reg in self.regularization:
        if crt_reg not in ["weight", "whiten", "none"]:
            raise ValueError(f"Unknown regularization {crt_reg}")

    self.competitor = nn.Linear(out_channels, out_channels, bias=False)
    torch.nn.init.eye_(self.competitor.weight)

    # make sure we maximize with respect to competitor weight...
    # ...and implement the learning rate ratio
    scaling = -1.0 / tau
    self.competitor.weight.register_hook(lambda g: g * scaling)

    self.y = torch.tensor([])

iteration_loss

iteration_loss(
    *args: Optional[torch.Tensor],
) -> torch.Tensor

Loss function associated with the iteration.

This is not actually used by the iteration, which instead uses manually calculated gradients (for efficiency).

Source code in /home/docs/checkouts/readthedocs.org/user_builds/pynsm/envs/latest/lib/python3.11/site-packages/pynsm/arch/similarity.py
def iteration_loss(self, *args: Optional[torch.Tensor]) -> torch.Tensor:
    """Loss function associated with the iteration.

    This is not actually used by the iteration, which instead uses manually
    calculated gradients (for efficiency).
    """
    assert self._Wx is not None
    loss = self._loss_no_reg(self._Wx, self.y, "sum")
    return loss / 4

SimilarityMatching

SimilarityMatching(
    encoder: nn.Module, out_channels: int, **kwargs: int
)

Bases: MultiSimilarityMatching

Single-input similarity matching circuit.

This is a thin wrapper around MultiSimilarityMatching using a single target.

Parameters:

  • encoder (Module) –

    module to use for encoding the inputs

  • out_channels (int) –

    number of output channels

  • **kwargs

    additional keyword arguments go to MultiSimilarityMatching

Source code in /home/docs/checkouts/readthedocs.org/user_builds/pynsm/envs/latest/lib/python3.11/site-packages/pynsm/arch/similarity.py
def __init__(self, encoder: nn.Module, out_channels: int, **kwargs):
    super().__init__(encoders=[encoder], out_channels=out_channels, **kwargs)

SupervisedSimilarityMatching

SupervisedSimilarityMatching(
    encoder: nn.Module,
    num_classes: int,
    out_channels: int,
    label_bias: bool = False,
    **kwargs: bool
)

Bases: MultiSimilarityMatching

Supervised similarity matching circuit for classification.

This is a wrapper that uses MultiSimilarityMatching with two encoders, self.encoders == [encoder, label_encoder]. The label_encoder is generated internally to map floating-point one-hot labels (which is what forward() expects) to a one-dimensional output of size out_channels. The forward() iteration is adapted to extend the output of the label_encoder to the same shape as the output from the encoder(), so it can be used in the similarity matching objective.

Note that by default "whiten" regularization is used for the label encoder and "weight" regularization for the input encoder.

Parameters:

  • encoder (Module) –

    module to use for encoding the inputs

  • num_classes (int) –

    number of classes for classification

  • out_channels (int) –

    number of output channels

  • label_bias (bool, default: False ) –

    set to true to include a bias term in the label encoder

  • **kwargs

    additional keyword arguments go to MultiSimilarityMatching

Source code in /home/docs/checkouts/readthedocs.org/user_builds/pynsm/envs/latest/lib/python3.11/site-packages/pynsm/arch/similarity.py
def __init__(
    self,
    encoder: nn.Module,
    num_classes: int,
    out_channels: int,
    label_bias: bool = False,
    **kwargs,
):
    label_encoder = nn.Linear(num_classes, out_channels, bias=label_bias)
    if "regularization" not in kwargs:
        kwargs["regularization"] = ("weight", "whiten")
    super().__init__(
        out_channels=out_channels, encoders=[encoder, label_encoder], **kwargs
    )