broadfield-dev commited on
Commit
bb95205
·
verified ·
1 Parent(s): 8904c8e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +191 -31
app.py CHANGED
@@ -1,43 +1,203 @@
1
- from datasets import load_dataset
2
  import pandas as pd
 
3
  import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
- # --- Configuration ---
6
 
7
- # 1. Hardcode the name of the Hugging Face dataset
8
- dataset_name = "Trendyol/Trendyol-Cybersecurity-Instruction-Tuning-Dataset"
 
 
9
 
10
- # 2. Define the name for the local file where the data will be saved
11
- local_file_path = "trendyol_cybersecurity_dataset.csv"
12
 
13
- # 3. Define the port for the Embedding Atlas server
14
- port = 7860
 
 
 
 
 
 
15
 
16
- # --- Script Logic ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
- # Step 1: Load the dataset from Hugging Face
19
- print(f"Loading dataset '{dataset_name}' from the Hub...")
20
- try:
21
- dataset = load_dataset(dataset_name, split="train")
22
- except Exception as e:
23
- print(f"Failed to load dataset. Error: {e}")
24
- exit()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
- # Step 2: Convert the dataset to a Pandas DataFrame
27
- print("Converting dataset to Pandas DataFrame...")
28
- df = dataset.to_pandas()
29
 
30
- # Step 3: Save the DataFrame to a local CSV file
31
- # This is the crucial step. The CLI tool will read from this file.
32
- # We use index=False to avoid saving the pandas index as an extra column.
33
- print(f"Saving DataFrame to a local file: '{local_file_path}'")
34
- df.to_csv(local_file_path, index=False)
35
- print("Save complete.")
36
 
37
- # Step 4: Construct and run the CLI command using the LOCAL file path
38
- command = f"embedding-atlas {local_file_path} --port {port}"
39
- print(f"\nLaunching Embedding Atlas...")
40
- print(f"Running command: {command}")
41
- print(f"Access the UI in your browser at: http://127.0.0.1:{port}")
 
 
 
 
 
 
 
 
42
 
43
- os.system(command)
 
 
1
+ import gradio as gr
2
  import pandas as pd
3
+ from datasets import load_dataset
4
  import os
5
+ import pathlib
6
+ import uuid
7
+
8
+ # --- Embedding Atlas Imports ---
9
+ # We will import the necessary components directly from the library
10
+ from embedding_atlas.data_source import DataSource
11
+ from embedding_atlas.server import make_server
12
+ from embedding_atlas.projection import compute_text_projection
13
+ from embedding_atlas.utils import find_column_name, Hasher
14
+
15
+ # --- Global State ---
16
+ # We need to keep track of the mounted app to avoid errors on re-runs.
17
+ # A dictionary to store unique app instances for each run.
18
+ mounted_apps = {}
19
+
20
+ def get_atlas_static_path():
21
+ """Finds the path to the static files for the embedding-atlas frontend."""
22
+ import embedding_atlas
23
+ return str((pathlib.Path(embedding_atlas.__file__).parent / "static").resolve())
24
+
25
+ def generate_atlas(
26
+ dataset_name: str,
27
+ text_column: str,
28
+ split: str,
29
+ sample_size: int,
30
+ model_name: str,
31
+ umap_neighbors: int,
32
+ umap_min_dist: float,
33
+ progress=gr.Progress(track_ τότε=True)
34
+ ):
35
+ """
36
+ Loads data, computes embeddings, and serves the Embedding Atlas UI.
37
+ """
38
+ global mounted_apps
39
+
40
+ # --- 1. Load Data ---
41
+ progress(0, desc=f"Loading dataset '{dataset_name}'...")
42
+ try:
43
+ # Load the dataset from Hugging Face
44
+ dataset = load_dataset(dataset_name, split=split)
45
+ df = dataset.to_pandas()
46
+ except Exception as e:
47
+ raise gr.Error(f"Failed to load dataset. Please check the name and split. Error: {e}")
48
+
49
+ # --- 2. Sample Data (if requested) ---
50
+ if sample_size > 0 and sample_size < len(df):
51
+ progress(0.1, desc=f"Sampling {sample_size} rows...")
52
+ df = df.sample(n=sample_size, random_state=42)
53
+
54
+ # Check if the text column exists
55
+ if text_column not in df.columns:
56
+ raise gr.Error(f"Column '{text_column}' not found in the dataset. Available columns: {', '.join(df.columns)}")
57
+
58
+ # --- 3. Compute Embeddings & UMAP Projection ---
59
+ progress(0.2, desc="Computing embeddings and UMAP projection. This may take a while...")
60
+
61
+ x_column = find_column_name(df.columns, "projection_x")
62
+ y_column = find_column_name(df.columns, "projection_y")
63
+ neighbors_column = find_column_name(df.columns, "__neighbors")
64
+
65
+ try:
66
+ compute_text_projection(
67
+ df,
68
+ text_column,
69
+ x=x_column,
70
+ y=y_column,
71
+ neighbors=neighbors_column,
72
+ model=model_name,
73
+ umap_args={
74
+ "n_neighbors": umap_neighbors,
75
+ "min_dist": umap_min_dist,
76
+ "metric": "cosine",
77
+ "random_state": 42,
78
+ },
79
+ )
80
+ except Exception as e:
81
+ raise gr.Error(f"Failed to compute embeddings. Check model name or try a smaller sample. Error: {e}")
82
+
83
+ # --- 4. Prepare Atlas DataSource ---
84
+ progress(0.8, desc="Preparing Atlas data source...")
85
+ id_column = find_column_name(df.columns, "_row_index")
86
+ df[id_column] = range(df.shape[0])
87
+
88
+ metadata = {
89
+ "columns": {
90
+ "id": id_column,
91
+ "text": text_column,
92
+ "embedding": {"x": x_column, "y": y_column},
93
+ "neighbors": neighbors_column,
94
+ },
95
+ }
96
+
97
+ # Create a unique identifier for the dataset to avoid conflicts
98
+ hasher = Hasher()
99
+ hasher.update(f"{dataset_name}-{text_column}-{sample_size}-{model_name}")
100
+ identifier = hasher.hexdigest()
101
+
102
+ atlas_dataset = DataSource(identifier, df, metadata)
103
+ static_path = get_atlas_static_path()
104
+
105
+ # --- 5. Create and Mount the FastAPI App ---
106
+ progress(0.9, desc="Mounting visualization UI...")
107
+
108
+ # Generate a unique path for this instance to avoid conflicts on remounting
109
+ mount_path = f"/{uuid.uuid4().hex}"
110
+
111
+ # Create the server instance
112
+ atlas_app = make_server(atlas_dataset, static_path=static_path, duckdb_uri="wasm")
113
+
114
+ # The `blocks` object is global in the Gradio script context.
115
+ # We mount the atlas server onto the main Gradio FastAPI app.
116
+ gr.mount_gradio_app(app, atlas_app, path=mount_path)
117
+
118
+ mounted_apps[mount_path] = atlas_app # Store it for potential cleanup later
119
 
