Fecalisboa commited on
Commit
6db85cc
·
verified ·
1 Parent(s): ff20f9a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -12
app.py CHANGED
@@ -17,7 +17,7 @@ from langchain_community.llms import HuggingFaceEndpoint
17
  import torch
18
  api_token = os.getenv("HF_TOKEN")
19
 
20
- list_llm = ["meta-llama/Meta-Llama-3-8B-Instruct", "mistralai/Mistral-7B-Instruct-v0.3"]
21
  list_llm_simple = [os.path.basename(llm) for llm in list_llm]
22
 
23
  # Load PDF document and create doc splits
@@ -64,19 +64,32 @@ def create_db(splits, collection_name, db_type):
64
  return vectordb
65
 
66
  # Initialize langchain LLM chain
67
- def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, initial_prompt, progress=gr.Progress()):
 
68
  progress(0.1, desc="Initializing HF tokenizer...")
69
 
70
  progress(0.5, desc="Initializing HF Hub...")
71
 
72
- llm = HuggingFaceEndpoint(
73
- repo_id=llm_model,
74
- huggingfacehub_api_token=api_token,
75
- temperature=temperature,
76
- max_new_tokens=max_tokens,
77
- top_k=top_k,
78
- )
79
-
 
 
 
 
 
 
 
 
 
 
 
 
80
  progress(0.75, desc="Defining buffer memory...")
81
  memory = ConversationBufferMemory(
82
  memory_key="chat_history",
@@ -88,12 +101,11 @@ def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, in
88
  qa_chain = ConversationalRetrievalChain.from_llm(
89
  llm,
90
  retriever=retriever,
91
- chain_type="stuff",
92
  memory=memory,
93
  return_source_documents=True,
94
  verbose=False,
95
  )
96
- qa_chain({"question": initial_prompt}) # Initialize with the initial prompt
97
  progress(0.9, desc="Done!")
98
  return qa_chain
99
 
 
17
  import torch
18
  api_token = os.getenv("HF_TOKEN")
19
 
20
+ list_llm = ["meta-llama/Meta-Llama-3-8B-Instruct", "mistralai/Mistral-7B-Instruct-v0.3"]
21
  list_llm_simple = [os.path.basename(llm) for llm in list_llm]
22
 
23
  # Load PDF document and create doc splits
 
64
  return vectordb
65
 
66
  # Initialize langchain LLM chain
67
+ # Initialize langchain LLM chain
68
+ def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
69
  progress(0.1, desc="Initializing HF tokenizer...")
70
 
71
  progress(0.5, desc="Initializing HF Hub...")
72
 
73
+
74
+ if llm_model == "meta-llama/Meta-Llama-3-8B-Instruct":
75
+ llm = HuggingFaceEndpoint(
76
+ repo_id=llm_model,
77
+ huggingfacehub_api_token=api_token,
78
+ temperature=temperature,
79
+ max_new_tokens=max_tokens,
80
+ top_k=top_k,
81
+ )
82
+
83
+ else:
84
+
85
+ llm = HuggingFaceEndpoint(
86
+ repo_id=llm_model,
87
+ huggingfacehub_api_token=api_token,
88
+ temperature=temperature,
89
+ max_new_tokens=max_tokens,
90
+ top_k=top_k,
91
+ )
92
+
93
  progress(0.75, desc="Defining buffer memory...")
94
  memory = ConversationBufferMemory(
95
  memory_key="chat_history",
 
101
  qa_chain = ConversationalRetrievalChain.from_llm(
102
  llm,
103
  retriever=retriever,
104
+ chain_type="stuff",
105
  memory=memory,
106
  return_source_documents=True,
107
  verbose=False,
108
  )
 
109
  progress(0.9, desc="Done!")
110
  return qa_chain
111