Ariamehr commited on
Commit
15872da
·
verified ·
1 Parent(s): 6c7d661

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -16
app.py CHANGED
@@ -1,22 +1,57 @@
1
- def preprocess_image(image):
2
- # Resize the image to a fixed size (e.g., 1024x768)
3
- image = image.resize((1024, 768))
 
4
 
5
- # Convert to RGB and normalize pixel values
6
- input_tensor = np.array(image.convert("RGB")) / 255.0
 
 
7
 
8
- # Divide the image into patches (adjust patch size as needed)
9
- patch_size = 16 # Assuming a patch size of 16 based on model information
10
- num_patches = (1024 // patch_size) * (768 // patch_size)
11
- input_tensor = input_tensor.reshape((num_patches, patch_size, patch_size, 3))
12
 
13
- # Flatten the patches
14
- input_tensor = input_tensor.reshape(-1, patch_size * patch_size * 3)
15
 
16
- # Add batch dimension
17
- input_tensor = input_tensor[np.newaxis, :]
 
 
 
 
 
18
 
19
- # Convert to PyTorch tensor
20
- input_tensor = torch.from_numpy(input_tensor)
21
 
22
- return input_tensor
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ from PIL import Image
4
+ import numpy as np
5
 
6
+ # بارگذاری مدل
7
+ model_path = "sapiens_0.3b_render_people_epoch_100_torchscript.pt2"
8
+ model = torch.jit.load(model_path, map_location=torch.device('cpu'))
9
+ model.eval()
10
 
11
+ def predict(image):
12
+     try:
13
+         print("Predict function called")
 
14
 
15
+         # تغییر اندازه تصویر به 224x224
16
+         image = image.resize((224, 224))  # تغییر اندازه به 224x224
17
 
18
+         # پیش‌پردازش تصویر
19
+         image = image.convert("RGB")
20
+         input_tensor = np.array(image)
21
+         input_tensor = input_tensor.transpose(2, 0, 1)  # تبدیل از HWC به CHW
22
+         input_tensor = input_tensor[np.newaxis, :]  # افزودن بعد batch
23
+         input_tensor = input_tensor / 255.0  # نرمال‌سازی
24
+         input_tensor = torch.from_numpy(input_tensor).float()
25
 
26
+         print(f"Input tensor shape: {input_tensor.shape}")
 
27
 
28
+         # اجرای مدل
29
+         with torch.no_grad():
30
+             output = model(input_tensor)
31
+
32
+         print(f"Output tensor shape: {output.shape}")
33
+
34
+         # پس‌پردازش خروجی
35
+         output_image = output.squeeze().cpu().numpy()
36
+         output_image = output_image.transpose(1, 2, 0)  # تبدیل از CHW به HWC
37
+         output_image = (output_image * 255).astype(np.uint8)
38
+         output_image = Image.fromarray(output_image)
39
+
40
+         print("Output image generated successfully")
41
+         return output_image
42
+
43
+     except Exception as e:
44
+         print(f"Error during prediction: {str(e)}")
45
+         return None
46
+
47
+ # تعریف رابط Gradio
48
+ iface = gr.Interface(
49
+     fn=predict,
50
+     inputs=gr.Image(type="pil", label="Input Image"),
51
+     outputs=gr.Image(type="pil", label="Output Image"),
52
+     title="Sapiens Model Inference",
53
+     description="Upload an image to process with the Sapiens model."
54
+ )
55
+
56
+ if __name__ == "__main__":
57
+     iface.launch(share=True)