From cae5478e03365289f3222f4c0698e73c1c2569b6 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Aur=C3=A9lie=20saulquin?= <aurelie.saulq@gmail.com>
Date: Mon, 24 Mar 2025 15:33:18 +0100
Subject: [PATCH] change template

---
 .../neurons/ShiftLif/shiftlif_parallel.vhd    | 15 ++++++++------
 .../modnef_driver/drivers/xstep_driver.py     |  3 ++-
 modneflib/modnef/templates/dataset.py         | 20 +++++++++++++++++++
 modneflib/modnef/templates/evaluation.py      | 16 ++-------------
 modneflib/modnef/templates/train.py           | 18 +----------------
 modneflib/modnef/templates/vhdl_generation.py | 17 ++--------------
 6 files changed, 36 insertions(+), 53 deletions(-)
 create mode 100644 modneflib/modnef/templates/dataset.py

diff --git a/ModNEF_Sources/modules/neurons/ShiftLif/shiftlif_parallel.vhd b/ModNEF_Sources/modules/neurons/ShiftLif/shiftlif_parallel.vhd
index 6c6989f..32a3aca 100644
--- a/ModNEF_Sources/modules/neurons/ShiftLif/shiftlif_parallel.vhd
+++ b/ModNEF_Sources/modules/neurons/ShiftLif/shiftlif_parallel.vhd
@@ -127,7 +127,7 @@ architecture Behavioral of ShiftLif_Parallel is
   end component;
 
   -- type definition
-  type transmission_state_t is (idle, voltage_update, check_arbitration, request, accept, wait_arbitration);
+  type transmission_state_t is (idle, voltage_update, check_arbitration, request, accept, wait_arbitration, arbitration_finish);
   type reception_state_t    is (idle, request, get_data);
 
   -- ram signals
@@ -249,9 +249,8 @@ begin
 
           when check_arbitration =>
             if spikes = no_spike then
-              transmission_state <= idle;
+              transmission_state <= arbitration_finish;
               o_emu_busy <= '0';
-              tr_fsm_en := '0';
             else
               transmission_state <= request;
               arb_spikes <= spikes;
@@ -278,12 +277,15 @@ begin
           when wait_arbitration =>
             start_arb <= '0';
             if arb_busy = '0' then
-              transmission_state <= idle;
-              o_emu_busy <= '0';
-              tr_fsm_en := '0';
+              transmission_state <= arbitration_finish;
             else
               transmission_state <= wait_arbitration;
             end if;  
+              
+          when arbitration_finish =>
+            transmission_state <= idle;
+            o_emu_busy <= '0';
+            tr_fsm_en := '0';
         end case;
       end if;
     end if;
@@ -337,3 +339,4 @@ begin
   end generate neuron_generation;
 
 end Behavioral;
+
diff --git a/modneflib/modnef/modnef_driver/drivers/xstep_driver.py b/modneflib/modnef/modnef_driver/drivers/xstep_driver.py
index 778b95b..6b5b216 100644
--- a/modneflib/modnef/modnef_driver/drivers/xstep_driver.py
+++ b/modneflib/modnef/modnef_driver/drivers/xstep_driver.py
@@ -160,8 +160,9 @@ class XStep_Driver(ModNEF_Driver):
         step_send += 1
 
       if step == len(sample_aer)-1:
+        print(len(data))
         emulation_result = self.rust_driver.data_transmission(data, reset_membrane)
-        
+        print("hi")
         res_step = self._unpack_data(emulation_result)
         for rs in res_step:
           for aer in rs:
diff --git a/modneflib/modnef/templates/dataset.py b/modneflib/modnef/templates/dataset.py
new file mode 100644
index 0000000..e0b86fc
--- /dev/null
+++ b/modneflib/modnef/templates/dataset.py
@@ -0,0 +1,20 @@
+import os
+import tonic
+from torch.utils.data import DataLoader
+
+"""DataSet Definition"""
+dataset_path = f"{os.environ['HOME']}/datasets"
+
+# data set definition, change to your dataset
+sensor_size = tonic.datasets.NMNIST.sensor_size
+frame_transform = tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=5)
+
+train_set = tonic.datasets.NMNIST(save_to=dataset_path, train=True, transform=frame_transform)
+test_set = tonic.datasets.NMNIST(save_to=dataset_path, train=False, transform=frame_transform)
+
+# batch loader
+batch_size = 64
+
+trainLoader = DataLoader(train_set, batch_size=batch_size, shuffle=True, drop_last = True, collate_fn = tonic.collation.PadTensors(batch_first=True))
+testLoader = DataLoader(test_set, batch_size=batch_size, shuffle=True, drop_last = True, collate_fn = tonic.collation.PadTensors(batch_first=True))
+validationLoader = None
\ No newline at end of file
diff --git a/modneflib/modnef/templates/evaluation.py b/modneflib/modnef/templates/evaluation.py
index ebcc8f8..91cba66 100644
--- a/modneflib/modnef/templates/evaluation.py
+++ b/modneflib/modnef/templates/evaluation.py
@@ -7,6 +7,7 @@ import torch
 from run_lib import *
 import sys
 from model import MyModel
