yashbyname commited on
Commit
dfaf276
·
verified ·
1 Parent(s): fd482de

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +137 -205
app.py CHANGED
@@ -7,240 +7,172 @@ import tensorflow_hub as hub
7
  import numpy as np
8
  from PIL import Image
9
  import io
10
- from typing import Optional, Dict, Any, Union
11
 
12
- # Set up logging
13
  logging.basicConfig(
14
  level=logging.INFO,
15
  format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
16
  )
17
  logger = logging.getLogger(__name__)
18
 
19
- class MedicalDiagnosisModel:
20
- def __init__(self, model_path: str):
21
- self.model_path = model_path
22
- self.model = self._load_model()
 
 
 
23
 
24
- def _load_model(self) -> Optional[tf_keras.Model]:
25
- """Load the transfer learning model with proper error handling."""
26
- try:
27
- if not os.path.exists(self.model_path):
28
- raise FileNotFoundError(f"Model file not found at {self.model_path}")
29
-
30
- logger.info(f"Loading model from {self.model_path}")
31
-
32
- # Define custom objects dictionary for transfer learning
33
- custom_objects = {
34
- 'KerasLayer': hub.KerasLayer
35
- }
36
-
37
- try:
38
- logger.info("Attempting to load model with custom objects...")
39
- with tf_keras.utils.custom_object_scope(custom_objects):
40
- model = tf_keras.models.load_model(self.model_path, compile=False)
41
- except Exception as e:
42
- logger.error(f"Failed to load with custom objects: {str(e)}")
43
- logger.info("Attempting to load model without custom objects...")
44
- model = tf_keras.models.load_model(self.model_path, compile=False)
45
-
46
- model.summary()
47
- logger.info("Model loaded successfully")
48
- return model
49
-
50
- except Exception as e:
51
- logger.error(f"Error loading model: {str(e)}")
52
- return None
53
-
54
- def preprocess_image(self, image: np.ndarray) -> np.ndarray:
55
- """Preprocess the input image for model prediction."""
56
- try:
57
- logger.info(f"Input image shape: {image.shape}, dtype: {image.dtype}")
58
-
59
- # If image is RGBA, convert to RGB
60
- if image.shape[-1] == 4:
61
- logger.info("Converting RGBA to RGB")
62
- # Convert to PIL Image and back to handle RGBA->RGB conversion
63
- image = Image.fromarray(image).convert('RGB')
64
- image = np.array(image)
65
-
66
- # Resize image
67
- image = tf_keras.preprocessing.image.smart_resize(
68
- image, (256, 256), interpolation='bilinear'
69
- )
70
-
71
- # Ensure values are between 0 and 1
72
- if image.max() > 1.0:
73
- image = image / 255.0
74
-
75
- # Add batch dimension if not present
76
- if len(image.shape) == 3:
77
- image = np.expand_dims(image, axis=0)
78
-
79
- logger.info(f"Preprocessed image shape: {image.shape}")
80
- return image
81
-
82
- except Exception as e:
83
- logger.error(f"Error preprocessing image: {str(e)}")
84
- raise
85
 
86
- def predict(self, image: np.ndarray) -> Dict[str, float]:
87
- """Run model prediction and return results."""
88
- try:
89
- prediction = self.model.predict(image)
90
- return {
91
- "prediction": float(prediction[0][0]),
92
- "confidence": float(prediction[0][0]) * 100
93
- }
94
- except Exception as e:
95
- logger.error(f"Error during prediction: {str(e)}")
96
- raise
97
-
98
- class MedicalDiagnosisAPI:
99
- def __init__(self, api_key: str, user_id: str):
100
- self.api_key = api_key
101
- self.user_id = user_id
102
- self.base_url = "https://api.example.com/v1" # Replace with actual API URL
103
 
104
- def create_chat_session(self) -> str:
105
- """Create a new chat session and return session ID."""
106
  try:
107
- response = requests.post(
108
- f"{self.base_url}/sessions",
109
- headers={
110
- "Authorization": f"Bearer {self.api_key}",
111
- "X-User-ID": self.user_id
112
- }
113
- )
114
- response.raise_for_status()
115
- return response.json()["session_id"]
116
  except Exception as e:
117
- logger.error(f"Error creating chat session: {str(e)}")
118
- raise
119
-
120
- def submit_query(self, session_id: str, patient_info: str,
121
- image_analysis: Optional[str] = None) -> Dict[str, Any]:
122
- """Submit a query to the API and return the response."""
123
- try:
124
- payload = {
125
- "patient_info": patient_info,
126
- "image_analysis": image_analysis
127
- }
128
 
129
- response = requests.post(
130
- f"{self.base_url}/sessions/{session_id}/query",
131
- headers={
132
- "Authorization": f"Bearer {self.api_key}",
133
- "X-User-ID": self.user_id
134
- },
135
- json=payload
136
- )
137
- response.raise_for_status()
138
- return response.json()
139
- except Exception as e:
140
- logger.error(f"Error submitting query: {str(e)}")
141
- raise
142
 
