Spaces:
Runtime error
Runtime error
File size: 3,603 Bytes
f2c6c21 7e156dc eba227e 7e156dc 538bf82 7e156dc eba227e 7e156dc 872f59a 0c5b5d7 68b644b f2c6c21 7e156dc 68b644b 0934f7b 7e156dc 0c5b5d7 7e156dc 0c5b5d7 7e156dc 0c5b5d7 3551e60 538bf82 7e156dc 538bf82 7e156dc 538bf82 eba227e 538bf82 7e156dc 872f59a 6934201 eba227e 7e156dc eba227e 7e156dc 872f59a 3551e60 872f59a eba227e 7e156dc 872f59a 538bf82 7e156dc d7a7630 e8f9fb4 d7a7630 7e156dc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 |
import csv
import os
import sys
import cv2
import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
import onnxruntime as ort
from matplotlib import gridspec
ade_palette = []
labels_list = []
csv.field_size_limit(sys.maxsize)
with open(r"labels.txt", "r") as fp:
for line in fp:
labels_list.append(line[:-1])
with open(r"ade_palette.txt", "r") as fp:
for line in fp:
tmp_list = list(map(int, line[:-1].strip("][").split(", ")))
ade_palette.append(tmp_list)
colormap = np.asarray(ade_palette)
model_filename = "segformer-b5-finetuned-ade-640-640.onnx"
sess_options = ort.SessionOptions()
sess_options.intra_op_num_threads = os.cpu_count()
sess = ort.InferenceSession(
model_filename, sess_options, providers=["CPUExecutionProvider"]
)
def label_to_color_image(label):
if label.ndim != 2:
raise ValueError("Expect 2-D input label")
if np.max(label) >= len(colormap):
raise ValueError("label value too large.")
return colormap[label]
def draw_plot(pred_img, seg):
fig = plt.figure(figsize=(20, 15))
grid_spec = gridspec.GridSpec(1, 2, width_ratios=[6, 1])
plt.subplot(grid_spec[0])
plt.imshow(pred_img)
plt.axis("off")
LABEL_NAMES = np.asarray(labels_list)
FULL_LABEL_MAP = np.arange(len(LABEL_NAMES)).reshape(len(LABEL_NAMES), 1)
FULL_COLOR_MAP = label_to_color_image(FULL_LABEL_MAP)
unique_labels = np.unique(seg)
ax = plt.subplot(grid_spec[1])
plt.imshow(FULL_COLOR_MAP[unique_labels].astype(np.uint8), interpolation="nearest")
ax.yaxis.tick_right()
plt.yticks(range(len(unique_labels)), LABEL_NAMES[unique_labels])
plt.xticks([], [])
ax.tick_params(width=0.0, labelsize=25)
return fig
def sepia(input_img):
img = cv2.imread(input_img)
img = cv2.resize(img, (640, 640)).astype(np.float32)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img_batch = np.expand_dims(img, axis=0)
img_batch = np.transpose(img_batch, (0, 3, 1, 2))
logits = sess.run(None, {"pixel_values": img_batch})[0]
logits = np.transpose(logits, (0, 2, 3, 1))
seg = np.argmax(logits, axis=-1)[0].astype("float32")
seg = cv2.resize(seg, (640, 640)).astype("uint8")
color_seg = np.zeros(
(seg.shape[0], seg.shape[1], 3), dtype=np.uint8
) # height, width, 3
for label, color in enumerate(colormap):
color_seg[seg == label, :] = color
# Convert to BGR
color_seg = color_seg[..., ::-1]
# Show image + mask
pred_img = img * 0.5 + color_seg * 0.5
pred_img = pred_img.astype(np.uint8)
fig = draw_plot(pred_img, seg)
return fig
title = "SegFormer(ADE20k) in TensorFlow"
description = """
This is demo TensorFlow SegFormer from 🤗 `transformers` official package. The pre-trained model was trained to segment scene specific images. We are **currently using ONNX model converted from the TensorFlow based SegFormer to improve the latency**. The average latency of an inference is **21** and **8** seconds for TensorFlow and ONNX converted models respectively (in [Colab](https://github.com/deep-diver/segformer-tf-transformers/blob/main/notebooks/TFSegFormer_ONNX.ipynb)). Check out the [repository](https://github.com/deep-diver/segformer-tf-transformers) to find out how to make inference, finetune the model with custom dataset, and further information.
"""
demo = gr.Interface(
sepia,
gr.inputs.Image(type="filepath"),
outputs=["plot"],
examples=["ADE_val_00000001.jpeg"],
allow_flagging="never",
title=title,
description=description,
)
demo.launch()
|