Spaces:
Running
Running
import gradio as gr | |
import pymongo | |
import certifi | |
from llama_index.core import VectorStoreIndex | |
from llama_index.embeddings.huggingface import HuggingFaceEmbedding | |
from llama_index.llms.groq import Groq | |
from llama_index.vector_stores.mongodb import MongoDBAtlasVectorSearch | |
from llama_index.core.prompts import PromptTemplate | |
from dotenv import load_dotenv | |
import os | |
import base64 | |
import markdown as md | |
from datetime import datetime | |
# Load environment variables | |
load_dotenv() | |
# --- Embedding Model --- | |
embed_model = HuggingFaceEmbedding(model_name="intfloat/multilingual-e5-base") | |
# --- Prompt Template --- | |
ramayana_qa_template = PromptTemplate( | |
"""You are an expert on the Valmiki Ramayana and a guide who always inspires people with the great Itihasa like the Ramayana. | |
Below is text from the epic, including shlokas and their explanations: | |
--------------------- | |
{context_str} | |
--------------------- | |
Using only this information, answer the following query. | |
Query: {query_str} | |
Answer: | |
- Intro or general description to ```Query``` | |
- Related shloka/shlokas followed by its explanation | |
- Overview of ```Query```""" | |
) | |
gita_qa_template = PromptTemplate( | |
"""You are an expert on the Bhagavad Gita and a spiritual guide. | |
Below is text from the scripture, including verses and their explanations: | |
--------------------- | |
{context_str} | |
--------------------- | |
Using only this information, answer the following query. | |
Query: {query_str} | |
Answer: | |
- Intro or context about the topic | |
- Relevant verse(s) with explanation | |
- Conclusion or reflection""" | |
) | |
# --- Connect to MongoDB once at startup --- | |
def get_vector_index(db_name, collection_name, vector_index_name): | |
mongo_client = pymongo.MongoClient( | |
os.getenv("ATLAS_CONNECTION_STRING"), | |
tlsCAFile=certifi.where(), | |
tlsAllowInvalidCertificates=False, | |
connectTimeoutMS=30000, | |
serverSelectionTimeoutMS=30000, | |
) | |
mongo_client.server_info() | |
print(f"β Connected to MongoDB Atlas for collection: {collection_name}") | |
vector_store = MongoDBAtlasVectorSearch( | |
mongo_client, | |
db_name=db_name, | |
collection_name=collection_name, | |
vector_index_name=vector_index_name, | |
) | |
return VectorStoreIndex.from_vector_store(vector_store, embed_model=embed_model) | |
# --- Respond Function (uses API key from state) --- | |
def chat_with_groq(index, template): | |
def fn(message, history, groq_key): | |
if not groq_key or not groq_key.startswith("gsk_"): | |
return "β Invalid Groq API Key. Please enter a valid key." | |
llm = Groq(model="llama-3.1-8b-instant", api_key=groq_key) | |
query_engine = index.as_query_engine( | |
llm=llm, | |
text_qa_template=template, | |
similarity_top_k=5, | |
verbose=True, | |
) | |
response = query_engine.query(message) | |
print(f"\n{datetime.now()}:: {message} --> {str(response)}\n") | |
return str(response) | |
return fn | |
# Load vector indices once | |
ramayana_index = get_vector_index("RAG", "ramayana", "ramayana_vector_index") | |
gita_index = get_vector_index("RAG", "bhagavad_gita", "gita_vector_index") | |
# Encode logos | |
def encode_image(image_path): | |
with open(image_path, "rb") as image_file: | |
return base64.b64encode(image_file.read()).decode("utf-8") | |
github_logo_encoded = encode_image("Images/github-logo.png") | |
linkedin_logo_encoded = encode_image("Images/linkedin-logo.png") | |
website_logo_encoded = encode_image("Images/ai-logo.png") | |
# --- Gradio UI --- | |
with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Roboto Mono")]), css='footer {visibility: hidden}') as demo: | |
with gr.Tabs(): | |
with gr.TabItem("Intro"): | |
gr.Markdown(md.description) | |
def create_tab(tab_title, chatbot_title, vector_index, template, intro): | |
with gr.TabItem(tab_title): | |
with gr.Column(visible=True) as accordion_container: | |
with gr.Accordion("How to get Groq API KEY", open=False): | |
gr.Markdown(md.groq_api_key) | |
groq_key_box = gr.Textbox( | |
label="Enter Groq API Key", | |
type="password", | |
placeholder="Paste your Groq API key here..." | |
) | |
start_btn = gr.Button("Start Chat") | |
groq_state = gr.State(value="") | |
with gr.Column(visible=False) as chatbot_container: | |
with gr.Accordion("Overview & Summary", open=False): | |
gr.Markdown(intro) | |
chatbot = gr.ChatInterface( | |
fn=chat_with_groq(vector_index, template), | |
additional_inputs=[groq_state], | |
chatbot=gr.Chatbot(height=500), | |
title=chatbot_title, | |
show_progress="full", | |
fill_height=True, | |
) | |
def save_key_and_show_chat(key): | |
if key and key.startswith("gsk_"): | |
return key, gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=True) | |
else: | |
return "", gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(visible=False) | |
start_btn.click( | |
fn=save_key_and_show_chat, | |
inputs=[groq_key_box], | |
outputs=[groq_state, groq_key_box, start_btn, accordion_container, chatbot_container] | |
) | |
create_tab("RamayanaGPT", "ποΈ RamayanaGPT", ramayana_index, ramayana_qa_template, md.RamayanaGPT) | |
create_tab("GitaGPT", "ποΈ GitaGPT", gita_index, gita_qa_template, md.GitaGPT) | |
gr.HTML(md.footer.format(github_logo_encoded, linkedin_logo_encoded, website_logo_encoded)) | |
if __name__ == "__main__": | |
demo.launch() |