Spaces:
Sleeping
Sleeping
File size: 6,634 Bytes
a454c92 c34122e a454c92 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 |
import random
import pandas as pd
import gradio as gr
from typing import Dict, Optional
import unibox as ub
# Store current dataset in a global dict so it persists across Gradio calls.
CURRENT_DATASET = {
"id": None,
"df": None
}
rating_map = {
"g": "general",
"s": "sensitive",
"q": "questionable",
"e": "explicit"
}
def load_dataset_if_needed(dataset_id: str):
"""
Checks if dataset_id is different from what's currently loaded.
If so, loads from HF again and updates CURRENT_DATASET.
"""
if CURRENT_DATASET["id"] != dataset_id:
df = ub.loads(f"hf://{dataset_id}").to_pandas()
CURRENT_DATASET["id"] = dataset_id
CURRENT_DATASET["df"] = df
def convert_dbr_tag_string(tag_string: str, shuffle: bool = True) -> str:
"""
1girl long_hair blush -> 1girl, long_hair, blush
"""
tags_list = [i.replace("_", " ") for i in tag_string.split(" ") if i]
if shuffle:
random.shuffle(tags_list)
return ", ".join(tags_list)
def get_tags_dict(df_row: pd.Series) -> dict:
"""
Returns a dict with rating/artist/character/copyright/general/meta
plus numeric score.
"""
rating = df_row["rating"]
artist = df_row["tag_string_artist"]
character = df_row["tag_string_character"]
copyright_ = df_row["tag_string_copyright"]
general = df_row["tag_string_general"]
meta = df_row["tag_string_meta"]
score = df_row["score"]
rating_str = rating_map.get(rating, "")
artist_str = artist if artist else ""
character_str = convert_dbr_tag_string(character) if character else ""
copyright_str = f"copyright:{copyright_}" if copyright_ else ""
general_str = convert_dbr_tag_string(general) if general else ""
meta_str = convert_dbr_tag_string(meta) if meta else ""
_score = str(score) if score else ""
return {
"rating_str": rating_str,
"artist_str": artist_str,
"character_str": character_str,
"copyright_str": copyright_str,
"general_str": general_str,
"meta_str": meta_str,
"score": _score,
}
def build_tags_from_tags_dict(tags_dict: dict, add_artist_tags: bool = True) -> str:
"""
Build a final comma-separated string (rating, artist, character, etc.).
"""
context = []
if tags_dict["rating_str"]:
context.append(tags_dict["rating_str"])
if tags_dict["artist_str"] and add_artist_tags:
context.append(f"artist:{tags_dict['artist_str']}")
if tags_dict["character_str"]:
context.append(tags_dict["character_str"])
if tags_dict["copyright_str"]:
context.append(tags_dict["copyright_str"])
if tags_dict["general_str"]:
context.append(tags_dict["general_str"])
return ", ".join(context)
def get_captions_for_rows(df, start_idx: int = 0, end_idx: int = 5,
tags_front: str = "", tags_back: str = "",
add_artist_tags: bool = True) -> list:
filtered_df = df.iloc[start_idx:end_idx]
captions = []
for _, row in filtered_df.iterrows():
tags = get_tags_dict(row)
caption_base = build_tags_from_tags_dict(tags, add_artist_tags)
# Combine front, base, back
pieces = [part for part in [tags_front, caption_base, tags_back] if part]
final_caption = ", ".join(pieces)
captions.append(final_caption)
return captions
def get_previews_for_rows(df: pd.DataFrame, start_idx: int = 0, end_idx: int = 5) -> list:
filtered_df = df.iloc[start_idx:end_idx]
return [row["large_file_url"] for _, row in filtered_df.iterrows()]
def gradio_interface(
dataset_id: str,
start_idx: int = 0,
display_count: int = 5,
tags_front: str = "",
tags_back: str = "",
add_artist_tags: bool = True
):
"""
1) Loads dataset if needed
2) Returns (DataFrame, Gallery, InfoMessage)
"""
# 1) Possibly reload
load_dataset_if_needed(dataset_id)
dset_df = CURRENT_DATASET["df"]
if dset_df is None:
return pd.DataFrame(), [], f"ERROR: Could not load dataset {dataset_id}"
# 2) Figure out total length, clamp inputs
total_len = len(dset_df)
if total_len == 0:
return pd.DataFrame(), [], f"Dataset {dataset_id} is empty."
start_idx = max(start_idx, 0)
if start_idx >= total_len:
start_idx = total_len - 1
end_idx = start_idx + display_count
if end_idx > total_len:
end_idx = total_len
# 3) Build results
idxs = range(start_idx, end_idx)
captions = get_captions_for_rows(dset_df, start_idx, end_idx, tags_front, tags_back, add_artist_tags)
previews = get_previews_for_rows(dset_df, start_idx, end_idx)
df_out = pd.DataFrame({"index": idxs, "Captions": captions})
# 4) Build info string
info_msg = (
f"**Current dataset:** {CURRENT_DATASET['id']} \n"
f"**Dataset length:** {total_len} \n"
f"**start_idx:** {start_idx}, **display_count:** {display_count}, "
f"**tags_front:** '{tags_front}', **tags_back:** '{tags_back}', "
f"**add_artist_tags:** {add_artist_tags}"
)
return df_out, previews, info_msg
with gr.Blocks() as demo:
gr.Markdown("## Danbooru2025 Dataset Captions and Previews")
with gr.Row():
with gr.Column(scale=1):
dataset_id_input = gr.Textbox(
value="dataproc5/test-danbooru2025-tag-balanced-2k",
label="Dataset ID"
)
start_idx_input = gr.Number(value=500, label="Start Index")
display_count_input = gr.Slider(
value=5, minimum=1, maximum=50, step=1,
label="Number of Items"
)
tags_front_input = gr.Textbox(value="", label="Tags Front")
tags_back_input = gr.Textbox(value="", label="Tags Back")
add_artist_tags_input = gr.Checkbox(label="Add artist tags", value=True)
run_button = gr.Button("Get Captions & Previews")
with gr.Column(scale=2):
captions_df_out = gr.DataFrame(label="Captions")
previews_gallery_out = gr.Gallery(label="Previews", type="filepath")
info_textbox_out = gr.Markdown(value="")
run_button.click(
fn=gradio_interface,
inputs=[
dataset_id_input,
start_idx_input,
display_count_input,
tags_front_input,
tags_back_input,
add_artist_tags_input
],
outputs=[
captions_df_out,
previews_gallery_out,
info_textbox_out
]
)
demo.launch() |