Fecalisboa commited on
Commit
f0fe5f2
·
verified ·
1 Parent(s): f1fd602

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +112 -23
app.py CHANGED
@@ -14,13 +14,51 @@ from langchain_community.llms import HuggingFacePipeline
14
  from langchain.chains import ConversationChain
15
  from langchain.memory import ConversationBufferMemory
16
  from langchain_community.llms import HuggingFaceEndpoint
 
17
  import torch
 
18
  api_token = os.getenv("HF_TOKEN")
19
 
 
 
 
 
20
  list_llm = ["meta-llama/Meta-Llama-3-8B-Instruct", "mistralai/Mistral-7B-Instruct-v0.3"]
21
  list_llm_simple = [os.path.basename(llm) for llm in list_llm]
22
 
23
- # Load PDF document and create doc splits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  def load_doc(list_file_path, chunk_size, chunk_overlap):
25
  loaders = [PyPDFLoader(x) for x in list_file_path]
26
  pages = []
@@ -30,7 +68,6 @@ def load_doc(list_file_path, chunk_size, chunk_overlap):
30
  doc_splits = text_splitter.split_documents(pages)
31
  return doc_splits
32
 
33
- # Create vector database
34
  def create_db(splits, collection_name, db_type):
35
  embedding = HuggingFaceEmbeddings()
36
 
@@ -63,10 +100,8 @@ def create_db(splits, collection_name, db_type):
63
 
64
  return vectordb
65
 
66
- # Initialize langchain LLM chain
67
  def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, initial_prompt, progress=gr.Progress()):
68
  progress(0.1, desc="Initializing HF tokenizer...")
69
-
70
  progress(0.5, desc="Initializing HF Hub...")
71
 
