정정민 commited on
Commit
2608878
·
1 Parent(s): 0752aa7

Add application file

Browse files
Files changed (1) hide show
  1. app.py +63 -0
app.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import requests
3
+ import gradio as gr
4
+ from PIL import Image
5
+
6
+ from transformers import AutoImageProcessor, ResNetForImageClassification
7
+
8
+ target_folder = "JungminChung/India_ResNet"
9
+
10
+ def load_model_and_preprocessor(target_folder):
11
+ model = ResNetForImageClassification.from_pretrained(target_folder)
12
+ image_processor = AutoImageProcessor.from_pretrained(target_folder)
13
+ return model, image_processor
14
+
15
+ def fetch_image(url):
16
+ headers = {
17
+ 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.36'
18
+ }
19
+ image_raw = requests.get(url, headers=headers, stream=True).raw
20
+ image = Image.open(image_raw)
21
+
22
+ return image
23
+
24
+ def infer_image(image, model, image_processor, k):
25
+ processed_img = image_processor(images=image.convert("RGB"), return_tensors="pt")
26
+
27
+ with torch.no_grad():
28
+ outputs = model(**processed_img)
29
+ logits = outputs.logits
30
+
31
+ prob = torch.nn.functional.softmax(logits, dim=-1)
32
+ topk_prob, topk_indices = torch.topk(prob, k=k)
33
+
34
+ res = ""
35
+ for idx, (prob, index) in enumerate(zip(topk_prob[0], topk_indices[0])):
36
+ res += f"{idx+1}. {model.config.id2label[index.item()]:<15} ({prob.item()*100:.2f} %) \n"
37
+ return res
38
+
39
+ def infer(url, k, target_folder=target_folder):
40
+ try :
41
+ image = fetch_image(url)
42
+ model, image_processor = load_model_and_preprocessor(target_folder)
43
+ res = infer_image(image, model, image_processor, k)
44
+ except :
45
+ image = Image.new('RGB', (224, 224))
46
+ res = "이미지를 불러오는데 문제가 있나봐요. 다른 이미지 url로 다시 시도해주세요."
47
+ return image, res
48
+
49
+ demo = gr.Interface(
50
+ fn=infer,
51
+ inputs=[
52
+ gr.Textbox(value="https://i.namu.wiki/i/XQznKj51oCpN5HKkUBe6o2R_fRb4TSbU6JTZk52zYJbbjH_1B0BFHM5uYQMfsFzQOLRHG3mhR8xhqPG_UbeA0w.webp",
53
+ label="Image URL"),
54
+ gr.Slider(minimum=0, maximum=20, step=1, value=3, label="상위 몇개까지 보여줄까요?")
55
+ ],
56
+ outputs=[
57
+ gr.Image(type="pil", label="입력 이미지"),
58
+ gr.Textbox(label="종류 (확률)")
59
+ ],
60
+ )
61
+
62
+ # demo.launch()
63
+ demo.launch(share=True)