File size: 2,893 Bytes
1400248 b7416a2 1400248 b7416a2 1400248 4ac0c74 1400248 b7416a2 1400248 d2ffcaf 1400248 b7416a2 1400248 b7416a2 1400248 b7416a2 1400248 b7416a2 1400248 b7416a2 1400248 b7416a2 1400248 b7416a2 ca414c3 b7416a2 ca414c3 b7416a2 1400248 b7416a2 1400248 b7416a2 1400248 b7416a2 1400248 b7416a2 1400248 b7416a2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 |
import streamlit as st
import pandas as pd
import os
import faiss
import pickle
import random
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from groq import Groq
from dotenv import load_dotenv
# Load environment variables
load_dotenv()
# Setup Groq client
client = Groq(api_key=os.getenv("GROQ_API_KEY"))
MODEL_NAME = "llama3-70b-8192"
# or try "llama-3-8b-8192" or "llama-3-3b-8192"
# Streamlit UI
st.set_page_config(page_title="RAG with Groq", layout="wide")
st.title("π RAG App using Groq API")
st.markdown("Ask enterprise, financial, and legal questions using Retrieval-Augmented Generation (RAG).")
# Load dataset from Hugging Face
@st.cache_data
def load_data():
dataset = load_dataset("llmware/rag_instruct_benchmark_tester", split="train")
df = pd.DataFrame(dataset)
return df
# Build or load FAISS index
@st.cache_resource
def load_embeddings(df):
embed_model = SentenceTransformer('all-MiniLM-L6-v2')
context_list = df['context'].tolist()
embeddings = embed_model.encode(context_list, show_progress_bar=True)
index = faiss.IndexFlatL2(embeddings[0].shape[0])
index.add(embeddings)
return index, embeddings, embed_model
# Retrieve top-k relevant context
def retrieve_context(query, embed_model, index, df, k=3):
query_embedding = embed_model.encode([query])
D, I = index.search(query_embedding, k)
context_passages = df.iloc[I[0]]['context'].tolist()
return context_passages
# Ask the Groq LLM
def ask_groq(query, context):
prompt = f"""You are a helpful assistant. Use the context to answer the question.
Context:
{context}
Question:
{query}
Answer:"""
response = client.chat.completions.create(
messages=[{"role": "user", "content": prompt}],
model=MODEL_NAME
)
return response.choices[0].message.content
# Load everything
df = load_data()
index, embeddings, embed_model = load_embeddings(df)
# User input
st.subheader("π Ask your question")
sample_queries = df['query'].dropna().unique().tolist()
col1, col2 = st.columns([3, 1])
with col1:
query = st.text_input("Enter your question here:", value=st.session_state.get("query", ""))
with col2:
if st.button("π² Random Sample"):
st.session_state["query"] = random.choice(sample_queries)
st.rerun()
# Handle query
if query:
st.markdown(f"**Your Query:** {query}")
with st.spinner("π Retrieving relevant context..."):
contexts = retrieve_context(query, embed_model, index, df)
combined_context = "\n\n".join(contexts)
with st.spinner("π€ Querying Groq LLM..."):
answer = ask_groq(query, combined_context)
st.markdown("### π‘ Answer")
st.write(answer)
st.markdown("### π Retrieved Context")
for i, ctx in enumerate(contexts, 1):
with st.expander(f"Context {i}"):
st.write(ctx)
|