|
import logging |
|
import os |
|
import pickle |
|
from concurrent.futures import ProcessPoolExecutor, as_completed |
|
from pathlib import Path |
|
|
|
import pandas |
|
import pandas as pd |
|
import requests |
|
import torch |
|
from deepface import DeepFace |
|
from sklearn.metrics import accuracy_score, recall_score, f1_score |
|
from torch import nn |
|
from torch.utils.data import Dataset, DataLoader |
|
from tqdm import tqdm |
|
import torch.nn.functional as F |
|
|
|
TARGET_LABELS = ["Male", "Young", "Oval_Face", "High_Cheekbones", "Big_Lips", "Big_Nose"] |
|
|
|
|
|
def load_df(target_labels: list[str]): |
|
|
|
partition_df = pd.read_csv('./data/list_eval_partition.csv') |
|
labels_df = pd.read_csv('./data/list_attr_celeba.csv') |
|
|
|
|
|
df = pd.merge(partition_df, labels_df, on='image_id') |
|
|
|
|
|
for label in target_labels: |
|
df[label] = (df[label] + 1) // 2 |
|
|
|
|
|
train_df = df[df['partition'] != 2] |
|
test_df = df[df['partition'] == 2] |
|
|
|
return train_df, test_df |
|
|
|
|
|
def ensure_model_downloaded(model_path: str): |
|
os.makedirs(os.path.dirname(model_path), exist_ok=True) |
|
if not os.path.exists(model_path): |
|
logging.warning("Model not found. Downloading from GitHub...") |
|
response = requests.get("https://github.com/wyyadd/facetype/releases/download/1.0.0/classifier.pth") |
|
if response.status_code != 200: |
|
logging.error("Failed to download classifier.pth") |
|
raise RuntimeError("Failed to download model.") |
|
with open(model_path, "wb") as f: |
|
f.write(response.content) |
|
logging.info("Download complete.") |
|
|
|
|
|
class EmbeddingDataset(Dataset): |
|
def __init__(self, df: pandas.DataFrame, target_labels: list[str]): |
|
self.df = df |
|
self.image_root = Path("./data/img_align_celeba/img_align_celeba/") |
|
self.target_labels = target_labels |
|
self.preprocess() |
|
|
|
def preprocess(self): |
|
to_process_images = [image_id for image_id in self.df['image_id'] if |
|
not (self.image_root / f"{image_id}.pkl").exists()] |
|
if len(to_process_images) > 0: |
|
logging.info(f"Preprocessing {len(to_process_images)} images") |
|
else: |
|
return |
|
with ProcessPoolExecutor() as executor: |
|
futures = [executor.submit(self._process_image, image_id) for image_id in to_process_images] |
|
for future in tqdm(as_completed(futures), total=len(futures), desc="Preprocessing"): |
|
try: |
|
future.result() |
|
except Exception as e: |
|
logging.error(f"Error processing image: {e}") |
|
|
|
def _process_image(self, image_id: str): |
|
|
|
image_path = self.image_root / image_id |
|
cache_file = self.image_root / f"{image_id}.pkl" |
|
|
|
|
|
if not cache_file.exists(): |
|
|
|
embedding_obj = DeepFace.represent( |
|
img_path=str(image_path), |
|
model_name="VGG-Face", |
|
enforce_detection=False |
|
) |
|
embedding = torch.tensor(embedding_obj[0]["embedding"], dtype=torch.float32) |
|
|
|
|
|
with open(cache_file, "wb") as f: |
|
pickle.dump(embedding, f) |
|
|
|
def __len__(self): |
|
return len(self.df) |
|
|
|
def __getitem__(self, idx): |
|
row = self.df.iloc[idx] |
|
|
|
|
|
cache_file = self.image_root / f"{row['image_id']}.pkl" |
|
with open(cache_file, "rb") as f: |
|
embedding = pickle.load(f) |
|
|
|
|
|
labels = torch.from_numpy(row[self.target_labels].values.astype(int)) |
|
return embedding, labels |
|
|
|
|
|
class MultiLabelClassifier(nn.Module): |
|
def __init__(self, embedding_dim: int, hidden_dim: int): |
|
super().__init__() |
|
self.embedding_dim = embedding_dim |
|
self.hidden_dim = hidden_dim |
|
self.output_dim = len(TARGET_LABELS) |
|
self.dropout = 0.1 |
|
self.classifier = nn.Sequential( |
|
nn.Linear(embedding_dim, self.hidden_dim), |
|
nn.ReLU(inplace=True), |
|
nn.Dropout(self.dropout), |
|
nn.Linear(hidden_dim, hidden_dim // 2), |
|
nn.ReLU(inplace=True), |
|
nn.Dropout(self.dropout), |
|
nn.Linear(hidden_dim // 2, len(TARGET_LABELS)), |
|
) |
|
|
|
def forward(self, x): |
|
return self.classifier(x) |
|
|
|
|
|
class FocalLoss(nn.Module): |
|
def __init__(self, alpha=1.0, gamma=2.0, reduction='mean'): |
|
super(FocalLoss, self).__init__() |
|
self.alpha = alpha |
|
self.gamma = gamma |
|
self.reduction = reduction |
|
|
|
def forward(self, inputs: torch.Tensor, targets: torch.Tensor): |
|
probs = torch.sigmoid(inputs) |
|
ce_loss = F.binary_cross_entropy(probs, targets.float(), reduction='none') |
|
pt = torch.where(targets == 1, probs, 1 - probs) |
|
focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss |
|
|
|
if self.reduction == 'mean': |
|
return focal_loss.mean() |
|
elif self.reduction == 'sum': |
|
return focal_loss.sum() |
|
else: |
|
return focal_loss |
|
|
|
|
|
def main(): |
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format='%(asctime)s - %(levelname)s - %(message)s', |
|
handlers=[ |
|
logging.FileHandler("train.log"), |
|
logging.StreamHandler() |
|
] |
|
) |
|
train_df, test_df = load_df(TARGET_LABELS) |
|
|
|
|
|
train_dataset = EmbeddingDataset(train_df, TARGET_LABELS) |
|
test_dataset = EmbeddingDataset(test_df, TARGET_LABELS) |
|
|
|
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=False) |
|
test_loader = DataLoader(test_dataset, batch_size=32) |
|
logging.info(f"Initializing Dataset, train_loader: {len(train_loader)}, test_loader: {len(test_loader)}") |
|
|
|
device = torch.device("mps") |
|
logging.info(f"Using device: {device}") |
|
|
|
model = MultiLabelClassifier(embedding_dim=4096, hidden_dim=1024).to(device) |
|
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) |
|
|
|
criterion = FocalLoss(alpha=0.5, gamma=2.0) |
|
logging.info("Initializing model, optimizer and criterion") |
|
logging.info("Starting training") |
|
|
|
for epoch in range(50): |
|
model.train() |
|
for inputs, targets in tqdm(train_loader, desc=f"Training Epoch {epoch}"): |
|
inputs, targets = inputs.to(device), targets.to(device) |
|
outputs = model(inputs) |
|
loss = criterion(outputs, targets.float()) |
|
|
|
optimizer.zero_grad() |
|
loss.backward() |
|
optimizer.step() |
|
logging.info(f"Epoch {epoch}, Loss: {loss.item():.4f}") |
|
|
|
if epoch % 5 == 0: |
|
model.eval() |
|
test_loss = 0.0 |
|
all_preds = [] |
|
all_targets = [] |
|
|
|
with torch.no_grad(): |
|
for inputs, targets in tqdm(test_loader, desc=f"Test Epoch {epoch}"): |
|
inputs, targets = inputs.to(device), targets.to(device) |
|
outputs = model(inputs) |
|
loss = criterion(outputs, targets.float()) |
|
|
|
test_loss += loss.item() |
|
predicted = torch.sigmoid(outputs) > 0.5 |
|
all_preds.append(predicted) |
|
all_targets.append(targets) |
|
|
|
avg_test_loss = test_loss / len(test_loader) |
|
all_preds = torch.cat(all_preds).cpu().numpy() |
|
all_targets = torch.cat(all_targets).cpu().numpy() |
|
|
|
accuracy = accuracy_score(all_targets, all_preds) |
|
recall = recall_score(all_targets, all_preds, average='macro') |
|
f1 = f1_score(all_targets, all_preds, average='macro') |
|
|
|
logging.info( |
|
f"Epoch {epoch} - Test Loss: {avg_test_loss:.4f}, Accuracy: {accuracy:.2f}, Recall: {recall:.2f}, F1: {f1:.2f}") |
|
|
|
torch.save(model.state_dict(), "data/classifier.pth") |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|