Ariamehr commited on
Commit
389db8f
·
verified ·
1 Parent(s): 99c6564

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -22
app.py CHANGED
@@ -1,35 +1,37 @@
1
  import torch
2
- from transformers import AutoImageProcessor, AutoModelForImageSegmentation
3
  from PIL import Image
4
  import requests
5
- import gradio as gr
6
 
7
- # لود پردازشگر تصویر و مدل
8
- processor = AutoImageProcessor.from_pretrained("facebook/sapiens/tree/main/sapiens_lite_host/torchscript/seg/checkpoints/sapiens_0.3b")
9
- model = AutoModelForImageSegmentation.from_pretrained("facebook/sapiens/tree/main/sapiens_lite_host/torchscript/seg/checkpoints/sapiens_0.3b")
10
 
11
- # تابع برای پردازش تصویر ورودی و اعمال مدل بر روی آن
12
- def segment_image(image):
13
- # پردازش تصویر
14
- inputs = processor(images=image, return_tensors="pt")
 
15
 
16
- # اجرای مدل روی تصویر پردازش شده
17
  with torch.no_grad():
18
- outputs = model(**inputs)
19
-
20
- # فرض می‌کنیم خروجی یک ماسک است (مثل کلاس‌بندی پیکسل‌ها)
21
- segmentation = outputs.logits.argmax(dim=1).detach().cpu().numpy()[0]
22
 
23
- # بازگرداندن نتیجه به صورت تصویر
24
- return Image.fromarray(segmentation)
 
 
25
 
26
- # ایجاد رابط Gradio برای بارگذاری تصویر و نمایش نتیجه
27
- interface = gr.Interface(
28
- fn=segment_image,
29
  inputs=gr.Image(type="pil"),
30
  outputs=gr.Image(type="pil"),
31
- title="Sapiens Body Part Segmentation"
 
32
  )
33
 
34
- # اجرای برنامه
35
- interface.launch()
 
 
1
  import torch
2
+ import gradio as gr
3
  from PIL import Image
4
  import requests
5
+ from io import BytesIO
6
 
7
+ # Load the model
8
+ model_url = "https://huggingface.co/facebook/sapiens/resolve/main/sapiens_lite_host/torchscript/normal/checkpoints/sapiens_0.3b/sapiens_0.3b.pt"
9
+ model = torch.jit.load(model_url, map_location=torch.device('cpu'))
10
 
11
+ # Define inference function
12
+ def predict(image):
13
+ # Preprocess image
14
+ image = image.convert("RGB")
15
+ input_tensor = torch.from_numpy(np.array(image)).permute(2, 0, 1).unsqueeze(0).float() / 255.0
16
 
17
+ # Run model
18
  with torch.no_grad():
19
+ output = model(input_tensor)
 
 
 
20
 
21
+ # Postprocess output
22
+ output_image = output.squeeze().permute(1, 2, 0).numpy()
23
+ output_image = (output_image * 255).astype(np.uint8)
24
+ return Image.fromarray(output_image)
25
 
26
+ # Gradio Interface
27
+ iface = gr.Interface(
28
+ fn=predict,
29
  inputs=gr.Image(type="pil"),
30
  outputs=gr.Image(type="pil"),
31
+ title="Sapiens Body Part Segmentation",
32
+ description="Upload an image to segment body parts using the Sapiens model."
33
  )
34
 
35
+ # Launch the interface
36
+ if __name__ == "__main__":
37
+ iface.launch()