File size: 4,271 Bytes
b1f1fd7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7feeeff
 
b1f1fd7
 
 
 
 
fb141c6
b1f1fd7
 
 
7feeeff
b1f1fd7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7feeeff
 
b1f1fd7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import json
import argparse
from pathlib import Path
from typing import List

import gradio as gr
import faiss
import numpy as np
import torch
from sentence_transformers import SentenceTransformer


file_example = """Please upload a JSON file with a "text" field (with optional "title" field). For example
```JSON
[
    {"title": "", "text": "This an example text without the title"},
    {"title": "Title A", "text": "This an example text with the title"},
    {"title": "Title B", "text": "This an example text with the title"},
]
```"""


def create_index(embeddings, use_gpu):
    index = faiss.IndexFlatIP(len(embeddings[0]))
    embeddings = np.asarray(embeddings, dtype=np.float32)
    if use_gpu:
        co = faiss.GpuMultipleClonerOptions()
        co.shard = True
        co.useFloat16 = True
        index = faiss.index_cpu_to_all_gpus(index, co=co)
    index.add(embeddings)
    return index


def upload_file_fn(
        file_path: List[str], 
        progress: gr.Progress = gr.Progress(track_tqdm=True)
    ):
    try:
        with open(file_path) as f:
            document_data = json.load(f)
        documents = []
        for obj in document_data:
            text = obj["title"] + "\n" + obj["text"] if obj.get("title") else obj["text"]
            documents.append(text)
    except Exception as e:
        print(e)
        gr.Warning("Read the file failed. Please check the data format.")
        return None, None
    
    documents_embeddings = model.encode(documents)

    document_index = create_index(documents_embeddings, use_gpu=False)
    
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()
    
    return document_index, document_data


def clear_file_fn():
    return None, None


def retrieve_document_fn(question, instruct, document_states):
    document_data, document_index = document_states
    num_retrieval_doc = 3
    if document_index is None or document_data is None:
        gr.Warning("Please upload documents first!")
        return [None for i in range(num_retrieval_doc)]
    
    question_embedding = model.encode([instruct + question])
    batch_scores, batch_inxs = document_index.search(question_embedding, k=num_retrieval_doc)
    
    answers = [document_data[i]["text"] for i in batch_inxs[0][:num_retrieval_doc]]
    return tuple(answers)


def main(args):
    global model

    model = SentenceTransformer(args.model_name_or_path)

    document_index = gr.State()
    document_data = gr.State()

    
    with open(Path(__file__).parent / "resources/head.html") as html_file:
        head = html_file.read().strip()
    with gr.Blocks(theme=gr.themes.Soft(font="sans-serif").set(background_fill_primary="linear-gradient(90deg, #e3ffe7 0%, #d9e7ff 100%)", background_fill_primary_dark="linear-gradient(90deg, #4b6cb7 0%, #182848 100%)",),
                   head=head, 
                   css=Path(__file__).parent / "resources/styles.css", 
                   title="KaLM-Embedding", 
                   fill_height=True, 
                   analytics_enabled=False) as demo:
        gr.Markdown(file_example)
        doc_files_box = gr.File(label="Upload Documents", file_types=[".json"], file_count="single")
        retrieval_interface = gr.Interface(
            fn=retrieve_document_fn,
            inputs=["text"],
            outputs=["text", "text",  "text"],
            additional_inputs=[gr.Textbox("Instruct: Given a query, retrieve documents that answer the query. \n Query: ", label="Instruct of Query"), [document_data, document_index]],
            concurrency_limit=1,
        )

        doc_files_box.upload(
            upload_file_fn, 
            [doc_files_box], 
            [document_index, document_data],
            queue=True,
            trigger_mode="once"
        )
        doc_files_box.clear(
                upload_file_fn, 
                None, 
                [document_index, document_data],
                queue=True,
                trigger_mode="once"
            )
    demo.launch()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_name_or_path", type=str, default="HIT-TMG/KaLM-embedding-multilingual-mini-instruct-v1.5")

    args = parser.parse_args()
    main(args)