Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,5 +1,8 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
|
|
|
|
|
|
|
|
|
|
| 3 |
import numpy as np
|
| 4 |
from PIL import Image
|
| 5 |
import tensorflow as tf
|
|
@@ -167,6 +170,58 @@ def ade_palette():
|
|
| 167 |
[92, 0, 255],
|
| 168 |
]
|
| 169 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 170 |
def sepia(input_img):
|
| 171 |
input_img = Image.fromarray(input_img)
|
| 172 |
|
|
@@ -194,8 +249,10 @@ def sepia(input_img):
|
|
| 194 |
# Show image + mask
|
| 195 |
pred_img = np.array(input_img) * 0.5 + color_seg * 0.5
|
| 196 |
pred_img = pred_img.astype(np.uint8)
|
| 197 |
-
return pred_img
|
| 198 |
|
| 199 |
-
|
|
|
|
|
|
|
|
|
|
| 200 |
|
| 201 |
demo.launch()
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
|
| 3 |
+
import pandas as pd
|
| 4 |
+
from matplotlib import gridspec
|
| 5 |
+
import matplotlib.pyplot as plt
|
| 6 |
import numpy as np
|
| 7 |
from PIL import Image
|
| 8 |
import tensorflow as tf
|
|
|
|
| 170 |
[92, 0, 255],
|
| 171 |
]
|
| 172 |
|
| 173 |
+
def label_to_color_image(label):
|
| 174 |
+
"""Adds color defined by the dataset colormap to the label.
|
| 175 |
+
|
| 176 |
+
Args:
|
| 177 |
+
label: A 2D array with integer type, storing the segmentation label.
|
| 178 |
+
|
| 179 |
+
Returns:
|
| 180 |
+
result: A 2D array with floating type. The element of the array
|
| 181 |
+
is the color indexed by the corresponding element in the input label
|
| 182 |
+
to the PASCAL color map.
|
| 183 |
+
|
| 184 |
+
Raises:
|
| 185 |
+
ValueError: If label is not of rank 2 or its value is larger than color
|
| 186 |
+
map maximum entry.
|
| 187 |
+
"""
|
| 188 |
+
if label.ndim != 2:
|
| 189 |
+
raise ValueError("Expect 2-D input label")
|
| 190 |
+
|
| 191 |
+
colormap = np.asarray(ade_palette())
|
| 192 |
+
|
| 193 |
+
if np.max(label) >= len(colormap):
|
| 194 |
+
raise ValueError("label value too large.")
|
| 195 |
+
|
| 196 |
+
return colormap[label]
|
| 197 |
+
|
| 198 |
+
def draw_plot(pred_img, seg):
|
| 199 |
+
fig = plt.figure(figsize=(20, 15))
|
| 200 |
+
|
| 201 |
+
grid_spec = gridspec.GridSpec(1, 2, width_ratios=[6, 1])
|
| 202 |
+
|
| 203 |
+
plt.subplot(grid_spec[0])
|
| 204 |
+
plt.imshow(pred_img)
|
| 205 |
+
plt.axis('off')
|
| 206 |
+
|
| 207 |
+
ade20k_labels_info = pd.read_csv(
|
| 208 |
+
"https://raw.githubusercontent.com/CSAILVision/sceneparsing/master/objectInfo150.csv"
|
| 209 |
+
)
|
| 210 |
+
labels_list = list(ade20k_labels_info["Name"])
|
| 211 |
+
|
| 212 |
+
LABEL_NAMES = np.asarray(labels_list)
|
| 213 |
+
FULL_LABEL_MAP = np.arange(len(LABEL_NAMES)).reshape(len(LABEL_NAMES), 1)
|
| 214 |
+
FULL_COLOR_MAP = label_to_color_image(FULL_LABEL_MAP)
|
| 215 |
+
|
| 216 |
+
unique_labels = np.unique(seg.numpy().astype("uint8"))
|
| 217 |
+
ax = plt.subplot(grid_spec[1])
|
| 218 |
+
plt.imshow(FULL_COLOR_MAP[unique_labels].astype(np.uint8), interpolation="nearest")
|
| 219 |
+
ax.yaxis.tick_right()
|
| 220 |
+
plt.yticks(range(len(unique_labels)), LABEL_NAMES[unique_labels])
|
| 221 |
+
plt.xticks([], [])
|
| 222 |
+
ax.tick_params(width=0.0, labelsize=25)
|
| 223 |
+
return fig
|
| 224 |
+
|
| 225 |
def sepia(input_img):
|
| 226 |
input_img = Image.fromarray(input_img)
|
| 227 |
|
|
|
|
| 249 |
# Show image + mask
|
| 250 |
pred_img = np.array(input_img) * 0.5 + color_seg * 0.5
|
| 251 |
pred_img = pred_img.astype(np.uint8)
|
|
|
|
| 252 |
|
| 253 |
+
fig = draw_plot(pred_img, seg)
|
| 254 |
+
return fig
|
| 255 |
+
|
| 256 |
+
demo = gr.Interface(sepia, gr.Image(shape=(200, 200)), outputs=['plot'], examples=["ADE_val_00000001.jpeg"])
|
| 257 |
|
| 258 |
demo.launch()
|