facetype / model.py
wyyadd's picture
feat: model download
4130e74
import logging
import os
import pickle
from concurrent.futures import ProcessPoolExecutor, as_completed
from pathlib import Path
import pandas
import pandas as pd
import requests
import torch
from deepface import DeepFace
from sklearn.metrics import accuracy_score, recall_score, f1_score
from torch import nn
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import torch.nn.functional as F
TARGET_LABELS = ["Male", "Young", "Oval_Face", "High_Cheekbones", "Big_Lips", "Big_Nose"]
def load_df(target_labels: list[str]):
# 1. load CSV file
partition_df = pd.read_csv('./data/list_eval_partition.csv')
labels_df = pd.read_csv('./data/list_attr_celeba.csv')
# 2. merge two tables
df = pd.merge(partition_df, labels_df, on='image_id')
# 3. mapping label: -1 -> 0
for label in target_labels:
df[label] = (df[label] + 1) // 2 # 转成 0/1
# 4. subset
train_df = df[df['partition'] != 2]
test_df = df[df['partition'] == 2]
return train_df, test_df
def ensure_model_downloaded(model_path: str):
os.makedirs(os.path.dirname(model_path), exist_ok=True)
if not os.path.exists(model_path):
logging.warning("Model not found. Downloading from GitHub...")
response = requests.get("https://github.com/wyyadd/facetype/releases/download/1.0.0/classifier.pth")
if response.status_code != 200:
logging.error("Failed to download classifier.pth")
raise RuntimeError("Failed to download model.")
with open(model_path, "wb") as f:
f.write(response.content)
logging.info("Download complete.")
class EmbeddingDataset(Dataset):
def __init__(self, df: pandas.DataFrame, target_labels: list[str]):
self.df = df
self.image_root = Path("./data/img_align_celeba/img_align_celeba/")
self.target_labels = target_labels
self.preprocess()
def preprocess(self):
to_process_images = [image_id for image_id in self.df['image_id'] if
not (self.image_root / f"{image_id}.pkl").exists()]
if len(to_process_images) > 0:
logging.info(f"Preprocessing {len(to_process_images)} images")
else:
return
with ProcessPoolExecutor() as executor:
futures = [executor.submit(self._process_image, image_id) for image_id in to_process_images]
for future in tqdm(as_completed(futures), total=len(futures), desc="Preprocessing"):
try:
future.result()
except Exception as e:
logging.error(f"Error processing image: {e}")
def _process_image(self, image_id: str):
# Get the image path and cache file path
image_path = self.image_root / image_id
cache_file = self.image_root / f"{image_id}.pkl"
# Check if the embedding is already cached
if not cache_file.exists():
# Generate the embedding if it is not cached
embedding_obj = DeepFace.represent(
img_path=str(image_path),
model_name="VGG-Face",
enforce_detection=False
)
embedding = torch.tensor(embedding_obj[0]["embedding"], dtype=torch.float32)
# Save the embedding to a pickle file for future use
with open(cache_file, "wb") as f:
pickle.dump(embedding, f)
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
row = self.df.iloc[idx]
# Get embedding
cache_file = self.image_root / f"{row['image_id']}.pkl"
with open(cache_file, "rb") as f:
embedding = pickle.load(f)
# Get labels
labels = torch.from_numpy(row[self.target_labels].values.astype(int))
return embedding, labels
class MultiLabelClassifier(nn.Module):
def __init__(self, embedding_dim: int, hidden_dim: int):
super().__init__()
self.embedding_dim = embedding_dim
self.hidden_dim = hidden_dim
self.output_dim = len(TARGET_LABELS)
self.dropout = 0.1
self.classifier = nn.Sequential(
nn.Linear(embedding_dim, self.hidden_dim),
nn.ReLU(inplace=True),
nn.Dropout(self.dropout),
nn.Linear(hidden_dim, hidden_dim // 2),
nn.ReLU(inplace=True),
nn.Dropout(self.dropout),
nn.Linear(hidden_dim // 2, len(TARGET_LABELS)),
)
def forward(self, x):
return self.classifier(x)
class FocalLoss(nn.Module):
def __init__(self, alpha=1.0, gamma=2.0, reduction='mean'):
super(FocalLoss, self).__init__()
self.alpha = alpha
self.gamma = gamma
self.reduction = reduction
def forward(self, inputs: torch.Tensor, targets: torch.Tensor):
probs = torch.sigmoid(inputs)
ce_loss = F.binary_cross_entropy(probs, targets.float(), reduction='none')
pt = torch.where(targets == 1, probs, 1 - probs)
focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss
if self.reduction == 'mean':
return focal_loss.mean()
elif self.reduction == 'sum':
return focal_loss.sum()
else:
return focal_loss
def main():
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler("train.log"),
logging.StreamHandler() # Also log to the console
]
)
train_df, test_df = load_df(TARGET_LABELS)
# filter df
# train_df, test_df = train_df[train_df.index % 5 == 0], test_df[test_df.index % 5 == 0]
train_dataset = EmbeddingDataset(train_df, TARGET_LABELS)
test_dataset = EmbeddingDataset(test_df, TARGET_LABELS)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32)
logging.info(f"Initializing Dataset, train_loader: {len(train_loader)}, test_loader: {len(test_loader)}")
device = torch.device("mps")
logging.info(f"Using device: {device}")
model = MultiLabelClassifier(embedding_dim=4096, hidden_dim=1024).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
# criterion = nn.BCEWithLogitsLoss()
criterion = FocalLoss(alpha=0.5, gamma=2.0)
logging.info("Initializing model, optimizer and criterion")
logging.info("Starting training")
for epoch in range(50):
model.train()
for inputs, targets in tqdm(train_loader, desc=f"Training Epoch {epoch}"):
inputs, targets = inputs.to(device), targets.to(device)
outputs = model(inputs)
loss = criterion(outputs, targets.float())
optimizer.zero_grad()
loss.backward()
optimizer.step()
logging.info(f"Epoch {epoch}, Loss: {loss.item():.4f}")
if epoch % 5 == 0:
model.eval()
test_loss = 0.0
all_preds = []
all_targets = []
with torch.no_grad():
for inputs, targets in tqdm(test_loader, desc=f"Test Epoch {epoch}"):
inputs, targets = inputs.to(device), targets.to(device)
outputs = model(inputs)
loss = criterion(outputs, targets.float())
test_loss += loss.item()
predicted = torch.sigmoid(outputs) > 0.5
all_preds.append(predicted)
all_targets.append(targets)
avg_test_loss = test_loss / len(test_loader)
all_preds = torch.cat(all_preds).cpu().numpy()
all_targets = torch.cat(all_targets).cpu().numpy()
accuracy = accuracy_score(all_targets, all_preds)
recall = recall_score(all_targets, all_preds, average='macro')
f1 = f1_score(all_targets, all_preds, average='macro')
logging.info(
f"Epoch {epoch} - Test Loss: {avg_test_loss:.4f}, Accuracy: {accuracy:.2f}, Recall: {recall:.2f}, F1: {f1:.2f}")
torch.save(model.state_dict(), "data/classifier.pth")
if __name__ == "__main__":
main()