Clustering using NSM¶
Demonstrating clustering using non-negative similarity matching.
In [1]:
Copied!
import time
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import torch
from torch import nn
from pynsm import SimilarityMatching
import time
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import torch
from torch import nn
from pynsm import SimilarityMatching
Generate a dataset using a mixture of Gaussians¶
First generate samples for each cluster.
In [2]:
Copied!
torch.manual_seed(1)
input_dim = 50
n_clusters = 5
n_samples_each = 20_000
samples0 = []
sample_ids = []
for i in range(n_clusters):
mu = 2.5 * torch.randn(input_dim)
gauss = torch.distributions.multivariate_normal.MultivariateNormal(
mu, torch.eye(input_dim)
)
crt_samples = gauss.sample(sample_shape=(n_samples_each,))
samples0.append(crt_samples)
sample_ids.extend(n_samples_each * [i])
n_samples = n_clusters * n_samples_each
sample_ids = torch.tensor(sample_ids)
torch.manual_seed(1)
input_dim = 50
n_clusters = 5
n_samples_each = 20_000
samples0 = []
sample_ids = []
for i in range(n_clusters):
mu = 2.5 * torch.randn(input_dim)
gauss = torch.distributions.multivariate_normal.MultivariateNormal(
mu, torch.eye(input_dim)
)
crt_samples = gauss.sample(sample_shape=(n_samples_each,))
samples0.append(crt_samples)
sample_ids.extend(n_samples_each * [i])
n_samples = n_clusters * n_samples_each
sample_ids = torch.tensor(sample_ids)
Shuffle the samples so the ones belonging to the same clusters are no longer grouped together.
In [3]:
Copied!
perm = torch.randperm(n_samples)
samples = torch.vstack((samples0))[perm]
sample_ids = sample_ids[perm]
perm = torch.randperm(n_samples)
samples = torch.vstack((samples0))[perm]
sample_ids = sample_ids[perm]
Make a loader for training.
In [4]:
Copied!
batch_size = 100
dataset = torch.utils.data.TensorDataset(samples, sample_ids)
train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
batch_size = 100
dataset = torch.utils.data.TensorDataset(samples, sample_ids)
train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
Visualize the dataset.
In [5]:
Copied!
_, ax = plt.subplots()
ax.scatter(
samples[:, 0].numpy(),
samples[:, 1].numpy(),
s=2,
c=sample_ids.numpy(),
alpha=0.1,
)
ax.set_xlabel("dim 0")
ax.set_ylabel("dim 1")
_, ax = plt.subplots()
ax.scatter(
samples[:, 0].numpy(),
samples[:, 1].numpy(),
s=2,
c=sample_ids.numpy(),
alpha=0.1,
)
ax.set_xlabel("dim 0")
ax.set_ylabel("dim 1")
Out[5]:
Text(0, 0.5, 'dim 1')
Train NSM¶
In [6]:
Copied!
torch.manual_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Running on {device}.")
encoder = nn.Linear(input_dim, n_clusters)
model = SimilarityMatching(
encoder,
out_channels=n_clusters,
tau=0.5,
iteration_projection=nn.ReLU(),
max_iterations=200,
iteration_lr=0.05,
tol=1e-4,
).to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-2)
n_batches = n_samples // batch_size
assert n_samples % batch_size == 0
scheduler = torch.optim.lr_scheduler.LambdaLR(
optimizer, lambda i: max(1.0 / (1.0 + 20 * (i / n_batches) ** 0.5), 0.01)
)
t0 = time.time()
running_loss = []
inputs = []
outputs = []
targets = []
pbar = tqdm(train_loader)
for sample, sample_id in pbar:
sample = sample.to(device)
output = model(sample)
loss = model.loss(sample, output)
optimizer.zero_grad()
loss.backward()
optimizer.step()
scheduler.step()
running_loss.append(loss.item())
inputs.append(sample.detach().cpu())
outputs.append(output.detach().cpu())
targets.append(sample_id.detach().cpu())
pbar.set_postfix({"loss": running_loss[-1]}, refresh=False)
t1 = time.time()
print(f"Training took {t1 - t0:.2f} seconds.")
torch.manual_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Running on {device}.")
encoder = nn.Linear(input_dim, n_clusters)
model = SimilarityMatching(
encoder,
out_channels=n_clusters,
tau=0.5,
iteration_projection=nn.ReLU(),
max_iterations=200,
iteration_lr=0.05,
tol=1e-4,
).to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-2)
n_batches = n_samples // batch_size
assert n_samples % batch_size == 0
scheduler = torch.optim.lr_scheduler.LambdaLR(
optimizer, lambda i: max(1.0 / (1.0 + 20 * (i / n_batches) ** 0.5), 0.01)
)
t0 = time.time()
running_loss = []
inputs = []
outputs = []
targets = []
pbar = tqdm(train_loader)
for sample, sample_id in pbar:
sample = sample.to(device)
output = model(sample)
loss = model.loss(sample, output)
optimizer.zero_grad()
loss.backward()
optimizer.step()
scheduler.step()
running_loss.append(loss.item())
inputs.append(sample.detach().cpu())
outputs.append(output.detach().cpu())
targets.append(sample_id.detach().cpu())
pbar.set_postfix({"loss": running_loss[-1]}, refresh=False)
t1 = time.time()
print(f"Training took {t1 - t0:.2f} seconds.")
Running on cpu.
100%|██████████| 1000/1000 [00:04<00:00, 218.91it/s, loss=-2890.5]
Training took 4.58 seconds.
Show learning curve.
In [7]:
Copied!
fig, ax = plt.subplots()
ax.plot(running_loss, lw=1.0)
ax.set_xlabel("batch")
ax.set_ylabel("loss")
sns.despine(ax=ax, offset=10)
fig, ax = plt.subplots()
ax.plot(running_loss, lw=1.0)
ax.set_xlabel("batch")
ax.set_ylabel("loss")
sns.despine(ax=ax, offset=10)
Visualize the clustering output.
In [8]:
Copied!
inputs_v = np.vstack(inputs)
outputs_v = np.vstack(outputs)
cluster_assignments = np.argmax(outputs_v, axis=1)
_, ax = plt.subplots()
ax.scatter(inputs_v[:, 0], inputs_v[:, 1], s=2, c=cluster_assignments, alpha=0.1)
ax.set_xlabel("dim 0")
ax.set_ylabel("dim 1")
ax.set_title("Colored by inferred clusters")
sns.despine(ax=ax, offset=10)
inputs_v = np.vstack(inputs)
outputs_v = np.vstack(outputs)
cluster_assignments = np.argmax(outputs_v, axis=1)
_, ax = plt.subplots()
ax.scatter(inputs_v[:, 0], inputs_v[:, 1], s=2, c=cluster_assignments, alpha=0.1)
ax.set_xlabel("dim 0")
ax.set_ylabel("dim 1")
ax.set_title("Colored by inferred clusters")
sns.despine(ax=ax, offset=10)
Show a confusion matrix to check clustering accuracy.
In [9]:
Copied!
targets_v = np.hstack(targets)
confusion = np.zeros((n_clusters, n_clusters))
for i in range(n_clusters):
selection_gt = targets_v == i
n_gt = np.sum(selection_gt)
for j in range(n_clusters):
selection_assigned = cluster_assignments == j
confusion[i, j] = np.sum(selection_gt & selection_assigned) / n_gt
_, ax = plt.subplots()
h = ax.imshow(confusion, vmin=0, vmax=1, cmap="Greens")
ax.set_xlabel("inferred")
ax.set_ylabel("ground truth")
plt.colorbar(h)
targets_v = np.hstack(targets)
confusion = np.zeros((n_clusters, n_clusters))
for i in range(n_clusters):
selection_gt = targets_v == i
n_gt = np.sum(selection_gt)
for j in range(n_clusters):
selection_assigned = cluster_assignments == j
confusion[i, j] = np.sum(selection_gt & selection_assigned) / n_gt
_, ax = plt.subplots()
h = ax.imshow(confusion, vmin=0, vmax=1, cmap="Greens")
ax.set_xlabel("inferred")
ax.set_ylabel("ground truth")
plt.colorbar(h)
Out[9]:
<matplotlib.colorbar.Colorbar at 0x16e0926d0>
In [ ]:
Copied!