chansung's picture
Update app.py
e8f9fb4
raw
history blame
3.55 kB
import gradio as gr
import sys
import csv
import numpy as np
import cv2
from matplotlib import gridspec
import matplotlib.pyplot as plt
import onnxruntime as ort
ade_palette = []
labels_list = []
csv.field_size_limit(sys.maxsize)
with open(r'labels.txt', 'r') as fp:
for line in fp:
labels_list.append(line[:-1])
with open(r'ade_palette.txt', 'r') as fp:
for line in fp:
tmp_list = list(map(int, line[:-1].strip('][').split(', ')))
ade_palette.append(tmp_list)
colormap = np.asarray(ade_palette)
model_filename = 'segformer-b5-finetuned-ade-640-640.onnx'
sess = ort.InferenceSession(model_filename)
def label_to_color_image(label):
if label.ndim != 2:
raise ValueError("Expect 2-D input label")
if np.max(label) >= len(colormap):
raise ValueError("label value too large.")
return colormap[label]
def draw_plot(pred_img, seg):
fig = plt.figure(figsize=(20, 15))
grid_spec = gridspec.GridSpec(1, 2, width_ratios=[6, 1])
plt.subplot(grid_spec[0])
plt.imshow(pred_img)
plt.axis('off')
LABEL_NAMES = np.asarray(labels_list)
FULL_LABEL_MAP = np.arange(len(LABEL_NAMES)).reshape(len(LABEL_NAMES), 1)
FULL_COLOR_MAP = label_to_color_image(FULL_LABEL_MAP)
unique_labels = np.unique(seg)
ax = plt.subplot(grid_spec[1])
plt.imshow(FULL_COLOR_MAP[unique_labels].astype(np.uint8), interpolation="nearest")
ax.yaxis.tick_right()
plt.yticks(range(len(unique_labels)), LABEL_NAMES[unique_labels])
plt.xticks([], [])
ax.tick_params(width=0.0, labelsize=25)
return fig
def sepia(input_img):
img = cv2.imread(input_img)
img = cv2.resize(img, (640, 640)).astype(np.float32)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img_batch = np.expand_dims(img, axis=0)
img_batch = np.transpose(img_batch, (0, 3, 1, 2))
logits = sess.run(None, {"pixel_values": img_batch})[0]
logits = np.transpose(logits, (0, 2, 3, 1))
seg = np.argmax(logits, axis=-1)[0].astype('float32')
seg = cv2.resize(seg, (640, 640)).astype('uint8')
color_seg = np.zeros(
(seg.shape[0], seg.shape[1], 3), dtype=np.uint8
) # height, width, 3
for label, color in enumerate(colormap):
color_seg[seg == label, :] = color
# Convert to BGR
color_seg = color_seg[..., ::-1]
# Show image + mask
pred_img = img * 0.5 + color_seg * 0.5
pred_img = pred_img.astype(np.uint8)
fig = draw_plot(pred_img, seg)
return fig
title = "SegFormer(ADE20k) in TensorFlow"
description = """
This is demo TensorFlow SegFormer from πŸ€— `transformers` official package. The pre-trained model was trained to segment scene specific images. We are **currently using ONNX model converted from the TensorFlow based SegFormer to improve the latency**. The average latency of an inference is **21** and **8** seconds for TensorFlow and ONNX converted models respectively (in [Colab](https://github.com/deep-diver/segformer-tf-transformers/blob/main/notebooks/TFSegFormer_ONNX.ipynb)). Check out the [repository](https://github.com/deep-diver/segformer-tf-transformers) to find out how to make inference, finetune the model with custom dataset, and further information.
"""
demo = gr.Interface(sepia,
gr.inputs.Image(type="filepath"),
outputs=['plot'],
examples=["ADE_val_00000001.jpeg"],
allow_flagging='never',
title=title,
description=description)
demo.launch()