YanshekWoo commited on
Commit
5d00a88
·
verified ·
1 Parent(s): 5dc11b3

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +44 -23
app.py CHANGED
@@ -7,6 +7,7 @@ import gradio as gr
7
  import faiss
8
  import numpy as np
9
  import torch
 
10
  from sentence_transformers import SentenceTransformer
11
 
12
 
@@ -17,7 +18,9 @@ file_example = """Please upload a JSON file with a "text" field (with optional "
17
  {"title": "Title A", "text": "This an example text with the title"},
18
  {"title": "Title B", "text": "This an example text with the title"},
19
  ]
20
- ```"""
 
 
21
 
22
 
23
  def create_index(embeddings, use_gpu):
@@ -42,13 +45,26 @@ def upload_file_fn(
42
  documents = []
43
  for obj in document_data:
44
  text = obj["title"] + "\n" + obj["text"] if obj.get("title") else obj["text"]
45
- documents.append(text)
 
 
 
46
  except Exception as e:
47
  print(e)
48
- gr.Warning("Read the file failed. Please check the data format.")
49
- return None, None
 
 
 
 
 
 
 
 
 
 
50
 
51
- documents_embeddings = model.encode(documents)
52
 
53
  document_index = create_index(documents_embeddings, use_gpu=False)
54
 
@@ -56,25 +72,30 @@ def upload_file_fn(
56
  torch.cuda.empty_cache()
57
  torch.cuda.ipc_collect()
58
 
59
- return document_index, document_data
 
 
60
 
61
 
62
  def clear_file_fn():
63
- return None, None
64
 
65
 
66
- def retrieve_document_fn(question, instruct, document_states):
67
- document_data, document_index = document_states
68
  num_retrieval_doc = 3
69
- if document_index is None or document_data is None:
 
70
  gr.Warning("Please upload documents first!")
71
- return [None for i in range(num_retrieval_doc)]
 
 
 
72
 
73
- question_embedding = model.encode([instruct + question])
74
  batch_scores, batch_inxs = document_index.search(question_embedding, k=num_retrieval_doc)
75
 
76
  answers = [document_data[i]["text"] for i in batch_inxs[0][:num_retrieval_doc]]
77
- return tuple(answers)
78
 
79
 
80
  def main(args):
@@ -82,10 +103,8 @@ def main(args):
82
 
83
  model = SentenceTransformer(args.model_name_or_path)
84
 
85
- document_index = gr.State()
86
- document_data = gr.State()
87
 
88
-
89
  with open(Path(__file__).parent / "resources/head.html") as html_file:
90
  head = html_file.read().strip()
91
  with gr.Blocks(theme=gr.themes.Soft(font="sans-serif").set(background_fill_primary="linear-gradient(90deg, #e3ffe7 0%, #d9e7ff 100%)", background_fill_primary_dark="linear-gradient(90deg, #4b6cb7 0%, #182848 100%)",),
@@ -98,23 +117,24 @@ def main(args):
98
  doc_files_box = gr.File(label="Upload Documents", file_types=[".json"], file_count="single")
99
  retrieval_interface = gr.Interface(
100
  fn=retrieve_document_fn,
101
- inputs=["text"],
102
- outputs=["text", "text", "text"],
103
- additional_inputs=[gr.Textbox("Instruct: Given a query, retrieve documents that answer the query. \n Query: ", label="Instruct of Query"), gr.State(value=[document_data, document_index])],
104
  concurrency_limit=1,
105
  )
 
106
 
107
  doc_files_box.upload(
108
  upload_file_fn,
109
  [doc_files_box],
110
- [document_index, document_data],
111
  queue=True,
112
  trigger_mode="once"
113
  )
114
  doc_files_box.clear(
115
- upload_file_fn,
116
  None,
117
- [document_index, document_data],
118
  queue=True,
119
  trigger_mode="once"
120
  )
@@ -123,7 +143,8 @@ def main(args):
123
 
124
  if __name__ == "__main__":
125
  parser = argparse.ArgumentParser()
126
- parser.add_argument("--model_name_or_path", type=str, default="HIT-TMG/KaLM-embedding-multilingual-mini-instruct-v1.5")
 
127
 
128
  args = parser.parse_args()
129
  main(args)
 
7
  import faiss
8
  import numpy as np
9
  import torch
10
+ from tqdm import tqdm
11
  from sentence_transformers import SentenceTransformer
12
 
13
 
 
18
  {"title": "Title A", "text": "This an example text with the title"},
19
  {"title": "Title B", "text": "This an example text with the title"},
20
  ]
21
+ ```
22
+ Due to the computation resources, please test with small scale data.
23
+ """
24
 
25
 
26
  def create_index(embeddings, use_gpu):
 
45
  documents = []
46
  for obj in document_data:
47
  text = obj["title"] + "\n" + obj["text"] if obj.get("title") else obj["text"]
48
+ if len(str(text).strip()):
49
+ documents.append(text)
50
+ else:
51
+ documents.append(model.tokenizer.eos_token)
52
  except Exception as e:
53
  print(e)
54
+ gr.Error("Read the file failed. Please check the data format.")
55
+ gr.Error(str(e))
56
+ return None
57
+
58
+ if len(documents) < 3:
59
+ gr.Error("Please upload more than 3 documents.")
60
+ return None
61
+
62
+ gr.Info(f"Upload {len(documents)} documents.")
63
+ if len(documents) > 2000:
64
+ gr.Info(f"Cut uploaded documents to 2000.")
65
+ documents = documents[: 2000]
66
 
67
+ documents_embeddings = model.encode(documents, show_progress_bar=True)
68
 
69
  document_index = create_index(documents_embeddings, use_gpu=False)
70
 
 
72
  torch.cuda.empty_cache()
73
  torch.cuda.ipc_collect()
74
 
75
+ print("upload is OK")
76
+ document_state = {"document_data": document_data, "document_index": document_index}
77
+ return document_state,
78
 
79
 
80
  def clear_file_fn():
81
+ return None
82
 
83
 
84
+ def retrieve_document_fn(question, document_states, instruct):
 
85
  num_retrieval_doc = 3
86
+
87
+ if document_states is None:
88
  gr.Warning("Please upload documents first!")
89
+ return [None for i in range(num_retrieval_doc)] + [None]
90
+
91
+ print(document_states)
92
+ document_data, document_index = document_states["document_data"], document_states["document_index"]
93
 
94
+ question_embedding = model.encode([str(instruct) + str(question)])
95
  batch_scores, batch_inxs = document_index.search(question_embedding, k=num_retrieval_doc)
96
 
97
  answers = [document_data[i]["text"] for i in batch_inxs[0][:num_retrieval_doc]]
98
+ return answers[0], answers[1], answers[2], document_states
99
 
100
 
101
  def main(args):
 
103
 
104
  model = SentenceTransformer(args.model_name_or_path)
105
 
106
+ document_state = gr.State()
 
107
 
 
108
  with open(Path(__file__).parent / "resources/head.html") as html_file:
109
  head = html_file.read().strip()
110
  with gr.Blocks(theme=gr.themes.Soft(font="sans-serif").set(background_fill_primary="linear-gradient(90deg, #e3ffe7 0%, #d9e7ff 100%)", background_fill_primary_dark="linear-gradient(90deg, #4b6cb7 0%, #182848 100%)",),
 
117
  doc_files_box = gr.File(label="Upload Documents", file_types=[".json"], file_count="single")
118
  retrieval_interface = gr.Interface(
119
  fn=retrieve_document_fn,
120
+ inputs=["text", document_state],
121
+ outputs=[gr.Text(label="Recall-1"), gr.Text(label="Recall-2"), gr.Text(label="Recall-3"), gr.State()],
122
+ additional_inputs=[gr.Textbox("Instruct: Given a query, retrieve documents that answer the query. \n Query: ", label="Instruct of Query")],
123
  concurrency_limit=1,
124
  )
125
+
126
 
127
  doc_files_box.upload(
128
  upload_file_fn,
129
  [doc_files_box],
130
+ [document_state],
131
  queue=True,
132
  trigger_mode="once"
133
  )
134
  doc_files_box.clear(
135
+ clear_file_fn,
136
  None,
137
+ [document_state],
138
  queue=True,
139
  trigger_mode="once"
140
  )
 
143
 
144
  if __name__ == "__main__":
145
  parser = argparse.ArgumentParser()
146
+ # parser.add_argument("--model_name_or_path", type=str, default="HIT-TMG/KaLM-embedding-multilingual-mini-instruct-v1.5")
147
+ parser.add_argument("--model_name_or_path", type=str, default="/raid/hxs/Checkpoints/huggingface_models/bge-base-en-v1.5")
148
 
149
  args = parser.parse_args()
150
  main(args)