120
+ progress(1.0, desc="Done!")
121
 
122
+ # --- 6. Return an IFrame pointing to the mounted path ---
123
+ iframe_html = f"<iframe src='{mount_path}' width='100%' height='800px' frameborder='0'></iframe>"
124
+
125
+ return gr.HTML(iframe_html)
126
 
 
 
127
 
128
+ # --- Gradio UI Definition ---
129
+ with gr.Blocks(theme=gr.themes.Soft(), title="Embedding Atlas Explorer") as app:
130
+ gr.Markdown("# Embedding Atlas Explorer")
131
+ gr.Markdown(
132
+ "Visualize any text column from a Hugging Face dataset. This app loads the data, "
133
+ "computes embeddings using Sentence Transformers, reduces dimensionality with UMAP, "
134
+ "and displays the result in an interactive Embedding Atlas."
135
+ )
136
 
137
+ with gr.Row():
138
+ with gr.Column(scale=1):
139
+ gr.Markdown("### 1. Configuration")
140
+ dataset_input = gr.Textbox(
141
+ label="Hugging Face Dataset Name",
142
+ value="Trendyol/Trendyol-Cybersecurity-Instruction-Tuning-Dataset"
143
+ )
144
+ text_column_input = gr.Textbox(
145
+ label="Text Column to Visualize",
146
+ value="instruction"
147
+ )
148
+ split_input = gr.Textbox(label="Dataset Split", value="train")
149
+ sample_size_input = gr.Slider(
150
+ label="Number of Samples (0 for all)",
151
+ minimum=0,
152
+ maximum=5000,
153
+ value=2000,
154
+ step=100
155
+ )
156
 
157
+ with gr.Accordion("Advanced Settings", open=False):
158
+ model_input = gr.Dropdown(
159
+ label="Embedding Model",
160
+ choices=["all-MiniLM-L6-v2", "all-mpnet-base-v2", "multi-qa-MiniLM-L6-cos-v1"],
161
+ value="all-MiniLM-L6-v2",
162
+ )
163
+ umap_neighbors_input = gr.Slider(
164
+ label="UMAP Neighbors",
165
+ minimum=2,
166
+ maximum=100,
167
+ value=15,
168
+ step=1,
169
+ info="Controls how UMAP balances local vs. global structure."
170
+ )
171
+ umap_min_dist_input = gr.Slider(
172
+ label="UMAP Min Distance",
173
+ minimum=0.0,
174
+ maximum=0.99,
175
+ value=0.1,
176
+ step=0.01,
177
+ info="Controls how tightly UMAP packs points together."
178
+ )
179
 
180
+ generate_button = gr.Button("Generate Atlas", variant="primary")
 
 
181
 
182
+ with gr.Column(scale=3):
183
+ gr.Markdown("### 2. Visualization")
184
+ output_html = gr.HTML(
185
+ "<div style='display:flex; justify-content:center; align-items:center; height:800px; border: 1px solid #ddd; border-radius: 5px;'><p>Atlas will be displayed here after generation.</p></div>"
186
+ )
 
187
 
188
+ generate_button.click(
189
+ fn=generate_atlas,
190
+ inputs=[
191
+ dataset_input,
192
+ text_column_input,
193
+ split_input,
194
+ sample_size_input,
195
+ model_input,
196
+ umap_neighbors_input,
197
+ umap_min_dist_input,
198
+ ],
199
+ outputs=[output_html],
200
+ )
201
 
202
+ if __name__ == "__main__":
203
+ app.launch()