Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -15,6 +15,7 @@ from langchain.chains import ConversationChain
|
|
15 |
from langchain.memory import ConversationBufferMemory
|
16 |
from langchain_community.llms import HuggingFaceEndpoint
|
17 |
import torch
|
|
|
18 |
api_token = os.getenv("HF_TOKEN")
|
19 |
|
20 |
list_llm = ["meta-llama/Meta-Llama-3-8B-Instruct", "mistralai/Mistral-7B-Instruct-v0.3"]
|
@@ -155,6 +156,33 @@ def upload_file(file_obj):
|
|
155 |
list_file_path.append(file.name)
|
156 |
return list_file_path
|
157 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
158 |
def demo():
|
159 |
with gr.Blocks(theme="base") as demo:
|
160 |
vector_db = gr.State()
|
@@ -257,7 +285,7 @@ def demo():
|
|
257 |
set_prompt_btn.click(lambda prompt: gr.update(value=prompt),
|
258 |
inputs=prompt_input,
|
259 |
outputs=initial_prompt)
|
260 |
-
qachain_btn.click(
|
261 |
inputs=[llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db, initial_prompt],
|
262 |
outputs=[qa_chain, llm_progress]).then(lambda:[None,"",0,"",0,"",0],
|
263 |
inputs=None,
|
|
|
15 |
from langchain.memory import ConversationBufferMemory
|
16 |
from langchain_community.llms import HuggingFaceEndpoint
|
17 |
import torch
|
18 |
+
|
19 |
api_token = os.getenv("HF_TOKEN")
|
20 |
|
21 |
list_llm = ["meta-llama/Meta-Llama-3-8B-Instruct", "mistralai/Mistral-7B-Instruct-v0.3"]
|
|
|
156 |
list_file_path.append(file.name)
|
157 |
return list_file_path
|
158 |
|
159 |
+
def initialize_database(list_file_obj, chunk_size, chunk_overlap, db_type, progress=gr.Progress()):
|
160 |
+
list_file_path = [x.name for x in list_file_obj if x is not None]
|
161 |
+
progress(0.1, desc="Creating collection name...")
|
162 |
+
collection_name = create_collection_name(list_file_path[0])
|
163 |
+
progress(0.25, desc="Loading document...")
|
164 |
+
doc_splits = load_doc(list_file_path, chunk_size, chunk_overlap)
|
165 |
+
progress(0.5, desc="Generating vector database...")
|
166 |
+
vector_db = create_db(doc_splits, collection_name, db_type)
|
167 |
+
progress(0.9, desc="Done!")
|
168 |
+
return vector_db, collection_name, "Complete!"
|
169 |
+
|
170 |
+
def create_collection_name(filepath):
|
171 |
+
collection_name = Path(filepath).stem
|
172 |
+
collection_name = collection_name.replace(" ", "-")
|
173 |
+
collection_name = unidecode(collection_name)
|
174 |
+
collection_name = re.sub('[^A-Za-z0-9]+', '-', collection_name)
|
175 |
+
collection_name = collection_name[:50]
|
176 |
+
if len(collection_name) < 3:
|
177 |
+
collection_name = collection_name + 'xyz'
|
178 |
+
if not collection_name[0].isalnum():
|
179 |
+
collection_name = 'A' + collection_name[1:]
|
180 |
+
if not collection_name[-1].isalnum():
|
181 |
+
collection_name = collection_name[:-1] + 'Z'
|
182 |
+
print('Filepath: ', filepath)
|
183 |
+
print('Collection name: ', collection_name)
|
184 |
+
return collection_name
|
185 |
+
|
186 |
def demo():
|
187 |
with gr.Blocks(theme="base") as demo:
|
188 |
vector_db = gr.State()
|
|
|
285 |
set_prompt_btn.click(lambda prompt: gr.update(value=prompt),
|
286 |
inputs=prompt_input,
|
287 |
outputs=initial_prompt)
|
288 |
+
qachain_btn.click(initialize_llmchain,
|
289 |
inputs=[llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db, initial_prompt],
|
290 |
outputs=[qa_chain, llm_progress]).then(lambda:[None,"",0,"",0,"",0],
|
291 |
inputs=None,
|