wyyadd commited on
Commit
2360da5
·
1 Parent(s): b19ada5
Files changed (5) hide show
  1. .gitignore +2 -1
  2. app.py +21 -0
  3. main.py +23 -67
  4. model.py +210 -0
  5. requirements.txt +6 -1
.gitignore CHANGED
@@ -169,4 +169,5 @@ cython_debug/
169
 
170
  # PyPI configuration file
171
  .pypirc
172
- .idea/*
 
 
169
 
170
  # PyPI configuration file
171
  .pypirc
172
+ .idea/*
173
+ data
app.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ from main import get_pred_binary
4
+ from model import TARGET_LABELS
5
+
6
+
7
+ def get_face_type(img):
8
+ pred_binary = get_pred_binary(img)
9
+ result = [f"{label}: {bool(pred)}" for label, pred in zip(TARGET_LABELS, pred_binary)]
10
+ face_type = int(''.join(map(str, pred_binary)), 2)
11
+ result = f"face_type: {face_type}\n{"\n".join(result)}"
12
+ return result
13
+
14
+
15
+ demo = gr.Interface(
16
+ fn=get_face_type,
17
+ inputs=["image"],
18
+ outputs=["text"],
19
+ )
20
+
21
+ demo.launch()
main.py CHANGED
@@ -1,15 +1,19 @@
1
  import cv2
2
  import numpy as np
3
  import requests
 
4
  from deepface import DeepFace
5
  from fastapi import FastAPI, HTTPException
6
 
 
 
7
  app = FastAPI()
8
 
9
- np.random.seed(42) # For reproducibility
10
- hyperplanes = np.random.randn(512, 5)
11
- # Optional: Normalize each hyperplane
12
- hyperplanes /= np.linalg.norm(hyperplanes, axis=0)
 
13
 
14
 
15
  @app.get("/face-type")
@@ -23,73 +27,25 @@ def get_face_type(url: str):
23
  except requests.exceptions.RequestException as e:
24
  raise HTTPException(status_code=400, detail=f"Failed to download image from URL: {str(e)}")
25
 
 
 
 
 
 
 
 
 
26
  try:
27
  embedding_objs = DeepFace.represent(
28
  img_path=img,
29
- model_name="Facenet512")
30
  except Exception as e:
31
  raise HTTPException(status_code=500, detail="No face detected.")
 
32
 
33
- ebd = np.array(embedding_objs[0]['embedding'], dtype=np.float32)
34
- # Project vector onto hyperplanes
35
- projections = np.dot(ebd, hyperplanes)
36
- # Binarize (sign function)
37
- bits = (projections >= 0).astype(int)
38
- # Convert bits to integer (LSB first)
39
- face_type = int(''.join(map(str, bits)), 2)
40
-
41
- return {"face_type": face_type}
42
 
43
- # def get_face_type(file):
44
- # try:
45
- # attribute = DeepFace.analyze(
46
- # img_path=file,
47
- # actions=['age', 'gender'],
48
- # )
49
- # gender = attribute[0]['dominant_gender']
50
- # age = attribute[0]['age']
51
- # if gender == 'Man':
52
- # if age < 10:
53
- # face_type = 7
54
- # elif age < 20:
55
- # face_type = 3
56
- # elif age < 30:
57
- # face_type = 12
58
- # elif age < 40:
59
- # face_type = 1
60
- # elif age < 50:
61
- # face_type = 15
62
- # elif age < 60:
63
- # face_type = 5
64
- # elif age < 70:
65
- # face_type = 10
66
- # else:
67
- # face_type = 8
68
- # elif gender == 'Woman':
69
- # if age < 10:
70
- # face_type = 14
71
- # elif age < 20:
72
- # face_type = 0
73
- # elif age < 30:
74
- # face_type = 4
75
- # elif age < 40:
76
- # face_type = 6
77
- # elif age < 50:
78
- # face_type = 13
79
- # elif age < 60:
80
- # face_type = 2
81
- # elif age < 70:
82
- # face_type = 9
83
- # else:
84
- # face_type = 11
85
- # else:
86
- # return "Face could not be detected."
87
- # return f"face type:{face_type}---gender:{gender}---age:{age}"
88
- # except Exception as e:
89
- # print(e)
90
- # return f"Face could not be detected."
91
- #
92
- #
93
- # if __name__ == '__main__':
94
- # demo = gr.Interface(fn=get_new_face_type, inputs="image", outputs="label")
95
- # demo.launch(share=False)
 
1
  import cv2
2
  import numpy as np
3
  import requests
4
+ import torch
5
  from deepface import DeepFace
6
  from fastapi import FastAPI, HTTPException
7
 
8
+ from model import MultiLabelClassifier
9
+
10
  app = FastAPI()
11
 
12
+ model_path = "data/classifier.pth"
13
+ model = MultiLabelClassifier(embedding_dim=4096, hidden_dim=1024)
14
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
+ model.load_state_dict(torch.load(model_path, weights_only=True))
16
+ model.to(device).eval()
17
 
18
 
19
  @app.get("/face-type")
 
27
  except requests.exceptions.RequestException as e:
28
  raise HTTPException(status_code=400, detail=f"Failed to download image from URL: {str(e)}")
29
 
30
+ pred_binary = get_pred_binary(img)
31
+
32
+ face_type = int(''.join(map(str, pred_binary)), 2)
33
+
34
+ return {"face_type": face_type}
35
+
36
+
37
+ def get_pred_binary(img: np.ndarray):
38
  try:
39
  embedding_objs = DeepFace.represent(
40
  img_path=img,
41
+ model_name="VGG-Face")
42
  except Exception as e:
43
  raise HTTPException(status_code=500, detail="No face detected.")
44
+ ebd = torch.tensor(embedding_objs[0]['embedding'], dtype=torch.float32).to(device)
45
 
46
+ with torch.no_grad():
47
+ logits = model(ebd)
48
+ probs = torch.sigmoid(logits).cpu().numpy()
49
+ pred_binary = (probs > 0.5).astype(int)
 
 
 
 
 
50
 
51
+ return pred_binary
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import pickle
3
+ from concurrent.futures import ProcessPoolExecutor, as_completed
4
+ from pathlib import Path
5
+
6
+ import pandas
7
+ import pandas as pd
8
+ import torch
9
+ from deepface import DeepFace
10
+ from sklearn.metrics import accuracy_score, recall_score, f1_score
11
+ from torch import nn
12
+ from torch.utils.data import Dataset, DataLoader
13
+ from tqdm import tqdm
14
+ import torch.nn.functional as F
15
+
16
+ TARGET_LABELS = ["Male", "Young", "Oval_Face", "High_Cheekbones", "Big_Lips", "Big_Nose"]
17
+
18
+
19
+ def load_df(target_labels: list[str]):
20
+ # 1. load CSV file
21
+ partition_df = pd.read_csv('./data/list_eval_partition.csv')
22
+ labels_df = pd.read_csv('./data/list_attr_celeba.csv')
23
+
24
+ # 2. merge two tables
25
+ df = pd.merge(partition_df, labels_df, on='image_id')
26
+
27
+ # 3. mapping label: -1 -> 0
28
+ for label in target_labels:
29
+ df[label] = (df[label] + 1) // 2 # 转成 0/1
30
+
31
+ # 4. subset
32
+ train_df = df[df['partition'] != 2]
33
+ test_df = df[df['partition'] == 2]
34
+
35
+ return train_df, test_df
36
+
37
+
38
+ class EmbeddingDataset(Dataset):
39
+ def __init__(self, df: pandas.DataFrame, target_labels: list[str]):
40
+ self.df = df
41
+ self.image_root = Path("./data/img_align_celeba/img_align_celeba/")
42
+ self.target_labels = target_labels
43
+ self.preprocess()
44
+
45
+ def preprocess(self):
46
+ to_process_images = [image_id for image_id in self.df['image_id'] if
47
+ not (self.image_root / f"{image_id}.pkl").exists()]
48
+ if len(to_process_images) > 0:
49
+ logging.info(f"Preprocessing {len(to_process_images)} images")
50
+ else:
51
+ return
52
+ with ProcessPoolExecutor() as executor:
53
+ futures = [executor.submit(self._process_image, image_id) for image_id in to_process_images]
54
+ for future in tqdm(as_completed(futures), total=len(futures), desc="Preprocessing"):
55
+ try:
56
+ future.result()
57
+ except Exception as e:
58
+ logging.error(f"Error processing image: {e}")
59
+
60
+ def _process_image(self, image_id: str):
61
+ # Get the image path and cache file path
62
+ image_path = self.image_root / image_id
63
+ cache_file = self.image_root / f"{image_id}.pkl"
64
+
65
+ # Check if the embedding is already cached
66
+ if not cache_file.exists():
67
+ # Generate the embedding if it is not cached
68
+ embedding_obj = DeepFace.represent(
69
+ img_path=str(image_path),
70
+ model_name="VGG-Face",
71
+ enforce_detection=False
72
+ )
73
+ embedding = torch.tensor(embedding_obj[0]["embedding"], dtype=torch.float32)
74
+
75
+ # Save the embedding to a pickle file for future use
76
+ with open(cache_file, "wb") as f:
77
+ pickle.dump(embedding, f)
78
+
79
+ def __len__(self):
80
+ return len(self.df)
81
+
82
+ def __getitem__(self, idx):
83
+ row = self.df.iloc[idx]
84
+
85
+ # Get embedding
86
+ cache_file = self.image_root / f"{row['image_id']}.pkl"
87
+ with open(cache_file, "rb") as f:
88
+ embedding = pickle.load(f)
89
+
90
+ # Get labels
91
+ labels = torch.from_numpy(row[self.target_labels].values.astype(int))
92
+ return embedding, labels
93
+
94
+
95
+ class MultiLabelClassifier(nn.Module):
96
+ def __init__(self, embedding_dim: int, hidden_dim: int):
97
+ super().__init__()
98
+ self.embedding_dim = embedding_dim
99
+ self.hidden_dim = hidden_dim
100
+ self.output_dim = len(TARGET_LABELS)
101
+ self.dropout = 0.1
102
+ self.classifier = nn.Sequential(
103
+ nn.Linear(embedding_dim, self.hidden_dim),
104
+ nn.ReLU(inplace=True),
105
+ nn.Dropout(self.dropout),
106
+ nn.Linear(hidden_dim, hidden_dim // 2),
107
+ nn.ReLU(inplace=True),
108
+ nn.Dropout(self.dropout),
109
+ nn.Linear(hidden_dim // 2, len(TARGET_LABELS)),
110
+ )
111
+
112
+ def forward(self, x):
113
+ return self.classifier(x)
114
+
115
+
116
+ class FocalLoss(nn.Module):
117
+ def __init__(self, alpha=1.0, gamma=2.0, reduction='mean'):
118
+ super(FocalLoss, self).__init__()
119
+ self.alpha = alpha
120
+ self.gamma = gamma
121
+ self.reduction = reduction
122
+
123
+ def forward(self, inputs: torch.Tensor, targets: torch.Tensor):
124
+ probs = torch.sigmoid(inputs)
125
+ ce_loss = F.binary_cross_entropy(probs, targets.float(), reduction='none')
126
+ pt = torch.where(targets == 1, probs, 1 - probs)
127
+ focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss
128
+
129
+ if self.reduction == 'mean':
130
+ return focal_loss.mean()
131
+ elif self.reduction == 'sum':
132
+ return focal_loss.sum()
133
+ else:
134
+ return focal_loss
135
+
136
+
137
+ def main():
138
+ logging.basicConfig(
139
+ level=logging.INFO,
140
+ format='%(asctime)s - %(levelname)s - %(message)s',
141
+ handlers=[
142
+ logging.FileHandler("train.log"),
143
+ logging.StreamHandler() # Also log to the console
144
+ ]
145
+ )
146
+ train_df, test_df = load_df(TARGET_LABELS)
147
+ # filter df
148
+ # train_df, test_df = train_df[train_df.index % 5 == 0], test_df[test_df.index % 5 == 0]
149
+ train_dataset = EmbeddingDataset(train_df, TARGET_LABELS)
150
+ test_dataset = EmbeddingDataset(test_df, TARGET_LABELS)
151
+
152
+ train_loader = DataLoader(train_dataset, batch_size=32, shuffle=False)
153
+ test_loader = DataLoader(test_dataset, batch_size=32)
154
+ logging.info(f"Initializing Dataset, train_loader: {len(train_loader)}, test_loader: {len(test_loader)}")
155
+
156
+ device = torch.device("mps")
157
+ logging.info(f"Using device: {device}")
158
+
159
+ model = MultiLabelClassifier(embedding_dim=4096, hidden_dim=1024).to(device)
160
+ optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
161
+ # criterion = nn.BCEWithLogitsLoss()
162
+ criterion = FocalLoss(alpha=0.5, gamma=2.0)
163
+ logging.info("Initializing model, optimizer and criterion")
164
+ logging.info("Starting training")
165
+
166
+ for epoch in range(50):
167
+ model.train()
168
+ for inputs, targets in tqdm(train_loader, desc=f"Training Epoch {epoch}"):
169
+ inputs, targets = inputs.to(device), targets.to(device)
170
+ outputs = model(inputs)
171
+ loss = criterion(outputs, targets.float())
172
+
173
+ optimizer.zero_grad()
174
+ loss.backward()
175
+ optimizer.step()
176
+ logging.info(f"Epoch {epoch}, Loss: {loss.item():.4f}")
177
+
178
+ if epoch % 5 == 0:
179
+ model.eval()
180
+ test_loss = 0.0
181
+ all_preds = []
182
+ all_targets = []
183
+
184
+ with torch.no_grad():
185
+ for inputs, targets in tqdm(test_loader, desc=f"Test Epoch {epoch}"):
186
+ inputs, targets = inputs.to(device), targets.to(device)
187
+ outputs = model(inputs)
188
+ loss = criterion(outputs, targets.float())
189
+
190
+ test_loss += loss.item()
191
+ predicted = torch.sigmoid(outputs) > 0.5
192
+ all_preds.append(predicted)
193
+ all_targets.append(targets)
194
+
195
+ avg_test_loss = test_loss / len(test_loader)
196
+ all_preds = torch.cat(all_preds).cpu().numpy()
197
+ all_targets = torch.cat(all_targets).cpu().numpy()
198
+
199
+ accuracy = accuracy_score(all_targets, all_preds)
200
+ recall = recall_score(all_targets, all_preds, average='macro')
201
+ f1 = f1_score(all_targets, all_preds, average='macro')
202
+
203
+ logging.info(
204
+ f"Epoch {epoch} - Test Loss: {avg_test_loss:.4f}, Accuracy: {accuracy:.2f}, Recall: {recall:.2f}, F1: {f1:.2f}")
205
+
206
+ torch.save(model.state_dict(), "data/classifier.pth")
207
+
208
+
209
+ if __name__ == "__main__":
210
+ main()
requirements.txt CHANGED
@@ -3,4 +3,9 @@ numpy
3
  requests
4
  fastapi[standard]
5
  opencv-python
6
- tf-keras
 
 
 
 
 
 
3
  requests
4
  fastapi[standard]
5
  opencv-python
6
+ tf-keras
7
+ pandas
8
+ torch
9
+ scikit-learn
10
+ gradio
11
+ tqdm