YanshekWoo commited on
Commit
f04d15c
·
verified ·
1 Parent(s): e6c813c

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +26 -17
app.py CHANGED
@@ -53,18 +53,24 @@ def upload_file_fn(
53
  print(e)
54
  gr.Error("Read the file failed. Please check the data format.")
55
  gr.Error(str(e))
56
- return None
57
 
58
  if len(documents) < 3:
59
  gr.Error("Please upload more than 3 documents.")
60
- return None
61
 
62
  gr.Info(f"Upload {len(documents)} documents.")
63
- if len(documents) > 2000:
64
- gr.Info(f"Cut uploaded documents to 2000.")
65
- documents = documents[: 2000]
66
 
67
- documents_embeddings = model.encode(documents, show_progress_bar=True)
 
 
 
 
 
 
68
 
69
  document_index = create_index(documents_embeddings, use_gpu=False)
70
 
@@ -72,13 +78,12 @@ def upload_file_fn(
72
  torch.cuda.empty_cache()
73
  torch.cuda.ipc_collect()
74
 
75
- print("upload is OK")
76
  document_state = {"document_data": document_data, "document_index": document_index}
77
- return document_state,
78
 
79
 
80
  def clear_file_fn():
81
- return None
82
 
83
 
84
  def retrieve_document_fn(question, document_states, instruct):
@@ -87,12 +92,11 @@ def retrieve_document_fn(question, document_states, instruct):
87
  if document_states is None:
88
  gr.Warning("Please upload documents first!")
89
  return [None for i in range(num_retrieval_doc)] + [None]
90
-
91
- print(document_states)
92
  document_data, document_index = document_states["document_data"], document_states["document_index"]
93
 
94
  question_embedding = model.encode([str(instruct) + str(question)])
95
- batch_scores, batch_inxs = document_index.search(question_embedding, k=num_retrieval_doc)
96
 
97
  answers = [document_data[i]["text"] for i in batch_inxs[0][:num_retrieval_doc]]
98
  return answers[0], answers[1], answers[2], document_states
@@ -101,7 +105,10 @@ def retrieve_document_fn(question, document_states, instruct):
101
  def main(args):
102
  global model
103
 
104
- model = SentenceTransformer(args.model_name_or_path)
 
 
 
105
 
106
  document_state = gr.State()
107
 
@@ -117,24 +124,25 @@ def main(args):
117
  doc_files_box = gr.File(label="Upload Documents", file_types=[".json"], file_count="single")
118
  retrieval_interface = gr.Interface(
119
  fn=retrieve_document_fn,
120
- inputs=["text", document_state],
121
  outputs=[gr.Text(label="Recall-1"), gr.Text(label="Recall-2"), gr.Text(label="Recall-3"), gr.State()],
122
- additional_inputs=[gr.Textbox("Instruct: Given a query, retrieve documents that answer the query. \n Query: ", label="Instruct of Query")],
123
  concurrency_limit=1,
124
  )
 
125
 
126
 
127
  doc_files_box.upload(
128
  upload_file_fn,
129
  [doc_files_box],
130
- [document_state],
131
  queue=True,
132
  trigger_mode="once"
133
  )
134
  doc_files_box.clear(
135
  clear_file_fn,
136
  None,
137
- [document_state],
138
  queue=True,
139
  trigger_mode="once"
140
  )
@@ -145,6 +153,7 @@ if __name__ == "__main__":
145
  parser = argparse.ArgumentParser()
146
  parser.add_argument("--model_name_or_path", type=str, default="HIT-TMG/KaLM-embedding-multilingual-mini-instruct-v1.5")
147
  # parser.add_argument("--model_name_or_path", type=str, default="/raid/hxs/Checkpoints/huggingface_models/bge-base-en-v1.5")
 
148
 
149
  args = parser.parse_args()
150
  main(args)
 
53
  print(e)
54
  gr.Error("Read the file failed. Please check the data format.")
55
  gr.Error(str(e))
56
+ return None, gr.update(interactive=False)
57
 
58
  if len(documents) < 3:
59
  gr.Error("Please upload more than 3 documents.")
60
+ return None, gr.update(interactive=False)
61
 
62
  gr.Info(f"Upload {len(documents)} documents.")
63
+ if len(documents) > 1000:
64
+ gr.Info(f"Cut uploaded documents to 1000 due to the computation resource.")
65
+ documents = documents[: 1000]
66
 
67
+ # documents_embeddings = model.encode(documents, show_progress_bar=True)
68
+ documents_embeddings = []
69
+ batch_size = 8
70
+ for i in tqdm(range(0, len(documents), batch_size)):
71
+ batch_documents = documents[i: i+batch_size]
72
+ batch_embeddings = model.encode(batch_documents, show_progress_bar=True)
73
+ documents_embeddings.extend(batch_embeddings)
74
 
75
  document_index = create_index(documents_embeddings, use_gpu=False)
76
 
 
78
  torch.cuda.empty_cache()
79
  torch.cuda.ipc_collect()
80
 
 
81
  document_state = {"document_data": document_data, "document_index": document_index}
82
+ return document_state, gr.update(interactive=True)
83
 
84
 
85
  def clear_file_fn():
86
+ return None, gr.update(interactive=True)
87
 
88
 
89
  def retrieve_document_fn(question, document_states, instruct):
 
92
  if document_states is None:
93
  gr.Warning("Please upload documents first!")
94
  return [None for i in range(num_retrieval_doc)] + [None]
95
+
 
96
  document_data, document_index = document_states["document_data"], document_states["document_index"]
97
 
98
  question_embedding = model.encode([str(instruct) + str(question)])
99
+ batch_scores, batch_inxs = document_index.search(question_embedding, k=min(len(document_data), 150))
100
 
101
  answers = [document_data[i]["text"] for i in batch_inxs[0][:num_retrieval_doc]]
102
  return answers[0], answers[1], answers[2], document_states
 
105
  def main(args):
106
  global model
107
 
108
+ model = SentenceTransformer(
109
+ args.model_name_or_path,
110
+ revision=args.revision,
111
+ backend="openvino")
112
 
113
  document_state = gr.State()
114
 
 
124
  doc_files_box = gr.File(label="Upload Documents", file_types=[".json"], file_count="single")
125
  retrieval_interface = gr.Interface(
126
  fn=retrieve_document_fn,
127
+ inputs=[gr.Textbox(label="Query"), document_state],
128
  outputs=[gr.Text(label="Recall-1"), gr.Text(label="Recall-2"), gr.Text(label="Recall-3"), gr.State()],
129
+ additional_inputs=[gr.Textbox("Instruct: Given a query, retrieve documents that answer the query. \n Query: ", label="Instruct of Query", lines=2)],
130
  concurrency_limit=1,
131
  )
132
+ # retrieval_interface.input_components[0] = gr.update(interactive=False)
133
 
134
 
135
  doc_files_box.upload(
136
  upload_file_fn,
137
  [doc_files_box],
138
+ [document_state, retrieval_interface.input_components[0]],
139
  queue=True,
140
  trigger_mode="once"
141
  )
142
  doc_files_box.clear(
143
  clear_file_fn,
144
  None,
145
+ [document_state, retrieval_interface.input_components[0]],
146
  queue=True,
147
  trigger_mode="once"
148
  )
 
153
  parser = argparse.ArgumentParser()
154
  parser.add_argument("--model_name_or_path", type=str, default="HIT-TMG/KaLM-embedding-multilingual-mini-instruct-v1.5")
155
  # parser.add_argument("--model_name_or_path", type=str, default="/raid/hxs/Checkpoints/huggingface_models/bge-base-en-v1.5")
156
+ parser.add_argument("--revision", type=str, default="refs/pr/2")
157
 
158
  args = parser.parse_args()
159
  main(args)