Epic-Minds / app.py
Sarath0x8f's picture
Upload 2 files
f32ea1e verified
raw
history blame
6.17 kB
import gradio as gr
import pymongo
import certifi
from llama_index.core import VectorStoreIndex
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.llms.groq import Groq
from llama_index.vector_stores.mongodb import MongoDBAtlasVectorSearch
from llama_index.core.prompts import PromptTemplate
from dotenv import load_dotenv
import os
import base64
import markdown as md
from datetime import datetime
# Load environment variables
load_dotenv()
# --- Embedding Model ---
embed_model = HuggingFaceEmbedding(model_name="intfloat/multilingual-e5-base")
# --- Prompt Template ---
ramayana_qa_template = PromptTemplate(
"""You are an expert on the Valmiki Ramayana and a guide who always inspires people with the great Itihasa like the Ramayana.
Below is text from the epic, including shlokas and their explanations:
---------------------
{context_str}
---------------------
Using only this information, answer the following query.
Query: {query_str}
Answer:
- Intro or general description to ```Query```
- Related shloka/shlokas followed by its explanation
- Overview of ```Query```"""
)
gita_qa_template = PromptTemplate(
"""You are an expert on the Bhagavad Gita and a spiritual guide.
Below is text from the scripture, including verses and their explanations:
---------------------
{context_str}
---------------------
Using only this information, answer the following query.
Query: {query_str}
Answer:
- Intro or context about the topic
- Relevant verse(s) with explanation
- Conclusion or reflection"""
)
# --- Connect to MongoDB once at startup ---
def get_vector_index(db_name, collection_name, vector_index_name):
mongo_client = pymongo.MongoClient(
os.getenv("ATLAS_CONNECTION_STRING"),
tlsCAFile=certifi.where(),
tlsAllowInvalidCertificates=False,
connectTimeoutMS=30000,
serverSelectionTimeoutMS=30000,
)
mongo_client.server_info()
print(f"βœ… Connected to MongoDB Atlas for collection: {collection_name}")
vector_store = MongoDBAtlasVectorSearch(
mongo_client,
db_name=db_name,
collection_name=collection_name,
vector_index_name=vector_index_name,
)
return VectorStoreIndex.from_vector_store(vector_store, embed_model=embed_model)
# --- Respond Function (uses API key from state) ---
def chat_with_groq(index, template):
def fn(message, history, groq_key):
if not groq_key or not groq_key.startswith("gsk_"):
return "❌ Invalid Groq API Key. Please enter a valid key."
llm = Groq(model="llama-3.1-8b-instant", api_key=groq_key)
query_engine = index.as_query_engine(
llm=llm,
text_qa_template=template,
similarity_top_k=5,
verbose=True,
)
response = query_engine.query(message)
print(f"\n{datetime.now()}:: {message} --> {str(response)}\n")
return str(response)
return fn
# Load vector indices once
ramayana_index = get_vector_index("RAG", "ramayana", "ramayana_vector_index")
gita_index = get_vector_index("RAG", "bhagavad_gita", "gita_vector_index")
# Encode logos
def encode_image(image_path):
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode("utf-8")
github_logo_encoded = encode_image("Images/github-logo.png")
linkedin_logo_encoded = encode_image("Images/linkedin-logo.png")
website_logo_encoded = encode_image("Images/ai-logo.png")
# --- Gradio UI ---
with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Roboto Mono")]), css='footer {visibility: hidden}') as demo:
with gr.Tabs():
with gr.TabItem("Intro"):
gr.Markdown(md.description)
def create_tab(tab_title, chatbot_title, vector_index, template, intro):
with gr.TabItem(tab_title):
with gr.Column(visible=True) as accordion_container:
with gr.Accordion("How to get Groq API KEY", open=False):
gr.Markdown(md.groq_api_key)
groq_key_box = gr.Textbox(
label="Enter Groq API Key",
type="password",
placeholder="Paste your Groq API key here..."
)
start_btn = gr.Button("Start Chat")
groq_state = gr.State(value="")
with gr.Column(visible=False) as chatbot_container:
with gr.Accordion("Overview & Summary", open=False):
gr.Markdown(intro)
chatbot = gr.ChatInterface(
fn=chat_with_groq(vector_index, template),
additional_inputs=[groq_state],
chatbot=gr.Chatbot(height=500),
title=chatbot_title,
show_progress="full",
fill_height=True,
)
def save_key_and_show_chat(key):
if key and key.startswith("gsk_"):
return key, gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)
else:
return "", gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)
start_btn.click(
fn=save_key_and_show_chat,
inputs=[groq_key_box],
outputs=[groq_state, groq_key_box, start_btn, accordion_container, chatbot_container]
)
create_tab("RamayanaGPT", "πŸ•‰οΈ RamayanaGPT", ramayana_index, ramayana_qa_template, md.RamayanaGPT)
create_tab("GitaGPT", "πŸ•‰οΈ GitaGPT", gita_index, gita_qa_template, md.GitaGPT)
gr.HTML(md.footer.format(github_logo_encoded, linkedin_logo_encoded, website_logo_encoded))
if __name__ == "__main__":
demo.launch()