File size: 7,521 Bytes
d3ebdbc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
import os
import logging
import torch
import gradio as gr
from tqdm import tqdm
from PIL import Image

# LangChain & LangGraph
from langgraph.graph import StateGraph
from langgraph.checkpoint.memory import MemorySaver
from langchain.tools import tool
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter

# Web Search
from duckduckgo_search import DDGS

# Llama GGUF Model Loader
from llama_cpp import Llama

# ------------------------------
# ๐Ÿ”น Setup Logging
# ------------------------------
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# ------------------------------
# ๐Ÿ”น Load GGUF Model with llama-cpp-python
# ------------------------------
model_path = "./Bio-Medical-MultiModal-Llama-3-8B-V1.i1-Q6_K.gguf"  # Update with actual GGUF model path
llm = Llama(model_path=model_path, n_ctx=2048, n_gpu_layers=35)  # Optimized for Hugging Face T4 GPU

logger.info("Llama GGUF Model Loaded Successfully.")

# ------------------------------
# ๐Ÿ”น Define Expert System Prompts
# ------------------------------
GP_PROMPT = "You are a General Practitioner AI Assistant. Answer medical questions with scientifically accurate information."
RADIOLOGY_PROMPT = "You are a Radiology AI expert. Analyze medical images and provide diagnostic insights."
WEBSEARCH_PROMPT = "You are a Web Search AI. Retrieve up-to-date medical information."

# ------------------------------
# ๐Ÿ”น FAISS Vector Store for RAG
# ------------------------------
_vector_store_cache = None

def load_vectorstore(pdf_path="medical_docs.pdf"):
    """Loads PDF files into a FAISS vector store for RAG."""
    try:
        loader = PyPDFLoader(pdf_path)
        documents = loader.load()
        text_splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=50)
        docs = text_splitter.split_documents(documents)
        embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
        vector_store = FAISS.from_documents(docs, embeddings)
        logger.info(f"Vector store loaded with {len(docs)} documents.")
        return vector_store
    except Exception as e:
        logger.error(f"Error loading vector store: {str(e)}")
        return None

def update_vector_store(pdf_file):
    """Updates FAISS vector store when a new PDF is uploaded."""
    pdf_path = "uploaded_medical_docs.pdf"
    try:
        with open(pdf_path, "wb") as f:
            f.write(pdf_file.read())
        vector_store = load_vectorstore(pdf_path)
        os.remove(pdf_path)  # Clean up
        return vector_store
    except Exception as e:
        logger.error(f"Error updating vector store: {str(e)}")
        return _vector_store_cache  # Fallback to cached version

if os.path.exists("medical_docs.pdf"):
    _vector_store_cache = load_vectorstore("medical_docs.pdf")
else:
    _vector_store_cache = None

vector_store = _vector_store_cache

# ------------------------------
# ๐Ÿ”น Define AI Tools
# ------------------------------
@tool
def analyze_medical_image(image_path: str):
    """Analyzes a medical image and returns a diagnostic explanation."""
    try:
        image = Image.open(image_path)
    except Exception as e:
        logger.error(f"Error opening image: {str(e)}")
        return "Error processing image."

    # Process image using Llama GGUF model
    output = llm(f"Analyze this medical image and provide a diagnosis:\n{image}")
    return output["choices"][0]["text"]

@tool
def retrieve_medical_knowledge(query: str):
    """Retrieves medical knowledge from FAISS vector store."""
    if vector_store is None:
        return "No external medical knowledge available."
    
    retriever = vector_store.as_retriever(search_type="similarity", search_kwargs={"k": 3})
    docs = retriever.get_relevant_documents(query)
    citations = [f"[{i+1}] {doc.metadata.get('source', 'Unknown Source')}" for i, doc in enumerate(docs)]
    citations_text = "\n".join(citations)
    content = "\n".join([doc.page_content for doc in docs])
    return content + f"\n\n**Citations:**\n{citations_text}"

@tool
def web_search(query: str):
    """Performs a real-time web search using DuckDuckGo."""
    try:
        results = ddg(query, max_results=3)
        summary = "\n".join([f"{r['title']}: {r['body']} ({r['href']})" for r in results]) or "No relevant results found."
        return summary
    except Exception as e:
        logger.error(f"Web search error: {str(e)}")
        return "Error retrieving web search results."

# ------------------------------
# ๐Ÿ”น Define Multi-Agent Workflow (LangGraph)
# ------------------------------
class AgentState:
    def __init__(self, query="", response="", image_path=None, expert=""):
        self.query = query
        self.response = response
        self.image_path = image_path
        self.expert = expert  # "GP", "Radiology", "Web Search"

# Memory checkpointing
checkpointer = MemorySaver()

# Create LangGraph state graph
agent_graph = StateGraph(AgentState)

def route_query(state: AgentState):
    """Determines which AI expert should handle the query."""
    if state.image_path:
        return "radiology_specialist"
    elif any(word in state.query.lower() for word in ["latest", "update", "breaking news"]):
        return "web_search_expert"
    else:
        return "general_practitioner"

def general_practitioner(state: AgentState):
    """GP Expert: Handles medical text queries and retrieves knowledge."""
    query = state.query
    retrieved_info = retrieve_medical_knowledge.run(query)
    output = llm(f"{GP_PROMPT}\nQ: {query}\nA:")
    return AgentState(query=query, response=output["choices"][0]["text"] + "\n\n" + retrieved_info, expert="GP")

def radiology_specialist(state: AgentState):
    """Radiology Expert: Analyzes medical images."""
    image_analysis = analyze_medical_image.run(state.image_path)
    return AgentState(query=state.query, response=image_analysis, expert="Radiology")

def web_search_expert(state: AgentState):
    """Web Search Expert: Retrieves the latest information."""
    search_result = web_search.run(state.query)
    return AgentState(query=state.query, response=search_result, expert="Web Search")

# Add nodes
agent_graph.add_node("general_practitioner", general_practitioner)
agent_graph.add_node("radiology_specialist", radiology_specialist)
agent_graph.add_node("web_search_expert", web_search_expert)
agent_graph.add_conditional_edges("route_query", route_query, {"general_practitioner", "radiology_specialist", "web_search_expert"})
agent_graph.set_entry_point("route_query")

# Compile graph
app = agent_graph.compile(checkpointer=checkpointer)

# ------------------------------
# ๐Ÿ”น Gradio Interface
# ------------------------------
with gr.Blocks(title="Llama3-Med Multi-Agent AI") as demo:
    gr.Markdown("# ๐Ÿฅ AI Medical Assistant")
    
    with gr.Row():
        user_input = gr.Textbox(label="Your Question")
        image_file = gr.Image(label="Upload Medical Image (Optional)", type="file")
        pdf_file = gr.File(label="Upload PDF (Optional)", file_types=[".pdf"])
        submit_btn = gr.Button("Submit")
        output_text = gr.Textbox(label="Assistant's Response", interactive=False)

    submit_btn.click(fn=chat_with_agent, inputs=[user_input, image_file, pdf_file], outputs=output_text)

if __name__ == "__main__":
    demo.launch()