feat: model download
Browse files
app.py
CHANGED
@@ -8,7 +8,7 @@ def get_face_type(img):
|
|
8 |
pred_binary = get_pred_binary(img)
|
9 |
result = [f"{label}: {bool(pred)}" for label, pred in zip(TARGET_LABELS, pred_binary)]
|
10 |
face_type = int(''.join(map(str, pred_binary)), 2)
|
11 |
-
result = f"face_type: {face_type}\n{
|
12 |
return result
|
13 |
|
14 |
|
|
|
8 |
pred_binary = get_pred_binary(img)
|
9 |
result = [f"{label}: {bool(pred)}" for label, pred in zip(TARGET_LABELS, pred_binary)]
|
10 |
face_type = int(''.join(map(str, pred_binary)), 2)
|
11 |
+
result = f"face_type: {face_type}\n{'\n'.join(result)}"
|
12 |
return result
|
13 |
|
14 |
|
main.py
CHANGED
@@ -5,11 +5,12 @@ import torch
|
|
5 |
from deepface import DeepFace
|
6 |
from fastapi import FastAPI, HTTPException
|
7 |
|
8 |
-
from model import MultiLabelClassifier
|
9 |
|
10 |
app = FastAPI()
|
11 |
|
12 |
-
model_path = "classifier.pth"
|
|
|
13 |
model = MultiLabelClassifier(embedding_dim=4096, hidden_dim=1024)
|
14 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
15 |
model.load_state_dict(torch.load(model_path, weights_only=True))
|
|
|
5 |
from deepface import DeepFace
|
6 |
from fastapi import FastAPI, HTTPException
|
7 |
|
8 |
+
from model import MultiLabelClassifier, ensure_model_downloaded
|
9 |
|
10 |
app = FastAPI()
|
11 |
|
12 |
+
model_path = "data/classifier.pth"
|
13 |
+
ensure_model_downloaded(model_path)
|
14 |
model = MultiLabelClassifier(embedding_dim=4096, hidden_dim=1024)
|
15 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
16 |
model.load_state_dict(torch.load(model_path, weights_only=True))
|
model.py
CHANGED
@@ -1,10 +1,12 @@
|
|
1 |
import logging
|
|
|
2 |
import pickle
|
3 |
from concurrent.futures import ProcessPoolExecutor, as_completed
|
4 |
from pathlib import Path
|
5 |
|
6 |
import pandas
|
7 |
import pandas as pd
|
|
|
8 |
import torch
|
9 |
from deepface import DeepFace
|
10 |
from sklearn.metrics import accuracy_score, recall_score, f1_score
|
@@ -35,6 +37,19 @@ def load_df(target_labels: list[str]):
|
|
35 |
return train_df, test_df
|
36 |
|
37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
class EmbeddingDataset(Dataset):
|
39 |
def __init__(self, df: pandas.DataFrame, target_labels: list[str]):
|
40 |
self.df = df
|
|
|
1 |
import logging
|
2 |
+
import os
|
3 |
import pickle
|
4 |
from concurrent.futures import ProcessPoolExecutor, as_completed
|
5 |
from pathlib import Path
|
6 |
|
7 |
import pandas
|
8 |
import pandas as pd
|
9 |
+
import requests
|
10 |
import torch
|
11 |
from deepface import DeepFace
|
12 |
from sklearn.metrics import accuracy_score, recall_score, f1_score
|
|
|
37 |
return train_df, test_df
|
38 |
|
39 |
|
40 |
+
def ensure_model_downloaded(model_path: str):
|
41 |
+
os.makedirs(os.path.dirname(model_path), exist_ok=True)
|
42 |
+
if not os.path.exists(model_path):
|
43 |
+
logging.warning("Model not found. Downloading from GitHub...")
|
44 |
+
response = requests.get("https://github.com/wyyadd/facetype/releases/download/1.0.0/classifier.pth")
|
45 |
+
if response.status_code != 200:
|
46 |
+
logging.error("Failed to download classifier.pth")
|
47 |
+
raise RuntimeError("Failed to download model.")
|
48 |
+
with open(model_path, "wb") as f:
|
49 |
+
f.write(response.content)
|
50 |
+
logging.info("Download complete.")
|
51 |
+
|
52 |
+
|
53 |
class EmbeddingDataset(Dataset):
|
54 |
def __init__(self, df: pandas.DataFrame, target_labels: list[str]):
|
55 |
self.df = df
|