Spaces:
Runtime error
Runtime error
File size: 7,521 Bytes
d3ebdbc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 |
import os
import logging
import torch
import gradio as gr
from tqdm import tqdm
from PIL import Image
# LangChain & LangGraph
from langgraph.graph import StateGraph
from langgraph.checkpoint.memory import MemorySaver
from langchain.tools import tool
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
# Web Search
from duckduckgo_search import DDGS
# Llama GGUF Model Loader
from llama_cpp import Llama
# ------------------------------
# ๐น Setup Logging
# ------------------------------
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# ------------------------------
# ๐น Load GGUF Model with llama-cpp-python
# ------------------------------
model_path = "./Bio-Medical-MultiModal-Llama-3-8B-V1.i1-Q6_K.gguf" # Update with actual GGUF model path
llm = Llama(model_path=model_path, n_ctx=2048, n_gpu_layers=35) # Optimized for Hugging Face T4 GPU
logger.info("Llama GGUF Model Loaded Successfully.")
# ------------------------------
# ๐น Define Expert System Prompts
# ------------------------------
GP_PROMPT = "You are a General Practitioner AI Assistant. Answer medical questions with scientifically accurate information."
RADIOLOGY_PROMPT = "You are a Radiology AI expert. Analyze medical images and provide diagnostic insights."
WEBSEARCH_PROMPT = "You are a Web Search AI. Retrieve up-to-date medical information."
# ------------------------------
# ๐น FAISS Vector Store for RAG
# ------------------------------
_vector_store_cache = None
def load_vectorstore(pdf_path="medical_docs.pdf"):
"""Loads PDF files into a FAISS vector store for RAG."""
try:
loader = PyPDFLoader(pdf_path)
documents = loader.load()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=50)
docs = text_splitter.split_documents(documents)
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
vector_store = FAISS.from_documents(docs, embeddings)
logger.info(f"Vector store loaded with {len(docs)} documents.")
return vector_store
except Exception as e:
logger.error(f"Error loading vector store: {str(e)}")
return None
def update_vector_store(pdf_file):
"""Updates FAISS vector store when a new PDF is uploaded."""
pdf_path = "uploaded_medical_docs.pdf"
try:
with open(pdf_path, "wb") as f:
f.write(pdf_file.read())
vector_store = load_vectorstore(pdf_path)
os.remove(pdf_path) # Clean up
return vector_store
except Exception as e:
logger.error(f"Error updating vector store: {str(e)}")
return _vector_store_cache # Fallback to cached version
if os.path.exists("medical_docs.pdf"):
_vector_store_cache = load_vectorstore("medical_docs.pdf")
else:
_vector_store_cache = None
vector_store = _vector_store_cache
# ------------------------------
# ๐น Define AI Tools
# ------------------------------
@tool
def analyze_medical_image(image_path: str):
"""Analyzes a medical image and returns a diagnostic explanation."""
try:
image = Image.open(image_path)
except Exception as e:
logger.error(f"Error opening image: {str(e)}")
return "Error processing image."
# Process image using Llama GGUF model
output = llm(f"Analyze this medical image and provide a diagnosis:\n{image}")
return output["choices"][0]["text"]
@tool
def retrieve_medical_knowledge(query: str):
"""Retrieves medical knowledge from FAISS vector store."""
if vector_store is None:
return "No external medical knowledge available."
retriever = vector_store.as_retriever(search_type="similarity", search_kwargs={"k": 3})
docs = retriever.get_relevant_documents(query)
citations = [f"[{i+1}] {doc.metadata.get('source', 'Unknown Source')}" for i, doc in enumerate(docs)]
citations_text = "\n".join(citations)
content = "\n".join([doc.page_content for doc in docs])
return content + f"\n\n**Citations:**\n{citations_text}"
@tool
def web_search(query: str):
"""Performs a real-time web search using DuckDuckGo."""
try:
results = ddg(query, max_results=3)
summary = "\n".join([f"{r['title']}: {r['body']} ({r['href']})" for r in results]) or "No relevant results found."
return summary
except Exception as e:
logger.error(f"Web search error: {str(e)}")
return "Error retrieving web search results."
# ------------------------------
# ๐น Define Multi-Agent Workflow (LangGraph)
# ------------------------------
class AgentState:
def __init__(self, query="", response="", image_path=None, expert=""):
self.query = query
self.response = response
self.image_path = image_path
self.expert = expert # "GP", "Radiology", "Web Search"
# Memory checkpointing
checkpointer = MemorySaver()
# Create LangGraph state graph
agent_graph = StateGraph(AgentState)
def route_query(state: AgentState):
"""Determines which AI expert should handle the query."""
if state.image_path:
return "radiology_specialist"
elif any(word in state.query.lower() for word in ["latest", "update", "breaking news"]):
return "web_search_expert"
else:
return "general_practitioner"
def general_practitioner(state: AgentState):
"""GP Expert: Handles medical text queries and retrieves knowledge."""
query = state.query
retrieved_info = retrieve_medical_knowledge.run(query)
output = llm(f"{GP_PROMPT}\nQ: {query}\nA:")
return AgentState(query=query, response=output["choices"][0]["text"] + "\n\n" + retrieved_info, expert="GP")
def radiology_specialist(state: AgentState):
"""Radiology Expert: Analyzes medical images."""
image_analysis = analyze_medical_image.run(state.image_path)
return AgentState(query=state.query, response=image_analysis, expert="Radiology")
def web_search_expert(state: AgentState):
"""Web Search Expert: Retrieves the latest information."""
search_result = web_search.run(state.query)
return AgentState(query=state.query, response=search_result, expert="Web Search")
# Add nodes
agent_graph.add_node("general_practitioner", general_practitioner)
agent_graph.add_node("radiology_specialist", radiology_specialist)
agent_graph.add_node("web_search_expert", web_search_expert)
agent_graph.add_conditional_edges("route_query", route_query, {"general_practitioner", "radiology_specialist", "web_search_expert"})
agent_graph.set_entry_point("route_query")
# Compile graph
app = agent_graph.compile(checkpointer=checkpointer)
# ------------------------------
# ๐น Gradio Interface
# ------------------------------
with gr.Blocks(title="Llama3-Med Multi-Agent AI") as demo:
gr.Markdown("# ๐ฅ AI Medical Assistant")
with gr.Row():
user_input = gr.Textbox(label="Your Question")
image_file = gr.Image(label="Upload Medical Image (Optional)", type="file")
pdf_file = gr.File(label="Upload PDF (Optional)", file_types=[".pdf"])
submit_btn = gr.Button("Submit")
output_text = gr.Textbox(label="Assistant's Response", interactive=False)
submit_btn.click(fn=chat_with_agent, inputs=[user_input, image_file, pdf_file], outputs=output_text)
if __name__ == "__main__":
demo.launch()
|