Ariamehr commited on
Commit
7ada48a
·
verified ·
1 Parent(s): 4a0dd7b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -27
app.py CHANGED
@@ -3,53 +3,57 @@ import gradio as gr
3
  from PIL import Image
4
  import numpy as np
5
 
6
- # Load the model from the local file
7
  model_path = "sapiens_0.3b_render_people_epoch_100_torchscript.pt2"
8
  model = torch.jit.load(model_path, map_location=torch.device('cpu'))
 
9
 
10
- # Define the prediction function
11
  def predict(image):
12
  try:
13
  print("Predict function called")
14
-
15
- # تغییر اندازه تصویر به اندازه‌ای که مدل نیاز دارد
16
- target_size = (3072, 3072) # یا هر اندازه‌ای که مدل نیاز دارد
17
  image = image.resize(target_size)
18
-
19
- # Preprocess image
20
  image = image.convert("RGB")
21
- input_tensor = torch.from_numpy(np.array(image)).permute(2, 0, 1).unsqueeze(0).float() / 255.0
22
-
23
- print("Image preprocessed")
24
-
25
- # Run the model
 
 
 
 
26
  with torch.no_grad():
27
  output = model(input_tensor)
28
 
29
- print("Model executed")
30
-
31
- # Postprocess the output
32
- output_image = output.squeeze().permute(1, 2, 0).numpy()
 
33
  output_image = (output_image * 255).astype(np.uint8)
 
34
 
35
- print("Output generated")
36
- return Image.fromarray(output_image)
 
37
 
38
  except Exception as e:
39
  print(f"Error during prediction: {str(e)}")
40
  return None
41
 
42
-
43
-
44
- # Gradio Interface
45
  iface = gr.Interface(
46
- fn=predict,
47
- inputs=gr.Image(type="pil"),
48
- outputs=gr.Image(type="pil"),
49
- title="Sapiens Body Part Segmentation",
50
- description="Upload an image to segment body parts using the Sapiens model."
51
  )
52
 
53
- # Launch the interface
54
  if __name__ == "__main__":
55
  iface.launch()
 
3
  from PIL import Image
4
  import numpy as np
5
 
6
+ # بارگذاری مدل
7
  model_path = "sapiens_0.3b_render_people_epoch_100_torchscript.pt2"
8
  model = torch.jit.load(model_path, map_location=torch.device('cpu'))
9
+ model.eval()
10
 
 
11
  def predict(image):
12
  try:
13
  print("Predict function called")
14
+
15
+ # تعیین اندازه ورودی مدل
16
+ target_size = (512, 512) # این اندازه را می‌توانید تغییر دهید
17
  image = image.resize(target_size)
18
+
19
+ # پیش‌پردازش تصویر
20
  image = image.convert("RGB")
21
+ input_tensor = np.array(image)
22
+ input_tensor = input_tensor.transpose(2, 0, 1) # تبدیل از HWC به CHW
23
+ input_tensor = input_tensor[np.newaxis, :] # افزودن بعد batch
24
+ input_tensor = input_tensor / 255.0 # نرمال‌سازی
25
+ input_tensor = torch.from_numpy(input_tensor).float()
26
+
27
+ print(f"Input tensor shape: {input_tensor.shape}")
28
+
29
+ # اجرای مدل
30
  with torch.no_grad():
31
  output = model(input_tensor)
32
 
33
+ print(f"Output tensor shape: {output.shape}")
34
+
35
+ # پس‌پردازش خروجی
36
+ output_image = output.squeeze().cpu().numpy()
37
+ output_image = output_image.transpose(1, 2, 0) # تبدیل از CHW به HWC
38
  output_image = (output_image * 255).astype(np.uint8)
39
+ output_image = Image.fromarray(output_image)
40
 
41
+ print("Output image generated successfully")
42
+
43
+ return output_image
44
 
45
  except Exception as e:
46
  print(f"Error during prediction: {str(e)}")
47
  return None
48
 
49
+ # تعریف رابط Gradio
 
 
50
  iface = gr.Interface(
51
+ fn=predict,
52
+ inputs=gr.Image(type="pil", label="Input Image"),
53
+ outputs=gr.Image(type="pil", label="Output Image"),
54
+ title="Sapiens Model Inference",
55
+ description="Upload an image to process with the Sapiens model."
56
  )
57
 
 
58
  if __name__ == "__main__":
59
  iface.launch()