LeeRuben commited on
Commit
e45de72
·
1 Parent(s): 51b512a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -5
app.py CHANGED
@@ -266,11 +266,44 @@ def sepia(input_img):
266
  fig = draw_plot(pred_img, seg)
267
  return fig
268
 
269
- demo = gr.Interface(fn=sepia,
270
- inputs=gr.Image(shape=(400, 600)),
271
- outputs=['plot'],
272
- examples=["c-1.jpg", "c-2.jpg", "c-3.jpg"],
273
- allow_flagging='never')
 
274
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
275
 
276
  demo.launch()
 
266
  fig = draw_plot(pred_img, seg)
267
  return fig
268
 
269
+ def custom_interface(input_image):
270
+ # Process input image
271
+ input_img = Image.fromarray(input_image)
272
+ inputs = feature_extractor(images=input_img, return_tensors="tf")
273
+ outputs = model(**inputs)
274
+ logits = outputs.logits
275
 
276
+ logits = tf.transpose(logits, [0, 2, 3, 1])
277
+ logits = tf.image.resize(
278
+ logits, input_img.size[::-1]
279
+ )
280
+ seg = tf.math.argmax(logits, axis=-1)[0]
281
+
282
+ color_seg = np.zeros(
283
+ (seg.shape[0], seg.shape[1], 3), dtype=np.uint8
284
+ )
285
+ for label, color in enumerate(colormap):
286
+ color_seg[seg.numpy() == label, :] = color
287
+
288
+ # Combine original image and mask
289
+ pred_img = np.array(input_img) * 0.5 + color_seg * 0.5
290
+ pred_img = pred_img.astype(np.uint8)
291
+
292
+ # Create a figure with custom style
293
+ fig = plt.figure(figsize=(10, 10))
294
+ plt.imshow(pred_img)
295
+ plt.axis('off')
296
+ plt.title("Segmented Image", fontsize=18, color="#333") # Customize title
297
+ return fig
298
+
299
+ # Create the Gradio interface with the custom style
300
+ demo = gr.Interface(
301
+ fn=custom_interface,
302
+ inputs=gr.Image(shape=(400, 600)),
303
+ outputs=["plot"],
304
+ examples=["c-1.jpg", "c-2.jpg", "c-3.jpg"],
305
+ allow_flagging="never",
306
+ live=False # Disabling live updates for this example
307
+ )
308
 
309
  demo.launch()