Skip to content
Snippets Groups Projects
Commit a9074e75 authored by Hammouda Elbez's avatar Hammouda Elbez :computer:
Browse files

Norse code added

parent 3a0222d3
No related branches found
No related tags found
No related merge requests found
import sys
sys.path.append('../../')
import torch
import numpy as np
from compression import ProgressiveCompression
from torch.utils.data import SubsetRandomSampler
from norse.torch import LIFCell, LICell
from norse.torch.module.leaky_integrator import LILinearCell
from norse.torch import LIFParameters
from norse.torch.module import encode, SequentialState
from datetime import datetime
import torchvision
import os
import pickle
import random
import itertools
# Reproducibility
torch.manual_seed(0)
random.seed(0)
np.random.seed(0)
MAXTH = [0.3,0.4,0.5,0.6,0.7] #
ALPHA = [0.005] # [0.002,0.004,0.006,0.008,0.01] #
REINFORCEMENT = [True] # [False, True]
apply_compression = False
for maxTh, Alpha, reinforcement in np.array(list(itertools.product(MAXTH, ALPHA, REINFORCEMENT))):
try:
os.mkdir("CIFAR10_CONV_maxTh:"+str(maxTh)+"_Alpha:"+str(Alpha)+"_"+"reinforcement:"+str(reinforcement))
except OSError as error:
print(error)
for i in range(2):
before = datetime.now()
file = open("CIFAR10_CONV_maxTh:"+str(maxTh)+"_Alpha:"+str(Alpha)+"_"+"reinforcement:"+str(reinforcement)+"/CIFAR10_CONV_maxTh:"+str(maxTh)+"_Alpha:"+str(Alpha)+"_"+"reinforcement:"+str(reinforcement)+"_"+str(before), 'w+')
transform = torchvision.transforms.Compose(
[
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.1307,), (0.3081,)),
]
)
train_data = torchvision.datasets.CIFAR10(
root=".",
train=True,
download=True,
transform=transform,
)
# reduce this number if you run out of GPU memory
BATCH_SIZE = 32
train_loader = torch.utils.data.DataLoader(
train_data, batch_size=BATCH_SIZE, shuffle=True #, sampler=SubsetRandomSampler(list(range(len(train_data)))[0:1000]) #
)
test_loader = torch.utils.data.DataLoader(
torchvision.datasets.CIFAR10(
root=".",
train=False,
transform=transform
),
batch_size=BATCH_SIZE
)
class Model(torch.nn.Module):
def __init__(self, encoder, snn, decoder):
super(Model, self).__init__()
self.encoder = encoder
self.snn = snn
self.decoder = decoder
def forward(self, x):
x = self.encoder(x)
x = self.snn(x)
log_p_y = self.decoder(x)
return log_p_y
class ConvNet(torch.nn.Module):
def __init__(self, num_channels=3, feature_size=32, method="super", alpha=100):
super(ConvNet, self).__init__()
self.features = int(((feature_size - 4) / 2 - 4) / 2)
self.conv1_out_channels = 32
self.conv2_out_channels = 128
self.fc1_out_channels = 1024
self.out_channels = 10
self.conv1 = torch.nn.Conv2d(num_channels, self.conv1_out_channels, 5, 1, bias=False)
self.conv2 = torch.nn.Conv2d(self.conv1_out_channels, self.conv2_out_channels, 5, 1, bias=False)
self.fc1 = torch.nn.Linear(self.features**2 * self.conv2_out_channels, self.fc1_out_channels, bias=False)
self.lif0 = LIFCell(p=LIFParameters(method=method, alpha=alpha, v_th=0.25))
self.lif1 = LIFCell(p=LIFParameters(method=method, alpha=alpha, v_th=0.25))
self.lif2 = LIFCell(p=LIFParameters(method=method, alpha=alpha, v_th=0.25))
self.out = LILinearCell(self.fc1_out_channels, self.out_channels)
#LILinearCell(self.fc1_out_channels, self.out_channels)
def forward(self, x):
seq_length = x.shape[0]
batch_size = x.shape[1]
# specify the initial states
s0 = s1 = s2 = s3 = s4 = so = None
voltages = torch.zeros(
seq_length, batch_size, self.out_channels, device=x.device, dtype=x.dtype
)
for ts in range(seq_length):
z = self.conv1(x[ts, :])
z, s0 = self.lif0(z, s0)
z = torch.nn.functional.max_pool2d(z, 2, 2)
z = self.out_channels * self.conv2(z)
z, s1 = self.lif1(z, s1)
z = torch.nn.functional.max_pool2d(z, 2, 2)
z = z.view(-1, self.features**2 * self.conv2_out_channels)
z = self.fc1(z)
z, s2 = self.lif2(z, s2)
v, so = self.out(torch.nn.functional.relu(z), so)
voltages[ts, :, :] = v
return voltages
def train(model, device, train_loader, optimizer, epoch, max_epochs):
model.train()
losses = []
for (data, target) in train_loader:
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = torch.nn.functional.nll_loss(output, target)
loss.backward()
optimizer.step()
losses.append(loss.item())
mean_loss = np.mean(losses)
return losses, mean_loss
def test(model, device, test_loader, epoch):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += torch.nn.functional.nll_loss(
output, target, reduction="sum"
).item() # sum up batch loss
pred = output.argmax(
dim=1, keepdim=True
) # get the index of the max log-probability
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
accuracy = 100.0 * correct / len(test_loader.dataset)
return test_loss, accuracy
def decode(x):
x, _ = torch.max(x, 0)
log_p_y = torch.nn.functional.log_softmax(x, dim=1)
return log_p_y
torch.autograd.set_detect_anomaly(True)
T = 35
LR = 0.001
EPOCHS = 100 # Increase this for improved accuracy
if torch.cuda.is_available():
DEVICE = torch.device(sys.argv[1])
else:
DEVICE = torch.device("cpu")
model = Model(
encoder=encode.SpikeLatencyLIFEncoder(T), snn=ConvNet(alpha=80), decoder=decode).to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
# compression
if (apply_compression):
progressive_compression = ProgressiveCompression(NorseModel=model, maxThreshold=maxTh, alphaP=Alpha, alphaN=-Alpha, to_file=True, apply_reinforcement=reinforcement, file= file)
training_losses = []
mean_losses = []
test_losses = []
accuracies = []
for epoch in range(EPOCHS):
print(f"Epoch {epoch}")
training_loss, mean_loss = train(
model, DEVICE, train_loader, optimizer, epoch, max_epochs=EPOCHS
)
test_loss, accuracy = test(model, DEVICE, test_loader, epoch)
training_losses += training_loss
mean_losses.append(mean_loss)
test_losses.append(test_loss)
accuracies.append(accuracy)
if (apply_compression):
progressive_compression.apply()
print(f"final accuracy: {accuracies[-1]}")
file.write("final accuracy:"+str(accuracies[-1])+"\n")
file.write("time:"+str(datetime.now() - before)+"\n")
with open("CIFAR10_CONV_maxTh:"+str(maxTh)+"_Alpha:"+str(Alpha)+"_"+"reinforcement:"+str(reinforcement)+"/CIFAR10_CONV_maxTh:"+str(maxTh)+"_Alpha:"+str(Alpha)+"_"+"reinforcement:"+str(reinforcement)+"_"+str(before)+".pkl",'wb') as f:
torch.save(model,"CIFAR10_CONV_maxTh:"+str(maxTh)+"_Alpha:"+str(Alpha)+"_"+"reinforcement:"+str(reinforcement)+"/CIFAR10_CONV_maxTh:"+str(maxTh)+"_Alpha:"+str(Alpha)+"_"+"reinforcement:"+str(reinforcement)+"_"+str(before)+".norse")
if (apply_compression):
pickle.dump([mean_losses,test_losses,accuracies,progressive_compression.weights,progressive_compression.compressions,progressive_compression.thresholds_p,progressive_compression.thresholds_n], f)
else:
pickle.dump([mean_losses,test_losses,accuracies], f)
import sys
sys.path.append('../../')
import torch
import numpy as np
from compression import ProgressiveCompression
from norse.torch import LIFCell
from norse.torch.module.leaky_integrator import LILinearCell
from norse.torch import LIFParameters
from norse.torch.module import encode
from tqdm import tqdm, trange
from datetime import datetime
import torchvision
import os
import pickle
import random
import itertools
# Reproducibility
torch.manual_seed(0)
random.seed(0)
np.random.seed(0)
MAXTH = [0.3,0.4,0.5,0.6,0.7] #
ALPHA = [0.005] # [0.002,0.004,0.006,0.008,0.01] #
REINFORCEMENT = [True] # [False, True]
apply_compression = True
for maxTh, Alpha, reinforcement in np.array(list(itertools.product(MAXTH, ALPHA, REINFORCEMENT))):
try:
os.mkdir("FACEMOTOR_CONV_maxTh:"+str(maxTh)+"_Alpha:"+str(Alpha)+"_"+"reinforcement:"+str(reinforcement))
except OSError as error:
print(error)
for i in range(2):
before = datetime.now()
file = open("FACEMOTOR_CONV_maxTh:"+str(maxTh)+"_Alpha:"+str(Alpha)+"_"+"reinforcement:"+str(reinforcement)+"/FACEMOTOR_CONV_maxTh:"+str(maxTh)+"_Alpha:"+str(Alpha)+"_"+"reinforcement:"+str(reinforcement)+"_"+str(before), 'w+')
transform = torchvision.transforms.Compose(
[
torchvision.transforms.ToTensor(),
torchvision.transforms.Grayscale(),
torchvision.transforms.Resize((250,160)),
torchvision.transforms.Normalize((0.1307,), (0.3081,)),
]
)
train_data = torchvision.datasets.ImageFolder(
root="/home/hammouda/Desktop/Work/falez-csnn-simulator/Datasets/FaceMotor/TrainingSet/",
transform=transform
)
# reduce this number if you run out of GPU memory
BATCH_SIZE = 5
train_loader = torch.utils.data.DataLoader(
train_data, batch_size=BATCH_SIZE, shuffle=True
)
test_loader = torch.utils.data.DataLoader(
torchvision.datasets.ImageFolder(
root="/home/hammouda/Desktop/Work/falez-csnn-simulator/Datasets/FaceMotor/TestingSet/",
transform=transform
),
batch_size=BATCH_SIZE,
)
class Model(torch.nn.Module):
def __init__(self, encoder, snn, decoder):
super(Model, self).__init__()
self.encoder = encoder
self.snn = snn
self.decoder = decoder
def forward(self, x):
x = self.encoder(x)
x = self.snn(x)
log_p_y = self.decoder(x)
return log_p_y
class ConvNet(torch.nn.Module):
def __init__(self, num_channels=1, feature_size=250, method="super", alpha=100):
super(ConvNet, self).__init__()
self.features = 54
self.conv1_out_channels = 32
self.conv2_out_channels = 64
self.fc1_out_channels = 128
self.out_channels = 10
self.conv1 = torch.nn.Conv2d(num_channels, self.conv1_out_channels, 5, 1, 3, bias=False)
self.conv2 = torch.nn.Conv2d(self.conv1_out_channels, self.conv2_out_channels, 17, 1, 9, bias=False)
self.fc1 = torch.nn.Linear(self.features * self.conv2_out_channels, self.fc1_out_channels, bias=False)
self.lif0 = LIFCell(p=LIFParameters(method=method, alpha=alpha, v_th=0.25))
self.lif1 = LIFCell(p=LIFParameters(method=method, alpha=alpha, v_th=0.25))
self.lif2 = LIFCell(p=LIFParameters(method=method, alpha=alpha, v_th=0.25))
self.out = LILinearCell(self.fc1_out_channels, self.out_channels)
def forward(self, x):
seq_length = x.shape[0]
batch_size = x.shape[1]
# specify the initial states
s0 = s1 = s2 = so = None
voltages = torch.zeros(
seq_length, batch_size, self.out_channels, device=x.device, dtype=x.dtype
)
for ts in range(seq_length):
z = self.conv1(x[ts, :])
z, s0 = self.lif0(z, s0)
z = torch.nn.functional.max_pool2d(z, 7, 6, 3)
z = self.out_channels * self.conv2(z)
z, s1 = self.lif1(z, s1)
z = torch.nn.functional.max_pool2d(z, 5, 5, 2)
z = z.view(batch_size,-1)
z = self.fc1(z)
z, s2 = self.lif2(z, s2)
v, so = self.out(torch.nn.functional.relu(z), so)
voltages[ts, :, :] = v
return voltages
def train(model, device, train_loader, optimizer, epoch, max_epochs):
model.train()
losses = []
for (data, target) in tqdm(train_loader, leave=False):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = torch.nn.functional.nll_loss(output, target)
loss.backward()
optimizer.step()
losses.append(loss.item())
mean_loss = np.mean(losses)
return losses, mean_loss
def test(model, device, test_loader, epoch):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += torch.nn.functional.nll_loss(
output, target, reduction="sum"
).item() # sum up batch loss
pred = output.argmax(
dim=1, keepdim=True
) # get the index of the max log-probability
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
accuracy = 100.0 * correct / len(test_loader.dataset)
return test_loss, accuracy
def decode(x):
x, _ = torch.max(x, 0)
log_p_y = torch.nn.functional.log_softmax(x, dim=1)
return log_p_y
T = 35
LR = 1e-4
EPOCHS = 15 # Increase this for improved accuracy
if torch.cuda.is_available():
DEVICE = torch.device("cuda")
else:
DEVICE = torch.device("cpu")
model = Model(
encoder=encode.SpikeLatencyLIFEncoder(T), snn=ConvNet(alpha=80), decoder=decode).to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
# compression
if (apply_compression):
progressive_compression = ProgressiveCompression(NorseModel=model, maxThreshold=maxTh, alphaP=Alpha, alphaN=-Alpha, to_file=True, apply_reinforcement=reinforcement, file= file)
training_losses = []
mean_losses = []
test_losses = []
accuracies = []
for epoch in trange(EPOCHS):
print(f"Epoch {epoch}")
training_loss, mean_loss = train(
model, DEVICE, train_loader, optimizer, epoch, max_epochs=EPOCHS
)
test_loss, accuracy = test(model, DEVICE, test_loader, epoch)
training_losses += training_loss
mean_losses.append(mean_loss)
test_losses.append(test_loss)
accuracies.append(accuracy)
if (apply_compression):
progressive_compression.apply()
print(f"final accuracy: {accuracies[-1]}")
file.write("final accuracy:"+str(accuracies[-1])+"\n")
file.write("time:"+str(datetime.now() - before)+"\n")
with open("FACEMOTOR_CONV_maxTh:"+str(maxTh)+"_Alpha:"+str(Alpha)+"_"+"reinforcement:"+str(reinforcement)+"/FACEMOTOR_CONV_maxTh:"+str(maxTh)+"_Alpha:"+str(Alpha)+"_"+"reinforcement:"+str(reinforcement)+"_"+str(before)+".pkl",'wb') as f:
torch.save(model,"FACEMOTOR_CONV_maxTh:"+str(maxTh)+"_Alpha:"+str(Alpha)+"_"+"reinforcement:"+str(reinforcement)+"/FACEMOTOR_CONV_maxTh:"+str(maxTh)+"_Alpha:"+str(Alpha)+"_"+"reinforcement:"+str(reinforcement)+"_"+str(before)+".norse")
if (apply_compression):
pickle.dump([mean_losses,test_losses,accuracies,progressive_compression.weights,progressive_compression.compressions,progressive_compression.thresholds_p,progressive_compression.thresholds_n], f)
else:
pickle.dump([mean_losses,test_losses,accuracies], f)
import sys
sys.path.append('../../')
import torch
import numpy as np
from compression import ProgressiveCompression
from norse.torch import LIFCell
from norse.torch.module.leaky_integrator import LILinearCell
from norse.torch import LIFParameters
from norse.torch.module import encode
#from tqdm import tqdm, trange
from datetime import datetime
import torchvision
import os
import pickle
import random
import itertools
# Reproducibility
torch.manual_seed(0)
random.seed(0)
np.random.seed(0)
MAXTH = [0.6,0.7] # 0.3,0.5,
ALPHA = [0.005] # [0.002,0.004,0.006,0.008,0.01] #
REINFORCEMENT = [True] # [False, True]
for maxTh, Alpha, reinforcement in np.array(list(itertools.product(MAXTH, ALPHA, REINFORCEMENT))):
try:
os.mkdir("FashionMNIST_CONV_maxTh:"+str(maxTh)+"_Alpha:"+str(Alpha)+"_"+"reinforcement:"+str(reinforcement))
except OSError as error:
print(error)
for i in range(2):
before = datetime.now()
file = open("FashionMNIST_CONV_maxTh:"+str(maxTh)+"_Alpha:"+str(Alpha)+"_"+"reinforcement:"+str(reinforcement)+"/FashionMNIST_CONV_maxTh:"+str(maxTh)+"_Alpha:"+str(Alpha)+"_"+"reinforcement:"+str(reinforcement)+"_"+str(before), 'w+')
apply_compression = True
transform = torchvision.transforms.Compose(
[
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.1307,), (0.3081,)),
]
)
train_data = torchvision.datasets.FashionMNIST(
root=".",
train=True,
download=True,
transform=transform,
)
# reduce this number if you run out of GPU memory
BATCH_SIZE = 128
train_loader = torch.utils.data.DataLoader(
train_data, batch_size=BATCH_SIZE, shuffle=True
)
test_loader = torch.utils.data.DataLoader(
torchvision.datasets.FashionMNIST(
root=".",
train=False,
transform=transform,
),
batch_size=BATCH_SIZE,
)
class Model(torch.nn.Module):
def __init__(self, encoder, snn, decoder):
super(Model, self).__init__()
self.encoder = encoder
self.snn = snn
self.decoder = decoder
def forward(self, x):
x = self.encoder(x)
x = self.snn(x)
log_p_y = self.decoder(x)
return log_p_y
class ConvNet(torch.nn.Module):
def __init__(self, num_channels=1, feature_size=28, method="super", alpha=100):
super(ConvNet, self).__init__()
self.features = int(((feature_size - 4) / 2 - 4) / 2)
self.conv1_out_channels = 32
self.conv2_out_channels = 128
self.fc1_out_channels = 1024
self.out_channels = 10
self.conv1 = torch.nn.Conv2d(num_channels, self.conv1_out_channels, 5, 1, bias=False)
self.conv2 = torch.nn.Conv2d(self.conv1_out_channels, self.conv2_out_channels, 5, 1, bias=False)
self.fc1 = torch.nn.Linear(self.features * self.features * self.conv2_out_channels, self.fc1_out_channels, bias=False)
self.lif0 = LIFCell(p=LIFParameters(method=method, alpha=alpha,v_th=0.25))
self.lif1 = LIFCell(p=LIFParameters(method=method, alpha=alpha,v_th=0.25))
self.lif2 = LIFCell(p=LIFParameters(method=method, alpha=alpha,v_th=0.25))
self.out = LILinearCell(self.fc1_out_channels, self.out_channels)
def forward(self, x):
seq_length = x.shape[0]
batch_size = x.shape[1]
# specify the initial states
s0 = s1 = s2 = so = None
voltages = torch.zeros(
seq_length, batch_size, self.out_channels, device=x.device, dtype=x.dtype
)
for ts in range(seq_length):
z = self.conv1(x[ts, :])
z, s0 = self.lif0(z, s0)
z = torch.nn.functional.max_pool2d(z, 2, 2)
z = self.out_channels * self.conv2(z)
z, s1 = self.lif1(z, s1)
z = torch.nn.functional.max_pool2d(z, 2, 2)
z = z.view(-1, 4**2 * self.conv2_out_channels)
z = self.fc1(z)
z, s2 = self.lif2(z, s2)
v, so = self.out(torch.nn.functional.relu(z), so)
voltages[ts, :, :] = v
return voltages
def train(model, device, train_loader, optimizer, epoch, max_epochs):
model.train()
losses = []
for (data, target) in train_loader: #tqdm(train_loader, leave=False):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = torch.nn.functional.nll_loss(output, target)
loss.backward()
optimizer.step()
losses.append(loss.item())
mean_loss = np.mean(losses)
return losses, mean_loss
def test(model, device, test_loader, epoch):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += torch.nn.functional.nll_loss(
output, target, reduction="sum"
).item() # sum up batch loss
pred = output.argmax(
dim=1, keepdim=True
) # get the index of the max log-probability
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
accuracy = 100.0 * correct / len(test_loader.dataset)
return test_loss, accuracy
def decode(x):
x, _ = torch.max(x, 0)
log_p_y = torch.nn.functional.log_softmax(x, dim=1)
return log_p_y
T = 35
LR = 0.001
EPOCHS = 100 # Increase this for improved accuracy
if torch.cuda.is_available():
DEVICE = torch.device("cuda")
else:
DEVICE = torch.device("cpu")
model = Model(
encoder=encode.SpikeLatencyLIFEncoder(T), snn=ConvNet(alpha=80), decoder=decode).to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
# compression
if (apply_compression):
progressive_compression = ProgressiveCompression(NorseModel=model, maxThreshold=maxTh, alphaP=Alpha, alphaN=-Alpha, to_file=True, apply_reinforcement=reinforcement, file= file)
training_losses = []
mean_losses = []
test_losses = []
accuracies = []
for epoch in range(EPOCHS):
print(f"Epoch {epoch}")
training_loss, mean_loss = train(
model, DEVICE, train_loader, optimizer, epoch, max_epochs=EPOCHS
)
test_loss, accuracy = test(model, DEVICE, test_loader, epoch)
training_losses += training_loss
mean_losses.append(mean_loss)
test_losses.append(test_loss)
accuracies.append(accuracy)
if (apply_compression):
progressive_compression.apply()
print(f"final accuracy: {accuracies[-1]}")
file.write("final accuracy:"+str(accuracies[-1])+"\n")
file.write("time:"+str(datetime.now() - before)+"\n")
with open("FashionMNIST_CONV_maxTh:"+str(maxTh)+"_Alpha:"+str(Alpha)+"_"+"reinforcement:"+str(reinforcement)+"/FashionMNIST_CONV_maxTh:"+str(maxTh)+"_Alpha:"+str(Alpha)+"_"+"reinforcement:"+str(reinforcement)+"_"+str(before)+".pkl",'wb') as f:
torch.save(model,"FashionMNIST_CONV_maxTh:"+str(maxTh)+"_Alpha:"+str(Alpha)+"_"+"reinforcement:"+str(reinforcement)+"/FashionMNIST_CONV_maxTh:"+str(maxTh)+"_Alpha:"+str(Alpha)+"_"+"reinforcement:"+str(reinforcement)+"_"+str(before)+".norse")
if (apply_compression):
pickle.dump([mean_losses,test_losses,accuracies,progressive_compression.weights,progressive_compression.compressions,progressive_compression.thresholds_p,progressive_compression.thresholds_n], f)
else:
pickle.dump([mean_losses,test_losses,accuracies], f)
import sys
sys.path.append('../../')
import torch
import numpy as np
from compression import ProgressiveCompression
from norse.torch import LIFCell
from norse.torch.module.leaky_integrator import LILinearCell
from norse.torch import LIFParameters
from norse.torch.module import encode
#from tqdm import tqdm, trange
from datetime import datetime
import torchvision
import os
import pickle
import random
import itertools
# Reproducibility
torch.manual_seed(0)
random.seed(0)
np.random.seed(0)
MAXTH = [0.6,0.7] #
ALPHA = [0.005] # [0.002,0.004,0.006,0.008,0.01] #
REINFORCEMENT = [True] # [False, True]
for maxTh, Alpha, reinforcement in np.array(list(itertools.product(MAXTH, ALPHA, REINFORCEMENT))):
try:
os.mkdir("MNIST_CONV_maxTh:"+str(maxTh)+"_Alpha:"+str(Alpha)+"_"+"reinforcement:"+str(reinforcement))
except OSError as error:
print(error)
for i in range(2):
before = datetime.now()
file = open("MNIST_CONV_maxTh:"+str(maxTh)+"_Alpha:"+str(Alpha)+"_"+"reinforcement:"+str(reinforcement)+"/MNIST_CONV_maxTh:"+str(maxTh)+"_Alpha:"+str(Alpha)+"_"+"reinforcement:"+str(reinforcement)+"_"+str(before), 'w+')
apply_compression = True
transform = torchvision.transforms.Compose(
[
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.1307,), (0.3081,)),
]
)
train_data = torchvision.datasets.MNIST(
root=".",
train=True,
download=True,
transform=transform,
)
# reduce this number if you run out of GPU memory
BATCH_SIZE = 512
train_loader = torch.utils.data.DataLoader(
train_data, batch_size=BATCH_SIZE, shuffle=True
)
test_loader = torch.utils.data.DataLoader(
torchvision.datasets.MNIST(
root=".",
train=False,
transform=transform,
),
batch_size=BATCH_SIZE,
)
class Model(torch.nn.Module):
def __init__(self, encoder, snn, decoder):
super(Model, self).__init__()
self.encoder = encoder
self.snn = snn
self.decoder = decoder
def forward(self, x):
x = self.encoder(x)
x = self.snn(x)
log_p_y = self.decoder(x)
return log_p_y
class ConvNet(torch.nn.Module):
def __init__(self, num_channels=1, feature_size=28, method="super", alpha=100):
super(ConvNet, self).__init__()
self.features = int(((feature_size - 4) / 2 - 4) / 2)
self.conv1_out_channels = 32
self.conv2_out_channels = 128
self.fc1_out_channels = 1024
self.out_channels = 10
self.conv1 = torch.nn.Conv2d(num_channels, self.conv1_out_channels, 5, 1, bias=False)
self.conv2 = torch.nn.Conv2d(self.conv1_out_channels, self.conv2_out_channels, 5, 1, bias=False)
self.fc1 = torch.nn.Linear(self.features * self.features * self.conv2_out_channels, self.fc1_out_channels, bias=False)
self.lif0 = LIFCell(p=LIFParameters(method=method, alpha=alpha,v_th=0.25))
self.lif1 = LIFCell(p=LIFParameters(method=method, alpha=alpha,v_th=0.25))
self.lif2 = LIFCell(p=LIFParameters(method=method, alpha=alpha,v_th=0.25))
self.out = LILinearCell(self.fc1_out_channels, self.out_channels)
def forward(self, x):
seq_length = x.shape[0]
batch_size = x.shape[1]
# specify the initial states
s0 = s1 = s2 = so = None
voltages = torch.zeros(
seq_length, batch_size, self.out_channels, device=x.device, dtype=x.dtype
)
for ts in range(seq_length):
z = self.conv1(x[ts, :])
z, s0 = self.lif0(z, s0)
z = torch.nn.functional.max_pool2d(z, 2, 2)
z = self.out_channels * self.conv2(z)
z, s1 = self.lif1(z, s1)
z = torch.nn.functional.max_pool2d(z, 2, 2)
z = z.view(-1, 4**2 * self.conv2_out_channels)
z = self.fc1(z)
z, s2 = self.lif2(z, s2)
v, so = self.out(torch.nn.functional.relu(z), so)
voltages[ts, :, :] = v
return voltages
def train(model, device, train_loader, optimizer, epoch, max_epochs):
model.train()
losses = []
for (data, target) in train_loader: #tqdm(train_loader, leave=False):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = torch.nn.functional.nll_loss(output, target)
loss.backward()
optimizer.step()
losses.append(loss.item())
mean_loss = np.mean(losses)
return losses, mean_loss
def test(model, device, test_loader, epoch):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += torch.nn.functional.nll_loss(
output, target, reduction="sum"
).item() # sum up batch loss
pred = output.argmax(
dim=1, keepdim=True
) # get the index of the max log-probability
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
accuracy = 100.0 * correct / len(test_loader.dataset)
return test_loss, accuracy
def decode(x):
x, _ = torch.max(x, 0)
log_p_y = torch.nn.functional.log_softmax(x, dim=1)
return log_p_y
T = 35
LR = 0.001
EPOCHS = 100 # Increase this for improved accuracy
if torch.cuda.is_available():
DEVICE = torch.device("cuda:1")
else:
DEVICE = torch.device("cpu")
model = Model(
encoder=encode.SpikeLatencyLIFEncoder(T), snn=ConvNet(alpha=80), decoder=decode).to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
# compression
if (apply_compression):
progressive_compression = ProgressiveCompression(NorseModel=model, maxThreshold=maxTh, alphaP=Alpha, alphaN=-Alpha, to_file=True, apply_reinforcement=reinforcement, file= file)
training_losses = []
mean_losses = []
test_losses = []
accuracies = []
for epoch in range(EPOCHS):
print(f"Epoch {epoch}")
training_loss, mean_loss = train(
model, DEVICE, train_loader, optimizer, epoch, max_epochs=EPOCHS
)
test_loss, accuracy = test(model, DEVICE, test_loader, epoch)
training_losses += training_loss
mean_losses.append(mean_loss)
test_losses.append(test_loss)
accuracies.append(accuracy)
if (apply_compression):
progressive_compression.apply()
print(f"final accuracy: {accuracies[-1]}")
file.write("final accuracy:"+str(accuracies[-1])+"\n")
file.write("time:"+str(datetime.now() - before)+"\n")
with open("MNIST_CONV_maxTh:"+str(maxTh)+"_Alpha:"+str(Alpha)+"_"+"reinforcement:"+str(reinforcement)+"/MNIST_CONV_maxTh:"+str(maxTh)+"_Alpha:"+str(Alpha)+"_"+"reinforcement:"+str(reinforcement)+"_"+str(before)+".pkl",'wb') as f:
torch.save(model,"MNIST_CONV_maxTh:"+str(maxTh)+"_Alpha:"+str(Alpha)+"_"+"reinforcement:"+str(reinforcement)+"/MNIST_CONV_maxTh:"+str(maxTh)+"_Alpha:"+str(Alpha)+"_"+"reinforcement:"+str(reinforcement)+"_"+str(before)+".norse")
if (apply_compression):
pickle.dump([mean_losses,test_losses,accuracies,progressive_compression.weights,progressive_compression.compressions,progressive_compression.thresholds_p,progressive_compression.thresholds_n], f)
else:
pickle.dump([mean_losses,test_losses,accuracies], f)
import numpy as np
import torch
class ProgressiveCompression:
def __init__(self, NorseModel,maxThreshold=0.3, alphaP=0.005, alphaN=-0.005, betaP=0.01, betaN=-0.01, prune_recurrent=True, apply_reinforcement=False, to_file=False, file=None, layerwise=False):
self.alpha_p = []
self.alpha_n = []
self.max_threshold = maxThreshold
self.model = NorseModel
self.prune_recurrent = prune_recurrent
self.layerwise = layerwise
self.apply_reinforcement = apply_reinforcement
self.to_file = to_file
self.file = file
self.max_threshold_p = []
self.max_threshold_n = []
self.max_w_p = []
self.max_w_n = []
self.prune_matrix = []
self.thresholds_p = []
self.thresholds_n = []
self.compressions = []
self.weights = []
self.betaP = betaP
self.betaN = betaN
i = 0
for name, param in self.model.named_parameters():
print(name,param.data.min(),param.data.max())
self.max_w_p.append(param.data.max())
self.max_w_n.append(param.data.min())
if (self.layerwise):
self.alpha_p.append(alphaP + (0.005*i))
self.alpha_n.append(alphaN + (-0.005*i))
else:
self.alpha_p.append(alphaP)
self.alpha_n.append(alphaN)
self.max_threshold_p.append(param.data.max() * self.max_threshold)
self.max_threshold_n.append(param.data.min() * self.max_threshold)
print("max_threshold_n: "+str(param.data.min() * self.max_threshold)+" max_threshold_p: "+str(param.data.max() * self.max_threshold)+"\n")
if(self.to_file):
self.file.write(str(name)+" "+str(param.data.min())+" "+str(param.data.max())+"\n")
self.file.write("max_threshold_n: "+str(param.data.min() * self.max_threshold)+" max_threshold_p: "+str(param.data.max() * self.max_threshold)+"\n")
i = i + 1
for param in self.model.parameters():
self.prune_matrix.append(torch.zeros_like(param))
self.thresholds_p.append([None])
self.thresholds_n.append([None])
self.compressions.append([0])
def applyprune(self,name, alpha_p, alpha_n, max_thresholds_p,max_thresholds_n, weights, prune_matrix, threshold_p, threshold_n):
print(name, " before prune: Min: ",weights.min()," Max: ",weights.max())
if ((not self.prune_recurrent) and ("recurrent" in name)):
return weights, prune_matrix, threshold_p, threshold_n
else:
if threshold_p == None:
threshold_p = alpha_p
else:
if threshold_p < max_thresholds_p:
threshold_p += alpha_p * (np.count_nonzero(prune_matrix.cpu() == 0) / (np.count_nonzero(prune_matrix.cpu() == 0) + np.count_nonzero(prune_matrix.cpu() == 1)))
threshold_p = round(threshold_p, 6)
if threshold_n == None:
threshold_n = alpha_n
else:
if threshold_n > max_thresholds_n:
threshold_n += alpha_n * (np.count_nonzero(prune_matrix.cpu() == 0) / (np.count_nonzero(prune_matrix.cpu() == 0) + np.count_nonzero(prune_matrix.cpu() == 1)))
threshold_n = round(threshold_n, 6)
return weights.masked_fill(((weights < threshold_p) & (weights > 0)) | ((weights > threshold_n) & (weights < 0)), 0), prune_matrix.masked_fill(((weights < threshold_p) & (weights > 0)) | ((weights > threshold_n) & (weights < 0)), 1), threshold_p, threshold_n
def reinforcement(self, name, weights, beta_p, beta_n, thres_p, thres_n, max_w_p, max_w_n):
if ((not self.prune_recurrent) and ("recurrent" in name)):
return weights
weights = torch.where(weights > 0,weights + beta_p * thres_p,weights)
weights[weights > max_w_p] = max_w_p
weights = torch.where(weights < 0,weights - beta_n * thres_n,weights)
weights[weights < max_w_n] = max_w_n
return weights
def apply(self,):
i = 0
for name, param in self.model.named_parameters():
param.data, self.prune_matrix[i], thres_p , thres_n = self.applyprune(name, self.alpha_p[i],self.alpha_n[i],self.max_threshold_p[i],self.max_threshold_n[i],param.data,self.prune_matrix[i],self.thresholds_p[i][-1],self.thresholds_n[i][-1])
#try:
print(name,"zeros:",int((self.prune_matrix[i]).sum()),"/",float(torch.prod(torch.tensor(param.data.shape))),"("+str(round(float((self.prune_matrix[i]).sum()*100/(torch.prod(torch.tensor(param.data.shape)))),3))+"%)","threshold_n:",thres_n,"threshold_p:",thres_p)
if(self.to_file):
self.file.write(str(name)+" zeros: "+str(float((self.prune_matrix[i]).sum()))+" / "+str(float(torch.prod(torch.tensor(param.data.shape))))+" ("+str(round(float((self.prune_matrix[i]).sum()*100/(torch.prod(torch.tensor(param.data.shape)))),3))+"%) threshold_n: "+str(thres_n)+" threshold_p: "+str(thres_p)+"\n")
self.compressions[i].append(round(float((self.prune_matrix[i]).sum()*100/(torch.prod(torch.tensor(param.data.shape)))),3))
self.thresholds_p[i].append(thres_p)
self.thresholds_n[i].append(thres_n)
if(self.apply_reinforcement):
param.data = self.reinforcement(name, param.data, self.betaP, self.betaN, thres_p, thres_n, self.max_w_p[i], self.max_w_n[i])
self.weights.append(param.data)
#except Exception:
# pass
i+=1
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment