Spaces:
Runtime error
Runtime error
import gradio as gr | |
from PIL import Image | |
import numpy as np | |
import tensorflow as tf | |
from transformers import AutoFeatureExtractor, TFAutoModelForSemanticSegmentation | |
# Hugging Face ๋ชจ๋ธ ๋ฐ ํ ํฌ๋์ด์ | |
model_name = "nvidia/segformer-b0-finetuned-cityscapes-1024-1024" | |
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name) | |
model = TFAutoModelForSemanticSegmentation.from_pretrained(model_name) | |
def label_to_color_image(label, colormap): | |
color_seg = np.zeros( | |
(label.shape[0], label.shape[1], 3), dtype=np.uint8 | |
) # height, width, 3 | |
for i in range(len(colormap)): | |
color_seg[label.numpy() == i, :] = colormap[i] | |
return color_seg | |
def draw_plot(pred_img, seg, colormap, labels_list): | |
# your existing draw_plot function, unchanged | |
def huggingface_model(input_img): | |
input_img = Image.fromarray(input_img) | |
inputs = feature_extractor(images=input_img, return_tensors="tf") | |
outputs = model(**inputs) | |
logits = outputs.logits | |
logits = tf.transpose(logits, [0, 2, 3, 1]) | |
logits = tf.image.resize( | |
logits, input_img.size[::-1] | |
) # We reverse the shape of `image` because `image.size` returns width and height. | |
seg = tf.math.argmax(logits, axis=-1)[0] | |
# Define the colormap for the cityscapes dataset | |
colormap = [ | |
[128, 64, 128], | |
[244, 35, 232], | |
[70, 70, 70], | |
[102, 102, 156], | |
[190, 153, 153], | |
[153, 153, 153], | |
[250, 170, 30], | |
[220, 220, 0], | |
[107, 142, 35], | |
[152, 251, 152], | |
[0, 130, 180], | |
[220, 20, 60], | |
[255, 0, 0], | |
[0, 0, 142], | |
[0, 0, 70], | |
[0, 60, 100], | |
[0, 80, 100], | |
[0, 0, 230], | |
[119, 11, 32], | |
] | |
color_seg = label_to_color_image(seg, colormap) | |
# Show image + mask | |
pred_img = np.array(input_img) * 0.5 + color_seg * 0.5 | |
pred_img = pred_img.astype(np.uint8) | |
# Draw plot | |
fig = draw_plot(pred_img, seg, colormap, labels_list) | |
return fig | |
# ์ฌ๋ฌ๋ถ์ด ๊ฐ์ง labels.txt ํ์ผ์ ๋ด์ฉ์ labels_list์ ํ ๋นํ์ธ์. | |
labels_list = ["label1", "label2", ...] | |
demo = gr.Interface( | |
fn=huggingface_model, | |
inputs=gr.Image(shape=(1024, 1024)), # ์ ๋ ฅ ์ด๋ฏธ์ง ํฌ๊ธฐ๋ ๋ชจ๋ธ์ ์ ๋ ฅ ํฌ๊ธฐ์ ๋ง๊ฒ ์กฐ์ ํด์ผ ํฉ๋๋ค. | |
outputs=["plot"], | |
examples=["person-1.jpg", "person-2.jpg", "person-3.jpg", "person-4.jpg", "person-5.jpg"], | |
allow_flagging='never' | |
) | |
demo.launch() |