sna89 commited on
Commit
7f36774
·
1 Parent(s): a3abfb8

Add application file

Browse files
Files changed (1) hide show
  1. app.py +47 -0
app.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import (
2
+ MaskFormerImageProcessor,
3
+ AutoImageProcessor,
4
+ MaskFormerForInstanceSegmentation,
5
+ )
6
+ import torch
7
+ from torchvision import transforms
8
+ import matplotlib.pyplot as plt
9
+ import gradio as gr
10
+ import numpy as np
11
+
12
+ processor = AutoImageProcessor.from_pretrained("facebook/maskformer-swin-base-coco")
13
+ model = MaskFormerForInstanceSegmentation.from_pretrained(
14
+ "sna89/segmentation_model"
15
+ )
16
+
17
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
+ model = model.to(device)
19
+
20
+ def segment_image(img):
21
+ img = processor(img, return_tensors="pt")
22
+ img = img.to(device)
23
+ with torch.no_grad():
24
+ outputs = model(**img)
25
+
26
+ predicted_semantic_map = processor.post_process_semantic_segmentation(
27
+ outputs, target_sizes=[image.size[::-1]]
28
+ )[0]
29
+
30
+ fig, ax = plt.subplots(figsize=(5, 5))
31
+ plt.axis('off')
32
+ plt.imshow(predicted_semantic_map.to("cpu"))
33
+ fig.canvas.draw() # Render the figure
34
+ image_array = np.array(fig.canvas.renderer.buffer_rgba())
35
+ return image_array
36
+ # return predicted_semantic_map.to("cpu").numpy()
37
+
38
+ demo = gr.Interface(
39
+ fn=segment_image,
40
+ inputs=gr.Image(type="pil"),
41
+ outputs=gr.Image(type="pil"),
42
+ title="Semantic segmentation for sidewalk dataset",
43
+ examples=[["image.jpg"], ["image (1).jpg"]],
44
+ live=True
45
+ )
46
+
47
+ demo.launch(share=True)