diff --git a/backpack.py b/backpack.py
index 3a5ccbadd092208d4ed9ddf2043486fb558c1731..d741fad2388a9cbba37d26e8ac811f0f81d0032b 100644
--- a/backpack.py
+++ b/backpack.py
@@ -8,67 +8,75 @@ from double_tanh import double_compact_tanh, double_tanh
 
 
 class BackPack(nn.Module):
-    def __init__(self, c_coeffs, rho_0, entropy, N, nets, params, ecdf_list):
+    def __init__(self, image_size, QF, folder_model, c_coeffs, rho_0, entropy, N, nets, ecdf_list, attack):
         super(BackPack, self).__init__()
-        self.net_decompress = IDCT8_Net(params)
+        self.net_decompress = IDCT8_Net(image_size, QF, folder_model)
         self.N = N
         self.entropy = torch.tensor(entropy)
-        self.im_size = params.image_size
-        self.rho_vec = torch.nn.Parameter(data=torch.tensor(rho_0), requires_grad=True)
+        self.im_size = image_size
+        self.rho_vec = torch.nn.Parameter(
+            data=torch.tensor(rho_0), requires_grad=True)
         self.c_coeffs = torch.tensor(c_coeffs).double()
-        self.spatial_cover = self.net_decompress.forward(torch.reshape(self.c_coeffs,(1,1,self.im_size, self.im_size)))/255
+        self.spatial_cover = self.net_decompress.forward(
+            torch.reshape(self.c_coeffs, (1, 1, self.im_size, self.im_size)))/255
         self.nets = nets
-        proba_cover = [torch.nn.Softmax(dim=1)(net.forward(self.spatial_cover.cuda().float()))[0,1] \
-                                for net in nets]
-        self.proba_cover = np.array([ecdf(x.cpu().detach().numpy()) for x,ecdf in zip(proba_cover,ecdf_list)])        
+        proba_cover = [torch.nn.Softmax(dim=1)(net.forward(self.spatial_cover.cuda().float()))[0, 1]
+                       for net in nets]
+        self.proba_cover = np.array(
+            [ecdf(x.cpu().detach().numpy()) for x, ecdf in zip(proba_cover, ecdf_list)])
         self.find_lambda_fn = find_lambda.apply
 
-        self.attack=params.attack
-        if(self.attack=='SGE'):
-            self.mod = torch.reshape(torch.tensor([-1,0,1]), (3,1,1))
-            self.smoothing_fn = lambda u, p, tau : \
-                    torch.sum(torch.nn.Softmax(dim=1)((-torch.log(-torch.log(u))+torch.log(p+1e-30))/tau)*self.mod,axis=1) if tau>0 \
-                    else torch.argmax(-torch.log(-torch.log(u))+torch.log(p+1e-30),axis=1) - 1
-        elif(self.attack=='DoCoTanh'):
+        self.attack = attack
+        if(self.attack == 'SGE'):
+            self.mod = torch.reshape(torch.tensor([-1, 0, 1]), (3, 1, 1))
+            self.smoothing_fn = lambda u, p, tau: \
+                torch.sum(torch.nn.Softmax(dim=1)((-torch.log(-torch.log(u))+torch.log(p+1e-30))/tau)*self.mod, axis=1) if tau > 0 \
+                else torch.argmax(-torch.log(-torch.log(u))+torch.log(p+1e-30), axis=1) - 1
+        elif(self.attack == 'DoCoTanh'):
             self.smoothing_fn = double_compact_tanh.apply
-        elif(self.attack=='DoTanh'):
-            self.smoothing_fn = double_tanh          
+        elif(self.attack == 'DoTanh'):
+            self.smoothing_fn = double_tanh
         self.ecdf_list = ecdf_list
-    
+
     def forward(self, tau):
-        if(self.attack=='SGE'):
+        if(self.attack == 'SGE'):
             u = torch.rand(size=(self.N, 3, self.im_size, self.im_size))
         else:
             u = torch.rand(size=(self.N, self.im_size, self.im_size))
-        
+
         # Compute modifications for soft stego
         lbd = self.find_lambda_fn(self.entropy, self.rho_vec)
-        probas = torch.nn.Softmax(dim=0)(-lbd*(self.rho_vec-torch.min(self.rho_vec,dim=0)[0]))
+        probas = torch.nn.Softmax(
+            dim=0)(-lbd*(self.rho_vec-torch.min(self.rho_vec, dim=0)[0]))
         b = self.smoothing_fn(u, probas, tau)
-        stego_soft = torch.reshape(self.c_coeffs+b, (self.N,1,self.im_size,self.im_size))
+        stego_soft = torch.reshape(
+            self.c_coeffs+b, (self.N, 1, self.im_size, self.im_size))
 
         spatial_image_soft = self.net_decompress.forward(stego_soft)/255
-        logits = [torch.reshape(net.forward(spatial_image_soft.cuda().float()),(1,self.N,2)) \
+        logits = [torch.reshape(net.forward(spatial_image_soft.cuda().float()), (1, self.N, 2))
                   for net in self.nets]
         logits = torch.cat(logits)
-        probas_soft = torch.nn.Softmax(dim=2)(logits)[:,:,-1]
-        mean_probas_soft = torch.mean(probas_soft,axis=-1)
+        probas_soft = torch.nn.Softmax(dim=2)(logits)[:, :, -1]
+        mean_probas_soft = torch.mean(probas_soft, axis=-1)
         mean_probas_soft_cpu = mean_probas_soft.clone().cpu().detach().numpy()
-        mean_probas_soft_cpu = np.array([ecdf(x) for ecdf,x in zip(self.ecdf_list, mean_probas_soft_cpu)])
+        mean_probas_soft_cpu = np.array(
+            [ecdf(x) for ecdf, x in zip(self.ecdf_list, mean_probas_soft_cpu)])
         argmax = np.argmax(mean_probas_soft_cpu)
         best_probas_soft = mean_probas_soft[argmax]
-        
+
         # Compute modifications for real/hard stego
         with torch.no_grad():
             b_hard = self.smoothing_fn(u, probas, 0)
-            stego_hard = torch.reshape(self.c_coeffs+b_hard, (self.N, 1, self.im_size, self.im_size))
+            stego_hard = torch.reshape(
+                self.c_coeffs+b_hard, (self.N, 1, self.im_size, self.im_size))
             spatial_image_hard = self.net_decompress.forward(stego_hard)/255
-            logits = [torch.reshape(net.forward(spatial_image_hard.cuda().float()),(1,self.N,2)) \
+            logits = [torch.reshape(net.forward(spatial_image_hard.cuda().float()), (1, self.N, 2))
                       for net in self.nets]
             logits = torch.cat(logits)
-            probas_hard = torch.nn.Softmax(dim=2)(logits)[:,:,-1]
-            mean_probas_hard = torch.mean(probas_hard,axis=-1).cpu().detach().numpy()
-            mean_probas_hard_cpu = np.array([ecdf(x) for ecdf,x in zip(self.ecdf_list, mean_probas_hard)])
-        
-        return best_probas_soft, mean_probas_soft_cpu, mean_probas_hard_cpu, stego_hard
+            probas_hard = torch.nn.Softmax(dim=2)(logits)[:, :, -1]
+            mean_probas_hard = torch.mean(
+                probas_hard, axis=-1).cpu().detach().numpy()
+            mean_probas_hard_cpu = np.array(
+                [ecdf(x) for ecdf, x in zip(self.ecdf_list, mean_probas_hard)])
 
+        return best_probas_soft, mean_probas_soft_cpu, mean_probas_hard_cpu, stego_hard
diff --git a/data_loader.py b/data_loader.py
index 9dcdb2a51e1481a080ec6ea1dae2e1bcda2315bd..2cf8864e73184f35e22bd8fa40e4a43a1e1832a2 100644
--- a/data_loader.py
+++ b/data_loader.py
@@ -15,33 +15,37 @@ import numpy as np
 import albumentations as A
 import matplotlib.pyplot as plt
 from albumentations.pytorch.transforms import ToTensorV2
-from torch.utils.data import Dataset,DataLoader
+from torch.utils.data import Dataset, DataLoader
 from torch.utils.data.sampler import SequentialSampler, RandomSampler
 import sklearn
 from PIL import Image
 
-from os import path,mkdir,makedirs
+from os import path, mkdir, makedirs
 from tools_stegano import compute_spatial_from_jpeg, compute_proba, HILL
 
-    
+
 # TRANSFORMS
 def get_train_transforms():
     return A.Compose([
-            A.HorizontalFlip(p=0.5),
-            A.VerticalFlip(p=0.5),
-            #A.ToGray(always_apply=True, p=1.0),
-            #A.Resize(height=512, width=512, p=1.0),
-            ToTensorV2(p=1.0), # Si le V2 n'exite pas chez toi, tu peux utiliser ToTensor également
-        ], p=1.0)
+        A.HorizontalFlip(p=0.5),
+        A.VerticalFlip(p=0.5),
+        #A.ToGray(always_apply=True, p=1.0),
+        #A.Resize(height=512, width=512, p=1.0),
+        # Si le V2 n'exite pas chez toi, tu peux utiliser ToTensor également
+        ToTensorV2(p=1.0),
+    ], p=1.0)
+
 
 def get_valid_transforms():
     return A.Compose([
-            #A.Resize(height=512, width=512, p=1.0),
-            #A.ToGray(always_apply=True, p=1.0),
-            ToTensorV2(p=1.0),
-        ], p=1.0)
+        #A.Resize(height=512, width=512, p=1.0),
+        #A.ToGray(always_apply=True, p=1.0),
+        ToTensorV2(p=1.0),
+    ], p=1.0)
 
 # DATASET CLASS
+
+
 def onehot(size, target):
     vec = torch.zeros(size, dtype=torch.float32)
     vec[target] = 1.
@@ -53,17 +57,18 @@ def onehot(size, target):
     vec[target] = 1.
     return vec
 
+
 class DatasetRetriever(Dataset):
 
-    def __init__(self, image_names, folder_model, QF, emb_rate, image_size, \
-            data_dir_cover, data_dir_stego_0, cost_dir, data_dir_prot, \
-            H1_filter, L1_filter, L2_filter, indexs_db, train_on_cost_map=False, \
-            labels=None,  transforms=None, pair_training=False, spatial=False):
+    def __init__(self, image_names, folder_model, QF, emb_rate, image_size,
+                 data_dir_cover, data_dir_stego_0, cost_dir, data_dir_prot,
+                 H1_filter, L1_filter, L2_filter, indexs_db, train_on_cost_map=False,
+                 labels=None,  transforms=None, pair_training=False, spatial=False):
 
         super().__init__()
         self.image_names = image_names
         self.indexs_db = indexs_db
-        self.labels = labels 
+        self.labels = labels
         self.transforms = transforms
         self.c_quant = np.load(folder_model + 'c_quant_'+str(QF)+'.npy')
         self.WET_COST = 10 ** 13
@@ -72,149 +77,168 @@ class DatasetRetriever(Dataset):
         self.cover_path = data_dir_cover
         self.data_dir_stego_0 = data_dir_stego_0
         self.cost_dir = cost_dir
-        self.pair_training=pair_training
+        self.pair_training = pair_training
         self.data_dir_prot = data_dir_prot
-        self.spatial=spatial
-        self.train_on_cost_map=train_on_cost_map
+        self.spatial = spatial
+        self.train_on_cost_map = train_on_cost_map
         if self.spatial:
             if(H1_filter is None):
-                self.H1_filter = 4 * np.array([[-0.25, 0.5, -0.25], 
-                         [0.5, -1, 0.5], 
-                         [-0.25, 0.5, -0.25]])
+                self.H1_filter = 4 * np.array([[-0.25, 0.5, -0.25],
+                                               [0.5, -1, 0.5],
+                                               [-0.25, 0.5, -0.25]])
                 self.L1_filter = (1.0/9.0)*np.ones((3, 3))
                 self.L2_filter = (1.0/225.0)*np.ones((15, 15))
             else:
-                self.H1_filter = np.load(H1_filter).reshape((3,3))
-                self.L1_filter = np.load(L1_filter).reshape((3,3))
-                self.L2_filter = np.load(L2_filter).reshape((15,15))
-
+                self.H1_filter = np.load(H1_filter).reshape((3, 3))
+                self.L1_filter = np.load(L1_filter).reshape((3, 3))
+                self.L2_filter = np.load(L2_filter).reshape((15, 15))
 
     def __getitem__(self, index: int):
-        
+
+        # IF PAIR TRAINING : return a pair (cover, stego) image
         if self.pair_training:
             image_name = self.image_names[index]
 
             if(self.spatial):
-                cover = np.asarray(Image.open(path.join(self.cover_path, image_name[:-3]+'pgm')),dtype=np.float32)
-                message_length= cover.size*self.emb_rate
+                cover = np.asarray(Image.open(
+                    path.join(self.cover_path, image_name[:-3]+'pgm')), dtype=np.float32)
+                message_length = cover.size*self.emb_rate
             else:
-                cover = np.load(path.join(self.cover_path, image_name[:-3]+'npy')).astype(np.float32)
-                nz_AC = np.sum(cover!=0)-np.sum(cover[::8,::8]!=0)
+                cover = np.load(path.join(self.cover_path,
+                                          image_name[:-3]+'npy')).astype(np.float32)
+                nz_AC = np.sum(cover != 0)-np.sum(cover[::8, ::8] != 0)
                 message_length = nz_AC*self.emb_rate
             index_db = self.indexs_db[index]
 
-            if self.train_on_cost_map: # Load the cost map and generate new stego
+            if self.train_on_cost_map:  # Load the cost map and generate new stego
                 if(self.spatial):
-                    rho = HILL(cover, self.H1_filter, self.L1_filter, self.L2_filter)
+                    rho = HILL(cover, self.H1_filter,
+                               self.L1_filter, self.L2_filter)
                 else:
                     try:
-                        cost_dir = self.data_dir_prot + 'data_adv_'+ str(index_db) +'/adv_cost/'
+                        cost_dir = self.data_dir_prot + \
+                            'data_adv_' + str(index_db) + '/adv_cost/'
                         rho = np.load(cost_dir+image_name[:-3]+'npy')
                     except:
-                        cost_dir = self.cost_dir 
+                        cost_dir = self.cost_dir
                         rho = np.load(cost_dir+image_name[:-3]+'npy')
-                if(rho.shape==(self.im_size, self.im_size)):
-                    rho = np.reshape(rho,(1,self.im_size, self.im_size))
-                    rho = np.concatenate((np.copy(rho), np.zeros_like(rho), np.copy(rho)),axis=0)
+                if(rho.shape == (self.im_size, self.im_size)):
+                    rho = np.reshape(rho, (1, self.im_size, self.im_size))
+                    rho = np.concatenate(
+                        (np.copy(rho), np.zeros_like(rho), np.copy(rho)), axis=0)
                 if(self.spatial):
-                   rho[0,cover <= 0] = self.WET_COST
-                   rho[2,cover >= 255] = self.WET_COST
-                if not self.spatial:    
-                    rho[0,cover < -1023] = self.WET_COST
-                    rho[2,cover > 1023] = self.WET_COST
-                
+                    rho[0, cover <= 0] = self.WET_COST
+                    rho[2, cover >= 255] = self.WET_COST
+                if not self.spatial:
+                    rho[0, cover < -1023] = self.WET_COST
+                    rho[2, cover > 1023] = self.WET_COST
+
                 p = compute_proba(rho, message_length)
-                u = np.random.uniform(0,1,(3, self.im_size, self.im_size))
-                stego = (cover + np.argmax(-np.log(-np.log(u))+np.log(p+1e-30),axis=0) - 1).astype(np.float32)
+                u = np.random.uniform(0, 1, (3, self.im_size, self.im_size))
+                stego = (cover + np.argmax(-np.log(-np.log(u)) +
+                                           np.log(p+1e-30), axis=0) - 1).astype(np.float32)
 
-            else: # Load the stego
+            else:  # Load the stego
                 try:
-                    dir_stego = self.data_dir_prot + 'data_adv_'+ str(index_db) +'/adv_final/'
+                    dir_stego = self.data_dir_prot + \
+                        'data_adv_' + str(index_db) + '/adv_final/'
                     stego = np.load(dir_stego+image_name[:-3]+'npy')
                 except:
-                    dir_stego = self.data_dir_stego_0 
+                    dir_stego = self.data_dir_stego_0
                     stego = np.load(dir_stego+image_name[:-3]+'npy')
-            
+
             if not self.spatial:
                 cover = compute_spatial_from_jpeg(cover, self.c_quant)/255
                 stego = compute_spatial_from_jpeg(stego, self.c_quant)/255
-                        
+
             if self.transforms:
                 # To have the same transformation on cover and stego
-                seed = np.random.randint(2147483647) 
+                seed = np.random.randint(2147483647)
                 random.seed(seed)
                 cover = self.transforms(image=cover)['image']
                 random.seed(seed)
                 stego = self.transforms(image=stego)['image']
             else:
-                cover = torch.tensor(cover.reshape((1, self.im_size, self.im_size)))
-                stego = torch.tensor(stego.reshape((1, self.im_size, self.im_size)))
-            cover = cover[0:1,:,:]
-            stego = stego[0:1,:,:]
+                cover = torch.tensor(cover.reshape(
+                    (1, self.im_size, self.im_size)))
+                stego = torch.tensor(stego.reshape(
+                    (1, self.im_size, self.im_size)))
+            cover = cover[0:1, :, :]
+            stego = stego[0:1, :, :]
             target_cover = onehot(2, 0)
             target_stego = onehot(2, 1)
-            return((cover,stego), (target_cover, target_stego))
-            
+            return((cover, stego), (target_cover, target_stego))
+
+        # IF NOT PAIR TRAINING : return an image cover or stego, depending on the value of label[index]
+        # (0 for cover and 1 for stego)
         else:
             image_name, label = self.image_names[index], self.labels[index]
             if(self.spatial):
-                cover = np.asarray(Image.open(path.join(self.cover_path, image_name[:-3]+'pgm')),dtype=np.float32)
-                message_length= cover.size*self.emb_rate
+                cover = np.asarray(Image.open(
+                    path.join(self.cover_path, image_name[:-3]+'pgm')), dtype=np.float32)
+                message_length = cover.size*self.emb_rate
             else:
-                cover = np.load(path.join(self.cover_path, image_name[:-3]+'npy')).astype(np.float32)
-                nz_AC = np.sum(cover!=0)-np.sum(cover[::8,::8]!=0)
+                cover = np.load(path.join(self.cover_path,
+                                          image_name[:-3]+'npy')).astype(np.float32)
+                nz_AC = np.sum(cover != 0)-np.sum(cover[::8, ::8] != 0)
                 message_length = nz_AC*self.emb_rate
-        
+
             if label == 0:
                 image = cover
             elif label == 1:
                 index_db = self.indexs_db[index]
 
-                if self.train_on_cost_map: # Load the cost map and generate new stego
+                if self.train_on_cost_map:  # Load the cost map and generate new stego
                     if(self.spatial):
-                        rho = HILL(cover, self.H1_filter, self.L1_filter, self.L2_filter)
+                        rho = HILL(cover, self.H1_filter,
+                                   self.L1_filter, self.L2_filter)
                     else:
                         try:
