Spaces:
Running
on
T4
Running
on
T4
from transformers import AutoTokenizer, AutoModel | |
import torch | |
import torch.nn.functional as F | |
import hnswlib | |
import gradio as gr | |
import numpy as np | |
seperator = "-HFSEP-" | |
base_name="intfloat/e5-large-v2" | |
device="cuda" | |
max_length=512 | |
tokenizer = AutoTokenizer.from_pretrained(base_name) | |
model = AutoModel.from_pretrained(base_name).to(device) | |
def get_embeddings(input_texts): | |
batch_dict = tokenizer( | |
input_texts, | |
max_length=max_length, | |
padding=True, | |
truncation=True, | |
return_tensors='pt' | |
).to(device) | |
with torch.no_grad(): | |
outputs = model(**batch_dict) | |
embeddings = _average_pool( | |
outputs.last_hidden_state, batch_dict['attention_mask'] | |
) | |
embeddings = F.normalize(embeddings, p=2, dim=1) | |
embeddings_np = embeddings.cpu().numpy() | |
if device == "cuda": | |
del embeddings | |
torch.cuda.empty_cache() | |
return embeddings_np | |
def _average_pool( | |
last_hidden_states, | |
attention_mask | |
): | |
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0) | |
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] | |
def create_hnsw_index(embeddings_np, space='ip', ef_construction=100, M=16): | |
index = hnswlib.Index(space=space, dim=len(embeddings_np[0])) | |
index.init_index(max_elements=len(embeddings_np), ef_construction=ef_construction, M=M) | |
ids = np.arange(embeddings_np.shape[0]) | |
index.add_items(embeddings_np, ids) | |
return index | |
def gradio_function(query, paragraph_chunks, top_k): | |
paragraph_chunks = paragraph_chunks.split(seperator) # Split the comma-separated values into a list | |
paragraph_chunks = [item.strip() for item in paragraph_chunks] # Trim whitespace from each item | |
print("creating embeddings") | |
embeddings_np = get_embeddings([query]+paragraph_chunks) | |
query_embedding, chunks_embeddings = embeddings_np[0], embeddings_np[1:] | |
print("creating index") | |
search_index = create_hnsw_index(chunks_embeddings) | |
print("searching index") | |
labels, _ = search_index.knn_query(query_embedding, k=min(int(top_k), len(chunks_embeddings))) | |
return f"The closes labels are: {labels}" | |
interface = gr.Interface( | |
fn=gradio_function, | |
inputs=[ | |
gr.Textbox(placeholder="Enter a user query..."), | |
gr.Textbox(placeholder="Enter comma-separated strings..."), | |
gr.Number() | |
], | |
outputs="text" | |
) | |
interface.launch() |