yaya36095 commited on
Commit
8ea56b1
·
verified ·
1 Parent(s): 00fa5d2

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +35 -13
handler.py CHANGED
@@ -3,39 +3,61 @@ import torch
3
  from torchvision import transforms
4
  from PIL import Image
5
  from safetensors.torch import load_file
6
- from timm import create_model # timm ضرورى للتعامل مع ViT
7
 
8
 
9
- class EndpointHandler: # اسم الفئة مهم جداً
 
 
10
  def __init__(self, model_dir: str):
 
11
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
 
13
  # تحميل الوزن بصيغة safetensors
14
- weights = load_file(os.path.join(model_dir, "model.safetensors"))
 
 
 
15
  self.model = create_model("vit_base_patch16_224", num_classes=5)
16
  self.model.load_state_dict(weights)
17
  self.model.eval().to(self.device)
18
 
19
- self.transform = transforms.Compose([
20
- transforms.Resize((224, 224), interpolation=Image.BICUBIC),
21
- transforms.ToTensor(),
22
- ])
 
 
 
23
 
24
- self.labels = ['stable_diffusion', 'midjourney', 'dalle', 'real', 'other_ai']
 
 
 
 
 
 
25
 
 
26
  def _prep(self, img: Image.Image):
27
  return self.transform(img.convert("RGB")).unsqueeze(0).to(self.device)
28
 
 
29
  def __call__(self, data):
30
- # يدعم: Widget (PIL) أو REST (base64)
 
 
 
 
31
  img = None
32
  if isinstance(data, Image.Image):
33
  img = data
34
  elif isinstance(data, dict):
35
- b = data.get("inputs") or data.get("image")
36
- if isinstance(b, (str, bytes)):
37
- b = b.encode() if isinstance(b, str) else b
38
- img = Image.open(io.BytesIO(base64.b64decode(b)))
 
39
 
40
  if img is None:
41
  return {"error": "No image provided"}
 
3
  from torchvision import transforms
4
  from PIL import Image
5
  from safetensors.torch import load_file
6
+ from timm import create_model # timm ضروري لتشغيل ViT
7
 
8
 
9
+ class EndpointHandler:
10
+ """Custom pipeline for Hugging Face Inference Endpoints."""
11
+
12
  def __init__(self, model_dir: str):
13
+ # اختَر GPU إذا متاح
14
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
 
16
  # تحميل الوزن بصيغة safetensors
17
+ weights_path = os.path.join(model_dir, "model.safetensors")
18
+ weights = load_file(weights_path)
19
+
20
+ # إنشاء نموذج ViT مطابق لِما درّبتَه
21
  self.model = create_model("vit_base_patch16_224", num_classes=5)
22
  self.model.load_state_dict(weights)
23
  self.model.eval().to(self.device)
24
 
25
+ # تحويـلات الصورة
26
+ self.transform = transforms.Compose(
27
+ [
28
+ transforms.Resize((224, 224), interpolation=Image.BICUBIC),
29
+ transforms.ToTensor(),
30
+ ]
31
+ )
32
 
33
+ self.labels = [
34
+ "stable_diffusion",
35
+ "midjourney",
36
+ "dalle",
37
+ "real",
38
+ "other_ai",
39
+ ]
40
 
41
+ # ---------- helpers ----------
42
  def _prep(self, img: Image.Image):
43
  return self.transform(img.convert("RGB")).unsqueeze(0).to(self.device)
44
 
45
+ # ---------- main entry ----------
46
  def __call__(self, data):
47
+ """
48
+ يدعم:
49
+ • Widget: يستلم PIL.Image
50
+ • REST API: يستلم base64 فى data["inputs"] أو data["image"]
51
+ """
52
  img = None
53
  if isinstance(data, Image.Image):
54
  img = data
55
  elif isinstance(data, dict):
56
+ b64 = data.get("inputs") or data.get("image")
57
+ if isinstance(b64, (str, bytes)):
58
+ if isinstance(b64, str):
59
+ b64 = b64.encode()
60
+ img = Image.open(io.BytesIO(base64.b64decode(b64)))
61
 
62
  if img is None:
63
  return {"error": "No image provided"}