-                            cost_dir = self.data_dir_prot + 'data_adv_'+ str(index_db) +'/adv_cost/'
+                            cost_dir = self.data_dir_prot + \
+                                'data_adv_' + str(index_db) + '/adv_cost/'
                             rho = np.load(cost_dir+image_name[:-3]+'npy')
                         except:
-                            cost_dir = self.cost_dir 
+                            cost_dir = self.cost_dir
                             rho = np.load(cost_dir+image_name[:-3]+'npy')
-                    if(rho.shape==(self.im_size, self.im_size)):
-                        rho = np.reshape(rho,(1,self.im_size, self.im_size))
-                        rho = np.concatenate((np.copy(rho), np.zeros_like(rho), np.copy(rho)),axis=0)
+                    if(rho.shape == (self.im_size, self.im_size)):
+                        rho = np.reshape(rho, (1, self.im_size, self.im_size))
+                        rho = np.concatenate(
+                            (np.copy(rho), np.zeros_like(rho), np.copy(rho)), axis=0)
                     if(self.spatial):
-                       rho[0,cover <= 0] = self.WET_COST
-                       rho[2,cover >= 255] = self.WET_COST
-                    if not self.spatial:    
-                        rho[0,cover < -1023] = self.WET_COST
-                        rho[2,cover > 1023] = self.WET_COST
+                        rho[0, cover <= 0] = self.WET_COST
+                        rho[2, cover >= 255] = self.WET_COST
+                    if not self.spatial:
+                        rho[0, cover < -1023] = self.WET_COST
+                        rho[2, cover > 1023] = self.WET_COST
                     p = compute_proba(rho, message_length)
-                    u = np.random.uniform(0,1,(3, self.im_size, self.im_size))
-                    stego = cover + np.argmax(-np.log(-np.log(u))+np.log(p+1e-30),axis=0) - 1
-                else: # Load the stego
+                    u = np.random.uniform(
+                        0, 1, (3, self.im_size, self.im_size))
+                    stego = cover + \
+                        np.argmax(-np.log(-np.log(u)) +
+                                  np.log(p+1e-30), axis=0) - 1
+                else:  # Load the stego
                     try:
-                        dir_stego = self.data_dir_prot + 'data_adv_'+ str(index_db) +'/adv_final/'
+                        dir_stego = self.data_dir_prot + \
+                            'data_adv_' + str(index_db) + '/adv_final/'
                         stego = np.load(dir_stego+image_name[:-3]+'npy')
                     except:
-                        dir_stego = self.data_dir_stego_0 
+                        dir_stego = self.data_dir_stego_0
                         stego = np.load(dir_stego+image_name[:-3]+'npy')
 
-                image = stego  
+                image = stego
 
             image = image.astype(np.float32)
 
             if not self.spatial:
                 image = compute_spatial_from_jpeg(image, self.c_quant)/255.
-            
+
             if self.transforms:
                 sample = {'image': image}
                 sample = self.transforms(**sample)
                 image = sample['image']
             else:
                 image = image.reshape((1, self.im_size, self.im_size))
-            image = image[0:1,:,:]
-
+            image = image[0:1, :, :]
 
             target = onehot(2, label)
             return image, target
@@ -226,61 +250,37 @@ class DatasetRetriever(Dataset):
         return list(self.labels)
 
 
-
-# LABEL SMOOTHING
-class LabelSmoothing(nn.Module):
-    def __init__(self, smoothing = 0.05):
-        super(LabelSmoothing, self).__init__()
-        self.confidence = 1.0 - smoothing
-        self.smoothing = smoothing
-
-    def forward(self, x, target):
-        if self.training:
-            x = x.float()
-            target = target.float()
-            logprobs = torch.nn.functional.log_softmax(x, dim = -1)
-
-            nll_loss = -logprobs * target
-            nll_loss = nll_loss.sum(-1)
-    
-            smooth_loss = -logprobs.mean(dim=-1)
-
-            loss = self.confidence * nll_loss + self.smoothing * smooth_loss
-
-            return loss.mean()
-        else:
-            return torch.nn.functional.cross_entropy(x, target)
-
-
-def load_dataset(iteration_step, permutation_files, train_size, valid_size, test_size, \
-    data_dir_prot, pair_training):
+def load_dataset(iteration_step, permutation_files, train_size, valid_size, test_size,
+                 data_dir_prot, pair_training):
 
     dataset = []
     im_list = np.load(permutation_files)
-    folds = np.zeros(train_size + valid_size + test_size,dtype=np.int8)
-    folds[train_size:]+=1
-    folds[train_size+valid_size:]+=1
-
-    if(iteration_step>0):
-        indexs_db = np.load(data_dir_prot +'data_train_'+str(iteration_step)+'/index.npy')
+    folds = np.zeros(train_size + valid_size + test_size, dtype=np.int8)
+    folds[train_size:] += 1
+    folds[train_size+valid_size:] += 1
+    n_images = train_size+valid_size + test_size
+
+    if(iteration_step > 0):
+        indexs_db = np.load(data_dir_prot + 'data_train_' +
+                            str(iteration_step)+'/index.npy')
     else:
-        indexs_db = np.zeros(train_size+valid_size + test_size, dtype=np.int8)
-    
+        indexs_db = np.zeros(n_images, dtype=np.int8)
+
     if pair_training:
-        for im,fold,ind in zip(im_list, folds, indexs_db):
+        for im, fold, ind in zip(im_list, folds, indexs_db):
             dataset.append({
                 'image_name': im,
                 'fold':  fold,
                 'indexs_db':  ind
             })
-        
+
     else:
-        for label in range(2): # 0 for cover and 1 for stego
-            for im,fold,ind in zip(im_list, folds,indexs_db):
-                if(label==0):
+        for label in range(2):  # 0 for cover and 1 for stego
+            for im, fold, ind in zip(im_list, folds, indexs_db):
+                if(label == 0):
                     index = -1
                 else:
