Spaces:
Running
Running
File size: 6,287 Bytes
b1f1fd7 5d00a88 b1f1fd7 5d00a88 c9d8253 5d00a88 b1f1fd7 c9d8253 b1f1fd7 5d00a88 b1f1fd7 5d00a88 f04d15c 5d00a88 c9d8253 f04d15c b1f1fd7 f04d15c c9d8253 f04d15c b1f1fd7 5d00a88 f04d15c b1f1fd7 f04d15c b1f1fd7 5d00a88 c9d8253 5d00a88 b1f1fd7 5d00a88 f04d15c 5d00a88 b1f1fd7 c9d8253 f04d15c b1f1fd7 c9d8253 b1f1fd7 f04d15c 2b145d6 b1f1fd7 5d00a88 b1f1fd7 2b145d6 b1f1fd7 f04d15c c9d8253 f04d15c b1f1fd7 c9d8253 b1f1fd7 f04d15c 5d00a88 b1f1fd7 f04d15c b1f1fd7 5d00a88 b1f1fd7 f04d15c b1f1fd7 e6c813c c9d8253 b1f1fd7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 |
import json
import argparse
from pathlib import Path
from typing import List
import gradio as gr
import faiss
import numpy as np
import torch
from tqdm import tqdm
from sentence_transformers import SentenceTransformer
file_example = """Please upload a JSON file with a "text" field (with optional "title" field). For example
```JSON
[
{"title": "", "text": "This an example text without the title"},
{"title": "Title A", "text": "This an example text with the title"},
{"title": "Title B", "text": "This an example text with the title"},
]
```
Due to the computation resources, please test with small scale data (<1000).
"""
def create_index(embeddings, use_gpu):
index = faiss.IndexFlatIP(len(embeddings[0]))
embeddings = np.asarray(embeddings, dtype=np.float32)
if use_gpu:
co = faiss.GpuMultipleClonerOptions()
co.shard = True
co.useFloat16 = True
index = faiss.index_cpu_to_all_gpus(index, co=co)
index.add(embeddings)
return index
def upload_file_fn(
file_path: List[str],
progress: gr.Progress = gr.Progress(track_tqdm=True)
):
try:
with open(file_path) as f:
document_data = json.load(f)
gr.Info(f"Upload {len(document_data)} documents.")
if len(document_data) > 1000:
gr.Info(f"Cut uploaded documents to 1000 due to the computation resource.")
document_data = document_data[: 1000]
documents = []
for obj in document_data:
text = obj["title"] + "\n" + obj["text"] if obj.get("title") else obj["text"]
if len(str(text).strip()):
documents.append(text)
else:
documents.append(model.tokenizer.eos_token)
except Exception as e:
print(e)
gr.Error("Read the file failed. Please check the data format.")
gr.Error(str(e))
return None, gr.update(interactive=False)
if len(documents) < 5:
gr.Error("Please upload more than 53 documents.")
return None, gr.update(interactive=False)
# documents_embeddings = model.encode(documents, show_progress_bar=True)
documents_embeddings = []
batch_size = 16
for i in tqdm(range(0, len(documents), batch_size)):
batch_documents = documents[i: i+batch_size]
batch_embeddings = model.encode(batch_documents, show_progress_bar=True)
documents_embeddings.extend(batch_embeddings)
document_index = create_index(documents_embeddings, use_gpu=False)
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
document_state = {"document_data": document_data, "document_index": document_index}
return document_state, gr.update(interactive=True)
def clear_file_fn():
return None, gr.update(interactive=True)
def retrieve_document_fn(question, document_states, instruct):
num_retrieval_doc = 5
if document_states is None:
gr.Warning("Please upload documents first!")
return [None for i in range(num_retrieval_doc)] + [None]
document_data, document_index = document_states["document_data"], document_states["document_index"]
question_with_inst = str(instruct) + str(question)
if len(question_with_inst.strip()) == 0:
gr.Warning("Please enter a non-empty query.")
return None, None, None, None, None, document_states
question_embedding = model.encode([question_with_inst])
batch_scores, batch_inxs = document_index.search(question_embedding, k=min(len(document_data), 150))
answers = [document_data[i]["text"] for i in batch_inxs[0][:num_retrieval_doc]]
return answers[0], answers[1], answers[2], answers[3], answers[4],document_states
def main(args):
global model
model = SentenceTransformer(
args.model_name_or_path,
revision=args.revision,
)
document_state = gr.State()
with open(Path(__file__).parent / "resources/head.html") as html_file:
head = html_file.read().strip()
with gr.Blocks(theme=gr.themes.Soft(font="sans-serif").set(background_fill_primary="linear-gradient(90deg, #e3ffe7 0%, #d9e7ff 100%)", background_fill_primary_dark="linear-gradient(90deg, #4b6cb7 0%, #182848 100%)",),
head=head,
css=Path(__file__).parent / "resources/styles.css",
title="KaLM-Embedding",
fill_height=True,
analytics_enabled=False) as demo:
gr.Markdown(file_example)
doc_files_box = gr.File(label="Upload Documents", file_types=[".json"], file_count="single")
model_selection = gr.Radio(["HIT-TMG/KaLM-embedding-multilingual-mini-instruct-v1.5"], value="HIT-TMG/KaLM-embedding-multilingual-mini-instruct-v1.5", label="Model Selection", interactive=False)
retrieval_interface = gr.Interface(
fn=retrieve_document_fn,
inputs=[gr.Textbox(label="Query"), document_state],
outputs=[gr.Text(label="Recall-1"), gr.Text(label="Recall-2"), gr.Text(label="Recall-3"), gr.Text(label="Recall-4"), gr.Text(label="Recall-5"), gr.State()],
additional_inputs=[gr.Textbox("Instruct: Given a query, retrieve documents that answer the query. \n Query: ", label="Instruct of Query", lines=2)],
concurrency_limit=1,
allow_flagging="never",
)
# retrieval_interface.input_components[0] = gr.update(interactive=False)
doc_files_box.upload(
upload_file_fn,
[doc_files_box],
[document_state, retrieval_interface.input_components[0]],
queue=True,
trigger_mode="once"
)
doc_files_box.clear(
clear_file_fn,
None,
[document_state, retrieval_interface.input_components[0]],
queue=True,
trigger_mode="once"
)
demo.launch()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model_name_or_path", type=str, default="HIT-TMG/KaLM-embedding-multilingual-mini-instruct-v1.5")
parser.add_argument("--revision", type=str, default=None)
args = parser.parse_args()
main(args) |