arch
Module definitions.
IterationLossModule
IterationLossModule(
iteration_optimizer: Callable = torch.optim.SGD,
iteration_scheduler: Optional[Callable] = None,
iteration_lr: float = 1.0,
it_optim_kwargs: Optional[Dict[str, Any]] = None,
it_sched_kwargs: Optional[Dict[str, Any]] = None,
iteration_projection: Optional[Callable] = None,
**kwargs: Optional[Callable]
)
Bases: IterationModule
A specialization of IterationModule
where the iteration is derived from a loss
function.
This creates an optimizer in the pre_iteration()
, then for each iteration runs
backward()
on the output from iteration_loss()
and steps the optimizer. This is
followed by an optional projection step; see iteration_projection
below. Note that
projection is a very simple way of enforcing constraints, and might not work well
with adaptive step optimizers.
The constructor has options for choosing the optimizer to use, as well as for an optional learning-rate scheduler; see below.
Functions to implement:
iteration_loss(*args, **kwargs)
should return the loss; the output is stored inself.last_iteration_loss
as a number (i.e.item()
is called on the tensor output fromiteration_loss()
)iteration_parameters()
should return a list of parameters to be optimized during the iteration.
Note that typically the iteration_parameters()
should not be included in the
module's parameters()
, but should potentially be saved as part of the
state_dict
, so it is recommended that they be registered as buffers.
The optimization and scheduling features of the class can be used in combination with manually calculated gradients by overriding
iteration_set_gradients(*args, **kwargs)
Module constructor:
Parameters:
-
iteration_optimizer
(
Callable
, default:SGD
) –optimizer to use for forward iteration; e.g.,
torch.optim.SGD
-
iteration_scheduler
(
Optional[Callable]
, default:None
) –scheduler to use (if any) for forward iteration; e.g.,
torch.optim.lr_scheduler.StepLR
-
iteration_lr
(
float
, default:1.0
) –learning rate for forward iteration; this is a shortcut that overrides any potential learning rate from
it_optim_kwargs
-
it_optim_kwargs
(
Optional[Dict[str, Any]]
, default:None
) –dictionary of keyword arguments to pass to the optimizer
-
it_sched_kwargs
(
Optional[Dict[str, Any]]
, default:None
) –dictionary of keyword arguments to pass to the scheduler
-
iteration_projection
(
Optional[Callable]
, default:None
) –optional projection to perform after each optimizer step; this should be a callable that will be applied to each element of
self.iteration_parameters()
(e.g.,torch.nn.functional.relu
) -
kwargs
–
other keyword arguments are passed to
IterationModule
Source code in /home/docs/checkouts/readthedocs.org/user_builds/pynsm/envs/latest/lib/python3.11/site-packages/pynsm/arch/base.py
iteration
Run one iteration.
This calculates the gradients using self.iteration_set_gradients()
, then steps
the optimizer and scheduler (if any), and finally projects the result using
self.iteration_projection
.
Source code in /home/docs/checkouts/readthedocs.org/user_builds/pynsm/envs/latest/lib/python3.11/site-packages/pynsm/arch/base.py
iteration_loss
Loss function used for the iteration.
Abstract method, has to be implemented by descendants. (Alternatively,
iteration_set_gradients()
can be overridden to avoid calling iteration_loss()
altogether.)
Source code in /home/docs/checkouts/readthedocs.org/user_builds/pynsm/envs/latest/lib/python3.11/site-packages/pynsm/arch/base.py
iteration_parameters
Return list of iteration parameters.
Abstract method, has to be implemented by descendants.
Source code in /home/docs/checkouts/readthedocs.org/user_builds/pynsm/envs/latest/lib/python3.11/site-packages/pynsm/arch/base.py
iteration_set_gradients
Calculate gradients for the iteration.
This uses a backward pass on the result from iteration_loss()
. Override as
needed to process the gradients before using in the optimizer.
Source code in /home/docs/checkouts/readthedocs.org/user_builds/pynsm/envs/latest/lib/python3.11/site-packages/pynsm/arch/base.py
post_iteration
Post-iteration processing.
This sets requires_grad
to False
for the iteration_parameters()
and to
whatever it was before the iteration for the parameters()
.
Source code in /home/docs/checkouts/readthedocs.org/user_builds/pynsm/envs/latest/lib/python3.11/site-packages/pynsm/arch/base.py
pre_iteration
Pre-iteration processing.
This sets requires_grad
to True
for the iteration_parameters()
and to
False
for the parameters()
. It also generates an optimizer and a scheduler,
if one is requested.
Source code in /home/docs/checkouts/readthedocs.org/user_builds/pynsm/envs/latest/lib/python3.11/site-packages/pynsm/arch/base.py
IterationModule
Bases: Module
A module where the forward pass is called iteratively.
The forward()
method calls self.iteration()
iteratively, until either a maximum
number of steps is reached, or self.converged()
is true. Any arguments passed to
forward()
are passed along:
self.iteration(*args, **kwargs)
When done iterating, forward()
returns the output from self.post_iteration()
.
The methods self.pre_iteration()
and self.post_iteration()
can also be used to
perform any necessary pre- and post-processing, as they are called before the first
iteration and after the last, respectively. They are passed the arguments passed to
forward()
:
self.pre_iteration(*args, **kwargs)
self.post_iteration(*args, **kwargs)
By default, these do nothing and return nothing.
Module constructor:
Parameters:
-
max_iterations
(
int
, default:1000
) –maximum number of
iteration()
calls in one forward pass
Source code in /home/docs/checkouts/readthedocs.org/user_builds/pynsm/envs/latest/lib/python3.11/site-packages/pynsm/arch/base.py
converged
Check whether the iteration converged.
Always returns false. Override in descendants as needed.
Source code in /home/docs/checkouts/readthedocs.org/user_builds/pynsm/envs/latest/lib/python3.11/site-packages/pynsm/arch/base.py
forward
Forward pass.
This sets up the iteration using self.pre_iteration()
, then runs at most
self.max_iterations
calls to self.iteration
, checking for self.converged()
for every iteration, and finally obtains the return value by calling
self.post_iteration()
.
All positional and keyword arguments are passed to all of the calls.
Source code in /home/docs/checkouts/readthedocs.org/user_builds/pynsm/envs/latest/lib/python3.11/site-packages/pynsm/arch/base.py
iteration
Run one iteration.
Abstract method, has to be implemented by descendants.
Source code in /home/docs/checkouts/readthedocs.org/user_builds/pynsm/envs/latest/lib/python3.11/site-packages/pynsm/arch/base.py
post_iteration
Finalize the iteration and generate a return value.
Does nothing by default. Override in descendants as needed.
Source code in /home/docs/checkouts/readthedocs.org/user_builds/pynsm/envs/latest/lib/python3.11/site-packages/pynsm/arch/base.py
pre_iteration
Set up the iteration.
Does nothing by default. Override in descendants as needed.
register_iteration_hook
Register a hook for the forward iteration.
Hooks can be assigned for the following events:
"pre"
: called beforepre_iteration()
signature:hook(module)
"post"
: called afterpost_iteration()
signature:hook(module)
"iteration"
: called after every call toiteration()
(beforeconverged()
) signature:hook(module) -> bool
A truthful return value ends the iteration. The iteration index is available inmodule.iteration_idx
.
Multiple hooks can be attached to the same event.
Parameters:
-
kind
(
str
) –the kind of hook to register
-
hook
(
Callable
) –the function to be called
Source code in /home/docs/checkouts/readthedocs.org/user_builds/pynsm/envs/latest/lib/python3.11/site-packages/pynsm/arch/base.py
MultiSimilarityMatching
MultiSimilarityMatching(
encoders: Sequence[nn.Module],
out_channels: int,
tau: float = 0.1,
tol: float = 0.0,
max_iterations: int = 40,
regularization: Union[str, Sequence[str]] = "weight",
**kwargs: Union[str, Sequence[str]]
)
Bases: IterationLossModule
Multiple-target similarity matching circuit.
Some of the encoders can be skipped during the forward()
call either by including
fewer arguments than len(encoders)
or by setting some to None
.
Parameters:
-
encoders
(
Sequence[Module]
) –modules to use for encoding the inputs
-
out_channels
(
int
) –number of output channels
-
tau
(
float
, default:0.1
) –factor by which to divide the competitor's learning rate
-
tol
(
float
, default:0.0
) –tolerance for convergence test (disabled by default); if the change in every element of the output after an iteration is smaller than
tol
in absolute value, the iteration is assumed to have converged -
max_iterations
(
int
, default:40
) –maximum number of iterations to run in
forward()
-
regularization
(
Union[str, Sequence[str]]
, default:'weight'
) –type of encoder regularization to use; this can be a single string, or a sequence, to have different regularizations for each encoder; options are *
"weight"
: use the encoders' parameters; regularization is added for all the tensors returned byencoder.parameters()
, as long as those tensors are trainable (i.e.,requires_grad
is true) *"whiten"
: use a regularizer that encourages whitening *"none"
: do not use regularization for the encoder; most useful to allow for custom regularization, since lack of regularization leads to unstable dynamics in many cases -
**kwargs
–
additional keyword arguments passed to
IterationLossModule
Source code in /home/docs/checkouts/readthedocs.org/user_builds/pynsm/envs/latest/lib/python3.11/site-packages/pynsm/arch/similarity.py
iteration_loss
Loss function associated with the iteration.
This is not actually used by the iteration, which instead uses manually calculated gradients (for efficiency).
Source code in /home/docs/checkouts/readthedocs.org/user_builds/pynsm/envs/latest/lib/python3.11/site-packages/pynsm/arch/similarity.py
SimilarityMatching
Bases: MultiSimilarityMatching
Single-input similarity matching circuit.
This is a thin wrapper around MultiSimilarityMatching
using a single target.
Parameters:
-
encoder
(
Module
) –module to use for encoding the inputs
-
out_channels
(
int
) –number of output channels
-
**kwargs
–
additional keyword arguments go to
MultiSimilarityMatching
Source code in /home/docs/checkouts/readthedocs.org/user_builds/pynsm/envs/latest/lib/python3.11/site-packages/pynsm/arch/similarity.py
SupervisedSimilarityMatching
SupervisedSimilarityMatching(
encoder: nn.Module,
num_classes: int,
out_channels: int,
label_bias: bool = False,
**kwargs: bool
)
Bases: MultiSimilarityMatching
Supervised similarity matching circuit for classification.
This is a wrapper that uses MultiSimilarityMatching
with two encoders,
self.encoders == [encoder, label_encoder]
. The label_encoder
is generated
internally to map floating-point one-hot labels (which is what forward()
expects) to a one-dimensional output of size out_channels
. The forward()
iteration is adapted to extend the output of the label_encoder
to the same shape
as the output from the encoder()
, so it can be used in the similarity matching
objective.
Note that by default "whiten"
regularization is used for the label encoder and
"weight"
regularization for the input encoder.
Parameters:
-
encoder
(
Module
) –module to use for encoding the inputs
-
num_classes
(
int
) –number of classes for classification
-
out_channels
(
int
) –number of output channels
-
label_bias
(
bool
, default:False
) –set to true to include a bias term in the label encoder
-
**kwargs
–
additional keyword arguments go to
MultiSimilarityMatching