Fecalisboa commited on
Commit
299393f
·
verified ·
1 Parent(s): 6db85cc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -25
app.py CHANGED
@@ -17,7 +17,7 @@ from langchain_community.llms import HuggingFaceEndpoint
17
  import torch
18
  api_token = os.getenv("HF_TOKEN")
19
 
20
- list_llm = ["meta-llama/Meta-Llama-3-8B-Instruct", "mistralai/Mistral-7B-Instruct-v0.3"]
21
  list_llm_simple = [os.path.basename(llm) for llm in list_llm]
22
 
23
  # Load PDF document and create doc splits
@@ -64,32 +64,19 @@ def create_db(splits, collection_name, db_type):
64
  return vectordb
65
 
66
  # Initialize langchain LLM chain
67
- # Initialize langchain LLM chain
68
- def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
69
  progress(0.1, desc="Initializing HF tokenizer...")
70
 
71
  progress(0.5, desc="Initializing HF Hub...")
72
 
73
-
74
- if llm_model == "meta-llama/Meta-Llama-3-8B-Instruct":
75
- llm = HuggingFaceEndpoint(
76
- repo_id=llm_model,
77
- huggingfacehub_api_token=api_token,
78
- temperature=temperature,
79
- max_new_tokens=max_tokens,
80
- top_k=top_k,
81
- )
82
-
83
- else:
84
-
85
- llm = HuggingFaceEndpoint(
86
- repo_id=llm_model,
87
- huggingfacehub_api_token=api_token,
88
- temperature=temperature,
89
- max_new_tokens=max_tokens,
90
- top_k=top_k,
91
- )
92
-
93
  progress(0.75, desc="Defining buffer memory...")
94
  memory = ConversationBufferMemory(
95
  memory_key="chat_history",
@@ -101,11 +88,12 @@ def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, pr
101
  qa_chain = ConversationalRetrievalChain.from_llm(
102
  llm,
103
  retriever=retriever,
104
- chain_type="stuff",
105
  memory=memory,
106
  return_source_documents=True,
107
  verbose=False,
108
  )
 
109
  progress(0.9, desc="Done!")
110
  return qa_chain
111
 
@@ -266,7 +254,7 @@ def demo():
266
  db_btn.click(initialize_database,
267
  inputs=[document, slider_chunk_size, slider_chunk_overlap, db_type_radio],
268
  outputs=[vector_db, collection_name, db_progress])
269
- set_prompt_btn.click(lambda prompt: gr.State(prompt),
270
  inputs=prompt_input,
271
  outputs=initial_prompt)
272
  qachain_btn.click(initialize_LLM,
 
17
  import torch
18
  api_token = os.getenv("HF_TOKEN")
19
 
20
+ list_llm = ["meta-llama/Meta-Llama-3-8B-Instruct", "mistralai/Mistral-7B-Instruct-v0.3"]
21
  list_llm_simple = [os.path.basename(llm) for llm in list_llm]
22
 
23
  # Load PDF document and create doc splits
 
64
  return vectordb
65
 
66
  # Initialize langchain LLM chain
67
+ def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, initial_prompt, progress=gr.Progress()):
 
68
  progress(0.1, desc="Initializing HF tokenizer...")
69
 
70
  progress(0.5, desc="Initializing HF Hub...")
71
 
72
+ llm = HuggingFaceEndpoint(
73
+ repo_id=llm_model,
74
+ huggingfacehub_api_token=api_token,
75
+ temperature=temperature,
76
+ max_new_tokens=max_tokens,
77
+ top_k=top_k,
78
+ )
79
+
 
 
 
 
 
 
 
 
 
 
 
 
80
  progress(0.75, desc="Defining buffer memory...")
81
  memory = ConversationBufferMemory(
82
  memory_key="chat_history",
 
88
  qa_chain = ConversationalRetrievalChain.from_llm(
89
  llm,
90
  retriever=retriever,
91
+ chain_type="stuff",
92
  memory=memory,
93
  return_source_documents=True,
94
  verbose=False,
95
  )
96
+ qa_chain({"question": initial_prompt}) # Initialize with the initial prompt
97
  progress(0.9, desc="Done!")
98
  return qa_chain
99
 
 
254
  db_btn.click(initialize_database,
255
  inputs=[document, slider_chunk_size, slider_chunk_overlap, db_type_radio],
256
  outputs=[vector_db, collection_name, db_progress])
257
+ set_prompt_btn.click(lambda prompt: prompt,
258
  inputs=prompt_input,
259
  outputs=initial_prompt)
260
  qachain_btn.click(initialize_LLM,