ajalisatgi commited on
Commit
2bf38d0
Β·
verified Β·
1 Parent(s): dad68a2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -97
app.py CHANGED
@@ -1,109 +1,57 @@
1
  import gradio as gr
2
- import openai
3
- import os
4
- from langchain_community.embeddings import HuggingFaceEmbeddings
5
  from langchain_community.vectorstores import Chroma
6
- from langchain.schema import Document
7
- from sentence_transformers import SentenceTransformer
8
- from datasets import load_dataset
9
- import nltk
10
- import nltk
11
- nltk.download('punkt') # βœ… Correct package
12
- import nltk
13
- import os
14
- nltk.data.path.append("/usr/local/share/nltk_data") # βœ… Set correct path
15
- nltk.download("punkt")
16
-
17
 
18
- # βœ… Load the Sentence Transformer Embedding Model
19
- model_name = "sentence-transformers/all-MiniLM-L6-v2"
 
20
  embedding_model = HuggingFaceEmbeddings(model_name=model_name)
 
21
 
22
- # βœ… Set OpenAI API Key
23
- openai.api_key = os.getenv("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")
24
-
25
- # βœ… Download NLTK Tokenizer
26
- nltk.download('punkt')
27
-
28
- # βœ… Load and Chunk Dataset
29
- def chunk_documents(documents, max_chunk_size=500):
30
- chunks = []
31
- for doc in documents:
32
- sentences = nltk.sent_tokenize(doc)
33
- current_chunk = ""
34
- for sentence in sentences:
35
- if len(current_chunk) + len(sentence) <= max_chunk_size:
36
- current_chunk += sentence + " "
37
- else:
38
- chunks.append(current_chunk.strip())
39
- current_chunk = sentence + " "
40
- if current_chunk:
41
- chunks.append(current_chunk.strip())
42
- return chunks
43
-
44
- # βœ… Load Dataset and Prepare ChromaDB
45
- dataset = load_dataset("rungalileo/ragbench", "techqa") # Example dataset
46
- original_documents = dataset['train']['documents']
47
- chunked_documents = chunk_documents(original_documents)
48
-
49
- persist_directory = "chroma_db_directory"
50
- documents = [Document(page_content=chunk) for chunk in chunked_documents]
51
-
52
- # βœ… Initialize ChromaDB
53
- vectordb = Chroma.from_documents(
54
- documents=documents,
55
- embedding=embedding_model,
56
- persist_directory=persist_directory
57
  )
58
- vectordb.persist()
59
-
60
- # βœ… Function to Retrieve Relevant Documents
61
- def retrieve_documents(question, k=5):
62
- docs = vectordb.similarity_search(question, k=k)
63
- if not docs:
64
- return ["⚠️ No relevant documents found. Try a different query."]
65
- return [doc.page_content for doc in docs]
66
-
67
- # βœ… Function to Generate AI Response
68
- def generate_response(question, context):
69
- if not context or "No relevant documents found." in context:
70
- return "No relevant context available. Try a different query."
71
-
72
- full_prompt = f"Context: {context}\n\nQuestion: {question}"
73
 
74
- try:
75
- response = openai.ChatCompletion.create(
76
- model="gpt-4",
77
- messages=[
78
- {"role": "system", "content": "You are an AI assistant that answers user queries based on the given context."},
79
- {"role": "user", "content": full_prompt}
80
- ],
81
- max_tokens=300,
82
- temperature=0.7
83
- )
84
- return response['choices'][0]['message']['content'].strip()
85
- except Exception as e:
86
- return f"Error generating response: {str(e)}"
87
-
88
- # βœ… Full RAG Pipeline
89
- def rag_pipeline(question):
90
- retrieved_docs = retrieve_documents(question, k=5)
91
- context = " ".join(retrieved_docs)
92
- response = generate_response(question, context)
93
- return response, "\n\n".join(retrieved_docs)
94
-
95
- # βœ… Gradio UI Interface
96
- iface = gr.Interface(
97
- fn=rag_pipeline,
98
- inputs=gr.Textbox(label="Enter your question"),
99
  outputs=[
100
- gr.Textbox(label="Generated Response"),
101
- gr.Textbox(label="Retrieved Documents")
102
  ],
103
- title="RAG-Based Question Answering System",
104
- description="Enter a question and retrieve relevant documents with AI-generated response."
 
 
 
 
 
105
  )
106
 
107
- # βœ… Launch the Gradio App
108
  if __name__ == "__main__":
109
- iface.launch()
 
1
  import gradio as gr
2
+ from langchain.embeddings import HuggingFaceEmbeddings
 
 
3
  from langchain_community.vectorstores import Chroma
4
+ import openai
5
+ import torch
 
 
 
 
 
 
 
 
 
6
 
7
+ # Initialize models and configurations
8
+ model_name = 'intfloat/e5-small'
9
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
  embedding_model = HuggingFaceEmbeddings(model_name=model_name)
11
+ embedding_model.client.to(device)
12
 
13
+ # Initialize Chroma
14
+ vectordb = Chroma(
15
+ persist_directory='./docs/chroma/',
16
+ embedding_function=embedding_model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
+ def process_query(query):
20
+ # Get relevant documents
21
+ relevant_docs = vectordb.similarity_search(query, k=30)
22
+ context = " ".join([doc.page_content for doc in relevant_docs])
23
+
24
+ # Generate response using OpenAI
25
+ response = openai.chat.completions.create(
26
+ model="gpt-4",
27
+ messages=[
28
+ {"role": "system", "content": "You are a helpful assistant."},
29
+ {"role": "user", "content": f"Given the document: {context}\n\nGenerate a response to the query: {query}"}
30
+ ],
31
+ max_tokens=300,
32
+ temperature=0.7,
33
+ )
34
+
35
+ return response.choices[0].message.content.strip()
36
+
37
+ # Create Gradio interface
38
+ demo = gr.Interface(
39
+ fn=process_query,
40
+ inputs=[
41
+ gr.Textbox(label="Enter your question", placeholder="Type your question here...")
42
+ ],
 
43
  outputs=[
44
+ gr.Textbox(label="Answer")
 
45
  ],
46
+ title="RAG-Powered Question Answering System",
47
+ description="Ask questions and get answers based on the embedded document knowledge.",
48
+ examples=[
49
+ ["What role does T-cell count play in severe human adenovirus type 55 (HAdV-55) infection?"],
50
+ ["In what school district is Governor John R. Rogers High School located?"],
51
+ ["Is there a functional neural correlate of individual differences in cardiovascular reactivity?"]
52
+ ]
53
  )
54
 
55
+ # Launch the app
56
  if __name__ == "__main__":
57
+ demo.launch()