+from dataset import *
 
 if __name__ == "__main__":
 
@@ -53,20 +54,7 @@ if __name__ == "__main__":
   conf_matrix_file = "confusion_matrix.png"
   conf_matrix_classes = [str(i) for i in range(10)]
   
-  """DataSet Definition"""
-  dataset_path = f"{os.environ['HOME']}/datasets"
-
-  # data set definition, change to your dataset
-  sensor_size = tonic.datasets.NMNIST.sensor_size
-  frame_transform = tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=10)
-
-  test_set = tonic.datasets.NMNIST(save_to=dataset_path, train=False, transform=frame_transform)
-
-  # batch loader
-  batch_size = 64
   
-  testLoader = DataLoader(test_set, batch_size=batch_size, shuffle=True, drop_last = True, collate_fn = tonic.collation.PadTensors(batch_first=True))
-
   if kind == "eval":
     acc, y_pred, y_true = evaluation(
       model=model, 
@@ -99,7 +87,7 @@ if __name__ == "__main__":
     exit(-1)
 
   if save_conf_matrix:
-    confusion_matrix(
+    conf_matrix(
       y_true=y_true,
       y_pred=y_pred,
       file_name=conf_matrix_file,
diff --git a/modneflib/modnef/templates/train.py b/modneflib/modnef/templates/train.py
index b67de45..0f48125 100644
--- a/modneflib/modnef/templates/train.py
+++ b/modneflib/modnef/templates/train.py
@@ -6,6 +6,7 @@ from snntorch.surrogate import fast_sigmoid
 from run_lib import *
 import torch
 from model import MyModel
+from dataset import *
 
 if __name__ == "__main__":
 
@@ -31,23 +32,6 @@ if __name__ == "__main__":
   save_plot = False
   save_history = False
   output_path = "."
-  
-  """DataSet Definition"""
-  dataset_path = f"{os.environ['HOME']}/datasets"
-
-  # data set definition, change to your dataset
-  sensor_size = tonic.datasets.NMNIST.sensor_size
-  frame_transform = tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=5)
-
-  train_set = tonic.datasets.NMNIST(save_to=dataset_path, train=True, transform=frame_transform)
-  test_set = tonic.datasets.NMNIST(save_to=dataset_path, train=False, transform=frame_transform)
-
-  # batch loader
-  batch_size = 64
-  
-  trainLoader = DataLoader(train_set, batch_size=batch_size, shuffle=True, drop_last = True, collate_fn = tonic.collation.PadTensors(batch_first=True))
-  testLoader = DataLoader(test_set, batch_size=batch_size, shuffle=True, drop_last = True, collate_fn = tonic.collation.PadTensors(batch_first=True))
-  validationLoader = None
 
   train(
     model=model, 
diff --git a/modneflib/modnef/templates/vhdl_generation.py b/modneflib/modnef/templates/vhdl_generation.py
index fcca214..563d1a1 100644
--- a/modneflib/modnef/templates/vhdl_generation.py
+++ b/modneflib/modnef/templates/vhdl_generation.py
@@ -6,6 +6,7 @@ from snntorch.surrogate import fast_sigmoid
 from run_lib import *
 import torch
 from model import MyModel
+from dataset import *
 
 if __name__ == "__main__":
 
@@ -33,21 +34,7 @@ if __name__ == "__main__":
   """VHDL file definition"""
   output_path = "."
   file_name = "template_vhdl_model.vhd"
-  driver_config_path = "driver_config"
-  
-  """DataSet Definition"""
-  dataset_path = f"{os.environ['HOME']}/datasets"
-
-  # data set definition, change to your dataset
-  sensor_size = tonic.datasets.NMNIST.sensor_size
-  frame_transform = tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=10)
-
-  test_set = tonic.datasets.NMNIST(save_to=dataset_path, train=False, transform=frame_transform)
-
-  # batch loader
-  batch_size = 64
-  
-  testLoader = DataLoader(test_set, batch_size=batch_size, shuffle=True, drop_last = True, collate_fn = tonic.collation.PadTensors(batch_first=True))
+  driver_config_path = "driver_config.yml"
 
   acc, y_pred, y_true = hardware_estimation(
     model=model, 
-- 
GitLab