NeuroGenAI / app.py
yashbyname's picture
Update app.py
f3a1e2d verified
raw
history blame
11.2 kB
import requests
import gradio as gr
import logging
import json
import tensorflow as tf
import numpy as np
from PIL import Image
import io
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# API key and user ID for on-demand
api_key = 'KGSjxB1uptfSk8I8A7ciCuNT9Xa3qWC3'
external_user_id = 'plugin-1717464304'
# Load the keras model
def load_model():
try:
model = tf.keras.models.load_model('model_epoch_01.h5.keras')
logger.info("Model loaded successfully")
return model
except Exception as e:
logger.error(f"Error loading model: {str(e)}")
raise
# Preprocess image for model
def preprocess_image(image):
try:
# Convert to numpy array if needed
if isinstance(image, Image.Image):
image = np.array(image)
# Resize image to match model's expected input shape
# Note: Adjust these dimensions to match your model's requirements
target_size = (224, 224) # Change this to match your model's input size
image = tf.image.resize(image, target_size)
# Normalize pixel values
image = image / 255.0
# Add batch dimension
image = np.expand_dims(image, axis=0)
return image
except Exception as e:
logger.error(f"Error preprocessing image: {str(e)}")
raise
def create_chat_session():
try:
create_session_url = 'https://api.on-demand.io/chat/v1/sessions'
create_session_headers = {
'apikey': api_key,
'Content-Type': 'application/json'
}
create_session_body = {
"pluginIds": [],
"externalUserId": external_user_id
}
response = requests.post(create_session_url, headers=create_session_headers, json=create_session_body)
response.raise_for_status()
return response.json()['data']['id']
except requests.exceptions.RequestException as e:
logger.error(f"Error creating chat session: {str(e)}")
raise
def submit_query(session_id, query, image_analysis=None):
try:
submit_query_url = f'https://api.on-demand.io/chat/v1/sessions/{session_id}/query'
submit_query_headers = {
'apikey': api_key,
'Content-Type': 'application/json'
}
# Include image analysis in the query if available
query_with_image = query
if image_analysis:
query_with_image += f"\n\nImage Analysis Results: {image_analysis}"
structured_query = f"""
Based on the following patient information and image analysis, provide a detailed medical analysis in JSON format:
{query_with_image}
Return only valid JSON with these fields:
- diagnosis_details
- probable_diagnoses (array)
- treatment_plans (array)
- lifestyle_modifications (array)
- medications (array of objects with name and dosage)
- additional_tests (array)
- precautions (array)
- follow_up (string)
- image_findings (object with prediction and confidence)
"""
submit_query_body = {
"endpointId": "predefined-openai-gpt4o",
"query": structured_query,
"pluginIds": ["plugin-1712327325", "plugin-1713962163"],
"responseMode": "sync"
}
response = requests.post(submit_query_url, headers=submit_query_headers, json=submit_query_body)
response.raise_for_status()
return response.json()
except requests.exceptions.RequestException as e:
logger.error(f"Error submitting query: {str(e)}")
raise
def extract_json_from_answer(answer):
"""Extract and clean JSON from the LLM response"""
try:
return json.loads(answer)
except json.JSONDecodeError:
try:
start_idx = answer.find('{')
end_idx = answer.rfind('}') + 1
if start_idx != -1 and end_idx != 0:
json_str = answer[start_idx:end_idx]
return json.loads(json_str)
except (json.JSONDecodeError, ValueError):
logger.error("Failed to parse JSON from response")
raise
# Initialize the model
model = load_model()
def gradio_interface(patient_info, image):
try:
# Process image if provided
image_analysis = None
if image is not None:
# Preprocess image
processed_image = preprocess_image(image)
# Get model prediction
prediction = model.predict(processed_image)
# Format prediction results
# Note: Adjust this based on your model's output format
image_analysis = {
"prediction": float(prediction[0][0]), # Adjust indexing based on your model's output
"confidence": float(prediction[0][0]) * 100 # Convert to percentage
}
# Create chat session and submit query
session_id = create_chat_session()
llm_response = submit_query(session_id, patient_info,
json.dumps(image_analysis) if image_analysis else None)
if not llm_response or 'data' not in llm_response or 'answer' not in llm_response['data']:
raise ValueError("Invalid response structure")
# Extract and clean JSON from the response
json_data = extract_json_from_answer(llm_response['data']['answer'])
# Return clean JSON string
return json.dumps(json_data)
except Exception as e:
logger.error(f"Error in gradio_interface: {str(e)}")
return json.dumps({"error": str(e)})
# Gradio interface
iface = gr.Interface(
fn=gradio_interface,
inputs=[
gr.Textbox(
label="Patient Information",
placeholder="Enter patient details including: symptoms, medical history, current medications, age, gender, and any relevant test results...",
lines=5,
max_lines=10
),
gr.Image(
label="Medical Image",
type="numpy",
optional=True
)
],
outputs=gr.Textbox(
label="Medical Analysis",
placeholder="JSON analysis will appear here...",
lines=15
),
title="Medical Diagnosis Assistant",
description="Enter patient information and optionally upload a medical image for analysis."
)
if __name__ == "__main__":
iface.launch()
# import requests
# import gradio as gr
# import logging
# import json
# # Set up logging
# logging.basicConfig(level=logging.INFO)
# logger = logging.getLogger(__name__)
# # API key and user ID for on-demand
# api_key = 'KGSjxB1uptfSk8I8A7ciCuNT9Xa3qWC3'
# external_user_id = 'plugin-1717464304'
# def create_chat_session():
# try:
# create_session_url = 'https://api.on-demand.io/chat/v1/sessions'
# create_session_headers = {
# 'apikey': api_key,
# 'Content-Type': 'application/json'
# }
# create_session_body = {
# "pluginIds": [],
# "externalUserId": external_user_id
# }
# response = requests.post(create_session_url, headers=create_session_headers, json=create_session_body)
# response.raise_for_status()
# return response.json()['data']['id']
# except requests.exceptions.RequestException as e:
# logger.error(f"Error creating chat session: {str(e)}")
# raise
# def submit_query(session_id, query):
# try:
# submit_query_url = f'https://api.on-demand.io/chat/v1/sessions/{session_id}/query'
# submit_query_headers = {
# 'apikey': api_key,
# 'Content-Type': 'application/json'
# }
# structured_query = f"""
# Based on the following patient information, provide a detailed medical analysis in JSON format:
# {query}
# Return only valid JSON with these fields:
# - diagnosis_details
# - probable_diagnoses (array)
# - treatment_plans (array)
# - lifestyle_modifications (array)
# - medications (array of objects with name and dosage)
# - additional_tests (array)
# - precautions (array)
# - follow_up (string)
# """
# submit_query_body = {
# "endpointId": "predefined-openai-gpt4o",
# "query": structured_query,
# "pluginIds": ["plugin-1712327325", "plugin-1713962163"],
# "responseMode": "sync"
# }
# response = requests.post(submit_query_url, headers=submit_query_headers, json=submit_query_body)
# response.raise_for_status()
# return response.json()
# except requests.exceptions.RequestException as e:
# logger.error(f"Error submitting query: {str(e)}")
# raise
# def extract_json_from_answer(answer):
# """Extract and clean JSON from the LLM response"""
# try:
# # First try to parse the answer directly
# return json.loads(answer)
# except json.JSONDecodeError:
# try:
# # If that fails, try to find JSON content and parse it
# start_idx = answer.find('{')
# end_idx = answer.rfind('}') + 1
# if start_idx != -1 and end_idx != 0:
# json_str = answer[start_idx:end_idx]
# return json.loads(json_str)
# except (json.JSONDecodeError, ValueError):
# logger.error("Failed to parse JSON from response")
# raise
# def gradio_interface(patient_info):
# try:
# session_id = create_chat_session()
# llm_response = submit_query(session_id, patient_info)
# if not llm_response or 'data' not in llm_response or 'answer' not in llm_response['data']:
# raise ValueError("Invalid response structure")
# # Extract and clean JSON from the response
# json_data = extract_json_from_answer(llm_response['data']['answer'])
# # Return clean JSON string without extra formatting
# return json.dumps(json_data)
# except Exception as e:
# logger.error(f"Error in gradio_interface: {str(e)}")
# return json.dumps({"error": str(e)})
# # Gradio interface
# iface = gr.Interface(
# fn=gradio_interface,
# inputs=[
# gr.Textbox(
# label="Patient Information",
# placeholder="Enter patient details including: symptoms, medical history, current medications, age, gender, and any relevant test results...",
# lines=5,
# max_lines=10
# )
# ],
# outputs=gr.Textbox(
# label="Medical Analysis",
# placeholder="JSON analysis will appear here...",
# lines=15
# ),
# title="Medical Diagnosis Assistant",
# description="Enter detailed patient information to receive a structured medical analysis in JSON format."
# )
# if __name__ == "__main__":
# iface.launch()