DDingcheol's picture
Update app.py
0f6c099
raw
history blame
2.49 kB
import gradio as gr
from PIL import Image
import numpy as np
import tensorflow as tf
from transformers import AutoFeatureExtractor, TFAutoModelForSemanticSegmentation
# Hugging Face ๋ชจ๋ธ ๋ฐ ํ† ํฌ๋‚˜์ด์ €
model_name = "nvidia/segformer-b0-finetuned-cityscapes-1024-1024"
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
model = TFAutoModelForSemanticSegmentation.from_pretrained(model_name)
def label_to_color_image(label, colormap):
color_seg = np.zeros(
(label.shape[0], label.shape[1], 3), dtype=np.uint8
) # height, width, 3
for i in range(len(colormap)):
color_seg[label.numpy() == i, :] = colormap[i]
return color_seg
def draw_plot(pred_img, seg, colormap, labels_list):
# your existing draw_plot function, unchanged
def huggingface_model(input_img):
input_img = Image.fromarray(input_img)
inputs = feature_extractor(images=input_img, return_tensors="tf")
outputs = model(**inputs)
logits = outputs.logits
logits = tf.transpose(logits, [0, 2, 3, 1])
logits = tf.image.resize(
logits, input_img.size[::-1]
) # We reverse the shape of `image` because `image.size` returns width and height.
seg = tf.math.argmax(logits, axis=-1)[0]
# Define the colormap for the cityscapes dataset
colormap = [
[128, 64, 128],
[244, 35, 232],
[70, 70, 70],
[102, 102, 156],
[190, 153, 153],
[153, 153, 153],
[250, 170, 30],
[220, 220, 0],
[107, 142, 35],
[152, 251, 152],
[0, 130, 180],
[220, 20, 60],
[255, 0, 0],
[0, 0, 142],
[0, 0, 70],
[0, 60, 100],
[0, 80, 100],
[0, 0, 230],
[119, 11, 32],
]
color_seg = label_to_color_image(seg, colormap)
# Show image + mask
pred_img = np.array(input_img) * 0.5 + color_seg * 0.5
pred_img = pred_img.astype(np.uint8)
# Draw plot
fig = draw_plot(pred_img, seg, colormap, labels_list)
return fig
# ์—ฌ๋Ÿฌ๋ถ„์ด ๊ฐ€์ง„ labels.txt ํŒŒ์ผ์˜ ๋‚ด์šฉ์„ labels_list์— ํ• ๋‹นํ•˜์„ธ์š”.
labels_list = ["label1", "label2", ...]
demo = gr.Interface(
fn=huggingface_model,
inputs=gr.Image(shape=(1024, 1024)), # ์ž…๋ ฅ ์ด๋ฏธ์ง€ ํฌ๊ธฐ๋Š” ๋ชจ๋ธ์˜ ์ž…๋ ฅ ํฌ๊ธฐ์— ๋งž๊ฒŒ ์กฐ์ ˆํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.
outputs=["plot"],
examples=["person-1.jpg", "person-2.jpg", "person-3.jpg", "person-4.jpg", "person-5.jpg"],
allow_flagging='never'
)
demo.launch()