etemkocaaslan commited on
Commit
2ee95b8
·
verified ·
1 Parent(s): f9d05e0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +88 -35
app.py CHANGED
@@ -1,21 +1,11 @@
1
  import torch
2
- from torch.nn import functional as F
3
  import torchvision.models as models
4
- from torchvision import transforms
 
 
5
  import requests
6
-
7
- preprocess = transforms.Compose([
8
- transforms.Resize(256),
9
- transforms.CenterCrop(224),
10
- transforms.ToTensor(),
11
- transforms.Normalize(
12
- mean=[0.485, 0.456 , 0.406],
13
- std=[0.229, 0.224, 0.225]
14
- )
15
- ])
16
-
17
- response = requests.get("https://git.io/JJkYN")
18
- labels = response.text.split("\n")
19
 
20
  image_prediction_models = {
21
  'resnet': models.resnet50,
@@ -50,26 +40,89 @@ def get_model_names(models_dict):
50
 
51
  model_list = get_model_names(image_prediction_models)
52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  def classify_image(input_image, selected_model):
54
- input_tensor = preprocess(input_image)
55
- input_batch = input_tensor.unsqueeze(0)
56
- model = load_pretrained_model(selected_model)
57
- if torch.cuda.is_available():
58
- input_batch = input_batch.to('cuda')
59
-
60
- with torch.no_grad():
61
- output = model(input_batch)
62
-
63
- probabilities = F.softmax(input = output[0] , dim = 0)
64
- top_prob, top_catid = torch.topk(probabilities, 5)
65
- confidences = {labels[top_catid[i].item()]: top_prob[i].item() for i in range(top_prob.size(0))}
66
- return confidences
67
 
68
- import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
- interface = gr.Interface(
71
- fn=classify_image,
72
- inputs= [gr.Image(type='pil'),
73
- gr.Dropdown(model_list)],
74
- outputs=gr.Label(num_top_classes=5))
75
- interface.launch()
 
1
  import torch
 
2
  import torchvision.models as models
3
+ import torchvision.transforms as transforms
4
+ import torchvision.datasets as datasets
5
+ from torchvision.transforms import Compose
6
  import requests
7
+ import random
8
+ import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  image_prediction_models = {
11
  'resnet': models.resnet50,
 
40
 
41
  model_list = get_model_names(image_prediction_models)
42
 
43
+ normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
44
+
45
+ def preprocess(model_name):
46
+ input_size = 224
47
+ if model_name == 'inception':
48
+ input_size = 299
49
+ return transforms.Compose([
50
+ transforms.Resize(256),
51
+ transforms.CenterCrop(input_size),
52
+ transforms.ToTensor(),
53
+ normalize,
54
+ ])
55
+
56
+ response = requests.get("https://git.io/JJkYN")
57
+ labels = response.text.split("\n")
58
+
59
+ def postprocess_default(output):
60
+ probabilities = torch.nn.functional.softmax(output[0], dim=0)
61
+ top_prob, top_catid = torch.topk(probabilities, 5)
62
+ confidences = {labels[top_catid[i].item()]: top_prob[i].item() for i in range(top_prob.size(0))}
63
+ return confidences
64
+
65
+ def postprocess_inception(output):
66
+ probabilities = torch.nn.functional.softmax(output[1], dim=0)
67
+ top_prob, top_catid = torch.topk(probabilities, 5)
68
+ confidences = {labels[top_catid[i].item()]: top_prob[i].item() for i in range(top_prob.size(0))}
69
+ return confidences
70
+
71
  def classify_image(input_image, selected_model):
72
+ preprocess_input = preprocess(model_name=selected_model)
73
+ input_tensor = preprocess_input(input_image)
74
+ input_batch = input_tensor.unsqueeze(0)
75
+ model = load_pretrained_model(selected_model)
 
 
 
 
 
 
 
 
 
76
 
77
+ if torch.cuda.is_available():
78
+ input_batch = input_batch.to('cuda')
79
+ model.to('cuda')
80
+
81
+ model.eval()
82
+ with torch.no_grad():
83
+ output = model(input_batch)
84
+
85
+ if selected_model.lower() == 'inception':
86
+ return postprocess_inception(output)
87
+ else:
88
+ return postprocess_default(output)
89
+
90
+ def get_random_image():
91
+ cifar10 = datasets.CIFAR10(root='./data', train=False, download=True, transform=transforms.ToTensor())
92
+ random_idx = random.randint(0, len(cifar10) - 1)
93
+ image, _ = cifar10[random_idx]
94
+ image = transforms.ToPILImage()(image)
95
+ return image
96
+
97
+ def generate_random_image():
98
+ image = get_random_image()
99
+ return image
100
+
101
+ def classify_generated_image(image, model):
102
+ return classify_image(image, model)
103
+
104
+ with gr.Blocks() as demo:
105
+ with gr.Tabs():
106
+ with gr.TabItem("Upload Image"):
107
+ with gr.Row():
108
+ with gr.Column():
109
+ upload_image = gr.Image(type='pil', label="Upload Image")
110
+ model_dropdown_upload = gr.Dropdown(model_list, label="Select Model")
111
+ classify_button_upload = gr.Button("Classify")
112
+ with gr.Column():
113
+ output_label_upload = gr.Label(num_top_classes=5)
114
+ classify_button_upload.click(classify_image, inputs=[upload_image, model_dropdown_upload], outputs=output_label_upload)
115
+
116
+ with gr.TabItem("Generate Random Image"):
117
+ with gr.Row():
118
+ with gr.Column():
119
+ generate_button = gr.Button("Generate Random Image")
120
+ random_image_output = gr.Image(type='pil', label="Random CIFAR-10 Image")
121
+ with gr.Column():
122
+ model_dropdown_random = gr.Dropdown(model_list, label="Select Model")
123
+ classify_button_random = gr.Button("Classify")
124
+ output_label_random = gr.Label(num_top_classes=5)
125
+ generate_button.click(generate_random_image, inputs=[], outputs=random_image_output)
126
+ classify_button_random.click(classify_generated_image, inputs=[random_image_output, model_dropdown_random], outputs=output_label_random)
127
 
128
+ demo.launch()