Walid-Ahmed commited on
Commit
6fab0c6
·
verified ·
1 Parent(s): 22fa9e8

Create app.y

Browse files
Files changed (1) hide show
  1. app.y +75 -0
app.y ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import pipeline
3
+ from PIL import Image
4
+ import numpy as np
5
+ import matplotlib.pyplot as plt
6
+
7
+ # Load the segmentation pipeline
8
+ pipe = pipeline("image-segmentation", model="mattmdjaga/segformer_b2_clothes")
9
+
10
+ # Save the example image locally
11
+ url = "https://plus.unsplash.com/premium_photo-1673210886161-bfcc40f54d1f?ixlib=rb-4.0.3&ixid=MnwxMjA3fDB8MHxzZWFyY2h8MXx8cGVyc29uJTIwc3RhbmRpbmd8ZW58MHx8MHx8&w=1000&q=80"
12
+ image = Image.open(requests.get(url, stream=True).raw)
13
+ image.save("example_image.jpg") # Save the image locally
14
+
15
+ # Your predefined label dictionary
16
+ label_dict = {
17
+ 0: "Background",
18
+ 1: "Hat",
19
+ 2: "Hair",
20
+ 3: "Sunglasses",
21
+ 4: "Upper-clothes",
22
+ 5: "Skirt",
23
+ 6: "Pants",
24
+ 7: "Dress",
25
+ 8: "Belt",
26
+ 9: "Left-shoe",
27
+ 10: "Right-shoe",
28
+ 11: "Face",
29
+ 12: "Left-leg",
30
+ 13: "Right-leg",
31
+ 14: "Left-arm",
32
+ 15: "Right-arm",
33
+ 16: "Bag",
34
+ 17: "Scarf",
35
+ }
36
+
37
+ # Function to process the image and generate the segmentation map
38
+ def segment_image(image):
39
+ # Perform segmentation
40
+ result = pipe(image)
41
+
42
+ # Initialize an empty array for the segmentation map
43
+ image_width, image_height = result[0]["mask"].size
44
+ segmentation_map = np.zeros((image_height, image_width), dtype=np.uint8)
45
+
46
+ # Combine masks into a single segmentation map
47
+ for idx, entry in enumerate(result):
48
+ mask = np.array(entry["mask"]) # Convert the PIL mask to a NumPy array
49
+ segmentation_map[mask > 0] = idx # Assign the class index
50
+
51
+ # Create a matplotlib figure and visualize the segmentation map
52
+ plt.figure(figsize=(8, 8))
53
+ plt.imshow(segmentation_map, cmap="tab20") # Visualize using a colormap
54
+ cbar = plt.colorbar(ticks=range(len(label_dict)), label="Classes")
55
+ cbar.ax.set_yticklabels([label_dict[i] for i in range(len(label_dict))])
56
+ plt.title("Combined Segmentation Map")
57
+ plt.axis("off")
58
+
59
+ # Save the figure as a PIL image for Gradio
60
+ plt.savefig("segmented_output.png", bbox_inches="tight") # Save as a temporary file
61
+ plt.close() # Close the figure to free memory
62
+ return Image.open("segmented_output.png")
63
+
64
+ # Gradio interface
65
+ interface = gr.Interface(
66
+ fn=segment_image,
67
+ inputs=gr.Image(type="pil"), # Input is an image
68
+ outputs=gr.Image(type="pil"), # Output is an image with the colormap
69
+ examples=["example_image.jpg"], # Use the saved image as an example
70
+ title="Image Segmentation with Colormap",
71
+ description="Upload an image, and the segmentation model will produce an output with a colormap applied to the segmented classes."
72
+ )
73
+
74
+ # Launch the app
75
+ interface.launch()