hanzla commited on
Commit
d9640db
·
1 Parent(s): 5313037
Files changed (2) hide show
  1. src/app.py +1 -1
  2. src/pdfchatbot.py +49 -5
src/app.py CHANGED
@@ -16,7 +16,7 @@ with demo:
16
 
17
  # Event handler for submitting text and generating response
18
  submit_button.click(pdf_chatbot.add_text, inputs=[chat_history, txt], outputs=[chat_history], queue=False).\
19
- success(pdf_chatbot.generate_response, inputs=[chat_history, txt, uploaded_pdf], outputs=[chat_history, txt]).\
20
  success(pdf_chatbot.render_file, inputs=[uploaded_pdf], outputs=[show_img])
21
 
22
  if __name__ == "__main__":
 
16
 
17
  # Event handler for submitting text and generating response
18
  submit_button.click(pdf_chatbot.add_text, inputs=[chat_history, txt], outputs=[chat_history], queue=False).\
19
+ success(pdf_chatbot.generate_response, inputs=[chat_history, txt, uploaded_pdf], outputs=[chat_history,txt]).\
20
  success(pdf_chatbot.render_file, inputs=[uploaded_pdf], outputs=[show_img])
21
 
22
  if __name__ == "__main__":
src/pdfchatbot.py CHANGED
@@ -14,7 +14,6 @@ import spaces
14
  from langchain_text_splitters import CharacterTextSplitter
15
 
16
 
17
-
18
  class PDFChatBot:
19
  def __init__(self, config_path="config.yaml"):
20
  """
@@ -37,6 +36,8 @@ class PDFChatBot:
37
  self.pipeline = None
38
  self.chain = None
39
  self.chunk_size = None
 
 
40
  #self.chunk_size_slider = chunk_size_slider
41
 
42
  def load_config(self, file_path):
@@ -128,6 +129,46 @@ class PDFChatBot:
128
  )
129
  self.pipeline = HuggingFacePipeline(pipeline=pipe)
130
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
  def create_chain(self):
132
  """
133
  Create a Conversational Retrieval Chain
@@ -153,8 +194,8 @@ class PDFChatBot:
153
  self.load_vectordb()
154
  self.load_tokenizer()
155
  self.load_model()
156
- self.create_pipeline()
157
- self.create_chain()
158
  @spaces.GPU
159
  def generate_response(self, history, query, file):
160
  """
@@ -176,11 +217,14 @@ class PDFChatBot:
176
  self.process_file(file)
177
  self.processed = True
178
 
179
- result = self.chain({"question": query, 'chat_history': self.chat_history}, return_only_outputs=True)
180
  self.chat_history.append((query, result["answer"]))
181
  for char in result['answer']:
182
  history[-1][-1] += char
183
- return history, " "
 
 
 
184
 
185
  def render_file(self, file,chunk_size):
186
  """
 
14
  from langchain_text_splitters import CharacterTextSplitter
15
 
16
 
 
17
  class PDFChatBot:
18
  def __init__(self, config_path="config.yaml"):
19
  """
 
36
  self.pipeline = None
37
  self.chain = None
38
  self.chunk_size = None
39
+ self.current_context = None
40
+ self.format_seperator="""\n\n--\n\n"""
41
  #self.chunk_size_slider = chunk_size_slider
42
 
43
  def load_config(self, file_path):
 
129
  )
130
  self.pipeline = HuggingFacePipeline(pipeline=pipe)
131
 
132
+ def create_organic_pipeline(self):
133
+ self.pipeline = pipeline(
134
+ "text-generation",
135
+ model=self.config.get("autoModelForCausalLM"),
136
+ model_kwargs={"torch_dtype": torch.bfloat16},
137
+ device="cuda",
138
+ )
139
+
140
+ def get_organic_context(self, query):
141
+ documents = self.vectordb.similarity_search_with_relevance_scores(query, k=self.k)
142
+ context = self.format_seperator.join([doc.page_content for doc, score in documents])
143
+ self.current_context = context
144
+ print(self.current_context)
145
+
146
+ def create_organic_response(self, history, query):
147
+ self.get_organic_context(query)
148
+ messages = [
149
+ {"role": "system", "content": "From the the contained given below, answer the question of user \n " + self.current_context},
150
+ {"role": "user", "content": query},
151
+ ]
152
+ prompt = self.pipeline.tokenizer.apply_chat_template(
153
+ messages,
154
+ tokenize=False,
155
+ add_generation_prompt=True
156
+ )
157
+ terminators = [
158
+ self.pipeline.tokenizer.eos_token_id,
159
+ self.pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>")
160
+ ]
161
+ temp = 0.1
162
+ outputs = pipeline(
163
+ prompt,
164
+ max_new_tokens=1024,
165
+ eos_token_id=terminators,
166
+ do_sample=True,
167
+ temperature=temp,
168
+ top_p=0.9,
169
+ )
170
+ return outputs[0]["generated_text"][len(prompt):]
171
+
172
  def create_chain(self):
173
  """
174
  Create a Conversational Retrieval Chain
 
194
  self.load_vectordb()
195
  self.load_tokenizer()
196
  self.load_model()
197
+ self.create_organic_pipeline()
198
+ #self.create_chain()
199
  @spaces.GPU
200
  def generate_response(self, history, query, file):
201
  """
 
217
  self.process_file(file)
218
  self.processed = True
219
 
220
+ """result = self.chain({"question": query, 'chat_history': self.chat_history}, return_only_outputs=True)
221
  self.chat_history.append((query, result["answer"]))
222
  for char in result['answer']:
223
  history[-1][-1] += char
224
+ return history, " """""
225
+
226
+ result = self.create_organic_response(history="",query=query)
227
+ return result,""
228
 
229
  def render_file(self, file,chunk_size):
230
  """