Spaces:
Runtime error
Runtime error
File size: 2,488 Bytes
0f6c099 5a104f4 0f6c099 5a104f4 0f6c099 5a104f4 0f6c099 5a104f4 0f6c099 5a104f4 0f6c099 5a104f4 0f6c099 5a104f4 0f6c099 5a104f4 0f6c099 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
import gradio as gr
from PIL import Image
import numpy as np
import tensorflow as tf
from transformers import AutoFeatureExtractor, TFAutoModelForSemanticSegmentation
# Hugging Face ๋ชจ๋ธ ๋ฐ ํ ํฌ๋์ด์
model_name = "nvidia/segformer-b0-finetuned-cityscapes-1024-1024"
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
model = TFAutoModelForSemanticSegmentation.from_pretrained(model_name)
def label_to_color_image(label, colormap):
color_seg = np.zeros(
(label.shape[0], label.shape[1], 3), dtype=np.uint8
) # height, width, 3
for i in range(len(colormap)):
color_seg[label.numpy() == i, :] = colormap[i]
return color_seg
def draw_plot(pred_img, seg, colormap, labels_list):
# your existing draw_plot function, unchanged
def huggingface_model(input_img):
input_img = Image.fromarray(input_img)
inputs = feature_extractor(images=input_img, return_tensors="tf")
outputs = model(**inputs)
logits = outputs.logits
logits = tf.transpose(logits, [0, 2, 3, 1])
logits = tf.image.resize(
logits, input_img.size[::-1]
) # We reverse the shape of `image` because `image.size` returns width and height.
seg = tf.math.argmax(logits, axis=-1)[0]
# Define the colormap for the cityscapes dataset
colormap = [
[128, 64, 128],
[244, 35, 232],
[70, 70, 70],
[102, 102, 156],
[190, 153, 153],
[153, 153, 153],
[250, 170, 30],
[220, 220, 0],
[107, 142, 35],
[152, 251, 152],
[0, 130, 180],
[220, 20, 60],
[255, 0, 0],
[0, 0, 142],
[0, 0, 70],
[0, 60, 100],
[0, 80, 100],
[0, 0, 230],
[119, 11, 32],
]
color_seg = label_to_color_image(seg, colormap)
# Show image + mask
pred_img = np.array(input_img) * 0.5 + color_seg * 0.5
pred_img = pred_img.astype(np.uint8)
# Draw plot
fig = draw_plot(pred_img, seg, colormap, labels_list)
return fig
# ์ฌ๋ฌ๋ถ์ด ๊ฐ์ง labels.txt ํ์ผ์ ๋ด์ฉ์ labels_list์ ํ ๋นํ์ธ์.
labels_list = ["label1", "label2", ...]
demo = gr.Interface(
fn=huggingface_model,
inputs=gr.Image(shape=(1024, 1024)), # ์
๋ ฅ ์ด๋ฏธ์ง ํฌ๊ธฐ๋ ๋ชจ๋ธ์ ์
๋ ฅ ํฌ๊ธฐ์ ๋ง๊ฒ ์กฐ์ ํด์ผ ํฉ๋๋ค.
outputs=["plot"],
examples=["person-1.jpg", "person-2.jpg", "person-3.jpg", "person-4.jpg", "person-5.jpg"],
allow_flagging='never'
)
demo.launch() |