mishig's picture
mishig HF staff
Create app.py
a342b03
raw
history blame
2.5 kB
from transformers import AutoTokenizer, AutoModel
import torch
import torch.nn.functional as F
import hnswlib
import gradio as gr
import numpy as np
seperator = "-HFSEP-"
base_name="intfloat/e5-large-v2"
device="cuda"
max_length=512
tokenizer = AutoTokenizer.from_pretrained(base_name)
model = AutoModel.from_pretrained(base_name).to(device)
def get_embeddings(input_texts):
batch_dict = tokenizer(
input_texts,
max_length=max_length,
padding=True,
truncation=True,
return_tensors='pt'
).to(device)
with torch.no_grad():
outputs = model(**batch_dict)
embeddings = _average_pool(
outputs.last_hidden_state, batch_dict['attention_mask']
)
embeddings = F.normalize(embeddings, p=2, dim=1)
embeddings_np = embeddings.cpu().numpy()
if device == "cuda":
del embeddings
torch.cuda.empty_cache()
return embeddings_np
def _average_pool(
last_hidden_states,
attention_mask
):
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
def create_hnsw_index(embeddings_np, space='ip', ef_construction=100, M=16):
index = hnswlib.Index(space=space, dim=len(embeddings_np[0]))
index.init_index(max_elements=len(embeddings_np), ef_construction=ef_construction, M=M)
ids = np.arange(embeddings_np.shape[0])
index.add_items(embeddings_np, ids)
return index
def gradio_function(query, paragraph_chunks, top_k):
paragraph_chunks = paragraph_chunks.split(seperator) # Split the comma-separated values into a list
paragraph_chunks = [item.strip() for item in paragraph_chunks] # Trim whitespace from each item
print("creating embeddings")
embeddings_np = get_embeddings([query]+paragraph_chunks)
query_embedding, chunks_embeddings = embeddings_np[0], embeddings_np[1:]
print("creating index")
search_index = create_hnsw_index(chunks_embeddings)
print("searching index")
labels, _ = search_index.knn_query(query_embedding, k=min(int(top_k), len(chunks_embeddings)))
return f"The closes labels are: {labels}"
interface = gr.Interface(
fn=gradio_function,
inputs=[
gr.Textbox(placeholder="Enter a user query..."),
gr.Textbox(placeholder="Enter comma-separated strings..."),
gr.Number()
],
outputs="text"
)
interface.launch()