devendergarg14 commited on
Commit
d9dc7e6
·
verified ·
1 Parent(s): 1b891e3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -36
app.py CHANGED
@@ -2,10 +2,23 @@ import gradio as gr
2
  import numpy as np
3
  import matplotlib.pyplot as plt
4
  import matplotlib.colors as mcolors
 
 
 
 
5
 
6
- def process_mask(file, category_to_hide):
7
- # Load the .npy file
8
- data = np.load(file.name)
 
 
 
 
 
 
 
 
 
9
 
10
  # Define grouped categories
11
  grouped_mapping = {
@@ -16,51 +29,31 @@ def process_mask(file, category_to_hide):
16
  "Skin (Hands, Feet, Body)": [4, 5, 6, 7, 10, 11, 13, 14, 15, 16, 19, 20, 21] # Hands, Feet, Arms, Legs, Torso
17
  }
18
 
19
- # Assign colors for the categories
20
- group_colors = {
21
- "Background": "black",
22
- "Clothes": "magenta",
23
- "Face": "orange",
24
- "Hair": "brown",
25
- "Skin (Hands, Feet, Body)": "cyan"
26
- }
27
 
28
- # Create a new mask with grouped categories
29
- grouped_mask = np.zeros((*data.shape, 3), dtype=np.uint8)
30
-
31
- for category, indices in grouped_mapping.items():
32
- if category == category_to_hide:
33
- continue # Skip applying colors for the selected category to hide
34
- for idx in indices:
35
- mask = data == idx
36
- rgb = mcolors.to_rgb(group_colors[category]) # Convert color to RGB
37
- grouped_mask[mask] = [int(c * 255) for c in rgb]
38
-
39
- # Save the mask image
40
- fig, ax = plt.subplots(figsize=(6, 6))
41
- ax.imshow(grouped_mask)
42
- ax.axis("off")
43
- plt.tight_layout()
44
 
45
- # Save to file for Gradio output
46
- output_path = "output_mask.png"
47
- plt.savefig(output_path, bbox_inches='tight', pad_inches=0)
48
- plt.close()
49
 
50
- return output_path
51
 
52
  # Define Gradio Interface
53
  demo = gr.Interface(
54
- fn=process_mask,
55
  inputs=[
56
- gr.File(label="Upload .npy Segmentation File"),
57
  gr.Radio([
58
  "Background", "Clothes", "Face", "Hair", "Skin (Hands, Feet, Body)"
59
  ], label="Select Category to Hide")
60
  ],
61
- outputs=gr.Image(label="Modified Segmentation Mask"),
62
  title="Segmentation Mask Editor",
63
- description="Upload a .npy segmentation file and select a category to mask (hide with black)."
64
  )
65
 
66
  if __name__ == "__main__":
 
2
  import numpy as np
3
  import matplotlib.pyplot as plt
4
  import matplotlib.colors as mcolors
5
+ from gradio_client import Client, handle_file
6
+ from PIL import Image
7
+ import requests
8
+ from io import BytesIO
9
 
10
+ def get_segmentation_mask(image_url):
11
+ client = Client("facebook/sapiens-seg")
12
+ result = client.predict(image=handle_file(image_url), model_name="1b", api_name="/process_image")
13
+ return np.load(result[2]) # Result[2] contains the .npy mask
14
+
15
+ def process_image(image, category_to_hide):
16
+ # Convert uploaded image to a PIL Image
17
+ image = Image.open(image.name).convert("RGB")
18
+
19
+ # Save temporarily and get the mask
20
+ image.save("temp_image.png")
21
+ mask_data = get_segmentation_mask("temp_image.png")
22
 
23
  # Define grouped categories
24
  grouped_mapping = {
 
29
  "Skin (Hands, Feet, Body)": [4, 5, 6, 7, 10, 11, 13, 14, 15, 16, 19, 20, 21] # Hands, Feet, Arms, Legs, Torso
30
  }
31
 
32
+ # Apply the mask over the original image
33
+ image_array = np.array(image)
34
+ masked_image = image_array.copy()
 
 
 
 
 
35
 
36
+ # Black out selected category
37
+ for idx in grouped_mapping[category_to_hide]:
38
+ masked_image[mask_data == idx] = [0, 0, 0]
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
+ # Convert back to PIL Image
41
+ result_image = Image.fromarray(masked_image)
 
 
42
 
43
+ return result_image
44
 
45
  # Define Gradio Interface
46
  demo = gr.Interface(
47
+ fn=process_image,
48
  inputs=[
49
+ gr.File(label="Upload an Image"),
50
  gr.Radio([
51
  "Background", "Clothes", "Face", "Hair", "Skin (Hands, Feet, Body)"
52
  ], label="Select Category to Hide")
53
  ],
54
+ outputs=gr.Image(label="Masked Image"),
55
  title="Segmentation Mask Editor",
56
+ description="Upload an image, generate a segmentation mask, and select a category to black out."
57
  )
58
 
59
  if __name__ == "__main__":