File size: 5,022 Bytes
5acccc4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import os
import pandas as pd
import gradio as gr
from langchain_community.vectorstores import FAISS
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_groq import ChatGroq
from langchain.prompts import PromptTemplate
from langchain.chains import RetrievalQA
from langchain_core.documents import Document

# Hardcoded Groq API key
GROK_API_KEY = "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"

# Initialize LLM (Grok)
def initialize_llm():
    return ChatGroq(
        temperature=0.7,
        groq_api_key=GROK_API_KEY,
        model_name="llama-3.3-70b-versatile"
    )

llm = initialize_llm()

# Load and prepare the CSV dataset, then create or load FAISS index
def create_or_load_faiss_index():
    index_path = "faiss_index"
    embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
    
    if os.path.exists(index_path):
        vector_db = FAISS.load_local(index_path, embeddings, allow_dangerous_deserialization=True)
    else:
        csv_path = "A_Z_medicines_dataset_of_India.csv"
        if not os.path.exists(csv_path):
            raise FileNotFoundError(f"Dataset not found at: {csv_path}")
        
        df = pd.read_csv(csv_path)
        documents = [
            Document(
                page_content=row["name"],
                metadata={"short_composition1": row["short_composition1"]}
            )
            for _, row in df.iterrows()
            if pd.notna(row["name"]) and pd.notna(row["short_composition1"])
        ]
        
        vector_db = FAISS.from_documents(documents, embeddings)
        vector_db.save_local(index_path)
    
    return vector_db

vector_db = create_or_load_faiss_index()

# Set up QA chain
retriever = vector_db.as_retriever(search_kwargs={"k": 1})
prompt_template = """You are DrugScan, a medical assistant that explains drug compositions. Provide a detailed explanation of the drug based on its active ingredient and dosage, including its uses, mechanism of action, potential side effects, and any relevant precautions. Be empathetic and clear in your response.

Drug Composition: {context}
User Query: {question}
DrugScan: """
PROMPT = PromptTemplate(template=prompt_template, input_variables=["context", "question"])
qa_chain = RetrievalQA.from_chain_type(
    llm=llm,
    chain_type="stuff",
    retriever=retriever,
    chain_type_kwargs={"prompt": PROMPT},
    return_source_documents=True
)

# Suggested drugs
suggested_drugs = [
    "Azirox",
    "Augmentin",
    "Ascoril LS",
    "Allepra 120",
    "Amoxycillin",
]

# Function to handle drug query
def query_drug(drug_name, chat_history):
    if not drug_name.strip():
        return chat_history + [[None, "Please enter a drug name."]]
    
    try:
        result = qa_chain.invoke({"query": drug_name})
        if not result["source_documents"]:
            response = "Drug not found in the dataset. Please try another drug name."
        else:
            composition = result["source_documents"][0].metadata["short_composition1"]
            response = f"{result['result']}\n\n**Drug Composition:** {composition}"
    except Exception as e:
        error_msg = str(e)
        if "rate limit" in error_msg.lower() or "quota" in error_msg.lower():
            response = "Error: Rate limit or quota exceeded for the Groq API. Please try again later."
        elif "connection" in error_msg.lower() or "network" in error_msg.lower():
            response = "Error: Network issue while connecting to the Groq API. Please check your internet connection."
        else:
            response = f"Error: An unexpected error occurred: {error_msg}"
    
    return chat_history + [[drug_name, response]]

# Gradio Interface
with gr.Blocks(title="DrugScan") as demo:
    gr.Markdown("# DrugScan")
    gr.Markdown("Enter the name of a drug to learn about its active ingredients, uses, mechanism of action, side effects, and more.")
    
    # Display logo
    logo_url = "https://i.postimg.cc/gJ9Z0RGS/bc20af1b-8ee6-4e1c-8748-eba44e2780c1-removalai-preview.png"
    gr.Image(logo_url, width=150)
    
    # Chat interface
    chatbot = gr.Chatbot(label="Results")
    drug_input = gr.Textbox(placeholder="Enter a drug name (e.g., 'Azirox')", label="Drug Name")
    
    # Suggested drugs buttons
    gr.Markdown("### Try These Drugs")
    with gr.Row():
        for drug in suggested_drugs:
            gr.Button(drug).click(
                fn=query_drug,
                inputs=[drug, chatbot],
                outputs=chatbot
            )
    
    # Search button
    drug_input.submit(
        fn=query_drug,
        inputs=[drug_input, chatbot],
        outputs=chatbot
    )
    
    # Disclaimer
    gr.Markdown("### Important Disclaimer")
    gr.Markdown(
        "DrugScan provides explanations of drug compositions based on available data. It is not a substitute for professional medical advice or diagnosis. Always consult a qualified healthcare provider for personal health concerns."
    )

# Launch the app
demo.launch()