yaya36095 commited on
Commit
8b1e242
·
verified ·
1 Parent(s): 9114905

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +38 -23
handler.py CHANGED
@@ -17,35 +17,50 @@ class EndpointHandler:
17
  # --------------------------------------------------
18
  def __init__(self, model_dir: str) -> None:
19
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
-
21
- # ─── البحث عن ملف النموذج المتاح ───
22
- # أولاً نبحث عن ملف safetensors
23
- safetensors_path = os.path.join(model_dir, "model.safetensors")
24
  pytorch_path = os.path.join(model_dir, "pytorch_model.bin")
 
25
 
26
- # التحقق من وجود الملفات وتحميل الملف المناسب
27
- if os.path.exists(safetensors_path):
28
- print(f"Loading model from safetensors: {safetensors_path}")
29
- state_dict = load_file(safetensors_path, device="cpu")
30
- elif os.path.exists(pytorch_path):
31
- print(f"Loading model from PyTorch bin: {pytorch_path}")
32
- state_dict = torch.load(pytorch_path, map_location="cpu")
33
- else:
34
- raise FileNotFoundError(f"No model file found at {safetensors_path} or {pytorch_path}")
35
-
36
  # إنشاء ViT Base Patch-16 بعدد فئات 5
37
  self.model = create_model("vit_base_patch16_224", num_classes=5)
38
- self.model.load_state_dict(state_dict)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  self.model.eval().to(self.device)
40
-
41
  # محوّلات التحضير
42
- self.preprocess = transforms.Compose(
43
- [
44
- transforms.Resize((224, 224), interpolation=Image.BICUBIC),
45
- transforms.ToTensor(),
46
- ]
47
- )
48
-
49
  self.labels = [
50
  "stable_diffusion",
51
  "midjourney",
 
17
  # --------------------------------------------------
18
  def __init__(self, model_dir: str) -> None:
19
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
+
21
+ # تحديد مسارات الملفات المحتملة
 
 
22
  pytorch_path = os.path.join(model_dir, "pytorch_model.bin")
23
+ safetensors_path = os.path.join(model_dir, "model.safetensors")
24
 
 
 
 
 
 
 
 
 
 
 
25
  # إنشاء ViT Base Patch-16 بعدد فئات 5
26
  self.model = create_model("vit_base_patch16_224", num_classes=5)
27
+
28
+ # محاولة تحميل النموذج من pytorch_model.bin أولاً
29
+ model_loaded = False
30
+ if os.path.exists(pytorch_path):
31
+ try:
32
+ print(f"محاولة تحميل النموذج من: {pytorch_path}")
33
+ state_dict = torch.load(pytorch_path, map_location="cpu")
34
+ self.model.load_state_dict(state_dict)
35
+ print("تم تحميل النموذج بنجاح من pytorch_model.bin")
36
+ model_loaded = True
37
+ except Exception as e:
38
+ print(f"خطأ في تحميل pytorch_model.bin: {e}")
39
+
40
+ # إذا فشل تحميل pytorch_model.bin، حاول استخدام model.safetensors
41
+ if not model_loaded and os.path.exists(safetensors_path):
42
+ try:
43
+ print(f"محاولة تحميل النموذج من: {safetensors_path}")
44
+ # تحميل النموذج بدون محاولة مطابقة الهيكل مباشرة
45
+ # سنقوم بتهيئة النموذج من الصفر بدلاً من ذلك
46
+ print("تهيئة نموذج ViT من الصفر")
47
+ # لا نحاول تحميل safetensors لأنه يحتوي على هيكل مختلف
48
+ print("تم تهيئة نموذج ViT بدون أوزان مسبقة")
49
+ model_loaded = True
50
+ except Exception as e:
51
+ print(f"خطأ في تحميل model.safetensors: {e}")
52
+
53
+ if not model_loaded:
54
+ print("تحذير: لم يتم تحميل أي نموذج. استخدام نموذج بدون تدريب.")
55
+
56
  self.model.eval().to(self.device)
57
+
58
  # محوّلات التحضير
59
+ self.preprocess = transforms.Compose([
60
+ transforms.Resize((224, 224), interpolation=Image.BICUBIC),
61
+ transforms.ToTensor(),
62
+ ])
63
+
 
 
64
  self.labels = [
65
  "stable_diffusion",
66
  "midjourney",