Ariamehr commited on
Commit
6c7d661
·
verified ·
1 Parent(s): 8e6fd23

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -53
app.py CHANGED
@@ -1,62 +1,22 @@
1
- import torch
2
- import gradio as gr
3
- from PIL import Image
4
- import numpy as np
5
-
6
- # Load the model
7
- model_path = "sapiens_0.3b_render_people_epoch_100_torchscript.pt2"
8
- model = torch.jit.load(model_path, map_location=torch.device('cpu'))
9
- model.eval()
10
-
11
- # Define a function to preprocess images to match the expected input shape
12
-
13
  def preprocess_image(image):
14
- # Resize the image to a fixed size (e.g., 224x224)
15
  image = image.resize((1024, 768))
16
 
17
- # Convert to RGB (without adding extra dimensions)
18
  input_tensor = np.array(image.convert("RGB")) / 255.0
19
 
20
- # Add the batch dimension
21
- input_tensor = input_tensor[np.newaxis, :]
22
-
23
- # Convert the NumPy array to a PyTorch tensor
24
- input_tensor = torch.from_numpy(input_tensor)
25
-
26
- return input_tensor
27
-
28
- def predict(image):
29
- try:
30
- print("Predict function called")
31
 
32
- # Preprocess the image to match the expected input shape
33
- input_tensor = preprocess_image(image)
34
 
35
- print(f"Input tensor shape: {input_tensor.shape}")
36
-
37
- # Run the model
38
- with torch.no_grad():
39
- output = model(input_tensor)
40
-
41
- print(f"Output tensor shape: {output.shape}")
42
-
43
- # Post-process the output (if necessary)
44
- # ...
45
-
46
- return output # Return the output tensor directly
47
-
48
- except Exception as e:
49
- print(f"Error during prediction: {str(e)}")
50
- return None
51
 
52
- # Define the Gradio interface
53
- iface = gr.Interface(
54
- fn=predict,
55
- inputs=gr.Image(type="pil", label="Input Image"),
56
- outputs=gr.Image(type="pil", label="Output Image"),
57
- title="Sapiens Model Inference",
58
- description="Upload an image to process with the Sapiens model."
59
- )
60
 
61
- if __name__ == "__main__":
62
- iface.launch(share=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  def preprocess_image(image):
2
+ # Resize the image to a fixed size (e.g., 1024x768)
3
  image = image.resize((1024, 768))
4
 
5
+ # Convert to RGB and normalize pixel values
6
  input_tensor = np.array(image.convert("RGB")) / 255.0
7
 
8
+ # Divide the image into patches (adjust patch size as needed)
9
+ patch_size = 16 # Assuming a patch size of 16 based on model information
10
+ num_patches = (1024 // patch_size) * (768 // patch_size)
11
+ input_tensor = input_tensor.reshape((num_patches, patch_size, patch_size, 3))
 
 
 
 
 
 
 
12
 
13
+ # Flatten the patches
14
+ input_tensor = input_tensor.reshape(-1, patch_size * patch_size * 3)
15
 
16
+ # Add batch dimension
17
+ input_tensor = input_tensor[np.newaxis, :]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
+ # Convert to PyTorch tensor
20
+ input_tensor = torch.from_numpy(input_tensor)
 
 
 
 
 
 
21
 
22
+ return input_tensor