File size: 2,409 Bytes
7103ccc 748826b 7103ccc 714a27c 748826b 416fea8 748826b 2371338 748826b d590a55 e856ebd d590a55 e856ebd 748826b e856ebd 748826b 416fea8 748826b 1c026a2 e856ebd 1c026a2 e856ebd 1c026a2 e856ebd 1c026a2 e856ebd 5e78e4f 9006e63 1c026a2 9006e63 5cd07f9 9006e63 e856ebd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 |
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModel
import plotly.graph_objects as go
model_name = "mistralai/Mistral-7B-v0.1"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = None
# Set pad token to eos token if not defined
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
def get_embedding(text):
global model
if model is None:
model = AutoModel.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16)
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512).to(model.device)
with torch.no_grad():
outputs = model(**inputs)
return outputs.last_hidden_state.mean(dim=1).squeeze().cpu()
def reduce_to_3d(embedding):
return embedding[:3]
def compare_embeddings(text_input):
try:
texts = [t.strip() for t in text_input.split('\n') if t.strip()]
embeddings = [get_embedding(text) for text in texts]
embeddings_3d = [reduce_to_3d(emb) for emb in embeddings]
fig = go.Figure()
# Add origin point (black)
fig.add_trace(go.Scatter3d(x=[0], y=[0], z=[0], mode='markers', name='Origin',
marker=dict(size=5, color='black')))
# Add lines and points for each text embedding
colors = ['red', 'blue', 'green', 'purple', 'orange', 'cyan', 'magenta', 'yellow']
for i, emb in enumerate(embeddings_3d):
color = colors[i % len(colors)]
fig.add_trace(go.Scatter3d(x=[0, emb[0].item()], y=[0, emb[1].item()], z=[0, emb[2].item()],
mode='lines+markers', name=f'Text {i+1}',
line=dict(color=color), marker=dict(color=color)))
fig.update_layout(scene=dict(xaxis_title='X', yaxis_title='Y', zaxis_title='Z'))
return fig
except Exception as e:
return f"An error occurred: {str(e)}"
iface = gr.Interface(
fn=compare_embeddings,
inputs=[
gr.Textbox(label="Input Texts", lines=5, placeholder="Enter multiple texts, each on a new line")
],
outputs=gr.Plot(),
title="3D Embedding Comparison",
description="Compare the embeddings of multiple strings visualized in 3D space using Mistral 7B.",
allow_flagging="never"
)
if __name__ == "__main__":
iface.launch() |