import gradio as gr import pandas as pd from datasets import load_dataset, get_dataset_split_names from huggingface_hub import HfApi import os import pathlib import uuid # --- Embedding Atlas Imports --- from embedding_atlas.data_source import DataSource from embedding_atlas.server import make_server from embedding_atlas.projection import compute_text_projection from embedding_atlas.utils import Hasher # --- Helper function from embedding_atlas/cli.py --- def find_column_name(existing_names, candidate): """Finds a unique column name, appending '_1', '_2', etc. if the candidate name already exists.""" if candidate not in existing_names: return candidate else: index = 1 while True: s = f"{candidate}_{index}" if s not in existing_names: return s index += 1 # --- Hugging Face API Helpers for Dynamic UI --- hf_api = HfApi() def get_user_datasets(username: str): """Fetches all public datasets for a given username or organization.""" if not username: return gr.Dropdown.update(choices=[], value=None, interactive=False) try: # --- THIS IS THE FIX --- # Replace deprecated 'cardData=True' with 'full=True' datasets = hf_api.list_datasets(author=username, full=True) dataset_ids = [d.id for d in datasets if not d.private] return gr.Dropdown.update(choices=sorted(dataset_ids), value=None, interactive=True) except Exception as e: gr.Warning(f"Could not fetch datasets for user '{username}'. Error: {e}") return gr.update(choices=[], value=None, interactive=False) def get_dataset_splits(dataset_id: str): """Gets all available splits for a selected dataset.""" if not dataset_id: return gr.Dropdown.update(choices=[], value=None, interactive=False) try: splits = get_dataset_split_names(dataset_id) return gr.Dropdown.update(choices=splits, value=splits[0] if splits else None, interactive=True) except Exception as e: gr.Warning(f"Could not fetch splits for dataset '{dataset_id}'. Error: {e}") return gr.update(choices=[], value=None, interactive=False) def get_split_columns(dataset_id: str, split: str): """Gets all columns for a selected split by loading one row.""" if not dataset_id or not split: return gr.update(choices=[], value=None, interactive=False) try: # Stream one row to get column names without downloading the whole dataset dataset_sample = load_dataset(dataset_id, split=split, streaming=True) first_row = next(iter(dataset_sample)) columns = list(first_row.keys()) # Heuristically find the best text column preferred_cols = ['text', 'content', 'instruction', 'question', 'document', 'prompt'] best_col = next((col for col in preferred_cols if col in columns), columns[0] if columns else None) return gr.update(choices=columns, value=best_col, interactive=True) except Exception as e: gr.Warning(f"Could not fetch columns for split '{split}'. Error: {e}") return gr.update(choices=[], value=None, interactive=False) # --- Main Atlas Generation Logic --- def generate_atlas( dataset_name: str, split: str, text_column: str, sample_size: int, model_name: str, umap_neighbors: int, umap_min_dist: float, progress=gr.Progress(track_tqdm=True) ): """ Loads data, computes embeddings, and serves the Embedding Atlas UI. """ if not all([dataset_name, split, text_column]): raise gr.Error("Please ensure a Dataset, Split, and Text Column are selected.") progress(0, desc=f"Loading dataset '{dataset_name}' [{split}]...") try: dataset = load_dataset(dataset_name, split=split) df = dataset.to_pandas() except Exception as e: raise gr.Error(f"Failed to load data. Error: {e}") if sample_size > 0 and sample_size < len(df): progress(0.1, desc=f"Sampling {sample_size} rows...") df = df.sample(n=sample_size, random_state=42).reset_index(drop=True) if text_column not in df.columns: raise gr.Error(f"Column '{text_column}' not found. Please select a valid column.") progress(0.2, desc="Computing embeddings and UMAP. This may take a while...") x_col = find_column_name(df.columns, "projection_x") y_col = find_column_name(df.columns, "projection_y") neighbors_col = find_column_name(df.columns, "__neighbors") try: compute_text_projection( df, text_column, x=x_col, y=y_col, neighbors=neighbors_col, model=model_name, umap_args={"n_neighbors": umap_neighbors, "min_dist": umap_min_dist, "metric": "cosine", "random_state": 42}, ) except Exception as e: raise gr.Error(f"Failed to compute embeddings. Check model name or try a smaller sample. Error: {e}") progress(0.8, desc="Preparing Atlas data source...") id_col = find_column_name(df.columns, "_row_index") df[id_col] = range(df.shape[0]) metadata = { "columns": {"id": id_col, "text": text_column, "embedding": {"x": x_col, "y": y_col}, "neighbors": neighbors_col}, } hasher = Hasher() hasher.update(f"{dataset_name}-{split}-{text_column}-{sample_size}-{model_name}") identifier = hasher.hexdigest() atlas_dataset = DataSource(identifier, df, metadata) progress(0.9, desc="Mounting visualization UI...") static_path = str((pathlib.Path(__import__('embedding_atlas').__file__).parent / "static").resolve()) mount_path = f"/{uuid.uuid4().hex}" atlas_app = make_server(atlas_dataset, static_path=static_path, duckdb_uri="wasm") app.mount_gradio_app(atlas_app, path=mount_path) progress(1.0, desc="Done!") iframe_html = f"" return gr.HTML(iframe_html) # --- Gradio UI Definition --- with gr.Blocks(theme=gr.themes.Soft(), title="Embedding Atlas Explorer") as app: gr.Markdown("# Embedding Atlas Explorer") gr.Markdown( "Interactively select and visualize any text-based dataset from the Hugging Face Hub. " "The app computes embeddings and projects them into a 2D map for exploration." ) with gr.Row(): with gr.Column(scale=1): gr.Markdown("### 1. Select Data") hf_user_input = gr.Textbox(label="Hugging Face User or Org Name", value="Trendyol", placeholder="e.g., 'gradio' or 'google'") dataset_input = gr.Dropdown(label="Select a Dataset", interactive=False) split_input = gr.Dropdown(label="Select a Split", interactive=False) text_column_input = gr.Dropdown(label="Select a Text Column", interactive=False) gr.Markdown("### 2. Configure Visualization") sample_size_input = gr.Slider(label="Number of Samples", minimum=0, maximum=10000, value=2000, step=100) with gr.Accordion("Advanced Settings", open=False): model_input = gr.Dropdown(label="Embedding Model", choices=["all-MiniLM-L6-v2", "all-mpnet-base-v2", "multi-qa-MiniLM-L6-cos-v1"], value="all-MiniLM-L6-v2") umap_neighbors_input = gr.Slider(label="UMAP Neighbors", minimum=2, maximum=100, value=15, step=1, info="Controls local vs. global structure.") umap_min_dist_input = gr.Slider(label="UMAP Min Distance", minimum=0.0, maximum=0.99, value=0.1, step=0.01, info="Controls how tightly points are packed.") generate_button = gr.Button("Generate Atlas", variant="primary") with gr.Column(scale=3): gr.Markdown("### 3. Explore Atlas") output_html = gr.HTML("

Atlas will be displayed here after generation.

") # --- Chained Event Listeners for Dynamic UI --- hf_user_input.submit( fn=get_user_datasets, inputs=[hf_user_input], outputs=[dataset_input] ) dataset_input.select( fn=get_dataset_splits, inputs=[dataset_input], outputs=[split_input] ) split_input.select( fn=get_split_columns, inputs=[dataset_input, split_input], outputs=[text_column_input] ) # --- Button Click Event --- generate_button.click( fn=generate_atlas, inputs=[ dataset_input, split_input, text_column_input, sample_size_input, model_input, umap_neighbors_input, umap_min_dist_input ], outputs=[output_html], ) # Load initial example data on app load app.load(fn=get_user_datasets, inputs=[hf_user_input], outputs=[dataset_input]) if __name__ == "__main__": app.launch()