Walid-Ahmed commited on
Commit
50474c7
·
verified ·
1 Parent(s): 53e0369

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -12
app.py CHANGED
@@ -31,33 +31,59 @@ label_dict = {
31
  17: "Scarf",
32
  }
33
 
 
 
 
34
  # Function to process the image and generate the segmentation map
35
  def segment_image(image):
36
  # Perform segmentation
37
  result = pipe(image)
38
-
39
  # Initialize an empty array for the segmentation map
40
  image_width, image_height = result[0]["mask"].size
41
  segmentation_map = np.zeros((image_height, image_width), dtype=np.uint8)
42
-
43
  # Combine masks into a single segmentation map
44
- for idx, entry in enumerate(result):
45
- mask = np.array(entry["mask"]) # Convert the PIL mask to a NumPy array
46
- segmentation_map[mask > 0] = idx # Assign the class index
 
 
 
 
 
 
 
47
 
 
 
 
 
 
 
 
48
  # Create a matplotlib figure and visualize the segmentation map
49
  plt.figure(figsize=(8, 8))
50
  plt.imshow(segmentation_map, cmap="tab20") # Visualize using a colormap
51
- cbar = plt.colorbar(ticks=range(len(label_dict)), label="Classes")
52
- cbar.ax.set_yticklabels([label_dict[i] for i in range(len(label_dict))])
53
- plt.title("Combined Segmentation Map")
 
 
 
 
 
 
 
 
54
  plt.axis("off")
55
-
56
- # Save the figure as a PIL image for Gradio
57
- plt.savefig("segmented_output.png", bbox_inches="tight") # Save as a temporary file
58
- plt.close() # Close the figure to free memory
59
  return Image.open("segmented_output.png")
60
 
 
 
 
61
  # Gradio interface
62
  interface = gr.Interface(
63
  fn=segment_image,
 
31
  17: "Scarf",
32
  }
33
 
34
+ # Function to process the image and generate the segmentation map
35
+
36
+
37
  # Function to process the image and generate the segmentation map
38
  def segment_image(image):
39
  # Perform segmentation
40
  result = pipe(image)
41
+
42
  # Initialize an empty array for the segmentation map
43
  image_width, image_height = result[0]["mask"].size
44
  segmentation_map = np.zeros((image_height, image_width), dtype=np.uint8)
45
+
46
  # Combine masks into a single segmentation map
47
+ for entry in result:
48
+ label = entry["label"] # Get the label (e.g., "Hair", "Upper-clothes")
49
+ mask = np.array(entry["mask"]) # Convert PIL Image to NumPy array
50
+
51
+ # Find the index of the label in the original label dictionary
52
+ class_idx = [key for key, value in label_dict.items() if value == label][0]
53
+
54
+ # Assign the correct class index to the segmentation map
55
+ segmentation_map[mask > 0] = class_idx
56
+
57
 
58
+ # Get the unique class indices in the segmentation map
59
+ unique_classes = np.unique(segmentation_map)
60
+ # Print the names of the detected classes
61
+ print("Detected Classes:")
62
+ for class_idx in unique_classes:
63
+ print(f"- {label_dict[class_idx]}")
64
+
65
  # Create a matplotlib figure and visualize the segmentation map
66
  plt.figure(figsize=(8, 8))
67
  plt.imshow(segmentation_map, cmap="tab20") # Visualize using a colormap
68
+ # Get the unique class indices in the segmentation map
69
+ unique_classes = np.unique(segmentation_map)
70
+
71
+
72
+ # Filter the label dictionary to include only detected classes
73
+ filtered_labels = {idx: label_dict[idx] for idx in unique_classes}
74
+
75
+ # Create a dynamic colorbar with only the detected classes
76
+ cbar = plt.colorbar(ticks=unique_classes)
77
+ cbar.ax.set_yticklabels([filtered_labels[i] for i in unique_classes])
78
+ plt.title("Segmented Image with Detected Classes")
79
  plt.axis("off")
80
+ plt.savefig("segmented_output.png", bbox_inches="tight")
81
+ plt.close()
 
 
82
  return Image.open("segmented_output.png")
83
 
84
+
85
+
86
+
87
  # Gradio interface
88
  interface = gr.Interface(
89
  fn=segment_image,