Fecalisboa commited on
Commit
f1fd602
·
verified ·
1 Parent(s): 16b2733

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -85
app.py CHANGED
@@ -14,51 +14,13 @@ from langchain_community.llms import HuggingFacePipeline
14
  from langchain.chains import ConversationChain
15
  from langchain.memory import ConversationBufferMemory
16
  from langchain_community.llms import HuggingFaceEndpoint
17
- from huggingface_hub import InferenceClient
18
  import torch
19
-
20
  api_token = os.getenv("HF_TOKEN")
21
 
22
- client = InferenceClient(
23
- "mistralai/Mistral-7B-Instruct-v0.3"
24
- )
25
-
26
  list_llm = ["meta-llama/Meta-Llama-3-8B-Instruct", "mistralai/Mistral-7B-Instruct-v0.3"]
27
  list_llm_simple = [os.path.basename(llm) for llm in list_llm]
28
 
29
- def format_prompt(message, history):
30
- prompt = "<s>"
31
- for user_prompt, bot_response in history:
32
- prompt += f"[INST] {user_prompt} [/INST]"
33
- prompt += f" {bot_response}</s> "
34
- prompt += f"[INST] {message} [/INST]"
35
- return prompt
36
-
37
- def generate(prompt, history, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0):
38
- temperature = float(temperature)
39
- if temperature < 1e-2:
40
- temperature = 1e-2
41
- top_p = float(top_p)
42
-
43
- generate_kwargs = dict(
44
- temperature=temperature,
45
- max_new_tokens=max_new_tokens,
46
- top_p=top_p,
47
- repetition_penalty=repetition_penalty,
48
- do_sample=True,
49
- seed=42,
50
- )
51
-
52
- formatted_prompt = format_prompt(prompt, history)
53
-
54
- stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
55
- output = ""
56
-
57
- for response in stream:
58
- output += response.token.text
59
- yield output
60
- return output
61
-
62
  def load_doc(list_file_path, chunk_size, chunk_overlap):
63
  loaders = [PyPDFLoader(x) for x in list_file_path]
64
  pages = []
@@ -68,6 +30,7 @@ def load_doc(list_file_path, chunk_size, chunk_overlap):
68
  doc_splits = text_splitter.split_documents(pages)
69
  return doc_splits
70
 
 
71
  def create_db(splits, collection_name, db_type):
72
  embedding = HuggingFaceEmbeddings()
73
 
@@ -100,8 +63,10 @@ def create_db(splits, collection_name, db_type):
100
 
101
  return vectordb
102
 
 
103
  def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, initial_prompt, progress=gr.Progress()):
104
  progress(0.1, desc="Initializing HF tokenizer...")
 
105
  progress(0.5, desc="Initializing HF Hub...")
106
 
