DDingcheol commited on
Commit
1129909
·
1 Parent(s): 424ddf2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -26
app.py CHANGED
@@ -1,28 +1,30 @@
1
  import gradio as gr
2
-
3
- from matplotlib import gridspec
4
- import matplotlib.pyplot as plt
5
  import numpy as np
6
- from PIL import Image
7
  import tensorflow as tf
8
- from transformers import SegformerFeatureExtractor, TFSegformerForSemanticSegmentation
 
 
 
9
 
10
- feature_extractor = SegformerFeatureExtractor.from_pretrained(
11
- "nvidia/segformer-b0-finetuned-cityscapes-1024-1024"
12
- )
13
- model = TFSegformerForSemanticSegmentation.from_pretrained(
14
- "nvidia/segformer-b0-finetuned-cityscapes-1024-1024"
15
- )
16
 
 
 
 
 
 
 
 
17
  def ade_palette():
18
- """ADE20K palette that maps each class to RGB values."""
19
  return [
20
  [255, 0, 0],
21
  [255, 187, 0],
22
  [255, 228, 0],
23
  [29, 219, 22],
24
  [178, 204, 255],
25
- [1, 0, 255],
26
  [165, 102, 255],
27
  [217, 65, 197],
28
  [116, 116, 116],
@@ -37,30 +39,25 @@ def ade_palette():
37
  [153, 0, 76]
38
  ]
39
 
40
- labels_list = []
41
-
42
- with open(r'labels.txt', 'r') as fp:
43
- for line in fp:
44
- labels_list.append(line[:-1])
45
-
46
  colormap = np.asarray(ade_palette())
47
 
 
48
  def label_to_color_image(label):
49
  if label.ndim != 2:
50
  raise ValueError("Expect 2-D input label")
51
-
52
  if np.max(label) >= len(colormap):
53
  raise ValueError("label value too large.")
54
  return colormap[label]
55
 
 
56
  def draw_plot(pred_img, seg):
57
  fig = plt.figure(figsize=(20, 15))
58
-
59
  grid_spec = gridspec.GridSpec(1, 2, width_ratios=[6, 1])
60
 
61
  plt.subplot(grid_spec[0])
62
  plt.imshow(pred_img)
63
  plt.axis('off')
 
64
  LABEL_NAMES = np.asarray(labels_list)
65
  FULL_LABEL_MAP = np.arange(len(LABEL_NAMES)).reshape(len(LABEL_NAMES), 1)
66
  FULL_COLOR_MAP = label_to_color_image(FULL_LABEL_MAP)
@@ -74,6 +71,7 @@ def draw_plot(pred_img, seg):
74
  ax.tick_params(width=0.0, labelsize=25)
75
  return fig
76
 
 
77
  def sepia(input_img):
78
  input_img = Image.fromarray(input_img)
79
 
@@ -84,27 +82,28 @@ def sepia(input_img):
84
  logits = tf.transpose(logits, [0, 2, 3, 1])
85
  logits = tf.image.resize(
86
  logits, input_img.size[::-1]
87
- ) # We reverse the shape of `image` because `image.size` returns width and height.
88
- seg = tf.math.argmax(logits, axis=-1)[0]
89
 
 
90
  color_seg = np.zeros(
91
  (seg.shape[0], seg.shape[1], 3), dtype=np.uint8
92
- ) # height, width, 3
 
93
  for label, color in enumerate(colormap):
94
  color_seg[seg.numpy() == label, :] = color
95
 
96
- # Show image + mask
97
  pred_img = np.array(input_img) * 0.5 + color_seg * 0.5
98
  pred_img = pred_img.astype(np.uint8)
99
 
100
  fig = draw_plot(pred_img, seg)
101
  return fig
102
 
 
103
  demo = gr.Interface(fn=sepia,
104
  inputs=gr.Image(shape=(800, 1200)),
105
  outputs=['plot'],
106
  examples=["citiscape-1.jpg", "citiscape-2.jpg"],
107
  allow_flagging='never')