-                    index=ind
+                    index = ind
                 dataset.append({
                     'image_name': im,
                     'label': label,
@@ -290,5 +290,3 @@ def load_dataset(iteration_step, permutation_files, train_size, valid_size, test
 
     dataset = pd.DataFrame(dataset)
     return(dataset)
-
-
diff --git a/double_tanh.py b/double_tanh.py
index 4a459edda51d59e50a4234e5deca662007254d8d..3655effcda3c33c36c0dd90c0bd52e9b2956882f 100644
--- a/double_tanh.py
+++ b/double_tanh.py
@@ -7,31 +7,37 @@ from tools_stegano import IDCT8_Net, find_lambda
 
 
 def logit(y, eps=1e-20):
-    return -1.0 * torch.log((1.0 - torch.min(y,torch.tensor(1.-eps).double())) / torch.max(y, torch.tensor(eps).double()))
+    return -1.0 * torch.log((1.0 - torch.min(y, torch.tensor(1.-eps).double())) / torch.max(y, torch.tensor(eps).double()))
+
 
 def compact_tanh(u, tau):
-    return(torch.tanh(logit(u)/tau))    
+    return(torch.tanh(logit(u)/tau))
+
 
 def g_warp(u, l, c, r):
     a = (u-l)/(r-l)
-    a = torch.clamp((u-l)/(r-l),0,1)
+    a = torch.clamp((u-l)/(r-l), 0, 1)
     b = np.log(0.5)/(torch.log((c-l)/(r-l)))
     e = b*torch.log(a)
     e = torch.exp(e)
     return(e)
 
 # Double Tanh
+
+
 def double_tanh(u, p, tau):
-    if(tau>0):
+    if(tau > 0):
         b = -0.5*(torch.tanh((p[0]-u)/tau)+torch.tanh((p[0]+p[1]-u)/tau))
     else:
-        cumsum = torch.cumsum(p,axis=0)
+        cumsum = torch.cumsum(p, axis=0)
         b = torch.full_like(u, torch.tensor(-1).double())
-        b[u>cumsum[0,None]]+=1
-        b[u>cumsum[1,None]]+=1
+        b[u > cumsum[0, None]] += 1
+        b[u > cumsum[1, None]] += 1
     return(b)
 
 # Double Compact Tanh
+
+
 class double_compact_tanh(torch.autograd.Function):
     """
     We can implement our own custom autograd Functions by subclassing
@@ -41,35 +47,37 @@ class double_compact_tanh(torch.autograd.Function):
 
     @staticmethod
     def forward(ctx, u, p, tau):
-        if tau>0:
+        if tau > 0:
             with torch.enable_grad():
                 gamma = (p[0]+1-p[-1])/2
                 step1 = torch.ones_like(u).double()
                 step2 = -torch.ones_like(u).double()
-                p = (p[:,None]).repeat((1,len(u),1,1))
-                gamma = (gamma[None]).repeat((len(u),1,1))
+                p = (p[:, None]).repeat((1, len(u), 1, 1))
+                gamma = (gamma[None]).repeat((len(u), 1, 1))
 
-                # Handle 0 probabilities of -1 or +1 
+                # Handle 0 probabilities of -1 or +1
                 # (probability of modification 0 is assumed to be never equal to 0)
-                bool1 = (u<gamma)&(p[1]<1.)
-                bool2 = (u>gamma)&(p[1]<1.)
+                bool1 = (u < gamma) & (p[1] < 1.)
+                bool2 = (u > gamma) & (p[1] < 1.)
 
-                step1[bool1] = compact_tanh(g_warp(u[bool1], 0, p[0,bool1], gamma[bool1]), tau)
-                step2[bool2] = compact_tanh(g_warp(u[bool2], gamma[bool2], (p[0]+p[1])[bool2], 1.), tau)
+                step1[bool1] = compact_tanh(
+                    g_warp(u[bool1], 0, p[0, bool1], gamma[bool1]), tau)
+                step2[bool2] = compact_tanh(
+                    g_warp(u[bool2], gamma[bool2], (p[0]+p[1])[bool2], 1.), tau)
 
                 b = 0.5*(step1 + step2)
 
                 # Save data
-                ctx.tau=tau
-                ctx.gamma=gamma
+                ctx.tau = tau
+                ctx.gamma = gamma
                 ctx.u = u
                 ctx.p = p
 
         else:
-            cumsum = torch.cumsum(p,axis=0)
+            cumsum = torch.cumsum(p, axis=0)
             b = torch.full_like(u, torch.tensor(-1).double())
-            b[u>cumsum[0,None]]+=1
-            b[u>cumsum[1,None]]+=1
+            b[u > cumsum[0, None]] += 1
+            b[u > cumsum[1, None]] += 1
 
         return(b)
 
@@ -79,32 +87,35 @@ class double_compact_tanh(torch.autograd.Function):
 
         # Same computation than in forward, to retrieve the gradient
         with torch.enable_grad():
-            
+
             gamma = (ctx.p[0]+1-ctx.p[-1])/2
-            
-            bool1 = (ctx.u<gamma)&(ctx.p[1]<1.)
-            bool2 = (ctx.u>gamma)&(ctx.p[1]<1.)
-                
-            p1 = ctx.p[:,bool1]
-            p2 = ctx.p[:,bool2]
+
+            bool1 = (ctx.u < gamma) & (ctx.p[1] < 1.)
+            bool2 = (ctx.u > gamma) & (ctx.p[1] < 1.)
+
+            p1 = ctx.p[:, bool1]
+            p2 = ctx.p[:, bool2]
             gam1 = ctx.gamma[bool1]
             gam2 = ctx.gamma[bool2]
             u1 = ctx.u[bool1]
             u2 = ctx.u[bool2]
-            
+
             step1 = compact_tanh(g_warp(u1, 0, p1[0], gam1), ctx.tau)
             step2 = compact_tanh(g_warp(u2, gam2, p2[0]+p2[1], 1.), ctx.tau)
-            
+
             gr = torch.zeros_like(ctx.p)
-            gr1 = 0.5*torch.autograd.grad(step1,p1,grad_outputs=grad_output[bool1],retain_graph=True)[0]
-            gr2 = 0.5*torch.autograd.grad(step2,p2,grad_outputs=grad_output[bool2],retain_graph=True)[0] 
-        
-            gr[:,bool1]=gr1
-            gr[:,bool2]=gr2
-            gr = gr[:,0]
-
-            # Avoid NaNs 
-            gr[torch.isnan(gr)]=0
-            
-        return(None, gr, None)
+            gr1 = 0.5 * \
+                torch.autograd.grad(
+                    step1, p1, grad_outputs=grad_output[bool1], retain_graph=True)[0]
+            gr2 = 0.5 * \
+                torch.autograd.grad(
+                    step2, p2, grad_outputs=grad_output[bool2], retain_graph=True)[0]
 
+            gr[:, bool1] = gr1
+            gr[:, bool2] = gr2
+            gr = gr[:, 0]
+
+            # Avoid NaNs
+            gr[torch.isnan(gr)] = 0
+
+        return(None, gr, None)
diff --git a/eval_classifier.py b/eval_classifier.py
index ad90b7839ee4dd5101e17d6e98c7d2763654f12c..ba71f39f6c37abea79b26dbc85debc8825b199da 100644
--- a/eval_classifier.py
+++ b/eval_classifier.py
@@ -8,86 +8,91 @@ import argparse
 from data_loader import compute_spatial_from_jpeg
 from efficientnet import get_net
 
+
 def softmax(array):
-    exp = np.exp(array-np.max(array,axis=1, keepdims=True))
+    exp = np.exp(array-np.max(array, axis=1, keepdims=True))
     return(exp/np.sum(exp, axis=1, keepdims=True))
 
+
 class cover_stego_loader(object):
 
-    def __init__(self, params, iteration, mode): # mode = stego or cover
+    def __init__(self, params, iteration, mode):  # mode = stego or cover
         self.params = params
-        n_images = params.train_size + params.valid_size + params.test_size 
-        self.files = np.load(params.folder_model + 'permutation_files.npy')[:n_images]
+        n_images = params.train_size + params.valid_size + params.test_size
+        self.files = np.load(params.folder_model +
+                             'permutation_files.npy')[:n_images]
         self.train_counter = 0
         self.train_data_size = len(self.files)
-        self.train_num_batches = int(np.ceil(1.0 * self.train_data_size / params.batch_size_eval))
+        self.train_num_batches = int(
+            np.ceil(1.0 * self.train_data_size / params.batch_size_eval))
         self.iteration_step = iteration
         self.mode = mode
-        self.c_quant = np.load(params.folder_model + 'c_quant_'+str(params.QF)+'.npy')
-    
+        self.c_quant = np.load(params.folder_model +
+                               'c_quant_'+str(params.QF)+'.npy')
+
     def next_batch(self):
 
-        borne_sup = min(self.train_counter + self.params.batch_size_eval, len(self.files))
+        borne_sup = min(self.train_counter +
+                        self.params.batch_size_eval, len(self.files))
         n_images = borne_sup-self.train_counter
 
-        next_batch_X = np.zeros((n_images,self.params.image_size,self.params.image_size),dtype=np.float32)
+        next_batch_X = np.zeros(
+            (n_images, self.params.image_size, self.params.image_size), dtype=np.float32)
 
-        for i,file in enumerate(self.files[self.train_counter:borne_sup]):
-            if(self.mode=='stego'):
-                if(self.iteration_step>0):
+        for i, file in enumerate(self.files[self.train_counter:borne_sup]):
+            if(self.mode == 'stego'):
+                if(self.iteration_step > 0):
                     try:
-                        image = np.load(self.params.data_dir_prot+'data_adv_'+str(self.iteration_step)+'/adv_final/'+file[:-4]+'.npy')
+                        image = np.load(self.params.data_dir_prot+'data_adv_' +
+                                        str(self.iteration_step)+'/adv_final/'+file[:-4]+'.npy')
                     except:
-                        image = np.load(self.params.data_dir_stego_0 + file[:-4]+'.npy')
+                        image = np.load(
+                            self.params.data_dir_stego_0 + file[:-4]+'.npy')
                 else:
-                    image = np.load(self.params.data_dir_stego_0 + file[:-4]+'.npy')
-            elif(self.mode=='cover'):
-                image = np.load(self.params.data_dir_cover + file[:-4] + '.npy')
+                    image = np.load(
+                        self.params.data_dir_stego_0 + file[:-4]+'.npy')
+            elif(self.mode == 'cover'):
+                image = np.load(self.params.data_dir_cover +
+                                file[:-4] + '.npy')
 
             spat_image = compute_spatial_from_jpeg(image, self.c_quant)
-            next_batch_X[i,:,:]=spat_image
-    
-        next_batch_X = np.reshape(next_batch_X,(next_batch_X.shape[0], 1, next_batch_X.shape[1],next_batch_X.shape[2]))
+            next_batch_X[i, :, :] = spat_image
+
+        next_batch_X = np.reshape(
+            next_batch_X, (next_batch_X.shape[0], 1, next_batch_X.shape[1], next_batch_X.shape[2]))
+
+        self.train_counter = (
+            self.train_counter + self.params.batch_size_eval) % self.train_data_size
+        return(next_batch_X, self.files[self.train_counter:borne_sup])
 
-        self.train_counter = (self.train_counter + self.params.batch_size_eval) % self.train_data_size
-        return(next_batch_X, self.files[self.train_counter:borne_sup])  
-    
     def reset_counter(self):
         self.train_counter = 0
 
 
-def evaluate_step_i(params, iteration_f, iteration_adv): # if iteration_adv == -1 : cover
-    
+def evaluate_step_i(params, iteration_f, iteration_adv):  # if iteration_adv == -1 : cover
+
     net = get_net().cuda()
     path = params.save_path + "last-checkpoint.bin"
     checkpoint = torch.load(path)
     net.load_state_dict(checkpoint['model_state_dict'])
-   
 
    # Create directory
-    if(iteration_adv==-1):
+    if(iteration_adv == -1):
         directory = params.data_dir_prot+'cover/eval_f'+str(iteration_f)+'/'
-        dataloader = cover_stego_loader(params,iteration_adv,'cover')
+        dataloader = cover_stego_loader(params, iteration_adv, 'cover')
     else:
-        directory = params.data_dir_prot+'data_adv_'+str(iteration_adv)+'/eval_f'+str(iteration_f)+'/'
-        dataloader = cover_stego_loader(params,iteration_adv,'stego')
+        directory = params.data_dir_prot+'data_adv_' + \
+            str(iteration_adv)+'/eval_f'+str(iteration_f)+'/'
+        dataloader = cover_stego_loader(params, iteration_adv, 'stego')
 
     dataloader = cover_stego_loader(params, iteration_adv, 'stego')
-    result_fi = np.empty((0,2))
+    result_fi = np.empty((0, 2))
     dataloader.reset_counter()
     for batch in range(dataloader.train_num_batches):
         batch_x, images_path = dataloader.next_batch()
         with torch.no_grad():
             l = net.forward(torch.tensor(batch_x).cuda())
-        result_fi = np.concatenate((result_fi,l.cpu().detach().numpy()))
-    np.save(directory+'probas',softmax(result_fi)[:,1])
-    np.save(directory+'logits',result_fi)
-    return(result_fi, softmax(result_fi)[:,1])
-
-        
-
-        
-
-
-
-   
+        result_fi = np.concatenate((result_fi, l.cpu().detach().numpy()))
+    np.save(directory+'probas', softmax(result_fi)[:, 1])
+    np.save(directory+'logits', result_fi)
+    return(result_fi, softmax(result_fi)[:, 1])
diff --git a/generate_train_db.py b/generate_train_db.py
index e3b116f12762d4d872c4a543622ff8b1fffb495d..495392380b559e109531b93e920b5a8f8c5303af 100644
--- a/generate_train_db.py
+++ b/generate_train_db.py
@@ -8,31 +8,36 @@ from time import time
 from statsmodels.distributions.empirical_distribution import ECDF
 
 
-def generate_train_db(params, iteration_step, strategy):
-    files = np.load(params.folder_model + 'permutation_files.npy')
-    if strategy=='minmax':
+def generate_train_db(iteration_step, strategy, models, permutation_files, data_dir_prot):
+    files = np.load(permutation_files)
+    if strategy == 'minmax':
         # Probas = Probas(image==stego==1)
-        probas = np.zeros((iteration_step+1,iteration_step*len(params.models),len(files))) # nb_data_bases * nb_classifs * n_images
+        # nb_data_bases * nb_classifs * n_images
+        probas = np.zeros(
+            (iteration_step+1, iteration_step*len(models), len(files)))
         # lignes * colonnes * profondeur
-        #CALIBRATION
-        ecdf_list=[]
+        # CALIBRATION
+        ecdf_list = []
         for i in range(iteration_step):
-            for model in params.models:
-                ecdf = ECDF(np.load(params.data_dir_prot+'cover/eval_'+model+'_'+str(i)+'/probas.npy'))
-                probas[0,i]=ecdf(np.load(params.data_dir_prot+'data_adv_0/eval_'+model+'_'+str(i)+'/probas.npy'))
+            for model in models:
+                ecdf = ECDF(np.load(data_dir_prot+'cover/eval_' +
+                                    model+'_'+str(i)+'/probas.npy'))
+                probas[0, i] = ecdf(
+                    np.load(data_dir_prot+'data_adv_0/eval_'+model+'_'+str(i)+'/probas.npy'))
                 ecdf_list.append(ecdf)
-        for j in range(1,iteration_step+1): 
+        for j in range(1, iteration_step+1):
             for i in range(iteration_step):
-                for k,model in enumerate(params.models):
-                    ecdf = ecdf_list[i*len(params.models)+k]
-                    probas[j,i*len(params.models)+k]=ecdf(np.load(params.data_dir_prot+'data_adv_'+str(j)+'/eval_'+model+'_'+str(i)+'/probas.npy'))
-        index = np.argmin(np.max(probas,axis=1),axis=0)
-    elif strategy=='random':
+                for k, model in enumerate(models):
+                    ecdf = ecdf_list[i*len(models)+k]
+                    probas[j, i*len(models)+k] = ecdf(np.load(data_dir_prot +
+                                                              'data_adv_'+str(j)+'/eval_'+model+'_'+str(i)+'/probas.npy'))
+        index = np.argmin(np.max(probas, axis=1), axis=0)
+    elif strategy == 'random':
         index = np.random.randint(iteration_step+1, size=len(files))
-    elif strategy=='lastit':
-        index = np.zeros(len(files),dtype=int)+iteration_step
-    if not os.path.exists(params.data_dir_prot+'data_train_'+str(iteration_step)+'/'):
-        os.makedirs(params.data_dir_prot+'data_train_'+str(iteration_step)+'/')
-    np.save(params.data_dir_prot+'data_train_'+str(iteration_step)+'/index.npy',index)
+    elif strategy == 'lastit':
+        index = np.zeros(len(files), dtype=int)+iteration_step
+    if not os.path.exists(data_dir_prot+'data_train_'+str(iteration_step)+'/'):
+        os.makedirs(data_dir_prot+'data_train_'+str(iteration_step)+'/')
+    np.save(data_dir_prot+'data_train_' +
+            str(iteration_step)+'/index.npy', index)
     return(index)
-
diff --git a/main.py b/main.py
index 0a6bee3b6c442d9ac81c37e7475b8ccbcfee60a9..fffb76f41fc1f42297ca29c6f85d7999e02c3ef4 100644
--- a/main.py
+++ b/main.py
@@ -17,37 +17,43 @@ from write_jobs import create_training_dictionnary, wait, write_command, run_job
 
 
 def p_error(iteration_f, model, train_size, valid_size, test_size, data_dir_prot):
-    if(iteration_f>0):
-        indexs = np.load(data_dir_prot + 'data_train_'+str(iteration_f)+'/index.npy')
+    if(iteration_f > 0):
+        indexs = np.load(data_dir_prot + 'data_train_' +
+                         str(iteration_f)+'/index.npy')
     else:
         indexs = np.zeros(train_size+valid_size+test_size, dtype=np.int8)
-    probas_stego_z = np.array([np.load(data_dir_prot+'data_adv_'+str(i)+'/eval_'+model+'_'+str(iteration_f)+'/probas.npy') \
-                        for i in range(iteration_f+1)])
-    predictions_cover = np.load(data_dir_prot+'cover/eval_'+model+'_'+str(iteration_f)+'/probas.npy')
-    predictions_stego = np.array([probas_stego_z[ind,i] for i,ind in enumerate(indexs)])
-    
+    probas_stego_z = np.array([np.load(data_dir_prot+'data_adv_'+str(i)+'/eval_'+model+'_'+str(iteration_f)+'/probas.npy')
+                               for i in range(iteration_f+1)])
+    predictions_cover = np.load(
+        data_dir_prot+'cover/eval_'+model+'_'+str(iteration_f)+'/probas.npy')
+    predictions_stego = np.array([probas_stego_z[ind, i]
+                                  for i, ind in enumerate(indexs)])
+
     # Choose best threshold
-    thresholds = np.linspace(0,1,51)[1:-1]
+    thresholds = np.linspace(0, 1, 51)[1:-1]
     PE_train, PE_valid, PE_test = [], [], []
 
     for t in thresholds:
-        pred_cover, pred_stego = predictions_cover[:train_size], predictions_stego[:train_size]
-        fa = np.sum(pred_cover>t)/train_size
-        md = np.sum(pred_stego<=t)/train_size
+        pred_cover, pred_stego = predictions_cover[:
+                                                   train_size], predictions_stego[:train_size]
+        fa = np.sum(pred_cover > t)/train_size
+        md = np.sum(pred_stego <= t)/train_size
         pe_train = (fa+md)/2
-        pred_cover, pred_stego = predictions_cover[train_size:train_size+valid_size], predictions_stego[train_size:train_size+valid_size]
-        fa = np.sum(pred_cover>t)/valid_size
-        md = np.sum(pred_stego<=t)/valid_size
+        pred_cover, pred_stego = predictions_cover[train_size:train_size +
+                                                   valid_size], predictions_stego[train_size:train_size+valid_size]
+        fa = np.sum(pred_cover > t)/valid_size
+        md = np.sum(pred_stego <= t)/valid_size
         pe_valid = (fa+md)/2
-        pred_cover, pred_stego = predictions_cover[train_size+valid_size:], predictions_stego[train_size+valid_size:]
-        fa = np.sum(pred_cover>t)/test_size
-        md = np.sum(pred_stego<=t)/test_size
+        pred_cover, pred_stego = predictions_cover[train_size +
+                                                   valid_size:], predictions_stego[train_size+valid_size:]
+        fa = np.sum(pred_cover > t)/test_size
+        md = np.sum(pred_stego <= t)/test_size
         pe_test = (fa+md)/2
         PE_train.append(pe_train)
         PE_valid.append(pe_valid)
         PE_test.append(pe_test)
     PE_valid = np.asarray(PE_valid)
-    mini = np.argmin(PE_valid) # choose best thresholds on validation set
+    mini = np.argmin(PE_valid)  # choose best thresholds on validation set
     dir_log = data_dir_prot+'train_'+model+'_'+str(iteration_f)
     return(PE_train[mini], PE_valid[mini], PE_test[mini])
 
@@ -58,119 +64,130 @@ def create_folder(path, list_dir):
             os.makedirs(path+d+'/')
 
 
-def run_iteration(iteration_step, label, data_dir_prot, data_dir_cover, data_dir_stego_0, cost_dir, \
-            image_size, QF, folder_model, permutation_files, version_eff, stride, n_loops, \
-            model, train_size, valid_size, test_size, attack, attack_last, emb_rate, \
-            batch_adv, n_iter_max_backpack, N_samples, tau_0, precision, lr, \
-            num_of_threads, training_dictionnary, spatial, strategy):
-    
+def run_iteration(iteration_step, label, data_dir_prot, data_dir_cover, data_dir_stego_0, cost_dir,
+                  image_size, QF, folder_model, permutation_files, version_eff, stride, n_loops,
+                  model, train_size, valid_size, test_size, attack, attack_last, emb_rate,
+                  batch_adv, n_iter_max_backpack, N_samples, tau_0, precision, lr,
+                  num_of_threads, training_dictionnary, spatial, strategy):
+
     n_images = train_size+valid_size+test_size
     models = model.split(',')
 
     def custom_command(mode, iteration_step, my_model):
-        return(write_command(mode, iteration_step, my_model, data_dir_prot, data_dir_cover, data_dir_stego_0, cost_dir, \
-            image_size, QF, folder_model, permutation_files, version_eff, stride, n_loops, \
-            train_size, valid_size, test_size, attack, attack_last, emb_rate, \
-            batch_adv, n_iter_max_backpack, N_samples, tau_0, precision, lr, \
-            num_of_threads, training_dictionnary, spatial))
+        return(write_command(mode, iteration_step, my_model, data_dir_prot, data_dir_cover, data_dir_stego_0, cost_dir,
+                             image_size, QF, folder_model, permutation_files, version_eff, stride, n_loops,
+                             train_size, valid_size, test_size, attack, attack_last, emb_rate,
+                             batch_adv, n_iter_max_backpack, N_samples, tau_0, precision, lr,
+                             num_of_threads, training_dictionnary, spatial))
 
-    if(iteration_step>0):
+    if(iteration_step > 0):
 
-        # GENERATE ADV DATA BASE OF THE BEST LAST CLASSIFIER 
+        # GENERATE ADV DATA BASE OF THE BEST LAST CLASSIFIER
         directory_adv = data_dir_prot+'data_adv_'+str(iteration_step)+'/'
-        create_folder(directory_adv,['adv_final', 'adv_cost',])
-        
+        create_folder(directory_adv, ['adv_final', 'adv_cost', ])
+
         print('Generating adv step ' + str(iteration_step))
         num_batch = n_images // batch_adv
-        command = 'script_attack.py ' + custom_command('attack',iteration_step, model)
-        run_job('attack', label, command, iteration_step, gpu=True, num_batch=num_batch)
+        command = 'script_attack.py ' + \
+            custom_command('attack', iteration_step, model)
+        run_job('attack', label, command, iteration_step,
+                gpu=True, num_batch=num_batch)
         wait(label)
 
-
-        #EVALUATION OF ALL THE CLASSIFIERS ON THE NEW ADV DATA BASE
-        for i in range(iteration_step): 
+        # EVALUATION OF ALL THE CLASSIFIERS ON THE NEW ADV DATA BASE
+        for i in range(iteration_step):
             for my_model in models:
-                if(i)==-1:
-                    directory = data_dir_prot+'cover/eval_'+my_model+'_'+str(i)+'/'
+                if(i) == -1:
+                    directory = data_dir_prot + \
+                        'cover/eval_'+my_model+'_'+str(i)+'/'
                 else:
-                    directory = data_dir_prot+'data_adv_'+str(iteration_step)+'/eval_'+my_model+'_'+str(i)+'/'
+                    directory = data_dir_prot+'data_adv_' + \
+                        str(iteration_step)+'/eval_'+my_model+'_'+str(i)+'/'
                 if not os.path.exists(directory):
                     os.makedirs(directory)
-                print('Evaluation of classifier '+my_model+ ' ' + str(i)+' on database ' + str(iteration_step))
+                print('Evaluation of classifier '+my_model + ' ' +
+                      str(i)+' on database ' + str(iteration_step))
                 command = 'script_evaluate_classif.py' + custom_command('classif', iteration_step, my_model) \
-                     +  ' --iteration_f='+str(i)+' --iteration_adv='+str(iteration_step)
-                run_job('eval_'+str(my_model), label, command, iteration_step, gpu=True)
+                    + ' --iteration_f=' + \
+                    str(i)+' --iteration_adv='+str(iteration_step)
+                run_job('eval_'+str(my_model), label,
+                        command, iteration_step, gpu=True)
 
         wait(label)
 
         # GENERATION OF THE TRAIN DATA BASE
-        generate_train_db(iteration_step, strategy, models, permutation_files, data_dir_prot)
+        generate_train_db(iteration_step, strategy, models,
+                          permutation_files, data_dir_prot)
 
-    
     # TRAINING NEW CLASSIFIER FOR EACH MODEL
     for my_model in models:
         print('Training '+my_model+' at iteration ' + str(iteration_step))
-        create_folder(data_dir_prot,['train_'+my_model+'_'+str(iteration_step)])
-        command = 'script_train.py' + command_main('train', iteration_step, my_model)
-        run_job('train_'+my_model, label, command, iteration_step, gpu=True)
+        create_folder(data_dir_prot, [
+                      'train_'+my_model+'_'+str(iteration_step)])
+        command = 'script_train.py' + \
+            custom_command('train', iteration_step, my_model)
+        run_job('train_'+my_model, label, command, iteration_step,
+                num_of_threads=num_of_threads, gpu=True)
 
     wait(label)
 
     # EVALUATE NEW CLASSIFIERS ON ALL STEGO DATA BASES AND ON COVER
-    for i in range(-1, iteration_step+1): # -1 for cover
+    for i in range(-1, iteration_step+1):  # -1 for cover
         for my_model in models:
-            if(i)==-1:
-                directory = data_dir_prot+'cover/'
+            if(i) == -1:
+                directory = data_dir_prot + 'cover/'
             else:
                 directory = data_dir_prot+'data_adv_'+str(i)+'/'
-            create_folder(directory, ['eval_'+my_model+'_'+str(iteration_step)])
-            print('Evaluation of classifier '+my_model+ ' ' + str(iteration_step)+' on database ' + str(i))
-            command = 'script_evaluate_classif.py' + command_main( 'classif', iteration_step, my_model) \
-                + ' --iteration_f='+str(iteration_step)+' --iteration_adv='+str(i)
-            run_job('eval_'+str(model), label, command, iteration_step, gpu=True)
+            create_folder(
+                directory, ['eval_'+my_model+'_'+str(iteration_step)])
+            print('Evaluation of classifier '+my_model + ' ' +
+                  str(iteration_step)+' on database ' + str(i))
+            command = 'script_evaluate_classif.py' + custom_command('classif', iteration_step, my_model) \
+                + ' --iteration_f='+str(iteration_step) + \
+                ' --iteration_adv='+str(i)
+            run_job('eval_'+str(model), label,
+                    command, iteration_step, gpu=True)
 
     wait(label)
 
     for my_model in models:
-        print(my_model, p_error(iteration_f, my_model, train_size, valid_size, test_size, data_dir_prot))
+        print(my_model, p_error(iteration_step, my_model,
+                                train_size, valid_size, test_size, data_dir_prot))
 
     return(True)
 
 
+def run_protocol(begin_step, number_steps, label, data_dir_prot, data_dir_cover, data_dir_stego_0, cost_dir,
+                 image_size, QF, folder_model, permutation_files, version_eff, stride, n_loops,
+                 model, train_size, valid_size, test_size, attack, attack_last, emb_rate,
+                 batch_adv, n_iter_max_backpack, N_samples, tau_0, precision, lr, strategy,
+                 num_of_threads, spatial, batch_size_classif_xu, batch_size_eval_xu, epoch_num_xu,
+                 CL_xu, start_emb_rate_xu, pair_training_xu, batch_size_classif_sr, batch_size_eval_sr, epoch_num_sr,
+                 CL_sr, start_emb_rate_sr, pair_training_sr, batch_size_classif_ef, batch_size_eval_ef, epoch_num_ef,
+                 CL_ef, start_emb_rate_ef, pair_training_ef, train_on_cost_map):
 
-def run_protocol(begin_step, number_steps, label, data_dir_prot, data_dir_cover, data_dir_stego_0, cost_dir, \
-            image_size, QF, folder_model, permutation_files, version_eff, stride, n_loops, \
-            model, train_size, valid_size, test_size, attack, attack_last, emb_rate, \
-            batch_adv, n_iter_max_backpack, N_samples, tau_0, precision, lr, strategy,\
-            num_of_threads, spatial, batch_size_classif_xu, batch_size_eval_xu, epoch_num_xu, \
-            CL_xu, start_emb_rate_xu, pair_training_xu, batch_size_classif_sr, batch_size_eval_sr, epoch_num_sr, \
-            CL_sr, start_emb_rate_sr, pair_training_sr, batch_size_classif_ef, batch_size_eval_ef, epoch_num_ef, \
-            CL_ef, start_emb_rate_ef, pair_training_ef, train_on_cost_map):
+    training_dictionnary = create_training_dictionnary(model, batch_size_classif_xu, batch_size_eval_xu, epoch_num_xu,
+                                                       CL_xu, start_emb_rate_xu, pair_training_xu, batch_size_classif_sr, batch_size_eval_sr, epoch_num_sr,
+                                                       CL_sr, start_emb_rate_sr, pair_training_sr, batch_size_classif_ef, batch_size_eval_ef, epoch_num_ef,
+                                                       CL_ef, start_emb_rate_ef, pair_training_ef, train_on_cost_map)
 
-    training_dictionnary = create_training_dictionnary(model, batch_size_classif_xu, batch_size_eval_xu, epoch_num_xu, \
-        CL_xu, start_emb_rate_xu, pair_training_xu, batch_size_classif_sr, batch_size_eval_sr, epoch_num_sr, \
-        CL_sr, start_emb_rate_sr, pair_training_sr, batch_size_classif_ef, batch_size_eval_ef, epoch_num_ef, \
-        CL_ef, start_emb_rate_ef, pair_training_ef, train_on_cost_map)
-    
-    iteration_step=begin_step    
+    iteration_step = begin_step
 
-    describe_exp_txt(begin_step, number_steps, num_of_threads, \
-        QF, image_size, emb_rate, data_dir_cover, data_dir_stego_0, cost_dir, \
-        strategy, model, version_eff, stride, n_loops, \
-        train_size, valid_size, test_size, permutation_files, training_dictionnary,\
-        attack, n_iter_max_backpack, N_samples, tau_0, precision, data_dir_prot)
+    describe_exp_txt(begin_step, number_steps, num_of_threads,
+                     QF, image_size, emb_rate, data_dir_cover, data_dir_stego_0, cost_dir,
+                     strategy, model, version_eff, stride, n_loops,
+                     train_size, valid_size, test_size, permutation_files, training_dictionnary,
+                     attack, n_iter_max_backpack, N_samples, tau_0, precision, data_dir_prot)
 
-    while iteration_step<begin_step+number_steps+1:
+    while iteration_step < begin_step+number_steps+1:
 
-        run_iteration(iteration_step, label, data_dir_prot, data_dir_cover, data_dir_stego_0, cost_dir, \
-            image_size, QF, folder_model, permutation_files, version_eff, stride, n_loops, \
-            model, train_size, valid_size, test_size, attack, attack_last, emb_rate, \
-            batch_adv, n_iter_max_backpack, N_samples, tau_0, precision, lr, \
-            num_of_threads, training_dictionnary, spatial, strategy)
-        
-        iteration_step+=1
+        run_iteration(iteration_step, label, data_dir_prot, data_dir_cover, data_dir_stego_0, cost_dir,
+                      image_size, QF, folder_model, permutation_files, version_eff, stride, n_loops,
+                      model, train_size, valid_size, test_size, attack, attack_last, emb_rate,
+                      batch_adv, n_iter_max_backpack, N_samples, tau_0, precision, lr,
+                      num_of_threads, training_dictionnary, spatial, strategy)
 
-    
+        iteration_step += 1
 
 
 if __name__ == '__main__':
@@ -178,23 +195,29 @@ if __name__ == '__main__':
     argparser = argparse.ArgumentParser(sys.argv[0])
     argparser.add_argument('--begin_step', type=int)
     argparser.add_argument('--number_steps', type=int, default=10)
-    argparser.add_argument('--folder_model', type=str, help='The path to the folder where the architecture of models are saved')
-    argparser.add_argument('--permutation_files', type=str, help='The path to the permutation of indexes of images (to shuffle for train, valid and test set).')
-    argparser.add_argument('--num_of_threads', type=int, default=10, help='number of cpus') 
+    argparser.add_argument('--folder_model', type=str,
+                           help='The path to the folder where the architecture of models are saved')
+    argparser.add_argument('--permutation_files', type=str,
+                           help='The path to the permutation of indexes of images (to shuffle for train, valid and test set).')
+    argparser.add_argument('--num_of_threads', type=int,
+                           default=10, help='number of cpus')
 
     # MODEL ARCHITECTURE
     argparser.add_argument('--model', type=str, help='model in the protocol : efnet, xunet or srnet, \
-                    or multiple models separated by commas like "xunet,srnet"') 
+                    or multiple models separated by commas like "xunet,srnet"')
     # For Efficient net
-    argparser.add_argument('--version_eff',type=str,default='b3', help='Version of efficient-net, from b0 to b7')
-    argparser.add_argument('--stride',type=int,default=1, help='Stride at the beginning. Values=1 or 2')
+    argparser.add_argument('--version_eff', type=str, default='b3',
+                           help='Version of efficient-net, from b0 to b7')
+    argparser.add_argument('--stride', type=int, default=1,
+                           help='Stride at the beginning. Values=1 or 2')
     # for xunet
-    argparser.add_argument('--n_loops',type=int, default=5, help='Number of loops in th xunet architecture')
+    argparser.add_argument('--n_loops', type=int, default=5,
+                           help='Number of loops in th xunet architecture')
 
     argparser.add_argument('--data_dir_prot', type=str)
     argparser.add_argument('--data_dir_cover', type=str)
     argparser.add_argument('--data_dir_stego_0', type=str)
-    argparser.add_argument('--cost_dir',type=str)
+    argparser.add_argument('--cost_dir', type=str)
 
     argparser.add_argument('--image_size', type=int)
     argparser.add_argument('--QF', type=int)
@@ -205,13 +228,14 @@ if __name__ == '__main__':
     argparser.add_argument('--label', type=str)
 
     # FOR DATA TRAIN
-    argparser.add_argument('--strategy', type=str, default='minmax') # minmax, random or lastit
+    # minmax, random or lastit
+    argparser.add_argument('--strategy', type=str, default='minmax')
 
     # FOR ADVERSARIAL COST MAP
     argparser.add_argument('--attack', type=str)
     argparser.add_argument('--attack_last', type=str)
     argparser.add_argument('--lr', type=float)
-    argparser.add_argument('--batch_adv',type=int,default=100)
+    argparser.add_argument('--batch_adv', type=int, default=100)
     argparser.add_argument('--n_iter_max_backpack', type=int)
     argparser.add_argument('--N_samples', type=int)
     argparser.add_argument('--tau_0', type=float)
@@ -221,32 +245,36 @@ if __name__ == '__main__':
     argparser.add_argument('--train_size', type=int, default=4000)
     argparser.add_argument('--valid_size', type=int, default=1000)
     argparser.add_argument('--test_size', type=int, default=5000)
-    argparser.add_argument('--train_on_cost_map', type=str, default='yes', help='yes or no. If yes stegos are created from the cost maps at during the training, elif no classifiers are trained directly with the stegos.') 
-    ## FOR TRAINING XU
+    argparser.add_argument('--train_on_cost_map', type=str, default='yes',
+                           help='yes or no. If yes stegos are created from the cost maps at during the training, elif no classifiers are trained directly with the stegos.')
+    # FOR TRAINING XU
     argparser.add_argument('--batch_size_classif_xu', type=int, default=16)
-    argparser.add_argument('--batch_size_eval_xu', type=int, default=16) 
+    argparser.add_argument('--batch_size_eval_xu', type=int, default=16)
     argparser.add_argument('--epoch_num_xu', type=int, default=30)
-    argparser.add_argument('--CL_xu',type=str, help='yes or no. If yes starts from emb_rate of start_emb_rate and decrease until reach emb_rate')
-    argparser.add_argument('--start_emb_rate_xu',type=float, default=0.7,help='Is CL=yes, is the starting emb_rate')
-    argparser.add_argument('--pair_training_xu',type=str, help='yes or no')
-    ## FOR TRAINING SR
+    argparser.add_argument(
+        '--CL_xu', type=str, help='yes or no. If yes starts from emb_rate of start_emb_rate and decrease until reach emb_rate')
+    argparser.add_argument('--start_emb_rate_xu', type=float,
+                           default=0.7, help='Is CL=yes, is the starting emb_rate')
+    argparser.add_argument('--pair_training_xu', type=str, help='yes or no')
+    # FOR TRAINING SR
     argparser.add_argument('--batch_size_classif_sr', type=int, default=16)
-    argparser.add_argument('--batch_size_eval_sr', type=int, default=16) 
+    argparser.add_argument('--batch_size_eval_sr', type=int, default=16)
     argparser.add_argument('--epoch_num_sr', type=int, default=30)
-    argparser.add_argument('--CL_sr',type=str, help='yes or no. If yes starts from emb_rate of start_emb_rate and decrease until reach emb_rate')
-    argparser.add_argument('--start_emb_rate_sr',type=float, default=0.7,help='Is CL=yes, is the starting emb_rate')
-    argparser.add_argument('--pair_training_sr',type=str, help='yes or no')
-    ## FOR TRAINING EF
+    argparser.add_argument(
+        '--CL_sr', type=str, help='yes or no. If yes starts from emb_rate of start_emb_rate and decrease until reach emb_rate')
+    argparser.add_argument('--start_emb_rate_sr', type=float,
+                           default=0.7, help='Is CL=yes, is the starting emb_rate')
+    argparser.add_argument('--pair_training_sr', type=str, help='yes or no')
+    # FOR TRAINING EF
     argparser.add_argument('--batch_size_classif_ef', type=int, default=16)
-    argparser.add_argument('--batch_size_eval_ef', type=int, default=16) 
+    argparser.add_argument('--batch_size_eval_ef', type=int, default=16)
     argparser.add_argument('--epoch_num_ef', type=int, default=30)
-    argparser.add_argument('--CL_ef',type=str, help='yes or no. If yes starts from emb_rate of start_emb_rate and decrease until reach emb_rate')
-    argparser.add_argument('--start_emb_rate_ef',type=float, default=0.7,help='Is CL=yes, is the starting emb_rate')
-    argparser.add_argument('--pair_training_ef',type=str, help='yes or no')
-   
-    params = argparser.parse_args()
-    
-    run_protocol(**vars(params))
-
+    argparser.add_argument(
+        '--CL_ef', type=str, help='yes or no. If yes starts from emb_rate of start_emb_rate and decrease until reach emb_rate')
+    argparser.add_argument('--start_emb_rate_ef', type=float,
+                           default=0.7, help='Is CL=yes, is the starting emb_rate')
+    argparser.add_argument('--pair_training_ef', type=str, help='yes or no')
 
+    params = argparser.parse_args()
 
+    run_protocol(**vars(params))
diff --git a/models/srnet.py b/models/srnet.py
index 8a8fc6cc2e5cfb3dfab4318754846f242ab45f4a..8b4d184031814ac02666938e884dcb1998b33f37 100644
--- a/models/srnet.py
+++ b/models/srnet.py
@@ -8,7 +8,7 @@ class get_net(nn.Module):
 
     def __init__(self, image_size):
         super(get_net, self).__init__()
-        self.im_size = params.image_size
+        self.im_size = image_size
         
         def _conv2d(in_channels, out_channels, stride=1, kernel_size=3, padding=1):
             return nn.Conv2d(in_channels=in_channels,\
diff --git a/script_attack.py b/script_attack.py
index aff834447fc7608b2fdb525d0de32c8b898dafab..e5b22165bb1ce1d9697ed96bd746f73b1ebf82b3 100644
--- a/script_attack.py
+++ b/script_attack.py
@@ -1,56 +1,58 @@
 # TOOLS
-import os,argparse,sys
+from statsmodels.distributions.empirical_distribution import ECDF
+import numpy as np
+from backpack import BackPack
+from srnet import TrainGlobalConfig as TrainGlobalConfig_sr
+from srnet import get_net as get_net_sr
+from xunet import TrainGlobalConfig as TrainGlobalConfig_xu
+from xunet import get_net as get_net_xu
+from efficientnet import TrainGlobalConfig as TrainGlobalConfig_ef
+from efficientnet import get_net as get_net_ef
+import os
+import argparse
+import sys
 
 import torch
 import torch.nn as nn
 import torch.nn.functional as F
 
 sys.path.append('models/')
-from efficientnet import get_net as get_net_ef
-from efficientnet import TrainGlobalConfig as TrainGlobalConfig_ef
 
-from xunet import get_net as get_net_xu
-from xunet import TrainGlobalConfig as TrainGlobalConfig_xu
 
-from srnet import get_net as get_net_sr
-from srnet import TrainGlobalConfig as TrainGlobalConfig_sr
+def backpack_attack(data_dir_cover, cost_dir, image_size, QF, folder_model, emb_rate,
+                    N_samples, nets, image_path, tau_0, ecdf_list, attack, lr, precision, n_iter_max_backpack):
 
-from backpack import BackPack
-
-import numpy as np
-from statsmodels.distributions.empirical_distribution import ECDF
-
-
-def backpack_attack(data_dir_cover, cost_dir, image_size, QF, folder_model, emb_rate, \
-        N_samples, nets, image_path, tau_0, ecdf_list, attack, lr, precision, n_iter_max_backpack):
-    
     c_coeffs = np.load(data_dir_cover + image_path)
-    nz_AC = np.sum(c_coeffs!=0)-np.sum(c_coeffs[::8,::8]!=0)
+    nz_AC = np.sum(c_coeffs != 0)-np.sum(c_coeffs[::8, ::8] != 0)
     rho = np.load(cost_dir+image_path)
-    if(rho.shape==(image_size, image_size)):
-        rho = np.reshape(rho,(1,image_size, image_size))
-        rho = np.concatenate((np.copy(rho), np.zeros_like(rho), np.copy(rho)),axis=0)
-        rho[0,c_coeffs < -1023] = 10e30
-        rho[2,c_coeffs > 1023] = 10e30
+    if(rho.shape == (image_size, image_size)):
+        rho = np.reshape(rho, (1, image_size, image_size))
+        rho = np.concatenate(
+            (np.copy(rho), np.zeros_like(rho), np.copy(rho)), axis=0)
+        rho[0, c_coeffs < -1023] = 10e30
+        rho[2, c_coeffs > 1023] = 10e30
     entropy = emb_rate*nz_AC
 
     tau = tau_0
-    backpack = BackPack(image_size, QF, folder_model, c_coeffs, rho, entropy, N_samples, nets, ecdf_list, attack)
+    backpack = BackPack(image_size, QF, folder_model, c_coeffs,
+                        rho, entropy, N_samples, nets, ecdf_list, attack)
     optimizer = torch.optim.Adam([backpack.rho_vec], lr=lr)
     proba_cover = backpack.proba_cover
-    best_probas_soft, mean_probas_soft, mean_probas_hard, stego_hard = backpack(tau)
+    best_probas_soft, mean_probas_soft, mean_probas_hard, stego_hard = backpack(
+        tau)
     mean_p_soft = mean_probas_soft
     mean_p_hard = mean_probas_hard
 
-    i=0
+    i = 0
     print(proba_cover, mean_p_soft, mean_p_hard)
 
     while True:
 
-        while((np.max(mean_p_soft-proba_cover)>precision) and (np.max(mean_p_hard-proba_cover)>precision) and (i<n_iter_max_backpack)):
+        while((np.max(mean_p_soft-proba_cover) > precision) and (np.max(mean_p_hard-proba_cover) > precision) and (i < n_iter_max_backpack)):
 
             optimizer.zero_grad()
-            best_probas_soft, mean_probas_soft, mean_probas_hard, stego_hard = backpack(tau)
+            best_probas_soft, mean_probas_soft, mean_probas_hard, stego_hard = backpack(
+                tau)
             best_probas_soft.backward()
 
             # g = torch.autograd.grad(best_probas_soft, backpack.rho_vec)[0]
@@ -62,51 +64,57 @@ def backpack_attack(data_dir_cover, cost_dir, image_size, QF, folder_model, emb_
             mean_p_hard = mean_probas_hard
 
             lr = optimizer.param_groups[0]['lr']
-            print(i, np.round(mean_p_soft,2), np.round(mean_p_hard,2), np.round(tau,4), np.round(lr,4))
+            print(i, np.round(mean_p_soft, 2), np.round(
+                mean_p_hard, 2), np.round(tau, 4), np.round(lr, 4))
 
             #tab.append([mean_p_soft, mean_p_hard, tau])
 
-            i+=1
+            i += 1
 
-            if(torch.isnan(backpack.rho_vec).sum()>0):
+            if(torch.isnan(backpack.rho_vec).sum() > 0):
                 print('Nan in cost map')
                 return(None, None)
 
-        if((np.max(mean_p_hard-proba_cover)<=precision) or (i>=n_iter_max_backpack)):
-            return(backpack.rho_vec.cpu().detach().numpy(), stego_hard.cpu().detach().numpy()[0,0])
+        if((np.max(mean_p_hard-proba_cover) <= precision) or (i >= n_iter_max_backpack)):
+            return(backpack.rho_vec.cpu().detach().numpy(), stego_hard.cpu().detach().numpy()[0, 0])
         else:
-            tau/=2
-            best_probas_soft, mean_probas_soft, mean_probas_hard, stego_hard = backpack(tau)
+            tau /= 2
+            best_probas_soft, mean_probas_soft, mean_probas_hard, stego_hard = backpack(
+                tau)
             mean_p_soft = mean_probas_soft
             mean_p_hard = mean_probas_hard
 
-def run_attack(iteration_step, folder_model, data_dir_prot, data_dir_cover, cost_dir, \
-        image_size, QF, emb_rate,  model, version_eff, stride, n_loops, attack_last, \
-        permutation_files, attack, tau_0, lr, precision, n_iter_max_backpack, N_samples, idx_start, batch_adv):
+
+def run_attack(iteration_step, folder_model, data_dir_prot, data_dir_cover, cost_dir,
+               image_size, QF, emb_rate,  model, version_eff, stride, n_loops, attack_last,
+               permutation_files, attack, tau_0, lr, precision, n_iter_max_backpack, N_samples, idx_start, batch_adv):
 
     models = model.split(',')
-    attack_last = attack_last=='yes'
+    attack_last = attack_last == 'yes'
 
     nets = []
-    
-    ecdf_list=[] # For calibration
+
+    ecdf_list = []  # For calibration
     if(attack_last):
         n_classifier_min = max(iteration_step-3, 0)
     else:
         n_classifier_min = 0
     for i in range(n_classifier_min, iteration_step):
         for model in models:
-            if(model=='efnet'):
+            if(model == 'efnet'):
                 net = get_net_ef(version_eff, stride).cuda()
-            elif(model=='xunet'):
+            elif(model == 'xunet'):
                 net = get_net_xu(folder_model, n_loops, image_size).cuda()
-            elif(model=='srnet'):
+            elif(model == 'srnet'):
                 net = get_net_sr(image_size).cuda()
 
-            paths =  os.listdir(data_dir_prot+'train_'+model+'_'+str(i)+'/')
-            paths = [int(x.split('-')[-1][:-9]) for x in paths if 'best-checkpoint' in x]
+            paths = os.listdir(data_dir_prot+'train_'+model+'_'+str(i)+'/')
+            paths = [int(x.split('-')[-1][:-9])
+                     for x in paths if 'best-checkpoint' in x]
             best_epoch = str(max(paths))
-            path = data_dir_prot+'train_'+model+'_'+str(i)+'/best-checkpoint-'+'0'*(3-len(best_epoch))+best_epoch+'epoch.bin'
+            path = data_dir_prot+'train_'+model+'_' + \
+                str(i)+'/best-checkpoint-'+'0' * \
+                (3-len(best_epoch))+best_epoch+'epoch.bin'
 
             checkpoint = torch.load(path)
             net.load_state_dict(checkpoint['model_state_dict'])
@@ -114,7 +122,8 @@ def run_attack(iteration_step, folder_model, data_dir_prot, data_dir_cover, cost
             for x in net.parameters():
                 x.requires_grad = False
             nets.append(net)
-            ecdf = ECDF(np.load(data_dir_prot+'cover/eval_'+model+'_'+str(i)+'/probas.npy'))
+            ecdf = ECDF(np.load(data_dir_prot+'cover/eval_' +
+                                model+'_'+str(i)+'/probas.npy'))
             ecdf_list.append(ecdf)
 
     files = np.load(permutation_files)
@@ -126,13 +135,13 @@ def run_attack(iteration_step, folder_model, data_dir_prot, data_dir_cover, cost
 
     for image_path in files:
         if not os.path.exists(path_save+'adv_cost/'+image_path[:-4]+'.npy'):
-            rho,stego = backpack_attack(data_dir_cover, cost_dir, image_size, QF, \
-                folder_model, emb_rate, N_samples, nets, image_path[:-4]+'.npy', tau_0, ecdf_list, \
-                attack, lr, precision, n_iter_max_backpack)
+            rho, stego = backpack_attack(data_dir_cover, cost_dir, image_size, QF,
+                                         folder_model, emb_rate, N_samples, nets, image_path[
+                                             :-4]+'.npy', tau_0, ecdf_list,
+                                         attack, lr, precision, n_iter_max_backpack)
             np.save(path_save+'adv_cost/'+image_path[:-4]+'.npy', rho)
             np.save(path_save+'adv_final/'+image_path[:-4]+'.npy', stego)
 
-            
 
 if __name__ == '__main__':
     argparser = argparse.ArgumentParser(sys.argv[0])
@@ -140,43 +149,36 @@ if __name__ == '__main__':
 
     argparser.add_argument('--data_dir_prot', type=str)
     argparser.add_argument('--data_dir_cover', type=str)
-    argparser.add_argument('--cost_dir',type=str)
+    argparser.add_argument('--cost_dir', type=str)
     argparser.add_argument('--permutation_files', type=str)
     argparser.add_argument('--folder_model', type=str)
     argparser.add_argument('--image_size', type=int)
     argparser.add_argument('--emb_rate', type=float)
     argparser.add_argument('--QF', type=int)
 
-    # Model parameters 
-    argparser.add_argument('--model', type=str, help='Model : efnet, xunet or srnet or multiple models separated by comma') 
+    # Model parameters
+    argparser.add_argument(
+        '--model', type=str, help='Model : efnet, xunet or srnet or multiple models separated by comma')
     # for efnet
-    argparser.add_argument('--version_eff',type=str, help='Version of efficient-net, from b0 to b7')
-    argparser.add_argument('--stride',type=int, help='Stride at the beginning. Values=1 or 2')
+    argparser.add_argument('--version_eff', type=str,
+                           help='Version of efficient-net, from b0 to b7')
+    argparser.add_argument('--stride', type=int,
+                           help='Stride at the beginning. Values=1 or 2')
     # for xunet
-    argparser.add_argument('--n_loops',type=int, help='Number of loops in the xunet architecture')
-    
+    argparser.add_argument('--n_loops', type=int,
+                           help='Number of loops in the xunet architecture')
+
     # For SGD
     argparser.add_argument('--attack', type=str)
-    argparser.add_argument('--attack_last', type=str,default='no')
+    argparser.add_argument('--attack_last', type=str, default='no')
     argparser.add_argument('--lr', type=float)
     argparser.add_argument('--n_iter_max_backpack', type=int)
     argparser.add_argument('--N_samples', type=int)
     argparser.add_argument('--tau_0', type=float)
     argparser.add_argument('--precision', type=float)
 
-    argparser.add_argument('--idx_start',type=int)
-    argparser.add_argument('--batch_adv',type=int)
+    argparser.add_argument('--idx_start', type=int)
+    argparser.add_argument('--batch_adv', type=int)
     params = argparser.parse_args()
 
-    
     run_attack(**vars(params))
-
-            
-        
-
-
-    
-
-
-
-
diff --git a/script_evaluate_classif.py b/script_evaluate_classif.py
index 848dfe892a440dc8f6d5e622064b5364091903de..d8080547172bb68863be25a4b216d93e20204c4d 100644
--- a/script_evaluate_classif.py
+++ b/script_evaluate_classif.py
@@ -1,3 +1,9 @@
+from srnet import TrainGlobalConfig as TrainGlobalConfig_sr
+from srnet import get_net as get_net_sr
+from xunet import TrainGlobalConfig as TrainGlobalConfig_xu
+from xunet import get_net as get_net_xu
+from efficientnet import TrainGlobalConfig as TrainGlobalConfig_ef
+from efficientnet import get_net as get_net_ef
 import os
 import numpy as np
 from scipy.fftpack import dct, idct
@@ -9,30 +15,24 @@ from data_loader import compute_spatial_from_jpeg
 
 import sys
 sys.path.append('models/')
-from efficientnet import get_net as get_net_ef
-from efficientnet import TrainGlobalConfig as TrainGlobalConfig_ef
-
-from xunet import get_net as get_net_xu
-from xunet import TrainGlobalConfig as TrainGlobalConfig_xu
-
-from srnet import get_net as get_net_sr
-from srnet import TrainGlobalConfig as TrainGlobalConfig_sr
 
 
 def softmax(array):
-    exp = np.exp(array-np.max(array,axis=1, keepdims=True))
+    exp = np.exp(array-np.max(array, axis=1, keepdims=True))
     return(exp/np.sum(exp, axis=1, keepdims=True))
 
+
 class cover_stego_loader(object):
 
-    def __init__(self, iteration_step, mode, train_size, valid_size, test_size, \
-                        batch_size_eval, QF, image_size, folder_model, \
-                        data_dir_prot, data_dir_cover, data_dir_stego_0, permutation_files): # mode = stego or cover
-        n_images = train_size + valid_size + test_size 
+    def __init__(self, iteration_step, mode, train_size, valid_size, test_size,
+                 batch_size_eval, QF, image_size, folder_model,
+                 data_dir_prot, data_dir_cover, data_dir_stego_0, permutation_files):  # mode = stego or cover
+        n_images = train_size + valid_size + test_size
         self.files = np.load(permutation_files)[:n_images]
         self.train_counter = 0
         self.train_data_size = len(self.files)
-        self.train_num_batches = int(np.ceil(1.0 * self.train_data_size / batch_size_eval))
+        self.train_num_batches = int(
+            np.ceil(1.0 * self.train_data_size / batch_size_eval))
         self.iteration_step = iteration_step
         self.mode = mode
         self.c_quant = np.load(folder_model + 'c_quant_'+str(QF)+'.npy')
@@ -45,80 +45,89 @@ class cover_stego_loader(object):
 
     def next_batch(self):
 
-        borne_sup = min(self.train_counter + self.batch_size_eval, len(self.files))
+        borne_sup = min(self.train_counter +
+                        self.batch_size_eval, len(self.files))
         n_images = borne_sup-self.train_counter
 
-        next_batch_X = np.zeros((n_images,self.image_size,self.image_size),dtype=np.float32)
+        next_batch_X = np.zeros(
+            (n_images, self.image_size, self.image_size), dtype=np.float32)
 
         files_batch = self.files[self.train_counter:borne_sup]
-        for i,file in enumerate(files_batch):
-            if(self.mode=='stego'):
-                if(self.iteration_step>0):
+        for i, file in enumerate(files_batch):
+            if(self.mode == 'stego'):
+                if(self.iteration_step > 0):
                     try:
-                        image = np.load(self.data_dir_prot+'data_adv_'+str(self.iteration_step)+'/adv_final/'+file[:-4]+'.npy')
+                        image = np.load(
+                            self.data_dir_prot+'data_adv_'+str(self.iteration_step)+'/adv_final/'+file[:-4]+'.npy')
                     except:
-                        image = np.load(self.data_dir_stego_0 + file[:-4]+'.npy')
+                        image = np.load(
+                            self.data_dir_stego_0 + file[:-4]+'.npy')
                 else:
                     image = np.load(self.data_dir_stego_0 + file[:-4]+'.npy')
-            elif(self.mode=='cover'):
+            elif(self.mode == 'cover'):
                 image = np.load(self.data_dir_cover + file[:-4] + '.npy')
 
             spat_image = compute_spatial_from_jpeg(image, self.c_quant)
             spat_image /= 255.0
-            next_batch_X[i,:,:]=spat_image
-    
-        next_batch_X = np.reshape(next_batch_X,(next_batch_X.shape[0], 1, next_batch_X.shape[1],next_batch_X.shape[2]))
+            next_batch_X[i, :, :] = spat_image
+
+        next_batch_X = np.reshape(
+            next_batch_X, (next_batch_X.shape[0], 1, next_batch_X.shape[1], next_batch_X.shape[2]))
+
+        self.train_counter = (self.train_counter +
+                              self.batch_size_eval) % self.train_data_size
+        return(next_batch_X, files_batch)
 
-        self.train_counter = (self.train_counter + self.batch_size_eval) % self.train_data_size
-        return(next_batch_X, files_batch)  
-    
     def reset_counter(self):
         self.train_counter = 0
 
 
-def evaluate_step_i(iteration_f, iteration_adv, model, data_dir_prot, train_size, valid_size, test_size, \
-                        batch_size_eval, QF, image_size, folder_model, \
-                        data_dir_cover, data_dir_stego_0, permutation_files,\
-                        version_eff=None, stride=None, n_loops=None): # if iteration_adv == -1 : cover
-    
-    if(model=='efnet'):
+def evaluate_step_i(iteration_f, iteration_adv, model, data_dir_prot, train_size, valid_size, test_size,
+                    batch_size_eval, QF, image_size, folder_model,
+                    data_dir_cover, data_dir_stego_0, permutation_files,
+                    version_eff=None, stride=None, n_loops=None):  # if iteration_adv == -1 : cover
+
+    if(model == 'efnet'):
         net = get_net_ef(version_eff, stride).cuda()
-    elif(model=='xunet'):
+    elif(model == 'xunet'):
         net = get_net_xu(folder_model, n_loops, image_size).cuda()
-    elif(model=='srnet'):
+    elif(model == 'srnet'):
         net = get_net_sr(image_size).cuda()
 
-    best_epoch = [int(x.split('-')[-1][:3]) \
-        for x in os.listdir(data_dir_prot+'train_'+model+'_'+str(iteration_f)+'/') \
-        if 'best' in x]
+    best_epoch = [int(x.split('-')[-1][:3])
+                  for x in os.listdir(data_dir_prot+'train_'+model+'_'+str(iteration_f)+'/')
+                  if 'best' in x]
     best_epoch.sort()
     best_epoch = str(best_epoch[-1])
-    path = data_dir_prot+'train_'+model+'_'+str(iteration_f)+'/best-checkpoint-'+'0'*(3-len(best_epoch))+best_epoch+'epoch.bin'
+    path = data_dir_prot+'train_'+model+'_' + \
+        str(iteration_f)+'/best-checkpoint-'+'0' * \
+        (3-len(best_epoch))+best_epoch+'epoch.bin'
     checkpoint = torch.load(path)
     net.load_state_dict(checkpoint['model_state_dict'])
     net.eval()
-   
+
     # Create directory
-    if(iteration_adv==-1):
+    if(iteration_adv == -1):
         directory = data_dir_prot+'cover/eval_'+model+'_'+str(iteration_f)+'/'
         mode = 'cover'
     else:
-        directory = data_dir_prot+'data_adv_'+str(iteration_adv)+'/eval_'+model+'_'+str(iteration_f)+'/'
+        directory = data_dir_prot+'data_adv_' + \
+            str(iteration_adv)+'/eval_'+model+'_'+str(iteration_f)+'/'
         mode = 'stego'
 
-    dataloader = cover_stego_loader(iteration_adv, mode, train_size, valid_size, test_size, \
-                        batch_size_eval, QF, image_size, folder_model, \
-                        data_dir_prot, data_dir_cover, data_dir_stego_0, permutation_files)
+    dataloader = cover_stego_loader(iteration_adv, mode, train_size, valid_size, test_size,
+                                    batch_size_eval, QF, image_size, folder_model,
+                                    data_dir_prot, data_dir_cover, data_dir_stego_0, permutation_files)
 
-    result_fi = np.empty((0,2))
+    result_fi = np.empty((0, 2))
     dataloader.reset_counter()
     for batch in range(dataloader.train_num_batches):
         batch_x, images_path = dataloader.next_batch()
         with torch.no_grad():
             l = net.forward(torch.tensor(batch_x).cuda())
-        result_fi = np.concatenate((result_fi,l.cpu().detach().numpy()))
-    np.save(directory+'probas',softmax(result_fi)[:,1])
-    np.save(directory+'logits',result_fi)
+        result_fi = np.concatenate((result_fi, l.cpu().detach().numpy()))
+    np.save(directory+'probas', softmax(result_fi)[:, 1])
+    np.save(directory+'logits', result_fi)
 
 
 if __name__ == '__main__':
@@ -141,17 +150,21 @@ if __name__ == '__main__':
     argparser.add_argument('--QF', type=int)
     argparser.add_argument('--batch_size_eval', type=int)
 
-    # Model parameters 
-    argparser.add_argument('--model', type=str, help='Model : efnet, xunet or srnet') 
+    # Model parameters
+    argparser.add_argument('--model', type=str,
+                           help='Model : efnet, xunet or srnet')
     # for efnet
-    argparser.add_argument('--version_eff',type=str, help='Version of efficient-net, from b0 to b7')
-    argparser.add_argument('--stride',type=int, help='Stride at the beginning. Values=1 or 2')
+    argparser.add_argument('--version_eff', type=str,
+                           help='Version of efficient-net, from b0 to b7')
+    argparser.add_argument('--stride', type=int,
+                           help='Stride at the beginning. Values=1 or 2')
     # for xunet
-    argparser.add_argument('--n_loops',type=int, help='Number of loops in the xunet architecture')
+    argparser.add_argument('--n_loops', type=int,
+                           help='Number of loops in the xunet architecture')
 
     params = argparser.parse_args()
 
-    print('Evaluate model '+params.model+' at iteration ' +str(params.iteration_f)+ ' on the adv images generated at '+ str(params.iteration_adv))
-    
-    evaluate_step_i(**vars(params))
+    print('Evaluate model '+params.model+' at iteration ' + str(params.iteration_f) +
+          ' on the adv images generated at ' + str(params.iteration_adv))
 
+    evaluate_step_i(**vars(params))
diff --git a/script_train.py b/script_train.py
index 87b7c88bc24a26a7afef5cae6c61fbd671c8abd7..233fe7fd86048609f57b98898fc55559e58ecb73 100644
--- a/script_train.py
+++ b/script_train.py
@@ -1,25 +1,23 @@
+from torch.utils.data import Dataset, DataLoader
+import argparse
+import sklearn
+from torch.utils.data.sampler import SequentialSampler, RandomSampler
+import torch
+import numpy as np
+from srnet import TrainGlobalConfig as TrainGlobalConfig_sr
+from srnet import get_net as get_net_sr
+from xunet import TrainGlobalConfig as TrainGlobalConfig_xu
+from xunet import get_net as get_net_xu
+from efficientnet import TrainGlobalConfig as TrainGlobalConfig_ef
+from efficientnet import get_net as get_net_ef
 from catalyst.data.sampler import BalanceClassSampler
-from data_loader import load_dataset, DatasetRetriever, get_train_transforms, get_valid_transforms, LabelSmoothing
+from data_loader import load_dataset, DatasetRetriever, get_train_transforms, get_valid_transforms
 from train import Fitter
 
-import sys, os
+import sys
+import os
 sys.path.append('models/')
 
-from efficientnet import get_net as get_net_ef
-from efficientnet import TrainGlobalConfig as TrainGlobalConfig_ef
-
-from xunet import get_net as get_net_xu
-from xunet import TrainGlobalConfig as TrainGlobalConfig_xu
-
-from srnet import get_net as get_net_sr
-from srnet import TrainGlobalConfig as TrainGlobalConfig_sr
-
-import argparse, sys
-import numpy as np
-import torch
-from torch.utils.data import Dataset,DataLoader
-from torch.utils.data.sampler import SequentialSampler, RandomSampler
-import sklearn
 
 def my_collate(batch, pair_training=False):
     if pair_training:
@@ -27,15 +25,15 @@ def my_collate(batch, pair_training=False):
         stego = torch.stack([x[0][1] for x in batch])
         label_cover = torch.stack([x[1][0] for x in batch])
         label_stego = torch.stack([x[1][1] for x in batch])
-        r=torch.randperm(len(cover)*2)
-        return(torch.cat((cover,stego))[r],torch.cat((label_cover,label_stego))[r])
+        r = torch.randperm(len(cover)*2)
+        return(torch.cat((cover, stego))[r], torch.cat((label_cover, label_stego))[r])
 
     else:
         return torch.utils.data.dataloader.default_collate(batch)
 
 
-def run_training(iteration_step, model, net, trainGlobalConfig, data_dir_prot, start_emb_rate, emb_rate, pair_training, \
-        version_eff, load_checkpoint, train_dataset, validation_dataset, test_dataset):
+def run_training(iteration_step, model, net, trainGlobalConfig, data_dir_prot, start_emb_rate, emb_rate, pair_training,
+                 version_eff, load_checkpoint, train_dataset, validation_dataset, test_dataset, folder_model, train_on_cost_map):
 
     device = torch.device('cuda:0')
 
@@ -47,55 +45,57 @@ def run_training(iteration_step, model, net, trainGlobalConfig, data_dir_prot, s
         drop_last=True,
         shuffle=True,
         num_workers=trainGlobalConfig.num_workers,
-        collate_fn=lambda x: my_collate(x,pair_training=pair_training)
+        collate_fn=lambda x: my_collate(x, pair_training=pair_training)
     )
     val_loader = torch.utils.data.DataLoader(
-        validation_dataset, 
+        validation_dataset,
         batch_size=trainGlobalConfig.batch_size,
         num_workers=trainGlobalConfig.num_workers,
         shuffle=False,
-        #sampler=SequentialSampler(validation_dataset),
+        # sampler=SequentialSampler(validation_dataset),
         pin_memory=True,
-        collate_fn=lambda x: my_collate(x,pair_training=pair_training)
+        collate_fn=lambda x: my_collate(x, pair_training=pair_training)
     )
 
     test_loader = torch.utils.data.DataLoader(
-        test_dataset, 
+        test_dataset,
         batch_size=trainGlobalConfig.batch_size,
         num_workers=trainGlobalConfig.num_workers,
         shuffle=False,
-        #sampler=SequentialSampler(test_dataset),
+        # sampler=SequentialSampler(test_dataset),
         pin_memory=True,
-        collate_fn=lambda x: my_collate(x,pair_training=pair_training)
+        collate_fn=lambda x: my_collate(x, pair_training=pair_training)
     )
 
-    save_path = data_dir_prot + 'train_'+model+'_'+str(iteration_step) +'/'
+    save_path = data_dir_prot + 'train_'+model+'_'+str(iteration_step) + '/'
     #save_path = data_dir_prot + 'train_'+model+'_stego_'+str(iteration_step) +'/'
     #save_path = data_dir_prot + 'train_'+model+'_'+str(iteration_step) +'_'+str(emb_rate)+'/'
     if not os.path.exists(save_path):
         os.makedirs(save_path)
-    
-    fitter = Fitter(model=net, device=device, config=trainGlobalConfig, save_path=save_path, model_str=model)
+
+    fitter = Fitter(model=net, device=device, config=trainGlobalConfig,
+                    save_path=save_path, model_str=model, train_on_cost_map=train_on_cost_map)
     print(f'{fitter.base_dir}')
-    
-    if(model=='efnet'):
+
+    if(model == 'efnet'):
         # Load the pretrained model
         fitter.load(folder_model+version_eff+"-imagenet")
-    
+
     if(load_checkpoint is not None):
         save_path_load = save_path
         #save_path_load = data_dir_prot + 'train_'+model+'_'+str(iteration_step) +'/'
-        if(load_checkpoint=='best'):
-            best_epoch = [int(x.split('-')[-1][:3]) \
-                for x in os.listdir(save_path_load) \
-                if 'best' in x]
+        if(load_checkpoint == 'best'):
+            best_epoch = [int(x.split('-')[-1][:3])
+                          for x in os.listdir(save_path_load)
+                          if 'best' in x]
             best_epoch.sort()
             best_epoch = str(best_epoch[-1])
-            path = save_path_load + 'best-checkpoint-'+'0'*(3-len(best_epoch))+best_epoch+'epoch.bin'
+            path = save_path_load + 'best-checkpoint-'+'0' * \
+                (3-len(best_epoch))+best_epoch+'epoch.bin'
             fitter.load(path)
         else:
             fitter.load(save_path+load_checkpoint)
-    
+
     fitter.fit(train_loader, val_loader, emb_rate)
 
     train_loader.dataset.emb_rate = emb_rate
@@ -107,92 +107,116 @@ def run_training(iteration_step, model, net, trainGlobalConfig, data_dir_prot, s
     _, test = fitter.validation(test_loader)
     print('Pe_err test :', test.avg)
 
-    np.save(save_path+'error_rate.npy', np.array([train.avg, val.avg, test.avg]))
+    np.save(save_path+'error_rate.npy',
+            np.array([train.avg, val.avg, test.avg]))
 
     return(train.avg, val.avg, test.avg)
 
-    
-def train(iteration_step, model, folder_model, data_dir_prot, permutation_files, num_of_threads, \
-        data_dir_cover, data_dir_stego_0, cost_dir, train_size, valid_size, test_size, \
-        image_size, batch_size_classif, batch_size_eval, epoch_num, QF, emb_rate, \
-        spatial, train_on_cost_map, H1_filter, L1_filter, L2_filter, \
-        version_eff, stride, n_loops, CL, start_emb_rate, pair_training, load_model):
-    
-    pair_training = pair_training=='yes'
-    spatial = spatial=='yes'
-    train_on_cost_map = train_on_cost_map=='yes'
 
-    dataset = load_dataset(iteration_step, permutation_files, train_size, valid_size, test_size, \
-        data_dir_prot, pair_training=pair_training)
+def train(iteration_step, model, folder_model, data_dir_prot, permutation_files, num_of_threads,
+          data_dir_cover, data_dir_stego_0, cost_dir, train_size, valid_size, test_size,
+          image_size, batch_size_classif, batch_size_eval, epoch_num, QF, emb_rate,
+          spatial, train_on_cost_map, H1_filter, L1_filter, L2_filter,
+          version_eff, stride, n_loops, CL, start_emb_rate, pair_training, load_model):
+
+    pair_training = pair_training == 'yes'
+    spatial = spatial == 'yes'
+    train_on_cost_map = train_on_cost_map == 'yes'
 
+    dataset = load_dataset(iteration_step, permutation_files, train_size, valid_size, test_size,
+                           data_dir_prot, pair_training=pair_training)
 
-    train_valid_test_names = [dataset[dataset['fold'] == i].image_name.values for i in range(3)]
-    train_valid_test_indexs_db = [dataset[dataset['fold'] == i].indexs_db.values for i in range(3)]
+    train_valid_test_names = [dataset[dataset['fold']
+                                      == i].image_name.values for i in range(3)]
+    train_valid_test_indexs_db = [
+        dataset[dataset['fold'] == i].indexs_db.values for i in range(3)]
 
     if pair_training:
         train_valid_test_labels = [None, None, None]
     else:
-        train_valid_test_labels = [dataset[dataset['fold'] == i].labels.values for i in range(3)]
-
-    train_valid_test_transforms = [get_train_transforms(), get_valid_transforms(), None]   
-
-    datasets = [DatasetRetriever(image_names, folder_model, QF, emb_rate, image_size, \
-            data_dir_cover, data_dir_stego_0, cost_dir, data_dir_prot, \
-            H1_filter, L1_filter, L2_filter, indexs_db, train_on_cost_map, \
-            labels,  transforms, pair_training, spatial) \
-            for image_names, indexs_db, labels, transforms in \
-                zip(train_valid_test_names, train_valid_test_indexs_db, \
+        train_valid_test_labels = [
+            dataset[dataset['fold'] == i].label.values for i in range(3)]
+
+    train_valid_test_transforms = [
+        get_train_transforms(), get_valid_transforms(), None]
+
+    datasets = [DatasetRetriever(image_names, folder_model, QF, emb_rate, image_size,
+                                 data_dir_cover, data_dir_stego_0, cost_dir, data_dir_prot,
+                                 H1_filter, L1_filter, L2_filter, indexs_db, train_on_cost_map=True,
+                                 labels=labels,  transforms=transforms, pair_training=pair_training, spatial=spatial)
+                for image_names, indexs_db, labels, transforms in
+                zip(train_valid_test_names, train_valid_test_indexs_db,
                     train_valid_test_labels, train_valid_test_transforms)]
-
     train_dataset, validation_dataset, test_dataset = datasets[0], datasets[1], datasets[2]
 
-    if(model=='efnet'):
+    if(model == 'efnet'):
         net = get_net_ef(version_eff, stride).cuda()
-        trainGlobalConfig = TrainGlobalConfig_ef(num_of_threads, batch_size_classif, epoch_num)
-    elif(model=='xunet'):
+        trainGlobalConfig = TrainGlobalConfig_ef(
+            num_of_threads, batch_size_classif, epoch_num)
+    elif(model == 'xunet'):
         net = get_net_xu(folder_model, n_loops, image_size).cuda()
-        trainGlobalConfig = TrainGlobalConfig_xu(num_of_threads, batch_size_classif, epoch_num)
-    elif(model=='srnet'):
+        trainGlobalConfig = TrainGlobalConfig_xu(
+            num_of_threads, batch_size_classif, epoch_num)
+    elif(model == 'srnet'):
         net = get_net_sr(image_size).cuda()
         net.init()
-        trainGlobalConfig = TrainGlobalConfig_sr(num_of_threads, batch_size_classif, epoch_num)
-
-    
+        trainGlobalConfig = TrainGlobalConfig_sr(
+            num_of_threads, batch_size_classif, epoch_num)
 
-    if(CL=='yes'):
+    if(CL == 'yes'):
         start_emb_rate = start_emb_rate
     else:
         start_emb_rate = emb_rate
 
-
     # Train first with cost map
-    run_training(iteration_step, model, net, trainGlobalConfig, data_dir_prot, start_emb_rate, \
-        emb_rate, pair_training, version_eff, load_model, \
-        train_dataset, validation_dataset, test_dataset)
+    train, val, test = run_training(iteration_step, model, net, trainGlobalConfig, data_dir_prot, start_emb_rate,
+                                    emb_rate, pair_training, version_eff, load_model,
+                                    train_dataset, validation_dataset, test_dataset, folder_model, train_on_cost_map=True)
+    print(train, val, test)
+
+    datasets = [DatasetRetriever(image_names, folder_model, QF, emb_rate, image_size,
+                                 data_dir_cover, data_dir_stego_0, cost_dir, data_dir_prot,
+                                 H1_filter, L1_filter, L2_filter, indexs_db, train_on_cost_map=False,
+                                 labels=labels,  transforms=transforms, pair_training=pair_training, spatial=spatial)
+                for image_names, indexs_db, labels, transforms in
+                zip(train_valid_test_names, train_valid_test_indexs_db,
+                    train_valid_test_labels, train_valid_test_transforms)]
+    train_dataset, validation_dataset, test_dataset = datasets[0], datasets[1], datasets[2]
 
     # Train then with the stegos from the best iteration obtained from the training with cost maps
     # (set load_model to 'best')
-    train, val, test = run_training(iteration_step, model, net, trainGlobalConfig, data_dir_prot, start_emb_rate, \
-            emb_rate, pair_training, version_eff, 'best', \
-            train_dataset, validation_dataset, test_dataset)
-    
-    return(train, val, test)
+    train, val, test = run_training(iteration_step, model, net, trainGlobalConfig, data_dir_prot, start_emb_rate,
+                                    emb_rate, pair_training, version_eff, 'best',
+                                    train_dataset, validation_dataset, test_dataset, folder_model, train_on_cost_map=False)
+    print(train, val, test)
+
+    # train, val, test = run_training(iteration_step, model, net, trainGlobalConfig, data_dir_prot, start_emb_rate, \
+    #     emb_rate, pair_training, version_eff, load_model, \
+    #     train_dataset, validation_dataset, test_dataset)
 
+    return(train, val, test)
 
 
 if __name__ == '__main__':
 
     argparser = argparse.ArgumentParser(sys.argv[0])
-    argparser.add_argument('--iteration_step', type=int, help='Iteration step') 
-    argparser.add_argument('--folder_model', type=str, help='The path to the folder where the architecture of models are saved')
-    argparser.add_argument('--data_dir_prot', type=str, help='Path of the protocol') 
-    argparser.add_argument('--permutation_files', type=str, help='List of files') 
-    
-    argparser.add_argument('--num_of_threads',type=int, help='Number of CPUs available')
-
-    argparser.add_argument('--data_dir_cover', type=str, help='Where are the .npy files of cover') 
-    argparser.add_argument('--data_dir_stego_0', type=str, help='Where are the inital stegos') 
-    argparser.add_argument('--cost_dir', type=str, help='Where are the .npy of costs') 
+    argparser.add_argument('--iteration_step', type=int, help='Iteration step')
+    argparser.add_argument('--folder_model', type=str,
+                           help='The path to the folder where the architecture of models are saved')
+    argparser.add_argument('--data_dir_prot', type=str,
+                           help='Path of the protocol')
+    argparser.add_argument('--permutation_files',
+                           type=str, help='List of files')
+
+    argparser.add_argument('--num_of_threads', type=int,
+                           help='Number of CPUs available')
+
+    argparser.add_argument('--data_dir_cover', type=str,
+                           help='Where are the .npy files of cover')
+    argparser.add_argument('--data_dir_stego_0', type=str,
+                           help='Where are the inital stegos')
+    argparser.add_argument('--cost_dir', type=str,
+                           help='Where are the .npy of costs')
 
     argparser.add_argument('--train_size', type=int, help='Size of train set')
     argparser.add_argument('--valid_size', type=int, help='Size of valid set')
@@ -202,36 +226,46 @@ if __name__ == '__main__':
     argparser.add_argument('--batch_size_classif', type=int)
     argparser.add_argument('--batch_size_eval', type=int)
     argparser.add_argument('--epoch_num', type=int, help='Number of epochs')
-    argparser.add_argument('--QF', type=int, help='Quality factor, values accepted are 75, 95 and 100')
-    argparser.add_argument('--emb_rate', type=float, help= 'Float between 0 and 1')
-    argparser.add_argument('--spatial', type=str, default='no', help= 'yes if for spatial steganography, no for jpeg steganography ')
-    argparser.add_argument('--train_on_cost_map', type=str, default='yes', help='yes or no. If yes stegos are created from the cost maps at during the training, elif no classifiers are trained directly with the stegos.') 
-    
+    argparser.add_argument(
+        '--QF', type=int, help='Quality factor, values accepted are 75, 95 and 100')
+    argparser.add_argument('--emb_rate', type=float,
+                           help='Float between 0 and 1')
+    argparser.add_argument('--spatial', type=str, default='no',
+                           help='yes if for spatial steganography, no for jpeg steganography ')
+    argparser.add_argument('--train_on_cost_map', type=str, default='yes',
+                           help='yes or no. If yes stegos are created from the cost maps at during the training, elif no classifiers are trained directly with the stegos.')
+
     # FOR CUSTOM FILTERS FOR HILL
-    argparser.add_argument('--H1_filter',type=str, default=None, help='Path to the saved H1 filter')
-    argparser.add_argument('--L1_filter',type=str, default=None, help='Path to the saved L1 filter')
-    argparser.add_argument('--L2_filter',type=str, default=None, help='Path to the saved L2 filter')
-    
-    # Model parameters 
-    argparser.add_argument('--model', type=str, help='Model : efnet, xunet or srnet') 
+    argparser.add_argument('--H1_filter', type=str,
+                           default=None, help='Path to the saved H1 filter')
+    argparser.add_argument('--L1_filter', type=str,
+                           default=None, help='Path to the saved L1 filter')
+    argparser.add_argument('--L2_filter', type=str,
+                           default=None, help='Path to the saved L2 filter')
+
+    # Model parameters
+    argparser.add_argument('--model', type=str,
+                           help='Model : efnet, xunet or srnet')
     # for efnet
-    argparser.add_argument('--version_eff',type=str, help='Version of efficient-net, from b0 to b7')
-    argparser.add_argument('--stride',type=int, help='Stride at the beginning. Values=1 or 2')
+    argparser.add_argument('--version_eff', type=str,
+                           help='Version of efficient-net, from b0 to b7')
+    argparser.add_argument('--stride', type=int,
+                           help='Stride at the beginning. Values=1 or 2')
     # for xunet
-    argparser.add_argument('--n_loops',type=int, default=5, help='Number of loops in th xunet architecture')
+    argparser.add_argument('--n_loops', type=int, default=5,
+                           help='Number of loops in th xunet architecture')
 
     # Training parameters
-    argparser.add_argument('--CL',type=str, help='yes or no. If yes starts from emb_rate of start_emb_rate and decrease until reach emb_rate')
-    argparser.add_argument('--start_emb_rate',type=float, default=0.7, help='Is CL=yes, is the starting emb_rate')
-    argparser.add_argument('--pair_training',type=str, help='Yes or no')
-   
-    argparser.add_argument('--load_model',type=str, default=None, help='Path to the saved efficient model')
+    argparser.add_argument(
+        '--CL', type=str, help='yes or no. If yes starts from emb_rate of start_emb_rate and decrease until reach emb_rate')
+    argparser.add_argument('--start_emb_rate', type=float,
+                           default=0.7, help='Is CL=yes, is the starting emb_rate')
+    argparser.add_argument('--pair_training', type=str, help='Yes or no')
 
+    argparser.add_argument('--load_model', type=str,
+                           default=None, help='Path to the saved efficient model')
 
     params = argparser.parse_args()
-    
+
     train, val, test = train(**vars(params))
     print(train, val, test)
-    
-
-
diff --git a/tools_stegano.py b/tools_stegano.py
index 0f6357e9683a3917c953ed80b827c55c455fb27e..d834c935e283a7716e2c18cac7121ecbb1f91849 100644
--- a/tools_stegano.py
+++ b/tools_stegano.py
@@ -7,40 +7,47 @@ from scipy.optimize import root_scalar
 from scipy.signal import fftconvolve, convolve2d
 
 
-def softmax(x):
-    y = np.exp((x-np.max(x,axis=0,keepdims=True)))
-    return(y/np.sum(y,axis=0,keepdims=True))
+def softmax(x, axis=0):
+    y = np.exp((x-np.max(x, axis=axis, keepdims=True)))
+    return(y/np.sum(y, axis=axis, keepdims=True))
+
 
 def gibsprobability(rho, lbd):
     return(softmax(-lbd*rho))
 
+
 def H(rho, lbd):
     return(ternary_entropy(gibsprobability(rho, lbd)))
 
+
 def compute_nz_AC(c_coeffs):
-    s = np.sum(c_coeffs!=0)-np.sum(c_coeffs[::8,::8]!=0)
+    s = np.sum(c_coeffs != 0)-np.sum(c_coeffs[::8, ::8] != 0)
     return(s)
 
+
 def ternary_entropy(p):
     P = np.copy(p)
-    P[P==0]=1
+    P[P == 0] = 1
     H = -((P) * np.log2(P))
     Ht = np.sum(H)
     return Ht
 
+
 def calc_lambda(rho, ml):
     lbd_r = 1e-30
     lbd_l = lbd_r
-    while lbd_r<1e14:
+    while lbd_r < 1e14:
         if H(rho, lbd_r) < ml:
             break
         lbd_l = lbd_r
         #lbd_r = (lbd_r+1e-30)*10
         lbd_r *= 10
     #print(lbd_l, lbd_r, H(rho, lbd_l), H(rho, lbd_r))
-    r = root_scalar(lambda x: ml - H(rho,x), bracket=[lbd_l,lbd_r], xtol=1e-10)
+    r = root_scalar(lambda x: ml - H(rho, x),
+                    bracket=[lbd_l, lbd_r], xtol=1e-10)
     return(r.root)
 
+
 def compute_proba(rho, message_length):
     """
     Embedding simulator simulates the embedding made by the best possible 
@@ -49,29 +56,32 @@ def compute_proba(rho, message_length):
     that are asymptotically approaching the nzbound
     """
     lbd = calc_lambda(rho, message_length)
-    a = np.exp(-lbd*rho)
-    p = a/(np.sum(a,axis=0))
+    p = gibsprobability(rho, lbd)
     return(p)
 
+
 def dct2(x):
     return dct(dct(x, norm='ortho').T, norm='ortho').T
 
+
 def idct2(x):
     return idct(idct(x, norm='ortho').T, norm='ortho').T
 
-def compute_spatial_from_jpeg(jpeg_im,c_quant):
+
+def compute_spatial_from_jpeg(jpeg_im, c_quant):
     """
     Compute the 8x8 DCT transform of the jpeg representation
     """
-    w,h = jpeg_im.shape
-    spatial_im = np.zeros((w,h))
+    w, h = jpeg_im.shape
+    spatial_im = np.zeros((w, h))
     for bind_i in range(int(w//8)):
         for bind_j in range(int(h//8)):
-            block = idct2(jpeg_im[bind_i*8:(bind_i+1)*8,bind_j*8:(bind_j+1)*8]*(c_quant))+128
-            #block[block>255]=255
-            #block[block<0]=0
-            spatial_im[bind_i*8:(bind_i+1)*8,bind_j*8:(bind_j+1)*8] = block
-    spatial_im = spatial_im.astype(np.float32)       
+            block = idct2(jpeg_im[bind_i*8:(bind_i+1)*8,
+                                  bind_j*8:(bind_j+1)*8]*(c_quant))+128
+            # block[block>255]=255
+            # block[block<0]=0
+            spatial_im[bind_i*8:(bind_i+1)*8, bind_j*8:(bind_j+1)*8] = block
+    spatial_im = spatial_im.astype(np.float32)
     return(spatial_im)
 
 
@@ -91,82 +101,85 @@ class find_lambda(torch.autograd.Function):
         ctx.lbd = lbd
         ctx.rho = rho_vec
 
-        lbd = torch.tensor(lbd, requires_grad=True) 
+        lbd = torch.tensor(lbd, requires_grad=True)
         return(lbd)
 
     @staticmethod
     def backward(ctx, grad_output):
         with torch.enable_grad():
             lbd = torch.tensor(ctx.lbd, requires_grad=True)
-            
+
             # Computation of probas and entropy
             softmax_fn = torch.nn.Softmax(dim=0)
-            probas = softmax_fn(-lbd*(ctx.rho-torch.min(ctx.rho,dim=0)[0]))
-            H_result = -torch.sum(torch.mul(probas, torch.log(probas+1e-30)))/np.log(2.)
-            
+            probas = softmax_fn(-lbd*(ctx.rho-torch.min(ctx.rho, dim=0)[0]))
+            H_result = - \
+                torch.sum(torch.mul(probas, torch.log(probas+1e-30)))/np.log(2.)
+
             # Gradients
-            grad_H_rho = torch.autograd.grad(H_result, ctx.rho, retain_graph=True)[0]
+            grad_H_rho = torch.autograd.grad(
+                H_result, ctx.rho, retain_graph=True)[0]
             grad_H_lbd = torch.autograd.grad(H_result, lbd)[0]
-            
+
             # Implicit gradient
             g = -grad_H_rho/grad_H_lbd
-        return(None,g)
+        return(None, g)
 
 
-def get_inv_perm(params):
-    x = np.arange(params.image_size**2).reshape((params.image_size,params.image_size))
-    X = np.zeros((8,8,params.image_size//8,params.image_size//8))
-    for i in range(params.image_size//8):
-        for j in range(params.image_size//8):
-            X[:,:,i,j]=x[i*8:(i+1)*8,j*8:(j+1)*8]
+def get_inv_perm(image_size):
+    x = np.arange(image_size**2).reshape((image_size, image_size))
+    X = np.zeros((8, 8, image_size//8, image_size//8))
+    for i in range(image_size//8):
+        for j in range(image_size//8):
+            X[:, :, i, j] = x[i*8:(i+1)*8, j*8:(j+1)*8]
     inv_perm = np.argsort(X.flatten())
     return(inv_perm)
 
 
 class IDCT8_Net(nn.Module):
 
-    def __init__(self, params):
-        super(IDCT8_Net, self).__init__()      
-        n = params.image_size//8
-        c_quant = np.load(params.folder_model+'c_quant_'+str(params.QF)+'.npy') # 8*8 quantification matrix
-        self.c_quant = torch.tensor(np.tile(c_quant,(n,n)).reshape((params.image_size,params.image_size))) 
-        self.inv_perm = get_inv_perm(params)
-        self.IDCT8_kernel = torch.tensor(np.load(params.folder_model+'DCT_8.npy')\
-                                         .reshape((8,8,64,1)).transpose((2,3,0,1)))
-        self.im_size = params.image_size
+    def __init__(self, image_size, QF, folder_model):
+        super(IDCT8_Net, self).__init__()
+        n = image_size//8
+        c_quant = np.load(folder_model+'c_quant_'+str(QF) +
+                          '.npy')  # 8*8 quantification matrix
+        self.c_quant = torch.tensor(
+            np.tile(c_quant, (n, n)).reshape((image_size, image_size)))
+        self.inv_perm = get_inv_perm(image_size)
+        self.IDCT8_kernel = torch.tensor(np.load(folder_model+'DCT_8.npy')
+                                         .reshape((8, 8, 64, 1)).transpose((2, 3, 0, 1)))
+        self.im_size = image_size
 
     def forward(self, x):
-        x = torch.mul(x, self.c_quant) # dequantization
-        x = F.conv2d(x, self.IDCT8_kernel, stride=(8,8), padding=0) # 2D-IDCT
+        x = torch.mul(x, self.c_quant)  # dequantization
+        x = F.conv2d(x, self.IDCT8_kernel, stride=(8, 8), padding=0)  # 2D-IDCT
         # Reorder coefficients and reshape
-        x = torch.reshape(x,(x.size()[0],self.im_size**2))
-        x = x[:,self.inv_perm]
-        x = torch.reshape(x,(x.size()[0], 1, self.im_size, self.im_size))
+        x = torch.reshape(x, (x.size()[0], self.im_size**2))
+        x = x[:, self.inv_perm]
+        x = torch.reshape(x, (x.size()[0], 1, self.im_size, self.im_size))
         x += 128
         return x
 
 
-def HILL(cover, H1=None, L1=None, L2=None) :
+def HILL(cover, H1=None, L1=None, L2=None):
 
     if H1 is None:
-        H1 = 4 * np.array([[-0.25, 0.5, -0.25], 
-                 [0.5, -1, 0.5], 
-                 [-0.25, 0.5, -0.25]])
+        H1 = 4 * np.array([[-0.25, 0.5, -0.25],
+                           [0.5, -1, 0.5],
+                           [-0.25, 0.5, -0.25]])
         L1 = (1.0/9.0)*np.ones((3, 3))
         L2 = (1.0/225.0)*np.ones((15, 15))
-    
+
     # High pass filter H1
-    R = convolve2d(cover, H1.reshape((3,3)), mode = 'same', boundary = 'symm')
+    R = convolve2d(cover, H1.reshape((3, 3)), mode='same', boundary='symm')
 
     # Low pass filter L1
-    xi = convolve2d(abs(R), L1.reshape((3,3)), mode = 'same', boundary = 'symm')
+    xi = convolve2d(abs(R), L1.reshape((3, 3)), mode='same', boundary='symm')
     inv_xi = 1/(xi+1e-20)
-    
+
     # Low pass filter L2
-    rho = convolve2d(inv_xi, L2.reshape((15,15)), mode = 'same', boundary = 'symm')
+    rho = convolve2d(inv_xi, L2.reshape((15, 15)),
+                     mode='same', boundary='symm')
     # adjust embedding costs
-    rho[rho > WET_COST] = WET_COST # threshold on the costs
-    rho[np.isnan(rho)] = WET_COST # Check if all elements are numbers
+    # rho[rho > WET_COST] = WET_COST # threshold on the costs
+    # rho[np.isnan(rho)] = WET_COST # Check if all elements are numbers
     return rho
-
-
diff --git a/train.py b/train.py
index f4ff5fafb2f4ba57c1f6d9339692fdd0d3a2e37d..bf73c10e6bf7e46589a1f38d0825f10901c8a223 100644
--- a/train.py
+++ b/train.py
@@ -1,3 +1,4 @@
+from sklearn import metrics
 from glob import glob
 from sklearn.model_selection import GroupKFold
 import cv2
@@ -11,25 +12,28 @@ import random
 import cv2
 import pandas as pd
 import numpy as np
+import albumentations as A
 import matplotlib.pyplot as plt
-from torch.utils.data import Dataset,DataLoader
+from albumentations.pytorch.transforms import ToTensorV2
+from torch.utils.data import Dataset, DataLoader
 from torch.utils.data.sampler import SequentialSampler, RandomSampler
 import sklearn
-from data_loader import load_dataset, DatasetRetriever, get_train_transforms, get_valid_transforms, LabelSmoothing
+from data_loader import load_dataset, DatasetRetriever, get_train_transforms, get_valid_transforms
 
-from os import path,mkdir,makedirs
+from os import path, mkdir, makedirs
 
 
 # FITTER
 import warnings
 warnings.filterwarnings("ignore")
 
+
 class Fitter:
-    
-    def __init__(self, model, device, config, save_path, model_str):
+
+    def __init__(self, model, device, config, save_path, model_str, train_on_cost_map):
         self.config = config
         self.epoch = 0
-        
+
         self.base_dir = save_path
         self.log_path = f'{self.base_dir}/log.txt'
         self.best_summary_loss = 10**5
@@ -37,32 +41,43 @@ class Fitter:
         self.model = model
         self.device = device
         self.model_str = model_str
+        self.train_on_cost_map = train_on_cost_map
 
         param_optimizer = list(self.model.named_parameters())
-        if(model_str=='efnet'):
+        if(model_str == 'efnet'):
             no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
             optimizer_grouped_parameters = [
-                {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.001},
-                {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
-            ] 
+                {'params': [p for n, p in param_optimizer if not any(
+                    nd in n for nd in no_decay)], 'weight_decay': 0.001},
+                {'params': [p for n, p in param_optimizer if any(
+                    nd in n for nd in no_decay)], 'weight_decay': 0.0}
+            ]
             optimizer_grouped_parameters = self.model.parameters()
-        elif(model_str=='xunet'):
-            fl = {n for n, m in self.model.named_modules() if isinstance(m, torch.nn.Linear)}
-            fl_param_names = {n for n, _ in self.model.named_parameters() if n.rsplit('.', 1)[0] in fl}
+        elif(model_str == 'xunet'):
+            fl = {n for n, m in self.model.named_modules(
+            ) if isinstance(m, torch.nn.Linear)}
+            fl_param_names = {n for n, _ in self.model.named_parameters() if n.rsplit('.', 1)[
+                0] in fl}
             optimizer_grouped_parameters = [
-                {'params': [p for n, p in self.model.named_parameters() if n not in fl_param_names], 'weight_decay': 0.0},
-                {'params': [p for n, p in self.model.named_parameters() if n in fl_param_names], 'weight_decay': 0.0005}
-                ]
-        elif(model_str=='srnet'):
+                {'params': [p for n, p in self.model.named_parameters(
+                ) if n not in fl_param_names], 'weight_decay': 0.0},
+                {'params': [p for n, p in self.model.named_parameters(
+                ) if n in fl_param_names], 'weight_decay': 0.0005}
+            ]
+        elif(model_str == 'srnet'):
             no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
             optimizer_grouped_parameters = [
-                {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 2e-4},
-                {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
+                {'params': [p for n, p in param_optimizer if not any(
+                    nd in n for nd in no_decay)], 'weight_decay': 2e-4},
+                {'params': [p for n, p in param_optimizer if any(
+                    nd in n for nd in no_decay)], 'weight_decay': 0.0}
             ]
 
-        self.optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=config.lr)
-        self.scheduler = config.SchedulerClass(self.optimizer, **config.scheduler_params)
-        self.criterion = LabelSmoothing().to(self.device)
+        self.optimizer = torch.optim.AdamW(
+            optimizer_grouped_parameters, lr=config.lr)
+        self.scheduler = config.SchedulerClass(
+            self.optimizer, **config.scheduler_params)
+        self.criterion = torch.nn.CrossEntropyLoss()
         self.log(f'Fitter prepared. Device is {self.device}')
 
     def fit(self, train_loader, validation_loader, emb_rate):
@@ -71,13 +86,15 @@ class Fitter:
                 lr = self.optimizer.param_groups[0]['lr']
                 timestamp = datetime.utcnow().isoformat()
                 self.log(f'\n{timestamp}\nLR: {lr}')
-                self.log(f'Emb_rate: {train_loader.dataset.emb_rate}')
+                if(self.train_on_cost_map):
+                    self.log(f'Emb_rate: {train_loader.dataset.emb_rate}')
 
             t = time.time()
             summary_loss, final_scores = self.train_one_epoch(train_loader)
 
-            if(e%2==0 and train_loader.dataset.emb_rate>emb_rate):
-                udpate_emb_rate = max(emb_rate, train_loader.dataset.emb_rate - 0.1)
+            if(e % 2 == 0 and train_loader.dataset.emb_rate > emb_rate):
+                udpate_emb_rate = max(
+                    emb_rate, train_loader.dataset.emb_rate*0.9)
                 train_loader.dataset.emb_rate = udpate_emb_rate
 
             self.log(f'[RESULT]: Train. Epoch: {self.epoch}, summary_loss: {summary_loss.avg:.5f}, final_score: {final_scores.avg:.5f}, time: {(time.time() - t):.5f}')
@@ -91,10 +108,10 @@ class Fitter:
                 self.best_summary_loss = summary_loss.avg
                 self.model.eval()
                 self.save(f'{self.base_dir}/best-checkpoint-{str(self.epoch).zfill(3)}epoch.bin')
-                for path in sorted(glob(f'{self.base_dir}/best-checkpoint-*epoch.bin'))[:-3]:
-                    os.remove(path)
+                # for path in sorted(glob(f'{self.base_dir}/best-checkpoint-*epoch.bin'))[:-3]:
+                #    os.remove(path)
 
-            if (self.config.validation_scheduler and train_loader.dataset.emb_rate==emb_rate):
+            if (self.config.validation_scheduler and train_loader.dataset.emb_rate == emb_rate):
                 self.scheduler.step(metrics=summary_loss.avg)
 
             self.epoch += 1
@@ -108,16 +125,17 @@ class Fitter:
             if self.config.verbose:
                 if step % self.config.verbose_step == 0:
                     print(
-                        f'Val Step {step}/{len(val_loader)}, ' + \
-                        f'summary_loss: {summary_loss.avg:.5f}, final_score: {final_scores.avg:.5f}, ' + \
+                        f'Val Step {step}/{len(val_loader)}, ' +
+                        f'summary_loss: {summary_loss.avg:.5f}, final_score: {final_scores.avg:.5f}, ' +
                         f'time: {(time.time() - t):.5f}', end='\r'
                     )
             with torch.no_grad():
-                targets = targets.to(self.device).float()
+                targets = targets.to(self.device).long()
                 batch_size = images.shape[0]
                 images = images.to(self.device).float()
                 outputs = self.model(images)
-                loss = self.criterion(outputs, targets)
+                loss = self.criterion(outputs, torch.argmax(targets, dim=1))
+                #loss = torch.nn.functional.cross_entropy(outputs, targets)
                 final_scores.update(targets, outputs)
                 summary_loss.update(loss.detach().item(), batch_size)
 
@@ -132,20 +150,20 @@ class Fitter:
             if self.config.verbose:
                 if step % self.config.verbose_step == 0:
                     print(
-                        f'Train Step {step}/{len(train_loader)}, ' + \
-                        f'summary_loss: {summary_loss.avg:.5f}, final_score: {final_scores.avg:.5f}, ' + \
+                        f'Train Step {step}/{len(train_loader)}, ' +
+                        f'summary_loss: {summary_loss.avg:.5f}, final_score: {final_scores.avg:.5f}, ' +
                         f'time: {(time.time() - t):.5f}', end='\r'
                     )
-            
-            targets = targets.to(self.device).float()
+            targets = targets.to(self.device).long()
             images = images.to(self.device).float()
             batch_size = images.shape[0]
 
             self.optimizer.zero_grad()
             outputs = self.model(images)
-            loss = self.criterion(outputs, targets)
+            loss = self.criterion(outputs, torch.argmax(targets, dim=1))
+            #loss = torch.nn.functional.cross_entropy(outputs, targets)
             loss.backward()
-            
+
             final_scores.update(targets, outputs)
             summary_loss.update(loss.detach().item(), batch_size)
 
@@ -155,7 +173,7 @@ class Fitter:
                 self.scheduler.step()
 
         return summary_loss, final_scores
-    
+
     def save(self, path):
         self.model.eval()
         torch.save({
@@ -174,23 +192,24 @@ class Fitter:
             self.optimizer.param_groups[0]['lr'] = lr
         self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
         self.best_summary_loss = checkpoint['best_summary_loss']
-        self.epoch = checkpoint['epoch'] + 1
-        
+        self.epoch = quit() + 1
+
     def transfer(self, path):
         checkpoint = torch.load(path)
         self.model.load_state_dict(checkpoint['model_state_dict'])
-        #self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
-        #self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
+        # self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
+        # self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
         #self.best_summary_loss = checkpoint['best_summary_loss']
         #self.epoch = checkpoint['epoch'] + 1
+
     def transfer_with_epoch(self, path):
         checkpoint = torch.load(path)
         self.model.load_state_dict(checkpoint['model_state_dict'])
-        #self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
-        #self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
+        # self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
+        # self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
         self.best_summary_loss = checkpoint['best_summary_loss']
         self.epoch = checkpoint['epoch'] + 1
-        
+
     def log(self, message):
         if self.config.verbose:
             print(message)
@@ -198,12 +217,12 @@ class Fitter:
             logger.write(f'{message}\n')
 
 
-
 # METRICS
-from sklearn import metrics
+
 
 class AverageMeter(object):
     """Computes and stores the average and current value"""
+
     def __init__(self):
         self.reset()
 
@@ -219,7 +238,7 @@ class AverageMeter(object):
         self.count += n
         self.avg = self.sum / self.count
 
-        
+
 def alaska_auc(y_true, y_valid):
     """
     https://www.kaggle.com/anokas/weighted-auc-metric-updated
@@ -230,6 +249,7 @@ def alaska_auc(y_true, y_valid):
     fpr, tpr, thresholds = metrics.roc_curve(y_true, y_valid, pos_label=1)
     return metrics.auc(fpr, tpr)
 
+
 def alaskaPE(y_true, y_valid):
     """
     Go find the definition of PE yourself 
@@ -238,8 +258,9 @@ def alaskaPE(y_true, y_valid):
     weights = [2, 1]
 
     fpr, tpr, thresholds = metrics.roc_curve(y_true, y_valid, pos_label=1)
-    return np.min(fpr +(1-tpr))/2
-        
+    return np.min(fpr + (1-tpr))/2
+
+
 def alaska_weighted_auc(y_true, y_valid):
     """
     https://www.kaggle.com/anokas/weighted-auc-metric-updated
@@ -261,7 +282,7 @@ def alaska_weighted_auc(y_true, y_valid):
         y_max = tpr_thresholds[idx + 1]
         mask = (y_min < tpr) & (tpr < y_max)
         # pdb.set_trace()
-        
+
         try:
             x_padding = np.linspace(fpr[mask][-1], 1, 100)
         except:
@@ -275,26 +296,26 @@ def alaska_weighted_auc(y_true, y_valid):
         competition_metric += submetric
 
     return competition_metric / normalization
-        
+
+
 class RocAucMeter(object):
     def __init__(self):
         self.reset()
 
     def reset(self):
-        self.y_true = np.array([0,1])
-        self.y_pred = np.array([0.5,0.5])
+        self.y_true = np.array([0, 1])
+        self.y_pred = np.array([0.5, 0.5])
         self.score = 0
 
     def update(self, y_true, y_pred):
         y_true = y_true.cpu().numpy().argmax(axis=1).clip(min=0, max=1).astype(int)
-        y_pred = 1 - nn.functional.softmax(y_pred, dim=1).data.cpu().numpy()[:,0]
+        y_pred = 1 - \
+            nn.functional.softmax(y_pred, dim=1).data.cpu().numpy()[:, 0]
         self.y_true = np.hstack((self.y_true, y_true))
         self.y_pred = np.hstack((self.y_pred, y_pred))
-        self.score = alaskaPE(self.y_true, self.y_pred)#alaska_auc(self.y_true, self.y_pred)
-    
+        # alaska_auc(self.y_true, self.y_pred)
+        self.score = alaskaPE(self.y_true, self.y_pred)
+
     @property
     def avg(self):
         return self.score
-    
-
-    
\ No newline at end of file
diff --git a/write_description.py b/write_description.py
index 5fc17f1aeb2d4f4067d636c7a2e427f4c8a363d1..42a6375648f2066246e6034efb57b47c6776f86f 100644
--- a/write_description.py
+++ b/write_description.py
@@ -1,21 +1,23 @@
 from datetime import date
 
-def describe_exp_txt(begin_step, number_steps, num_of_threads, \
-        QF, image_size, emb_rate, data_dir_cover, data_dir_stego_0, cost_dir, \
-        strategy, model, version_eff, stride, n_loops, \
-        train_size, valid_size, test_size, permutation_files, training_dictionnary,\
-        attack, n_iter_max_backpack, N_samples, tau_0, precision, data_dir_prot):
+
+def describe_exp_txt(begin_step, number_steps, num_of_threads,
+                     QF, image_size, emb_rate, data_dir_cover, data_dir_stego_0, cost_dir,
+                     strategy, model, version_eff, stride, n_loops,
+                     train_size, valid_size, test_size, permutation_files, training_dictionnary,
+                     attack, n_iter_max_backpack, N_samples, tau_0, precision, data_dir_prot):
 
     n_images = train_size + valid_size + test_size
     models = model.split(',')
 
     tab = []
     tab.append(date.today().strftime("%b-%d-%Y"))
-    tab.append('Launch of the protocol, starting from iteration ' + str(begin_step) + ' to ' +  str(begin_step+number_steps) + '\n')
+    tab.append('Launch of the protocol, starting from iteration ' +
+               str(begin_step) + ' to ' + str(begin_step+number_steps) + '\n')
     tab.append('Number of CPUs called : ' + str(num_of_threads) + '\n \n')
-    
+
     tab.append('PARAMETERS \n')
-    
+
     tab.append('Image characteristics \n')
     tab.append('- QF = ' + str(QF) + '\n')
     tab.append('- Image size = ' + str(image_size) + '\n')
@@ -23,55 +25,70 @@ def describe_exp_txt(begin_step, number_steps, num_of_threads, \
     tab.append('- Cover images are taken in folder ' + data_dir_cover + '\n')
     tab.append('- Stego images are taken in folder ' + data_dir_stego_0 + '\n')
     tab.append('- Cost maps are taken in folder ' + cost_dir + '\n \n')
-    
+
     tab.append('Protocol setup \n')
-    tab.append('- Strategy =' + strategy +' \n \n')
+    tab.append('- Strategy =' + strategy + ' \n \n')
 
     tab.append('Model description \n')
-    if(len(models)==1):
-        tab.append('- Model architecture is ' + model + ' with the following setup :\n')
-        if model=='efnet':
-            tab.append('     - Efficient-net version is ' +str(version_eff) +' pretrained on image-net \n')
-            tab.append('     - First conv stem is with stride = ' +str(stride) +' \n \n')
-        elif model=='xunet':
-            tab.append('     - XuNet architecture is composed with '+str(n_loops) +' big blocks \n \n')
+    if(len(models) == 1):
+        tab.append('- Model architecture is ' + model +
+                   ' with the following setup :\n')
+        if model == 'efnet':
+            tab.append('     - Efficient-net version is ' +
+                       str(version_eff) + ' pretrained on image-net \n')
+            tab.append('     - First conv stem is with stride = ' +
+                       str(stride) + ' \n \n')
+        elif model == 'xunet':
+            tab.append('     - XuNet architecture is composed with ' +
+                       str(n_loops) + ' big blocks \n \n')
     else:
-        tab.append('- The '+str(len(models)) +' model architectures are ' + model + ' with the following setup :\n')
+        tab.append('- The '+str(len(models)) + ' model architectures are ' +
+                   model + ' with the following setup :\n')
         if 'efnet' in models:
-            tab.append('     - Efficient-net version is ' +str(version_eff) +' pretrained on image-net \n')
-            tab.append('     - First conv stem is with stride = ' +str(stride) +' \n \n')
+            tab.append('     - Efficient-net version is ' +
+                       str(version_eff) + ' pretrained on image-net \n')
+            tab.append('     - First conv stem is with stride = ' +
+                       str(stride) + ' \n \n')
         if 'xunet' in models:
-            tab.append('     - XuNet architecture is composed with '+str(n_loops) +' big blocks \n \n')
-
+            tab.append('     - XuNet architecture is composed with ' +
+                       str(n_loops) + ' big blocks \n \n')
 
     tab.append('Training setup \n')
-    tab.append('- Train size = '+str(train_size) +'\n')
-    tab.append('- Valid size = '+str(valid_size) +'\n')
-    tab.append('- Test size = '+str(test_size) +'\n')
-    tab.append('- Files permutation, which order determines train, valid and test sets is '+ permutation_files +'\n')
+    tab.append('- Train size = '+str(train_size) + '\n')
+    tab.append('- Valid size = '+str(valid_size) + '\n')
+    tab.append('- Test size = '+str(test_size) + '\n')
+    tab.append('- Files permutation, which order determines train, valid and test sets is ' +
+               permutation_files + '\n')
 
     for model in models:
-        tab.append('- Model '+model+' is trained during '+str(training_dictionnary[model]['epoch_num']) + ' epochs \n')
-        if(training_dictionnary[model]['pair_training']=='no'):
+        tab.append('- Model '+model+' is trained during ' +
+                   str(training_dictionnary[model]['epoch_num']) + ' epochs \n')
+        if(training_dictionnary[model]['pair_training'] == 'no'):
             tab.append('- Pair training is not used \n')
-            tab.append('- Batch size is ' + str(training_dictionnary[model]['batch_size_classif']) + ' \n')
+            tab.append('- Batch size is ' +
+                       str(training_dictionnary[model]['batch_size_classif']) + ' \n')
         else:
             tab.append('- Pair training is used \n')
-            tab.append('- Batch size is 2*' + str(training_dictionnary[model]['batch_size_classif']) + ' \n')
-        if(training_dictionnary[model]['CL']=='yes'):
-            tab.append('- Curriculum is used : the embedding rate starts from '+ str(training_dictionnary[model]['start_emb_rate']) +' and decreases every two epochs by factor 0.9 to reach target embedding rate '+str(emb_rate)+'\n \n')
+            tab.append('- Batch size is 2*' +
+                       str(training_dictionnary[model]['batch_size_classif']) + ' \n')
+        if(training_dictionnary[model]['CL'] == 'yes'):
+            tab.append('- Curriculum is used : the embedding rate starts from ' + str(
+                training_dictionnary[model]['start_emb_rate']) + ' and decreases every two epochs by factor 0.9 to reach target embedding rate '+str(emb_rate)+'\n \n')
         else:
-            tab.append('- Curriculum is not used : the embedding rate = '+str(emb_rate)+' is constant during training \n \n')
-
+            tab.append('- Curriculum is not used : the embedding rate = ' +
+                       str(emb_rate)+' is constant during training \n \n')
 
     tab.append('Attack setup \n')
-    tab.append('- The smoothing function is ' +str(attack) +' \n')
-    tab.append('- Maximum number of steps is ' +str(n_iter_max_backpack) +' \n')
-    tab.append('- Number of samples is ' +str(N_samples) +' \n')
-    tab.append('- Tau is initialized with value ' +str(tau_0) +' and decreases by factor 0.5 when needed\n')
-    tab.append('- The exit condition is required to be respected with precision = '+str(precision)+'\n \n')
+    tab.append('- The smoothing function is ' + str(attack) + ' \n')
+    tab.append('- Maximum number of steps is ' +
+               str(n_iter_max_backpack) + ' \n')
+    tab.append('- Number of samples is ' + str(N_samples) + ' \n')
+    tab.append('- Tau is initialized with value ' + str(tau_0) +
+               ' and decreases by factor 0.5 when needed\n')
+    tab.append(
+        '- The exit condition is required to be respected with precision = '+str(precision)+'\n \n')
 
-    file = open(data_dir_prot+"description.txt","w")  
+    file = open(data_dir_prot+"description.txt", "w")
     file.writelines(tab)
     file.close()
 
@@ -96,4 +113,4 @@ def update_exp_txt(params, lines_to_append):
             else:
                 appendEOL = True
             # Append element at the end of file
-            file_object.write(line)
\ No newline at end of file
+            file_object.write(line)
diff --git a/write_jobs.py b/write_jobs.py
index 2b3b0dbeb384fd5bfe631b55afd672dd1b3bbf09..15c9a017a45ac3b8a27ec352ad5a1f141828ce8d 100644
--- a/write_jobs.py
+++ b/write_jobs.py
@@ -6,155 +6,158 @@ def is_finished(label):
     liste = os.popen("squeue --user=ulc18or").read()
     liste = liste.split('\n')[1:]
     for x in liste:
-        if (label  + '_' in x):
+        if (label + '_' in x):
             return(False)
     return(True)
 
+
 def wait(label):
     while not is_finished(label):
         sleep(60)
 
 
-def create_training_dictionnary(model, batch_size_classif_xu, batch_size_eval_xu, epoch_num_xu, \
-    CL_xu, start_emb_rate_xu, pair_training_xu, batch_size_classif_sr, batch_size_eval_sr, epoch_num_sr, \
-    CL_sr, start_emb_rate_sr, pair_training_sr, batch_size_classif_ef, batch_size_eval_ef, epoch_num_ef, \
-    CL_ef, start_emb_rate_ef, pair_training_ef, train_on_cost_map):
+def create_training_dictionnary(model, batch_size_classif_xu, batch_size_eval_xu, epoch_num_xu,
+                                CL_xu, start_emb_rate_xu, pair_training_xu, batch_size_classif_sr, batch_size_eval_sr, epoch_num_sr,
+                                CL_sr, start_emb_rate_sr, pair_training_sr, batch_size_classif_ef, batch_size_eval_ef, epoch_num_ef,
+                                CL_ef, start_emb_rate_ef, pair_training_ef, train_on_cost_map):
 
     training_dictionnary = {}
     models = model.split(',')
 
     for model in models:
-        if(model=='xunet'):
-            training_dictionnary[model]={'batch_size_classif':batch_size_classif_xu, 
-                'batch_size_eval':batch_size_eval_xu,
-                'epoch_num':epoch_num_xu,
-                'CL':CL_xu,
-                'start_emb_rate':start_emb_rate_xu,
-                'pair_training':pair_training_xu,
-                'train_on_cost_map':train_on_cost_map}
-        elif(model=='srnet'):
-            training_dictionnary[model]={'batch_size_classif':batch_size_classif_sr, 
-                'batch_size_eval':batch_size_eval_sr,
-                'epoch_num':epoch_num_sr,
-                'CL':CL_sr,
-                'start_emb_rate':start_emb_rate_sr,
-                'pair_training':pair_training_sr,
-                'train_on_cost_map':train_on_cost_map}
-        elif(model=='efnet'):
-            training_dictionnary[model]={'batch_size_classif':batch_size_classif_ef, 
-                'batch_size_eval':batch_size_eval_ef,
-                'epoch_num':epoch_num_ef,
-                'CL':CL_ef,
-                'start_emb_rate':start_emb_rate_ef,
-                'pair_training':pair_training_ef,
-                'train_on_cost_map':train_on_cost_map}
+        if(model == 'xunet'):
+            training_dictionnary[model] = {'batch_size_classif': batch_size_classif_xu,
+                                           'batch_size_eval': batch_size_eval_xu,
+                                           'epoch_num': epoch_num_xu,
+                                           'CL': CL_xu,
+                                           'start_emb_rate': start_emb_rate_xu,
+                                           'pair_training': pair_training_xu,
+                                           'train_on_cost_map': train_on_cost_map}
+        elif(model == 'srnet'):
+            training_dictionnary[model] = {'batch_size_classif': batch_size_classif_sr,
+                                           'batch_size_eval': batch_size_eval_sr,
+                                           'epoch_num': epoch_num_sr,
+                                           'CL': CL_sr,
+                                           'start_emb_rate': start_emb_rate_sr,
+                                           'pair_training': pair_training_sr,
+                                           'train_on_cost_map': train_on_cost_map}
+        elif(model == 'efnet'):
+            training_dictionnary[model] = {'batch_size_classif': batch_size_classif_ef,
+                                           'batch_size_eval': batch_size_eval_ef,
+                                           'epoch_num': epoch_num_ef,
+                                           'CL': CL_ef,
+                                           'start_emb_rate': start_emb_rate_ef,
+                                           'pair_training': pair_training_ef,
+                                           'train_on_cost_map': train_on_cost_map}
 
     return(training_dictionnary)
 
 
-def write_command(mode, iteration_step, model, data_dir_prot, data_dir_cover, data_dir_stego_0, cost_dir, \
-    image_size, QF, folder_model, permutation_files, version_eff, stride, n_loops, \
-    train_size, valid_size, test_size, attack, attack_last, emb_rate, \
-    batch_adv, n_iter_max_backpack, N_samples, tau_0, precision, lr, \
-    num_of_threads, training_dictionnary, spatial):
+def write_command(mode, iteration_step, model, data_dir_prot, data_dir_cover, data_dir_stego_0, cost_dir,
+                  image_size, QF, folder_model, permutation_files, version_eff, stride, n_loops,
+                  train_size, valid_size, test_size, attack, attack_last, emb_rate,
+                  batch_adv, n_iter_max_backpack, N_samples, tau_0, precision, lr,
+                  num_of_threads, training_dictionnary, spatial):
 
     com = ' --data_dir_prot='+data_dir_prot
-    com+= ' --data_dir_cover='+data_dir_cover
-    com+= ' --image_size='+str(image_size)
-    com+= ' --QF=' + str(QF)
-    com+= ' --folder_model='+folder_model
-    com+= ' --permutation_files=' + permutation_files 
-    com+= ' --version_eff=' + version_eff
-    com+= ' --stride=' + str(stride)
-    com+= ' --n_loops=' + str(n_loops)
-
-    if(mode=='classif'):
-        com+= ' --batch_size_eval='+str(training_dictionnary[model]['batch_size_eval'])
-        com+= ' --data_dir_stego_0='+data_dir_stego_0
-        com+= ' --train_size='+str(train_size)
-        com+= ' --valid_size='+str(valid_size)
-        com+= ' --test_size='+str(test_size)
-        com+= ' --model='+model
-
-    elif (mode=='attack'):
-        com+= ' --iteration_step='+str(iteration_step)
-        com+= ' --attack=' + str(attack)
-        com+= ' --attack_last=' + str(attack_last)
-        com+= ' --emb_rate=' + str(emb_rate)
-        com+= ' --cost_dir=' + str(cost_dir)
-        com+= ' --idx_start=$SLURM_ARRAY_TASK_ID'
-        com+= ' --batch_adv=' + str(batch_adv)
-        com+= ' --n_iter_max_backpack=' + str(n_iter_max_backpack)
-        com+= ' --N_samples=' + str(N_samples)
-        com+= ' --tau_0=' + str(tau_0)
-        com+= ' --precision=' + str(precision)
-        com+= ' --lr=' + str(lr)
-        com+= ' --model=' + model
-
-    elif (mode=='train'):
-        com+= ' --iteration_step='+str(iteration_step)
-        com+= ' --train_size='+str(train_size)
-        com+= ' --valid_size='+str(valid_size)
-        com+= ' --test_size='+str(test_size)
-        com+= ' --model='+model
-        com+= ' --num_of_threads='+str(num_of_threads)
-        com+= ' --emb_rate=' + str(emb_rate)   
-        com+= ' --cost_dir=' + str(cost_dir)
-        com+= ' --data_dir_stego_0=' + str(data_dir_stego_0)
-        com+= ' --spatial=' + str(spatial)
-        
-        com+= ' --batch_size_classif='+str(training_dictionnary[model]['batch_size_classif'])
-        com+= ' --batch_size_eval='+str(training_dictionnary[model]['batch_size_eval'])
-        com+= ' --epoch_num='+str(training_dictionnary[model]['epoch_num'])
-        com+= ' --CL=' + str(training_dictionnary[model]['CL'])   
-        com+= ' --start_emb_rate=' + str(training_dictionnary[model]['start_emb_rate'])   
-        com+= ' --pair_training=' + str(training_dictionnary[model]['pair_training']) 
-        com+= ' --train_on_cost_map=' + str(training_dictionnary[model]['train_on_cost_map']) 
-
+    com += ' --data_dir_cover='+data_dir_cover
+    com += ' --image_size='+str(image_size)
+    com += ' --QF=' + str(QF)
+    com += ' --folder_model='+folder_model
+    com += ' --permutation_files=' + permutation_files
+    com += ' --version_eff=' + version_eff
+    com += ' --stride=' + str(stride)
+    com += ' --n_loops=' + str(n_loops)
+
+    if(mode == 'classif'):
+        com += ' --batch_size_eval=' + \
+            str(training_dictionnary[model]['batch_size_eval'])
+        com += ' --data_dir_stego_0='+data_dir_stego_0
+        com += ' --train_size='+str(train_size)
+        com += ' --valid_size='+str(valid_size)
+        com += ' --test_size='+str(test_size)
+        com += ' --model='+model
+
+    elif (mode == 'attack'):
+        com += ' --iteration_step='+str(iteration_step)
+        com += ' --attack=' + str(attack)
+        com += ' --attack_last=' + str(attack_last)
+        com += ' --emb_rate=' + str(emb_rate)
+        com += ' --cost_dir=' + str(cost_dir)
+        com += ' --idx_start=$SLURM_ARRAY_TASK_ID'
+        com += ' --batch_adv=' + str(batch_adv)
+        com += ' --n_iter_max_backpack=' + str(n_iter_max_backpack)
+        com += ' --N_samples=' + str(N_samples)
+        com += ' --tau_0=' + str(tau_0)
+        com += ' --precision=' + str(precision)
+        com += ' --lr=' + str(lr)
+        com += ' --model=' + model
+
+    elif (mode == 'train'):
+        com += ' --iteration_step='+str(iteration_step)
+        com += ' --train_size='+str(train_size)
+        com += ' --valid_size='+str(valid_size)
+        com += ' --test_size='+str(test_size)
+        com += ' --model='+model
+        com += ' --num_of_threads='+str(num_of_threads)
+        com += ' --emb_rate=' + str(emb_rate)
+        com += ' --cost_dir=' + str(cost_dir)
+        com += ' --data_dir_stego_0=' + str(data_dir_stego_0)
+        com += ' --spatial=' + str(spatial)
+
+        com += ' --batch_size_classif=' + \
+            str(training_dictionnary[model]['batch_size_classif'])
+        com += ' --batch_size_eval=' + \
+            str(training_dictionnary[model]['batch_size_eval'])
+        com += ' --epoch_num='+str(training_dictionnary[model]['epoch_num'])
+        com += ' --CL=' + str(training_dictionnary[model]['CL'])
+        com += ' --start_emb_rate=' + \
+            str(training_dictionnary[model]['start_emb_rate'])
+        com += ' --pair_training=' + \
+            str(training_dictionnary[model]['pair_training'])
+        com += ' --train_on_cost_map=' + \
+            str(training_dictionnary[model]['train_on_cost_map'])
 
     return(com)
 
 
+def run_job(mode, label, command, iteration, num_of_threads=None,
+            num_batch=None, gpu=True):
 
-def run_job(mode, label, command, iteration, \
-        num_batch=None, num_of_threads = None, gpu=True):
-
-    name = label + '_' + str(iteration) +'_'+ mode 
+    name = label + '_' + str(iteration) + '_' + mode
     job_file = "./%s.job" % name
 
-    with open(job_file, 'w+') as fh: 
+    with open(job_file, 'w+') as fh:
         fh.writelines("#!/bin/bash\n")
         fh.writelines("#SBATCH --nodes=1\n")
         fh.writelines("#SBATCH --job-name=%s\n" % name)
 
         if(gpu):
-            fh.writelines("#SBATCH --account=%s\n" % 'srp@gpu')
             fh.writelines("#SBATCH --gres=gpu:1\n")
             fh.writelines("#SBATCH --ntasks=1\n")
             fh.writelines("#SBATCH --hint=nomultithread\n")
-            fh.writelines("#SBATCH --time=10:00:00 \n")
-            
+            fh.writelines("#SBATCH --time=15:00:00 \n")
+
         else:
-            fh.writelines("#SBATCH --account=%s\n" % 'kun@cpu')
-            fh.writelines("#SBATCH --time=2:00:00 \n")            
+            fh.writelines("#SBATCH --time=2:00:00 \n")
 
-        if(mode=='attack'):
+        if(mode == 'attack'):
             fh.writelines("#SBATCH -C v100-32g\n")
             fh.writelines("#SBATCH --array="+str(0)+'-'+str(num_batch)+" \n")
             fh.writelines("module purge\n")
-            fh.writelines("module load python/3.7.5\n")
-            fh.writelines("module load pytorch-gpu/py3/1.7.1\n")
-            
+            # WRITE LOAD MODULES
             fh.writelines("python -u " + command)
 
         elif('train' in mode):
             fh.writelines("#SBATCH -C v100-32g\n")
             fh.writelines("#SBATCH --cpus-per-task="+str(num_of_threads)+"\n")
+            fh.writelines("module purge\n")
             fh.writelines("module load pytorch-gpu/py3/1.7.1\n")
             fh.writelines("time python -u " + command)
-        
+
         else:
+            fh.writelines("module purge\n")
             fh.writelines("module load pytorch-gpu/py3/1.7.1\n")
             fh.writelines("time python -u " + command)
 
-    os.system("sbatch %s" %job_file)
+    os.system("sbatch %s" % job_file)