Spaces:
Sleeping
Sleeping
import gradio as gr | |
import pymongo | |
import certifi | |
from llama_index.core import VectorStoreIndex | |
from llama_index.embeddings.huggingface import HuggingFaceEmbedding | |
from llama_index.llms.gemini import Gemini | |
from llama_index.vector_stores.mongodb import MongoDBAtlasVectorSearch | |
from llama_index.core.prompts import PromptTemplate | |
from dotenv import load_dotenv | |
import os | |
import base64 | |
import markdown as md | |
from datetime import datetime | |
# Load environment variables | |
load_dotenv() | |
# --- Embedding Model --- | |
embed_model = HuggingFaceEmbedding(model_name="intfloat/multilingual-e5-base") | |
# --- Prompt Templates --- | |
ramayana_qa_template = PromptTemplate( | |
"""You are an expert on the Valmiki Ramayana and a guide who always inspires people with the great Itihasa like the Ramayana. | |
Below is text from the epic, including shlokas and their explanations: | |
--------------------- | |
{context_str} | |
--------------------- | |
Using only this information, answer the following query. | |
Query: {query_str} | |
Answer: | |
- Intro or general description to ```Query``` | |
- Related sanskrit shloka(s) followed by its explanation | |
- Overview of ```Query```""" | |
) | |
gita_qa_template = PromptTemplate( | |
"""You are an expert on the Bhagavad Gita and a spiritual guide. | |
Below is text from the scripture, including verses and their explanations: | |
--------------------- | |
{context_str} | |
--------------------- | |
Using only this information, answer the following query. | |
Query: {query_str} | |
Answer: | |
- Intro or context about the topic | |
- Relevant sanskrit verse(s) with explanation | |
- Conclusion or reflection""" | |
) | |
# --- MongoDB Vector Index Loader --- | |
def get_vector_index(db_name, collection_name, vector_index_name): | |
mongo_client = pymongo.MongoClient( | |
os.getenv("ATLAS_CONNECTION_STRING"), | |
tlsCAFile=certifi.where(), | |
) | |
mongo_client.server_info() | |
print(f"✅ Connected to MongoDB Atlas for collection: {collection_name}") | |
vector_store = MongoDBAtlasVectorSearch( | |
mongo_client, | |
db_name=db_name, | |
collection_name=collection_name, | |
vector_index_name=vector_index_name, | |
) | |
return VectorStoreIndex.from_vector_store(vector_store, embed_model=embed_model) | |
# --- Load Indices Once --- | |
ramayana_index = get_vector_index("RAG", "ramayana", "ramayana_vector_index") | |
gita_index = get_vector_index("RAG", "bhagavad_gita", "gita_vector_index") | |
# --- Gradio Chat Wrapper with Streaming --- | |
def chat(index, template): | |
llm = Gemini( | |
model="models/gemini-1.5-flash", | |
api_key=os.getenv("GOOGLE_API_KEY"), | |
streaming=True | |
) | |
query_engine = index.as_query_engine( | |
llm=llm, | |
text_qa_template=template, | |
similarity_top_k=5, | |
streaming=True, | |
verbose=True, | |
) | |
def fn(message, history): | |
streaming_response = query_engine.query(message) | |
full_response = "" | |
for text in streaming_response.response_gen: | |
full_response += text | |
yield full_response | |
response = query_engine.query(message) | |
yield str(response) | |
print(f"\n{datetime.now()}:: {message} --> {str(full_response)}\n") | |
return fn | |
# --- Encode Logos --- | |
def encode_image(image_path): | |
with open(image_path, "rb") as image_file: | |
return base64.b64encode(image_file.read()).decode("utf-8") | |
github_logo_encoded = encode_image("Images/github-logo.png") | |
linkedin_logo_encoded = encode_image("Images/linkedin-logo.png") | |
website_logo_encoded = encode_image("Images/ai-logo.png") | |
# --- Gradio UI --- | |
with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Roboto Mono")]), css="footer {visibility: hidden}") as demo: | |
with gr.Tabs(): | |
with gr.TabItem("Intro"): | |
gr.Markdown(md.description) | |
def create_tab(tab_title, vector_index, template, intro_md): | |
with gr.TabItem(tab_title): | |
with gr.Accordion("==========> Overview & Summary <==========", open=False): | |
gr.Markdown(intro_md) | |
gr.ChatInterface( | |
fn=chat(vector_index, template), | |
chatbot=gr.Chatbot(height=500), | |
show_progress="full", | |
fill_height=True, | |
) | |
create_tab("RamayanaGPT🏹", ramayana_index, ramayana_qa_template, md.RamayanaGPT) | |
create_tab("GitaGPT🛞", gita_index, gita_qa_template, md.GitaGPT) | |
gr.HTML(md.footer.format(github_logo_encoded, linkedin_logo_encoded, website_logo_encoded)) | |
if __name__ == "__main__": | |
demo.launch() | |