File size: 2,662 Bytes
1400248
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import streamlit as st
import pandas as pd
import os
import faiss
import pickle
from sentence_transformers import SentenceTransformer
from groq import Groq

# Load environment variables
from dotenv import load_dotenv
load_dotenv()

# Setup Groq client
client = Groq(api_key=os.getenv("GROQ_API_KEY"))
MODEL_NAME = "llama-3-70b-8192"  # Or use "llama-3-8b-8192", "llama-3-3b-8192"

# Load dataset
@st.cache_data
def load_data():
    url = "https://huggingface.co/datasets/llmware/rag_instruct_benchmark_tester/resolve/main/rag_instruct_benchmark_tester.csv"
    df = pd.read_csv(url)
    return df

# Build or load FAISS index
@st.cache_resource
def load_embeddings(df):
    embed_model = SentenceTransformer('all-MiniLM-L6-v2')
    context_list = df['context'].tolist()
    embeddings = embed_model.encode(context_list, show_progress_bar=True)

    index = faiss.IndexFlatL2(embeddings[0].shape[0])
    index.add(embeddings)

    return index, embeddings, embed_model

# Retrieve top k similar context passages
def retrieve_context(query, embed_model, index, df, k=3):
    query_embedding = embed_model.encode([query])
    D, I = index.search(query_embedding, k)
    context_passages = df.iloc[I[0]]['context'].tolist()
    return context_passages

# Ask Groq LLM
def ask_groq(query, context):
    prompt = f"""You are a helpful assistant. Use the provided context to answer the question.

Context:
{context}

Question:
{query}

Answer:"""
    response = client.chat.completions.create(
        messages=[{"role": "user", "content": prompt}],
        model=MODEL_NAME
    )
    return response.choices[0].message.content

# Streamlit UI
st.title("πŸ“š RAG App with Groq API")
st.markdown("Use this Retrieval-Augmented Generation app to ask enterprise, legal, and financial questions.")

df = load_data()
index, embeddings, embed_model = load_embeddings(df)

sample_queries = df['query'].dropna().unique().tolist()

query = st.text_input("Enter your question:", "")
if st.button("Use Random Sample"):
    import random
    query = random.choice(sample_queries)
    st.session_state["query"] = query
    st.experimental_rerun()

if query:
    st.markdown(f"**Your Query:** {query}")
    with st.spinner("Retrieving relevant context..."):
        contexts = retrieve_context(query, embed_model, index, df)
        combined_context = "\n\n".join(contexts)
    with st.spinner("Getting answer from Groq..."):
        answer = ask_groq(query, combined_context)
    st.markdown("### πŸ’‘ Answer")
    st.write(answer)
    st.markdown("### πŸ“„ Retrieved Context")
    for i, ctx in enumerate(contexts, 1):
        st.markdown(f"**Context {i}:**")
        st.write(ctx)