yashbyname commited on
Commit
fd482de
·
verified ·
1 Parent(s): ae01e5c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -13
app.py CHANGED
@@ -6,7 +6,7 @@ import tf_keras
6
  import tensorflow_hub as hub
7
  import numpy as np
8
  from PIL import Image
9
- import os
10
  from typing import Optional, Dict, Any, Union
11
 
12
  # Set up logging
@@ -51,22 +51,33 @@ class MedicalDiagnosisModel:
51
  logger.error(f"Error loading model: {str(e)}")
52
  return None
53
 
54
- def preprocess_image(self, image: Image.Image):
55
  """Preprocess the input image for model prediction."""
56
  try:
57
- # Convert to RGB and resize
58
- image = image.convert('RGB')
59
- image = image.resize((256, 256))
60
 
61
- # Convert to numpy array and normalize
62
- image_array = np.array(image)
63
- image_array = image_array / 255.0
 
 
 
64
 
65
- # Add batch dimension
66
- image_array = np.expand_dims(image_array, axis=0)
67
- logger.info(f"Preprocessed image shape: {image_array.shape}")
 
68
 
69
- return image_array
 
 
 
 
 
 
 
 
 
70
 
71
  except Exception as e:
72
  logger.error(f"Error preprocessing image: {str(e)}")
@@ -149,7 +160,7 @@ class MedicalDiagnosisApp:
149
  self.api = MedicalDiagnosisAPI(api_key, user_id)
150
 
151
  def process_request(self, patient_info: str,
152
- image: Optional[Image.Image]) -> str:
153
  """Process a medical diagnosis request."""
154
  try:
155
  if self.model.model is None:
@@ -229,6 +240,7 @@ if __name__ == "__main__":
229
  iface = create_gradio_interface()
230
  iface.launch(
231
  server_name="0.0.0.0",
 
232
  debug=True
233
  )
234
 
 
6
  import tensorflow_hub as hub
7
  import numpy as np
8
  from PIL import Image
9
+ import io
10
  from typing import Optional, Dict, Any, Union
11
 
12
  # Set up logging
 
51
  logger.error(f"Error loading model: {str(e)}")
52
  return None
53
 
54
+ def preprocess_image(self, image: np.ndarray) -> np.ndarray:
55
  """Preprocess the input image for model prediction."""
56
  try:
57
+ logger.info(f"Input image shape: {image.shape}, dtype: {image.dtype}")
 
 
58
 
59
+ # If image is RGBA, convert to RGB
60
+ if image.shape[-1] == 4:
61
+ logger.info("Converting RGBA to RGB")
62
+ # Convert to PIL Image and back to handle RGBA->RGB conversion
63
+ image = Image.fromarray(image).convert('RGB')
64
+ image = np.array(image)
65
 
66
+ # Resize image
67
+ image = tf_keras.preprocessing.image.smart_resize(
68
+ image, (256, 256), interpolation='bilinear'
69
+ )
70
 
71
+ # Ensure values are between 0 and 1
72
+ if image.max() > 1.0:
73
+ image = image / 255.0
74
+
75
+ # Add batch dimension if not present
76
+ if len(image.shape) == 3:
77
+ image = np.expand_dims(image, axis=0)
78
+
79
+ logger.info(f"Preprocessed image shape: {image.shape}")
80
+ return image
81
 
82
  except Exception as e:
83
  logger.error(f"Error preprocessing image: {str(e)}")
 
160
  self.api = MedicalDiagnosisAPI(api_key, user_id)
161
 
162
  def process_request(self, patient_info: str,
163
+ image: Optional[np.ndarray]) -> str:
164
  """Process a medical diagnosis request."""
165
  try:
166
  if self.model.model is None:
 
240
  iface = create_gradio_interface()
241
  iface.launch(
242
  server_name="0.0.0.0",
243
+ share=True, # Enable public link
244
  debug=True
245
  )
246