zihaoz96 commited on
Commit
f3b2340
·
1 Parent(s): 739d458
Files changed (1) hide show
  1. app.py +20 -6
app.py CHANGED
@@ -1,6 +1,8 @@
1
  import gradio as gr
2
  import numpy as np
 
3
  import os
 
4
 
5
  from hugsvision.inference.TorchVisionClassifierInference import TorchVisionClassifierInference
6
 
@@ -15,20 +17,31 @@ colname = "mobilenet_v2"
15
  radio = gr.inputs.Radio(models_name, default="mobilenet_v2", type="value", label=colname)
16
  print(radio.label)
17
 
18
- def predict_image(image):
19
- image = np.array(image) / 255
20
- image = np.expand_dims(image, axis=0)
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  classifier = TorchVisionClassifierInference(
23
- model_path = "./models/" + colname + "/best_model.pth",
24
  )
25
 
26
- pred = classifier.predict(img=image)
27
 
28
  label2id = json.load(open("./models/" + colname + "/best_model.pth"))["label2id"].keys()
29
  # vec = [100.0 if a.lower() == pred.lower() else 0.00 for a in label2id]
30
  acc = dict((label2id[i], "%.2f" % 100.0 if label2id[i].lower() == pred.lower() else 0.0) for i in range(len(label2id)))
31
-
32
  return acc
33
  # return pred
34
 
@@ -37,6 +50,7 @@ categories = open("categories.txt", "r")
37
  labels = categories.readline().split(";")
38
 
39
  image = gr.inputs.Image(shape=(300, 300), label="Upload Your Image Here")
 
40
  label = gr.outputs.Label(num_top_classes=len(labels))
41
 
42
  samples = ['./samples/basking.jpg', './samples/blacktip.jpg']
 
1
  import gradio as gr
2
  import numpy as np
3
+ from PIL import Image
4
  import os
5
+ import json
6
 
7
  from hugsvision.inference.TorchVisionClassifierInference import TorchVisionClassifierInference
8
 
 
17
  radio = gr.inputs.Radio(models_name, default="mobilenet_v2", type="value", label=colname)
18
  print(radio.label)
19
 
20
+ def predict_image(image, model_name):
21
+
22
+ image = Image.fromarray(np.uint8(image)).convert('RGB')
23
+
24
+ print("======================")
25
+ print(type(image))
26
+ print(type(model_name))
27
+ print("==========")
28
+ print(image)
29
+ print(model_name)
30
+ print("======================")
31
+
32
+ # image = np.array(image) / 255
33
+ # image = np.expand_dims(image, axis=0)
34
 
35
  classifier = TorchVisionClassifierInference(
36
+ model_path = "./models/" + colname,
37
  )
38
 
39
+ pred = classifier.predict_image(img=image)
40
 
41
  label2id = json.load(open("./models/" + colname + "/best_model.pth"))["label2id"].keys()
42
  # vec = [100.0 if a.lower() == pred.lower() else 0.00 for a in label2id]
43
  acc = dict((label2id[i], "%.2f" % 100.0 if label2id[i].lower() == pred.lower() else 0.0) for i in range(len(label2id)))
44
+ print(acc)
45
  return acc
46
  # return pred
47
 
 
50
  labels = categories.readline().split(";")
51
 
52
  image = gr.inputs.Image(shape=(300, 300), label="Upload Your Image Here")
53
+ print(image)
54
  label = gr.outputs.Label(num_top_classes=len(labels))
55
 
56
  samples = ['./samples/basking.jpg', './samples/blacktip.jpg']