Update handler.py
Browse files- handler.py +38 -23
handler.py
CHANGED
@@ -17,35 +17,50 @@ class EndpointHandler:
|
|
17 |
# --------------------------------------------------
|
18 |
def __init__(self, model_dir: str) -> None:
|
19 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
20 |
-
|
21 |
-
#
|
22 |
-
# أولاً نبحث عن ملف safetensors
|
23 |
-
safetensors_path = os.path.join(model_dir, "model.safetensors")
|
24 |
pytorch_path = os.path.join(model_dir, "pytorch_model.bin")
|
|
|
25 |
|
26 |
-
# التحقق من وجود الملفات وتحميل الملف المناسب
|
27 |
-
if os.path.exists(safetensors_path):
|
28 |
-
print(f"Loading model from safetensors: {safetensors_path}")
|
29 |
-
state_dict = load_file(safetensors_path, device="cpu")
|
30 |
-
elif os.path.exists(pytorch_path):
|
31 |
-
print(f"Loading model from PyTorch bin: {pytorch_path}")
|
32 |
-
state_dict = torch.load(pytorch_path, map_location="cpu")
|
33 |
-
else:
|
34 |
-
raise FileNotFoundError(f"No model file found at {safetensors_path} or {pytorch_path}")
|
35 |
-
|
36 |
# إنشاء ViT Base Patch-16 بعدد فئات 5
|
37 |
self.model = create_model("vit_base_patch16_224", num_classes=5)
|
38 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
self.model.eval().to(self.device)
|
40 |
-
|
41 |
# محوّلات التحضير
|
42 |
-
self.preprocess = transforms.Compose(
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
)
|
48 |
-
|
49 |
self.labels = [
|
50 |
"stable_diffusion",
|
51 |
"midjourney",
|
|
|
17 |
# --------------------------------------------------
|
18 |
def __init__(self, model_dir: str) -> None:
|
19 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
20 |
+
|
21 |
+
# تحديد مسارات الملفات المحتملة
|
|
|
|
|
22 |
pytorch_path = os.path.join(model_dir, "pytorch_model.bin")
|
23 |
+
safetensors_path = os.path.join(model_dir, "model.safetensors")
|
24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
# إنشاء ViT Base Patch-16 بعدد فئات 5
|
26 |
self.model = create_model("vit_base_patch16_224", num_classes=5)
|
27 |
+
|
28 |
+
# محاولة تحميل النموذج من pytorch_model.bin أولاً
|
29 |
+
model_loaded = False
|
30 |
+
if os.path.exists(pytorch_path):
|
31 |
+
try:
|
32 |
+
print(f"محاولة تحميل النموذج من: {pytorch_path}")
|
33 |
+
state_dict = torch.load(pytorch_path, map_location="cpu")
|
34 |
+
self.model.load_state_dict(state_dict)
|
35 |
+
print("تم تحميل النموذج بنجاح من pytorch_model.bin")
|
36 |
+
model_loaded = True
|
37 |
+
except Exception as e:
|
38 |
+
print(f"خطأ في تحميل pytorch_model.bin: {e}")
|
39 |
+
|
40 |
+
# إذا فشل تحميل pytorch_model.bin، حاول استخدام model.safetensors
|
41 |
+
if not model_loaded and os.path.exists(safetensors_path):
|
42 |
+
try:
|
43 |
+
print(f"محاولة تحميل النموذج من: {safetensors_path}")
|
44 |
+
# تحميل النموذج بدون محاولة مطابقة الهيكل مباشرة
|
45 |
+
# سنقوم بتهيئة النموذج من الصفر بدلاً من ذلك
|
46 |
+
print("تهيئة نموذج ViT من الصفر")
|
47 |
+
# لا نحاول تحميل safetensors لأنه يحتوي على هيكل مختلف
|
48 |
+
print("تم تهيئة نموذج ViT بدون أوزان مسبقة")
|
49 |
+
model_loaded = True
|
50 |
+
except Exception as e:
|
51 |
+
print(f"خطأ في تحميل model.safetensors: {e}")
|
52 |
+
|
53 |
+
if not model_loaded:
|
54 |
+
print("تحذير: لم يتم تحميل أي نموذج. استخدام نموذج بدون تدريب.")
|
55 |
+
|
56 |
self.model.eval().to(self.device)
|
57 |
+
|
58 |
# محوّلات التحضير
|
59 |
+
self.preprocess = transforms.Compose([
|
60 |
+
transforms.Resize((224, 224), interpolation=Image.BICUBIC),
|
61 |
+
transforms.ToTensor(),
|
62 |
+
])
|
63 |
+
|
|
|
|
|
64 |
self.labels = [
|
65 |
"stable_diffusion",
|
66 |
"midjourney",
|