File size: 3,546 Bytes
872f59a
 
f2c6c21
 
eba227e
 
538bf82
 
eba227e
872f59a
0c5b5d7
68b644b
 
f2c6c21
 
68b644b
 
 
0934f7b
0c5b5d7
 
 
 
 
 
 
 
 
 
 
3551e60
538bf82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eba227e
538bf82
 
 
 
 
 
 
 
872f59a
6934201
 
eba227e
 
 
 
 
 
 
 
 
 
872f59a
 
 
 
3551e60
872f59a
 
 
 
 
 
eba227e
872f59a
 
538bf82
 
 
d7a7630
 
 
e8f9fb4
d7a7630
 
 
a93cc9f
6934201
a93cc9f
f2c6c21
d7a7630
 
f0c7c16
872f59a
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import gradio as gr

import sys
import csv
import numpy as np
import cv2
from matplotlib import gridspec
import matplotlib.pyplot as plt
import onnxruntime as ort

ade_palette = []
labels_list = []

csv.field_size_limit(sys.maxsize)

with open(r'labels.txt', 'r') as fp:
    for line in fp:
        labels_list.append(line[:-1])

with open(r'ade_palette.txt', 'r') as fp:
    for line in fp:
      tmp_list = list(map(int, line[:-1].strip('][').split(', ')))
      ade_palette.append(tmp_list)

colormap = np.asarray(ade_palette)

model_filename = 'segformer-b5-finetuned-ade-640-640.onnx'
sess = ort.InferenceSession(model_filename)


def label_to_color_image(label):
    if label.ndim != 2:
        raise ValueError("Expect 2-D input label")

    if np.max(label) >= len(colormap):
        raise ValueError("label value too large.")

    return colormap[label]

def draw_plot(pred_img, seg):
    fig = plt.figure(figsize=(20, 15))

    grid_spec = gridspec.GridSpec(1, 2, width_ratios=[6, 1])

    plt.subplot(grid_spec[0])
    plt.imshow(pred_img)
    plt.axis('off')

    LABEL_NAMES = np.asarray(labels_list)
    FULL_LABEL_MAP = np.arange(len(LABEL_NAMES)).reshape(len(LABEL_NAMES), 1)
    FULL_COLOR_MAP = label_to_color_image(FULL_LABEL_MAP)

    unique_labels = np.unique(seg)
    ax = plt.subplot(grid_spec[1])
    plt.imshow(FULL_COLOR_MAP[unique_labels].astype(np.uint8), interpolation="nearest")
    ax.yaxis.tick_right()
    plt.yticks(range(len(unique_labels)), LABEL_NAMES[unique_labels])
    plt.xticks([], [])
    ax.tick_params(width=0.0, labelsize=25)
    return fig

def sepia(input_img):
    img = cv2.imread(input_img)
    img = cv2.resize(img, (640, 640)).astype(np.float32)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img_batch = np.expand_dims(img, axis=0)
    img_batch = np.transpose(img_batch, (0, 3, 1, 2))
    
    logits = sess.run(None, {"pixel_values": img_batch})[0]

    logits = np.transpose(logits, (0, 2, 3, 1))
    seg = np.argmax(logits, axis=-1)[0].astype('float32')
    seg = cv2.resize(seg, (640, 640)).astype('uint8')
    
    color_seg = np.zeros(
        (seg.shape[0], seg.shape[1], 3), dtype=np.uint8
    )  # height, width, 3

    for label, color in enumerate(colormap):
        color_seg[seg == label, :] = color

    # Convert to BGR
    color_seg = color_seg[..., ::-1]

    # Show image + mask
    pred_img = img * 0.5 + color_seg * 0.5
    pred_img = pred_img.astype(np.uint8)    

    fig = draw_plot(pred_img, seg)
    return fig

title = "SegFormer(ADE20k) in TensorFlow"
description = """

This is demo TensorFlow SegFormer from 🤗 `transformers` official package. The pre-trained model was trained to segment scene specific images. We are **currently using ONNX model converted from the TensorFlow based SegFormer to improve the latency**. The average latency of an inference is **21** and **8** seconds for TensorFlow and ONNX converted models respectively (in [Colab](https://github.com/deep-diver/segformer-tf-transformers/blob/main/notebooks/TFSegFormer_ONNX.ipynb)). Check out the [repository](https://github.com/deep-diver/segformer-tf-transformers) to find out how to make inference, finetune the model with custom dataset, and further information.

"""

demo = gr.Interface(sepia, 
                    gr.inputs.Image(type="filepath"), 
                    outputs=['plot'], 
                    examples=["ADE_val_00000001.jpeg"],
                    allow_flagging='never',
                    title=title,
                    description=description)

demo.launch()