File size: 2,409 Bytes
7103ccc
 
748826b
 
7103ccc
714a27c
748826b
416fea8
748826b
2371338
 
 
 
748826b
d590a55
 
e856ebd
d590a55
e856ebd
748826b
 
e856ebd
748826b
416fea8
 
748826b
1c026a2
e856ebd
 
 
 
 
 
1c026a2
e856ebd
 
 
1c026a2
e856ebd
 
 
 
 
 
 
1c026a2
e856ebd
 
 
 
 
5e78e4f
9006e63
 
 
1c026a2
9006e63
 
 
5cd07f9
 
9006e63
 
e856ebd
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModel
import plotly.graph_objects as go

model_name = "mistralai/Mistral-7B-v0.1"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = None

# Set pad token to eos token if not defined
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

def get_embedding(text):
    global model
    if model is None:
        model = AutoModel.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16)
    
    inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512).to(model.device)
    with torch.no_grad():
        outputs = model(**inputs)
    return outputs.last_hidden_state.mean(dim=1).squeeze().cpu()

def reduce_to_3d(embedding):
    return embedding[:3]

def compare_embeddings(text_input):
    try:
        texts = [t.strip() for t in text_input.split('\n') if t.strip()]
        embeddings = [get_embedding(text) for text in texts]
        embeddings_3d = [reduce_to_3d(emb) for emb in embeddings]
        
        fig = go.Figure()

        # Add origin point (black)
        fig.add_trace(go.Scatter3d(x=[0], y=[0], z=[0], mode='markers', name='Origin',
                                   marker=dict(size=5, color='black')))

        # Add lines and points for each text embedding
        colors = ['red', 'blue', 'green', 'purple', 'orange', 'cyan', 'magenta', 'yellow']
        for i, emb in enumerate(embeddings_3d):
            color = colors[i % len(colors)]
            fig.add_trace(go.Scatter3d(x=[0, emb[0].item()], y=[0, emb[1].item()], z=[0, emb[2].item()], 
                                       mode='lines+markers', name=f'Text {i+1}',
                                       line=dict(color=color), marker=dict(color=color)))

        fig.update_layout(scene=dict(xaxis_title='X', yaxis_title='Y', zaxis_title='Z'))
        
        return fig
    except Exception as e:
        return f"An error occurred: {str(e)}"

iface = gr.Interface(
    fn=compare_embeddings,
    inputs=[
        gr.Textbox(label="Input Texts", lines=5, placeholder="Enter multiple texts, each on a new line")
    ],
    outputs=gr.Plot(),
    title="3D Embedding Comparison",
    description="Compare the embeddings of multiple strings visualized in 3D space using Mistral 7B.",
    allow_flagging="never"
)

if __name__ == "__main__":
    iface.launch()