143
- def extract_json_from_answer(answer: str) -> Dict[str, Any]:
144
- """Extract and parse JSON from the API response."""
 
 
 
 
 
 
 
 
145
  try:
146
- # Find JSON content between triple backticks if present
147
- if "```json" in answer and "```" in answer:
148
- json_str = answer.split("```json")[1].split("```")[0].strip()
149
- else:
150
- json_str = answer.strip()
151
-
152
- return json.loads(json_str)
 
 
 
 
 
 
 
 
 
153
  except Exception as e:
154
- logger.error(f"Error extracting JSON from answer: {str(e)}")
155
  raise
156
 
157
- class MedicalDiagnosisApp:
158
- def __init__(self, model_path: str, api_key: str, user_id: str):
159
- self.model = MedicalDiagnosisModel(model_path)
160
- self.api = MedicalDiagnosisAPI(api_key, user_id)
161
-
162
- def process_request(self, patient_info: str,
163
- image: Optional[np.ndarray]) -> str:
164
- """Process a medical diagnosis request."""
165
- try:
166
- if self.model.model is None:
167
- return json.dumps({
168
- "error": "Model initialization failed",
169
- "status": "error"
170
- }, indent=2)
171
-
172
- # Process image if provided
173
- image_analysis = None
174
- if image is not None:
175
- processed_image = self.model.preprocess_image(image)
176
- image_analysis = self.model.predict(processed_image)
177
- logger.info(f"Image analysis results: {image_analysis}")
178
-
179
- # Create chat session and submit query
180
- session_id = self.api.create_chat_session()
181
- llm_response = self.api.submit_query(
182
- session_id,
183
- patient_info,
184
- json.dumps(image_analysis) if image_analysis else None
185
- )
186
-
187
- if not llm_response or 'data' not in llm_response or 'answer' not in llm_response['data']:
188
- raise ValueError("Invalid response structure from LLM")
189
-
190
- json_data = extract_json_from_answer(llm_response['data']['answer'])
191
- return json.dumps(json_data, indent=2)
192
-
193
- except Exception as e:
194
- logger.error(f"Error processing request: {str(e)}")
195
  return json.dumps({
196
- "error": str(e),
197
- "status": "error",
198
- "details": "Check the application logs for more information"
199
  }, indent=2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
 
201
- def create_gradio_interface() -> gr.Interface:
202
- """Create and configure the Gradio interface."""
203
- app = MedicalDiagnosisApp(
204
- model_path='model_epoch_01.h5.keras',
205
- api_key='KGSjxB1uptfSk8I8A7ciCuNT9Xa3qWC3',
206
- user_id='plugin-1717464304'
207
- )
208
-
209
- return gr.Interface(
210
- fn=app.process_request,
211
- inputs=[
212
- gr.Textbox(
213
- label="Patient Information",
214
- placeholder="Enter patient details including: symptoms, medical history, current medications, age, gender, and any relevant test results...",
215
- lines=5,
216
- max_lines=10
217
- ),
218
- gr.Image(
219
- label="Medical Image",
220
- type="numpy",
221
- interactive=True
222
- )
223
- ],
224
- outputs=gr.Textbox(
225
- label="Medical Analysis",
226
- placeholder="JSON analysis will appear here...",
227
- lines=15
228
  ),
229
- title="Medical Diagnosis Assistant",
230
- description="Enter patient information and optionally upload a medical image for analysis."
231
- )
 
 
 
 
 
 
 
 
 
 
 
232
 
233
  if __name__ == "__main__":
234
- # Log version information
235
- logger.info(f"TF-Keras version: {tf_keras.__version__}")
236
  logger.info(f"TensorFlow Hub version: {hub.__version__}")
237
  logger.info(f"Gradio version: {gr.__version__}")
238
 
239
- # Create and launch the interface
240
- iface = create_gradio_interface()
241
  iface.launch(
242
  server_name="0.0.0.0",
243
- share=True, # Enable public link
244
  debug=True
245
  )
246
 
 
7
  import numpy as np
8
  from PIL import Image
9
  import io
10
+ import os
11
 
12
+ # Set up logging with more detailed format
13
  logging.basicConfig(
14
  level=logging.INFO,
15
  format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
16
  )
17
  logger = logging.getLogger(__name__)
18
 
19
+ # API key and user ID for on-demand
20
+ api_key = 'KGSjxB1uptfSk8I8A7ciCuNT9Xa3qWC3'
21
+ external_user_id = 'plugin-1717464304'
22
+
23
+ def load_model():
24
+ try:
25
+ model_path = 'model_epoch_01.h5.keras'
26
 
27
+ # Check if model file exists
28
+ if not os.path.exists(model_path):
29
+ raise FileNotFoundError(f"Model file not found at {model_path}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
+ logger.info(f"Attempting to load model from {model_path}")
32
+
33
+ # Define custom objects dictionary
34
+ custom_objects = {
35
+ 'KerasLayer': hub.KerasLayer
36
+ # Add more custom objects if needed
37
+ }
 
 
 
 
 
 
 
 
 
 
38
 
39
+ # Try loading with different configurations
 
40
  try:
41
+ logger.info("Attempting to load model with custom objects...")
42
+ with tf.keras.utils.custom_object_scope(custom_objects):
43
+ model = tf_keras.models.load_model(model_path, compile=False)
 
 
 
 
 
 
44
  except Exception as e:
45
+ logger.error(f"Failed to load with custom objects: {str(e)}")
46
+ logger.info("Attempting to load model without custom objects...")
47
+ model = tf_keras.models.load_model(model_path, compile=False)
48
+
49
+ # Verify model loaded correctly
50
+ if model is None:
51
+ raise ValueError("Model loading returned None")
 
 
 
 
52
 
53
+ # Print model summary for debugging
54
+ model.summary()
55
+ logger.info("Model loaded successfully")
56
+ return model
57
+
58
+ except Exception as e:
59
+ logger.error(f"Error loading model: {str(e)}")
60
+ logger.error(f"Model loading failed with exception type: {type(e)}")
61
+ raise
 
 
 
 
62
 
63
+ # Initialize the model globally
64
+ try:
65
+ logger.info("Initializing model...")
66
+ model = load_model()
67
+ logger.info("Model initialization completed")
68
+ except Exception as e:
69
+ logger.error(f"Failed to initialize model: {str(e)}")
70
+ model = None
71
+
72
+ def preprocess_image(image):
73
  try:
74
+ # Log image shape and type for debugging
75
+ logger.info(f"Input image shape: {image.shape}, dtype: {image.dtype}")
76
+
77
+ image = image.convert('rgb')
78
+ image = image.resize((256, 256, 3))
79
+ image = np.array(image)
80
+
81
+ # Normalize pixel values
82
+ image = image / 255.0
83
+
84
+ # Add batch dimension
85
+ image = np.expand_dims(image, axis=0)
86
+ logger.info(f"Final preprocessed image shape: {image.shape}")
87
+
88
+ return image
89
+
90
  except Exception as e:
91
+ logger.error(f"Error preprocessing image: {str(e)}")
92
  raise
93
 
94
+ def gradio_interface(patient_info, image):
95
+ try:
96
+ if model is None:
97
+ logger.error("Model is not initialized")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  return json.dumps({
99
+ "error": "Model initialization failed. Please check the logs for details.",
100
+ "status": "error"
 
101
  }, indent=2)
102
+
103
+ # Process image if provided
104
+ image_analysis = None
105
+ if image is not None:
106
+ logger.info("Processing uploaded image")
107
+ # Preprocess image
108
+ processed_image = preprocess_image(image)
109
+
110
+ # Get model prediction
111
+ logger.info("Running model prediction")
112
+ prediction = model.predict(processed_image)
113
+ logger.info(f"Raw prediction shape: {prediction.shape}")
114
+ logger.info(f"Prediction: {prediction}")
115
+ # Format prediction results
116
+ image_analysis = {
117
+ "prediction": float(prediction[0][0]),
118
+ "confidence": float(prediction[0][0]) * 100
119
+ }
120
+ logger.info(f"Image analysis results: {image_analysis}")
121
+
122
+ # Create chat session and submit query
123
+ session_id = create_chat_session()
124
+ llm_response = submit_query(session_id, patient_info,
125
+ json.dumps(image_analysis) if image_analysis else None)
126
+
127
+ if not llm_response or 'data' not in llm_response or 'answer' not in llm_response['data']:
128
+ raise ValueError("Invalid response structure from LLM")
129
+
130
+ # Extract and clean JSON from the response
131
+ json_data = extract_json_from_answer(llm_response['data']['answer'])
132
+
133
+ return json.dumps(json_data, indent=2)
134
+
135
+ except Exception as e:
136
+ logger.error(f"Error in gradio_interface: {str(e)}")
137
+ return json.dumps({
138
+ "error": str(e),
139
+ "status": "error",
140
+ "details": "Check the application logs for more information"
141
+ }, indent=2)
142
 
143
+ # Gradio interface
144
+ iface = gr.Interface(
145
+ fn=gradio_interface,
146
+ inputs=[
147
+ gr.Textbox(
148
+ label="Patient Information",
149
+ placeholder="Enter patient details including: symptoms, medical history, current medications, age, gender, and any relevant test results...",
150
+ lines=5,
151
+ max_lines=10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
  ),
153
+ gr.Image(
154
+ label="Medical Image",
155
+ type="numpy",
156
+ interactive=True
157
+ )
158
+ ],
159
+ outputs=gr.Textbox(
160
+ label="Medical Analysis",
161
+ placeholder="JSON analysis will appear here...",
162
+ lines=15
163
+ ),
164
+ title="Medical Diagnosis Assistant",
165
+ description="Enter patient information and optionally upload a medical image for analysis."
166
+ )
167
 
168
  if __name__ == "__main__":
169
+ # Add version information logging
170
+ logger.info(f"TensorFlow Keras version: {tf_keras.__version__}")
171
  logger.info(f"TensorFlow Hub version: {hub.__version__}")
172
  logger.info(f"Gradio version: {gr.__version__}")
173
 
 
 
174
  iface.launch(
175
  server_name="0.0.0.0",
 
176
  debug=True
177
  )
178