Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -36,36 +36,54 @@ def create_chat_session():
|
|
| 36 |
session_id = response_data['data']['id']
|
| 37 |
return session_id
|
| 38 |
|
| 39 |
-
# Step 2: Submit a query to the API
|
| 40 |
def submit_query(session_id, query):
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
|
| 54 |
def extract_json_from_answer(answer):
|
| 55 |
"""Extract and clean JSON from the LLM response"""
|
| 56 |
try:
|
| 57 |
-
#
|
|
|
|
| 58 |
return json.loads(answer)
|
| 59 |
except json.JSONDecodeError:
|
|
|
|
| 60 |
try:
|
| 61 |
-
#
|
| 62 |
start_idx = answer.find('{')
|
| 63 |
end_idx = answer.rfind('}') + 1
|
| 64 |
if start_idx != -1 and end_idx != 0:
|
| 65 |
json_str = answer[start_idx:end_idx]
|
|
|
|
| 66 |
return json.loads(json_str)
|
| 67 |
-
except (json.JSONDecodeError, ValueError):
|
| 68 |
-
logger.error("Failed to parse JSON from response")
|
| 69 |
raise
|
| 70 |
|
| 71 |
def load_model():
|
|
@@ -188,41 +206,106 @@ def gradio_interface(patient_info, image):
|
|
| 188 |
"details": "Check the application logs for more information"
|
| 189 |
}, indent=2)
|
| 190 |
|
| 191 |
-
# Gradio interface
|
| 192 |
-
iface = gr.Interface(
|
| 193 |
-
fn=gradio_interface,
|
| 194 |
-
inputs=[
|
| 195 |
-
gr.Textbox(
|
| 196 |
-
label="Patient Information",
|
| 197 |
-
placeholder="Enter patient details including: symptoms, medical history, current medications, age, gender, and any relevant test results...",
|
| 198 |
-
lines=5,
|
| 199 |
-
max_lines=10
|
| 200 |
-
),
|
| 201 |
-
gr.Image(
|
| 202 |
-
label="Medical Image",
|
| 203 |
-
type="pil",
|
| 204 |
-
interactive=True
|
| 205 |
-
)
|
| 206 |
-
],
|
| 207 |
-
outputs=gr.Textbox(
|
| 208 |
-
label="Medical Analysis",
|
| 209 |
-
placeholder="JSON analysis will appear here...",
|
| 210 |
-
lines=15
|
| 211 |
-
),
|
| 212 |
-
title="Medical Diagnosis Assistant",
|
| 213 |
-
description="Enter patient information and optionally upload a medical image for analysis."
|
| 214 |
-
)
|
| 215 |
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 221 |
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
|
| 227 |
|
| 228 |
|
|
|
|
| 36 |
session_id = response_data['data']['id']
|
| 37 |
return session_id
|
| 38 |
|
|
|
|
| 39 |
def submit_query(session_id, query):
|
| 40 |
+
try:
|
| 41 |
+
submit_query_url = f'https://api.on-demand.io/chat/v1/sessions/{session_id}/query'
|
| 42 |
+
submit_query_headers = {
|
| 43 |
+
'apikey': api_key
|
| 44 |
+
}
|
| 45 |
+
submit_query_body = {
|
| 46 |
+
"endpointId": "predefined-openai-gpt4o",
|
| 47 |
+
"query": query,
|
| 48 |
+
"pluginIds": ["plugin-1712327325", "plugin-1713962163"],
|
| 49 |
+
"responseMode": "sync"
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
# Log the query sent to the API
|
| 53 |
+
logger.info(f"Sending query to LLM API: {query}")
|
| 54 |
+
|
| 55 |
+
response = requests.post(submit_query_url, headers=submit_query_headers, json=submit_query_body)
|
| 56 |
+
logger.info(f"LLM API response status: {response.status_code}")
|
| 57 |
+
|
| 58 |
+
if response.status_code != 200:
|
| 59 |
+
logger.error(f"LLM API call failed with status code {response.status_code}")
|
| 60 |
+
return None
|
| 61 |
+
|
| 62 |
+
# Return the JSON response
|
| 63 |
+
return response.json()
|
| 64 |
+
|
| 65 |
+
except Exception as e:
|
| 66 |
+
logger.error(f"Error submitting query to LLM API: {str(e)}")
|
| 67 |
+
raise
|
| 68 |
|
| 69 |
def extract_json_from_answer(answer):
|
| 70 |
"""Extract and clean JSON from the LLM response"""
|
| 71 |
try:
|
| 72 |
+
# Try to parse the answer directly as JSON
|
| 73 |
+
logger.info(f"Trying to parse the LLM answer as JSON: {answer}")
|
| 74 |
return json.loads(answer)
|
| 75 |
except json.JSONDecodeError:
|
| 76 |
+
logger.warning("Direct JSON parse failed, trying to extract JSON content")
|
| 77 |
try:
|
| 78 |
+
# Attempt to find and parse JSON content within the answer
|
| 79 |
start_idx = answer.find('{')
|
| 80 |
end_idx = answer.rfind('}') + 1
|
| 81 |
if start_idx != -1 and end_idx != 0:
|
| 82 |
json_str = answer[start_idx:end_idx]
|
| 83 |
+
logger.info(f"Extracted JSON string: {json_str}")
|
| 84 |
return json.loads(json_str)
|
| 85 |
+
except (json.JSONDecodeError, ValueError) as parse_error:
|
| 86 |
+
logger.error(f"Failed to parse JSON from response: {str(parse_error)}")
|
| 87 |
raise
|
| 88 |
|
| 89 |
def load_model():
|
|
|
|
| 206 |
"details": "Check the application logs for more information"
|
| 207 |
}, indent=2)
|
| 208 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 209 |
|
| 210 |
+
def gradio_interface(patient_info, image):
|
| 211 |
+
try:
|
| 212 |
+
if model is None:
|
| 213 |
+
logger.error("Model is not initialized")
|
| 214 |
+
return json.dumps({
|
| 215 |
+
"error": "Model initialization failed. Please check the logs for details.",
|
| 216 |
+
"status": "error"
|
| 217 |
+
}, indent=2)
|
| 218 |
+
|
| 219 |
+
classes = ["Alzheimer's", "Stroke", "Tumor", "Normal"]
|
| 220 |
+
image_analysis = None
|
| 221 |
+
|
| 222 |
+
# Process the image if provided
|
| 223 |
+
if image is not None:
|
| 224 |
+
logger.info("Processing uploaded image")
|
| 225 |
+
processed_image = preprocess_image(image)
|
| 226 |
+
|
| 227 |
+
logger.info("Running model prediction")
|
| 228 |
+
prediction = model.predict(processed_image)
|
| 229 |
+
|
| 230 |
+
logger.info(f"Raw prediction: {prediction}")
|
| 231 |
+
# Assume the first prediction corresponds to the class
|
| 232 |
+
predicted_class_index = int(np.argmax(prediction))
|
| 233 |
+
predicted_confidence = np.max(prediction) * 100
|
| 234 |
+
|
| 235 |
+
image_analysis = {
|
| 236 |
+
"prediction": classes[predicted_class_index],
|
| 237 |
+
"confidence": predicted_confidence
|
| 238 |
+
}
|
| 239 |
+
logger.info(f"Image analysis results: {image_analysis}")
|
| 240 |
+
|
| 241 |
+
# Append image analysis to patient info
|
| 242 |
+
patient_info += f"\nPrediction based on MRI images: {image_analysis['prediction']}, Confidence: {image_analysis['confidence']}"
|
| 243 |
+
|
| 244 |
+
# Log the patient info sent to the API
|
| 245 |
+
logger.info(f"Submitting the following patient info to LLM: {patient_info}")
|
| 246 |
+
|
| 247 |
+
# Create a session and submit the query
|
| 248 |
+
session_id = create_chat_session()
|
| 249 |
+
llm_response = submit_query(session_id, patient_info)
|
| 250 |
+
|
| 251 |
+
# Log the raw response from the LLM API
|
| 252 |
+
logger.info(f"LLM API response: {llm_response}")
|
| 253 |
+
|
| 254 |
+
if not llm_response or 'data' not in llm_response or 'answer' not in llm_response['data']:
|
| 255 |
+
logger.error("Invalid response structure from LLM")
|
| 256 |
+
return json.dumps({
|
| 257 |
+
"error": "Invalid LLM response structure.",
|
| 258 |
+
"status": "error"
|
| 259 |
+
}, indent=2)
|
| 260 |
+
|
| 261 |
+
# Extract the answer and parse the JSON from the response
|
| 262 |
+
json_data = extract_json_from_answer(llm_response['data']['answer'])
|
| 263 |
+
|
| 264 |
+
return json.dumps(json_data, indent=2)
|
| 265 |
+
|
| 266 |
+
except Exception as e:
|
| 267 |
+
logger.error(f"Error in gradio_interface: {str(e)}")
|
| 268 |
+
return json.dumps({
|
| 269 |
+
"error": str(e),
|
| 270 |
+
"status": "error",
|
| 271 |
+
"details": "Check the application logs for more information"
|
| 272 |
+
}, indent=2)
|
| 273 |
+
|
| 274 |
+
# # Gradio interface
|
| 275 |
+
# iface = gr.Interface(
|
| 276 |
+
# fn=gradio_interface,
|
| 277 |
+
# inputs=[
|
| 278 |
+
# gr.Textbox(
|
| 279 |
+
# label="Patient Information",
|
| 280 |
+
# placeholder="Enter patient details including: symptoms, medical history, current medications, age, gender, and any relevant test results...",
|
| 281 |
+
# lines=5,
|
| 282 |
+
# max_lines=10
|
| 283 |
+
# ),
|
| 284 |
+
# gr.Image(
|
| 285 |
+
# label="Medical Image",
|
| 286 |
+
# type="pil",
|
| 287 |
+
# interactive=True
|
| 288 |
+
# )
|
| 289 |
+
# ],
|
| 290 |
+
# outputs=gr.Textbox(
|
| 291 |
+
# label="Medical Analysis",
|
| 292 |
+
# placeholder="JSON analysis will appear here...",
|
| 293 |
+
# lines=15
|
| 294 |
+
# ),
|
| 295 |
+
# title="Medical Diagnosis Assistant",
|
| 296 |
+
# description="Enter patient information and optionally upload a medical image for analysis."
|
| 297 |
+
# )
|
| 298 |
+
|
| 299 |
+
# if __name__ == "__main__":
|
| 300 |
+
# # Add version information logging
|
| 301 |
+
# logger.info(f"TensorFlow Keras version: {tf_keras.__version__}")
|
| 302 |
+
# logger.info(f"TensorFlow Hub version: {hub.__version__}")
|
| 303 |
+
# logger.info(f"Gradio version: {gr.__version__}")
|
| 304 |
|
| 305 |
+
# iface.launch(
|
| 306 |
+
# server_name="0.0.0.0",
|
| 307 |
+
# debug=True
|
| 308 |
+
# )
|
| 309 |
|
| 310 |
|
| 311 |
|