Spaces:
Sleeping
Sleeping
import gradio as gr | |
import pandas as pd | |
from datasets import load_dataset | |
import os | |
import pathlib | |
import uuid | |
# --- Embedding Atlas Imports --- | |
# We will import the necessary components directly from the library | |
from embedding_atlas.data_source import DataSource | |
from embedding_atlas.server import make_server | |
from embedding_atlas.projection import compute_text_projection | |
from embedding_atlas.utils import find_column_name, Hasher | |
# --- Global State --- | |
# We need to keep track of the mounted app to avoid errors on re-runs. | |
# A dictionary to store unique app instances for each run. | |
mounted_apps = {} | |
def get_atlas_static_path(): | |
"""Finds the path to the static files for the embedding-atlas frontend.""" | |
import embedding_atlas | |
return str((pathlib.Path(embedding_atlas.__file__).parent / "static").resolve()) | |
def generate_atlas( | |
dataset_name: str, | |
text_column: str, | |
split: str, | |
sample_size: int, | |
model_name: str, | |
umap_neighbors: int, | |
umap_min_dist: float, | |
progress=gr.Progress(track_ τότε=True) | |
): | |
""" | |
Loads data, computes embeddings, and serves the Embedding Atlas UI. | |
""" | |
global mounted_apps | |
# --- 1. Load Data --- | |
progress(0, desc=f"Loading dataset '{dataset_name}'...") | |
try: | |
# Load the dataset from Hugging Face | |
dataset = load_dataset(dataset_name, split=split) | |
df = dataset.to_pandas() | |
except Exception as e: | |
raise gr.Error(f"Failed to load dataset. Please check the name and split. Error: {e}") | |
# --- 2. Sample Data (if requested) --- | |
if sample_size > 0 and sample_size < len(df): | |
progress(0.1, desc=f"Sampling {sample_size} rows...") | |
df = df.sample(n=sample_size, random_state=42) | |
# Check if the text column exists | |
if text_column not in df.columns: | |
raise gr.Error(f"Column '{text_column}' not found in the dataset. Available columns: {', '.join(df.columns)}") | |
# --- 3. Compute Embeddings & UMAP Projection --- | |
progress(0.2, desc="Computing embeddings and UMAP projection. This may take a while...") | |
x_column = find_column_name(df.columns, "projection_x") | |
y_column = find_column_name(df.columns, "projection_y") | |
neighbors_column = find_column_name(df.columns, "__neighbors") | |
try: | |
compute_text_projection( | |
df, | |
text_column, | |
x=x_column, | |
y=y_column, | |
neighbors=neighbors_column, | |
model=model_name, | |
umap_args={ | |
"n_neighbors": umap_neighbors, | |
"min_dist": umap_min_dist, | |
"metric": "cosine", | |
"random_state": 42, | |
}, | |
) | |
except Exception as e: | |
raise gr.Error(f"Failed to compute embeddings. Check model name or try a smaller sample. Error: {e}") | |
# --- 4. Prepare Atlas DataSource --- | |
progress(0.8, desc="Preparing Atlas data source...") | |
id_column = find_column_name(df.columns, "_row_index") | |
df[id_column] = range(df.shape[0]) | |
metadata = { | |
"columns": { | |
"id": id_column, | |
"text": text_column, | |
"embedding": {"x": x_column, "y": y_column}, | |
"neighbors": neighbors_column, | |
}, | |
} | |
# Create a unique identifier for the dataset to avoid conflicts | |
hasher = Hasher() | |
hasher.update(f"{dataset_name}-{text_column}-{sample_size}-{model_name}") | |
identifier = hasher.hexdigest() | |
atlas_dataset = DataSource(identifier, df, metadata) | |
static_path = get_atlas_static_path() | |
# --- 5. Create and Mount the FastAPI App --- | |
progress(0.9, desc="Mounting visualization UI...") | |
# Generate a unique path for this instance to avoid conflicts on remounting | |
mount_path = f"/{uuid.uuid4().hex}" | |
# Create the server instance | |
atlas_app = make_server(atlas_dataset, static_path=static_path, duckdb_uri="wasm") | |
# The `blocks` object is global in the Gradio script context. | |
# We mount the atlas server onto the main Gradio FastAPI app. | |
gr.mount_gradio_app(app, atlas_app, path=mount_path) | |
mounted_apps[mount_path] = atlas_app # Store it for potential cleanup later | |
progress(1.0, desc="Done!") | |
# --- 6. Return an IFrame pointing to the mounted path --- | |
iframe_html = f"<iframe src='{mount_path}' width='100%' height='800px' frameborder='0'></iframe>" | |
return gr.HTML(iframe_html) | |
# --- Gradio UI Definition --- | |
with gr.Blocks(theme=gr.themes.Soft(), title="Embedding Atlas Explorer") as app: | |
gr.Markdown("# Embedding Atlas Explorer") | |
gr.Markdown( | |
"Visualize any text column from a Hugging Face dataset. This app loads the data, " | |
"computes embeddings using Sentence Transformers, reduces dimensionality with UMAP, " | |
"and displays the result in an interactive Embedding Atlas." | |
) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gr.Markdown("### 1. Configuration") | |
dataset_input = gr.Textbox( | |
label="Hugging Face Dataset Name", | |
value="Trendyol/Trendyol-Cybersecurity-Instruction-Tuning-Dataset" | |
) | |
text_column_input = gr.Textbox( | |
label="Text Column to Visualize", | |
value="instruction" | |
) | |
split_input = gr.Textbox(label="Dataset Split", value="train") | |
sample_size_input = gr.Slider( | |
label="Number of Samples (0 for all)", | |
minimum=0, | |
maximum=5000, | |
value=2000, | |
step=100 | |
) | |
with gr.Accordion("Advanced Settings", open=False): | |
model_input = gr.Dropdown( | |
label="Embedding Model", | |
choices=["all-MiniLM-L6-v2", "all-mpnet-base-v2", "multi-qa-MiniLM-L6-cos-v1"], | |
value="all-MiniLM-L6-v2", | |
) | |
umap_neighbors_input = gr.Slider( | |
label="UMAP Neighbors", | |
minimum=2, | |
maximum=100, | |
value=15, | |
step=1, | |
info="Controls how UMAP balances local vs. global structure." | |
) | |
umap_min_dist_input = gr.Slider( | |
label="UMAP Min Distance", | |
minimum=0.0, | |
maximum=0.99, | |
value=0.1, | |
step=0.01, | |
info="Controls how tightly UMAP packs points together." | |
) | |
generate_button = gr.Button("Generate Atlas", variant="primary") | |
with gr.Column(scale=3): | |
gr.Markdown("### 2. Visualization") | |
output_html = gr.HTML( | |
"<div style='display:flex; justify-content:center; align-items:center; height:800px; border: 1px solid #ddd; border-radius: 5px;'><p>Atlas will be displayed here after generation.</p></div>" | |
) | |
generate_button.click( | |
fn=generate_atlas, | |
inputs=[ | |
dataset_input, | |
text_column_input, | |
split_input, | |
sample_size_input, | |
model_input, | |
umap_neighbors_input, | |
umap_min_dist_input, | |
], | |
outputs=[output_html], | |
) | |
if __name__ == "__main__": | |
app.launch() |