Update app.py
Browse files
app.py
CHANGED
@@ -31,33 +31,59 @@ label_dict = {
|
|
31 |
17: "Scarf",
|
32 |
}
|
33 |
|
|
|
|
|
|
|
34 |
# Function to process the image and generate the segmentation map
|
35 |
def segment_image(image):
|
36 |
# Perform segmentation
|
37 |
result = pipe(image)
|
38 |
-
|
39 |
# Initialize an empty array for the segmentation map
|
40 |
image_width, image_height = result[0]["mask"].size
|
41 |
segmentation_map = np.zeros((image_height, image_width), dtype=np.uint8)
|
42 |
-
|
43 |
# Combine masks into a single segmentation map
|
44 |
-
for
|
45 |
-
|
46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
# Create a matplotlib figure and visualize the segmentation map
|
49 |
plt.figure(figsize=(8, 8))
|
50 |
plt.imshow(segmentation_map, cmap="tab20") # Visualize using a colormap
|
51 |
-
|
52 |
-
|
53 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
plt.axis("off")
|
55 |
-
|
56 |
-
|
57 |
-
plt.savefig("segmented_output.png", bbox_inches="tight") # Save as a temporary file
|
58 |
-
plt.close() # Close the figure to free memory
|
59 |
return Image.open("segmented_output.png")
|
60 |
|
|
|
|
|
|
|
61 |
# Gradio interface
|
62 |
interface = gr.Interface(
|
63 |
fn=segment_image,
|
|
|
31 |
17: "Scarf",
|
32 |
}
|
33 |
|
34 |
+
# Function to process the image and generate the segmentation map
|
35 |
+
|
36 |
+
|
37 |
# Function to process the image and generate the segmentation map
|
38 |
def segment_image(image):
|
39 |
# Perform segmentation
|
40 |
result = pipe(image)
|
41 |
+
|
42 |
# Initialize an empty array for the segmentation map
|
43 |
image_width, image_height = result[0]["mask"].size
|
44 |
segmentation_map = np.zeros((image_height, image_width), dtype=np.uint8)
|
45 |
+
|
46 |
# Combine masks into a single segmentation map
|
47 |
+
for entry in result:
|
48 |
+
label = entry["label"] # Get the label (e.g., "Hair", "Upper-clothes")
|
49 |
+
mask = np.array(entry["mask"]) # Convert PIL Image to NumPy array
|
50 |
+
|
51 |
+
# Find the index of the label in the original label dictionary
|
52 |
+
class_idx = [key for key, value in label_dict.items() if value == label][0]
|
53 |
+
|
54 |
+
# Assign the correct class index to the segmentation map
|
55 |
+
segmentation_map[mask > 0] = class_idx
|
56 |
+
|
57 |
|
58 |
+
# Get the unique class indices in the segmentation map
|
59 |
+
unique_classes = np.unique(segmentation_map)
|
60 |
+
# Print the names of the detected classes
|
61 |
+
print("Detected Classes:")
|
62 |
+
for class_idx in unique_classes:
|
63 |
+
print(f"- {label_dict[class_idx]}")
|
64 |
+
|
65 |
# Create a matplotlib figure and visualize the segmentation map
|
66 |
plt.figure(figsize=(8, 8))
|
67 |
plt.imshow(segmentation_map, cmap="tab20") # Visualize using a colormap
|
68 |
+
# Get the unique class indices in the segmentation map
|
69 |
+
unique_classes = np.unique(segmentation_map)
|
70 |
+
|
71 |
+
|
72 |
+
# Filter the label dictionary to include only detected classes
|
73 |
+
filtered_labels = {idx: label_dict[idx] for idx in unique_classes}
|
74 |
+
|
75 |
+
# Create a dynamic colorbar with only the detected classes
|
76 |
+
cbar = plt.colorbar(ticks=unique_classes)
|
77 |
+
cbar.ax.set_yticklabels([filtered_labels[i] for i in unique_classes])
|
78 |
+
plt.title("Segmented Image with Detected Classes")
|
79 |
plt.axis("off")
|
80 |
+
plt.savefig("segmented_output.png", bbox_inches="tight")
|
81 |
+
plt.close()
|
|
|
|
|
82 |
return Image.open("segmented_output.png")
|
83 |
|
84 |
+
|
85 |
+
|
86 |
+
|
87 |
# Gradio interface
|
88 |
interface = gr.Interface(
|
89 |
fn=segment_image,
|