KaLM-Embedding / app.py
YanshekWoo's picture
Upload folder using huggingface_hub
f04d15c verified
raw
history blame
5.87 kB
import json
import argparse
from pathlib import Path
from typing import List
import gradio as gr
import faiss
import numpy as np
import torch
from tqdm import tqdm
from sentence_transformers import SentenceTransformer
file_example = """Please upload a JSON file with a "text" field (with optional "title" field). For example
```JSON
[
{"title": "", "text": "This an example text without the title"},
{"title": "Title A", "text": "This an example text with the title"},
{"title": "Title B", "text": "This an example text with the title"},
]
```
Due to the computation resources, please test with small scale data.
"""
def create_index(embeddings, use_gpu):
index = faiss.IndexFlatIP(len(embeddings[0]))
embeddings = np.asarray(embeddings, dtype=np.float32)
if use_gpu:
co = faiss.GpuMultipleClonerOptions()
co.shard = True
co.useFloat16 = True
index = faiss.index_cpu_to_all_gpus(index, co=co)
index.add(embeddings)
return index
def upload_file_fn(
file_path: List[str],
progress: gr.Progress = gr.Progress(track_tqdm=True)
):
try:
with open(file_path) as f:
document_data = json.load(f)
documents = []
for obj in document_data:
text = obj["title"] + "\n" + obj["text"] if obj.get("title") else obj["text"]
if len(str(text).strip()):
documents.append(text)
else:
documents.append(model.tokenizer.eos_token)
except Exception as e:
print(e)
gr.Error("Read the file failed. Please check the data format.")
gr.Error(str(e))
return None, gr.update(interactive=False)
if len(documents) < 3:
gr.Error("Please upload more than 3 documents.")
return None, gr.update(interactive=False)
gr.Info(f"Upload {len(documents)} documents.")
if len(documents) > 1000:
gr.Info(f"Cut uploaded documents to 1000 due to the computation resource.")
documents = documents[: 1000]
# documents_embeddings = model.encode(documents, show_progress_bar=True)
documents_embeddings = []
batch_size = 8
for i in tqdm(range(0, len(documents), batch_size)):
batch_documents = documents[i: i+batch_size]
batch_embeddings = model.encode(batch_documents, show_progress_bar=True)
documents_embeddings.extend(batch_embeddings)
document_index = create_index(documents_embeddings, use_gpu=False)
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
document_state = {"document_data": document_data, "document_index": document_index}
return document_state, gr.update(interactive=True)
def clear_file_fn():
return None, gr.update(interactive=True)
def retrieve_document_fn(question, document_states, instruct):
num_retrieval_doc = 3
if document_states is None:
gr.Warning("Please upload documents first!")
return [None for i in range(num_retrieval_doc)] + [None]
document_data, document_index = document_states["document_data"], document_states["document_index"]
question_embedding = model.encode([str(instruct) + str(question)])
batch_scores, batch_inxs = document_index.search(question_embedding, k=min(len(document_data), 150))
answers = [document_data[i]["text"] for i in batch_inxs[0][:num_retrieval_doc]]
return answers[0], answers[1], answers[2], document_states
def main(args):
global model
model = SentenceTransformer(
args.model_name_or_path,
revision=args.revision,
backend="openvino")
document_state = gr.State()
with open(Path(__file__).parent / "resources/head.html") as html_file:
head = html_file.read().strip()
with gr.Blocks(theme=gr.themes.Soft(font="sans-serif").set(background_fill_primary="linear-gradient(90deg, #e3ffe7 0%, #d9e7ff 100%)", background_fill_primary_dark="linear-gradient(90deg, #4b6cb7 0%, #182848 100%)",),
head=head,
css=Path(__file__).parent / "resources/styles.css",
title="KaLM-Embedding",
fill_height=True,
analytics_enabled=False) as demo:
gr.Markdown(file_example)
doc_files_box = gr.File(label="Upload Documents", file_types=[".json"], file_count="single")
retrieval_interface = gr.Interface(
fn=retrieve_document_fn,
inputs=[gr.Textbox(label="Query"), document_state],
outputs=[gr.Text(label="Recall-1"), gr.Text(label="Recall-2"), gr.Text(label="Recall-3"), gr.State()],
additional_inputs=[gr.Textbox("Instruct: Given a query, retrieve documents that answer the query. \n Query: ", label="Instruct of Query", lines=2)],
concurrency_limit=1,
)
# retrieval_interface.input_components[0] = gr.update(interactive=False)
doc_files_box.upload(
upload_file_fn,
[doc_files_box],
[document_state, retrieval_interface.input_components[0]],
queue=True,
trigger_mode="once"
)
doc_files_box.clear(
clear_file_fn,
None,
[document_state, retrieval_interface.input_components[0]],
queue=True,
trigger_mode="once"
)
demo.launch()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model_name_or_path", type=str, default="HIT-TMG/KaLM-embedding-multilingual-mini-instruct-v1.5")
# parser.add_argument("--model_name_or_path", type=str, default="/raid/hxs/Checkpoints/huggingface_models/bge-base-en-v1.5")
parser.add_argument("--revision", type=str, default="refs/pr/2")
args = parser.parse_args()
main(args)