ajalisatgi commited on
Commit
c534f6f
·
verified ·
1 Parent(s): e03563e

Update rag_gradio_app.py

Browse files
Files changed (1) hide show
  1. rag_gradio_app.py +23 -34
rag_gradio_app.py CHANGED
@@ -1,56 +1,45 @@
1
  import gradio as gr
2
- import torch
3
  from langchain.embeddings import HuggingFaceEmbeddings
4
  from langchain_community.vectorstores import Chroma
5
- from sentence_transformers import SentenceTransformer
6
- import openai
7
 
8
- # Load pre-trained embedding model
9
- model_name = 'intfloat/e5-small'
 
 
 
10
  embedding_model = HuggingFaceEmbeddings(model_name=model_name)
11
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
 
13
  # Load ChromaDB
14
- persist_directory = './docs/chroma/'
15
  vectordb = Chroma(persist_directory=persist_directory, embedding_function=embedding_model)
16
 
17
- # OpenAI API Key
18
- openai.api_key = 'your-api-key'
19
-
20
- def retrieve_documents(question, k=5):
21
- """Retrieve top K relevant documents from ChromaDB"""
22
- docs = vectordb.similarity_search(question, k=k)
23
- return [doc.page_content for doc in docs]
24
-
25
- def generate_response(question, context):
26
- """Generate response using OpenAI GPT-4"""
27
- full_prompt = f"Context: {context}\n\nQuestion: {question}"
28
  response = openai.ChatCompletion.create(
29
  model="gpt-4",
30
  messages=[{"role": "user", "content": full_prompt}],
31
  max_tokens=300,
32
  temperature=0.7
33
  )
34
- return response['choices'][0]['message']['content'].strip()
35
-
36
- def rag_pipeline(question):
37
- """Full RAG Pipeline - Retrieve Docs & Generate Response"""
38
- retrieved_docs = retrieve_documents(question, k=5)
39
- context = " ".join(retrieved_docs)
40
- response = generate_response(question, context)
41
- return response, retrieved_docs
42
-
43
- def gradio_interface(question):
44
- response, retrieved_docs = rag_pipeline(question)
45
- return response, "\n\n".join(retrieved_docs)
46
 
47
- # Create Gradio App
48
  iface = gr.Interface(
49
- fn=gradio_interface,
50
  inputs=gr.Textbox(label="Enter your question"),
51
  outputs=[gr.Textbox(label="Generated Response"), gr.Textbox(label="Retrieved Documents")],
52
  title="RAG-Based Question Answering System",
53
- description="Enter a question and retrieve relevant documents along with the AI-generated response."
54
  )
55
 
56
- iface.launch()
 
 
 
1
  import gradio as gr
2
+ import openai
3
  from langchain.embeddings import HuggingFaceEmbeddings
4
  from langchain_community.vectorstores import Chroma
 
 
5
 
6
+ # Set API Key
7
+ openai.api_key = "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"
8
+
9
+ # Load embedding model
10
+ model_name = "intfloat/e5-small"
11
  embedding_model = HuggingFaceEmbeddings(model_name=model_name)
 
12
 
13
  # Load ChromaDB
14
+ persist_directory = "./docs/chroma/"
15
  vectordb = Chroma(persist_directory=persist_directory, embedding_function=embedding_model)
16
 
17
+ # Define RAG function
18
+ def rag_pipeline(question):
19
+ """Retrieve relevant documents and generate AI response"""
20
+ retrieved_docs = vectordb.similarity_search(question, k=5)
21
+ context = " ".join([doc.page_content for doc in retrieved_docs])
22
+
23
+ # Generate AI response
24
+ full_prompt = f"Context: {context}\\n\\nQuestion: {question}"
 
 
 
25
  response = openai.ChatCompletion.create(
26
  model="gpt-4",
27
  messages=[{"role": "user", "content": full_prompt}],
28
  max_tokens=300,
29
  temperature=0.7
30
  )
31
+
32
+ return response['choices'][0]['message']['content'].strip(), retrieved_docs
 
 
 
 
 
 
 
 
 
 
33
 
34
+ # Gradio UI
35
  iface = gr.Interface(
36
+ fn=rag_pipeline,
37
  inputs=gr.Textbox(label="Enter your question"),
38
  outputs=[gr.Textbox(label="Generated Response"), gr.Textbox(label="Retrieved Documents")],
39
  title="RAG-Based Question Answering System",
40
+ description="Enter a question and retrieve relevant documents with AI-generated response."
41
  )
42
 
43
+ # Launch Gradio app
44
+ if __name__ == "__main__":
45
+ iface.launch()