Sartc commited on
Commit
017b291
·
verified ·
1 Parent(s): 3ba3ce7

Upload face_parsing.py

Browse files
Files changed (1) hide show
  1. face_parsing.py +220 -0
face_parsing.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import json
3
+ import torch
4
+ from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation
5
+ from PIL import Image
6
+ from PIL import ImageOps
7
+ import matplotlib.pyplot as plt
8
+
9
+ def segmentation(input_img_path):
10
+ with open('face-parsing/config.json', 'r') as file:
11
+ data = json.load(file)
12
+
13
+ for key, value in data["id2label"].items():
14
+ print(f"{key}: {value}")
15
+
16
+ image_processor = SegformerImageProcessor.from_pretrained("face-parsing")
17
+ model = SegformerForSemanticSegmentation.from_pretrained("face-parsing")
18
+
19
+ # input_img_path = "akshay kumar img.jpeg"
20
+ if isinstance(input_img_path, str): # It's a path
21
+ image = Image.open(input_img_path)
22
+ else: # It's already an image object
23
+ image = input_img_path
24
+
25
+ image = image.convert("RGB")
26
+
27
+ # new_size = (128, 128)
28
+ # image = image.resize(new_size)
29
+ # plt.imshow(image)
30
+ # plt.axis("off")
31
+ # plt.show()
32
+
33
+ inputs = image_processor(images=image, return_tensors="pt")
34
+ outputs = model(**inputs)
35
+
36
+ color_map = np.array([
37
+ [255, 255, 255], # 0 background
38
+ [255, 0, 0], # 1 skin
39
+ [0, 255, 0], # 2 nose
40
+ [0, 0, 255], # 3 eye_g
41
+ [255, 255, 0], # 4 l_eye
42
+ [255, 0, 255], # 5 r_eye
43
+ [0, 255, 255], # 6 l_brow
44
+ [192, 192, 192], # 7 r_brow
45
+ [128, 128, 128], # 8 l_ear
46
+ [128, 0, 0], # 9 r_ear
47
+ [128, 128, 0], # 10 mouth
48
+ [0, 128, 0], # 11 u_lip
49
+ [0, 128, 128], # 12 l_lip
50
+ [0, 0, 128], # 13 hair
51
+ [255, 165, 0], # 14 hat
52
+ [75, 0, 130], # 15 ear_r
53
+ [240, 230, 140], # 16 neck_l
54
+ [255, 20, 147], # 17 neck
55
+ [100, 149, 237] # 18 cloth
56
+ ])
57
+
58
+ predicted_classes = torch.argmax(outputs["logits"], dim=1).squeeze().cpu().numpy()
59
+ segmentation_map = color_map[predicted_classes]
60
+
61
+ outputs["logits"] = outputs["logits"].squeeze()
62
+ outputs["logits"].shape
63
+
64
+ img = np.array(image)
65
+ print(img.shape)
66
+ print(segmentation_map.shape)
67
+
68
+ face_mask = outputs["logits"][1]
69
+ print(face_mask)
70
+
71
+ # plt.figure(figsize=(15, 7))
72
+ # plt.subplot(1, 2, 1)
73
+ # plt.title("Original Image")
74
+ # plt.imshow(image)
75
+ # plt.axis('off')
76
+
77
+ # plt.subplot(1, 2, 2)
78
+ # plt.title("Predicted Segmentation Map")
79
+ # plt.imshow(segmentation_map)
80
+ # plt.axis('off')
81
+
82
+ # plt.show()
83
+
84
+ new_size = (128, 128)
85
+ image = image.resize(new_size)
86
+ original_image_np = np.array(image)
87
+ segmented_image_np = np.array(segmentation_map)
88
+
89
+ skin_color = [255, 0, 0]
90
+ eyeg_color = [0, 0, 255]
91
+ nose_color = [0, 255, 0]
92
+ leye_color = [255, 255, 0]
93
+ reye_color = [255, 0, 255]
94
+ lbrow_color = [0, 255, 255]
95
+ rbrow_color = [192, 192, 192]
96
+ lear_color = [128, 128, 128]
97
+ rear_color = [128, 0, 0]
98
+ mouth_color = [128, 128, 0]
99
+ ulip_color = [0, 128, 0]
100
+ llip_color = [0, 128, 128]
101
+ hair_color = [0, 0, 128]
102
+ hat_color = [255, 165, 0]
103
+ neck_color = [255, 20, 147]
104
+
105
+ skin_mask = np.all(segmented_image_np == skin_color, axis=-1)
106
+ eyeg_mask = np.all(segmented_image_np == eyeg_color, axis=-1)
107
+ nose_mask = np.all(segmented_image_np == nose_color, axis=-1)
108
+ leye_mask = np.all(segmented_image_np == leye_color, axis=-1)
109
+ reye_mask = np.all(segmented_image_np == reye_color, axis=-1)
110
+ lbrow_mask = np.all(segmented_image_np == lbrow_color, axis=-1)
111
+ rbrow_mask = np.all(segmented_image_np == rbrow_color, axis=-1)
112
+ lear_mask = np.all(segmented_image_np == lear_color, axis=-1)
113
+ rear_mask = np.all(segmented_image_np == rear_color, axis=-1)
114
+ mouth_mask = np.all(segmented_image_np == mouth_color, axis=-1)
115
+ ulip_mask = np.all(segmented_image_np == ulip_color, axis=-1)
116
+ llip_mask = np.all(segmented_image_np == llip_color, axis=-1)
117
+ hair_mask = np.all(segmented_image_np == hair_color, axis=-1)
118
+ hat_mask = np.all(segmented_image_np == hat_color, axis=-1)
119
+ neck_mask = np.all(segmented_image_np == neck_color, axis=-1)
120
+
121
+ # combining all the masks
122
+ combined_mask = np.logical_or.reduce((
123
+ skin_mask, eyeg_mask, nose_mask, leye_mask, reye_mask,
124
+ lbrow_mask, rbrow_mask, lear_mask, rear_mask,
125
+ mouth_mask, ulip_mask, llip_mask, hair_mask,
126
+ hat_mask, neck_mask
127
+ ))
128
+
129
+ # applying the combined mask to the original image
130
+ selected_regions = np.full_like(original_image_np, 255)
131
+
132
+ selected_regions[combined_mask] = original_image_np[combined_mask]
133
+
134
+ # # Visualize the results
135
+ # plt.figure(figsize=(15, 5))
136
+
137
+ # # Display the original image
138
+ # plt.subplot(1, 3, 1)
139
+ # plt.imshow(image)
140
+ # plt.title("Original Image")
141
+ # plt.axis("off")
142
+
143
+ # # Display the segmented image
144
+ # plt.subplot(1, 3, 2)
145
+ # plt.imshow(segmentation_map)
146
+ # plt.title("Segmented Image")
147
+ # plt.axis("off")
148
+
149
+ # # Display the extracted regions
150
+ # plt.subplot(1, 3, 3)
151
+ # plt.imshow(Image.fromarray(selected_regions))
152
+ # plt.title("Selected Regions")
153
+ # plt.axis("off")
154
+
155
+ # plt.tight_layout()
156
+ # plt.show()
157
+
158
+ selected_regions = Image.fromarray(selected_regions)
159
+ # selected_regions.save("only_face.jpg")
160
+ return selected_regions
161
+
162
+ """challenges as of now:
163
+ 1. meme face image is not eradicated (manually remove this). (done!)
164
+ 2. note down the width of the neck of meme image and make adjustements accordingly so that the target face gets fixed on the meme image. (pending)
165
+ 3. background in the person's image is black, which is messing up with hair, fix that.
166
+ """
167
+
168
+ def integration_with_meme(input_img_path, face_x, face_y, face_width, face_height):
169
+
170
+ person_image = segmentation(input_img_path)
171
+ # person_image = Image.open('only_face.jpg')
172
+ meme_image = Image.open('chillguy.jpeg')
173
+
174
+ # Convert meme image to RGBA (for transparency handling) and to a NumPy array
175
+ meme_image = meme_image.convert("RGBA")
176
+ meme_data = np.array(meme_image)
177
+
178
+ # Define the coordinates of the face region in the meme image
179
+ # face_x, face_y, face_width, face_height = 0, 40, 180, 110 # Adjust based on meme image
180
+
181
+ # Clamp the face region to ensure it is within bounds
182
+ meme_height, meme_width = meme_data.shape[:2]
183
+ face_width = min(face_width, meme_width - face_x)
184
+ face_height = min(face_height, meme_height - face_y)
185
+
186
+ # Resize the person's image to fit the face region
187
+ person_resized = person_image.resize((face_width, face_height)).convert("RGBA")
188
+
189
+ person_resized = ImageOps.mirror(person_resized)
190
+
191
+ person_data = np.array(person_resized)
192
+
193
+ r, g, b, a = person_data[..., 0], person_data[..., 1], person_data[..., 2], person_data[..., 3]
194
+ white_areas = (r > 230) & (g > 230) & (b > 230)
195
+ person_data[white_areas, 3] = 0
196
+
197
+ person_resized = Image.fromarray(person_data)
198
+
199
+ face_region = meme_data[face_y:face_y+face_height, face_x:face_x+face_width]
200
+
201
+ face_region_resized = Image.fromarray(face_region).resize((face_width, face_height))
202
+
203
+ blended_region = Image.alpha_composite(face_region_resized, person_resized)
204
+
205
+ blended_region_data = np.array(blended_region)
206
+ meme_data[face_y:face_y+face_height, face_x:face_x+face_width] = blended_region_data
207
+
208
+ result_image = Image.fromarray(meme_data)
209
+
210
+ # plt.imshow(result_image)
211
+ # plt.axis("off")
212
+ # plt.show()
213
+
214
+ meme_image = np.array(meme_image)
215
+ print(meme_image.shape)
216
+ result_image = np.array(result_image)
217
+ print(result_image.shape)
218
+ return result_image
219
+
220
+ # integration_with_meme(input_img_path="akshay kumar img.jpeg", face_x=0, face_y=40, face_width=180, face_height=110)