Epic-Minds / app.py
Sarath0x8f's picture
Upload 2 files
32ca645 verified
import gradio as gr
import pymongo
import certifi
from llama_index.core import VectorStoreIndex
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.llms.gemini import Gemini
from llama_index.vector_stores.mongodb import MongoDBAtlasVectorSearch
from llama_index.core.prompts import PromptTemplate
from dotenv import load_dotenv
import os
import base64
import markdown as md
from datetime import datetime
# Load environment variables
load_dotenv()
# --- Embedding Model ---
embed_model = HuggingFaceEmbedding(model_name="intfloat/multilingual-e5-base")
# --- Prompt Templates ---
ramayana_qa_template = PromptTemplate(
"""You are an expert on the Valmiki Ramayana and a guide who always inspires people with the great Itihasa like the Ramayana.
Below is text from the epic, including shlokas and their explanations:
---------------------
{context_str}
---------------------
Using only this information, answer the following query.
Query: {query_str}
Answer:
- Intro or general description to ```Query```
- Related sanskrit shloka(s) followed by its explanation
- Overview of ```Query```"""
)
gita_qa_template = PromptTemplate(
"""You are an expert on the Bhagavad Gita and a spiritual guide.
Below is text from the scripture, including verses and their explanations:
---------------------
{context_str}
---------------------
Using only this information, answer the following query.
Query: {query_str}
Answer:
- Intro or context about the topic
- Relevant sanskrit verse(s) with explanation
- Conclusion or reflection"""
)
# --- MongoDB Vector Index Loader ---
def get_vector_index(db_name, collection_name, vector_index_name):
mongo_client = pymongo.MongoClient(
os.getenv("ATLAS_CONNECTION_STRING"),
tlsCAFile=certifi.where(),
)
mongo_client.server_info()
print(f"✅ Connected to MongoDB Atlas for collection: {collection_name}")
vector_store = MongoDBAtlasVectorSearch(
mongo_client,
db_name=db_name,
collection_name=collection_name,
vector_index_name=vector_index_name,
)
return VectorStoreIndex.from_vector_store(vector_store, embed_model=embed_model)
# --- Load Indices Once ---
ramayana_index = get_vector_index("RAG", "ramayana", "ramayana_vector_index")
gita_index = get_vector_index("RAG", "bhagavad_gita", "gita_vector_index")
# --- Gradio Chat Wrapper with Streaming ---
def chat(index, template):
llm = Gemini(
model="models/gemini-1.5-flash",
api_key=os.getenv("GOOGLE_API_KEY"),
streaming=True
)
query_engine = index.as_query_engine(
llm=llm,
text_qa_template=template,
similarity_top_k=5,
streaming=True,
verbose=True,
)
def fn(message, history):
streaming_response = query_engine.query(message)
full_response = ""
for text in streaming_response.response_gen:
full_response += text
yield full_response
response = query_engine.query(message)
yield str(response)
print(f"\n{datetime.now()}:: {message} --> {str(full_response)}\n")
return fn
# --- Encode Logos ---
def encode_image(image_path):
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode("utf-8")
github_logo_encoded = encode_image("Images/github-logo.png")
linkedin_logo_encoded = encode_image("Images/linkedin-logo.png")
website_logo_encoded = encode_image("Images/ai-logo.png")
# --- Gradio UI ---
with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Roboto Mono")]), css="footer {visibility: hidden}") as demo:
with gr.Tabs():
with gr.TabItem("Intro"):
gr.Markdown(md.description)
def create_tab(tab_title, vector_index, template, intro_md):
with gr.TabItem(tab_title):
with gr.Accordion("==========> Overview & Summary <==========", open=False):
gr.Markdown(intro_md)
gr.ChatInterface(
fn=chat(vector_index, template),
chatbot=gr.Chatbot(height=500),
show_progress="full",
fill_height=True,
)
create_tab("RamayanaGPT🏹", ramayana_index, ramayana_qa_template, md.RamayanaGPT)
create_tab("GitaGPT🛞", gita_index, gita_qa_template, md.GitaGPT)
gr.HTML(md.footer.format(github_logo_encoded, linkedin_logo_encoded, website_logo_encoded))
if __name__ == "__main__":
demo.launch()