Extending the library
To extend the library, start by following the developer installation instructions on GitHub.
Adding a similarity-based model
You can add models by making a new file in the arch folder. Variants of similarity matching can be built by inheriting from the MultiSimilarityMatching class. For instance, canonical correlation analysis (CCA) can be implemented like this:1
from torch import nn
from pynsm import MultiSimilarityMatching
class SimilarityMatchingCCA(MultiSimilarityMatching):
def __init__(self, dim1: int, dim2: int, out_channels: int, **kwargs):
encoders = [
nn.Linear(dim1, out_channels), nn.Linear(dim2, out_channels)
]
super().__init__(encoders, regularization="whiten", **kwargs)
After adding new models, add these in arch/__init__.py and __init__.py to ensure that they can be easily accessed by users.
Adding a more generic iteration-based model
For models where the forward pass requires iteration but are not based on similarity matching, you can instead inherit from IterationModule or IterationLossModule. The former requires that you specify the actual processing performed for every iteration, while the latter assumes that the iteration is gradient-based and only the loss function needs to be specified. In both cases, you will have to also define the state variables by overriding the pre_iteration() method, and define a return value by overriding the post_iteration() method.
Here is an example where the forward iteration generates the Mandelbrot set:
import torch
from pynsm import IterationModule
class Mandelbrot(IterationModule):
def pre_iteration(self, c: torch.Tensor):
self.state = torch.zeros_like(c, dtype=torch.complex64)
self.counts = torch.zeros_like(self.state, dtype=int)
def iteration(self, c: torch.Tensor):
self.state = self.state ** 2 + c
mask = torch.abs(self.state) < 2
self.counts[mask] += 1
def post_iteration(self, c: torch.Tensor) -> torch.Tensor:
self.counts = self.counts.float()
self.counts[self.counts == self.max_iterations] = float("nan")
return self.counts
We can test the code as follows:
import matplotlib.pyplot as plt
extent = (-2, 0.8, -1.4, 1.4)
x = torch.linspace(extent[0], extent[1], 500)
y = torch.linspace(extent[2], extent[3], 500)
grid_real = torch.meshgrid(x, y, indexing="xy")
grid = torch.complex(*grid_real)
m = Mandelbrot(max_iterations=20)
counts = m(grid)
plt.imshow(counts, extent=extent, cmap="plasma")
This yields a familiar picture of the Mandelbrot set:

-
Lipshutz, D., Bahroun, Y., Golkar, S., Sengupta, A. M., & Chklovskii, D. B. (2021). A biologically plausible neural network for multichannel canonical correlation analysis. Neural Computation, 33(9), 2309–2352. https://doi.org/10.1162/neco_a_01414 ↩