JinHyeong99 commited on
Commit
5b982d8
ยท
1 Parent(s): 5e1d0b4
Files changed (1) hide show
  1. app.py +27 -21
app.py CHANGED
@@ -1,35 +1,41 @@
1
  import gradio as gr
2
  from transformers import SegformerFeatureExtractor, SegformerForSemanticSegmentation
3
  from PIL import Image
 
4
  import torch
5
 
6
  # ๋ชจ๋ธ๊ณผ feature extractor ๋กœ๋“œ
7
- model = SegformerForSemanticSegmentation.from_pretrained("nvidia/segformer-b3-finetuned-cityscapes-1024-1024")
8
- feature_extractor = SegformerFeatureExtractor.from_pretrained("nvidia/segformer-b3-finetuned-cityscapes-1024-1024")
 
9
 
10
- # ์ด๋ฏธ์ง€๋ฅผ ์ฒ˜๋ฆฌํ•˜๋Š” ํ•จ์ˆ˜
11
- def predict(image):
12
- # ์ด๋ฏธ์ง€๋ฅผ ๋ชจ๋ธ์— ๋งž๊ฒŒ ๋ณ€ํ™˜
13
- processed_image = image.resize((1024, 1024))
14
- inputs = feature_extractor(images=processed_image, return_tensors="pt")
15
- outputs = model(**inputs)
16
- logits = outputs.logits
17
 
18
- # ๊ฒฐ๊ณผ๋ฅผ ์ด๋ฏธ์ง€๋กœ ๋ณ€ํ™˜ (์˜ˆ: ๊ฐ€์žฅ ๋†’์€ ํ™•๋ฅ ์„ ๊ฐ€์ง„ ํด๋ž˜์Šค ์„ ํƒ)
19
- result = torch.argmax(logits)
20
- result = result.squeeze().cpu().numpy()
 
 
 
21
 
22
- # ์—ฌ๊ธฐ์—์„œ๋Š” ๋‹จ์ˆœํ™”๋ฅผ ์œ„ํ•ด ๊ฒฐ๊ณผ๋ฅผ ๊ทธ๋Œ€๋กœ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค.
23
- # ์‹ค์ œ๋กœ๋Š” ๊ฒฐ๊ณผ๋ฅผ ์ ์ ˆํ•œ ํ˜•์‹์œผ๋กœ ๋ณ€ํ™˜ํ•ด์•ผ ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
24
- return result
25
 
26
- # Gradio ์ธํ„ฐํŽ˜์ด์Šค ์ƒ์„ฑ
 
 
 
27
  demo = gr.Interface(
28
- fn=predict,
29
- inputs=gr.inputs.Image(type='pil'),
30
- outputs=gr.outputs.Image(type='pil'),
31
- examples=["image1.jpg", "image2.jpg", "image3.jpg"] # ์„ธ ๊ฐœ์˜ ์˜ˆ์ œ ์ด๋ฏธ์ง€ ๊ฒฝ๋กœ
 
32
  )
33
 
34
  # ์ธํ„ฐํŽ˜์ด์Šค ์‹คํ–‰
35
- demo.launch()
 
1
  import gradio as gr
2
  from transformers import SegformerFeatureExtractor, SegformerForSemanticSegmentation
3
  from PIL import Image
4
+ import numpy as np
5
  import torch
6
 
7
  # ๋ชจ๋ธ๊ณผ feature extractor ๋กœ๋“œ
8
+ model_name = "nvidia/segformer-b0-finetuned-ade-512-512"
9
+ model = SegformerForSemanticSegmentation.from_pretrained(model_name)
10
+ feature_extractor = SegformerFeatureExtractor.from_pretrained(model_name)
11
 
12
+ def segment_image(image):
13
+ # ์ด๋ฏธ์ง€ ์ฒ˜๋ฆฌ
14
+ inputs = feature_extractor(images=image, return_tensors="pt")
15
+ with torch.no_grad():
16
+ outputs = model(**inputs)
 
 
17
 
18
+ # ๋งˆ์Šคํฌ ์ƒ์„ฑ
19
+ upsampled_logits = torch.nn.functional.interpolate(
20
+ outputs.logits, size=image.size[::-1], mode="bilinear", align_corners=False
21
+ )
22
+ upsampled_predictions = upsampled_logits.argmax(dim=1)
23
+ mask = upsampled_predictions.squeeze().numpy()
24
 
25
+ # ๊ฒฐ๊ณผ ๋ฐ˜ํ™˜
26
+ return Image.fromarray(np.uint8(mask * 255))
 
27
 
28
+ # ์˜ˆ์‹œ ์ด๋ฏธ์ง€ ๊ฒฝ๋กœ
29
+ example_images = ["image1.jpg", "image2.jpg", "image3.jpg"]
30
+
31
+ # Gradio ์ธํ„ฐํŽ˜์ด์Šค ์„ค์ •
32
  demo = gr.Interface(
33
+ fn=segment_image,
34
+ inputs=gr.inputs.Image(type="pil"),
35
+ outputs="image",
36
+ title="๋จธ์‹ ๋Ÿฌ๋‹ 7์ฃผ์ฐจ ๊ณผ์ œ_3",
37
+ examples=example_images
38
  )
39
 
40
  # ์ธํ„ฐํŽ˜์ด์Šค ์‹คํ–‰
41
+ demo.launch()