yashbyname commited on
Commit
ef3808c
·
verified ·
1 Parent(s): 43f872f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +193 -137
app.py CHANGED
@@ -6,171 +6,227 @@ import tf_keras
6
  import tensorflow_hub as hub
7
  import numpy as np
8
  from PIL import Image
9
- import io
10
  import os
 
11
 
12
- # Set up logging with more detailed format
13
  logging.basicConfig(
14
  level=logging.INFO,
15
  format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
16
  )
17
  logger = logging.getLogger(__name__)
18
 
19
- # API key and user ID for on-demand
20
- api_key = 'KGSjxB1uptfSk8I8A7ciCuNT9Xa3qWC3'
21
- external_user_id = 'plugin-1717464304'
22
-
23
- def load_model():
24
- try:
25
- model_path = 'model_epoch_01.h5.keras'
26
 
27
- # Check if model file exists
28
- if not os.path.exists(model_path):
29
- raise FileNotFoundError(f"Model file not found at {model_path}")
 
 
 
 
30
 
31
- logger.info(f"Attempting to load model from {model_path}")
32
-
33
- # Define custom objects dictionary
34
- custom_objects = {
35
- 'KerasLayer': hub.KerasLayer
36
- # Add more custom objects if needed
37
- }
38
-
39
- # Try loading with different configurations
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  try:
41
- logger.info("Attempting to load model with custom objects...")
42
- with tf.keras.utils.custom_object_scope(custom_objects):
43
- model = tf_keras.models.load_model(model_path, compile=False)
 
 
 
 
 
 
 
 
 
 
 
44
  except Exception as e:
45
- logger.error(f"Failed to load with custom objects: {str(e)}")
46
- logger.info("Attempting to load model without custom objects...")
47
- model = tf_keras.models.load_model(model_path, compile=False)
48
-
49
- # Verify model loaded correctly
50
- if model is None:
51
- raise ValueError("Model loading returned None")
52
 
53
- # Print model summary for debugging
54
- model.summary()
55
- logger.info("Model loaded successfully")
56
- return model
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
- except Exception as e:
59
- logger.error(f"Error loading model: {str(e)}")
60
- logger.error(f"Model loading failed with exception type: {type(e)}")
61
- raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
- # Initialize the model globally
64
- try:
65
- logger.info("Initializing model...")
66
- model = load_model()
67
- logger.info("Model initialization completed")
68
- except Exception as e:
69
- logger.error(f"Failed to initialize model: {str(e)}")
70
- model = None
71
-
72
- def preprocess_image(image):
73
  try:
74
- # Log image shape and type for debugging
75
- logger.info(f"Input image shape: {image.shape}, dtype: {image.dtype}")
76
-
77
- image = image.convert('rgb')
78
- image = image.resize((256, 256, 3))
79
- image = np.array(image)
80
-
81
- # Normalize pixel values
82
- image = image / 255.0
83
-
84
- # Add batch dimension
85
- image = np.expand_dims(image, axis=0)
86
- logger.info(f"Final preprocessed image shape: {image.shape}")
87
-
88
- return image
89
-
90
  except Exception as e:
91
- logger.error(f"Error preprocessing image: {str(e)}")
92
  raise
93
 