107
  llm = HuggingFaceEndpoint(
@@ -264,58 +229,27 @@ def demo():
264
  clear_btn = gr.ClearButton([msg, chatbot], value="Clear conversation")
265
 
266
  with gr.Tab("Step 6 - Chatbot without document"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
267
  chatbot_no_doc = gr.Chatbot(height=300)
268
- additional_inputs=[
269
- gr.Slider(
270
- label="Temperature",
271
- value=0.9,
272
- minimum=0.0,
273
- maximum=1.0,
274
- step=0.05,
275
- interactive=True,
276
- info="Higher values produce more diverse outputs",
277
- ),
278
- gr.Slider(
279
- label="Max new tokens",
280
- value=256,
281
- minimum=0,
282
- maximum=1048,
283
- step=64,
284
- interactive=True,
285
- info="The maximum numbers of new tokens",
286
- ),
287
- gr.Slider(
288
- label="Top-p (nucleus sampling)",
289
- value=0.90,
290
- minimum=0.0,
291
- maximum=1,
292
- step=0.05,
293
- interactive=True,
294
- info="Higher values sample more low-probability tokens",
295
- ),
296
- gr.Slider(
297
- label="Repetition penalty",
298
- value=1.2,
299
- minimum=1.0,
300
- maximum=2.0,
301
- step=0.05,
302
- interactive=True,
303
- info="Penalize repeated tokens",
304
- )
305
- ]
306
  with gr.Row():
307
  msg_no_doc = gr.Textbox(placeholder="Type message to chat with lucIAna", container=True)
308
  with gr.Row():
309
  submit_btn_no_doc = gr.Button("Submit message")
310
  clear_btn_no_doc = gr.ClearButton([msg_no_doc, chatbot_no_doc], value="Clear conversation")
311
 
312
- gr.ChatInterface(
313
- fn=generate,
314
- chatbot=chatbot_no_doc,
315
- additional_inputs=additional_inputs,
316
- title="Mistral 7B v0.3"
317
- )
318
-
319
  # Preprocessing events
320
  db_btn.click(initialize_database,
321
  inputs=[document, slider_chunk_size, slider_chunk_overlap, db_type_radio],
@@ -323,7 +257,7 @@ def demo():
323
  set_prompt_btn.click(lambda prompt: gr.update(value=prompt),
324
  inputs=prompt_input,
325
  outputs=initial_prompt)
326
- qachain_btn.click(initialize_llmchain,
327
  inputs=[llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db, initial_prompt],
328
  outputs=[qa_chain, llm_progress]).then(lambda:[None,"",0,"",0,"",0],
329
  inputs=None,
@@ -345,6 +279,10 @@ def demo():
345
  queue=False)
346
 
347
  # Initialize LLM without document for conversation
 
 
 
 
348
  submit_btn_no_doc.click(conversation_no_doc,
349
  inputs=[llm_no_doc, msg_no_doc, chatbot_no_doc],
350
  outputs=[llm_no_doc, msg_no_doc, chatbot_no_doc],
 
14
  from langchain.chains import ConversationChain
15
  from langchain.memory import ConversationBufferMemory
16
  from langchain_community.llms import HuggingFaceEndpoint
 
17
  import torch
 
18
  api_token = os.getenv("HF_TOKEN")
19
 
 
 
 
 
20
  list_llm = ["meta-llama/Meta-Llama-3-8B-Instruct", "mistralai/Mistral-7B-Instruct-v0.3"]
21
  list_llm_simple = [os.path.basename(llm) for llm in list_llm]
22
 
23
+ # Load PDF document and create doc splits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  def load_doc(list_file_path, chunk_size, chunk_overlap):
25
  loaders = [PyPDFLoader(x) for x in list_file_path]
26
  pages = []
 
30
  doc_splits = text_splitter.split_documents(pages)
31
  return doc_splits
32
 
33
+ # Create vector database
34
  def create_db(splits, collection_name, db_type):
35
  embedding = HuggingFaceEmbeddings()
36
 
 
63
 
64
  return vectordb
65
 
66
+ # Initialize langchain LLM chain
67
  def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, initial_prompt, progress=gr.Progress()):
68
  progress(0.1, desc="Initializing HF tokenizer...")
69
+
70
  progress(0.5, desc="Initializing HF Hub...")
71
 
72
  llm = HuggingFaceEndpoint(
 
229
  clear_btn = gr.ClearButton([msg, chatbot], value="Clear conversation")
230
 
231
  with gr.Tab("Step 6 - Chatbot without document"):
232
+ with gr.Row():
233
+ llm_no_doc_btn = gr.Radio(list_llm_simple,
234
+ label="LLM models", value=list_llm_simple[0], type="index", info="Choose your LLM model for chatbot without document")
235
+ with gr.Accordion("Advanced options - LLM model", open=False):
236
+ with gr.Row():
237
+ slider_temperature_no_doc = gr.Slider(minimum=0.01, maximum=1.0, value=0.7, step=0.1, label="Temperature", info="Model temperature", interactive=True)
238
+ with gr.Row():
239
+ slider_maxtokens_no_doc = gr.Slider(minimum=224, maximum=4096, value=1024, step=32, label="Max Tokens", info="Model max tokens", interactive=True)
240
+ with gr.Row():
241
+ slider_topk_no_doc = gr.Slider(minimum=1, maximum=10, value=3, step=1, label="top-k samples", info="Model top-k samples", interactive=True)
242
+ with gr.Row():
243
+ llm_no_doc_progress = gr.Textbox(value="None", label="LLM initialization for chatbot without document")
244
+ with gr.Row():
245
+ llm_no_doc_init_btn = gr.Button("Initialize LLM for Chatbot without document")
246
  chatbot_no_doc = gr.Chatbot(height=300)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
247
  with gr.Row():
248
  msg_no_doc = gr.Textbox(placeholder="Type message to chat with lucIAna", container=True)
249
  with gr.Row():
250
  submit_btn_no_doc = gr.Button("Submit message")
251
  clear_btn_no_doc = gr.ClearButton([msg_no_doc, chatbot_no_doc], value="Clear conversation")
252
 
 
 
 
 
 
 
 
253
  # Preprocessing events
254
  db_btn.click(initialize_database,
255
  inputs=[document, slider_chunk_size, slider_chunk_overlap, db_type_radio],
 
257
  set_prompt_btn.click(lambda prompt: gr.update(value=prompt),
258
  inputs=prompt_input,
259
  outputs=initial_prompt)
260
+ qachain_btn.click(initialize_LLM,
261
  inputs=[llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db, initial_prompt],
262
  outputs=[qa_chain, llm_progress]).then(lambda:[None,"",0,"",0,"",0],
263
  inputs=None,
 
279
  queue=False)
280
 
281
  # Initialize LLM without document for conversation
282
+ llm_no_doc_init_btn.click(initialize_llm_no_doc,
283
+ inputs=[llm_no_doc_btn, slider_temperature_no_doc, slider_maxtokens_no_doc, slider_topk_no_doc, initial_prompt],
284
+ outputs=[llm_no_doc, llm_no_doc_progress])
285
+
286
  submit_btn_no_doc.click(conversation_no_doc,
287
  inputs=[llm_no_doc, msg_no_doc, chatbot_no_doc],
288
  outputs=[llm_no_doc, msg_no_doc, chatbot_no_doc],