Ariamehr commited on
Commit
761ccc2
·
verified ·
1 Parent(s): db988b1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -2
app.py CHANGED
@@ -10,23 +10,31 @@ model = torch.jit.load(model_path, map_location=torch.device('cpu'))
10
  # Define the prediction function
11
  def predict(image):
12
  try:
 
 
13
  # Preprocess image
14
  image = image.convert("RGB")
15
  input_tensor = torch.from_numpy(np.array(image)).permute(2, 0, 1).unsqueeze(0).float() / 255.0
16
 
 
 
17
  # Run the model
18
  with torch.no_grad():
19
  output = model(input_tensor)
 
 
20
 
21
  # Postprocess the output
22
  output_image = output.squeeze().permute(1, 2, 0).numpy()
23
  output_image = (output_image * 255).astype(np.uint8)
 
 
24
  return Image.fromarray(output_image)
25
 
26
  except Exception as e:
27
- # Print the error for debugging
28
  print(f"Error during prediction: {str(e)}")
29
- return None # Return None if there is an error
 
30
 
31
  # Gradio Interface
32
  iface = gr.Interface(
 
10
  # Define the prediction function
11
  def predict(image):
12
  try:
13
+ print("Predict function called") # Check if the function is being called
14
+
15
  # Preprocess image
16
  image = image.convert("RGB")
17
  input_tensor = torch.from_numpy(np.array(image)).permute(2, 0, 1).unsqueeze(0).float() / 255.0
18
 
19
+ print("Image preprocessed") # Check if preprocessing is successful
20
+
21
  # Run the model
22
  with torch.no_grad():
23
  output = model(input_tensor)
24
+
25
+ print("Model executed") # Check if model execution is successful
26
 
27
  # Postprocess the output
28
  output_image = output.squeeze().permute(1, 2, 0).numpy()
29
  output_image = (output_image * 255).astype(np.uint8)
30
+
31
+ print("Output generated") # Check if postprocessing is successful
32
  return Image.fromarray(output_image)
33
 
34
  except Exception as e:
 
35
  print(f"Error during prediction: {str(e)}")
36
+ return None
37
+
38
 
39
  # Gradio Interface
40
  iface = gr.Interface(