Spaces:
Runtime error
Runtime error
work on training and dashboard statistics 2
Browse files- .gitignore +2 -1
- .gitmodules +3 -0
- app.py +11 -13
- data_mnist +1 -1
- metrics.json +1 -0
- model.pth +1 -1
- optimizer.pth +1 -1
- utils.py +2 -2
.gitignore
CHANGED
@@ -1,3 +1,4 @@
|
|
1 |
__pycache__/*
|
2 |
data_local/*
|
3 |
-
flagged/*
|
|
|
|
1 |
__pycache__/*
|
2 |
data_local/*
|
3 |
+
flagged/*
|
4 |
+
data_mnist/*
|
.gitmodules
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
[submodule "data_mnist"]
|
2 |
+
path = data_mnist
|
3 |
+
url = https://huggingface.co/datasets/chrisjay/mnist-adversarial-dataset
|
app.py
CHANGED
@@ -22,7 +22,7 @@ learning_rate = 0.01
|
|
22 |
momentum = 0.5
|
23 |
log_interval = 10
|
24 |
random_seed = 1
|
25 |
-
TRAIN_CUTOFF =
|
26 |
WHAT_TO_DO=WHAT_TO_DO.format(num_samples=TRAIN_CUTOFF)
|
27 |
METRIC_PATH = './metrics.json'
|
28 |
REPOSITORY_DIR = "data"
|
@@ -216,8 +216,8 @@ def test():
|
|
216 |
test_losses.append(test_loss)
|
217 |
acc = 100. * correct / len(test_loader.dataset)
|
218 |
acc = acc.item()
|
219 |
-
test_metric = '〽Current test metric
|
220 |
-
test_loss,
|
221 |
return test_metric,acc
|
222 |
|
223 |
|
@@ -349,7 +349,7 @@ def flag(input_image,correct_result,adversarial_number):
|
|
349 |
test_metric_ = train_and_test()
|
350 |
test_metric = f"<html> {test_metric_} </html>"
|
351 |
output = f'<div> ✔ ({adversarial_number}) Successfully saved your adversarial data and trained the model on adversarial data! </div>'
|
352 |
-
return output,
|
353 |
|
354 |
def get_number_dict(DATA_DIR):
|
355 |
files = [f.name for f in os.scandir(DATA_DIR)]
|
@@ -376,10 +376,11 @@ def get_statistics():
|
|
376 |
DATA_DIR = './data_mnist/data'
|
377 |
numbers_count_keys,numbers_count_values = get_number_dict(DATA_DIR)
|
378 |
|
|
|
379 |
|
380 |
plt_digits = plot_bar(numbers_count_values,numbers_count_keys,'Number of adversarial samples',"Digit",f"Distribution of adversarial samples over digits")
|
381 |
|
382 |
-
fig_d, ax_d = plt.subplots(
|
383 |
|
384 |
if os.path.exists(METRIC_PATH):
|
385 |
metric_dict = read_json(METRIC_PATH)
|
@@ -401,7 +402,7 @@ def get_statistics():
|
|
401 |
</div>
|
402 |
"""
|
403 |
|
404 |
-
return plt_digits,fig_d,done_html
|
405 |
|
406 |
|
407 |
|
@@ -417,7 +418,7 @@ def main():
|
|
417 |
with gr.Tabs():
|
418 |
with gr.TabItem('MNIST'):
|
419 |
gr.Markdown(WHAT_TO_DO)
|
420 |
-
test_metric = gr.outputs.HTML(
|
421 |
with gr.Row():
|
422 |
|
423 |
|
@@ -435,23 +436,20 @@ def main():
|
|
435 |
|
436 |
|
437 |
submit.click(image_classifier,inputs = [image_input],outputs=[label_output])
|
438 |
-
flag_btn.click(flag,inputs=[image_input,number_dropdown,adversarial_number],outputs=[output_result,
|
439 |
|
440 |
with gr.TabItem('Dashboard') as dashboard:
|
441 |
notification = gr.HTML("""<div style="color: green">
|
442 |
<p> ⌛ Creating statistics... </p>
|
443 |
</div>
|
444 |
""")
|
445 |
-
_,numbers_count_values_ = get_number_dict('./data_mnist/data')
|
446 |
|
447 |
-
|
448 |
-
|
449 |
-
gr.Markdown(STATS_EXPLANATION_)
|
450 |
stat_adv_image =gr.Plot(type="matplotlib")
|
451 |
gr.Markdown(DASHBOARD_EXPLANATION)
|
452 |
test_results=gr.Plot(type="matplotlib")
|
453 |
|
454 |
-
dashboard.select(get_statistics,inputs=[],outputs=[stat_adv_image,test_results,notification])
|
455 |
|
456 |
|
457 |
|
|
|
22 |
momentum = 0.5
|
23 |
log_interval = 10
|
24 |
random_seed = 1
|
25 |
+
TRAIN_CUTOFF = 10
|
26 |
WHAT_TO_DO=WHAT_TO_DO.format(num_samples=TRAIN_CUTOFF)
|
27 |
METRIC_PATH = './metrics.json'
|
28 |
REPOSITORY_DIR = "data"
|
|
|
216 |
test_losses.append(test_loss)
|
217 |
acc = 100. * correct / len(test_loader.dataset)
|
218 |
acc = acc.item()
|
219 |
+
test_metric = '〽Current test metric -> Avg. loss: `{:.4f}`, Accuracy: `{:.0f}%`\n'.format(
|
220 |
+
test_loss,acc)
|
221 |
return test_metric,acc
|
222 |
|
223 |
|
|
|
349 |
test_metric_ = train_and_test()
|
350 |
test_metric = f"<html> {test_metric_} </html>"
|
351 |
output = f'<div> ✔ ({adversarial_number}) Successfully saved your adversarial data and trained the model on adversarial data! </div>'
|
352 |
+
return output,adversarial_number
|
353 |
|
354 |
def get_number_dict(DATA_DIR):
|
355 |
files = [f.name for f in os.scandir(DATA_DIR)]
|
|
|
376 |
DATA_DIR = './data_mnist/data'
|
377 |
numbers_count_keys,numbers_count_values = get_number_dict(DATA_DIR)
|
378 |
|
379 |
+
STATS_EXPLANATION_ = STATS_EXPLANATION.format(num_adv_samples = sum(numbers_count_values))
|
380 |
|
381 |
plt_digits = plot_bar(numbers_count_values,numbers_count_keys,'Number of adversarial samples',"Digit",f"Distribution of adversarial samples over digits")
|
382 |
|
383 |
+
fig_d, ax_d = plt.subplots(tight_layout=True)
|
384 |
|
385 |
if os.path.exists(METRIC_PATH):
|
386 |
metric_dict = read_json(METRIC_PATH)
|
|
|
402 |
</div>
|
403 |
"""
|
404 |
|
405 |
+
return plt_digits,fig_d,done_html,STATS_EXPLANATION_
|
406 |
|
407 |
|
408 |
|
|
|
418 |
with gr.Tabs():
|
419 |
with gr.TabItem('MNIST'):
|
420 |
gr.Markdown(WHAT_TO_DO)
|
421 |
+
#test_metric = gr.outputs.HTML("")
|
422 |
with gr.Row():
|
423 |
|
424 |
|
|
|
436 |
|
437 |
|
438 |
submit.click(image_classifier,inputs = [image_input],outputs=[label_output])
|
439 |
+
flag_btn.click(flag,inputs=[image_input,number_dropdown,adversarial_number],outputs=[output_result,adversarial_number])
|
440 |
|
441 |
with gr.TabItem('Dashboard') as dashboard:
|
442 |
notification = gr.HTML("""<div style="color: green">
|
443 |
<p> ⌛ Creating statistics... </p>
|
444 |
</div>
|
445 |
""")
|
|
|
446 |
|
447 |
+
stats = gr.Markdown()
|
|
|
|
|
448 |
stat_adv_image =gr.Plot(type="matplotlib")
|
449 |
gr.Markdown(DASHBOARD_EXPLANATION)
|
450 |
test_results=gr.Plot(type="matplotlib")
|
451 |
|
452 |
+
dashboard.select(get_statistics,inputs=[],outputs=[stat_adv_image,test_results,notification,stats])
|
453 |
|
454 |
|
455 |
|
data_mnist
CHANGED
@@ -1 +1 @@
|
|
1 |
-
Subproject commit
|
|
|
1 |
+
Subproject commit c6d1292ac6318c7c44131ca2fb18d37535ae1383
|
metrics.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"all": [10.55875015258789], "0": [0.0], "1": [0.0], "2": [0.0], "3": [43.33333206176758], "4": [86.66666412353516], "5": [0.0], "6": [0.0], "7": [0.0], "8": [0.0], "9": [0.0]}
|
model.pth
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 89871
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0615455222f0654123d29490ed6fa00db335abb7bc856a817ed8069c03cfaf42
|
3 |
size 89871
|
optimizer.pth
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 89807
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9cfa224990c352a3ad53a41d439d5dd790358bb1e0acb9d3d63379f5c9d0ba7e
|
3 |
size 89807
|
utils.py
CHANGED
@@ -22,7 +22,7 @@ WHAT_TO_DO="""
|
|
22 |
MODEL_IS_WRONG = """
|
23 |
---
|
24 |
|
25 |
-
|
26 |
"""
|
27 |
DEFAULT_TEST_METRIC = "<html> Current test metric - Avg. loss: 1000, Accuracy: 30/1000 (30%) </html>"
|
28 |
|
@@ -65,7 +65,7 @@ def dump_json(thing,file):
|
|
65 |
|
66 |
|
67 |
def plot_bar(value,name,x_name,y_name,title):
|
68 |
-
fig, ax = plt.subplots(
|
69 |
|
70 |
ax.set(xlabel=x_name, ylabel=y_name,title=title)
|
71 |
|
|
|
22 |
MODEL_IS_WRONG = """
|
23 |
---
|
24 |
|
25 |
+
### Did the model get it wrong? Choose the correct prediction below and flag it. When you flag it, the instance is saved to our dataset and the model is trained on it.
|
26 |
"""
|
27 |
DEFAULT_TEST_METRIC = "<html> Current test metric - Avg. loss: 1000, Accuracy: 30/1000 (30%) </html>"
|
28 |
|
|
|
65 |
|
66 |
|
67 |
def plot_bar(value,name,x_name,y_name,title):
|
68 |
+
fig, ax = plt.subplots(tight_layout=True)
|
69 |
|
70 |
ax.set(xlabel=x_name, ylabel=y_name,title=title)
|
71 |
|