Spaces:
Runtime error
Runtime error
fixed issue with test acc on line 249
Browse files
app.py
CHANGED
@@ -239,15 +239,14 @@ def train_and_test(train_model=True):
|
|
239 |
# Train for one epoch and test
|
240 |
train_dataset = MNISTAdversarial_Dataset('./data_mnist',TRAIN_TRANSFORM)
|
241 |
|
242 |
-
train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size_test, shuffle=True
|
243 |
-
)
|
244 |
train(n_epochs,network,optimizer,train_loader)
|
245 |
|
246 |
test_metric,test_acc = test()
|
247 |
|
248 |
if os.path.exists(METRIC_PATH):
|
249 |
metric_dict = read_json(METRIC_PATH)
|
250 |
-
metric_dict['all'] = metric_dict['all'] if 'all' in metric_dict else [] + [test_acc]
|
251 |
else:
|
252 |
metric_dict={}
|
253 |
metric_dict['all'] = [test_acc]
|
|
|
239 |
# Train for one epoch and test
|
240 |
train_dataset = MNISTAdversarial_Dataset('./data_mnist',TRAIN_TRANSFORM)
|
241 |
|
242 |
+
train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size_test, shuffle=True)
|
|
|
243 |
train(n_epochs,network,optimizer,train_loader)
|
244 |
|
245 |
test_metric,test_acc = test()
|
246 |
|
247 |
if os.path.exists(METRIC_PATH):
|
248 |
metric_dict = read_json(METRIC_PATH)
|
249 |
+
metric_dict['all'] = metric_dict['all']+ [test_acc] if 'all' in metric_dict else [] + [test_acc]
|
250 |
else:
|
251 |
metric_dict={}
|
252 |
metric_dict['all'] = [test_acc]
|