Spaces:
Runtime error
Runtime error
import gradio as gr | |
from transformers import SegformerFeatureExtractor, SegformerForSemanticSegmentation | |
from PIL import Image | |
import torch | |
# ๋ชจ๋ธ๊ณผ feature extractor ๋ก๋ | |
model = SegformerForSemanticSegmentation.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512") | |
feature_extractor = SegformerFeatureExtractor.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512") | |
# ์ด๋ฏธ์ง๋ฅผ ์ฒ๋ฆฌํ๋ ํจ์ | |
def predict(image): | |
# ์ด๋ฏธ์ง๋ฅผ ๋ชจ๋ธ์ ๋ง๊ฒ ๋ณํ | |
inputs = feature_extractor(images=image, return_tensors="pt") | |
outputs = model(**inputs) | |
logits = outputs.logits | |
# ๊ฒฐ๊ณผ๋ฅผ ์ด๋ฏธ์ง๋ก ๋ณํ (์: ๊ฐ์ฅ ๋์ ํ๋ฅ ์ ๊ฐ์ง ํด๋์ค ์ ํ) | |
result = torch.argmax(logits, dim=1) | |
result = result.squeeze().cpu().numpy() | |
# ์ฌ๊ธฐ์์๋ ๋จ์ํ๋ฅผ ์ํด ๊ฒฐ๊ณผ๋ฅผ ๊ทธ๋๋ก ๋ฐํํฉ๋๋ค. | |
# ์ค์ ๋ก๋ ๊ฒฐ๊ณผ๋ฅผ ์ ์ ํ ํ์์ผ๋ก ๋ณํํด์ผ ํ ์ ์์ต๋๋ค. | |
return result | |
# Gradio ์ธํฐํ์ด์ค ์์ฑ | |
demo = gr.Interface( | |
fn=predict, | |
inputs=gr.inputs.Image(shape=(400, 600)), | |
outputs=gr.outputs.Image(), | |
examples=["image1.jpg", "image2.jpg", "image3.jpg"] # ์ธ ๊ฐ์ ์์ ์ด๋ฏธ์ง ๊ฒฝ๋ก | |
) | |
# ์ธํฐํ์ด์ค ์คํ | |
demo.launch() |