amasood commited on
Commit
1400248
Β·
verified Β·
1 Parent(s): 9b4c763

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +88 -0
app.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import os
4
+ import faiss
5
+ import pickle
6
+ from sentence_transformers import SentenceTransformer
7
+ from groq import Groq
8
+
9
+ # Load environment variables
10
+ from dotenv import load_dotenv
11
+ load_dotenv()
12
+
13
+ # Setup Groq client
14
+ client = Groq(api_key=os.getenv("GROQ_API_KEY"))
15
+ MODEL_NAME = "llama-3-70b-8192" # Or use "llama-3-8b-8192", "llama-3-3b-8192"
16
+
17
+ # Load dataset
18
+ @st.cache_data
19
+ def load_data():
20
+ url = "https://huggingface.co/datasets/llmware/rag_instruct_benchmark_tester/resolve/main/rag_instruct_benchmark_tester.csv"
21
+ df = pd.read_csv(url)
22
+ return df
23
+
24
+ # Build or load FAISS index
25
+ @st.cache_resource
26
+ def load_embeddings(df):
27
+ embed_model = SentenceTransformer('all-MiniLM-L6-v2')
28
+ context_list = df['context'].tolist()
29
+ embeddings = embed_model.encode(context_list, show_progress_bar=True)
30
+
31
+ index = faiss.IndexFlatL2(embeddings[0].shape[0])
32
+ index.add(embeddings)
33
+
34
+ return index, embeddings, embed_model
35
+
36
+ # Retrieve top k similar context passages
37
+ def retrieve_context(query, embed_model, index, df, k=3):
38
+ query_embedding = embed_model.encode([query])
39
+ D, I = index.search(query_embedding, k)
40
+ context_passages = df.iloc[I[0]]['context'].tolist()
41
+ return context_passages
42
+
43
+ # Ask Groq LLM
44
+ def ask_groq(query, context):
45
+ prompt = f"""You are a helpful assistant. Use the provided context to answer the question.
46
+
47
+ Context:
48
+ {context}
49
+
50
+ Question:
51
+ {query}
52
+
53
+ Answer:"""
54
+ response = client.chat.completions.create(
55
+ messages=[{"role": "user", "content": prompt}],
56
+ model=MODEL_NAME
57
+ )
58
+ return response.choices[0].message.content
59
+
60
+ # Streamlit UI
61
+ st.title("πŸ“š RAG App with Groq API")
62
+ st.markdown("Use this Retrieval-Augmented Generation app to ask enterprise, legal, and financial questions.")
63
+
64
+ df = load_data()
65
+ index, embeddings, embed_model = load_embeddings(df)
66
+
67
+ sample_queries = df['query'].dropna().unique().tolist()
68
+
69
+ query = st.text_input("Enter your question:", "")
70
+ if st.button("Use Random Sample"):
71
+ import random
72
+ query = random.choice(sample_queries)
73
+ st.session_state["query"] = query
74
+ st.experimental_rerun()
75
+
76
+ if query:
77
+ st.markdown(f"**Your Query:** {query}")
78
+ with st.spinner("Retrieving relevant context..."):
79
+ contexts = retrieve_context(query, embed_model, index, df)
80
+ combined_context = "\n\n".join(contexts)
81
+ with st.spinner("Getting answer from Groq..."):
82
+ answer = ask_groq(query, combined_context)
83
+ st.markdown("### πŸ’‘ Answer")
84
+ st.write(answer)
85
+ st.markdown("### πŸ“„ Retrieved Context")
86
+ for i, ctx in enumerate(contexts, 1):
87
+ st.markdown(f"**Context {i}:**")
88
+ st.write(ctx)