Ariamehr commited on
Commit
8ca3087
·
verified ·
1 Parent(s): 6502465

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -21
app.py CHANGED
@@ -3,48 +3,53 @@ import gradio as gr
3
  from PIL import Image
4
  import numpy as np
5
 
6
- # بارگذاری مدل
7
  model_path = "sapiens_0.3b_render_people_epoch_100_torchscript.pt2"
8
  model = torch.jit.load(model_path, map_location=torch.device('cpu'))
9
  model.eval()
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  def predict(image):
12
  try:
13
  print("Predict function called")
14
 
15
- # تغییر اندازه تصویر به 224x224
16
- image = image.resize((224, 224)) # تغییر اندازه به 224x224
17
-
18
- # پیش‌پردازش تصویر
19
- image = image.convert("RGB")
20
- input_tensor = np.array(image)
21
- input_tensor = input_tensor.transpose(2, 0, 1) # تبدیل از HWC به CHW
22
- input_tensor = input_tensor[np.newaxis, :] # افزودن بعد batch
23
- input_tensor = input_tensor / 255.0 # نرمال‌سازی
24
- input_tensor = torch.from_numpy(input_tensor).float()
25
 
26
  print(f"Input tensor shape: {input_tensor.shape}")
27
 
28
- # اجرای مدل
29
  with torch.no_grad():
30
  output = model(input_tensor)
31
 
32
  print(f"Output tensor shape: {output.shape}")
33
 
34
- # پس‌پردازش خروجی
35
- output_image = output.squeeze().cpu().numpy()
36
- output_image = output_image.transpose(1, 2, 0) # تبدیل از CHW به HWC
37
- output_image = (output_image * 255).astype(np.uint8)
38
- output_image = Image.fromarray(output_image)
39
 
40
- print("Output image generated successfully")
41
- return output_image
42
 
43
  except Exception as e:
44
  print(f"Error during prediction: {str(e)}")
45
  return None
46
 
47
- # تعریف رابط Gradio
48
  iface = gr.Interface(
49
  fn=predict,
50
  inputs=gr.Image(type="pil", label="Input Image"),
@@ -54,4 +59,4 @@ iface = gr.Interface(
54
  )
55
 
56
  if __name__ == "__main__":
57
- iface.launch(share=True)
 
3
  from PIL import Image
4
  import numpy as np
5
 
6
+ # Load the model
7
  model_path = "sapiens_0.3b_render_people_epoch_100_torchscript.pt2"
8
  model = torch.jit.load(model_path, map_location=torch.device('cpu'))
9
  model.eval()
10
 
11
+ # Define a function to preprocess images to match the expected input shape
12
+ def preprocess_image(image):
13
+ # Resize the image to a fixed size (e.g., 224x224)
14
+ image = image.resize((224, 224))
15
+
16
+ # Convert to RGB and normalize pixel values
17
+ image = image.convert("RGB")
18
+ input_tensor = np.array(image) / 255.0
19
+
20
+ # Flatten the image into a 1D vector
21
+ input_tensor = input_tensor.reshape(-1)
22
+
23
+ # Add the batch dimension
24
+ input_tensor = input_tensor[np.newaxis, :]
25
+
26
+ return input_tensor
27
+
28
  def predict(image):
29
  try:
30
  print("Predict function called")
31
 
32
+ # Preprocess the image to match the expected input shape
33
+ input_tensor = preprocess_image(image)
 
 
 
 
 
 
 
 
34
 
35
  print(f"Input tensor shape: {input_tensor.shape}")
36
 
37
+ # Run the model
38
  with torch.no_grad():
39
  output = model(input_tensor)
40
 
41
  print(f"Output tensor shape: {output.shape}")
42
 
43
+ # Post-process the output (if necessary)
44
+ # ...
 
 
 
45
 
46
+ return output # Return the output tensor directly
 
47
 
48
  except Exception as e:
49
  print(f"Error during prediction: {str(e)}")
50
  return None
51
 
52
+ # Define the Gradio interface
53
  iface = gr.Interface(
54
  fn=predict,
55
  inputs=gr.Image(type="pil", label="Input Image"),
 
59
  )
60
 
61
  if __name__ == "__main__":
62
+ iface.launch(share=True)