devendergarg14 commited on
Commit
6ca5592
·
verified ·
1 Parent(s): a11173b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -8
app.py CHANGED
@@ -10,9 +10,9 @@ from io import BytesIO
10
  def get_segmentation_mask(image_url):
11
  client = Client("facebook/sapiens-seg")
12
  result = client.predict(image=handle_file(image_url), model_name="1b", api_name="/process_image")
13
- return np.load(result[1]) # Result[2] contains the .npy mask
14
 
15
- def process_image(image, category_to_hide):
16
  # Convert uploaded image to a PIL Image
17
  image = Image.open(image.name).convert("RGB")
18
 
@@ -33,9 +33,10 @@ def process_image(image, category_to_hide):
33
  image_array = np.array(image)
34
  masked_image = image_array.copy()
35
 
36
- # Black out selected category
37
- for idx in grouped_mapping[category_to_hide]:
38
- masked_image[mask_data == idx] = [0, 0, 0]
 
39
 
40
  # Convert back to PIL Image
41
  result_image = Image.fromarray(masked_image)
@@ -47,13 +48,13 @@ demo = gr.Interface(
47
  fn=process_image,
48
  inputs=[
49
  gr.File(label="Upload an Image"),
50
- gr.Radio([
51
  "Background", "Clothes", "Face", "Hair", "Skin (Hands, Feet, Body)"
52
- ], label="Select Category to Hide")
53
  ],
54
  outputs=gr.Image(label="Masked Image"),
55
  title="Segmentation Mask Editor",
56
- description="Upload an image, generate a segmentation mask, and select a category to black out."
57
  )
58
 
59
  if __name__ == "__main__":
 
10
  def get_segmentation_mask(image_url):
11
  client = Client("facebook/sapiens-seg")
12
  result = client.predict(image=handle_file(image_url), model_name="1b", api_name="/process_image")
13
+ return np.load(result[2]) # Result[2] contains the .npy mask
14
 
15
+ def process_image(image, categories_to_hide):
16
  # Convert uploaded image to a PIL Image
17
  image = Image.open(image.name).convert("RGB")
18
 
 
33
  image_array = np.array(image)
34
  masked_image = image_array.copy()
35
 
36
+ # Black out selected categories
37
+ for category in categories_to_hide:
38
+ for idx in grouped_mapping.get(category, []):
39
+ masked_image[mask_data == idx] = [0, 0, 0]
40
 
41
  # Convert back to PIL Image
42
  result_image = Image.fromarray(masked_image)
 
48
  fn=process_image,
49
  inputs=[
50
  gr.File(label="Upload an Image"),
51
+ gr.CheckboxGroup([
52
  "Background", "Clothes", "Face", "Hair", "Skin (Hands, Feet, Body)"
53
+ ], label="Select Categories to Hide")
54
  ],
55
  outputs=gr.Image(label="Masked Image"),
56
  title="Segmentation Mask Editor",
57
+ description="Upload an image, generate a segmentation mask, and select categories to black out."
58
  )
59
 
60
  if __name__ == "__main__":