diff --git a/snntorch/functional/loss.py b/snntorch/functional/loss.py index 3157604b..a64111b6 100644 --- a/snntorch/functional/loss.py +++ b/snntorch/functional/loss.py @@ -444,16 +444,51 @@ def __init__( self.tolerance_fn = self.Tolerance.apply self.multi_spike = multi_spike - if not self.target_is_time: - self.on_target = on_target - self.off_target = off_target # override this with final step + # Validate and normalise target specifications + if not target_is_time: + if not multi_spike: + self.on_target = self._validate_scalar_target(on_target, "on_target") + self.off_target = self._validate_scalar_target(off_target, "off_target") + else: + # If multi_spike=true both on_target and off_targets must be iterables + self.on_target = self._validate_sequence_target(on_target, "on_target") + self.off_target = self._validate_sequence_target(off_target, "off_target") + + self.first_spike_fn = self.MultiSpike.apply if multi_spike else self.FirstSpike.apply + + + def _validate_scalar_target(self, value, name): + """Ensure target is a scalar; tensors with more than one element raise an error.""" + if isinstance(value, torch.Tensor): + if value.numel() != 1: + raise TypeError( + f"{name} must be a scalar when multi_spike=False; received tensor with shape {value.shape}. " + "Set multi_spike=True to specify multiple spike times.") + return int(value.item()) + + elif isinstance(value, int): + return value + + raise TypeError( + f"{name} must be an int or a 0-D/1-element torch.Tensor when multi_spike=False; " + f"got {type(value).__name__}." + ) - # function used to extract the first F spike times. If - # multi_spike=False, F=1. - if self.multi_spike: - self.first_spike_fn = self.MultiSpike.apply - else: - self.first_spike_fn = self.FirstSpike.apply + + def _validate_sequence_target(self, value, name): + """Convert 1-D tensors or iterables into lists of ints for multi-spike mode.""" + if isinstance(value, torch.Tensor): + if value.ndim == 0: + return [int(value.item())] + if value.ndim > 1: + raise TypeError( + f"{name} must be one dimensional when multi_spike=True; received tensor with shape {value.shape}.") + return [int(x) for x in value.flatten().tolist()] + try: + return [int(x) for x in value] + except TypeError: + raise TypeError( + f"{name} must be iterable when multi_spike=True; got object of type {type(value).__name__}.") # spiking output from final layer is a recording: T x B x N # targets can either be labels or spike times @@ -504,45 +539,66 @@ class FirstSpike(torch.autograd.Function): Linearize df/dS=-1 if spike, 0 if no spike.""" @staticmethod - def forward(ctx, spk_rec, device="cpu"): + def forward(ctx, spk_rec: torch.Tensor, device="cpu"): """Convert spk_rec of 1/0s [TxBxN] --> spk_time [TxBxN]. 0's indicate no spike --> +1 is first time step. Transpose accounts for broadcasting along final dimension (i.e., multiply along T).""" + T = spk_rec.shape[0] + spk_time = ( spk_rec.transpose(0, -1) * (torch.arange(0, spk_rec.size(0)).detach().to(device) + 1) ).transpose(0, -1) - """extact first spike time. Will be used to pass into loss - function.""" - first_spike_time = torch.zeros_like(spk_time[0]) - for step in range(spk_time.size(0)): - first_spike_time += ( - spk_time[step] * ~first_spike_time.bool() - ) # mask out subsequent spikes - - """override element 0 (no spike) with shadow spike @ final time - step, then offset by -1 - s.t. first_spike is at t=0.""" - first_spike_time += ~first_spike_time.bool() * (spk_time.size(0)) - first_spike_time -= 1 # fix offset + # spk_time is (T, B, N) with 0 or (t+1). + # Create a mask of shape (T, B, N) indicating if this is the first spike for each (B, N). + # For each (b, n), we want the earliest t at which spk_time[t,b,n]!=0. + + # We can do a cumulative sum over time to see if we have "ever spiked" up to that point. + spk_cum = (spk_time > 0).cumsum(dim=0) # shape (T, B, N) + # 'spk_cum[t,b,n]' is how many spikes have occurred up to time step t for (b,n). + + # The "first spike" for each (t,b,n) is the place where spk_time[t,b,n]!=0 AND spk_cum[t,b,n]==1 + first_spike_mask = (spk_time > 0) & (spk_cum == 1) # shape (T, B, N), bool + + # Now each (b,n) can have at most one True in the time dimension, or none if no spike at all. + first_spike_time = (first_spike_mask * spk_time).sum(dim=0) + # shape is (B, N), each entry is the earliest (t+1) we found, or 0 if none. + + # Next handle the no-spike neurons by replacing 0 with T + no_spike = (first_spike_time == 0) + first_spike_time[no_spike] = T + first_spike_time -= 1 + ctx.save_for_backward(first_spike_time, spk_rec) return first_spike_time @staticmethod def backward(ctx, grad_output): (first_spike_time, spk_rec) = ctx.saved_tensors - spk_time_grad = torch.zeros_like(spk_rec) # T x B x N """spike extraction step/indexing @ each step is non-differentiable. Apply sign estimator by substituting gradient for -1 ONLY at first spike time.""" - for i in range(first_spike_time.size(0)): - for j in range(first_spike_time.size(1)): - spk_time_grad[first_spike_time[i, j].long(), i, j] = 1.0 + + # first_spike_time is (B, N) + spk_time_grad = torch.zeros_like(spk_rec) # (T, B, N) + + # Flatten out the (B, N) coordinates + b_coords = torch.arange(first_spike_time.size(0), device=first_spike_time.device) + n_coords = torch.arange(first_spike_time.size(1), device=first_spike_time.device) + # Create a meshgrid to get all (b, n) pairs + B_idx, N_idx = torch.meshgrid(b_coords, n_coords, indexing='ij') + # B_idx, N_idx both are (B, N) + + # Time coords from first_spike_time + T_idx = first_spike_time.long() # (B, N) + + spk_time_grad[T_idx, B_idx, N_idx] = 1.0 grad = -grad_output * spk_time_grad + return grad, None @staticmethod @@ -623,13 +679,8 @@ class Tolerance(torch.autograd.Function): # TO-DO: remove ctx? @staticmethod def forward(ctx, spk_time, target, tolerance): - spk_time_clone = ( - spk_time.clone() - ) # spk_time_clone: BxN (FxBxN for multi-spike); target: TxBxN - spk_time_clone[torch.abs(spk_time - target) < tolerance] = ( - torch.ones_like(spk_time) * target - )[torch.abs(spk_time - target) < tolerance] - return spk_time_clone + mask = (spk_time - target).abs() <= tolerance + return torch.where(mask, target.to(spk_time.dtype), spk_time) @staticmethod def backward(ctx, grad_output): @@ -653,15 +704,15 @@ def label_to_single_spike(self, targets, num_outputs): """Convert labels from neuron index (dim: B) to first spike time (dim: B x N).""" - # guess: i designed this code with on_target >> off_target in mind - targets = spikegen.targets_convert( - targets, - num_classes=num_outputs, - on_target=self.on_target, - off_target=self.off_target, - ) + batch_size = targets.size(0) - return targets + # Initialize the target tensor with the incorrect timesteps + target_spike_time = torch.full((batch_size, num_outputs), self.off_target, dtype=torch.float32, device=targets.device) + + # Set the correct class latencies to self.on_target + target_spike_time[torch.arange(batch_size), targets] = self.on_target + + return target_spike_time def label_to_multi_spike(self, targets, num_outputs): """Convert labels from neuron index (dim: B) to multiple spike times @@ -677,16 +728,17 @@ def label_to_multi_spike(self, targets, num_outputs): f"`on_target` (length: {num_spikes_on}) must have the same " f"length as `off_target` (length: {num_spikes_off}." ) + + batch_size = targets.size(0) # iterate through each spike targets_rec = [] for step in range(num_spikes_on): - target_step = spikegen.targets_convert( - targets, - num_classes=num_outputs, - on_target=self.on_target[step], - off_target=self.off_target[step], - ) + # Initialize the target tensor with the incorrect timesteps + target_step = torch.full((batch_size, num_outputs), self.off_target[step], dtype=torch.float32, device=targets.device) + # Set the correct class latencies to self.on_target + target_step[torch.arange(batch_size), targets] = self.on_target[step] + targets_rec.append(target_step) targets_rec = torch.stack(targets_rec)