Spaces:
Running
Running
# import requests | |
# import numpy as np | |
# import tensorflow as tf | |
# import tensorflow_hub as hub | |
# import gradio as gr | |
# from PIL import Image | |
# # Load models | |
# #model_initial = keras.models.load_model( | |
# # "models/initial_model.h5", custom_objects={'KerasLayer': hub.KerasLayer} | |
# #) | |
# #model_tumor = keras.models.load_model( | |
# # "models/model_tumor.h5", custom_objects={'KerasLayer': hub.KerasLayer} | |
# #) | |
# #model_stroke = keras.models.load_model( | |
# # "models/model_stroke.h5", custom_objects={'KerasLayer': hub.KerasLayer} | |
# #) | |
# #model_alzheimer = keras.models.load_model( | |
# # "models/model_alzheimer.h5", custom_objects={'KerasLayer': hub.KerasLayer} | |
# # API key and user ID for on-demand | |
# api_key = 'KGSjxB1uptfSk8I8A7ciCuNT9Xa3qWC3' | |
# external_user_id = 'plugin-1717464304' | |
# # Step 1: Create a chat session with the API | |
# def create_chat_session(): | |
# create_session_url = 'https://api.on-demand.io/chat/v1/sessions' | |
# create_session_headers = { | |
# 'apikey': api_key | |
# } | |
# create_session_body = { | |
# "pluginIds": [], | |
# "externalUserId": external_user_id | |
# } | |
# response = requests.post(create_session_url, headers=create_session_headers, json=create_session_body) | |
# response_data = response.json() | |
# session_id = response_data['data']['id'] | |
# return session_id | |
# # Step 2: Submit query to the API | |
# def submit_query(session_id, query): | |
# submit_query_url = f'https://api.on-demand.io/chat/v1/sessions/{session_id}/query' | |
# submit_query_headers = { | |
# 'apikey': api_key | |
# } | |
# submit_query_body = { | |
# "endpointId": "predefined-openai-gpt4o", | |
# "query": query, | |
# "pluginIds": ["plugin-1712327325", "plugin-1713962163"], | |
# "responseMode": "sync" | |
# } | |
# response = requests.post(submit_query_url, headers=submit_query_headers, json=submit_query_body) | |
# return response.json() | |
# # Combined disease model (placeholder) | |
# class CombinedDiseaseModel(tf.keras.Model): | |
# def __init__(self, model_initial, model_alzheimer, model_tumor, model_stroke): | |
# super(CombinedDiseaseModel, self).__init__() | |
# self.model_initial = model_initial | |
# self.model_alzheimer = model_alzheimer | |
# self.model_tumor = model_tumor | |
# self.model_stroke = model_stroke | |
# self.disease_labels = ["Alzheimer's", 'No Disease', 'Stroke', 'Tumor'] | |
# self.sub_models = { | |
# "Alzheimer's": model_alzheimer, | |
# 'Tumor': model_tumor, | |
# 'Stroke': model_stroke | |
# } | |
# def call(self, inputs): | |
# initial_probs = self.model_initial(inputs, training=False) | |
# main_disease_idx = tf.argmax(initial_probs, axis=1) | |
# main_disease = self.disease_labels[main_disease_idx[0].numpy()] | |
# main_disease_prob = initial_probs[0, main_disease_idx[0]].numpy() | |
# if main_disease == 'No Disease': | |
# sub_category = "No Disease" | |
# sub_category_prob = main_disease_prob | |
# else: | |
# sub_model = self.sub_models[main_disease] | |
# sub_category_pred = sub_model(inputs, training=False) | |
# sub_category = tf.argmax(sub_category_pred, axis=1).numpy()[0] | |
# sub_category_prob = sub_category_pred[0, sub_category].numpy() | |
# if main_disease == "Alzheimer's": | |
# sub_category_label = ['Very Mild', 'Mild', 'Moderate'] | |
# elif main_disease == 'Tumor': | |
# sub_category_label = ['Glioma', 'Meningioma', 'Pituitary'] | |
# elif main_disease == 'Stroke': | |
# sub_category_label = ['Ischemic', 'Hemorrhagic'] | |
# sub_category = sub_category_label[sub_category] | |
# return f"The MRI image shows {main_disease} with a probability of {main_disease_prob*100:.2f}%.\n" \ | |
# f"The subcategory of {main_disease} is {sub_category} with a probability of {sub_category_prob*100:.2f}%." | |
# # Placeholder function to process images | |
# def process_image(image): | |
# image = image.resize((256, 256)) | |
# image.convert("RGB") | |
# image_array = np.array(image) / 255.0 | |
# image_array = np.expand_dims(image_array, axis=0) | |
# # Prediction logic here | |
# # predictions = cnn_model(image_array) | |
# return "Mock prediction: Disease identified with a probability of 85%." | |
# # Function to handle patient info, query, and image processing | |
# def gradio_interface(patient_info, query_type, image): | |
# if image is not None: | |
# image_response = process_image(image) | |
# # Call LLM with patient info and query | |
# session_id = create_chat_session() | |
# query = f"Patient Info: {patient_info}\nQuery Type: {query_type}" | |
# llm_response = submit_query(session_id, query) | |
# # Debug: Print the full response to inspect it | |
# print("LLM Response:", llm_response) # This will print the full response for inspection | |
# # Safely handle 'message' if it exists | |
# message = llm_response.get('data', {}).get('message', 'No message returned from LLM') | |
# # Check if message is empty and print the complete response if necessary | |
# if message == 'No message returned from LLM': | |
# print("Full LLM Response Data:", llm_response) # Inspect the full LLM response for any helpful info | |
# response = f"Patient Info: {patient_info}\nQuery Type: {query_type}\n\n{image_response}\n\nLLM Response:\n{message}" | |
# return response | |
# else: | |
# return "Please upload an image." | |
# # Gradio interface | |
# iface = gr.Interface( | |
# fn=gradio_interface, | |
# inputs=[ | |
# gr.Textbox( | |
# label="Patient Information", | |
# placeholder="Enter patient details here...", | |
# lines=5, | |
# max_lines=10 | |
# ), | |
# gr.Textbox( | |
# label="Query Type", | |
# placeholder="Describe the type of diagnosis or information needed..." | |
# ), | |
# gr.Image( | |
# type="pil", | |
# label="Upload an MRI Image", | |
# ) | |
# ], | |
# outputs=gr.Textbox(label="Response", placeholder="The response will appear here..."), | |
# title="Medical Diagnosis with MRI and LLM", | |
# description="Upload MRI images and provide patient information for a combined CNN model and LLM analysis." | |
# ) | |
# iface.launch() | |
import requests | |
import gradio as gr | |
import logging | |
import json | |
# Set up logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# API key and user ID for on-demand | |
api_key = 'KGSjxB1uptfSk8I8A7ciCuNT9Xa3qWC3' | |
external_user_id = 'plugin-1717464304' | |
def create_chat_session(): | |
try: | |
create_session_url = 'https://api.on-demand.io/chat/v1/sessions' | |
create_session_headers = { | |
'apikey': api_key, | |
'Content-Type': 'application/json' | |
} | |
create_session_body = { | |
"pluginIds": [], | |
"externalUserId": external_user_id | |
} | |
logger.info("Creating chat session...") | |
response = requests.post(create_session_url, headers=create_session_headers, json=create_session_body) | |
response.raise_for_status() # Raise an exception for bad status codes | |
response_data = response.json() | |
logger.info(f"Session created successfully: {json.dumps(response_data, indent=2)}") | |
session_id = response_data['data']['id'] | |
return session_id | |
except requests.exceptions.RequestException as e: | |
logger.error(f"Error creating chat session: {str(e)}") | |
if hasattr(e.response, 'text'): | |
logger.error(f"Response content: {e.response.text}") | |
raise | |
def submit_query(session_id, query): | |
try: | |
submit_query_url = f'https://api.on-demand.io/chat/v1/sessions/{session_id}/query' | |
submit_query_headers = { | |
'apikey': api_key, | |
'Content-Type': 'application/json' | |
} | |
submit_query_body = { | |
"endpointId": "predefined-openai-gpt4o", | |
"query": query, | |
"pluginIds": ["plugin-1712327325", "plugin-1713962163"], | |
"responseMode": "sync" | |
} | |
logger.info(f"Submitting query for session {session_id}") | |
logger.info(f"Query content: {query}") | |
response = requests.post(submit_query_url, headers=submit_query_headers, json=submit_query_body) | |
response.raise_for_status() | |
response_data = response.json() | |
logger.info(f"Query response received: {json.dumps(response_data, indent=2)}") | |
return response_data | |
except requests.exceptions.RequestException as e: | |
logger.error(f"Error submitting query: {str(e)}") | |
if hasattr(e.response, 'text'): | |
logger.error(f"Response content: {e.response.text}") | |
raise | |
def gradio_interface(patient_info, query_type): | |
try: | |
# Create session | |
session_id = create_chat_session() | |
# Construct query | |
query = f"Patient Info: {patient_info}\nQuery Type: {query_type}" | |
# Submit query and get response | |
llm_response = submit_query(session_id, query) | |
# Enhanced response handling | |
if not llm_response: | |
logger.error("Empty response received from LLM") | |
return "Error: No response received from the LLM service" | |
# Navigate the response structure with detailed logging | |
logger.info(f"Processing LLM response: {json.dumps(llm_response, indent=2)}") | |
if 'data' not in llm_response: | |
logger.error("Response missing 'data' field") | |
return f"Error: Unexpected response structure\nFull response: {json.dumps(llm_response, indent=2)}" | |
message = llm_response.get('data', {}).get('message') | |
if not message: | |
logger.error("No message found in response data") | |
return f"Error: No message in response\nFull response: {json.dumps(llm_response, indent=2)}" | |
response = f"Patient Info: {patient_info}\nQuery Type: {query_type}\n\nLLM Response:\n{message}" | |
return response | |
except Exception as e: | |
logger.error(f"Error in gradio_interface: {str(e)}", exc_info=True) | |
return f"Error processing request: {str(e)}" | |
# Gradio interface | |
iface = gr.Interface( | |
fn=gradio_interface, | |
inputs=[ | |
gr.Textbox( | |
label="Patient Information", | |
placeholder="Enter patient details here...", | |
lines=5, | |
max_lines=10 | |
), | |
gr.Textbox( | |
label="Query Type", | |
placeholder="Describe the type of diagnosis or information needed..." | |
), | |
], | |
outputs=gr.Textbox(label="Response", placeholder="The response will appear here..."), | |
title="Medical Diagnosis with LLM", | |
description="Provide patient information and a query type for analysis by the LLM." | |
) | |
if __name__ == "__main__": | |
iface.launch() |