Spaces:
Sleeping
Sleeping
File size: 7,334 Bytes
bb95205 fe5ff1b bb95205 5d7eb35 bb95205 7656238 bb95205 8904c8e bb95205 fe5ff1b 8904c8e bb95205 fe5ff1b bb95205 8904c8e bb95205 8904c8e bb95205 fe5ff1b bb95205 8904c8e bb95205 c5a0831 bb95205 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 |
import gradio as gr
import pandas as pd
from datasets import load_dataset
import os
import pathlib
import uuid
# --- Embedding Atlas Imports ---
# We will import the necessary components directly from the library
from embedding_atlas.data_source import DataSource
from embedding_atlas.server import make_server
from embedding_atlas.projection import compute_text_projection
from embedding_atlas.utils import find_column_name, Hasher
# --- Global State ---
# We need to keep track of the mounted app to avoid errors on re-runs.
# A dictionary to store unique app instances for each run.
mounted_apps = {}
def get_atlas_static_path():
"""Finds the path to the static files for the embedding-atlas frontend."""
import embedding_atlas
return str((pathlib.Path(embedding_atlas.__file__).parent / "static").resolve())
def generate_atlas(
dataset_name: str,
text_column: str,
split: str,
sample_size: int,
model_name: str,
umap_neighbors: int,
umap_min_dist: float,
progress=gr.Progress(track_ τότε=True)
):
"""
Loads data, computes embeddings, and serves the Embedding Atlas UI.
"""
global mounted_apps
# --- 1. Load Data ---
progress(0, desc=f"Loading dataset '{dataset_name}'...")
try:
# Load the dataset from Hugging Face
dataset = load_dataset(dataset_name, split=split)
df = dataset.to_pandas()
except Exception as e:
raise gr.Error(f"Failed to load dataset. Please check the name and split. Error: {e}")
# --- 2. Sample Data (if requested) ---
if sample_size > 0 and sample_size < len(df):
progress(0.1, desc=f"Sampling {sample_size} rows...")
df = df.sample(n=sample_size, random_state=42)
# Check if the text column exists
if text_column not in df.columns:
raise gr.Error(f"Column '{text_column}' not found in the dataset. Available columns: {', '.join(df.columns)}")
# --- 3. Compute Embeddings & UMAP Projection ---
progress(0.2, desc="Computing embeddings and UMAP projection. This may take a while...")
x_column = find_column_name(df.columns, "projection_x")
y_column = find_column_name(df.columns, "projection_y")
neighbors_column = find_column_name(df.columns, "__neighbors")
try:
compute_text_projection(
df,
text_column,
x=x_column,
y=y_column,
neighbors=neighbors_column,
model=model_name,
umap_args={
"n_neighbors": umap_neighbors,
"min_dist": umap_min_dist,
"metric": "cosine",
"random_state": 42,
},
)
except Exception as e:
raise gr.Error(f"Failed to compute embeddings. Check model name or try a smaller sample. Error: {e}")
# --- 4. Prepare Atlas DataSource ---
progress(0.8, desc="Preparing Atlas data source...")
id_column = find_column_name(df.columns, "_row_index")
df[id_column] = range(df.shape[0])
metadata = {
"columns": {
"id": id_column,
"text": text_column,
"embedding": {"x": x_column, "y": y_column},
"neighbors": neighbors_column,
},
}
# Create a unique identifier for the dataset to avoid conflicts
hasher = Hasher()
hasher.update(f"{dataset_name}-{text_column}-{sample_size}-{model_name}")
identifier = hasher.hexdigest()
atlas_dataset = DataSource(identifier, df, metadata)
static_path = get_atlas_static_path()
# --- 5. Create and Mount the FastAPI App ---
progress(0.9, desc="Mounting visualization UI...")
# Generate a unique path for this instance to avoid conflicts on remounting
mount_path = f"/{uuid.uuid4().hex}"
# Create the server instance
atlas_app = make_server(atlas_dataset, static_path=static_path, duckdb_uri="wasm")
# The `blocks` object is global in the Gradio script context.
# We mount the atlas server onto the main Gradio FastAPI app.
gr.mount_gradio_app(app, atlas_app, path=mount_path)
mounted_apps[mount_path] = atlas_app # Store it for potential cleanup later
progress(1.0, desc="Done!")
# --- 6. Return an IFrame pointing to the mounted path ---
iframe_html = f"<iframe src='{mount_path}' width='100%' height='800px' frameborder='0'></iframe>"
return gr.HTML(iframe_html)
# --- Gradio UI Definition ---
with gr.Blocks(theme=gr.themes.Soft(), title="Embedding Atlas Explorer") as app:
gr.Markdown("# Embedding Atlas Explorer")
gr.Markdown(
"Visualize any text column from a Hugging Face dataset. This app loads the data, "
"computes embeddings using Sentence Transformers, reduces dimensionality with UMAP, "
"and displays the result in an interactive Embedding Atlas."
)
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### 1. Configuration")
dataset_input = gr.Textbox(
label="Hugging Face Dataset Name",
value="Trendyol/Trendyol-Cybersecurity-Instruction-Tuning-Dataset"
)
text_column_input = gr.Textbox(
label="Text Column to Visualize",
value="instruction"
)
split_input = gr.Textbox(label="Dataset Split", value="train")
sample_size_input = gr.Slider(
label="Number of Samples (0 for all)",
minimum=0,
maximum=5000,
value=2000,
step=100
)
with gr.Accordion("Advanced Settings", open=False):
model_input = gr.Dropdown(
label="Embedding Model",
choices=["all-MiniLM-L6-v2", "all-mpnet-base-v2", "multi-qa-MiniLM-L6-cos-v1"],
value="all-MiniLM-L6-v2",
)
umap_neighbors_input = gr.Slider(
label="UMAP Neighbors",
minimum=2,
maximum=100,
value=15,
step=1,
info="Controls how UMAP balances local vs. global structure."
)
umap_min_dist_input = gr.Slider(
label="UMAP Min Distance",
minimum=0.0,
maximum=0.99,
value=0.1,
step=0.01,
info="Controls how tightly UMAP packs points together."
)
generate_button = gr.Button("Generate Atlas", variant="primary")
with gr.Column(scale=3):
gr.Markdown("### 2. Visualization")
output_html = gr.HTML(
"<div style='display:flex; justify-content:center; align-items:center; height:800px; border: 1px solid #ddd; border-radius: 5px;'><p>Atlas will be displayed here after generation.</p></div>"
)
generate_button.click(
fn=generate_atlas,
inputs=[
dataset_input,
text_column_input,
split_input,
sample_size_input,
model_input,
umap_neighbors_input,
umap_min_dist_input,
],
outputs=[output_html],
)
if __name__ == "__main__":
app.launch() |