Spaces:
Sleeping
Sleeping
import base64 | |
import copy | |
from datetime import datetime, timedelta | |
from io import BytesIO | |
import random | |
import gradio as gr | |
from functools import lru_cache | |
from hffs.fs import HfFileSystem | |
from typing import List, Tuple, Callable | |
from matplotlib import pyplot as plt | |
import pandas as pd | |
import numpy as np | |
import pyarrow as pa | |
import pyarrow.parquet as pq | |
from functools import partial | |
from tqdm.contrib.concurrent import thread_map | |
from datasets import Features, Image, Audio, Sequence | |
from fastapi import FastAPI, Response | |
import uvicorn | |
import os | |
from gradio_datetimerange import DateTimeRange | |
class AppError(RuntimeError): | |
pass | |
APP_URL = "http://127.0.0.1:7860" if os.getenv("DEV") else "https://lhoestq-datasets-explorer.hf.space" | |
PAGE_SIZE = 5 | |
MAX_CACHED_BLOBS = PAGE_SIZE * 10 | |
TIME_PLOTS_NUM = 5 | |
_blobs_cache = {} | |
##################################################### | |
# Utils | |
##################################################### | |
def ndarray_to_base64(ndarray): | |
""" | |
将一维np.ndarray绘图并转换为Base64编码。 | |
""" | |
# 创建绘图 | |
plt.figure(figsize=(8, 4)) | |
plt.plot(ndarray) | |
plt.title("Vector Plot") | |
plt.xlabel("Index") | |
plt.ylabel("Value") | |
plt.tight_layout() | |
# 保存图像到内存字节流 | |
buffer = BytesIO() | |
plt.savefig(buffer, format="png") | |
plt.close() | |
buffer.seek(0) | |
# 转换为Base64字符串 | |
base64_str = base64.b64encode(buffer.getvalue()).decode('utf-8') | |
return f"data:image/png;base64,{base64_str}" | |
def flatten_ndarray_column(df, column_name): | |
def flatten_ndarray(ndarray): | |
if isinstance(ndarray, np.ndarray) and ndarray.dtype == 'O': | |
return np.concatenate([flatten_ndarray(subarray) for subarray in ndarray]) | |
elif isinstance(ndarray, np.ndarray) and ndarray.ndim == 1: | |
return np.expand_dims(ndarray, axis=0) | |
return ndarray | |
flattened_data = df[column_name].apply(flatten_ndarray) | |
max_length = max(flattened_data.apply(len)) | |
for i in range(max_length): | |
df[f'{column_name}_{i}'] = flattened_data.apply(lambda x: x[i] if i < len(x) else np.nan) | |
return df | |
##################################################### | |
# Define routes for image and audio files | |
##################################################### | |
app = FastAPI() | |
def image(id: str): | |
blob = get_blob(id) | |
return Response(content=blob, media_type="image/png") | |
def audio(id: str): | |
blob = get_blob(id) | |
return Response(content=blob, media_type="audio/wav") | |
def push_blob(blob: bytes, blob_id: str) -> str: | |
global _blobs_cache | |
if blob_id in _blobs_cache: | |
del _blobs_cache[blob_id] | |
_blobs_cache[blob_id] = blob | |
if len(_blobs_cache) > MAX_CACHED_BLOBS: | |
del _blobs_cache[next(iter(_blobs_cache))] | |
return blob_id | |
def get_blob(blob_id: str) -> bytes: | |
global _blobs_cache | |
return _blobs_cache[blob_id] | |
def blobs_to_urls(blobs: List[bytes], type: str, prefix: str) -> List[str]: | |
image_blob_ids = [push_blob(blob, f"{prefix}-{i}") for i, blob in enumerate(blobs)] | |
return [APP_URL + f"/{type}?id={blob_id}" for blob_id in image_blob_ids] | |
##################################################### | |
# List configs, splits and parquet files | |
##################################################### | |
def get_parquet_fs(dataset: str) -> HfFileSystem: | |
try: | |
fs = HfFileSystem(dataset, repo_type="dataset", revision="refs/convert/parquet") | |
if any(fs.isfile(path) for path in fs.ls("") if not path.startswith(".")): | |
raise AppError(f"Parquet export doesn't exist for '{dataset}'.") | |
return fs | |
except: | |
raise AppError(f"Parquet export doesn't exist for '{dataset}'.") | |
def get_parquet_configs(dataset: str) -> List[str]: | |
fs = get_parquet_fs(dataset) | |
return [path for path in fs.ls("") if fs.isdir(path)] | |
def _sorted_split_key(split: str) -> str: | |
return split if not split.startswith("train") else chr(0) + split # always "train" first | |
def get_parquet_splits(dataset: str, config: str) -> List[str]: | |
fs = get_parquet_fs(dataset) | |
return [path.split("/")[1] for path in fs.ls(config) if fs.isdir(path)] | |
##################################################### | |
# Index and query Parquet data | |
##################################################### | |
RowGroupReaders = List[Callable[[], pa.Table]] | |
def index(dataset: str, config: str, split: str) -> Tuple[np.ndarray, RowGroupReaders, int, Features]: | |
fs = get_parquet_fs(dataset) | |
sources = fs.glob(f"{config}/{split}/*.parquet") | |
if not sources: | |
if config not in get_parquet_configs(dataset): | |
raise AppError(f"Invalid config {config}. Available configs are: {', '.join(get_parquet_configs(dataset))}.") | |
else: | |
raise AppError(f"Invalid split {split}. Available splits are: {', '.join(get_parquet_splits(dataset, config))}.") | |
desc = f"{dataset}/{config}/{split}" | |
all_pf: List[pq.ParquetFile] = thread_map(partial(pq.ParquetFile, filesystem=fs), sources, desc=desc, unit="pq") | |
features = Features.from_arrow_schema(all_pf[0].schema.to_arrow_schema()) | |
rg_offsets = np.cumsum([pf.metadata.row_group(i).num_rows for pf in all_pf for i in range(pf.metadata.num_row_groups)]) | |
rg_readers = [partial(pf.read_row_group, i) for pf in all_pf for i in range(pf.metadata.num_row_groups)] | |
max_page = 1 + (rg_offsets[-1] - 1) // PAGE_SIZE | |
return rg_offsets, rg_readers, max_page, features | |
def query(page: int, page_size: int, rg_offsets: np.ndarray, rg_readers: RowGroupReaders) -> pd.DataFrame: | |
start_row, end_row = (page - 1) * page_size, min(page * page_size, rg_offsets[-1] - 1) # both included | |
# rg_offsets[start_rg - 1] <= start_row < rg_offsets[start_rg] | |
# rg_offsets[end_rg - 1] <= end_row < rg_offsets[end_rg] | |
start_rg, end_rg = np.searchsorted(rg_offsets, [start_row, end_row], side="right") # both included | |
pa_table = pa.concat_tables([rg_readers[i]() for i in range(start_rg, end_rg + 1)]) | |
offset = start_row - (rg_offsets[start_rg - 1] if start_rg > 0 else 0) | |
pa_table = pa_table.slice(offset, page_size) | |
return pa_table.to_pandas() | |
def sanitize_inputs(dataset: str, config: str, split: str, page: str) -> Tuple[str, str, str, int]: | |
try: | |
page = int(page) | |
assert page > 0 | |
except: | |
raise AppError(f"Bad page: {page}") | |
if not dataset: | |
raise AppError("Empty dataset name") | |
if not config: | |
raise AppError(f"Empty config. Available configs are: {', '.join(get_parquet_configs(dataset))}.") | |
if not split: | |
raise AppError(f"Empty split. Available splits are: {', '.join(get_parquet_splits(dataset, config))}.") | |
return dataset, config, split, int(page) | |
def get_page_df(dataset: str, config: str, split: str, page: str) -> Tuple[pd.DataFrame, int, Features]: | |
dataset, config, split, page = sanitize_inputs(dataset, config, split, page) | |
rg_offsets, rg_readers, max_page, features = index(dataset, config, split) | |
if page > max_page: | |
raise AppError(f"Page {page} does not exist") | |
df = query(page, PAGE_SIZE, rg_offsets=rg_offsets, rg_readers=rg_readers) | |
return df, max_page, features | |
##################################################### | |
# Format results | |
##################################################### | |
def get_page(dataset: str, config: str, split: str, page: str) -> Tuple[str, int, str]: | |
df_, max_page, features = get_page_df(dataset, config, split, page) | |
df = copy.deepcopy(df_) | |
unsupported_columns = [] | |
if dataset != 'Salesforce/lotsa_data': | |
for column, feature in features.items(): | |
if isinstance(feature, Image): | |
blob_type = "image" # TODO: support audio - right now it seems that the markdown renderer in gradio doesn't support audio and shows nothing | |
blob_urls = blobs_to_urls([item.get("bytes") if isinstance(item, dict) else None for item in df[column]], blob_type, prefix=f"{dataset}-{config}-{split}-{page}-{column}") | |
df = df.drop([column], axis=1) | |
df[column] = [f"" for url in blob_urls] | |
elif any(bad_type in str(feature) for bad_type in ["Image(", "Audio(", "'binary'"]): | |
unsupported_columns.append(column) | |
df = df.drop([column], axis=1) | |
elif isinstance(feature, Sequence): | |
if feature.feature.dtype == 'float32': | |
# 直接将内容绘图,并嵌入为Base64编码 | |
base64_srcs = [ndarray_to_base64(vec) for vec in df[column]] | |
df = df.drop([column], axis=1) | |
df[column] = [f"" for src in base64_srcs] | |
info = "" if not unsupported_columns else f"Some columns are not supported yet: {unsupported_columns}" | |
return df.reset_index().to_markdown(index=False), max_page, info | |
else: | |
# 其他的处理逻辑 | |
info = "" if not unsupported_columns else f"Some columns are not supported yet: {unsupported_columns}" | |
return df, max_page, info | |
##################################################### | |
# Gradio app | |
##################################################### | |
with gr.Blocks() as demo: | |
gr.Markdown("# 📖 Datasets Explorer\n\nAccess any slice of data of any dataset on the [Hugging Face Dataset Hub](https://huggingface.co/datasets)") | |
gr.Markdown("This is the dataset viewer from parquet export demo before the feature was added on the Hugging Face website.") | |
cp_dataset = gr.Textbox("Salesforce/lotsa_data", label="Pick a dataset", placeholder="competitions/aiornot") | |
cp_go = gr.Button("Explore") | |
cp_config = gr.Dropdown(["plain_text"], value="plain_text", label="Config", visible=False) | |
cp_split = gr.Dropdown(["train", "validation"], value="train", label="Split", visible=False) | |
cp_goto_next_page = gr.Button("Next page", visible=False) | |
cp_error = gr.Markdown("", visible=False) | |
cp_info = gr.Markdown("", visible=False) | |
cp_result = gr.Markdown("", visible=False) | |
now = datetime.now() | |
df = pd.DataFrame({ | |
'time': [now - timedelta(minutes=5*i) for i in range(25)] + [now], | |
'price': np.random.randint(100, 1000, 26), | |
'origin': [random.choice(["DFW", "DAL", "HOU"]) for _ in range(26)], | |
'destination': [random.choice(["JFK", "LGA", "EWR"]) for _ in range(26)], | |
}) | |
componets = [] | |
for _ in range(TIME_PLOTS_NUM): | |
with gr.Row(): | |
textbox = gr.Textbox("名称或说明") | |
with gr.Column(): | |
daterange = DateTimeRange(["now - 24h", "now"]) | |
plot1 = gr.LinePlot(df, x="time", y="price", color="origin") | |
# plot2 = gr.LinePlot(df, x="time", y="price", color="origin") | |
daterange.bind([plot1, | |
# plot2, | |
]) | |
comp = { | |
"textbox" : textbox, | |
"daterange" : daterange, | |
"plot1" : plot1, | |
# "plot2" : plot2, | |
} | |
componets.append(comp) | |
with gr.Row(): | |
cp_page = gr.Textbox("1", label="Page", placeholder="1", visible=False) | |
cp_goto_page = gr.Button("Go to page", visible=False) | |
def show_error(message: str) -> dict: | |
return { | |
cp_error: gr.update(visible=True, value=f"## ❌ Error:\n\n{message}"), | |
cp_info: gr.update(visible=False, value=""), | |
cp_result: gr.update(visible=False, value=""), | |
} | |
def show_dataset_at_config_and_split_and_page(dataset: str, config: str, split: str, page: str) -> dict: | |
try: | |
ret = {} | |
if dataset != 'Salesforce/lotsa_data': | |
markdown_result, max_page, info = get_page(dataset, config, split, page) | |
ret[cp_result] = gr.update(visible=True, value=markdown_result) | |
else: | |
df, max_page, info = get_page(dataset, config, split, page) | |
print(df.columns) | |
# TODO:target为一维数组时len(row['target'][0])会直接报错 | |
df['timestamp'] = df.apply(lambda row: pd.date_range(start=row['start'], periods=len(row['target'][0]), freq=row['freq']).to_pydatetime().tolist(), axis=1) | |
df = flatten_ndarray_column(df, 'target') | |
# 删除原始的start和freq列 | |
df.drop(columns=['start', 'freq', 'target'], inplace=True) | |
if 'past_feat_dynamic_real' in df.columns: | |
df.drop(columns=['past_feat_dynamic_real'], inplace=True) | |
info = f"({info})" if info else "" | |
for i, rows in df.iterrows(): | |
index = rows['item_id'] | |
df_without_index = rows.drop('item_id').to_frame().T | |
df_expanded = df_without_index.apply(pd.Series.explode).reset_index(drop=True).fillna(0) | |
ret.update({ | |
componets[i]["textbox"]: gr.update(value=f"item_id: {index}"), | |
componets[i]["daterange"]: gr.update(value=[df_without_index['timestamp'][i][0], df_without_index['timestamp'][i][-1]]), | |
componets[i]["plot1"]: gr.update(value=df_expanded, x="timestamp", y="target_0"), | |
}) | |
return { | |
**ret, | |
cp_info: gr.update(visible=True, value=f"Page {page}/{max_page} {info}"), | |
cp_error: gr.update(visible=False, value="") | |
} | |
except AppError as err: | |
return show_error(str(err)) | |
def show_dataset_at_config_and_split_and_next_page(dataset: str, config: str, split: str, page: str) -> dict: | |
try: | |
next_page = str(int(page) + 1) | |
return { | |
**show_dataset_at_config_and_split_and_page(dataset, config, split, next_page), | |
cp_page: gr.update(value=next_page, visible=True), | |
} | |
except AppError as err: | |
return show_error(str(err)) | |
def show_dataset_at_config_and_split(dataset: str, config: str, split: str) -> dict: | |
try: | |
return { | |
**show_dataset_at_config_and_split_and_page(dataset, config, split, "1"), | |
cp_page: gr.update(value="1", visible=True), | |
cp_goto_page: gr.update(visible=True), | |
cp_goto_next_page: gr.update(visible=True), | |
} | |
except AppError as err: | |
return show_error(str(err)) | |
def show_dataset_at_config(dataset: str, config: str) -> dict: | |
try: | |
splits = get_parquet_splits(dataset, config) | |
if not splits: | |
raise AppError(f"Dataset {dataset} with config {config} has no splits.") | |
else: | |
split = splits[0] | |
return { | |
**show_dataset_at_config_and_split(dataset, config, split), | |
cp_split: gr.update(value=split, choices=splits, visible=len(splits) > 1), | |
} | |
except AppError as err: | |
return show_error(str(err)) | |
def show_dataset(dataset: str) -> dict: | |
try: | |
configs = get_parquet_configs(dataset) | |
if not configs: | |
raise AppError(f"Dataset {dataset} has no configs.") | |
else: | |
config = configs[0] | |
return { | |
**show_dataset_at_config(dataset, config), | |
cp_config: gr.update(value=config, choices=configs, visible=len(configs) > 1), | |
} | |
except AppError as err: | |
return show_error(str(err)) | |
""" | |
动态生成组件时使用gr.LinePlot会有bug,直接卡死在show_dataset部分 | |
""" | |
# @gr.render(triggers=[cp_go.click]) | |
# def create_test(): | |
# now = datetime.now() | |
# df = pd.DataFrame({ | |
# 'time': [now - timedelta(minutes=5*i) for i in range(25)], | |
# 'price': np.random.randint(100, 1000, 25), | |
# 'origin': [random.choice(["DFW", "DAL", "HOU"]) for _ in range(25)], | |
# 'destination': [random.choice(["JFK", "LGA", "EWR"]) for _ in range(25)], | |
# }) | |
# # componets = [] | |
# # daterange = DateTimeRange(["now - 24h", "now"]) | |
# plot1 = gr.LinePlot(df, x="time", y="price") | |
# plot2 = gr.LinePlot(df, x="time", y="price", color="origin") | |
# # # daterange.bind([plot1, plot2]) | |
# # componets.append(plot1) | |
# # componets.append(plot2) | |
# # componets.append(daterange) | |
# # test = gr.Textbox(label="input") | |
# # componets.append(test) | |
# # return componets | |
all_outputs = [cp_config, cp_split, cp_page, cp_goto_page, cp_goto_next_page, cp_result, cp_info, cp_error] | |
for comp in componets: | |
all_outputs += list(comp.values()) | |
cp_go.click(show_dataset, inputs=[cp_dataset], outputs=all_outputs) | |
cp_config.change(show_dataset_at_config, inputs=[cp_dataset, cp_config], outputs=all_outputs) | |
cp_split.change(show_dataset_at_config_and_split, inputs=[cp_dataset, cp_config, cp_split], outputs=all_outputs) | |
cp_goto_page.click(show_dataset_at_config_and_split_and_page, inputs=[cp_dataset, cp_config, cp_split, cp_page], outputs=all_outputs) | |
cp_goto_next_page.click(show_dataset_at_config_and_split_and_next_page, inputs=[cp_dataset, cp_config, cp_split, cp_page], outputs=all_outputs) | |
if __name__ == "__main__": | |
app = gr.mount_gradio_app(app, demo, path="/") | |
uvicorn.run(app, host="127.0.0.1", port=7860) | |
# 需求: | |
# target多变量没办法同时打到一个图上。有几种选择,可以选择拉一个框选,一次一个;或者用强行用颜色区分,或者用两个框分别展示(动态生成多个框没办法指定位置) | |
# 无法动态生成组件 | |
# 没有聚合、统计值等功能 | |
# 支持其他库的调用 |