Linear Regression using snnTorch #122
Replies: 2 comments
-
|
Finally got some good results. Tough the SNN model is not perfect but its working. I'm sharing the code if it helps someone. #output Epoch [1000/50000], Loss: 267667.7500 |
Beta Was this translation helpful? Give feedback.
-
|
Thanks for sharing!
We've noticed that the atan surrogate gradient function does better than
fastsigmoid, so I pushed that into the latest version of snntorch
yesterday.
Try updating snntorch, and you might get better performance with that.
…On Fri, 5 Aug 2022, 3:59 pm mahmad2005, ***@***.***> wrote:
Finally got some good results. Tough the SNN model is not perfect but its
working.
I'm sharing the code if it helps someone.
`!pip install snntorch`
`import snntorch as snn
from snntorch import spikeplot as splt
from snntorch import spikegen
from snntorch import surrogate
import numpy as np
import pandas as pd # Data Prcessing, I/O for csv
import seaborn as sns # Visualisation
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.autograd import Variable
dtype = torch.float
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
# Network Architecture
num_inputs = 1
num_hidden = 100
num_outputs = 1
# Temporal Dynamics
num_steps = 1
beta = 0.5
# Define Network
class Net(nn.Module):
def __init__(self, alpha, beta, spike_grad):
super().__init__()
self.alpha = alpha
self.beta = beta
self.spike_grad = spike_grad
# Initialize layers
self.fc1 = nn.Linear(num_inputs, num_hidden)
self.lif1 = snn.Leaky(beta = self.beta, spike_grad = self.spike_grad)
self.fc2 = nn.Linear(num_hidden, num_outputs)
self.lif2 = snn.Leaky(beta = self.beta, spike_grad = self.spike_grad)
def forward(self, x):
# Initialize hidden states at t=0
mem1 = self.lif1.init_leaky()
mem2 = self.lif2.init_leaky()
# Record the final layer
spk2_rec = []
mem2_rec = []
vel_xyz = None
for step in range(num_steps):
cur1 = self.fc1(x)
spk1, mem1 = self.lif1(cur1, mem1)
cur2 = self.fc2(spk1)
spk2, mem2 = self.lif2(cur2, mem2)
spk2_rec.append(spk2)
mem2_rec.append(mem2)
return torch.stack(spk2_rec, dim=0), torch.stack(mem2_rec, dim=0)
#Loading Data
AllData = pd.read_csv(r"../input/xisx-file/xisx2.csv")
Data_x = AllData[['x']]
Data_y = AllData[['y']]
x_train = np.array(Data_x).astype('float32')
y_train = np.array(Data_y).astype('float32')
# Neuron and simulation parameters
spike_grad = surrogate.fast_sigmoid(slope = 1)
beta = 0.50
# SNN regression model, Loss and optimizer
model = Net(beta, alpha, spike_grad).to(device)
criterion = nn.MSELoss()
#optimizer = torch.optim.SGD(model.parameters(), lr = 0.00000001)
optimizer = torch.optim.Adam(model.parameters(), lr = 0.001) # best in lr = 0.001
optimizer.zero_grad()
# Train the model
num_epochs = 50000
for epoch in range(num_epochs):
# Convert numpy arrays to torch tensors
inputs = Variable(torch.from_numpy(x_train))
labels = Variable(torch.from_numpy(y_train))
# Clearing the gradients w.r.t. parameters
optimizer.zero_grad()
# Forward pass
# model = model.double() # torch complains that tensor type mismatch if not included
_, outputs = model.forward(inputs)
loss = criterion(outputs, labels)
# Backward and optimize
loss.backward() # Backpropagation
optimizer.step() # Update of parameters
if (epoch+1) % 1000 == 0:
print ('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, loss.item()))
#model.eval()
# Plot the graph
#predicted = model.forward(torch.from_numpy(x_train)).detach().numpy()
#plt.plot(x_train, y_train, 'ro', label='Original data')
#plt.plot(x_train, predicted, label='Prediction')
#plt.legend()
#plt.show()
# Save the model checkpoint
torch.save(model.state_dict(), 'model.ckpt')`
—
Reply to this email directly, view it on GitHub
<#122 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AJTFT4S6GXAFPLFUHACARIDVXTCWHANCNFSM55N53GOA>
.
You are receiving this because you are subscribed to this thread.Message
ID: ***@***.***
com>
|
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
I was trying to test a simple regression problem using snnTorch where the function is f(x) = x, x contain 0 to 1000.
I've implemented my code following the snntorch Tutorial 5.
[https://colab.research.google.com/github/jeshraghian/snntorch/blob/master/examples/tutorial_5_FCN.ipynb]
I've used MSELoss and SGD as loss function and optimizer respectively. However, the model could not improve the training loss. Can anyone help me to fix my code? I am new to this area so I hope you will consider my silly mistakes.
I have search many codes but could not find anything that used spiking neural network for regression problem.
#putput
tensor([[ 0.],
[ 1.],
[ 2.],
...,
[ 998.],
[ 999.],
[1000.]])
torch.Size([1, 1001, 1])
Training loss: 333320.312
Training loss: 532816064.000
Train set accuracy for a single minibatch: 0.10%
556094336.0
21454432.0
13805051.0
8894782.0
5740483.0
3712347.25
2406832.75
1565298.25
1021911.625
670300.625
442194.8125
293747.125
196772.78125
133134.84375
91147.125
63267.3984375
44618.32421875
32038.07421875
23470.833984375
17575.109375
13471.677734375
10581.3544921875
8520.2470703125
7032.09375
5944.4453125
5140.1611328125
4538.87109375
4084.806640625
3738.813232421875
3473.06396484375
3267.533203125
3107.634765625
2982.609619140625
2884.44384765625
2807.09521484375
2745.977783203125
2697.56884765625
2659.151611328125
2628.61669921875
2604.315673828125
2584.958740234375
2569.524658203125
2557.2119140625
2547.382080078125
2539.530029296875
2533.25732421875
2528.2451171875
2524.2392578125
2521.037841796875
2518.477783203125
202434.71875
122462.5
72880.0078125
42427.015625
23960.92578125
12961.69921875
6577.4453125
3015.725341796875
1155.5867919921875
300.587646484375
21.77994155883789
58.056297302246094
253.0438232421875
515.2521362304688
792.9987182617188
1058.7027587890625
1299.111328125
1509.2578125
1688.7808837890625
1839.698486328125
1965.1041259765625
2068.41748046875
2152.981689453125
2221.85791015625
2277.74267578125
2322.952880859375
2359.44189453125
2388.838623046875
2412.487060546875
2431.489501953125
2446.74658203125
2458.98681640625
2468.801025390625
2476.666748046875
2482.96630859375
2488.01220703125
2492.05322265625
2495.28759765625
2497.878662109375
2499.9501953125
2501.607666015625
2502.936767578125
2503.99755859375
2504.84814453125
2505.52880859375
2506.071533203125
2506.507568359375
2506.855712890625
2507.13330078125
2507.35595703125
202514.6875
122512.25
72910.71875
42445.7578125
23972.197265625
12968.330078125
6581.22412109375
3017.77099609375
1156.5980224609375
300.9961242675781
21.85030174255371
57.948455810546875
252.8524169921875
515.0319213867188
792.7794189453125
1058.5
1298.9315185546875
1509.10302734375
1688.6494140625
1839.588623046875
1965.013427734375
2068.34423828125
2152.922119140625
2221.810791015625
2277.70458984375
2322.922119140625
2359.416748046875
2388.81787109375
2412.470703125
2431.477294921875
2446.7373046875
2458.9775390625
2468.7919921875
2476.658203125
2482.96044921875
2488.00927734375
2492.05029296875
2495.284423828125
2497.875732421875
2499.9501953125
2501.607666015625
2502.936767578125
2503.99755859375
2504.84814453125
2505.52880859375
2506.071533203125
2506.507568359375
2506.855712890625
2507.13330078125
2507.35595703125
202514.6875
122512.25
72910.71875
42445.7578125
23972.197265625
12968.330078125
6581.22412109375
3017.77099609375
1156.5980224609375
300.9961242675781
21.85030174255371
57.948455810546875
252.8524169921875
515.0319213867188
792.7794189453125
1058.5
1298.9315185546875
1509.10302734375
1688.6494140625
1839.588623046875
1965.013427734375
2068.34423828125
2152.922119140625
2221.810791015625
2277.70458984375
2322.922119140625
2359.416748046875
2388.81787109375
2412.470703125
2431.477294921875
2446.7373046875
2458.9775390625
2468.7919921875
2476.658203125
2482.96044921875
2488.00927734375
2492.05029296875
2495.284423828125
2497.875732421875
2499.9501953125
2501.607666015625
2502.936767578125
2503.99755859375
2504.84814453125
2505.52880859375
2506.071533203125
2506.507568359375
2506.855712890625
2507.13330078125
2507.35595703125
202514.6875
122512.25
72910.71875
42445.7578125
23972.197265625
12968.330078125
6581.22412109375
3017.77099609375
1156.5980224609375
300.9961242675781
21.85030174255371
57.948455810546875
252.8524169921875
515.0319213867188
792.7794189453125
1058.5
1298.9315185546875
1509.10302734375
1688.6494140625
1839.588623046875
1965.013427734375
2068.34423828125
2152.922119140625
2221.810791015625
2277.70458984375
2322.922119140625
2359.416748046875
2388.81787109375
2412.470703125
2431.477294921875
2446.7373046875
2458.9775390625
2468.7919921875
2476.658203125
2482.96044921875
2488.00927734375
2492.05029296875
2495.284423828125
2497.875732421875
2499.9501953125
2501.607666015625
2502.936767578125
2503.99755859375
2504.84814453125
2505.52880859375
2506.071533203125
2506.507568359375
2506.855712890625
2507.13330078125
2507.35595703125
202514.6875
122512.25
72910.71875
42445.7578125
23972.197265625
12968.330078125
6581.22412109375
3017.77099609375
1156.5980224609375
300.9961242675781
21.85030174255371
57.948455810546875
252.8524169921875
515.0319213867188
792.7794189453125
1058.5
1298.9315185546875
1509.10302734375
1688.6494140625
1839.588623046875
1965.013427734375
2068.34423828125
2152.922119140625
2221.810791015625
2277.70458984375
2322.922119140625
2359.416748046875
2388.81787109375
2412.470703125
2431.477294921875
2446.7373046875
2458.9775390625
2468.7919921875
2476.658203125
2482.96044921875
2488.00927734375
2492.05029296875
2495.284423828125
2497.875732421875
2499.9501953125
2501.607666015625
2502.936767578125
2503.99755859375
2504.84814453125
2505.52880859375
2506.071533203125
2506.507568359375
2506.855712890625
2507.13330078125
2507.35595703125
202514.6875
122512.25
72910.71875
42445.7578125
23972.197265625
12968.330078125
6581.22412109375
3017.77099609375
1156.5980224609375
300.9961242675781
21.85030174255371
57.948455810546875
252.8524169921875
515.0319213867188
792.7794189453125
1058.5
1298.9315185546875
1509.10302734375
1688.6494140625
1839.588623046875
1965.013427734375
2068.34423828125
2152.922119140625
2221.810791015625
2277.70458984375
2322.922119140625
2359.416748046875
2388.81787109375
2412.470703125
2431.477294921875
2446.7373046875
2458.9775390625
2468.7919921875
2476.658203125
2482.96044921875
2488.00927734375
2492.05029296875
2495.284423828125
2497.875732421875
2499.9501953125
2501.607666015625
2502.936767578125
2503.99755859375
2504.84814453125
2505.52880859375
2506.071533203125
2506.507568359375
2506.855712890625
2507.13330078125
2507.35595703125
202514.6875
122512.25
72910.71875
42445.7578125
23972.197265625
12968.330078125
6581.22412109375
3017.77099609375
1156.5980224609375
300.9961242675781
21.85030174255371
57.948455810546875
252.8524169921875
515.0319213867188
792.7794189453125
1058.5
1298.9315185546875
1509.10302734375
1688.6494140625
1839.588623046875
1965.013427734375
2068.34423828125
2152.922119140625
2221.810791015625
2277.70458984375
2322.922119140625
2359.416748046875
2388.81787109375
2412.470703125
2431.477294921875
2446.7373046875
2458.9775390625
2468.7919921875
2476.658203125
2482.96044921875
2488.00927734375
2492.05029296875
2495.284423828125
2497.875732421875
2499.9501953125
2501.607666015625
2502.936767578125
2503.99755859375
2504.84814453125
2505.52880859375
2506.071533203125
2506.507568359375
2506.855712890625
2507.13330078125
2507.35595703125
202514.6875
122512.25
72910.71875
42445.7578125
23972.197265625
12968.330078125
6581.22412109375
3017.77099609375
1156.5980224609375
300.9961242675781
21.85030174255371
57.948455810546875
252.8524169921875
515.0319213867188
792.7794189453125
1058.5
1298.9315185546875
1509.10302734375
1688.6494140625
1839.588623046875
1965.013427734375
2068.34423828125
2152.922119140625
2221.810791015625
2277.70458984375
2322.922119140625
2359.416748046875
2388.81787109375
2412.470703125
2431.477294921875
2446.7373046875
2458.9775390625
2468.7919921875
2476.658203125
2482.96044921875
2488.00927734375
2492.05029296875
2495.284423828125
2497.875732421875
2499.9501953125
2501.607666015625
2502.936767578125
2503.99755859375
2504.84814453125
2505.52880859375
2506.071533203125
2506.507568359375
2506.855712890625
2507.13330078125
2507.35595703125
202514.6875
122512.25
72910.71875
42445.7578125
23972.197265625
12968.330078125
6581.22412109375
3017.77099609375
1156.5980224609375
300.9961242675781
21.85030174255371
57.948455810546875
252.8524169921875
515.0319213867188
792.7794189453125
1058.5
1298.9315185546875
1509.10302734375
1688.6494140625
1839.588623046875
1965.013427734375
2068.34423828125
2152.922119140625
2221.810791015625
2277.70458984375
2322.922119140625
2359.416748046875
2388.81787109375
2412.470703125
2431.477294921875
2446.7373046875
2458.9775390625
2468.7919921875
2476.658203125
2482.96044921875
2488.00927734375
2492.05029296875
2495.284423828125
2497.875732421875
2499.9501953125
2501.607666015625
2502.936767578125
2503.99755859375
2504.84814453125
2505.52880859375
2506.071533203125
2506.507568359375
2506.855712890625
2507.13330078125
2507.35595703125
Prediction output:tensor([[455.5071]], grad_fn=)
Beta Was this translation helpful? Give feedback.
All reactions