wyyadd commited on
Commit
4130e74
·
1 Parent(s): d5e315b

feat: model download

Browse files
Files changed (3) hide show
  1. app.py +1 -1
  2. main.py +3 -2
  3. model.py +15 -0
app.py CHANGED
@@ -8,7 +8,7 @@ def get_face_type(img):
8
  pred_binary = get_pred_binary(img)
9
  result = [f"{label}: {bool(pred)}" for label, pred in zip(TARGET_LABELS, pred_binary)]
10
  face_type = int(''.join(map(str, pred_binary)), 2)
11
- result = f"face_type: {face_type}\n{"\n".join(result)}"
12
  return result
13
 
14
 
 
8
  pred_binary = get_pred_binary(img)
9
  result = [f"{label}: {bool(pred)}" for label, pred in zip(TARGET_LABELS, pred_binary)]
10
  face_type = int(''.join(map(str, pred_binary)), 2)
11
+ result = f"face_type: {face_type}\n{'\n'.join(result)}"
12
  return result
13
 
14
 
main.py CHANGED
@@ -5,11 +5,12 @@ import torch
5
  from deepface import DeepFace
6
  from fastapi import FastAPI, HTTPException
7
 
8
- from model import MultiLabelClassifier
9
 
10
  app = FastAPI()
11
 
12
- model_path = "classifier.pth"
 
13
  model = MultiLabelClassifier(embedding_dim=4096, hidden_dim=1024)
14
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
  model.load_state_dict(torch.load(model_path, weights_only=True))
 
5
  from deepface import DeepFace
6
  from fastapi import FastAPI, HTTPException
7
 
8
+ from model import MultiLabelClassifier, ensure_model_downloaded
9
 
10
  app = FastAPI()
11
 
12
+ model_path = "data/classifier.pth"
13
+ ensure_model_downloaded(model_path)
14
  model = MultiLabelClassifier(embedding_dim=4096, hidden_dim=1024)
15
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
  model.load_state_dict(torch.load(model_path, weights_only=True))
model.py CHANGED
@@ -1,10 +1,12 @@
1
  import logging
 
2
  import pickle
3
  from concurrent.futures import ProcessPoolExecutor, as_completed
4
  from pathlib import Path
5
 
6
  import pandas
7
  import pandas as pd
 
8
  import torch
9
  from deepface import DeepFace
10
  from sklearn.metrics import accuracy_score, recall_score, f1_score
@@ -35,6 +37,19 @@ def load_df(target_labels: list[str]):
35
  return train_df, test_df
36
 
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  class EmbeddingDataset(Dataset):
39
  def __init__(self, df: pandas.DataFrame, target_labels: list[str]):
40
  self.df = df
 
1
  import logging
2
+ import os
3
  import pickle
4
  from concurrent.futures import ProcessPoolExecutor, as_completed
5
  from pathlib import Path
6
 
7
  import pandas
8
  import pandas as pd
9
+ import requests
10
  import torch
11
  from deepface import DeepFace
12
  from sklearn.metrics import accuracy_score, recall_score, f1_score
 
37
  return train_df, test_df
38
 
39
 
40
+ def ensure_model_downloaded(model_path: str):
41
+ os.makedirs(os.path.dirname(model_path), exist_ok=True)
42
+ if not os.path.exists(model_path):
43
+ logging.warning("Model not found. Downloading from GitHub...")
44
+ response = requests.get("https://github.com/wyyadd/facetype/releases/download/1.0.0/classifier.pth")
45
+ if response.status_code != 200:
46
+ logging.error("Failed to download classifier.pth")
47
+ raise RuntimeError("Failed to download model.")
48
+ with open(model_path, "wb") as f:
49
+ f.write(response.content)
50
+ logging.info("Download complete.")
51
+
52
+
53
  class EmbeddingDataset(Dataset):
54
  def __init__(self, df: pandas.DataFrame, target_labels: list[str]):
55
  self.df = df