File size: 4,549 Bytes
8f6035e
 
 
 
 
d74ddfc
b3758b8
163f1eb
8f6035e
 
 
 
 
 
3400476
e71614a
9454c48
ad0d74d
 
 
6723b69
6ac55a5
 
e71614a
8f6035e
 
 
 
 
 
 
 
e71614a
8f6035e
 
27731ec
ad0d74d
8f6035e
 
ad0d74d
8f6035e
 
e71614a
8f6035e
 
 
 
140d9f4
8f6035e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e71614a
8f6035e
 
 
 
 
3400476
8f6035e
3400476
8f6035e
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import spaces
import torch
import torch.nn.functional as F
from torch import Tensor
from transformers import AutoTokenizer, AutoModel
import gradio as gr
import os

title = """
# 👋🏻Welcome to 🙋🏻‍♂️Tonic's 🐣e5-mistral🛌🏻Embeddings """
description = """
You can use this Space to test out the current model [intfloat/e5-mistral-7b-instruct](https://huggingface.co/intfloat/e5-mistral-7b-instruct). e5mistral has a larger context window, a different prompting/return mechanism and generally better results than other embedding models. 
You can also use 🐣e5-mistral🛌🏻 by cloning this space. 🧬🔬🔍 Simply click here: <a style="display:inline-block" href="https://huggingface.co/spaces/Tonic/e5?duplicate=true"><img src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=&logoWidth=14" alt="Duplicate Space"></a></h3> 
Join us : 🌟TeamTonic🌟 is always making cool demos! Join our active builder's🛠️community on 👻Discord:  [![Let's build the future of AI together! 🚀🤖](https://discordapp.com/api/guilds/1109943800132010065/widget.png)](https://discord.gg/GWpVpekp) On 🤗Huggingface: [TeamTonic](https://huggingface.co/TeamTonic) & [MultiTransformer](https://huggingface.co/MultiTransformer) On 🌐Github: [Polytonic](https://github.com/tonic-ai) & contribute to 🌟 [Poly](https://github.com/tonic-ai/poly)
"""

os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:30'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

tokenizer = AutoTokenizer.from_pretrained('intfloat/e5-mistral-7b-instruct')
model = AutoModel.from_pretrained('intfloat/e5-mistral-7b-instruct', torch_dtype=torch.float16, device_map=device)
# model.half()
# model.to(device)

def last_token_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
    left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
    if left_padding:
        return last_hidden_states[:, -1]
    else:
        sequence_lengths = attention_mask.sum(dim=1) - 1
        batch_size = last_hidden_states.shape[0]
        return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]

def get_detailed_instruct(task_description: str, query: str) -> str:
    return f'Instruct: {task_description}\nQuery: {query}'

    
@spaces.GPU
def compute_embeddings(*input_texts):

    max_length = 4096
    task = 'Given a web search query, retrieve relevant passages that answer the query'

    processed_texts = [get_detailed_instruct(task, text) for text in input_texts]
    batch_dict = tokenizer(processed_texts, max_length=max_length - 1, return_attention_mask=False, padding=False, truncation=True)
    batch_dict['input_ids'] = [input_ids + [tokenizer.eos_token_id] for input_ids in batch_dict['input_ids']]
    batch_dict = tokenizer.pad(batch_dict, padding=True, return_attention_mask=True, return_tensors='pt')
    batch_dict = {k: v.to(device) for k, v in batch_dict.items()}
    outputs = model(**batch_dict)
    embeddings = last_token_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
    embeddings = F.normalize(embeddings, p=2, dim=1)
    embeddings_list = embeddings.detach().cpu().numpy().tolist()
    return embeddings_list
    
def app_interface():
    with gr.Blocks() as demo:
        gr.Markdown(title)
        gr.Markdown(description)
        
        input_text_boxes = [gr.Textbox(label=f"Input Text {i+1}") for i in range(4)]
        
        compute_button = gr.Button("Compute Embeddings")
        
        output_display = gr.Dataframe(headers=["Embedding Value"], datatype=["number"])
        
        with gr.Row():
            with gr.Column():
                for text_box in input_text_boxes:
                    text_box
            with gr.Column():
                compute_button
                output_display

        compute_button.click(
            fn=compute_embeddings,
            inputs=input_text_boxes,
            outputs=output_display
        )

    return demo

# Run the Gradio app
app_interface().launch()