ts_explorer / app.py
Liu Yiwen
修改了展示内容
a25c8f7
import copy
import os
import time
from functools import lru_cache, partial
import gradio as gr
import numpy as np
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
from tqdm.contrib.concurrent import thread_map
from fastapi import FastAPI, Response
import uvicorn
from hffs.fs import HfFileSystem
from datasets import Features, Image, Audio, Sequence
from typing import List, Tuple, Callable
from utils import ndarray_to_base64, clean_up_df, create_statistic, create_plot, get_question_info
from comm_utils import save_to_file, send_msg_to_server, save_score
from config import *
class AppError(RuntimeError):
pass
APP_URL = "http://127.0.0.1:7860" if os.getenv("DEV") else "https://Kamarov-lotsa-explorer.hf.space"
PAGE_SIZE = 1
MAX_CACHED_BLOBS = PAGE_SIZE * 10
TIME_PLOTS_NUM = 1
_blobs_cache = {}
#####################################################
# Define routes for image and audio files
#####################################################
app = FastAPI()
@app.get(
"/image",
responses={200: {"content": {"image/png": {}}}},
response_class=Response,
)
def image(id: str):
blob = get_blob(id)
return Response(content=blob, media_type="image/png")
@app.get(
"/audio",
responses={200: {"content": {"audio/wav": {}}}},
response_class=Response,
)
def audio(id: str):
blob = get_blob(id)
return Response(content=blob, media_type="audio/wav")
def push_blob(blob: bytes, blob_id: str) -> str:
global _blobs_cache
if blob_id in _blobs_cache:
del _blobs_cache[blob_id]
_blobs_cache[blob_id] = blob
if len(_blobs_cache) > MAX_CACHED_BLOBS:
del _blobs_cache[next(iter(_blobs_cache))]
return blob_id
def get_blob(blob_id: str) -> bytes:
global _blobs_cache
return _blobs_cache[blob_id]
def blobs_to_urls(blobs: List[bytes], type: str, prefix: str) -> List[str]:
image_blob_ids = [push_blob(blob, f"{prefix}-{i}") for i, blob in enumerate(blobs)]
return [APP_URL + f"/{type}?id={blob_id}" for blob_id in image_blob_ids]
#####################################################
# List configs, splits and parquet files
#####################################################
@lru_cache(maxsize=128)
def get_parquet_fs(dataset: str) -> HfFileSystem:
try:
fs = HfFileSystem(dataset, repo_type="dataset", revision="refs/convert/parquet")
if any(fs.isfile(path) for path in fs.ls("") if not path.startswith(".")):
raise AppError(f"Parquet export doesn't exist for '{dataset}'.")
return fs
except:
raise AppError(f"Parquet export doesn't exist for '{dataset}'.")
@lru_cache(maxsize=128)
def get_parquet_configs(dataset: str) -> List[str]:
fs = get_parquet_fs(dataset)
return [path for path in fs.ls("") if fs.isdir(path)]
def _sorted_split_key(split: str) -> str:
return split if not split.startswith("train") else chr(0) + split # always "train" first
@lru_cache(maxsize=128)
def get_parquet_splits(dataset: str, config: str) -> List[str]:
fs = get_parquet_fs(dataset)
return [path.split("/")[1] for path in fs.ls(config) if fs.isdir(path)]
#####################################################
# Index and query Parquet data
#####################################################
RowGroupReaders = List[Callable[[], pa.Table]]
@lru_cache(maxsize=128)
def index(dataset: str, config: str, split: str) -> Tuple[np.ndarray, RowGroupReaders, int, Features]:
fs = get_parquet_fs(dataset)
sources = fs.glob(f"{config}/{split}/*.parquet")
if not sources:
if config not in get_parquet_configs(dataset):
raise AppError(f"Invalid config {config}. Available configs are: {', '.join(get_parquet_configs(dataset))}.")
else:
raise AppError(f"Invalid split {split}. Available splits are: {', '.join(get_parquet_splits(dataset, config))}.")
desc = f"{dataset}/{config}/{split}"
all_pf: List[pq.ParquetFile] = thread_map(partial(pq.ParquetFile, filesystem=fs), sources, desc=desc, unit="pq")
features = Features.from_arrow_schema(all_pf[0].schema.to_arrow_schema())
rg_offsets = np.cumsum([pf.metadata.row_group(i).num_rows for pf in all_pf for i in range(pf.metadata.num_row_groups)])
rg_readers = [partial(pf.read_row_group, i) for pf in all_pf for i in range(pf.metadata.num_row_groups)]
max_page = 1 + (rg_offsets[-1] - 1) // PAGE_SIZE
return rg_offsets, rg_readers, max_page, features
def query(page: int, page_size: int, rg_offsets: np.ndarray, rg_readers: RowGroupReaders) -> pd.DataFrame:
start_row, end_row = (page - 1) * page_size, min(page * page_size, rg_offsets[-1] - 1) # both included
# rg_offsets[start_rg - 1] <= start_row < rg_offsets[start_rg]
# rg_offsets[end_rg - 1] <= end_row < rg_offsets[end_rg]
start_rg, end_rg = np.searchsorted(rg_offsets, [start_row, end_row], side="right") # both included
t = time.time()
# TODO:性能瓶颈
pa_table = pa.concat_tables([rg_readers[i]() for i in range(start_rg, end_rg + 1)])
print(f"concat_tables time: {time.time()-t}")
offset = start_row - (rg_offsets[start_rg - 1] if start_rg > 0 else 0)
pa_table = pa_table.slice(offset, page_size)
return pa_table.to_pandas()
def sanitize_inputs(dataset: str, config: str, split: str, page: str) -> Tuple[str, str, str, int]:
try:
page = int(page)
assert page > 0
except:
raise AppError(f"Bad page: {page}")
if not dataset:
raise AppError("Empty dataset name")
if not config:
raise AppError(f"Empty config. Available configs are: {', '.join(get_parquet_configs(dataset))}.")
if not split:
raise AppError(f"Empty split. Available splits are: {', '.join(get_parquet_splits(dataset, config))}.")
return dataset, config, split, int(page)
@lru_cache(maxsize=128)
def get_page_df(dataset: str, config: str, split: str, page: str) -> Tuple[pd.DataFrame, int, Features]:
dataset, config, split, page = sanitize_inputs(dataset, config, split, page)
rg_offsets, rg_readers, max_page, features = index(dataset, config, split)
if page > max_page:
raise AppError(f"Page {page} does not exist")
df = query(page, PAGE_SIZE, rg_offsets=rg_offsets, rg_readers=rg_readers)
return df, max_page, features
#####################################################
# Format results
#####################################################
def get_page(dataset: str, config: str, split: str, page: str) -> Tuple[str, int, str]:
df_, max_page, features = get_page_df(dataset, config, split, page)
df = copy.deepcopy(df_)
unsupported_columns = []
if dataset == TARGET_DATASET:
# 对Salesforce/lotsa_data数据集进行特殊处理
info = "" if not unsupported_columns else f"Some columns are not supported yet: {unsupported_columns}"
return df, max_page, info
elif dataset == BENCHMARK_DATASET:
# 对YY26/TS_DATASETS数据集进行特殊处理
info = "" if not unsupported_columns else f"Some columns are not supported yet: {unsupported_columns}"
return df, max_page, info
else:
# 其他数据集保留原有逻辑
for column, feature in features.items():
if isinstance(feature, Image):
blob_type = "image" # TODO: support audio - right now it seems that the markdown renderer in gradio doesn't support audio and shows nothing
blob_urls = blobs_to_urls([item.get("bytes") if isinstance(item, dict) else None for item in df[column]], blob_type, prefix=f"{dataset}-{config}-{split}-{page}-{column}")
df = df.drop([column], axis=1)
df[column] = [f"![]({url})" for url in blob_urls]
elif any(bad_type in str(feature) for bad_type in ["Image(", "Audio(", "'binary'"]):
unsupported_columns.append(column)
df = df.drop([column], axis=1)
elif isinstance(feature, Sequence):
if feature.feature.dtype == 'float32':
# 直接将内容绘图,并嵌入为Base64编码
base64_srcs = [ndarray_to_base64(vec) for vec in df[column]]
df = df.drop([column], axis=1)
df[column] = [f"![]({src})" for src in base64_srcs]
info = "" if not unsupported_columns else f"Some columns are not supported yet: {unsupported_columns}"
return df.reset_index().to_markdown(index=False), max_page, info
#####################################################
# Process data
#####################################################
def process_salesforce_data(dataset: str, config: str, split: str, page: List[str], sub_targets: List[int|str]) -> Tuple[List[pd.DataFrame], List[str]]:
df_list, id_list = [], []
for i, page in enumerate(page):
df, max_page, info = get_page(dataset, config, split, page)
global tot_samples, tot_targets
tot_samples, tot_targets = max_page, len(df['target'][0]) if isinstance(df['target'][0], np.ndarray) and df['target'][0].dtype == 'O' else 1
if 'all' in sub_targets:
sub_targets = [i for i in range(tot_targets)]
df = clean_up_df(df, sub_targets, SUBTARGET_MEANING_MAP[config])
row = df.iloc[0]
id_list.append(row['item_id'])
# 将单行的DataFrame展开为新的DataFrame
df_without_index = row.drop('item_id').to_frame().T
df_expanded = df_without_index.apply(pd.Series.explode).reset_index(drop=True).fillna(0)
df_list.append(df_expanded)
return df_list, id_list
#####################################################
# Gradio app
#####################################################
with gr.Blocks() as demo:
# 初始化组件
gr.Markdown("A tool for interactive observation of lotsa dataset, extended from lhoestq/datasets-explorer")
cp_dataset = gr.Textbox(BENCHMARK_DATASET, label="Pick a dataset", interactive=False)
cp_go = gr.Button("Explore")
cp_config = gr.Dropdown(["plain_text"], value="plain_text", label="Config", visible=False)
cp_split = gr.Dropdown(["train", "validation"], value="train", label="Split", visible=False)
cp_goto_next_page = gr.Button("Next page", visible=False)
cp_error = gr.Markdown("", visible=False)
cp_info = gr.Markdown("", visible=False)
cp_result = gr.Markdown("", visible=False)
qusetion_id_box = gr.Textbox(visible=False)
tot_samples = 0
# 初始化Salesforce/lotsa_data数据集展示使用的组件
# componets = []
# for _ in range(TIME_PLOTS_NUM):
# with gr.Row():
# with gr.Column(scale=2):
# select_sample_box = gr.Dropdown(choices=["items"], label="Select some items", multiselect=True, interactive=True)
# with gr.Column(scale=2):
# select_subtarget_box = gr.Dropdown(choices=["subtargets"], label="Select some subtargets", multiselect=True, interactive=True)
# with gr.Column(scale=1):
# select_buttom = gr.Button("Show selected items")
with gr.Row():
with gr.Column(scale=2):
statistics_textbox = gr.DataFrame()
hr_line = gr.HTML('<hr style="border: 1px solid black;">')
question_info_textbox_p1 = gr.DataFrame()
question_info_textbox_p2 = gr.DataFrame()
with gr.Column(scale=3):
plot = gr.Plot()
with gr.Row():
user_input_box = gr.Textbox(label="question", interactive=False)
user_output_box = gr.Textbox(label="answer", interactive=False)
# componets.append({"select_sample_box": select_sample_box,
# "statistics_textbox": statistics_textbox,
# "user_input_box": user_input_box,
# "plot": plot})
hr_line_ = gr.HTML('<hr style="border: 2px dashed black;">')
with gr.Row():
with gr.Column(scale=1):
choose_retain = gr.Dropdown(["delete", "retain", "modify"], label="Choose to retain or delete or modify", interactive=True)
with gr.Column(scale=2):
choose_retain_reason_box = gr.Textbox(label="Reason", placeholder="Enter your reason", interactive=True)
score_slider = gr.Slider(1, 5, 1, step=1, label="Score for answer", interactive=True)
with gr.Row():
with gr.Column(scale=2):
user_name_box = gr.Textbox(label="user_name", placeholder="Enter your name firstly", interactive=True)
user_submit_button = gr.Button("submit", interactive=True)
with gr.Column(scale=1):
submit_info_box = gr.Textbox(label="submit_info", interactive=False)
with gr.Row():
cp_page = gr.Textbox("1", label="Page", placeholder="1", visible=False)
cp_goto_page = gr.Button("Go to page", visible=False)
def show_error(message: str) -> dict:
return {
cp_error: gr.update(visible=True, value=f"## ❌ Error:\n\n{message}"),
cp_info: gr.update(visible=False, value=""),
cp_result: gr.update(visible=False, value=""),
}
def show_dataset_at_config_and_split_and_page(dataset: str, config: str, split: str, page: str|List[str], sub_targets: List[int|str]=['all']) -> dict:
try:
ret = {}
if dataset == TARGET_DATASET:
if type(page) == str:
page = [page]
df_list, id_list = process_salesforce_data(dataset, config, split, page, sub_targets)
ret[statistics_textbox] = gr.update(value=create_statistic(df_list, id_list))
ret[plot] = gr.update(value=create_plot(df_list, id_list))
elif dataset == BENCHMARK_DATASET:
df, max_page, info = get_page(dataset, config, split, page)
question_info_p1 = get_question_info(df, [COLUMN_DOMAIN, COLUMN_SOURCE])
question_info_p2 = get_question_info(df, [COLUMN_QA_TYPE, COLUMN_TASK_TYPE])
ret[qusetion_id_box] = gr.update(value = df[COLUMN_ID][0])
lotsa_config, lotsa_page = str(df[COLUMN_SOURCE][0]).split('/')[-1], eval(df[COLUMN_TS_ID][0])
#TODO: 对partial-train的处理
lotsa_split = get_parquet_splits(TARGET_DATASET, lotsa_config)[0]
start_index, end_index = df[COLUMN_START_INDEX][0], df[COLUMN_END_INDEX][0]
interval = None if np.isnan(start_index) or np.isnan(end_index) else [start_index, end_index]
lotsa_subtargets = eval(df[COLUMN_TARGET_ID][0])
df_list, id_list = process_salesforce_data(TARGET_DATASET, lotsa_config, lotsa_split, lotsa_page, lotsa_subtargets)
ret[question_info_textbox_p1] = gr.update(value=question_info_p1)
ret[question_info_textbox_p2] = gr.update(value=question_info_p2)
ret[statistics_textbox] = gr.update(value=create_statistic(df_list, id_list, interval=interval))
ret[plot] = gr.update(value=create_plot(df_list, id_list, interval=interval))
ret[user_input_box] = gr.update(value=df[COLUMN_QUESTION][0])
ret[user_output_box] = gr.update(value=df[COLUMN_ANSWER][0])
ret[submit_info_box] = gr.update(value="")
else:
markdown_result, max_page, info = get_page(dataset, config, split, page)
ret[cp_result] = gr.update(visible=True, value=markdown_result)
return {
**ret,
cp_info: gr.update(visible=True, value=f"Page {page}/{max_page} {info}"),
cp_error: gr.update(visible=False, value="")
}
except AppError as err:
return show_error(str(err))
def show_dataset_at_config_and_split_and_next_page(dataset: str, config: str, split: str, page: str) -> dict:
try:
next_page = str(int(page) + 1)
return {
**show_dataset_at_config_and_split_and_page(dataset, config, split, next_page),
cp_page: gr.update(value=next_page, visible=True),
}
except AppError as err:
return show_error(str(err))
def show_dataset_at_config_and_split(dataset: str, config: str, split: str) -> dict:
try:
return {
**show_dataset_at_config_and_split_and_page(dataset, config, split, "1", [0]),
# select_sample_box: gr.update(choices=[f"{i+1}" for i in range(tot_samples)], value=["1"]),
# select_subtarget_box: gr.update(choices=[i for i in range(tot_targets)]+['all'], value=[0]),
cp_page: gr.update(value="1", visible=True),
cp_goto_page: gr.update(visible=True),
cp_goto_next_page: gr.update(visible=True),
}
except AppError as err:
return show_error(str(err))
def show_dataset_at_config(dataset: str, config: str) -> dict:
try:
splits = get_parquet_splits(dataset, config)
if not splits:
raise AppError(f"Dataset {dataset} with config {config} has no splits.")
else:
split = splits[0]
return {
**show_dataset_at_config_and_split(dataset, config, split),
cp_split: gr.update(value=split, choices=splits, visible=len(splits) > 1),
}
except AppError as err:
return show_error(str(err))
def show_dataset(dataset: str) -> dict:
try:
configs = get_parquet_configs(dataset)
if not configs:
raise AppError(f"Dataset {dataset} has no configs.")
else:
config = configs[0]
return {
**show_dataset_at_config(dataset, config),
cp_config: gr.update(value=config, choices=configs, visible=len(configs) > 1),
}
except AppError as err:
return show_error(str(err))
all_outputs = [cp_config, cp_split,
cp_page, cp_goto_page, cp_goto_next_page,
cp_result, cp_info, cp_error,
# select_sample_box, select_subtarget_box,
# select_buttom,
statistics_textbox, plot,
qusetion_id_box,
user_input_box, user_output_box,
submit_info_box,
question_info_textbox_p1, question_info_textbox_p2]
cp_go.click(show_dataset, inputs=[cp_dataset], outputs=all_outputs)
cp_config.change(show_dataset_at_config, inputs=[cp_dataset, cp_config], outputs=all_outputs)
cp_split.change(show_dataset_at_config_and_split, inputs=[cp_dataset, cp_config, cp_split], outputs=all_outputs)
cp_goto_page.click(show_dataset_at_config_and_split_and_page, inputs=[cp_dataset, cp_config, cp_split, cp_page], outputs=all_outputs)
cp_goto_next_page.click(show_dataset_at_config_and_split_and_next_page, inputs=[cp_dataset, cp_config, cp_split, cp_page], outputs=all_outputs)
user_submit_button.click(save_score, inputs=[user_name_box, cp_config, qusetion_id_box, score_slider, choose_retain, choose_retain_reason_box], outputs=[submit_info_box])
# select_buttom.click(show_dataset_at_config_and_split_and_page, inputs=[cp_dataset, cp_config, cp_split, select_sample_box, select_subtarget_box], outputs=all_outputs)
if __name__ == "__main__":
app = gr.mount_gradio_app(app, demo, path="/")
# host = "127.0.0.1" if os.getenv("DEV") else "0.0.0.0"
host = "0.0.0.0"
# import subprocess
# subprocess.Popen(["python", "test_server.py"])
uvicorn.run(app, host=host, port=7860)
#// 对一下数据 --
#// 部署到服务器上
#// 测试一下功能 --
#// 加一个选择文本框【删除、保留、修改】,加一个意见的文本框 --
#// 横坐标增加一个代表index的轴 -
#// 加一个物理含义的映射 -