108
 
109
-
110
  demo.launch()
 
1
  import gradio as gr
 
 
 
2
  import numpy as np
 
3
  import tensorflow as tf
4
+ from PIL import Image
5
+ from transformers import SegformerImageProcessor, TFSegformerForSemanticSegmentation
6
+ import matplotlib.pyplot as plt
7
+ from matplotlib import gridspec
8
 
9
+ # Load model and feature extractor
10
+ feature_extractor = SegformerImageProcessor.from_pretrained("nvidia/segformer-b0-finetuned-cityscapes-1024-1024")
11
+ model = TFSegformerForSemanticSegmentation.from_pretrained("nvidia/segformer-b0-finetuned-cityscapes-1024-1024")
 
 
 
12
 
13
+ # Load labels
14
+ labels_list = []
15
+ with open(r'labels.txt', 'r') as fp:
16
+ for line in fp:
17
+ labels_list.append(line[:-1])
18
+
19
+ # ADE20K palette
20
  def ade_palette():
 
21
  return [
22
  [255, 0, 0],
23
  [255, 187, 0],
24
  [255, 228, 0],
25
  [29, 219, 22],
26
  [178, 204, 255],
27
+ [1, 0, 255],
28
  [165, 102, 255],
29
  [217, 65, 197],
30
  [116, 116, 116],
 
39
  [153, 0, 76]
40
  ]
41
 
 
 
 
 
 
 
42
  colormap = np.asarray(ade_palette())
43
 
44
+ # Label to color image mapping
45
  def label_to_color_image(label):
46
  if label.ndim != 2:
47
  raise ValueError("Expect 2-D input label")
 
48
  if np.max(label) >= len(colormap):
49
  raise ValueError("label value too large.")
50
  return colormap[label]
51
 
52
+ # Draw segmentation plot
53
  def draw_plot(pred_img, seg):
54
  fig = plt.figure(figsize=(20, 15))
 
55
  grid_spec = gridspec.GridSpec(1, 2, width_ratios=[6, 1])
56
 
57
  plt.subplot(grid_spec[0])
58
  plt.imshow(pred_img)
59
  plt.axis('off')
60
+
61
  LABEL_NAMES = np.asarray(labels_list)
62
  FULL_LABEL_MAP = np.arange(len(LABEL_NAMES)).reshape(len(LABEL_NAMES), 1)
63
  FULL_COLOR_MAP = label_to_color_image(FULL_LABEL_MAP)
 
71
  ax.tick_params(width=0.0, labelsize=25)
72
  return fig
73
 
74
+ # Sepia function
75
  def sepia(input_img):
76
  input_img = Image.fromarray(input_img)
77
 
 
82
  logits = tf.transpose(logits, [0, 2, 3, 1])
83
  logits = tf.image.resize(
84
  logits, input_img.size[::-1]
85
+ )
 
86
 
87
+ seg = tf.math.argmax(logits, axis=-1)[0]
88
  color_seg = np.zeros(
89
  (seg.shape[0], seg.shape[1], 3), dtype=np.uint8
90
+ )
91
+
92
  for label, color in enumerate(colormap):
93
  color_seg[seg.numpy() == label, :] = color
94
 
 
95
  pred_img = np.array(input_img) * 0.5 + color_seg * 0.5
96
  pred_img = pred_img.astype(np.uint8)
97
 
98
  fig = draw_plot(pred_img, seg)
99
  return fig
100
 
101
+ # Gradio Interface
102
  demo = gr.Interface(fn=sepia,
103
  inputs=gr.Image(shape=(800, 1200)),
104
  outputs=['plot'],
105
  examples=["citiscape-1.jpg", "citiscape-2.jpg"],
106
  allow_flagging='never')
107
 
108
+ # Launch the interface
109
  demo.launch()