Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import pandas as pd | |
| from datasets import load_dataset, get_dataset_split_names | |
| from huggingface_hub import HfApi | |
| import os | |
| import pathlib | |
| import uuid | |
| import logging | |
| # --- Setup Logging --- | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
| # --- Embedding Atlas Imports --- | |
| from embedding_atlas.data_source import DataSource | |
| from embedding_atlas.server import make_server | |
| from embedding_atlas.projection import compute_text_projection | |
| from embedding_atlas.utils import Hasher | |
| # --- Helper function from embedding_atlas/cli.py --- | |
| def find_column_name(existing_names, candidate): | |
| """Finds a unique column name, appending '_1', '_2', etc. if the candidate name already exists.""" | |
| if candidate not in existing_names: | |
| return candidate | |
| else: | |
| index = 1 | |
| while True: | |
| s = f"{candidate}_{index}" | |
| if s not in existing_names: | |
| return s | |
| index += 1 | |
| # --- Hugging Face API Helpers for Dynamic UI --- | |
| hf_api = HfApi() | |
| def get_user_datasets(username: str): | |
| logging.info(f"Fetching datasets for user: {username}") | |
| if not username: | |
| return gr.update(choices=[], value=None, interactive=False) | |
| try: | |
| datasets = hf_api.list_datasets(author=username, full=True) | |
| dataset_ids = [d.id for d in datasets if not d.private] | |
| logging.info(f"Found {len(dataset_ids)} datasets for {username}.") | |
| return gr.update(choices=sorted(dataset_ids), value=None, interactive=True) | |
| except Exception as e: | |
| logging.error(f"Failed to fetch datasets for {username}: {e}") | |
| gr.Warning(f"Could not fetch datasets for user '{username}'.") | |
| return gr.update(choices=[], value=None, interactive=False) | |
| def get_dataset_splits(dataset_id: str): | |
| logging.info(f"Fetching splits for dataset: {dataset_id}") | |
| if not dataset_id: | |
| return gr.update(choices=[], value=None, interactive=False) | |
| try: | |
| splits = get_dataset_split_names(dataset_id) | |
| logging.info(f"Found splits for {dataset_id}: {splits}") | |
| return gr.update(choices=splits, value=splits[0] if splits else None, interactive=True) | |
| except Exception as e: | |
| logging.error(f"Failed to fetch splits for {dataset_id}: {e}") | |
| gr.Warning(f"Could not fetch splits for dataset '{dataset_id}'.") | |
| return gr.update(choices=[], value=None, interactive=False) | |
| def get_split_columns(dataset_id: str, split: str): | |
| logging.info(f"Fetching columns for: {dataset_id}, split: {split}") | |
| if not dataset_id or not split: | |
| return gr.update(choices=[], value=None, interactive=False) | |
| try: | |
| dataset_sample = load_dataset(dataset_id, split=split, streaming=True) | |
| first_row = next(iter(dataset_sample)) | |
| columns = list(first_row.keys()) | |
| logging.info(f"Found columns: {columns}") | |
| preferred_cols = ['text', 'content', 'instruction', 'question', 'document', 'prompt'] | |
| best_col = next((col for col in preferred_cols if col in columns), columns[0] if columns else None) | |
| logging.info(f"Best default column chosen: {best_col}") | |
| return gr.update(choices=columns, value=best_col, interactive=True) | |
| except Exception as e: | |
| logging.error(f"Failed to get columns for {dataset_id}/{split}: {e}", exc_info=True) | |
| gr.Warning(f"Could not fetch columns for split '{split}'. Error: {e}") | |
| return gr.update(choices=[], value=None, interactive=False) | |
| # --- Main Atlas Generation Logic --- | |
| def generate_atlas( | |
| dataset_name: str, | |
| split: str, | |
| text_column: str, | |
| sample_size: int, | |
| model_name: str, | |
| umap_neighbors: int, | |
| umap_min_dist: float, | |
| progress: gr.Progress, | |
| request: gr.Request # <<< STEP 1: ADD THE REQUEST OBJECT TO THE FUNCTION SIGNATURE | |
| ): | |
| """ | |
| Loads data, computes embeddings, and serves the Embedding Atlas UI. | |
| """ | |
| if not all([dataset_name, split, text_column]): | |
| raise gr.Error("Please ensure a Dataset, Split, and Text Column are selected.") | |
| progress(0, desc=f"Loading dataset '{dataset_name}' [{split}]...") | |
| try: | |
| dataset = load_dataset(dataset_name, split=split) | |
| df = dataset.to_pandas() | |
| except Exception as e: | |
| raise gr.Error(f"Failed to load data. Error: {e}") | |
| if sample_size > 0 and sample_size < len(df): | |
| progress(0.1, desc=f"Sampling {sample_size} rows...") | |
| df = df.sample(n=sample_size, random_state=42).reset_index(drop=True) | |
| if text_column not in df.columns: | |
| raise gr.Error(f"Column '{text_column}' not found. Please select a valid column.") | |
| progress(0.2, desc="Computing embeddings and UMAP...") | |
| x_col = find_column_name(df.columns, "projection_x") | |
| y_col = find_column_name(df.columns, "projection_y") | |
| neighbors_col = find_column_name(df.columns, "__neighbors") | |
| try: | |
| compute_text_projection( | |
| df, text_column, x=x_col, y=y_col, neighbors=neighbors_col, model=model_name, | |
| umap_args={"n_neighbors": umap_neighbors, "min_dist": umap_min_dist, "metric": "cosine", "random_state": 42}, | |
| ) | |
| except Exception as e: | |
| raise gr.Error(f"Failed to compute embeddings. Check model name or sample size. Error: {e}") | |
| progress(0.8, desc="Preparing Atlas data source...") | |
| id_col = find_column_name(df.columns, "_row_index") | |
| df[id_col] = range(df.shape[0]) | |
| metadata = {"columns": {"id": id_col, "text": text_column, "embedding": {"x": x_col, "y": y_col}, "neighbors": neighbors_col}} | |
| hasher = Hasher() | |
| hasher.update(f"{dataset_name}-{split}-{text_column}-{sample_size}-{model_name}") | |
| identifier = hasher.hexdigest() | |
| atlas_dataset = DataSource(identifier, df, metadata) | |
| progress(0.9, desc="Mounting visualization UI...") | |
| static_path = str((pathlib.Path(__import__('embedding_atlas').__file__).parent / "static").resolve()) | |
| mount_path = f"/{uuid.uuid4().hex}" | |
| atlas_app = make_server(atlas_dataset, static_path=static_path, duckdb_uri="wasm") | |
| # --- STEP 2: USE THE CORRECT MOUNT METHOD VIA THE REQUEST OBJECT --- | |
| logging.info(f"Mounting FastAPI app at path: {mount_path}") | |
| request.app.mount(mount_path, atlas_app) | |
| progress(1.0, desc="Done!") | |
| iframe_html = f"<iframe src='{mount_path}' width='100%' height='800px' frameborder='0'></iframe>" | |
| return gr.HTML(iframe_html) | |
| # --- Gradio UI Definition --- | |
| with gr.Blocks(theme=gr.themes.Soft(), title="Embedding Atlas Explorer") as app: | |
| gr.Markdown("# Embedding Atlas Explorer") | |
| gr.Markdown("Interactively select and visualize any text-based dataset from the Hugging Face Hub.") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### 1. Select Data") | |
| hf_user_input = gr.Textbox(label="Hugging Face User or Org Name", value="Trendyol", placeholder="e.g., 'gradio' or 'google'") | |
| dataset_input = gr.Dropdown(label="Select a Dataset", interactive=False) | |
| split_input = gr.Dropdown(label="Select a Split", interactive=False) | |
| text_column_input = gr.Dropdown(label="Select a Text Column", interactive=False) | |
| gr.Markdown("### 2. Configure Visualization") | |
| sample_size_input = gr.Slider(label="Number of Samples", minimum=0, maximum=10000, value=2000, step=100) | |
| with gr.Accordion("Advanced Settings", open=False): | |
| model_input = gr.Dropdown(label="Embedding Model", choices=["all-MiniLM-L6-v2", "all-mpnet-base-v2", "multi-qa-MiniLM-L6-cos-v1"], value="all-MiniLM-L6-v2") | |
| umap_neighbors_input = gr.Slider(label="UMAP Neighbors", minimum=2, maximum=100, value=15, step=1, info="Controls local vs. global structure.") | |
| umap_min_dist_input = gr.Slider(label="UMAP Min Distance", minimum=0.0, maximum=0.99, value=0.1, step=0.01, info="Controls how tightly points are packed.") | |
| generate_button = gr.Button("Generate Atlas", variant="primary") | |
| with gr.Column(scale=3): | |
| gr.Markdown("### 3. Explore Atlas") | |
| output_html = gr.HTML("<div style='display:flex; justify-content:center; align-items:center; height:800px; border: 1px solid #ddd; border-radius: 5px;'><p>Atlas will be displayed here after generation.</p></div>") | |
| # --- Chained Event Listeners for Dynamic UI --- | |
| hf_user_input.submit(fn=get_user_datasets, inputs=hf_user_input, outputs=dataset_input) | |
| dataset_input.change(fn=get_dataset_splits, inputs=dataset_input, outputs=split_input) | |
| split_input.change(fn=get_split_columns, inputs=[dataset_input, split_input], outputs=text_column_input) | |
| # --- Button Click Event --- | |
| # Note: We do NOT add `request` to the inputs list. Gradio injects it automatically. | |
| generate_button.click( | |
| fn=generate_atlas, | |
| inputs=[dataset_input, split_input, text_column_input, sample_size_input, model_input, umap_neighbors_input, umap_min_dist_input], | |
| outputs=[output_html], | |
| ) | |
| app.load(fn=get_user_datasets, inputs=hf_user_input, outputs=dataset_input) | |
| if __name__ == "__main__": | |
| app.launch(debug=True) |