import gradio as gr import pandas as pd from datasets import load_dataset import os import pathlib import uuid # --- Embedding Atlas Imports --- # We will import the necessary components directly from the library from embedding_atlas.data_source import DataSource from embedding_atlas.server import make_server from embedding_atlas.projection import compute_text_projection from embedding_atlas.utils import find_column_name, Hasher # --- Global State --- # We need to keep track of the mounted app to avoid errors on re-runs. # A dictionary to store unique app instances for each run. mounted_apps = {} def get_atlas_static_path(): """Finds the path to the static files for the embedding-atlas frontend.""" import embedding_atlas return str((pathlib.Path(embedding_atlas.__file__).parent / "static").resolve()) def generate_atlas( dataset_name: str, text_column: str, split: str, sample_size: int, model_name: str, umap_neighbors: int, umap_min_dist: float, progress=gr.Progress(track_ τότε=True) ): """ Loads data, computes embeddings, and serves the Embedding Atlas UI. """ global mounted_apps # --- 1. Load Data --- progress(0, desc=f"Loading dataset '{dataset_name}'...") try: # Load the dataset from Hugging Face dataset = load_dataset(dataset_name, split=split) df = dataset.to_pandas() except Exception as e: raise gr.Error(f"Failed to load dataset. Please check the name and split. Error: {e}") # --- 2. Sample Data (if requested) --- if sample_size > 0 and sample_size < len(df): progress(0.1, desc=f"Sampling {sample_size} rows...") df = df.sample(n=sample_size, random_state=42) # Check if the text column exists if text_column not in df.columns: raise gr.Error(f"Column '{text_column}' not found in the dataset. Available columns: {', '.join(df.columns)}") # --- 3. Compute Embeddings & UMAP Projection --- progress(0.2, desc="Computing embeddings and UMAP projection. This may take a while...") x_column = find_column_name(df.columns, "projection_x") y_column = find_column_name(df.columns, "projection_y") neighbors_column = find_column_name(df.columns, "__neighbors") try: compute_text_projection( df, text_column, x=x_column, y=y_column, neighbors=neighbors_column, model=model_name, umap_args={ "n_neighbors": umap_neighbors, "min_dist": umap_min_dist, "metric": "cosine", "random_state": 42, }, ) except Exception as e: raise gr.Error(f"Failed to compute embeddings. Check model name or try a smaller sample. Error: {e}") # --- 4. Prepare Atlas DataSource --- progress(0.8, desc="Preparing Atlas data source...") id_column = find_column_name(df.columns, "_row_index") df[id_column] = range(df.shape[0]) metadata = { "columns": { "id": id_column, "text": text_column, "embedding": {"x": x_column, "y": y_column}, "neighbors": neighbors_column, }, } # Create a unique identifier for the dataset to avoid conflicts hasher = Hasher() hasher.update(f"{dataset_name}-{text_column}-{sample_size}-{model_name}") identifier = hasher.hexdigest() atlas_dataset = DataSource(identifier, df, metadata) static_path = get_atlas_static_path() # --- 5. Create and Mount the FastAPI App --- progress(0.9, desc="Mounting visualization UI...") # Generate a unique path for this instance to avoid conflicts on remounting mount_path = f"/{uuid.uuid4().hex}" # Create the server instance atlas_app = make_server(atlas_dataset, static_path=static_path, duckdb_uri="wasm") # The `blocks` object is global in the Gradio script context. # We mount the atlas server onto the main Gradio FastAPI app. gr.mount_gradio_app(app, atlas_app, path=mount_path) mounted_apps[mount_path] = atlas_app # Store it for potential cleanup later progress(1.0, desc="Done!") # --- 6. Return an IFrame pointing to the mounted path --- iframe_html = f"" return gr.HTML(iframe_html) # --- Gradio UI Definition --- with gr.Blocks(theme=gr.themes.Soft(), title="Embedding Atlas Explorer") as app: gr.Markdown("# Embedding Atlas Explorer") gr.Markdown( "Visualize any text column from a Hugging Face dataset. This app loads the data, " "computes embeddings using Sentence Transformers, reduces dimensionality with UMAP, " "and displays the result in an interactive Embedding Atlas." ) with gr.Row(): with gr.Column(scale=1): gr.Markdown("### 1. Configuration") dataset_input = gr.Textbox( label="Hugging Face Dataset Name", value="Trendyol/Trendyol-Cybersecurity-Instruction-Tuning-Dataset" ) text_column_input = gr.Textbox( label="Text Column to Visualize", value="instruction" ) split_input = gr.Textbox(label="Dataset Split", value="train") sample_size_input = gr.Slider( label="Number of Samples (0 for all)", minimum=0, maximum=5000, value=2000, step=100 ) with gr.Accordion("Advanced Settings", open=False): model_input = gr.Dropdown( label="Embedding Model", choices=["all-MiniLM-L6-v2", "all-mpnet-base-v2", "multi-qa-MiniLM-L6-cos-v1"], value="all-MiniLM-L6-v2", ) umap_neighbors_input = gr.Slider( label="UMAP Neighbors", minimum=2, maximum=100, value=15, step=1, info="Controls how UMAP balances local vs. global structure." ) umap_min_dist_input = gr.Slider( label="UMAP Min Distance", minimum=0.0, maximum=0.99, value=0.1, step=0.01, info="Controls how tightly UMAP packs points together." ) generate_button = gr.Button("Generate Atlas", variant="primary") with gr.Column(scale=3): gr.Markdown("### 2. Visualization") output_html = gr.HTML( "
Atlas will be displayed here after generation.