import streamlit as st import pandas as pd import os import faiss import pickle import random from datasets import load_dataset from sentence_transformers import SentenceTransformer from groq import Groq from dotenv import load_dotenv # Load environment variables load_dotenv() # Setup Groq client client = Groq(api_key=os.getenv("GROQ_API_KEY")) MODEL_NAME = "llama3-70b-8192" # or try "llama-3-8b-8192" or "llama-3-3b-8192" # Streamlit UI st.set_page_config(page_title="RAG with Groq", layout="wide") st.title("📚 RAG App using Groq API") st.markdown("Ask enterprise, financial, and legal questions using Retrieval-Augmented Generation (RAG).") # Load dataset from Hugging Face @st.cache_data def load_data(): dataset = load_dataset("llmware/rag_instruct_benchmark_tester", split="train") df = pd.DataFrame(dataset) return df # Build or load FAISS index @st.cache_resource def load_embeddings(df): embed_model = SentenceTransformer('all-MiniLM-L6-v2') context_list = df['context'].tolist() embeddings = embed_model.encode(context_list, show_progress_bar=True) index = faiss.IndexFlatL2(embeddings[0].shape[0]) index.add(embeddings) return index, embeddings, embed_model # Retrieve top-k relevant context def retrieve_context(query, embed_model, index, df, k=3): query_embedding = embed_model.encode([query]) D, I = index.search(query_embedding, k) context_passages = df.iloc[I[0]]['context'].tolist() return context_passages # Ask the Groq LLM def ask_groq(query, context): prompt = f"""You are a helpful assistant. Use the context to answer the question. Context: {context} Question: {query} Answer:""" response = client.chat.completions.create( messages=[{"role": "user", "content": prompt}], model=MODEL_NAME ) return response.choices[0].message.content # Load everything df = load_data() index, embeddings, embed_model = load_embeddings(df) # User input st.subheader("🔍 Ask your question") sample_queries = df['query'].dropna().unique().tolist() col1, col2 = st.columns([3, 1]) with col1: query = st.text_input("Enter your question here:", value=st.session_state.get("query", "")) with col2: if st.button("🎲 Random Sample"): st.session_state["query"] = random.choice(sample_queries) st.rerun() # Handle query if query: st.markdown(f"**Your Query:** {query}") with st.spinner("🔎 Retrieving relevant context..."): contexts = retrieve_context(query, embed_model, index, df) combined_context = "\n\n".join(contexts) with st.spinner("🤖 Querying Groq LLM..."): answer = ask_groq(query, combined_context) st.markdown("### 💡 Answer") st.write(answer) st.markdown("### 📄 Retrieved Context") for i, ctx in enumerate(contexts, 1): with st.expander(f"Context {i}"): st.write(ctx)