JinHyeong99
1
a87c3be
raw
history blame
1.26 kB
import gradio as gr
from transformers import SegformerFeatureExtractor, SegformerForSemanticSegmentation
from PIL import Image
import torch
# ๋ชจ๋ธ๊ณผ feature extractor ๋กœ๋“œ
model = SegformerForSemanticSegmentation.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")
feature_extractor = SegformerFeatureExtractor.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")
# ์ด๋ฏธ์ง€๋ฅผ ์ฒ˜๋ฆฌํ•˜๋Š” ํ•จ์ˆ˜
def predict(image):
# ์ด๋ฏธ์ง€๋ฅผ ๋ชจ๋ธ์— ๋งž๊ฒŒ ๋ณ€ํ™˜
inputs = feature_extractor(images=image, return_tensors="pt")
outputs = model(**inputs)
logits = outputs.logits
# ๊ฒฐ๊ณผ๋ฅผ ์ด๋ฏธ์ง€๋กœ ๋ณ€ํ™˜ (์˜ˆ: ๊ฐ€์žฅ ๋†’์€ ํ™•๋ฅ ์„ ๊ฐ€์ง„ ํด๋ž˜์Šค ์„ ํƒ)
result = torch.argmax(logits, dim=1)
result = result.squeeze().cpu().numpy()
# ์—ฌ๊ธฐ์—์„œ๋Š” ๋‹จ์ˆœํ™”๋ฅผ ์œ„ํ•ด ๊ฒฐ๊ณผ๋ฅผ ๊ทธ๋Œ€๋กœ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค.
# ์‹ค์ œ๋กœ๋Š” ๊ฒฐ๊ณผ๋ฅผ ์ ์ ˆํ•œ ํ˜•์‹์œผ๋กœ ๋ณ€ํ™˜ํ•ด์•ผ ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
return result
# Gradio ์ธํ„ฐํŽ˜์ด์Šค ์ƒ์„ฑ
demo = gr.Interface(
fn=predict,
inputs=gr.inputs.Image(shape=(400, 600)),
outputs=gr.outputs.Image(),
examples=["image1.jpg", "image2.jpg", "image3.jpg"] # ์„ธ ๊ฐœ์˜ ์˜ˆ์ œ ์ด๋ฏธ์ง€ ๊ฒฝ๋กœ
)
# ์ธํ„ฐํŽ˜์ด์Šค ์‹คํ–‰
demo.launch()