Spaces:
Runtime error
Runtime error
enabled reloading of trained weights + using best weights
Browse files- app.py +50 -13
- best_weights/mnist_model.pth +3 -0
- best_weights/optimizer.pth +3 -0
- data_mnist +1 -1
- utils.py +1 -1
app.py
CHANGED
|
@@ -24,6 +24,8 @@ momentum = 0.5
|
|
| 24 |
log_interval = 10
|
| 25 |
random_seed = 1
|
| 26 |
TRAIN_CUTOFF = 10
|
|
|
|
|
|
|
| 27 |
WHAT_TO_DO=WHAT_TO_DO.format(num_samples=TRAIN_CUTOFF)
|
| 28 |
MODEL_PATH = 'model'
|
| 29 |
METRIC_PATH = os.path.join(MODEL_PATH,'metrics.json')
|
|
@@ -86,7 +88,7 @@ class MNISTAdversarial_Dataset(Dataset):
|
|
| 86 |
return img, label
|
| 87 |
|
| 88 |
class MNISTCorrupted_By_Digit(Dataset):
|
| 89 |
-
def __init__(self,transform,digit,limit=
|
| 90 |
self.transform = transform
|
| 91 |
self.digit = digit
|
| 92 |
corrupted_dir="./mnist_c"
|
|
@@ -127,8 +129,8 @@ class MNISTCorrupted(Dataset):
|
|
| 127 |
self.transform = transform
|
| 128 |
corrupted_dir="./mnist_c"
|
| 129 |
files = [f.name for f in os.scandir(corrupted_dir)]
|
| 130 |
-
images = [np.load(os.path.join(os.path.join(corrupted_dir,f),'test_images.npy'))[:
|
| 131 |
-
labels = [np.load(os.path.join(os.path.join(corrupted_dir,f),'test_labels.npy'))[:
|
| 132 |
self.data = np.vstack(images)
|
| 133 |
self.labels = np.hstack(labels)
|
| 134 |
|
|
@@ -283,24 +285,40 @@ if os.path.exists(model_state_dict) and os.path.exists(optimizer_state_dict):
|
|
| 283 |
optimizer.load_state_dict(optimizer_state_dict)
|
| 284 |
|
| 285 |
else:
|
| 286 |
-
#
|
| 287 |
-
|
| 288 |
-
|
|
|
|
|
|
|
| 289 |
_ = train_and_test(False)
|
| 290 |
|
| 291 |
|
| 292 |
-
# Train
|
| 293 |
-
#train(n_epochs,network,optimizer)
|
| 294 |
-
|
| 295 |
-
|
| 296 |
def image_classifier(inp):
|
| 297 |
"""
|
| 298 |
-
It
|
| 299 |
-
|
| 300 |
|
| 301 |
:param inp: the image to be classified
|
| 302 |
-
:return: A dictionary of the
|
| 303 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 304 |
input_image = torchvision.transforms.ToTensor()(inp).unsqueeze(0)
|
| 305 |
with torch.no_grad():
|
| 306 |
|
|
@@ -314,6 +332,19 @@ def image_classifier(inp):
|
|
| 314 |
|
| 315 |
|
| 316 |
def flag(input_image,correct_result,adversarial_number):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 317 |
|
| 318 |
adversarial_number = 0 if None else adversarial_number
|
| 319 |
|
|
@@ -375,6 +406,12 @@ def get_number_dict(DATA_DIR):
|
|
| 375 |
|
| 376 |
|
| 377 |
def get_statistics():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 378 |
model_repo.git_pull()
|
| 379 |
model_state_dict = MODEL_WEIGHTS_PATH
|
| 380 |
optimizer_state_dict = OPTIMIZER_PATH
|
|
|
|
| 24 |
log_interval = 10
|
| 25 |
random_seed = 1
|
| 26 |
TRAIN_CUTOFF = 10
|
| 27 |
+
TEST_PER_SAMPLE = 1500
|
| 28 |
+
DASHBOARD_EXPLANATION = DASHBOARD_EXPLANATION.format(TEST_PER_SAMPLE=TEST_PER_SAMPLE)
|
| 29 |
WHAT_TO_DO=WHAT_TO_DO.format(num_samples=TRAIN_CUTOFF)
|
| 30 |
MODEL_PATH = 'model'
|
| 31 |
METRIC_PATH = os.path.join(MODEL_PATH,'metrics.json')
|
|
|
|
| 88 |
return img, label
|
| 89 |
|
| 90 |
class MNISTCorrupted_By_Digit(Dataset):
|
| 91 |
+
def __init__(self,transform,digit,limit=TEST_PER_SAMPLE):
|
| 92 |
self.transform = transform
|
| 93 |
self.digit = digit
|
| 94 |
corrupted_dir="./mnist_c"
|
|
|
|
| 129 |
self.transform = transform
|
| 130 |
corrupted_dir="./mnist_c"
|
| 131 |
files = [f.name for f in os.scandir(corrupted_dir)]
|
| 132 |
+
images = [np.load(os.path.join(os.path.join(corrupted_dir,f),'test_images.npy'))[:TEST_PER_SAMPLE] for f in files]
|
| 133 |
+
labels = [np.load(os.path.join(os.path.join(corrupted_dir,f),'test_labels.npy'))[:TEST_PER_SAMPLE] for f in files]
|
| 134 |
self.data = np.vstack(images)
|
| 135 |
self.labels = np.hstack(labels)
|
| 136 |
|
|
|
|
| 285 |
optimizer.load_state_dict(optimizer_state_dict)
|
| 286 |
|
| 287 |
else:
|
| 288 |
+
# Use best weights
|
| 289 |
+
BEST_WEIGHTS_MODEL = "best_weights/mnist_model.pth"
|
| 290 |
+
BEST_WEIGHTS_OPTIMIZER = "best_weights/optimizer.pth"
|
| 291 |
+
torch.save(network.state_dict(), BEST_WEIGHTS_MODEL)
|
| 292 |
+
torch.save(optimizer.state_dict(), BEST_WEIGHTS_OPTIMIZER)
|
| 293 |
_ = train_and_test(False)
|
| 294 |
|
| 295 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 296 |
def image_classifier(inp):
|
| 297 |
"""
|
| 298 |
+
It loads the latest model weights from the model repository, and then uses those weights to make a
|
| 299 |
+
prediction on the input image.
|
| 300 |
|
| 301 |
:param inp: the image to be classified
|
| 302 |
+
:return: A dictionary of the form {class_number: confidence}
|
| 303 |
"""
|
| 304 |
+
|
| 305 |
+
# Get latest model weights ----------------
|
| 306 |
+
model_repo.git_pull()
|
| 307 |
+
model_state_dict = MODEL_WEIGHTS_PATH
|
| 308 |
+
optimizer_state_dict = OPTIMIZER_PATH
|
| 309 |
+
|
| 310 |
+
if os.path.exists(model_state_dict) and os.path.exists(optimizer_state_dict):
|
| 311 |
+
network_state_dict = torch.load(model_state_dict)
|
| 312 |
+
network.load_state_dict(network_state_dict)
|
| 313 |
+
optimizer_state_dict = torch.load(optimizer_state_dict)
|
| 314 |
+
optimizer.load_state_dict(optimizer_state_dict)
|
| 315 |
+
else:
|
| 316 |
+
# Use best weights
|
| 317 |
+
BEST_WEIGHTS_MODEL = "best_weights/mnist_model.pth"
|
| 318 |
+
BEST_WEIGHTS_OPTIMIZER = "best_weights/optimizer.pth"
|
| 319 |
+
network.load_state_dict(torch.load(BEST_WEIGHTS_MODEL))
|
| 320 |
+
optimizer.load_state_dict(torch.load(BEST_WEIGHTS_OPTIMIZER))
|
| 321 |
+
|
| 322 |
input_image = torchvision.transforms.ToTensor()(inp).unsqueeze(0)
|
| 323 |
with torch.no_grad():
|
| 324 |
|
|
|
|
| 332 |
|
| 333 |
|
| 334 |
def flag(input_image,correct_result,adversarial_number):
|
| 335 |
+
"""
|
| 336 |
+
It takes in an image, the correct result, and the number of adversarial images that have been
|
| 337 |
+
uploaded so far. It saves the image and metadata to a local directory, uploads the image and
|
| 338 |
+
metadata to the hub, and then pulls the data from the hub to the local directory. If the number of
|
| 339 |
+
images in the local directory is divisible by the TRAIN_CUTOFF, then it trains the model on the
|
| 340 |
+
adversarial data
|
| 341 |
+
|
| 342 |
+
:param input_image: The adversarial image that you want to save
|
| 343 |
+
:param correct_result: The correct number that the image represents
|
| 344 |
+
:param adversarial_number: This is the number of adversarial examples that have been uploaded to the
|
| 345 |
+
dataset
|
| 346 |
+
:return: The output is the output of the flag function.
|
| 347 |
+
"""
|
| 348 |
|
| 349 |
adversarial_number = 0 if None else adversarial_number
|
| 350 |
|
|
|
|
| 406 |
|
| 407 |
|
| 408 |
def get_statistics():
|
| 409 |
+
"""
|
| 410 |
+
It loads the model and optimizer state dicts, pulls the latest data from the repo, gets the number
|
| 411 |
+
of adversarial samples per digit, plots the distribution of adversarial samples per digit, plots the
|
| 412 |
+
test accuracy per digit per train step, and plots the test accuracy for all digits per train step
|
| 413 |
+
:return: the following:
|
| 414 |
+
"""
|
| 415 |
model_repo.git_pull()
|
| 416 |
model_state_dict = MODEL_WEIGHTS_PATH
|
| 417 |
optimizer_state_dict = OPTIMIZER_PATH
|
best_weights/mnist_model.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ba8d282674beb300db53069e4972cfed358f8c7c627cf449215e44b365fcdc54
|
| 3 |
+
size 89871
|
best_weights/optimizer.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:fe255c0ca501d01ae3c2083ea760ea95759fcfe9075e39fba299c57a9907bf1b
|
| 3 |
+
size 623
|
data_mnist
CHANGED
|
@@ -1 +1 @@
|
|
| 1 |
-
Subproject commit
|
|
|
|
| 1 |
+
Subproject commit ed62a26e764902f519ff43df850842e07dfe2cc0
|
utils.py
CHANGED
|
@@ -24,7 +24,7 @@ MODEL_IS_WRONG = """
|
|
| 24 |
"""
|
| 25 |
DEFAULT_TEST_METRIC = "<html> Current test metric - Avg. loss: 1000, Accuracy: 30/1000 (30%) </html>"
|
| 26 |
|
| 27 |
-
DASHBOARD_EXPLANATION="To test the effect of adversarial training on out-of-distribution data, we track the performance progress of the model on the [MNIST Corrupted test dataset](https://zenodo.org/record/3239543)."
|
| 28 |
DASHBOARD_EXPLANATION_TEST="Test accuracy on out-of-distribution data for all numbers."
|
| 29 |
|
| 30 |
STATS_EXPLANATION = "Here is the distribution of the __{num_adv_samples}__ adversarial samples we've got. The dataset can be found [here](https://huggingface.co/datasets/chrisjay/mnist-adversarial-dataset)."
|
|
|
|
| 24 |
"""
|
| 25 |
DEFAULT_TEST_METRIC = "<html> Current test metric - Avg. loss: 1000, Accuracy: 30/1000 (30%) </html>"
|
| 26 |
|
| 27 |
+
DASHBOARD_EXPLANATION="To test the effect of adversarial training on out-of-distribution data, we track the performance progress of the model on the [MNIST Corrupted test dataset](https://zenodo.org/record/3239543). We are using {TEST_PER_SAMPLE} samples per digit."
|
| 28 |
DASHBOARD_EXPLANATION_TEST="Test accuracy on out-of-distribution data for all numbers."
|
| 29 |
|
| 30 |
STATS_EXPLANATION = "Here is the distribution of the __{num_adv_samples}__ adversarial samples we've got. The dataset can be found [here](https://huggingface.co/datasets/chrisjay/mnist-adversarial-dataset)."
|