haydpw commited on
Commit
da8cc4f
·
1 Parent(s): c892a5c

use PNG for mask

Browse files
Files changed (2) hide show
  1. main.py +3 -2
  2. utils/helpers.py +19 -3
main.py CHANGED
@@ -10,7 +10,7 @@ import models.face_classifier as classifier
10
  from fastapi.middleware.cors import CORSMiddleware
11
  from PIL import Image
12
  from rembg import remove
13
- from utils.helpers import image_to_base64, calculate_mask_area
14
 
15
 
16
  dotenv.load_dotenv()
@@ -62,7 +62,8 @@ async def predict_image(file: UploadFile = File(...)):
62
  # change the mask to base64 and calculate the score
63
  for i in range(len(results)):
64
  mask_area = calculate_mask_area(results[i]["mask"])
65
- results[i]["mask"] = image_to_base64(results[i]["mask"])
 
66
  if results[i]["label"] == "background":
67
  continue
68
  print(f"{results[i]['label']} area: {mask_area}")
 
10
  from fastapi.middleware.cors import CORSMiddleware
11
  from PIL import Image
12
  from rembg import remove
13
+ from utils.helpers import image_to_base64, calculate_mask_area, process_image
14
 
15
 
16
  dotenv.load_dotenv()
 
62
  # change the mask to base64 and calculate the score
63
  for i in range(len(results)):
64
  mask_area = calculate_mask_area(results[i]["mask"])
65
+ processed_image = process_image(results[i]["mask"])
66
+ results[i]["mask"] = image_to_base64(processed_image, "PNG")
67
  if results[i]["label"] == "background":
68
  continue
69
  print(f"{results[i]['label']} area: {mask_area}")
utils/helpers.py CHANGED
@@ -3,13 +3,29 @@ import io
3
  import base64
4
  import numpy as np
5
 
6
- def image_to_base64(image: Image.Image) -> str:
7
  buffered = io.BytesIO()
8
 
9
- image.save(buffered, format="JPEG")
10
  return base64.b64encode(buffered.getvalue()).decode("utf-8")
11
 
12
  def calculate_mask_area(mask: Image.Image, background=False) -> int:
13
  mask_array = np.array(mask)
14
  non_zero_pixels = np.count_nonzero(mask_array)
15
- return non_zero_pixels
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  import base64
4
  import numpy as np
5
 
6
+ def image_to_base64(image: Image.Image,format="JPEG") -> str:
7
  buffered = io.BytesIO()
8
 
9
+ image.save(buffered, format=format)
10
  return base64.b64encode(buffered.getvalue()).decode("utf-8")
11
 
12
  def calculate_mask_area(mask: Image.Image, background=False) -> int:
13
  mask_array = np.array(mask)
14
  non_zero_pixels = np.count_nonzero(mask_array)
15
+ return non_zero_pixels
16
+
17
+ def process_image(input_image: Image.Image) -> Image.Image:
18
+ if input_image.mode != 'RGBA':
19
+ input_image = input_image.convert('RGBA')
20
+
21
+ data = np.array(input_image)
22
+ # Split the image into its component channels
23
+ r, g, b, a = data.T
24
+
25
+ # Create a mask where all pixels that are black (0,0,0) will have 0 alpha
26
+ black_areas = (r == 0) & (g == 0) & (b == 0)
27
+
28
+ # Apply the mask to the alpha channel
29
+ data[..., 3][black_areas] = 0
30
+
31
+ return Image.fromarray(data)