Spaces:
Running
Running
import gradio as gr | |
import pymongo | |
import certifi | |
from llama_index.core import VectorStoreIndex | |
from llama_index.embeddings.huggingface import HuggingFaceEmbedding | |
from llama_index.llms.groq import Groq | |
from llama_index.vector_stores.mongodb import MongoDBAtlasVectorSearch | |
from llama_index.core.prompts import PromptTemplate | |
from dotenv import load_dotenv | |
import os | |
import base64 | |
import markdown as md | |
# Load environment variables | |
load_dotenv() | |
# --- MongoDB Config --- | |
ATLAS_CONNECTION_STRING = os.getenv("ATLAS_CONNECTION_STRING") | |
DB_NAME = "RAG" | |
COLLECTION_NAME = "ramayana" | |
VECTOR_INDEX_NAME = "ramayana_vector_index" | |
# --- Embedding Model --- | |
embed_model = HuggingFaceEmbedding(model_name="intfloat/multilingual-e5-base") | |
# --- Prompt Template --- | |
ramayana_qa_template = PromptTemplate( | |
"""You are an expert on the Valmiki Ramayana and a guide who always inspires people with the great Itihasa like the Ramayana. | |
Below is text from the epic, including shlokas and their explanations: | |
--------------------- | |
{context_str} | |
--------------------- | |
Using only this information, answer the following query. | |
Query: {query_str} | |
Answer: | |
- Intro or general description to ```Query``` | |
- Related shloka/shlokas followed by its explanation | |
- Overview of ```Query``` | |
""" | |
) | |
# --- Connect to MongoDB once at startup --- | |
def get_vector_index_once(): | |
mongo_client = pymongo.MongoClient( | |
ATLAS_CONNECTION_STRING, | |
tlsCAFile=certifi.where(), | |
tlsAllowInvalidCertificates=False, | |
connectTimeoutMS=30000, | |
serverSelectionTimeoutMS=30000, | |
) | |
mongo_client.server_info() | |
print("β Connected to MongoDB Atlas.") | |
vector_store = MongoDBAtlasVectorSearch( | |
mongo_client, | |
db_name=DB_NAME, | |
collection_name=COLLECTION_NAME, | |
vector_index_name=VECTOR_INDEX_NAME, | |
) | |
return VectorStoreIndex.from_vector_store(vector_store, embed_model=embed_model) | |
# Connect once | |
vector_index = get_vector_index_once() | |
# --- Respond Function (uses API key from state) --- | |
def chat_with_groq(message, history, groq_key): | |
llm = Groq(model="llama-3.1-8b-instant", api_key=groq_key) | |
query_engine = vector_index.as_query_engine( | |
llm=llm, | |
text_qa_template=ramayana_qa_template, | |
similarity_top_k=5, | |
verbose=True, | |
) | |
response = query_engine.query(message) | |
return str(response) | |
def encode_image(image_path): | |
with open(image_path, "rb") as image_file: | |
return base64.b64encode(image_file.read()).decode('utf-8') | |
# Encode the images | |
github_logo_encoded = encode_image("Images/github-logo.png") | |
linkedin_logo_encoded = encode_image("Images/linkedin-logo.png") | |
website_logo_encoded = encode_image("Images/ai-logo.png") | |
# --- Gradio UI --- | |
with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Roboto Mono")]), css='footer {visibility: hidden}') as demo: | |
with gr.Tabs(): | |
with gr.TabItem("Intro"): | |
gr.Markdown(md.description) | |
with gr.TabItem("GPT"): | |
with gr.Column(visible=True) as accordion_container: | |
with gr.Accordion("How to get Groq API KEY", open=False): | |
gr.Markdown(md.groq_api_key) | |
groq_key_box = gr.Textbox( | |
label="Enter Groq API Key", | |
type="password", | |
placeholder="Paste your Groq API key here..." | |
) | |
start_btn = gr.Button("Start Chat") | |
groq_state = gr.State(value="") | |
# Chat container, initially hidden | |
with gr.Column(visible=False) as chatbot_container: | |
chatbot = gr.ChatInterface( | |
fn=lambda message, history, groq_key: chat_with_groq(message, history, groq_key), | |
additional_inputs=[groq_state], | |
chatbot=gr.Chatbot(height=500), | |
title="ποΈ RamayanaGPT", | |
show_progress="full", | |
fill_height=True, | |
# description="Ask questions from the Valmiki Ramayana. Powered by RAG + MongoDB + LlamaIndex.", | |
) | |
# Show chat and hide inputs | |
def save_key_and_show_chat(key): | |
return key, gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=True) | |
start_btn.click( | |
fn=save_key_and_show_chat, | |
inputs=[groq_key_box], | |
outputs=[groq_state, groq_key_box, start_btn, accordion_container, chatbot_container] | |
) | |
gr.HTML(md.footer.format(github_logo_encoded, linkedin_logo_encoded, website_logo_encoded)) | |
if __name__ == "__main__": | |
demo.launch() |