diff --git a/HelperFunctions.ipynb b/HelperFunctions.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..8ebb3a75a50bbbd3a673437e235b043791f564f7 --- /dev/null +++ b/HelperFunctions.ipynb @@ -0,0 +1,139 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Here you can find some useful functions for working with the SpiNNaker ;) " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "~~~\n", + " ExtractConnectionFile(Input_File, Connections_File,Non_Zero, Connections_shape, n_samples_to_plot)\n", + "~~~\n", + "\n", + " Using the CSNN simulator, by exporting the weights as binary files, using ***ExtractConnectionFile*** we can extract the weights as connection files to use as projections between the Populations on the SpiNNaker" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import pathlib\n", + "import math\n", + "from mpl_toolkits.axes_grid1 import ImageGrid\n", + "import matplotlib.pyplot as plt\n", + "\n", + "def plot_and_save_grid(input_data,nrows_ncols,lable='Sample Plost',save=False,file_name=None):\n", + " fig = plt.figure(figsize=(10, 10))\n", + " grid = ImageGrid(fig, 111, # similar to subplot(111)\n", + " nrows_ncols, # creates nrow x ncol grid of axes\n", + " axes_pad=0.1, # pad between axes in inch.\n", + " )\n", + "\n", + " for ax, inx in zip(grid, list(range(nrows_ncols[0]*nrows_ncols[1]))):\n", + " ax.imshow(input_data[inx],cmap=\"jet\")\n", + " ax.set_axis_off()\n", + " \n", + "\n", + " if save:\n", + " plt.savefig(file_name)\n", + " return\n", + "\n", + "def Plot_extracted_weights(Input_File, Connections_File,Connections_shape={'x':28, 'y':28, 'z':400}, n_samples_to_plot=0):\n", + " w = np.fromfile(Input_File, dtype=np.dtype(np.float32))#[28, 28, 1, 400]\n", + " print(f'Input file shape = {w.shape}')\n", + " w=w.reshape(Connections_shape['x'], Connections_shape['y'], Connections_shape['z'])#x, y, z\n", + " w=np.array([w[:, :, i].transpose() for i in range(Connections_shape['z']) ])\n", + " w=w.reshape(Connections_shape['z'], Connections_shape['x']*Connections_shape['y']).transpose()# 400rows of 28*28 to 400 rows of 784 and then transposed==> 784*400\n", + " if n_samples_to_plot:\n", + " w_to_plot=[w.reshape(Connections_shape['x'], Connections_shape['y'],-1)[:,:,i] for i in range(n_samples_to_plot)]\n", + " plot_and_save_grid(w_to_plot,(int(math.sqrt(n_samples_to_plot)),int(math.sqrt(n_samples_to_plot))))\n", + " return\n", + "\n", + "def ExtractConnectionFile(Input_File, Connections_File,Non_Zero=False, Connections_shape={'x':28, 'y':28, 'z':400}, n_samples_to_plot=0):\n", + " extracted_weights=[]\n", + " w = np.fromfile(Input_File, dtype=np.dtype(np.float32))#[28, 28, 1, 400]\n", + " print(f'Input file shape = {w.shape}')\n", + "# w=w.reshape(Connections_shape['x'], Connections_shape['y'], Connections_shape['z'])#x, y, z\n", + "# w=np.array([w[:, :, i].transpose() for i in range(Connections_shape['z']) ])\n", + " f=open(Connections_File, 'w')\n", + " f.write('# columns = [\"i\", \"j\", \"weight\", \"delay\"]\\n')\n", + " w=w.reshape(Connections_shape['x']*Connections_shape['y'], Connections_shape['z'])# 400rows of 28*28 to 400 rows of 784 and then transposed==> 784*400\n", + " X,Y=w.shape\n", + " for x in range(X):\n", + " for y in range(Y):\n", + " if Non_Zero:\n", + " if w[x][y]:\n", + " line=f'{x} {y} {w[x][y]} 1\\n'\n", + " f.write(line)\n", + " extracted_weights.append(w[x][y])\n", + " else:\n", + " line=f'{x} {y} {w[x][y]} 1\\n' \n", + " f.write(line)\n", + " extracted_weights.append(w[x][y])\n", + " \n", + " f.close()\n", + " if n_samples_to_plot:\n", + " w_to_plot=[w.reshape(Connections_shape['x'], Connections_shape['y'],-1)[:,:,i] for i in range(n_samples_to_plot)]\n", + " plot_and_save_grid(w_to_plot,(int(math.sqrt(n_samples_to_plot)),int(math.sqrt(n_samples_to_plot))))\n", + "\n", + " print(f'Output file shape = {w.shape}')\n", + " print(f'Connections File Path = {pathlib.Path(Connections_File).parent.resolve()}/{Connections_File}')\n", + " return pathlib.Path(Connections_File).parent.resolve(), extracted_weights" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## HOw to use:\n", + "Just put the binary files in an array as below and call the function for all the files:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "files_name=['weights_fc1_0_0_25_0.7', 'weights_fc2_0_0_25_0.7', 'weights_fc1_1_1_25_0.7', 'weights_fc2_1_1_25_0.7']\n", + "weights=[[],[],[],[]]\n", + "shape=[{'x':28, 'y':28, 'z':400}, {'x':20, 'y':20, 'z':1600}]\n", + "Non_Zero=False\n", + "for i, file_name in enumerate(files_name):\n", + " _, weights[i]=ExtractConnectionFile(Input_File=file_name, \n", + " Connections_File=f'{file_name}_Zero_Removed_{str(Non_Zero)}.txt',\n", + " Non_Zero=Non_Zero,\n", + " Connections_shape=shape[i%2],\n", + " n_samples_to_plot=0\n", + " )\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3.8.10 ('snnToolBox')", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.8.10" + }, + "orig_nbformat": 4, + "vscode": { + "interpreter": { + "hash": "20a24a3a4007955ef5ee2dc9fb7716c1a03a2c0fc0e1cbee0d7177cafa6993b3" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}