File size: 8,222 Bytes
2360da5 4130e74 2360da5 4130e74 2360da5 4130e74 2360da5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 |
import logging
import os
import pickle
from concurrent.futures import ProcessPoolExecutor, as_completed
from pathlib import Path
import pandas
import pandas as pd
import requests
import torch
from deepface import DeepFace
from sklearn.metrics import accuracy_score, recall_score, f1_score
from torch import nn
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import torch.nn.functional as F
TARGET_LABELS = ["Male", "Young", "Oval_Face", "High_Cheekbones", "Big_Lips", "Big_Nose"]
def load_df(target_labels: list[str]):
# 1. load CSV file
partition_df = pd.read_csv('./data/list_eval_partition.csv')
labels_df = pd.read_csv('./data/list_attr_celeba.csv')
# 2. merge two tables
df = pd.merge(partition_df, labels_df, on='image_id')
# 3. mapping label: -1 -> 0
for label in target_labels:
df[label] = (df[label] + 1) // 2 # 转成 0/1
# 4. subset
train_df = df[df['partition'] != 2]
test_df = df[df['partition'] == 2]
return train_df, test_df
def ensure_model_downloaded(model_path: str):
os.makedirs(os.path.dirname(model_path), exist_ok=True)
if not os.path.exists(model_path):
logging.warning("Model not found. Downloading from GitHub...")
response = requests.get("https://github.com/wyyadd/facetype/releases/download/1.0.0/classifier.pth")
if response.status_code != 200:
logging.error("Failed to download classifier.pth")
raise RuntimeError("Failed to download model.")
with open(model_path, "wb") as f:
f.write(response.content)
logging.info("Download complete.")
class EmbeddingDataset(Dataset):
def __init__(self, df: pandas.DataFrame, target_labels: list[str]):
self.df = df
self.image_root = Path("./data/img_align_celeba/img_align_celeba/")
self.target_labels = target_labels
self.preprocess()
def preprocess(self):
to_process_images = [image_id for image_id in self.df['image_id'] if
not (self.image_root / f"{image_id}.pkl").exists()]
if len(to_process_images) > 0:
logging.info(f"Preprocessing {len(to_process_images)} images")
else:
return
with ProcessPoolExecutor() as executor:
futures = [executor.submit(self._process_image, image_id) for image_id in to_process_images]
for future in tqdm(as_completed(futures), total=len(futures), desc="Preprocessing"):
try:
future.result()
except Exception as e:
logging.error(f"Error processing image: {e}")
def _process_image(self, image_id: str):
# Get the image path and cache file path
image_path = self.image_root / image_id
cache_file = self.image_root / f"{image_id}.pkl"
# Check if the embedding is already cached
if not cache_file.exists():
# Generate the embedding if it is not cached
embedding_obj = DeepFace.represent(
img_path=str(image_path),
model_name="VGG-Face",
enforce_detection=False
)
embedding = torch.tensor(embedding_obj[0]["embedding"], dtype=torch.float32)
# Save the embedding to a pickle file for future use
with open(cache_file, "wb") as f:
pickle.dump(embedding, f)
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
row = self.df.iloc[idx]
# Get embedding
cache_file = self.image_root / f"{row['image_id']}.pkl"
with open(cache_file, "rb") as f:
embedding = pickle.load(f)
# Get labels
labels = torch.from_numpy(row[self.target_labels].values.astype(int))
return embedding, labels
class MultiLabelClassifier(nn.Module):
def __init__(self, embedding_dim: int, hidden_dim: int):
super().__init__()
self.embedding_dim = embedding_dim
self.hidden_dim = hidden_dim
self.output_dim = len(TARGET_LABELS)
self.dropout = 0.1
self.classifier = nn.Sequential(
nn.Linear(embedding_dim, self.hidden_dim),
nn.ReLU(inplace=True),
nn.Dropout(self.dropout),
nn.Linear(hidden_dim, hidden_dim // 2),
nn.ReLU(inplace=True),
nn.Dropout(self.dropout),
nn.Linear(hidden_dim // 2, len(TARGET_LABELS)),
)
def forward(self, x):
return self.classifier(x)
class FocalLoss(nn.Module):
def __init__(self, alpha=1.0, gamma=2.0, reduction='mean'):
super(FocalLoss, self).__init__()
self.alpha = alpha
self.gamma = gamma
self.reduction = reduction
def forward(self, inputs: torch.Tensor, targets: torch.Tensor):
probs = torch.sigmoid(inputs)
ce_loss = F.binary_cross_entropy(probs, targets.float(), reduction='none')
pt = torch.where(targets == 1, probs, 1 - probs)
focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss
if self.reduction == 'mean':
return focal_loss.mean()
elif self.reduction == 'sum':
return focal_loss.sum()
else:
return focal_loss
def main():
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler("train.log"),
logging.StreamHandler() # Also log to the console
]
)
train_df, test_df = load_df(TARGET_LABELS)
# filter df
# train_df, test_df = train_df[train_df.index % 5 == 0], test_df[test_df.index % 5 == 0]
train_dataset = EmbeddingDataset(train_df, TARGET_LABELS)
test_dataset = EmbeddingDataset(test_df, TARGET_LABELS)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32)
logging.info(f"Initializing Dataset, train_loader: {len(train_loader)}, test_loader: {len(test_loader)}")
device = torch.device("mps")
logging.info(f"Using device: {device}")
model = MultiLabelClassifier(embedding_dim=4096, hidden_dim=1024).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
# criterion = nn.BCEWithLogitsLoss()
criterion = FocalLoss(alpha=0.5, gamma=2.0)
logging.info("Initializing model, optimizer and criterion")
logging.info("Starting training")
for epoch in range(50):
model.train()
for inputs, targets in tqdm(train_loader, desc=f"Training Epoch {epoch}"):
inputs, targets = inputs.to(device), targets.to(device)
outputs = model(inputs)
loss = criterion(outputs, targets.float())
optimizer.zero_grad()
loss.backward()
optimizer.step()
logging.info(f"Epoch {epoch}, Loss: {loss.item():.4f}")
if epoch % 5 == 0:
model.eval()
test_loss = 0.0
all_preds = []
all_targets = []
with torch.no_grad():
for inputs, targets in tqdm(test_loader, desc=f"Test Epoch {epoch}"):
inputs, targets = inputs.to(device), targets.to(device)
outputs = model(inputs)
loss = criterion(outputs, targets.float())
test_loss += loss.item()
predicted = torch.sigmoid(outputs) > 0.5
all_preds.append(predicted)
all_targets.append(targets)
avg_test_loss = test_loss / len(test_loader)
all_preds = torch.cat(all_preds).cpu().numpy()
all_targets = torch.cat(all_targets).cpu().numpy()
accuracy = accuracy_score(all_targets, all_preds)
recall = recall_score(all_targets, all_preds, average='macro')
f1 = f1_score(all_targets, all_preds, average='macro')
logging.info(
f"Epoch {epoch} - Test Loss: {avg_test_loss:.4f}, Accuracy: {accuracy:.2f}, Recall: {recall:.2f}, F1: {f1:.2f}")
torch.save(model.state_dict(), "data/classifier.pth")
if __name__ == "__main__":
main()
|