Fecalisboa commited on
Commit
802f2b8
·
verified ·
1 Parent(s): 90d40e1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -1
app.py CHANGED
@@ -15,6 +15,7 @@ from langchain.chains import ConversationChain
15
  from langchain.memory import ConversationBufferMemory
16
  from langchain_community.llms import HuggingFaceEndpoint
17
  import torch
 
18
  api_token = os.getenv("HF_TOKEN")
19
 
20
  list_llm = ["meta-llama/Meta-Llama-3-8B-Instruct", "mistralai/Mistral-7B-Instruct-v0.3"]
@@ -155,6 +156,33 @@ def upload_file(file_obj):
155
  list_file_path.append(file.name)
156
  return list_file_path
157
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
  def demo():
159
  with gr.Blocks(theme="base") as demo:
160
  vector_db = gr.State()
@@ -257,7 +285,7 @@ def demo():
257
  set_prompt_btn.click(lambda prompt: gr.update(value=prompt),
258
  inputs=prompt_input,
259
  outputs=initial_prompt)
260
- qachain_btn.click(initialize_LLM,
261
  inputs=[llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db, initial_prompt],
262
  outputs=[qa_chain, llm_progress]).then(lambda:[None,"",0,"",0,"",0],
263
  inputs=None,
 
15
  from langchain.memory import ConversationBufferMemory
16
  from langchain_community.llms import HuggingFaceEndpoint
17
  import torch
18
+
19
  api_token = os.getenv("HF_TOKEN")
20
 
21
  list_llm = ["meta-llama/Meta-Llama-3-8B-Instruct", "mistralai/Mistral-7B-Instruct-v0.3"]
 
156
  list_file_path.append(file.name)
157
  return list_file_path
158
 
159
+ def initialize_database(list_file_obj, chunk_size, chunk_overlap, db_type, progress=gr.Progress()):
160
+ list_file_path = [x.name for x in list_file_obj if x is not None]
161
+ progress(0.1, desc="Creating collection name...")
162
+ collection_name = create_collection_name(list_file_path[0])
163
+ progress(0.25, desc="Loading document...")
164
+ doc_splits = load_doc(list_file_path, chunk_size, chunk_overlap)
165
+ progress(0.5, desc="Generating vector database...")
166
+ vector_db = create_db(doc_splits, collection_name, db_type)
167
+ progress(0.9, desc="Done!")
168
+ return vector_db, collection_name, "Complete!"
169
+
170
+ def create_collection_name(filepath):
171
+ collection_name = Path(filepath).stem
172
+ collection_name = collection_name.replace(" ", "-")
173
+ collection_name = unidecode(collection_name)
174
+ collection_name = re.sub('[^A-Za-z0-9]+', '-', collection_name)
175
+ collection_name = collection_name[:50]
176
+ if len(collection_name) < 3:
177
+ collection_name = collection_name + 'xyz'
178
+ if not collection_name[0].isalnum():
179
+ collection_name = 'A' + collection_name[1:]
180
+ if not collection_name[-1].isalnum():
181
+ collection_name = collection_name[:-1] + 'Z'
182
+ print('Filepath: ', filepath)
183
+ print('Collection name: ', collection_name)
184
+ return collection_name
185
+
186
  def demo():
187
  with gr.Blocks(theme="base") as demo:
188
  vector_db = gr.State()
 
285
  set_prompt_btn.click(lambda prompt: gr.update(value=prompt),
286
  inputs=prompt_input,
287
  outputs=initial_prompt)
288
+ qachain_btn.click(initialize_llmchain,
289
  inputs=[llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db, initial_prompt],
290
  outputs=[qa_chain, llm_progress]).then(lambda:[None,"",0,"",0,"",0],
291
  inputs=None,