NeuroGenAI / app.py
yashbyname's picture
Update app.py
3df9cbe verified
raw
history blame
5.8 kB
import requests
import numpy as np
import tensorflow as tf
import tensorflow_hub as hub
import gradio as gr
from PIL import Image
# Load models
#model_initial = keras.models.load_model(
# "models/initial_model.h5", custom_objects={'KerasLayer': hub.KerasLayer}
#)
#model_tumor = keras.models.load_model(
# "models/model_tumor.h5", custom_objects={'KerasLayer': hub.KerasLayer}
#)
#model_stroke = keras.models.load_model(
# "models/model_stroke.h5", custom_objects={'KerasLayer': hub.KerasLayer}
#)
#model_alzheimer = keras.models.load_model(
# "models/model_alzheimer.h5", custom_objects={'KerasLayer': hub.KerasLayer}
# API key and user ID for on-demand
api_key = 'KGSjxB1uptfSk8I8A7ciCuNT9Xa3qWC3'
external_user_id = 'plugin-1717464304'
# Step 1: Create a chat session with the API
def create_chat_session():
create_session_url = 'https://api.on-demand.io/chat/v1/sessions'
create_session_headers = {
'apikey': api_key
}
create_session_body = {
"pluginIds": [],
"externalUserId": external_user_id
}
response = requests.post(create_session_url, headers=create_session_headers, json=create_session_body)
response_data = response.json()
session_id = response_data['data']['id']
return session_id
# Step 2: Submit query to the API
def submit_query(session_id, query):
submit_query_url = f'https://api.on-demand.io/chat/v1/sessions/{session_id}/query'
submit_query_headers = {
'apikey': api_key
}
submit_query_body = {
"endpointId": "predefined-openai-gpt4o",
"query": query,
"pluginIds": ["plugin-1712327325", "plugin-1713962163"],
"responseMode": "sync"
}
response = requests.post(submit_query_url, headers=submit_query_headers, json=submit_query_body)
return response.json()
# Combined disease model (placeholder)
class CombinedDiseaseModel(tf.keras.Model):
def __init__(self, model_initial, model_alzheimer, model_tumor, model_stroke):
super(CombinedDiseaseModel, self).__init__()
self.model_initial = model_initial
self.model_alzheimer = model_alzheimer
self.model_tumor = model_tumor
self.model_stroke = model_stroke
self.disease_labels = ["Alzheimer's", 'No Disease', 'Stroke', 'Tumor']
self.sub_models = {
"Alzheimer's": model_alzheimer,
'Tumor': model_tumor,
'Stroke': model_stroke
}
def call(self, inputs):
initial_probs = self.model_initial(inputs, training=False)
main_disease_idx = tf.argmax(initial_probs, axis=1)
main_disease = self.disease_labels[main_disease_idx[0].numpy()]
main_disease_prob = initial_probs[0, main_disease_idx[0]].numpy()
if main_disease == 'No Disease':
sub_category = "No Disease"
sub_category_prob = main_disease_prob
else:
sub_model = self.sub_models[main_disease]
sub_category_pred = sub_model(inputs, training=False)
sub_category = tf.argmax(sub_category_pred, axis=1).numpy()[0]
sub_category_prob = sub_category_pred[0, sub_category].numpy()
if main_disease == "Alzheimer's":
sub_category_label = ['Very Mild', 'Mild', 'Moderate']
elif main_disease == 'Tumor':
sub_category_label = ['Glioma', 'Meningioma', 'Pituitary']
elif main_disease == 'Stroke':
sub_category_label = ['Ischemic', 'Hemorrhagic']
sub_category = sub_category_label[sub_category]
return f"The MRI image shows {main_disease} with a probability of {main_disease_prob*100:.2f}%.\n" \
f"The subcategory of {main_disease} is {sub_category} with a probability of {sub_category_prob*100:.2f}%."
# Placeholder function to process images
def process_image(image):
image = image.resize((256, 256))
image.convert("RGB")
image_array = np.array(image) / 255.0
image_array = np.expand_dims(image_array, axis=0)
# Prediction logic here
# predictions = cnn_model(image_array)
return "Mock prediction: Disease identified with a probability of 85%."
# Function to handle patient info, query, and image processing
def gradio_interface(patient_info, query_type, image):
if image is not None:
image_response = process_image(image)
# Call LLM with patient info and query
session_id = create_chat_session()
query = f"Patient Info: {patient_info}\nQuery Type: {query_type}"
llm_response = submit_query(session_id, query)
# Debug: Print the full response to inspect it
print("LLM Response:", llm_response)
# Handle missing 'message' field safely
message = llm_response.get('data', {}).get('message', 'No message returned from LLM')
response = f"Patient Info: {patient_info}\nQuery Type: {query_type}\n\n{image_response}\n\nLLM Response:\n{message}"
return response
else:
return "Please upload an image."
# Gradio interface
iface = gr.Interface(
fn=gradio_interface,
inputs=[
gr.Textbox(
label="Patient Information",
placeholder="Enter patient details here...",
lines=5,
max_lines=10
),
gr.Textbox(
label="Query Type",
placeholder="Describe the type of diagnosis or information needed..."
),
gr.Image(
type="pil",
label="Upload an MRI Image",
)
],
outputs=gr.Textbox(label="Response", placeholder="The response will appear here..."),
title="Medical Diagnosis with MRI and LLM",
description="Upload MRI images and provide patient information for a combined CNN model and LLM analysis."
)
iface.launch()