File size: 6,287 Bytes
b1f1fd7
 
 
 
 
 
 
 
 
5d00a88
b1f1fd7
 
 
 
 
 
 
 
 
 
5d00a88
c9d8253
5d00a88
b1f1fd7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c9d8253
 
 
 
 
 
b1f1fd7
 
 
5d00a88
 
 
 
b1f1fd7
 
5d00a88
 
f04d15c
5d00a88
c9d8253
 
f04d15c
b1f1fd7
f04d15c
 
c9d8253
f04d15c
 
 
 
b1f1fd7
 
 
 
 
 
 
5d00a88
f04d15c
b1f1fd7
 
 
f04d15c
b1f1fd7
 
5d00a88
c9d8253
5d00a88
 
b1f1fd7
5d00a88
f04d15c
5d00a88
b1f1fd7
c9d8253
 
 
 
 
 
f04d15c
b1f1fd7
 
c9d8253
b1f1fd7
 
 
 
 
f04d15c
 
 
2b145d6
b1f1fd7
5d00a88
b1f1fd7
 
 
 
 
 
 
 
 
 
 
2b145d6
b1f1fd7
 
f04d15c
c9d8253
f04d15c
b1f1fd7
c9d8253
b1f1fd7
f04d15c
5d00a88
b1f1fd7
 
 
 
f04d15c
b1f1fd7
 
 
 
5d00a88
b1f1fd7
f04d15c
b1f1fd7
 
 
 
 
 
 
 
e6c813c
c9d8253
b1f1fd7
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
import json
import argparse
from pathlib import Path
from typing import List

import gradio as gr
import faiss
import numpy as np
import torch
from tqdm import tqdm
from sentence_transformers import SentenceTransformer


file_example = """Please upload a JSON file with a "text" field (with optional "title" field). For example
```JSON
[
    {"title": "", "text": "This an example text without the title"},
    {"title": "Title A", "text": "This an example text with the title"},
    {"title": "Title B", "text": "This an example text with the title"},
]
```
Due to the computation resources, please test with small scale data (<1000).
"""


def create_index(embeddings, use_gpu):
    index = faiss.IndexFlatIP(len(embeddings[0]))
    embeddings = np.asarray(embeddings, dtype=np.float32)
    if use_gpu:
        co = faiss.GpuMultipleClonerOptions()
        co.shard = True
        co.useFloat16 = True
        index = faiss.index_cpu_to_all_gpus(index, co=co)
    index.add(embeddings)
    return index


def upload_file_fn(
        file_path: List[str], 
        progress: gr.Progress = gr.Progress(track_tqdm=True)
    ):
    try:
        with open(file_path) as f:
            document_data = json.load(f)
        
        gr.Info(f"Upload {len(document_data)} documents.")
        if len(document_data) > 1000:
            gr.Info(f"Cut uploaded documents to 1000 due to the computation resource.")
            document_data = document_data[: 1000]

        documents = []
        for obj in document_data:
            text = obj["title"] + "\n" + obj["text"] if obj.get("title") else obj["text"]
            if len(str(text).strip()):
                documents.append(text)
            else:
                documents.append(model.tokenizer.eos_token)
    except Exception as e:
        print(e)
        gr.Error("Read the file failed. Please check the data format.")
        gr.Error(str(e))
        return None, gr.update(interactive=False)

    if len(documents) < 5:
        gr.Error("Please upload more than 53 documents.")
        return None, gr.update(interactive=False)
    
    # documents_embeddings = model.encode(documents, show_progress_bar=True)
    documents_embeddings = []
    batch_size = 16
    for i in tqdm(range(0, len(documents), batch_size)):
        batch_documents = documents[i: i+batch_size]
        batch_embeddings = model.encode(batch_documents, show_progress_bar=True)
        documents_embeddings.extend(batch_embeddings)

    document_index = create_index(documents_embeddings, use_gpu=False)
    
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()
    
    document_state = {"document_data": document_data, "document_index": document_index}
    return document_state, gr.update(interactive=True)


def clear_file_fn():
    return None, gr.update(interactive=True)


def retrieve_document_fn(question, document_states, instruct):
    num_retrieval_doc = 5

    if document_states is None:
        gr.Warning("Please upload documents first!")
        return [None for i in range(num_retrieval_doc)] + [None]

    document_data, document_index = document_states["document_data"], document_states["document_index"]
    
    question_with_inst = str(instruct) + str(question)
    if len(question_with_inst.strip()) == 0:
        gr.Warning("Please enter a non-empty query.")
        return None, None, None, None, None, document_states
    
    question_embedding = model.encode([question_with_inst])
    batch_scores, batch_inxs = document_index.search(question_embedding, k=min(len(document_data), 150))
    
    answers = [document_data[i]["text"] for i in batch_inxs[0][:num_retrieval_doc]]
    return answers[0], answers[1], answers[2], answers[3], answers[4],document_states


def main(args):
    global model

    model = SentenceTransformer(
        args.model_name_or_path, 
        revision=args.revision,
        )

    document_state = gr.State()

    with open(Path(__file__).parent / "resources/head.html") as html_file:
        head = html_file.read().strip()
    with gr.Blocks(theme=gr.themes.Soft(font="sans-serif").set(background_fill_primary="linear-gradient(90deg, #e3ffe7 0%, #d9e7ff 100%)", background_fill_primary_dark="linear-gradient(90deg, #4b6cb7 0%, #182848 100%)",),
                   head=head, 
                   css=Path(__file__).parent / "resources/styles.css", 
                   title="KaLM-Embedding", 
                   fill_height=True, 
                   analytics_enabled=False) as demo:
        gr.Markdown(file_example)
        doc_files_box = gr.File(label="Upload Documents", file_types=[".json"], file_count="single")
        model_selection = gr.Radio(["HIT-TMG/KaLM-embedding-multilingual-mini-instruct-v1.5"], value="HIT-TMG/KaLM-embedding-multilingual-mini-instruct-v1.5", label="Model Selection", interactive=False)
        retrieval_interface = gr.Interface(
            fn=retrieve_document_fn,
            inputs=[gr.Textbox(label="Query"), document_state],
            outputs=[gr.Text(label="Recall-1"), gr.Text(label="Recall-2"),  gr.Text(label="Recall-3"), gr.Text(label="Recall-4"), gr.Text(label="Recall-5"), gr.State()],
            additional_inputs=[gr.Textbox("Instruct: Given a query, retrieve documents that answer the query. \n Query: ", label="Instruct of Query", lines=2)],
            concurrency_limit=1,
            allow_flagging="never",
        )
        # retrieval_interface.input_components[0] = gr.update(interactive=False)
        

        doc_files_box.upload(
            upload_file_fn, 
            [doc_files_box], 
            [document_state, retrieval_interface.input_components[0]],
            queue=True,
            trigger_mode="once"
        )
        doc_files_box.clear(
                clear_file_fn, 
                None, 
                [document_state, retrieval_interface.input_components[0]],
                queue=True,
                trigger_mode="once"
            )
    demo.launch()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_name_or_path", type=str, default="HIT-TMG/KaLM-embedding-multilingual-mini-instruct-v1.5")
    parser.add_argument("--revision", type=str, default=None)

    args = parser.parse_args()
    main(args)