File size: 7,334 Bytes
bb95205
fe5ff1b
bb95205
5d7eb35
bb95205
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7656238
bb95205
8904c8e
bb95205
 
 
 
fe5ff1b
8904c8e
bb95205
 
 
 
 
 
 
 
fe5ff1b
bb95205
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8904c8e
bb95205
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8904c8e
bb95205
fe5ff1b
bb95205
 
 
 
 
8904c8e
bb95205
 
 
 
 
 
 
 
 
 
 
 
 
c5a0831
bb95205
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
import gradio as gr
import pandas as pd
from datasets import load_dataset
import os
import pathlib
import uuid

# --- Embedding Atlas Imports ---
# We will import the necessary components directly from the library
from embedding_atlas.data_source import DataSource
from embedding_atlas.server import make_server
from embedding_atlas.projection import compute_text_projection
from embedding_atlas.utils import find_column_name, Hasher

# --- Global State ---
# We need to keep track of the mounted app to avoid errors on re-runs.
# A dictionary to store unique app instances for each run.
mounted_apps = {}

def get_atlas_static_path():
    """Finds the path to the static files for the embedding-atlas frontend."""
    import embedding_atlas
    return str((pathlib.Path(embedding_atlas.__file__).parent / "static").resolve())

def generate_atlas(
    dataset_name: str,
    text_column: str,
    split: str,
    sample_size: int,
    model_name: str,
    umap_neighbors: int,
    umap_min_dist: float,
    progress=gr.Progress(track_ τότε=True)
):
    """
    Loads data, computes embeddings, and serves the Embedding Atlas UI.
    """
    global mounted_apps
    
    # --- 1. Load Data ---
    progress(0, desc=f"Loading dataset '{dataset_name}'...")
    try:
        # Load the dataset from Hugging Face
        dataset = load_dataset(dataset_name, split=split)
        df = dataset.to_pandas()
    except Exception as e:
        raise gr.Error(f"Failed to load dataset. Please check the name and split. Error: {e}")

    # --- 2. Sample Data (if requested) ---
    if sample_size > 0 and sample_size < len(df):
        progress(0.1, desc=f"Sampling {sample_size} rows...")
        df = df.sample(n=sample_size, random_state=42)
    
    # Check if the text column exists
    if text_column not in df.columns:
        raise gr.Error(f"Column '{text_column}' not found in the dataset. Available columns: {', '.join(df.columns)}")

    # --- 3. Compute Embeddings & UMAP Projection ---
    progress(0.2, desc="Computing embeddings and UMAP projection. This may take a while...")
    
    x_column = find_column_name(df.columns, "projection_x")
    y_column = find_column_name(df.columns, "projection_y")
    neighbors_column = find_column_name(df.columns, "__neighbors")
    
    try:
        compute_text_projection(
            df,
            text_column,
            x=x_column,
            y=y_column,
            neighbors=neighbors_column,
            model=model_name,
            umap_args={
                "n_neighbors": umap_neighbors,
                "min_dist": umap_min_dist,
                "metric": "cosine",
                "random_state": 42,
            },
        )
    except Exception as e:
        raise gr.Error(f"Failed to compute embeddings. Check model name or try a smaller sample. Error: {e}")

    # --- 4. Prepare Atlas DataSource ---
    progress(0.8, desc="Preparing Atlas data source...")
    id_column = find_column_name(df.columns, "_row_index")
    df[id_column] = range(df.shape[0])

    metadata = {
        "columns": {
            "id": id_column,
            "text": text_column,
            "embedding": {"x": x_column, "y": y_column},
            "neighbors": neighbors_column,
        },
    }

    # Create a unique identifier for the dataset to avoid conflicts
    hasher = Hasher()
    hasher.update(f"{dataset_name}-{text_column}-{sample_size}-{model_name}")
    identifier = hasher.hexdigest()

    atlas_dataset = DataSource(identifier, df, metadata)
    static_path = get_atlas_static_path()
    
    # --- 5. Create and Mount the FastAPI App ---
    progress(0.9, desc="Mounting visualization UI...")
    
    # Generate a unique path for this instance to avoid conflicts on remounting
    mount_path = f"/{uuid.uuid4().hex}"
    
    # Create the server instance
    atlas_app = make_server(atlas_dataset, static_path=static_path, duckdb_uri="wasm")
    
    # The `blocks` object is global in the Gradio script context.
    # We mount the atlas server onto the main Gradio FastAPI app.
    gr.mount_gradio_app(app, atlas_app, path=mount_path)
    
    mounted_apps[mount_path] = atlas_app # Store it for potential cleanup later

    progress(1.0, desc="Done!")

    # --- 6. Return an IFrame pointing to the mounted path ---
    iframe_html = f"<iframe src='{mount_path}' width='100%' height='800px' frameborder='0'></iframe>"
    
    return gr.HTML(iframe_html)


# --- Gradio UI Definition ---
with gr.Blocks(theme=gr.themes.Soft(), title="Embedding Atlas Explorer") as app:
    gr.Markdown("# Embedding Atlas Explorer")
    gr.Markdown(
        "Visualize any text column from a Hugging Face dataset. This app loads the data, "
        "computes embeddings using Sentence Transformers, reduces dimensionality with UMAP, "
        "and displays the result in an interactive Embedding Atlas."
    )

    with gr.Row():
        with gr.Column(scale=1):
            gr.Markdown("### 1. Configuration")
            dataset_input = gr.Textbox(
                label="Hugging Face Dataset Name", 
                value="Trendyol/Trendyol-Cybersecurity-Instruction-Tuning-Dataset"
            )
            text_column_input = gr.Textbox(
                label="Text Column to Visualize", 
                value="instruction"
            )
            split_input = gr.Textbox(label="Dataset Split", value="train")
            sample_size_input = gr.Slider(
                label="Number of Samples (0 for all)",
                minimum=0,
                maximum=5000,
                value=2000,
                step=100
            )

            with gr.Accordion("Advanced Settings", open=False):
                model_input = gr.Dropdown(
                    label="Embedding Model",
                    choices=["all-MiniLM-L6-v2", "all-mpnet-base-v2", "multi-qa-MiniLM-L6-cos-v1"],
                    value="all-MiniLM-L6-v2",
                )
                umap_neighbors_input = gr.Slider(
                    label="UMAP Neighbors",
                    minimum=2,
                    maximum=100,
                    value=15,
                    step=1,
                    info="Controls how UMAP balances local vs. global structure."
                )
                umap_min_dist_input = gr.Slider(
                    label="UMAP Min Distance",
                    minimum=0.0,
                    maximum=0.99,
                    value=0.1,
                    step=0.01,
                    info="Controls how tightly UMAP packs points together."
                )

            generate_button = gr.Button("Generate Atlas", variant="primary")

        with gr.Column(scale=3):
            gr.Markdown("### 2. Visualization")
            output_html = gr.HTML(
                "<div style='display:flex; justify-content:center; align-items:center; height:800px; border: 1px solid #ddd; border-radius: 5px;'><p>Atlas will be displayed here after generation.</p></div>"
            )

    generate_button.click(
        fn=generate_atlas,
        inputs=[
            dataset_input,
            text_column_input,
            split_input,
            sample_size_input,
            model_input,
            umap_neighbors_input,
            umap_min_dist_input,
        ],
        outputs=[output_html],
    )

if __name__ == "__main__":
    app.launch()