SakuraD commited on
Commit
e3e6c1e
·
1 Parent(s): d44f456
Files changed (1) hide show
  1. app.py +16 -2
app.py CHANGED
@@ -48,7 +48,20 @@ def inference(img):
48
  prediction = model(image)
49
  prediction = F.softmax(prediction, dim=1).flatten()
50
 
51
- return {imagenet_id_to_classname[str(i)]: float(prediction[i]) for i in range(1000)}
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
 
54
  demo = gr.Blocks()
@@ -79,8 +92,9 @@ with demo:
79
  )
80
 
81
  submit_button.click(fn=inference, inputs=input_image, outputs=label)
 
82
 
83
- demo.launch(enable_queue=True)
84
 
85
 
86
 
 
48
  prediction = model(image)
49
  prediction = F.softmax(prediction, dim=1).flatten()
50
 
51
+ # return {imagenet_id_to_classname[str(i)]: float(prediction[i]) for i in range(1000)}
52
+
53
+ pred_classes = prediction.topk(k=5).indices
54
+ pred_class_names = [imagenet_id_to_classname[str(i.item())] for i in pred_classes[0]]
55
+ pred_class_probs = [prediction[0][i.item()].item() * 100 for i in pred_classes[0]]
56
+ res = "Top 5 predicted labels:\n"
57
+ for name, prob in zip(pred_class_names, pred_class_probs):
58
+ res += f"[{prob:2.2f}%]\t{name}\n"
59
+
60
+ return res
61
+
62
+
63
+ def set_example_image(example: list) -> dict:
64
+ return gr.Image.update(value=example[0])
65
 
66
 
67
  demo = gr.Blocks()
 
92
  )
93
 
94
  submit_button.click(fn=inference, inputs=input_image, outputs=label)
95
+ example_images.click(fn=set_example_image, inputs=example_images, outputs=example_images.components)
96
 
97
+ demo.launch(enable_queue=True, cache_examples=True)
98
 
99
 
100