import requests import gradio as gr import logging import json import tf_keras import tensorflow_hub as hub import numpy as np from PIL import Image import os from typing import Optional, Dict, Any, Union # Set up logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) class MedicalDiagnosisModel: def __init__(self, model_path: str): self.model_path = model_path self.model = self._load_model() def _load_model(self) -> Optional[tf_keras.Model]: """Load the transfer learning model with proper error handling.""" try: if not os.path.exists(self.model_path): raise FileNotFoundError(f"Model file not found at {self.model_path}") logger.info(f"Loading model from {self.model_path}") # Define custom objects dictionary for transfer learning custom_objects = { 'KerasLayer': hub.KerasLayer } try: logger.info("Attempting to load model with custom objects...") with tf_keras.utils.custom_object_scope(custom_objects): model = tf_keras.models.load_model(self.model_path, compile=False) except Exception as e: logger.error(f"Failed to load with custom objects: {str(e)}") logger.info("Attempting to load model without custom objects...") model = tf_keras.models.load_model(self.model_path, compile=False) model.summary() logger.info("Model loaded successfully") return model except Exception as e: logger.error(f"Error loading model: {str(e)}") return None def preprocess_image(self, image: Image.Image): """Preprocess the input image for model prediction.""" try: # Convert to RGB and resize image = image.convert('RGB') image = image.resize((256, 256)) # Convert to numpy array and normalize image_array = np.array(image) image_array = image_array / 255.0 # Add batch dimension image_array = np.expand_dims(image_array, axis=0) logger.info(f"Preprocessed image shape: {image_array.shape}") return image_array except Exception as e: logger.error(f"Error preprocessing image: {str(e)}") raise def predict(self, image: np.ndarray) -> Dict[str, float]: """Run model prediction and return results.""" try: prediction = self.model.predict(image) return { "prediction": float(prediction[0][0]), "confidence": float(prediction[0][0]) * 100 } except Exception as e: logger.error(f"Error during prediction: {str(e)}") raise class MedicalDiagnosisAPI: def __init__(self, api_key: str, user_id: str): self.api_key = api_key self.user_id = user_id self.base_url = "https://api.example.com/v1" # Replace with actual API URL def create_chat_session(self) -> str: """Create a new chat session and return session ID.""" try: response = requests.post( f"{self.base_url}/sessions", headers={ "Authorization": f"Bearer {self.api_key}", "X-User-ID": self.user_id } ) response.raise_for_status() return response.json()["session_id"] except Exception as e: logger.error(f"Error creating chat session: {str(e)}") raise def submit_query(self, session_id: str, patient_info: str, image_analysis: Optional[str] = None) -> Dict[str, Any]: """Submit a query to the API and return the response.""" try: payload = { "patient_info": patient_info, "image_analysis": image_analysis } response = requests.post( f"{self.base_url}/sessions/{session_id}/query", headers={ "Authorization": f"Bearer {self.api_key}", "X-User-ID": self.user_id }, json=payload ) response.raise_for_status() return response.json() except Exception as e: logger.error(f"Error submitting query: {str(e)}") raise def extract_json_from_answer(answer: str) -> Dict[str, Any]: """Extract and parse JSON from the API response.""" try: # Find JSON content between triple backticks if present if "```json" in answer and "```" in answer: json_str = answer.split("```json")[1].split("```")[0].strip() else: json_str = answer.strip() return json.loads(json_str) except Exception as e: logger.error(f"Error extracting JSON from answer: {str(e)}") raise class MedicalDiagnosisApp: def __init__(self, model_path: str, api_key: str, user_id: str): self.model = MedicalDiagnosisModel(model_path) self.api = MedicalDiagnosisAPI(api_key, user_id) def process_request(self, patient_info: str, image: Optional[Image.Image]) -> str: """Process a medical diagnosis request.""" try: if self.model.model is None: return json.dumps({ "error": "Model initialization failed", "status": "error" }, indent=2) # Process image if provided image_analysis = None if image is not None: processed_image = self.model.preprocess_image(image) image_analysis = self.model.predict(processed_image) logger.info(f"Image analysis results: {image_analysis}") # Create chat session and submit query session_id = self.api.create_chat_session() llm_response = self.api.submit_query( session_id, patient_info, json.dumps(image_analysis) if image_analysis else None ) if not llm_response or 'data' not in llm_response or 'answer' not in llm_response['data']: raise ValueError("Invalid response structure from LLM") json_data = extract_json_from_answer(llm_response['data']['answer']) return json.dumps(json_data, indent=2) except Exception as e: logger.error(f"Error processing request: {str(e)}") return json.dumps({ "error": str(e), "status": "error", "details": "Check the application logs for more information" }, indent=2) def create_gradio_interface() -> gr.Interface: """Create and configure the Gradio interface.""" app = MedicalDiagnosisApp( model_path='model_epoch_01.h5.keras', api_key='KGSjxB1uptfSk8I8A7ciCuNT9Xa3qWC3', user_id='plugin-1717464304' ) return gr.Interface( fn=app.process_request, inputs=[ gr.Textbox( label="Patient Information", placeholder="Enter patient details including: symptoms, medical history, current medications, age, gender, and any relevant test results...", lines=5, max_lines=10 ), gr.Image( label="Medical Image", type="numpy", interactive=True ) ], outputs=gr.Textbox( label="Medical Analysis", placeholder="JSON analysis will appear here...", lines=15 ), title="Medical Diagnosis Assistant", description="Enter patient information and optionally upload a medical image for analysis." ) if __name__ == "__main__": # Log version information logger.info(f"TF-Keras version: {tf_keras.__version__}") logger.info(f"TensorFlow Hub version: {hub.__version__}") logger.info(f"Gradio version: {gr.__version__}") # Create and launch the interface iface = create_gradio_interface() iface.launch( server_name="0.0.0.0", debug=True ) # import requests # import gradio as gr # import logging # import json # # Set up logging # logging.basicConfig(level=logging.INFO) # logger = logging.getLogger(__name__) # # API key and user ID for on-demand # api_key = 'KGSjxB1uptfSk8I8A7ciCuNT9Xa3qWC3' # external_user_id = 'plugin-1717464304' # def create_chat_session(): # try: # create_session_url = 'https://api.on-demand.io/chat/v1/sessions' # create_session_headers = { # 'apikey': api_key, # 'Content-Type': 'application/json' # } # create_session_body = { # "pluginIds": [], # "externalUserId": external_user_id # } # response = requests.post(create_session_url, headers=create_session_headers, json=create_session_body) # response.raise_for_status() # return response.json()['data']['id'] # except requests.exceptions.RequestException as e: # logger.error(f"Error creating chat session: {str(e)}") # raise # def submit_query(session_id, query): # try: # submit_query_url = f'https://api.on-demand.io/chat/v1/sessions/{session_id}/query' # submit_query_headers = { # 'apikey': api_key, # 'Content-Type': 'application/json' # } # structured_query = f""" # Based on the following patient information, provide a detailed medical analysis in JSON format: # {query} # Return only valid JSON with these fields: # - diagnosis_details # - probable_diagnoses (array) # - treatment_plans (array) # - lifestyle_modifications (array) # - medications (array of objects with name and dosage) # - additional_tests (array) # - precautions (array) # - follow_up (string) # """ # submit_query_body = { # "endpointId": "predefined-openai-gpt4o", # "query": structured_query, # "pluginIds": ["plugin-1712327325", "plugin-1713962163"], # "responseMode": "sync" # } # response = requests.post(submit_query_url, headers=submit_query_headers, json=submit_query_body) # response.raise_for_status() # return response.json() # except requests.exceptions.RequestException as e: # logger.error(f"Error submitting query: {str(e)}") # raise # def extract_json_from_answer(answer): # """Extract and clean JSON from the LLM response""" # try: # # First try to parse the answer directly # return json.loads(answer) # except json.JSONDecodeError: # try: # # If that fails, try to find JSON content and parse it # start_idx = answer.find('{') # end_idx = answer.rfind('}') + 1 # if start_idx != -1 and end_idx != 0: # json_str = answer[start_idx:end_idx] # return json.loads(json_str) # except (json.JSONDecodeError, ValueError): # logger.error("Failed to parse JSON from response") # raise # def gradio_interface(patient_info): # try: # session_id = create_chat_session() # llm_response = submit_query(session_id, patient_info) # if not llm_response or 'data' not in llm_response or 'answer' not in llm_response['data']: # raise ValueError("Invalid response structure") # # Extract and clean JSON from the response # json_data = extract_json_from_answer(llm_response['data']['answer']) # # Return clean JSON string without extra formatting # return json.dumps(json_data) # except Exception as e: # logger.error(f"Error in gradio_interface: {str(e)}") # return json.dumps({"error": str(e)}) # # Gradio interface # iface = gr.Interface( # fn=gradio_interface, # inputs=[ # gr.Textbox( # label="Patient Information", # placeholder="Enter patient details including: symptoms, medical history, current medications, age, gender, and any relevant test results...", # lines=5, # max_lines=10 # ) # ], # outputs=gr.Textbox( # label="Medical Analysis", # placeholder="JSON analysis will appear here...", # lines=15 # ), # title="Medical Diagnosis Assistant", # description="Enter detailed patient information to receive a structured medical analysis in JSON format." # ) # if __name__ == "__main__": # iface.launch()