import streamlit as st import pandas as pd import os import faiss import pickle from sentence_transformers import SentenceTransformer from groq import Groq from datasets import load_dataset # Load environment variables from dotenv import load_dotenv load_dotenv() # Setup Groq client client = Groq(api_key=os.getenv("GROQ_API_KEY")) MODEL_NAME = "llama-3-70b-8192" # Or use "llama-3-8b-8192", "llama-3-3b-8192" # Load dataset @st.cache_data def load_data(): dataset = load_dataset("llmware/rag_instruct_benchmark_tester", split="train") df = pd.DataFrame(dataset) return df # Build or load FAISS index @st.cache_resource def load_embeddings(df): embed_model = SentenceTransformer('all-MiniLM-L6-v2') context_list = df['context'].tolist() embeddings = embed_model.encode(context_list, show_progress_bar=True) index = faiss.IndexFlatL2(embeddings[0].shape[0]) index.add(embeddings) return index, embeddings, embed_model # Retrieve top k similar context passages def retrieve_context(query, embed_model, index, df, k=3): query_embedding = embed_model.encode([query]) D, I = index.search(query_embedding, k) context_passages = df.iloc[I[0]]['context'].tolist() return context_passages # Ask Groq LLM def ask_groq(query, context): prompt = f"""You are a helpful assistant. Use the provided context to answer the question. Context: {context} Question: {query} Answer:""" response = client.chat.completions.create( messages=[{"role": "user", "content": prompt}], model=MODEL_NAME ) return response.choices[0].message.content # Streamlit UI st.title("📚 RAG App with Groq API") st.markdown("Use this Retrieval-Augmented Generation app to ask enterprise, legal, and financial questions.") df = load_data() index, embeddings, embed_model = load_embeddings(df) sample_queries = df['query'].dropna().unique().tolist() query = st.text_input("Enter your question:", "") if st.button("Use Random Sample"): import random query = random.choice(sample_queries) st.session_state["query"] = query st.experimental_rerun() if query: st.markdown(f"**Your Query:** {query}") with st.spinner("Retrieving relevant context..."): contexts = retrieve_context(query, embed_model, index, df) combined_context = "\n\n".join(contexts) with st.spinner("Getting answer from Groq..."): answer = ask_groq(query, combined_context) st.markdown("### 💡 Answer") st.write(answer) st.markdown("### 📄 Retrieved Context") for i, ctx in enumerate(contexts, 1): st.markdown(f"**Context {i}:**") st.write(ctx)