import os import logging import torch import gradio as gr from tqdm import tqdm from PIL import Image # LangChain & LangGraph from langgraph.graph import StateGraph from langgraph.checkpoint.memory import MemorySaver from langchain.tools import tool from langchain_community.vectorstores import FAISS from langchain_community.embeddings import HuggingFaceEmbeddings from langchain_community.document_loaders import PyPDFLoader from langchain.text_splitter import RecursiveCharacterTextSplitter # Web Search from duckduckgo_search import DDGS # Llama GGUF Model Loader from llama_cpp import Llama # ------------------------------ # 🔹 Setup Logging # ------------------------------ logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # ------------------------------ # 🔹 Load GGUF Model with llama-cpp-python # ------------------------------ model_path = "./Bio-Medical-MultiModal-Llama-3-8B-V1.i1-Q6_K.gguf" # Update with actual GGUF model path llm = Llama(model_path=model_path, n_ctx=2048, n_gpu_layers=35) # Optimized for Hugging Face T4 GPU logger.info("Llama GGUF Model Loaded Successfully.") # ------------------------------ # 🔹 Define Expert System Prompts # ------------------------------ GP_PROMPT = "You are a General Practitioner AI Assistant. Answer medical questions with scientifically accurate information." RADIOLOGY_PROMPT = "You are a Radiology AI expert. Analyze medical images and provide diagnostic insights." WEBSEARCH_PROMPT = "You are a Web Search AI. Retrieve up-to-date medical information." # ------------------------------ # 🔹 FAISS Vector Store for RAG # ------------------------------ _vector_store_cache = None def load_vectorstore(pdf_path="medical_docs.pdf"): """Loads PDF files into a FAISS vector store for RAG.""" try: loader = PyPDFLoader(pdf_path) documents = loader.load() text_splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=50) docs = text_splitter.split_documents(documents) embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") vector_store = FAISS.from_documents(docs, embeddings) logger.info(f"Vector store loaded with {len(docs)} documents.") return vector_store except Exception as e: logger.error(f"Error loading vector store: {str(e)}") return None def update_vector_store(pdf_file): """Updates FAISS vector store when a new PDF is uploaded.""" pdf_path = "uploaded_medical_docs.pdf" try: with open(pdf_path, "wb") as f: f.write(pdf_file.read()) vector_store = load_vectorstore(pdf_path) os.remove(pdf_path) # Clean up return vector_store except Exception as e: logger.error(f"Error updating vector store: {str(e)}") return _vector_store_cache # Fallback to cached version if os.path.exists("medical_docs.pdf"): _vector_store_cache = load_vectorstore("medical_docs.pdf") else: _vector_store_cache = None vector_store = _vector_store_cache # ------------------------------ # 🔹 Define AI Tools # ------------------------------ @tool def analyze_medical_image(image_path: str): """Analyzes a medical image and returns a diagnostic explanation.""" try: image = Image.open(image_path) except Exception as e: logger.error(f"Error opening image: {str(e)}") return "Error processing image." # Process image using Llama GGUF model output = llm(f"Analyze this medical image and provide a diagnosis:\n{image}") return output["choices"][0]["text"] @tool def retrieve_medical_knowledge(query: str): """Retrieves medical knowledge from FAISS vector store.""" if vector_store is None: return "No external medical knowledge available." retriever = vector_store.as_retriever(search_type="similarity", search_kwargs={"k": 3}) docs = retriever.get_relevant_documents(query) citations = [f"[{i+1}] {doc.metadata.get('source', 'Unknown Source')}" for i, doc in enumerate(docs)] citations_text = "\n".join(citations) content = "\n".join([doc.page_content for doc in docs]) return content + f"\n\n**Citations:**\n{citations_text}" @tool def web_search(query: str): """Performs a real-time web search using DuckDuckGo.""" try: results = ddg(query, max_results=3) summary = "\n".join([f"{r['title']}: {r['body']} ({r['href']})" for r in results]) or "No relevant results found." return summary except Exception as e: logger.error(f"Web search error: {str(e)}") return "Error retrieving web search results." # ------------------------------ # 🔹 Define Multi-Agent Workflow (LangGraph) # ------------------------------ class AgentState: def __init__(self, query="", response="", image_path=None, expert=""): self.query = query self.response = response self.image_path = image_path self.expert = expert # "GP", "Radiology", "Web Search" # Memory checkpointing checkpointer = MemorySaver() # Create LangGraph state graph agent_graph = StateGraph(AgentState) def route_query(state: AgentState): """Determines which AI expert should handle the query.""" if state.image_path: return "radiology_specialist" elif any(word in state.query.lower() for word in ["latest", "update", "breaking news"]): return "web_search_expert" else: return "general_practitioner" def general_practitioner(state: AgentState): """GP Expert: Handles medical text queries and retrieves knowledge.""" query = state.query retrieved_info = retrieve_medical_knowledge.run(query) output = llm(f"{GP_PROMPT}\nQ: {query}\nA:") return AgentState(query=query, response=output["choices"][0]["text"] + "\n\n" + retrieved_info, expert="GP") def radiology_specialist(state: AgentState): """Radiology Expert: Analyzes medical images.""" image_analysis = analyze_medical_image.run(state.image_path) return AgentState(query=state.query, response=image_analysis, expert="Radiology") def web_search_expert(state: AgentState): """Web Search Expert: Retrieves the latest information.""" search_result = web_search.run(state.query) return AgentState(query=state.query, response=search_result, expert="Web Search") # Add nodes agent_graph.add_node("general_practitioner", general_practitioner) agent_graph.add_node("radiology_specialist", radiology_specialist) agent_graph.add_node("web_search_expert", web_search_expert) agent_graph.add_conditional_edges("route_query", route_query, {"general_practitioner", "radiology_specialist", "web_search_expert"}) agent_graph.set_entry_point("route_query") # Compile graph app = agent_graph.compile(checkpointer=checkpointer) # ------------------------------ # 🔹 Gradio Interface # ------------------------------ with gr.Blocks(title="Llama3-Med Multi-Agent AI") as demo: gr.Markdown("# 🏥 AI Medical Assistant") with gr.Row(): user_input = gr.Textbox(label="Your Question") image_file = gr.Image(label="Upload Medical Image (Optional)", type="file") pdf_file = gr.File(label="Upload PDF (Optional)", file_types=[".pdf"]) submit_btn = gr.Button("Submit") output_text = gr.Textbox(label="Assistant's Response", interactive=False) submit_btn.click(fn=chat_with_agent, inputs=[user_input, image_file, pdf_file], outputs=output_text) if __name__ == "__main__": demo.launch()