Embedding-Atlas / app.py
broadfield-dev's picture
Update app.py
bb95205 verified
raw
history blame
7.33 kB
import gradio as gr
import pandas as pd
from datasets import load_dataset
import os
import pathlib
import uuid
# --- Embedding Atlas Imports ---
# We will import the necessary components directly from the library
from embedding_atlas.data_source import DataSource
from embedding_atlas.server import make_server
from embedding_atlas.projection import compute_text_projection
from embedding_atlas.utils import find_column_name, Hasher
# --- Global State ---
# We need to keep track of the mounted app to avoid errors on re-runs.
# A dictionary to store unique app instances for each run.
mounted_apps = {}
def get_atlas_static_path():
"""Finds the path to the static files for the embedding-atlas frontend."""
import embedding_atlas
return str((pathlib.Path(embedding_atlas.__file__).parent / "static").resolve())
def generate_atlas(
dataset_name: str,
text_column: str,
split: str,
sample_size: int,
model_name: str,
umap_neighbors: int,
umap_min_dist: float,
progress=gr.Progress(track_ τότε=True)
):
"""
Loads data, computes embeddings, and serves the Embedding Atlas UI.
"""
global mounted_apps
# --- 1. Load Data ---
progress(0, desc=f"Loading dataset '{dataset_name}'...")
try:
# Load the dataset from Hugging Face
dataset = load_dataset(dataset_name, split=split)
df = dataset.to_pandas()
except Exception as e:
raise gr.Error(f"Failed to load dataset. Please check the name and split. Error: {e}")
# --- 2. Sample Data (if requested) ---
if sample_size > 0 and sample_size < len(df):
progress(0.1, desc=f"Sampling {sample_size} rows...")
df = df.sample(n=sample_size, random_state=42)
# Check if the text column exists
if text_column not in df.columns:
raise gr.Error(f"Column '{text_column}' not found in the dataset. Available columns: {', '.join(df.columns)}")
# --- 3. Compute Embeddings & UMAP Projection ---
progress(0.2, desc="Computing embeddings and UMAP projection. This may take a while...")
x_column = find_column_name(df.columns, "projection_x")
y_column = find_column_name(df.columns, "projection_y")
neighbors_column = find_column_name(df.columns, "__neighbors")
try:
compute_text_projection(
df,
text_column,
x=x_column,
y=y_column,
neighbors=neighbors_column,
model=model_name,
umap_args={
"n_neighbors": umap_neighbors,
"min_dist": umap_min_dist,
"metric": "cosine",
"random_state": 42,
},
)
except Exception as e:
raise gr.Error(f"Failed to compute embeddings. Check model name or try a smaller sample. Error: {e}")
# --- 4. Prepare Atlas DataSource ---
progress(0.8, desc="Preparing Atlas data source...")
id_column = find_column_name(df.columns, "_row_index")
df[id_column] = range(df.shape[0])
metadata = {
"columns": {
"id": id_column,
"text": text_column,
"embedding": {"x": x_column, "y": y_column},
"neighbors": neighbors_column,
},
}
# Create a unique identifier for the dataset to avoid conflicts
hasher = Hasher()
hasher.update(f"{dataset_name}-{text_column}-{sample_size}-{model_name}")
identifier = hasher.hexdigest()
atlas_dataset = DataSource(identifier, df, metadata)
static_path = get_atlas_static_path()
# --- 5. Create and Mount the FastAPI App ---
progress(0.9, desc="Mounting visualization UI...")
# Generate a unique path for this instance to avoid conflicts on remounting
mount_path = f"/{uuid.uuid4().hex}"
# Create the server instance
atlas_app = make_server(atlas_dataset, static_path=static_path, duckdb_uri="wasm")
# The `blocks` object is global in the Gradio script context.
# We mount the atlas server onto the main Gradio FastAPI app.
gr.mount_gradio_app(app, atlas_app, path=mount_path)
mounted_apps[mount_path] = atlas_app # Store it for potential cleanup later
progress(1.0, desc="Done!")
# --- 6. Return an IFrame pointing to the mounted path ---
iframe_html = f"<iframe src='{mount_path}' width='100%' height='800px' frameborder='0'></iframe>"
return gr.HTML(iframe_html)
# --- Gradio UI Definition ---
with gr.Blocks(theme=gr.themes.Soft(), title="Embedding Atlas Explorer") as app:
gr.Markdown("# Embedding Atlas Explorer")
gr.Markdown(
"Visualize any text column from a Hugging Face dataset. This app loads the data, "
"computes embeddings using Sentence Transformers, reduces dimensionality with UMAP, "
"and displays the result in an interactive Embedding Atlas."
)
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### 1. Configuration")
dataset_input = gr.Textbox(
label="Hugging Face Dataset Name",
value="Trendyol/Trendyol-Cybersecurity-Instruction-Tuning-Dataset"
)
text_column_input = gr.Textbox(
label="Text Column to Visualize",
value="instruction"
)
split_input = gr.Textbox(label="Dataset Split", value="train")
sample_size_input = gr.Slider(
label="Number of Samples (0 for all)",
minimum=0,
maximum=5000,
value=2000,
step=100
)
with gr.Accordion("Advanced Settings", open=False):
model_input = gr.Dropdown(
label="Embedding Model",
choices=["all-MiniLM-L6-v2", "all-mpnet-base-v2", "multi-qa-MiniLM-L6-cos-v1"],
value="all-MiniLM-L6-v2",
)
umap_neighbors_input = gr.Slider(
label="UMAP Neighbors",
minimum=2,
maximum=100,
value=15,
step=1,
info="Controls how UMAP balances local vs. global structure."
)
umap_min_dist_input = gr.Slider(
label="UMAP Min Distance",
minimum=0.0,
maximum=0.99,
value=0.1,
step=0.01,
info="Controls how tightly UMAP packs points together."
)
generate_button = gr.Button("Generate Atlas", variant="primary")
with gr.Column(scale=3):
gr.Markdown("### 2. Visualization")
output_html = gr.HTML(
"<div style='display:flex; justify-content:center; align-items:center; height:800px; border: 1px solid #ddd; border-radius: 5px;'><p>Atlas will be displayed here after generation.</p></div>"
)
generate_button.click(
fn=generate_atlas,
inputs=[
dataset_input,
text_column_input,
split_input,
sample_size_input,
model_input,
umap_neighbors_input,
umap_min_dist_input,
],
outputs=[output_html],
)
if __name__ == "__main__":
app.launch()