Spaces:
Runtime error
Runtime error
enabled reloading of trained weights + using best weights
Browse files- app.py +50 -13
- best_weights/mnist_model.pth +3 -0
- best_weights/optimizer.pth +3 -0
- data_mnist +1 -1
- utils.py +1 -1
app.py
CHANGED
@@ -24,6 +24,8 @@ momentum = 0.5
|
|
24 |
log_interval = 10
|
25 |
random_seed = 1
|
26 |
TRAIN_CUTOFF = 10
|
|
|
|
|
27 |
WHAT_TO_DO=WHAT_TO_DO.format(num_samples=TRAIN_CUTOFF)
|
28 |
MODEL_PATH = 'model'
|
29 |
METRIC_PATH = os.path.join(MODEL_PATH,'metrics.json')
|
@@ -86,7 +88,7 @@ class MNISTAdversarial_Dataset(Dataset):
|
|
86 |
return img, label
|
87 |
|
88 |
class MNISTCorrupted_By_Digit(Dataset):
|
89 |
-
def __init__(self,transform,digit,limit=
|
90 |
self.transform = transform
|
91 |
self.digit = digit
|
92 |
corrupted_dir="./mnist_c"
|
@@ -127,8 +129,8 @@ class MNISTCorrupted(Dataset):
|
|
127 |
self.transform = transform
|
128 |
corrupted_dir="./mnist_c"
|
129 |
files = [f.name for f in os.scandir(corrupted_dir)]
|
130 |
-
images = [np.load(os.path.join(os.path.join(corrupted_dir,f),'test_images.npy'))[:
|
131 |
-
labels = [np.load(os.path.join(os.path.join(corrupted_dir,f),'test_labels.npy'))[:
|
132 |
self.data = np.vstack(images)
|
133 |
self.labels = np.hstack(labels)
|
134 |
|
@@ -283,24 +285,40 @@ if os.path.exists(model_state_dict) and os.path.exists(optimizer_state_dict):
|
|
283 |
optimizer.load_state_dict(optimizer_state_dict)
|
284 |
|
285 |
else:
|
286 |
-
#
|
287 |
-
|
288 |
-
|
|
|
|
|
289 |
_ = train_and_test(False)
|
290 |
|
291 |
|
292 |
-
# Train
|
293 |
-
#train(n_epochs,network,optimizer)
|
294 |
-
|
295 |
-
|
296 |
def image_classifier(inp):
|
297 |
"""
|
298 |
-
It
|
299 |
-
|
300 |
|
301 |
:param inp: the image to be classified
|
302 |
-
:return: A dictionary of the
|
303 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
304 |
input_image = torchvision.transforms.ToTensor()(inp).unsqueeze(0)
|
305 |
with torch.no_grad():
|
306 |
|
@@ -314,6 +332,19 @@ def image_classifier(inp):
|
|
314 |
|
315 |
|
316 |
def flag(input_image,correct_result,adversarial_number):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
317 |
|
318 |
adversarial_number = 0 if None else adversarial_number
|
319 |
|
@@ -375,6 +406,12 @@ def get_number_dict(DATA_DIR):
|
|
375 |
|
376 |
|
377 |
def get_statistics():
|
|
|
|
|
|
|
|
|
|
|
|
|
378 |
model_repo.git_pull()
|
379 |
model_state_dict = MODEL_WEIGHTS_PATH
|
380 |
optimizer_state_dict = OPTIMIZER_PATH
|
|
|
24 |
log_interval = 10
|
25 |
random_seed = 1
|
26 |
TRAIN_CUTOFF = 10
|
27 |
+
TEST_PER_SAMPLE = 1500
|
28 |
+
DASHBOARD_EXPLANATION = DASHBOARD_EXPLANATION.format(TEST_PER_SAMPLE=TEST_PER_SAMPLE)
|
29 |
WHAT_TO_DO=WHAT_TO_DO.format(num_samples=TRAIN_CUTOFF)
|
30 |
MODEL_PATH = 'model'
|
31 |
METRIC_PATH = os.path.join(MODEL_PATH,'metrics.json')
|
|
|
88 |
return img, label
|
89 |
|
90 |
class MNISTCorrupted_By_Digit(Dataset):
|
91 |
+
def __init__(self,transform,digit,limit=TEST_PER_SAMPLE):
|
92 |
self.transform = transform
|
93 |
self.digit = digit
|
94 |
corrupted_dir="./mnist_c"
|
|
|
129 |
self.transform = transform
|
130 |
corrupted_dir="./mnist_c"
|
131 |
files = [f.name for f in os.scandir(corrupted_dir)]
|
132 |
+
images = [np.load(os.path.join(os.path.join(corrupted_dir,f),'test_images.npy'))[:TEST_PER_SAMPLE] for f in files]
|
133 |
+
labels = [np.load(os.path.join(os.path.join(corrupted_dir,f),'test_labels.npy'))[:TEST_PER_SAMPLE] for f in files]
|
134 |
self.data = np.vstack(images)
|
135 |
self.labels = np.hstack(labels)
|
136 |
|
|
|
285 |
optimizer.load_state_dict(optimizer_state_dict)
|
286 |
|
287 |
else:
|
288 |
+
# Use best weights
|
289 |
+
BEST_WEIGHTS_MODEL = "best_weights/mnist_model.pth"
|
290 |
+
BEST_WEIGHTS_OPTIMIZER = "best_weights/optimizer.pth"
|
291 |
+
torch.save(network.state_dict(), BEST_WEIGHTS_MODEL)
|
292 |
+
torch.save(optimizer.state_dict(), BEST_WEIGHTS_OPTIMIZER)
|
293 |
_ = train_and_test(False)
|
294 |
|
295 |
|
|
|
|
|
|
|
|
|
296 |
def image_classifier(inp):
|
297 |
"""
|
298 |
+
It loads the latest model weights from the model repository, and then uses those weights to make a
|
299 |
+
prediction on the input image.
|
300 |
|
301 |
:param inp: the image to be classified
|
302 |
+
:return: A dictionary of the form {class_number: confidence}
|
303 |
"""
|
304 |
+
|
305 |
+
# Get latest model weights ----------------
|
306 |
+
model_repo.git_pull()
|
307 |
+
model_state_dict = MODEL_WEIGHTS_PATH
|
308 |
+
optimizer_state_dict = OPTIMIZER_PATH
|
309 |
+
|
310 |
+
if os.path.exists(model_state_dict) and os.path.exists(optimizer_state_dict):
|
311 |
+
network_state_dict = torch.load(model_state_dict)
|
312 |
+
network.load_state_dict(network_state_dict)
|
313 |
+
optimizer_state_dict = torch.load(optimizer_state_dict)
|
314 |
+
optimizer.load_state_dict(optimizer_state_dict)
|
315 |
+
else:
|
316 |
+
# Use best weights
|
317 |
+
BEST_WEIGHTS_MODEL = "best_weights/mnist_model.pth"
|
318 |
+
BEST_WEIGHTS_OPTIMIZER = "best_weights/optimizer.pth"
|
319 |
+
network.load_state_dict(torch.load(BEST_WEIGHTS_MODEL))
|
320 |
+
optimizer.load_state_dict(torch.load(BEST_WEIGHTS_OPTIMIZER))
|
321 |
+
|
322 |
input_image = torchvision.transforms.ToTensor()(inp).unsqueeze(0)
|
323 |
with torch.no_grad():
|
324 |
|
|
|
332 |
|
333 |
|
334 |
def flag(input_image,correct_result,adversarial_number):
|
335 |
+
"""
|
336 |
+
It takes in an image, the correct result, and the number of adversarial images that have been
|
337 |
+
uploaded so far. It saves the image and metadata to a local directory, uploads the image and
|
338 |
+
metadata to the hub, and then pulls the data from the hub to the local directory. If the number of
|
339 |
+
images in the local directory is divisible by the TRAIN_CUTOFF, then it trains the model on the
|
340 |
+
adversarial data
|
341 |
+
|
342 |
+
:param input_image: The adversarial image that you want to save
|
343 |
+
:param correct_result: The correct number that the image represents
|
344 |
+
:param adversarial_number: This is the number of adversarial examples that have been uploaded to the
|
345 |
+
dataset
|
346 |
+
:return: The output is the output of the flag function.
|
347 |
+
"""
|
348 |
|
349 |
adversarial_number = 0 if None else adversarial_number
|
350 |
|
|
|
406 |
|
407 |
|
408 |
def get_statistics():
|
409 |
+
"""
|
410 |
+
It loads the model and optimizer state dicts, pulls the latest data from the repo, gets the number
|
411 |
+
of adversarial samples per digit, plots the distribution of adversarial samples per digit, plots the
|
412 |
+
test accuracy per digit per train step, and plots the test accuracy for all digits per train step
|
413 |
+
:return: the following:
|
414 |
+
"""
|
415 |
model_repo.git_pull()
|
416 |
model_state_dict = MODEL_WEIGHTS_PATH
|
417 |
optimizer_state_dict = OPTIMIZER_PATH
|
best_weights/mnist_model.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ba8d282674beb300db53069e4972cfed358f8c7c627cf449215e44b365fcdc54
|
3 |
+
size 89871
|
best_weights/optimizer.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:fe255c0ca501d01ae3c2083ea760ea95759fcfe9075e39fba299c57a9907bf1b
|
3 |
+
size 623
|
data_mnist
CHANGED
@@ -1 +1 @@
|
|
1 |
-
Subproject commit
|
|
|
1 |
+
Subproject commit ed62a26e764902f519ff43df850842e07dfe2cc0
|
utils.py
CHANGED
@@ -24,7 +24,7 @@ MODEL_IS_WRONG = """
|
|
24 |
"""
|
25 |
DEFAULT_TEST_METRIC = "<html> Current test metric - Avg. loss: 1000, Accuracy: 30/1000 (30%) </html>"
|
26 |
|
27 |
-
DASHBOARD_EXPLANATION="To test the effect of adversarial training on out-of-distribution data, we track the performance progress of the model on the [MNIST Corrupted test dataset](https://zenodo.org/record/3239543)."
|
28 |
DASHBOARD_EXPLANATION_TEST="Test accuracy on out-of-distribution data for all numbers."
|
29 |
|
30 |
STATS_EXPLANATION = "Here is the distribution of the __{num_adv_samples}__ adversarial samples we've got. The dataset can be found [here](https://huggingface.co/datasets/chrisjay/mnist-adversarial-dataset)."
|
|
|
24 |
"""
|
25 |
DEFAULT_TEST_METRIC = "<html> Current test metric - Avg. loss: 1000, Accuracy: 30/1000 (30%) </html>"
|
26 |
|
27 |
+
DASHBOARD_EXPLANATION="To test the effect of adversarial training on out-of-distribution data, we track the performance progress of the model on the [MNIST Corrupted test dataset](https://zenodo.org/record/3239543). We are using {TEST_PER_SAMPLE} samples per digit."
|
28 |
DASHBOARD_EXPLANATION_TEST="Test accuracy on out-of-distribution data for all numbers."
|
29 |
|
30 |
STATS_EXPLANATION = "Here is the distribution of the __{num_adv_samples}__ adversarial samples we've got. The dataset can be found [here](https://huggingface.co/datasets/chrisjay/mnist-adversarial-dataset)."
|