Liu Yiwen commited on
Commit
f3718f0
·
0 Parent(s):

initial commit

Browse files
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Lotsa Explorer
3
+ emoji: 📈
4
+ colorFrom: indigo
5
+ colorTo: blue
6
+ sdk: gradio
7
+ sdk_version: 4.36.0
8
+ app_file: app.py
9
+ pinned: false
10
+ short_description: A tool for interactive observation of lotsa dataset
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
14
+
15
+ This project is modified based on the project https://huggingface.co/spaces/lhoestq/datasets-explorer
16
+
17
+ Run:
18
+
19
+ ```python
20
+ gradio app.py
21
+ ```
__pycache__/comm_utils.cpython-311.pyc ADDED
Binary file (1.49 kB). View file
 
__pycache__/utils.cpython-311.pyc ADDED
Binary file (8.64 kB). View file
 
app.py ADDED
@@ -0,0 +1,393 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import os
3
+ import time
4
+ from functools import lru_cache, partial
5
+
6
+ import gradio as gr
7
+ import numpy as np
8
+ import pandas as pd
9
+ import pyarrow as pa
10
+ import pyarrow.parquet as pq
11
+ from tqdm.contrib.concurrent import thread_map
12
+ from fastapi import FastAPI, Response
13
+ import uvicorn
14
+ from hffs.fs import HfFileSystem
15
+ from datasets import Features, Image, Audio, Sequence
16
+ from typing import List, Tuple, Callable
17
+
18
+ from utils import ndarray_to_base64, clean_up_df, create_statistic, create_plot
19
+ from comm_utils import save_to_file, send_msg_to_server
20
+
21
+ class AppError(RuntimeError):
22
+ pass
23
+
24
+
25
+ APP_URL = "http://127.0.0.1:7860" if os.getenv("DEV") else "https://Kamarov-lotsa-explorer.hf.space"
26
+ PAGE_SIZE = 1
27
+ MAX_CACHED_BLOBS = PAGE_SIZE * 10
28
+ TIME_PLOTS_NUM = 1
29
+ _blobs_cache = {}
30
+
31
+
32
+ #####################################################
33
+ # Define routes for image and audio files
34
+ #####################################################
35
+
36
+ app = FastAPI()
37
+
38
+
39
+ @app.get(
40
+ "/image",
41
+ responses={200: {"content": {"image/png": {}}}},
42
+ response_class=Response,
43
+ )
44
+ def image(id: str):
45
+ blob = get_blob(id)
46
+ return Response(content=blob, media_type="image/png")
47
+
48
+
49
+ @app.get(
50
+ "/audio",
51
+ responses={200: {"content": {"audio/wav": {}}}},
52
+ response_class=Response,
53
+ )
54
+ def audio(id: str):
55
+ blob = get_blob(id)
56
+ return Response(content=blob, media_type="audio/wav")
57
+
58
+
59
+ def push_blob(blob: bytes, blob_id: str) -> str:
60
+ global _blobs_cache
61
+ if blob_id in _blobs_cache:
62
+ del _blobs_cache[blob_id]
63
+ _blobs_cache[blob_id] = blob
64
+ if len(_blobs_cache) > MAX_CACHED_BLOBS:
65
+ del _blobs_cache[next(iter(_blobs_cache))]
66
+ return blob_id
67
+
68
+
69
+ def get_blob(blob_id: str) -> bytes:
70
+ global _blobs_cache
71
+ return _blobs_cache[blob_id]
72
+
73
+
74
+ def blobs_to_urls(blobs: List[bytes], type: str, prefix: str) -> List[str]:
75
+ image_blob_ids = [push_blob(blob, f"{prefix}-{i}") for i, blob in enumerate(blobs)]
76
+ return [APP_URL + f"/{type}?id={blob_id}" for blob_id in image_blob_ids]
77
+
78
+
79
+ #####################################################
80
+ # List configs, splits and parquet files
81
+ #####################################################
82
+
83
+
84
+ @lru_cache(maxsize=128)
85
+ def get_parquet_fs(dataset: str) -> HfFileSystem:
86
+ try:
87
+ fs = HfFileSystem(dataset, repo_type="dataset", revision="refs/convert/parquet")
88
+ if any(fs.isfile(path) for path in fs.ls("") if not path.startswith(".")):
89
+ raise AppError(f"Parquet export doesn't exist for '{dataset}'.")
90
+ return fs
91
+ except:
92
+ raise AppError(f"Parquet export doesn't exist for '{dataset}'.")
93
+
94
+
95
+
96
+ @lru_cache(maxsize=128)
97
+ def get_parquet_configs(dataset: str) -> List[str]:
98
+ fs = get_parquet_fs(dataset)
99
+ return [path for path in fs.ls("") if fs.isdir(path)]
100
+
101
+
102
+ def _sorted_split_key(split: str) -> str:
103
+ return split if not split.startswith("train") else chr(0) + split # always "train" first
104
+
105
+
106
+ @lru_cache(maxsize=128)
107
+ def get_parquet_splits(dataset: str, config: str) -> List[str]:
108
+ fs = get_parquet_fs(dataset)
109
+ return [path.split("/")[1] for path in fs.ls(config) if fs.isdir(path)]
110
+
111
+
112
+ #####################################################
113
+ # Index and query Parquet data
114
+ #####################################################
115
+
116
+
117
+ RowGroupReaders = List[Callable[[], pa.Table]]
118
+
119
+
120
+ @lru_cache(maxsize=128)
121
+ def index(dataset: str, config: str, split: str) -> Tuple[np.ndarray, RowGroupReaders, int, Features]:
122
+ fs = get_parquet_fs(dataset)
123
+ sources = fs.glob(f"{config}/{split}/*.parquet")
124
+ if not sources:
125
+ if config not in get_parquet_configs(dataset):
126
+ raise AppError(f"Invalid config {config}. Available configs are: {', '.join(get_parquet_configs(dataset))}.")
127
+ else:
128
+ raise AppError(f"Invalid split {split}. Available splits are: {', '.join(get_parquet_splits(dataset, config))}.")
129
+ desc = f"{dataset}/{config}/{split}"
130
+ all_pf: List[pq.ParquetFile] = thread_map(partial(pq.ParquetFile, filesystem=fs), sources, desc=desc, unit="pq")
131
+ features = Features.from_arrow_schema(all_pf[0].schema.to_arrow_schema())
132
+ rg_offsets = np.cumsum([pf.metadata.row_group(i).num_rows for pf in all_pf for i in range(pf.metadata.num_row_groups)])
133
+ rg_readers = [partial(pf.read_row_group, i) for pf in all_pf for i in range(pf.metadata.num_row_groups)]
134
+ max_page = 1 + (rg_offsets[-1] - 1) // PAGE_SIZE
135
+ return rg_offsets, rg_readers, max_page, features
136
+
137
+
138
+ def query(page: int, page_size: int, rg_offsets: np.ndarray, rg_readers: RowGroupReaders) -> pd.DataFrame:
139
+ start_row, end_row = (page - 1) * page_size, min(page * page_size, rg_offsets[-1] - 1) # both included
140
+ # rg_offsets[start_rg - 1] <= start_row < rg_offsets[start_rg]
141
+ # rg_offsets[end_rg - 1] <= end_row < rg_offsets[end_rg]
142
+ start_rg, end_rg = np.searchsorted(rg_offsets, [start_row, end_row], side="right") # both included
143
+ t = time.time()
144
+ # TODO:性能瓶颈
145
+ pa_table = pa.concat_tables([rg_readers[i]() for i in range(start_rg, end_rg + 1)])
146
+ print(f"concat_tables time: {time.time()-t}")
147
+ offset = start_row - (rg_offsets[start_rg - 1] if start_rg > 0 else 0)
148
+ pa_table = pa_table.slice(offset, page_size)
149
+ return pa_table.to_pandas()
150
+
151
+
152
+ def sanitize_inputs(dataset: str, config: str, split: str, page: str) -> Tuple[str, str, str, int]:
153
+ try:
154
+ page = int(page)
155
+ assert page > 0
156
+ except:
157
+ raise AppError(f"Bad page: {page}")
158
+ if not dataset:
159
+ raise AppError("Empty dataset name")
160
+ if not config:
161
+ raise AppError(f"Empty config. Available configs are: {', '.join(get_parquet_configs(dataset))}.")
162
+ if not split:
163
+ raise AppError(f"Empty split. Available splits are: {', '.join(get_parquet_splits(dataset, config))}.")
164
+ return dataset, config, split, int(page)
165
+
166
+
167
+ @lru_cache(maxsize=128)
168
+ def get_page_df(dataset: str, config: str, split: str, page: str) -> Tuple[pd.DataFrame, int, Features]:
169
+ dataset, config, split, page = sanitize_inputs(dataset, config, split, page)
170
+ rg_offsets, rg_readers, max_page, features = index(dataset, config, split)
171
+ if page > max_page:
172
+ raise AppError(f"Page {page} does not exist")
173
+ df = query(page, PAGE_SIZE, rg_offsets=rg_offsets, rg_readers=rg_readers)
174
+ return df, max_page, features
175
+
176
+
177
+ #####################################################
178
+ # Format results
179
+ #####################################################
180
+
181
+ def get_page(dataset: str, config: str, split: str, page: str) -> Tuple[str, int, str]:
182
+ df_, max_page, features = get_page_df(dataset, config, split, page)
183
+ df = copy.deepcopy(df_)
184
+
185
+ unsupported_columns = []
186
+ if dataset == 'Salesforce/lotsa_data':
187
+ # 对Salesforce/lotsa_data数据集进行特殊处理
188
+ info = "" if not unsupported_columns else f"Some columns are not supported yet: {unsupported_columns}"
189
+ return df, max_page, info
190
+
191
+ elif dataset == 'YY26/TS_DATASETS':
192
+ # 对YY26/TS_DATASETS数据集进行特殊处理
193
+ info = "" if not unsupported_columns else f"Some columns are not supported yet: {unsupported_columns}"
194
+ return df, max_page, info
195
+ else:
196
+ # 其他数据集保留原有逻辑
197
+ for column, feature in features.items():
198
+ if isinstance(feature, Image):
199
+ blob_type = "image" # TODO: support audio - right now it seems that the markdown renderer in gradio doesn't support audio and shows nothing
200
+ blob_urls = blobs_to_urls([item.get("bytes") if isinstance(item, dict) else None for item in df[column]], blob_type, prefix=f"{dataset}-{config}-{split}-{page}-{column}")
201
+ df = df.drop([column], axis=1)
202
+ df[column] = [f"![]({url})" for url in blob_urls]
203
+ elif any(bad_type in str(feature) for bad_type in ["Image(", "Audio(", "'binary'"]):
204
+ unsupported_columns.append(column)
205
+ df = df.drop([column], axis=1)
206
+ elif isinstance(feature, Sequence):
207
+ if feature.feature.dtype == 'float32':
208
+ # 直接将内容绘图,并嵌入为Base64编码
209
+ base64_srcs = [ndarray_to_base64(vec) for vec in df[column]]
210
+ df = df.drop([column], axis=1)
211
+ df[column] = [f"![]({src})" for src in base64_srcs]
212
+ info = "" if not unsupported_columns else f"Some columns are not supported yet: {unsupported_columns}"
213
+ return df.reset_index().to_markdown(index=False), max_page, info
214
+
215
+ #####################################################
216
+ # Process data
217
+ #####################################################
218
+ def process_salesforce_data(dataset: str, config: str, split: str, page: List[str], sub_targets: List[int|str]) -> Tuple[List[pd.DataFrame], List[str]]:
219
+ df_list, id_list = [], []
220
+
221
+ for i, page in enumerate(page):
222
+ df, max_page, info = get_page(dataset, config, split, page)
223
+
224
+ global tot_samples, tot_targets
225
+ tot_samples, tot_targets = max_page, len(df['target'][0]) if isinstance(df['target'][0], np.ndarray) and df['target'][0].dtype == 'O' else 1
226
+ if 'all' in sub_targets:
227
+ sub_targets = [i for i in range(tot_targets)]
228
+ df = clean_up_df(df, sub_targets)
229
+ row = df.iloc[0]
230
+ id_list.append(row['item_id'])
231
+ # 将单行的DataFrame展开为新的DataFrame
232
+ df_without_index = row.drop('item_id').to_frame().T
233
+ df_expanded = df_without_index.apply(pd.Series.explode).reset_index(drop=True).fillna(0)
234
+ df_list.append(df_expanded)
235
+ return df_list, id_list
236
+
237
+ #####################################################
238
+ # Gradio app
239
+ #####################################################
240
+
241
+
242
+ with gr.Blocks() as demo:
243
+ # 初始化组件
244
+ gr.Markdown("A tool for interactive observation of lotsa dataset, extended from lhoestq/datasets-explorer")
245
+ cp_dataset = gr.Textbox("YY26/TS_DATASETS", label="Pick a dataset", interactive=False)
246
+ cp_go = gr.Button("Explore")
247
+ cp_config = gr.Dropdown(["plain_text"], value="plain_text", label="Config", visible=False)
248
+ cp_split = gr.Dropdown(["train", "validation"], value="train", label="Split", visible=False)
249
+ cp_goto_next_page = gr.Button("Next page", visible=False)
250
+ cp_error = gr.Markdown("", visible=False)
251
+ cp_info = gr.Markdown("", visible=False)
252
+ cp_result = gr.Markdown("", visible=False)
253
+ tot_samples = 0
254
+ # 初始化Salesforce/lotsa_data数据集展示使用的组件
255
+ # componets = []
256
+ # for _ in range(TIME_PLOTS_NUM):
257
+ # with gr.Row():
258
+ # with gr.Column(scale=2):
259
+ # select_sample_box = gr.Dropdown(choices=["items"], label="Select some items", multiselect=True, interactive=True)
260
+ # with gr.Column(scale=2):
261
+ # select_subtarget_box = gr.Dropdown(choices=["subtargets"], label="Select some subtargets", multiselect=True, interactive=True)
262
+ # with gr.Column(scale=1):
263
+ # select_buttom = gr.Button("Show selected items")
264
+ with gr.Row():
265
+ with gr.Column(scale=2):
266
+ statistics_textbox = gr.DataFrame()
267
+ with gr.Column(scale=3):
268
+ plot = gr.Plot()
269
+ with gr.Row():
270
+ user_input_box = gr.Textbox(placeholder="输入一些内容", label="输入", lines=5, interactive=True)
271
+ user_output_box = gr.Textbox(label="回答", lines=5, interactive=False)
272
+ user_io_buttom = gr.Button("发送", interactive=True)
273
+ # componets.append({"select_sample_box": select_sample_box,
274
+ # "statistics_textbox": statistics_textbox,
275
+ # "user_input_box": user_input_box,
276
+ # "plot": plot})
277
+
278
+ with gr.Row():
279
+ cp_page = gr.Textbox("1", label="Page", placeholder="1", visible=False)
280
+ cp_goto_page = gr.Button("Go to page", visible=False)
281
+
282
+ def show_error(message: str) -> dict:
283
+ return {
284
+ cp_error: gr.update(visible=True, value=f"## ❌ Error:\n\n{message}"),
285
+ cp_info: gr.update(visible=False, value=""),
286
+ cp_result: gr.update(visible=False, value=""),
287
+ }
288
+
289
+
290
+ def show_dataset_at_config_and_split_and_page(dataset: str, config: str, split: str, page: str|List[str], sub_targets: List[int|str]=['all']) -> dict:
291
+ try:
292
+ ret = {}
293
+ if dataset == 'Salesforce/lotsa_data':
294
+ # 对Salesforce/lotsa_data数据集进行特殊处理
295
+ if type(page) == str:
296
+ page = [page]
297
+ df_list, id_list = process_salesforce_data(dataset, config, split, page, sub_targets)
298
+ ret[statistics_textbox] = gr.update(value=create_statistic(df_list, id_list))
299
+ ret[plot] = gr.update(value=create_plot(df_list, id_list))
300
+ elif dataset == 'YY26/TS_DATASETS':
301
+ df, max_page, info = get_page(dataset, config, split, page)
302
+ lotsa_config, lotsa_split, lotsa_page = 'traffic_hourly', 'train', eval(df['ts_id'][0])
303
+ # lotsa_subtargets = eval(df['target_id'][0])
304
+ df_list, id_list = process_salesforce_data('Salesforce/lotsa_data', lotsa_config, lotsa_split, lotsa_page, [1])
305
+ ret[statistics_textbox] = gr.update(value=create_statistic(df_list, id_list))
306
+ ret[plot] = gr.update(value=create_plot(df_list, id_list))
307
+ ret[user_input_box] = gr.update(value=df['question'][0])
308
+ ret[user_output_box] = gr.update(value=df['answer'][0])
309
+ else:
310
+ markdown_result, max_page, info = get_page(dataset, config, split, page)
311
+ ret[cp_result] = gr.update(visible=True, value=markdown_result)
312
+ return {
313
+ **ret,
314
+ cp_info: gr.update(visible=True, value=f"Page {page}/{max_page} {info}"),
315
+ cp_error: gr.update(visible=False, value="")
316
+ }
317
+ except AppError as err:
318
+ return show_error(str(err))
319
+
320
+ def show_dataset_at_config_and_split_and_next_page(dataset: str, config: str, split: str, page: str) -> dict:
321
+ try:
322
+ next_page = str(int(page) + 1)
323
+ return {
324
+ **show_dataset_at_config_and_split_and_page(dataset, config, split, next_page),
325
+ cp_page: gr.update(value=next_page, visible=True),
326
+ }
327
+ except AppError as err:
328
+ return show_error(str(err))
329
+
330
+ def show_dataset_at_config_and_split(dataset: str, config: str, split: str) -> dict:
331
+ try:
332
+ return {
333
+ **show_dataset_at_config_and_split_and_page(dataset, config, split, "1", [0]),
334
+ # select_sample_box: gr.update(choices=[f"{i+1}" for i in range(tot_samples)], value=["1"]),
335
+ # select_subtarget_box: gr.update(choices=[i for i in range(tot_targets)]+['all'], value=[0]),
336
+ cp_page: gr.update(value="1", visible=True),
337
+ cp_goto_page: gr.update(visible=True),
338
+ cp_goto_next_page: gr.update(visible=True),
339
+ }
340
+ except AppError as err:
341
+ return show_error(str(err))
342
+
343
+ def show_dataset_at_config(dataset: str, config: str) -> dict:
344
+ try:
345
+ splits = get_parquet_splits(dataset, config)
346
+ if not splits:
347
+ raise AppError(f"Dataset {dataset} with config {config} has no splits.")
348
+ else:
349
+ split = splits[0]
350
+ return {
351
+ **show_dataset_at_config_and_split(dataset, config, split),
352
+ cp_split: gr.update(value=split, choices=splits, visible=len(splits) > 1),
353
+ }
354
+ except AppError as err:
355
+ return show_error(str(err))
356
+
357
+ def show_dataset(dataset: str) -> dict:
358
+ try:
359
+ configs = get_parquet_configs(dataset)
360
+ if not configs:
361
+ raise AppError(f"Dataset {dataset} has no configs.")
362
+ else:
363
+ config = configs[0]
364
+ return {
365
+ **show_dataset_at_config(dataset, config),
366
+ cp_config: gr.update(value=config, choices=configs, visible=len(configs) > 1),
367
+ }
368
+ except AppError as err:
369
+ return show_error(str(err))
370
+
371
+ all_outputs = [cp_config, cp_split,
372
+ cp_page, cp_goto_page, cp_goto_next_page,
373
+ cp_result, cp_info, cp_error,
374
+ # select_sample_box, select_subtarget_box,
375
+ # select_buttom,
376
+ statistics_textbox, plot,
377
+ user_input_box, user_output_box]
378
+ cp_go.click(show_dataset, inputs=[cp_dataset], outputs=all_outputs)
379
+ cp_config.change(show_dataset_at_config, inputs=[cp_dataset, cp_config], outputs=all_outputs)
380
+ cp_split.change(show_dataset_at_config_and_split, inputs=[cp_dataset, cp_config, cp_split], outputs=all_outputs)
381
+ cp_goto_page.click(show_dataset_at_config_and_split_and_page, inputs=[cp_dataset, cp_config, cp_split, cp_page], outputs=all_outputs)
382
+ cp_goto_next_page.click(show_dataset_at_config_and_split_and_next_page, inputs=[cp_dataset, cp_config, cp_split, cp_page], outputs=all_outputs)
383
+ user_io_buttom.click(send_msg_to_server, inputs=[user_input_box], outputs=[user_output_box])
384
+ # select_buttom.click(show_dataset_at_config_and_split_and_page, inputs=[cp_dataset, cp_config, cp_split, select_sample_box, select_subtarget_box], outputs=all_outputs)
385
+
386
+
387
+ if __name__ == "__main__":
388
+
389
+ app = gr.mount_gradio_app(app, demo, path="/")
390
+ host = "127.0.0.1" if os.getenv("DEV") else "0.0.0.0"
391
+ # import subprocess
392
+ # subprocess.Popen(["python", "test_server.py"])
393
+ uvicorn.run(app, host=host, port=7860)
comm_utils.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+
3
+
4
+ API_URL = "http://127.0.0.1:5000/api/process"
5
+
6
+ def save_to_file(user_input):
7
+ with open("user_input.txt", "w") as file:
8
+ file.write(user_input)
9
+
10
+
11
+ def send_msg_to_server(input_text):
12
+ try:
13
+ # 构造请求数据
14
+ payload = {"text": input_text}
15
+ headers = {"Content-Type": "application/json"}
16
+
17
+ # 发送请求
18
+ response = requests.post(API_URL, json=payload, headers=headers)
19
+ response.raise_for_status() # 检查是否请求成功
20
+
21
+ # 返回响应结果
22
+ result = response.json() # 假设服务器返回的是 JSON 格式
23
+ return result.get("processed_text", "No result returned.")
24
+ except requests.RequestException as e:
25
+ return f"请求失败:{e}"
gitattributes ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ datasets==3.2.0
2
+ fastapi==0.115.6
3
+ git+https://github.com/huggingface/hffs.git@63298cde9f994a0ab16c3ba89c5f7a9d140f20b2
4
+ matplotlib==3.8.4
5
+ numpy==2.2.2
6
+ pandas==2.2.3
7
+ plotly==5.22.0
8
+ pyarrow==19.0.0
9
+ tqdm==4.67.1
10
+ uvicorn==0.34.0
11
+ fsspec[http]
12
+ tqdm
13
+ tabulate
14
+ flask==3.0.3
test_server.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ from flask import Flask, request, jsonify
3
+
4
+ app = Flask(__name__)
5
+
6
+ @app.route('/api/process', methods=['POST'])
7
+ def process_text():
8
+ data = request.get_json()
9
+ input_text = data.get("text", "")
10
+
11
+ time.sleep(1)
12
+ processed_text = f"{input_text[::-1]}"
13
+
14
+ return jsonify({"processed_text": processed_text})
15
+
16
+ if __name__ == "__main__":
17
+ app.run(host="127.0.0.1", port=5000)
user_input.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ 1234
utils.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #####################################################
2
+ # Utils
3
+ #####################################################
4
+ # 本文件包含了一些用于数据处理和绘图的实用函数。
5
+
6
+ import base64
7
+ from io import BytesIO
8
+ from matplotlib import pyplot as plt
9
+ import pandas as pd
10
+ import plotly.graph_objects as go
11
+ import numpy as np
12
+
13
+
14
+ def ndarray_to_base64(ndarray):
15
+ """
16
+ 将一维np.ndarray绘图并转换为Base64编码。
17
+ """
18
+ # 创建绘图
19
+ plt.figure(figsize=(8, 4))
20
+ plt.plot(ndarray)
21
+ plt.title("Vector Plot")
22
+ plt.xlabel("Index")
23
+ plt.ylabel("Value")
24
+ plt.tight_layout()
25
+
26
+ # 保存图像到内存字节流
27
+ buffer = BytesIO()
28
+ plt.savefig(buffer, format="png")
29
+ plt.close()
30
+ buffer.seek(0)
31
+
32
+ # 转换为Base64字符串
33
+ base64_str = base64.b64encode(buffer.getvalue()).decode('utf-8')
34
+ return f"data:image/png;base64,{base64_str}"
35
+
36
+ def flatten_ndarray_column(df, column_name, rows_to_include):
37
+ """
38
+ 将嵌套的np.ndarray列展平为多列,并只保留指定的行。
39
+ """
40
+ def select_and_flatten(ndarray):
41
+ if isinstance(ndarray, np.ndarray) and ndarray.dtype == 'O':
42
+ selected = [ndarray[i] for i in rows_to_include if i < len(ndarray)]
43
+ return np.concatenate([select_and_flatten(subarray) for subarray in selected])
44
+ elif isinstance(ndarray, np.ndarray) and ndarray.ndim == 1:
45
+ return np.expand_dims(ndarray, axis=0)
46
+ return ndarray
47
+
48
+ selected_data = df[column_name].apply(select_and_flatten)
49
+
50
+ for i, index in enumerate(rows_to_include):
51
+ df[f'{column_name}_{index}'] = selected_data.apply(lambda x: x[i])
52
+
53
+ return df
54
+
55
+ def create_plot(dfs:list[pd.DataFrame], ids:list[str]):
56
+ """
57
+ 创建一个包含所有传入 DataFrame 的线图。
58
+ """
59
+ fig = go.Figure()
60
+ for df, df_id in zip(dfs, ids):
61
+ for i, column in enumerate(df.columns[1:]):
62
+ fig.add_trace(go.Scatter(
63
+ x=df[df.columns[0]],
64
+ y=df[column],
65
+ mode='lines',
66
+ name=f"item_{df_id} - {column}",
67
+ visible=True if i == 0 else 'legendonly'
68
+ ))
69
+
70
+ # 配置图例
71
+ fig.update_layout(
72
+ legend=dict(
73
+ title="Variables",
74
+ orientation="h",
75
+ yanchor="top",
76
+ y=-0.2,
77
+ xanchor="center",
78
+ x=0.5
79
+ ),
80
+ xaxis_title='Time',
81
+ yaxis_title='Values'
82
+ )
83
+ return fig
84
+
85
+ def create_statistic(dfs: list[pd.DataFrame], ids: list[str]):
86
+ """
87
+ 计算数据集列表的统计信息。
88
+ """
89
+ stats_list = []
90
+
91
+ for df, id in zip(dfs, ids):
92
+ df_values = df.iloc[:, 1:]
93
+ # 计算统计值
94
+ mean_values = df_values.mean().round(2)
95
+ std_values = df_values.std().round(2)
96
+ max_values = df_values.max().round(2)
97
+ min_values = df_values.min().round(2)
98
+
99
+ # 将这些统计信息合并成一个新的DataFrame
100
+ stats_df = pd.DataFrame({
101
+ 'Variables': [f"{id}_{col}" for col in df_values.columns],
102
+ 'mean': mean_values.values,
103
+ 'std': std_values.values,
104
+ 'max': max_values.values,
105
+ 'min': min_values.values
106
+ })
107
+ stats_list.append(stats_df)
108
+
109
+ # 合并所有统计信息DataFrame
110
+ combined_stats_df = pd.concat(stats_list, ignore_index=True)
111
+ return combined_stats_df
112
+
113
+ def clean_up_df(df: pd.DataFrame, rows_to_include: list[int]) -> pd.DataFrame:
114
+ """
115
+ 清理数据集,将嵌套的np.ndarray列展平为多列。
116
+ """
117
+ rows_to_include = sorted(rows_to_include)
118
+
119
+ df['timestamp'] = df.apply(lambda row: pd.date_range(
120
+ start=row['start'],
121
+ periods=len(row['target'][0]) if isinstance(row['target'][0], np.ndarray) else len(row['target']),
122
+ freq=row['freq']
123
+ ).to_pydatetime().tolist(), axis=1)
124
+ df = flatten_ndarray_column(df, 'target', rows_to_include)
125
+ # 删除原始的start和freq列
126
+ df.drop(columns=['start', 'freq', 'target'], inplace=True)
127
+ if 'past_feat_dynamic_real' in df.columns:
128
+ df.drop(columns=['past_feat_dynamic_real'], inplace=True)
129
+ return df
130
+
131
+ if __name__ == '__main__':
132
+
133
+ # 创建测试数据
134
+ data1 = {
135
+ 'Time': ['2023-01-01', '2023-01-02', '2023-01-03'],
136
+ 'Value1': [10, 15, 20],
137
+ 'Value2': [20, 25, 30]
138
+ }
139
+
140
+ data2 = {
141
+ 'Time': ['2023-01-01', '2023-01-02', '2023-01-03'],
142
+ 'Value3': [5, 10, 15],
143
+ 'Value4': [15, 20, 25]
144
+ }
145
+
146
+ df1 = pd.DataFrame(data1)
147
+ df2 = pd.DataFrame(data2)
148
+
149
+ # 转换时间列为日期时间格式
150
+ df1['Time'] = pd.to_datetime(df1['Time'])
151
+ df2['Time'] = pd.to_datetime(df2['Time'])
152
+
153
+ # 创建图表
154
+ fig = create_plot(df1, df2)
155
+
156
+ # 显示图表
157
+ fig.show()