JinHyeong99 commited on
Commit
a87c3be
ยท
1 Parent(s): 91062c2
Files changed (1) hide show
  1. app.py +27 -11
app.py CHANGED
@@ -1,18 +1,34 @@
 
1
  from transformers import SegformerFeatureExtractor, SegformerForSemanticSegmentation
2
  from PIL import Image
 
3
 
 
 
 
4
 
5
- feature_extractor = SegformerFeatureExtractor.from_pretrained("nvidia/segformer-b3-finetuned-cityscapes-1024-1024")
6
- model = SegformerForSemanticSegmentation.from_pretrained("nvidia/segformer-b3-finetuned-cityscapes-1024-1024")
 
 
 
 
7
 
8
- # url = "http://images.cocodataset.org/val2017/000000039769.jpg"
9
- # image = Image.open(requests.get(url, stream=True).raw)
 
10
 
11
- image1, image2, image3 = 'image1.jpg', 'image2.jpg', 'image3.jpg'
12
- image1 = Image.open(image1)
13
- image2 = Image.open(image2)
14
- image3 = Image.open(image3)
15
 
16
- inputs = feature_extractor(images=[image1, image2, image3], return_tensors="pt")
17
- outputs = model(**inputs)
18
- logits = outputs.logits # shape (batch_size, num_labels, height/4, width/4)
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
  from transformers import SegformerFeatureExtractor, SegformerForSemanticSegmentation
3
  from PIL import Image
4
+ import torch
5
 
6
+ # ๋ชจ๋ธ๊ณผ feature extractor ๋กœ๋“œ
7
+ model = SegformerForSemanticSegmentation.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")
8
+ feature_extractor = SegformerFeatureExtractor.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")
9
 
10
+ # ์ด๋ฏธ์ง€๋ฅผ ์ฒ˜๋ฆฌํ•˜๋Š” ํ•จ์ˆ˜
11
+ def predict(image):
12
+ # ์ด๋ฏธ์ง€๋ฅผ ๋ชจ๋ธ์— ๋งž๊ฒŒ ๋ณ€ํ™˜
13
+ inputs = feature_extractor(images=image, return_tensors="pt")
14
+ outputs = model(**inputs)
15
+ logits = outputs.logits
16
 
17
+ # ๊ฒฐ๊ณผ๋ฅผ ์ด๋ฏธ์ง€๋กœ ๋ณ€ํ™˜ (์˜ˆ: ๊ฐ€์žฅ ๋†’์€ ํ™•๋ฅ ์„ ๊ฐ€์ง„ ํด๋ž˜์Šค ์„ ํƒ)
18
+ result = torch.argmax(logits, dim=1)
19
+ result = result.squeeze().cpu().numpy()
20
 
21
+ # ์—ฌ๊ธฐ์—์„œ๋Š” ๋‹จ์ˆœํ™”๋ฅผ ์œ„ํ•ด ๊ฒฐ๊ณผ๋ฅผ ๊ทธ๋Œ€๋กœ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค.
22
+ # ์‹ค์ œ๋กœ๋Š” ๊ฒฐ๊ณผ๋ฅผ ์ ์ ˆํ•œ ํ˜•์‹์œผ๋กœ ๋ณ€ํ™˜ํ•ด์•ผ ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
23
+ return result
 
24
 
25
+ # Gradio ์ธํ„ฐํŽ˜์ด์Šค ์ƒ์„ฑ
26
+ demo = gr.Interface(
27
+ fn=predict,
28
+ inputs=gr.inputs.Image(shape=(400, 600)),
29
+ outputs=gr.outputs.Image(),
30
+ examples=["image1.jpg", "image2.jpg", "image3.jpg"] # ์„ธ ๊ฐœ์˜ ์˜ˆ์ œ ์ด๋ฏธ์ง€ ๊ฒฝ๋กœ
31
+ )
32
+
33
+ # ์ธํ„ฐํŽ˜์ด์Šค ์‹คํ–‰
34
+ demo.launch()