Spaces:
Runtime error
Runtime error
import json | |
import hashlib | |
import random | |
import string | |
import warnings | |
import matplotlib.pyplot as plt | |
TITLE = "# MNIST Adversarial: Try to fool this MNIST model" | |
description = """This project is about dynamic adversarial data collection (DADC). | |
The basic idea is to collect “adversarial data” - the kind of data that is difficult for a model to predict correctly. | |
This kind of data is presumably the most valuable for a model, so this can be helpful in low-resource settings where data is hard to collect and label. | |
""" | |
WHAT_TO_DO=""" | |
### What to do: | |
1. Draw any number from 0-9. The model will automatically try to predict it after drawing. | |
2. If the model misclassifies it, Flag that example. | |
3. This will add your (adversarial) example to a dataset on which the model will be trained later. | |
4. The model will finetune on the adversarial samples after every __{num_samples}__ samples have been generated. | |
""" | |
MODEL_IS_WRONG = """ | |
--- | |
### Did the model get it wrong or has a low confidence? Choose the correct prediction below and flag it. When you flag it, the instance is saved [here](https://huggingface.co/datasets/chrisjay/mnist-adversarial-dataset) and the model learns from it periodically. | |
""" | |
DEFAULT_TEST_METRIC = "<html> Current test metric - Avg. loss: 1000, Accuracy: 30/1000 (30%) </html>" | |
DASHBOARD_EXPLANATION="To test the effect of adversarial training on out-of-distribution data, we track the performance progress of the model on the [MNIST Corrupted test dataset](https://zenodo.org/record/3239543). We are using {TEST_PER_SAMPLE} samples per digit." | |
DASHBOARD_EXPLANATION_TEST="Test accuracy on out-of-distribution data for all numbers combined." | |
STATS_EXPLANATION = "Here is the distribution of the __{num_adv_samples}__ adversarial samples we've got. The dataset can be found [here](https://huggingface.co/datasets/chrisjay/mnist-adversarial-dataset)." | |
def get_unique_name(): | |
return ''.join([random.choice(string.ascii_letters | |
+ string.digits) for n in range(32)]) | |
def read_json(file): | |
with open(file,'r',encoding="utf8") as f: | |
return json.load(f) | |
def read_json_lines(file): | |
try: | |
with open(file,'r',encoding="utf8") as f: | |
lines = f.readlines() | |
data=[] | |
for l in lines: | |
data.append(json.loads(l)) | |
return data | |
except Exception as err: | |
warnings.warn(f"{err}") | |
return None | |
def json_dump(thing): | |
return json.dumps(thing, | |
ensure_ascii=False, | |
sort_keys=True, | |
indent=None, | |
separators=(',', ':')) | |
def get_hash(thing): # stable-hashing | |
return str(hashlib.md5(json_dump(thing).encode('utf-8')).hexdigest()) | |
def dump_json(thing,file): | |
with open(file,'w+',encoding="utf8") as f: | |
json.dump(thing,f) | |
def plot_bar(value,name,x_name,y_name,title,set_yticks=False,set_xticks=False): | |
fig, ax = plt.subplots(tight_layout=True) | |
ax.set(xlabel=x_name, ylabel=y_name,title=title) | |
if set_yticks: | |
ax.set_yticks(range(min(name), max(name)+1, 1)) | |
if set_xticks: | |
ax.set_xticks(range(min(name), max(name)+1, 1)) | |
ax.barh(name, value) | |
return ax.figure | |