94
- def gradio_interface(patient_info, image):
95
- try:
96
- if model is None:
97
- logger.error("Model is not initialized")
98
- return json.dumps({
99
- "error": "Model initialization failed. Please check the logs for details.",
100
- "status": "error"
101
- }, indent=2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
 
103
- # Process image if provided
104
- image_analysis = None
105
- if image is not None:
106
- logger.info("Processing uploaded image")
107
- # Preprocess image
108
- processed_image = preprocess_image(image)
109
 
110
- # Get model prediction
111
- logger.info("Running model prediction")
112
- prediction = model.predict(processed_image)
113
- logger.info(f"Raw prediction shape: {prediction.shape}")
114
- logger.info(f"Prediction: {prediction}")
115
- # Format prediction results
116
- image_analysis = {
117
- "prediction": float(prediction[0][0]),
118
- "confidence": float(prediction[0][0]) * 100
119
- }
120
- logger.info(f"Image analysis results: {image_analysis}")
121
-
122
- # Create chat session and submit query
123
- session_id = create_chat_session()
124
- llm_response = submit_query(session_id, patient_info,
125
- json.dumps(image_analysis) if image_analysis else None)
126
-
127
- if not llm_response or 'data' not in llm_response or 'answer' not in llm_response['data']:
128
- raise ValueError("Invalid response structure from LLM")
129
-
130
- # Extract and clean JSON from the response
131
- json_data = extract_json_from_answer(llm_response['data']['answer'])
132
-
133
- return json.dumps(json_data, indent=2)
134
-
135
- except Exception as e:
136
- logger.error(f"Error in gradio_interface: {str(e)}")
137
- return json.dumps({
138
- "error": str(e),
139
- "status": "error",
140
- "details": "Check the application logs for more information"
141
- }, indent=2)
142
 
143
- # Gradio interface
144
- iface = gr.Interface(
145
- fn=gradio_interface,
146
- inputs=[
147
- gr.Textbox(
148
- label="Patient Information",
149
- placeholder="Enter patient details including: symptoms, medical history, current medications, age, gender, and any relevant test results...",
150
- lines=5,
151
- max_lines=10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
  ),
153
- gr.Image(
154
- label="Medical Image",
155
- type="numpy",
156
- interactive=True
157
- )
158
- ],
159
- outputs=gr.Textbox(
160
- label="Medical Analysis",
161
- placeholder="JSON analysis will appear here...",
162
- lines=15
163
- ),
164
- title="Medical Diagnosis Assistant",
165
- description="Enter patient information and optionally upload a medical image for analysis."
166
- )
167
 
168
  if __name__ == "__main__":
169
- # Add version information logging
170
- logger.info(f"TensorFlow Keras version: {tf_keras.__version__}")
171
  logger.info(f"TensorFlow Hub version: {hub.__version__}")
172
  logger.info(f"Gradio version: {gr.__version__}")
173
 
 
 
174
  iface.launch(
175
  server_name="0.0.0.0",
176
  debug=True
 
6
  import tensorflow_hub as hub
7
  import numpy as np
8
  from PIL import Image
 
9
  import os
10
+ from typing import Optional, Dict, Any, Union
11
 
12
+ # Set up logging
13
  logging.basicConfig(
14
  level=logging.INFO,
15
  format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
16
  )
17
  logger = logging.getLogger(__name__)
18
 
19
+ class MedicalDiagnosisModel:
20
+ def __init__(self, model_path: str):
21
+ self.model_path = model_path
22
+ self.model = self._load_model()
 
 
 
23
 
24
+ def _load_model(self) -> Optional[tf_keras.Model]:
25
+ """Load the transfer learning model with proper error handling."""
26
+ try:
27
+ if not os.path.exists(self.model_path):
28
+ raise FileNotFoundError(f"Model file not found at {self.model_path}")
29
+
30
+ logger.info(f"Loading model from {self.model_path}")
31
 
32
+ # Define custom objects dictionary for transfer learning
33
+ custom_objects = {
34
+ 'KerasLayer': hub.KerasLayer
35
+ }
36
+
37
+ try:
38
+ logger.info("Attempting to load model with custom objects...")
39
+ with tf_keras.utils.custom_object_scope(custom_objects):
40
+ model = tf_keras.models.load_model(self.model_path, compile=False)
41
+ except Exception as e:
42
+ logger.error(f"Failed to load with custom objects: {str(e)}")
43
+ logger.info("Attempting to load model without custom objects...")
44
+ model = tf_keras.models.load_model(self.model_path, compile=False)
45
+
46
+ model.summary()
47
+ logger.info("Model loaded successfully")
48
+ return model
49
+
50
+ except Exception as e:
51
+ logger.error(f"Error loading model: {str(e)}")
52
+ return None
53
+
54
+ def preprocess_image(self, image: Image.Image) -> np.ndarray:
55
+ """Preprocess the input image for model prediction."""
56
  try:
57
+ # Convert to RGB and resize
58
+ image = image.convert('RGB')
59
+ image = image.resize((256, 256))
60
+
61
+ # Convert to numpy array and normalize
62
+ image_array = np.array(image)
63
+ image_array = image_array / 255.0
64
+
65
+ # Add batch dimension
66
+ image_array = np.expand_dims(image_array, axis=0)
67
+ logger.info(f"Preprocessed image shape: {image_array.shape}")
68
+
69
+ return image_array
70
+
71
  except Exception as e:
72
+ logger.error(f"Error preprocessing image: {str(e)}")
73
+ raise
 
 
 
 
 
74
 
75
+ def predict(self, image: np.ndarray) -> Dict[str, float]:
76
+ """Run model prediction and return results."""
77
+ try:
78
+ prediction = self.model.predict(image)
79
+ return {
80
+ "prediction": float(prediction[0][0]),
81
+ "confidence": float(prediction[0][0]) * 100
82
+ }
83
+ except Exception as e:
84
+ logger.error(f"Error during prediction: {str(e)}")
85
+ raise
86
+
87
+ class MedicalDiagnosisAPI:
88
+ def __init__(self, api_key: str, user_id: str):
89
+ self.api_key = api_key
90
+ self.user_id = user_id
91
+ self.base_url = "https://api.example.com/v1" # Replace with actual API URL
92
 
93
+ def create_chat_session(self) -> str:
94
+ """Create a new chat session and return session ID."""
95
+ try:
96
+ response = requests.post(
97
+ f"{self.base_url}/sessions",
98
+ headers={
99
+ "Authorization": f"Bearer {self.api_key}",
100
+ "X-User-ID": self.user_id
101
+ }
102
+ )
103
+ response.raise_for_status()
104
+ return response.json()["session_id"]
105
+ except Exception as e:
106
+ logger.error(f"Error creating chat session: {str(e)}")
107
+ raise
108
+
109
+ def submit_query(self, session_id: str, patient_info: str,
110
+ image_analysis: Optional[str] = None) -> Dict[str, Any]:
111
+ """Submit a query to the API and return the response."""
112
+ try:
113
+ payload = {
114
+ "patient_info": patient_info,
115
+ "image_analysis": image_analysis
116
+ }
117
+
118
+ response = requests.post(
119
+ f"{self.base_url}/sessions/{session_id}/query",
120
+ headers={
121
+ "Authorization": f"Bearer {self.api_key}",
122
+ "X-User-ID": self.user_id
123
+ },
124
+ json=payload
125
+ )
126
+ response.raise_for_status()
127
+ return response.json()
128
+ except Exception as e:
129
+ logger.error(f"Error submitting query: {str(e)}")
130
+ raise
131
 
132
+ def extract_json_from_answer(answer: str) -> Dict[str, Any]:
133
+ """Extract and parse JSON from the API response."""
 
 
 
 
 
 
 
 
134
  try:
135
+ # Find JSON content between triple backticks if present
136
+ if "```json" in answer and "```" in answer:
137
+ json_str = answer.split("```json")[1].split("```")[0].strip()
138
+ else:
139
+ json_str = answer.strip()
140
+
141
+ return json.loads(json_str)
 
 
 
 
 
 
 
 
 
142
  except Exception as e:
143
+ logger.error(f"Error extracting JSON from answer: {str(e)}")
144
  raise
145
 
146
+ class MedicalDiagnosisApp:
147
+ def __init__(self, model_path: str, api_key: str, user_id: str):
148
+ self.model = MedicalDiagnosisModel(model_path)
149
+ self.api = MedicalDiagnosisAPI(api_key, user_id)
150
+
151
+ def process_request(self, patient_info: str,
152
+ image: Optional[Image.Image]) -> str:
153
+ """Process a medical diagnosis request."""
154
+ try:
155
+ if self.model.model is None:
156
+ return json.dumps({
157
+ "error": "Model initialization failed",
158
+ "status": "error"
159
+ }, indent=2)
160
+
161
+ # Process image if provided
162
+ image_analysis = None
163
+ if image is not None:
164
+ processed_image = self.model.preprocess_image(image)
165
+ image_analysis = self.model.predict(processed_image)
166
+ logger.info(f"Image analysis results: {image_analysis}")
167
+
168
+ # Create chat session and submit query
169
+ session_id = self.api.create_chat_session()
170
+ llm_response = self.api.submit_query(
171
+ session_id,
172
+ patient_info,
173
+ json.dumps(image_analysis) if image_analysis else None
174
+ )
175
 
176
+ if not llm_response or 'data' not in llm_response or 'answer' not in llm_response['data']:
177
+ raise ValueError("Invalid response structure from LLM")
178
+
179
+ json_data = extract_json_from_answer(llm_response['data']['answer'])
180
+ return json.dumps(json_data, indent=2)
 
181
 
182
+ except Exception as e:
183
+ logger.error(f"Error processing request: {str(e)}")
184
+ return json.dumps({
185
+ "error": str(e),
186
+ "status": "error",
187
+ "details": "Check the application logs for more information"
188
+ }, indent=2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
 
190
+ def create_gradio_interface() -> gr.Interface:
191
+ """Create and configure the Gradio interface."""
192
+ app = MedicalDiagnosisApp(
193
+ model_path='model_epoch_01.h5.keras',
194
+ api_key='KGSjxB1uptfSk8I8A7ciCuNT9Xa3qWC3',
195
+ user_id='plugin-1717464304'
196
+ )
197
+
198
+ return gr.Interface(
199
+ fn=app.process_request,
200
+ inputs=[
201
+ gr.Textbox(
202
+ label="Patient Information",
203
+ placeholder="Enter patient details including: symptoms, medical history, current medications, age, gender, and any relevant test results...",
204
+ lines=5,
205
+ max_lines=10
206
+ ),
207
+ gr.Image(
208
+ label="Medical Image",
209
+ type="numpy",
210
+ interactive=True
211
+ )
212
+ ],
213
+ outputs=gr.Textbox(
214
+ label="Medical Analysis",
215
+ placeholder="JSON analysis will appear here...",
216
+ lines=15
217
  ),
218
+ title="Medical Diagnosis Assistant",
219
+ description="Enter patient information and optionally upload a medical image for analysis."
220
+ )
 
 
 
 
 
 
 
 
 
 
 
221
 
222
  if __name__ == "__main__":
223
+ # Log version information
224
+ logger.info(f"TF-Keras version: {tf_keras.__version__}")
225
  logger.info(f"TensorFlow Hub version: {hub.__version__}")
226
  logger.info(f"Gradio version: {gr.__version__}")
227
 
228
+ # Create and launch the interface
229
+ iface = create_gradio_interface()
230
  iface.launch(
231
  server_name="0.0.0.0",
232
  debug=True