72
  llm = HuggingFaceEndpoint(
@@ -155,6 +190,33 @@ def upload_file(file_obj):
155
  list_file_path.append(file.name)
156
  return list_file_path
157
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
  def demo():
159
  with gr.Blocks(theme="base") as demo:
160
  vector_db = gr.State()
@@ -229,27 +291,58 @@ def demo():
229
  clear_btn = gr.ClearButton([msg, chatbot], value="Clear conversation")
230
 
231
  with gr.Tab("Step 6 - Chatbot without document"):
232
- with gr.Row():
233
- llm_no_doc_btn = gr.Radio(list_llm_simple,
234
- label="LLM models", value=list_llm_simple[0], type="index", info="Choose your LLM model for chatbot without document")
235
- with gr.Accordion("Advanced options - LLM model", open=False):
236
- with gr.Row():
237
- slider_temperature_no_doc = gr.Slider(minimum=0.01, maximum=1.0, value=0.7, step=0.1, label="Temperature", info="Model temperature", interactive=True)
238
- with gr.Row():
239
- slider_maxtokens_no_doc = gr.Slider(minimum=224, maximum=4096, value=1024, step=32, label="Max Tokens", info="Model max tokens", interactive=True)
240
- with gr.Row():
241
- slider_topk_no_doc = gr.Slider(minimum=1, maximum=10, value=3, step=1, label="top-k samples", info="Model top-k samples", interactive=True)
242
- with gr.Row():
243
- llm_no_doc_progress = gr.Textbox(value="None", label="LLM initialization for chatbot without document")
244
- with gr.Row():
245
- llm_no_doc_init_btn = gr.Button("Initialize LLM for Chatbot without document")
246
  chatbot_no_doc = gr.Chatbot(height=300)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
247
  with gr.Row():
248
  msg_no_doc = gr.Textbox(placeholder="Type message to chat with lucIAna", container=True)
249
  with gr.Row():
250
  submit_btn_no_doc = gr.Button("Submit message")
251
  clear_btn_no_doc = gr.ClearButton([msg_no_doc, chatbot_no_doc], value="Clear conversation")
252
 
 
 
 
 
 
 
 
253
  # Preprocessing events
254
  db_btn.click(initialize_database,
255
  inputs=[document, slider_chunk_size, slider_chunk_overlap, db_type_radio],
@@ -257,7 +350,7 @@ def demo():
257
  set_prompt_btn.click(lambda prompt: gr.update(value=prompt),
258
  inputs=prompt_input,
259
  outputs=initial_prompt)
260
- qachain_btn.click(initialize_LLM,
261
  inputs=[llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db, initial_prompt],
262
  outputs=[qa_chain, llm_progress]).then(lambda:[None,"",0,"",0,"",0],
263
  inputs=None,
@@ -279,10 +372,6 @@ def demo():
279
  queue=False)
280
 
281
  # Initialize LLM without document for conversation
282
- llm_no_doc_init_btn.click(initialize_llm_no_doc,
283
- inputs=[llm_no_doc_btn, slider_temperature_no_doc, slider_maxtokens_no_doc, slider_topk_no_doc, initial_prompt],
284
- outputs=[llm_no_doc, llm_no_doc_progress])
285
-
286
  submit_btn_no_doc.click(conversation_no_doc,
287
  inputs=[llm_no_doc, msg_no_doc, chatbot_no_doc],
288
  outputs=[llm_no_doc, msg_no_doc, chatbot_no_doc],
 
14
  from langchain.chains import ConversationChain
15
  from langchain.memory import ConversationBufferMemory
16
  from langchain_community.llms import HuggingFaceEndpoint
17
+ from huggingface_hub import InferenceClient
18
  import torch
19
+
20
  api_token = os.getenv("HF_TOKEN")
21
 
22
+ client = InferenceClient(
23
+ "mistralai/Mistral-7B-Instruct-v0.3"
24
+ )
25
+
26
  list_llm = ["meta-llama/Meta-Llama-3-8B-Instruct", "mistralai/Mistral-7B-Instruct-v0.3"]
27
  list_llm_simple = [os.path.basename(llm) for llm in list_llm]
28
 
29
+ def format_prompt(message, history):
30
+ prompt = "<s>"
31
+ for user_prompt, bot_response in history:
32
+ prompt += f"[INST] {user_prompt} [/INST]"
33
+ prompt += f" {bot_response}</s> "
34
+ prompt += f"[INST] {message} [/INST]"
35
+ return prompt
36
+
37
+ def generate(prompt, history, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0):
38
+ temperature = float(temperature)
39
+ if temperature < 1e-2:
40
+ temperature = 1e-2
41
+ top_p = float(top_p)
42
+
43
+ generate_kwargs = dict(
44
+ temperature=temperature,
45
+ max_new_tokens=max_new_tokens,
46
+ top_p=top_p,
47
+ repetition_penalty=repetition_penalty,
48
+ do_sample=True,
49
+ seed=42,
50
+ )
51
+
52
+ formatted_prompt = format_prompt(prompt, history)
53
+
54
+ stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
55
+ output = ""
56
+
57
+ for response in stream:
58
+ output += response.token.text
59
+ yield output
60
+ return output
61
+
62
  def load_doc(list_file_path, chunk_size, chunk_overlap):
63
  loaders = [PyPDFLoader(x) for x in list_file_path]
64
  pages = []
 
68
  doc_splits = text_splitter.split_documents(pages)
69
  return doc_splits
70
 
 
71
  def create_db(splits, collection_name, db_type):
72
  embedding = HuggingFaceEmbeddings()
73
 
 
100
 
101
  return vectordb
102
 
 
103
  def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, initial_prompt, progress=gr.Progress()):
104
  progress(0.1, desc="Initializing HF tokenizer...")
 
105
  progress(0.5, desc="Initializing HF Hub...")
106
 
107
  llm = HuggingFaceEndpoint(
 
190
  list_file_path.append(file.name)
191
  return list_file_path
192
 
193
+ def initialize_database(list_file_obj, chunk_size, chunk_overlap, db_type, progress=gr.Progress()):
194
+ list_file_path = [x.name for x in list_file_obj if x is not None]
195
+ progress(0.1, desc="Creating collection name...")
196
+ collection_name = create_collection_name(list_file_path[0])
197
+ progress(0.25, desc="Loading document...")
198
+ doc_splits = load_doc(list_file_path, chunk_size, chunk_overlap)
199
+ progress(0.5, desc="Generating vector database...")
200
+ vector_db = create_db(doc_splits, collection_name, db_type)
201
+ progress(0.9, desc="Done!")
202
+ return vector_db, collection_name, "Complete!"
203
+
204
+ def create_collection_name(filepath):
205
+ collection_name = Path(filepath).stem
206
+ collection_name = collection_name.replace(" ", "-")
207
+ collection_name = unidecode(collection_name)
208
+ collection_name = re.sub('[^A-Za-z0-9]+', '-', collection_name)
209
+ collection_name = collection_name[:50]
210
+ if len(collection_name) < 3:
211
+ collection_name = collection_name + 'xyz'
212
+ if not collection_name[0].isalnum():
213
+ collection_name = 'A' + collection_name[1:]
214
+ if not collection_name[-1].isalnum():
215
+ collection_name = collection_name[:-1] + 'Z'
216
+ print('Filepath: ', filepath)
217
+ print('Collection name: ', collection_name)
218
+ return collection_name
219
+
220
  def demo():
221
  with gr.Blocks(theme="base") as demo:
222
  vector_db = gr.State()
 
291
  clear_btn = gr.ClearButton([msg, chatbot], value="Clear conversation")
292
 
293
  with gr.Tab("Step 6 - Chatbot without document"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
294
  chatbot_no_doc = gr.Chatbot(height=300)
295
+ additional_inputs=[
296
+ gr.Slider(
297
+ label="Temperature",
298
+ value=0.9,
299
+ minimum=0.0,
300
+ maximum=1.0,
301
+ step=0.05,
302
+ interactive=True,
303
+ info="Higher values produce more diverse outputs",
304
+ ),
305
+ gr.Slider(
306
+ label="Max new tokens",
307
+ value=256,
308
+ minimum=0,
309
+ maximum=1048,
310
+ step=64,
311
+ interactive=True,
312
+ info="The maximum numbers of new tokens",
313
+ ),
314
+ gr.Slider(
315
+ label="Top-p (nucleus sampling)",
316
+ value=0.90,
317
+ minimum=0.0,
318
+ maximum=1,
319
+ step=0.05,
320
+ interactive=True,
321
+ info="Higher values sample more low-probability tokens",
322
+ ),
323
+ gr.Slider(
324
+ label="Repetition penalty",
325
+ value=1.2,
326
+ minimum=1.0,
327
+ maximum=2.0,
328
+ step=0.05,
329
+ interactive=True,
330
+ info="Penalize repeated tokens",
331
+ )
332
+ ]
333
  with gr.Row():
334
  msg_no_doc = gr.Textbox(placeholder="Type message to chat with lucIAna", container=True)
335
  with gr.Row():
336
  submit_btn_no_doc = gr.Button("Submit message")
337
  clear_btn_no_doc = gr.ClearButton([msg_no_doc, chatbot_no_doc], value="Clear conversation")
338
 
339
+ gr.ChatInterface(
340
+ fn=generate,
341
+ chatbot=chatbot_no_doc,
342
+ additional_inputs=additional_inputs,
343
+ title="Mistral 7B v0.3"
344
+ )
345
+
346
  # Preprocessing events
347
  db_btn.click(initialize_database,
348
  inputs=[document, slider_chunk_size, slider_chunk_overlap, db_type_radio],
 
350
  set_prompt_btn.click(lambda prompt: gr.update(value=prompt),
351
  inputs=prompt_input,
352
  outputs=initial_prompt)
353
+ qachain_btn.click(initialize_llmchain,
354
  inputs=[llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db, initial_prompt],
355
  outputs=[qa_chain, llm_progress]).then(lambda:[None,"",0,"",0,"",0],
356
  inputs=None,
 
372
  queue=False)
373
 
374
  # Initialize LLM without document for conversation
 
 
 
 
375
  submit_btn_no_doc.click(conversation_no_doc,
376
  inputs=[llm_no_doc, msg_no_doc, chatbot_no_doc],
377
  outputs=[llm_no_doc, msg_no_doc, chatbot_no_doc],