Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 100 additions & 48 deletions snntorch/functional/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,16 +444,51 @@ def __init__(
self.tolerance_fn = self.Tolerance.apply
self.multi_spike = multi_spike

if not self.target_is_time:
self.on_target = on_target
self.off_target = off_target # override this with final step
# Validate and normalise target specifications
if not target_is_time:
if not multi_spike:
self.on_target = self._validate_scalar_target(on_target, "on_target")
self.off_target = self._validate_scalar_target(off_target, "off_target")
else:
# If multi_spike=true both on_target and off_targets must be iterables
self.on_target = self._validate_sequence_target(on_target, "on_target")
self.off_target = self._validate_sequence_target(off_target, "off_target")

self.first_spike_fn = self.MultiSpike.apply if multi_spike else self.FirstSpike.apply


def _validate_scalar_target(self, value, name):
"""Ensure target is a scalar; tensors with more than one element raise an error."""
if isinstance(value, torch.Tensor):
if value.numel() != 1:
raise TypeError(
f"{name} must be a scalar when multi_spike=False; received tensor with shape {value.shape}. "
"Set multi_spike=True to specify multiple spike times.")
return int(value.item())

elif isinstance(value, int):
return value

raise TypeError(
f"{name} must be an int or a 0-D/1-element torch.Tensor when multi_spike=False; "
f"got {type(value).__name__}."
)

# function used to extract the first F spike times. If
# multi_spike=False, F=1.
if self.multi_spike:
self.first_spike_fn = self.MultiSpike.apply
else:
self.first_spike_fn = self.FirstSpike.apply

def _validate_sequence_target(self, value, name):
"""Convert 1-D tensors or iterables into lists of ints for multi-spike mode."""
if isinstance(value, torch.Tensor):
if value.ndim == 0:
return [int(value.item())]
if value.ndim > 1:
raise TypeError(
f"{name} must be one dimensional when multi_spike=True; received tensor with shape {value.shape}.")
return [int(x) for x in value.flatten().tolist()]
try:
return [int(x) for x in value]
except TypeError:
raise TypeError(
f"{name} must be iterable when multi_spike=True; got object of type {type(value).__name__}.")

# spiking output from final layer is a recording: T x B x N
# targets can either be labels or spike times
Expand Down Expand Up @@ -504,45 +539,66 @@ class FirstSpike(torch.autograd.Function):
Linearize df/dS=-1 if spike, 0 if no spike."""

@staticmethod
def forward(ctx, spk_rec, device="cpu"):
def forward(ctx, spk_rec: torch.Tensor, device="cpu"):
"""Convert spk_rec of 1/0s [TxBxN] --> spk_time [TxBxN].
0's indicate no spike --> +1 is first time step.
Transpose accounts for broadcasting along final dimension
(i.e., multiply along T)."""
T = spk_rec.shape[0]

spk_time = (
spk_rec.transpose(0, -1)
* (torch.arange(0, spk_rec.size(0)).detach().to(device) + 1)
).transpose(0, -1)

"""extact first spike time. Will be used to pass into loss
function."""
first_spike_time = torch.zeros_like(spk_time[0])
for step in range(spk_time.size(0)):
first_spike_time += (
spk_time[step] * ~first_spike_time.bool()
) # mask out subsequent spikes

"""override element 0 (no spike) with shadow spike @ final time
step, then offset by -1
s.t. first_spike is at t=0."""
first_spike_time += ~first_spike_time.bool() * (spk_time.size(0))
first_spike_time -= 1 # fix offset
# spk_time is (T, B, N) with 0 or (t+1).
# Create a mask of shape (T, B, N) indicating if this is the first spike for each (B, N).
# For each (b, n), we want the earliest t at which spk_time[t,b,n]!=0.

# We can do a cumulative sum over time to see if we have "ever spiked" up to that point.
spk_cum = (spk_time > 0).cumsum(dim=0) # shape (T, B, N)
# 'spk_cum[t,b,n]' is how many spikes have occurred up to time step t for (b,n).

# The "first spike" for each (t,b,n) is the place where spk_time[t,b,n]!=0 AND spk_cum[t,b,n]==1
first_spike_mask = (spk_time > 0) & (spk_cum == 1) # shape (T, B, N), bool

# Now each (b,n) can have at most one True in the time dimension, or none if no spike at all.
first_spike_time = (first_spike_mask * spk_time).sum(dim=0)
# shape is (B, N), each entry is the earliest (t+1) we found, or 0 if none.

# Next handle the no-spike neurons by replacing 0 with T
no_spike = (first_spike_time == 0)
first_spike_time[no_spike] = T
first_spike_time -= 1

ctx.save_for_backward(first_spike_time, spk_rec)
return first_spike_time

@staticmethod
def backward(ctx, grad_output):
(first_spike_time, spk_rec) = ctx.saved_tensors
spk_time_grad = torch.zeros_like(spk_rec) # T x B x N

"""spike extraction step/indexing @ each step is
non-differentiable.
Apply sign estimator by substituting gradient for -1 ONLY at
first spike time."""
for i in range(first_spike_time.size(0)):
for j in range(first_spike_time.size(1)):
spk_time_grad[first_spike_time[i, j].long(), i, j] = 1.0

# first_spike_time is (B, N)
spk_time_grad = torch.zeros_like(spk_rec) # (T, B, N)

# Flatten out the (B, N) coordinates
b_coords = torch.arange(first_spike_time.size(0), device=first_spike_time.device)
n_coords = torch.arange(first_spike_time.size(1), device=first_spike_time.device)
# Create a meshgrid to get all (b, n) pairs
B_idx, N_idx = torch.meshgrid(b_coords, n_coords, indexing='ij')
# B_idx, N_idx both are (B, N)

# Time coords from first_spike_time
T_idx = first_spike_time.long() # (B, N)

spk_time_grad[T_idx, B_idx, N_idx] = 1.0
grad = -grad_output * spk_time_grad

return grad, None

@staticmethod
Expand Down Expand Up @@ -623,13 +679,8 @@ class Tolerance(torch.autograd.Function):
# TO-DO: remove ctx?
@staticmethod
def forward(ctx, spk_time, target, tolerance):
spk_time_clone = (
spk_time.clone()
) # spk_time_clone: BxN (FxBxN for multi-spike); target: TxBxN
spk_time_clone[torch.abs(spk_time - target) < tolerance] = (
torch.ones_like(spk_time) * target
)[torch.abs(spk_time - target) < tolerance]
return spk_time_clone
mask = (spk_time - target).abs() <= tolerance
return torch.where(mask, target.to(spk_time.dtype), spk_time)

@staticmethod
def backward(ctx, grad_output):
Expand All @@ -653,15 +704,15 @@ def label_to_single_spike(self, targets, num_outputs):
"""Convert labels from neuron index (dim: B) to first spike time
(dim: B x N)."""

# guess: i designed this code with on_target >> off_target in mind
targets = spikegen.targets_convert(
targets,
num_classes=num_outputs,
on_target=self.on_target,
off_target=self.off_target,
)
batch_size = targets.size(0)

return targets
# Initialize the target tensor with the incorrect timesteps
target_spike_time = torch.full((batch_size, num_outputs), self.off_target, dtype=torch.float32, device=targets.device)

# Set the correct class latencies to self.on_target
target_spike_time[torch.arange(batch_size), targets] = self.on_target

return target_spike_time

def label_to_multi_spike(self, targets, num_outputs):
"""Convert labels from neuron index (dim: B) to multiple spike times
Expand All @@ -677,16 +728,17 @@ def label_to_multi_spike(self, targets, num_outputs):
f"`on_target` (length: {num_spikes_on}) must have the same "
f"length as `off_target` (length: {num_spikes_off}."
)

batch_size = targets.size(0)

# iterate through each spike
targets_rec = []
for step in range(num_spikes_on):
target_step = spikegen.targets_convert(
targets,
num_classes=num_outputs,
on_target=self.on_target[step],
off_target=self.off_target[step],
)
# Initialize the target tensor with the incorrect timesteps
target_step = torch.full((batch_size, num_outputs), self.off_target[step], dtype=torch.float32, device=targets.device)
# Set the correct class latencies to self.on_target
target_step[torch.arange(batch_size), targets] = self.on_target[step]

targets_rec.append(target_step)
targets_rec = torch.stack(targets_rec)

Expand Down