Spaces:
Running
on
A100
Running
on
A100
import torch | |
import gradio as gr | |
from diffusers import FluxPipeline, FluxTransformer2DModel | |
import gc | |
import random | |
import glob | |
from pathlib import Path | |
from PIL import Image | |
import os | |
import time | |
import json | |
from fasteners import InterProcessLock | |
import spaces | |
from datasets import Dataset, Image as HFImage, load_dataset | |
from datasets import Features, Value | |
from datasets import concatenate_datasets | |
from datetime import datetime | |
AGG_FILE = Path(__file__).parent / "agg_stats.json" | |
LOCK_FILE = AGG_FILE.with_suffix(".lock") | |
def _load_agg_stats() -> dict: | |
if AGG_FILE.exists(): | |
with open(AGG_FILE, "r") as f: | |
try: | |
return json.load(f) | |
except json.JSONDecodeError: | |
print(f"Warning: {AGG_FILE} is corrupted. Starting with empty stats.") | |
return {"8-bit bnb": {"attempts": 0, "correct": 0}, "4-bit bnb": {"attempts": 0, "correct": 0}} | |
return {"8-bit bnb": {"attempts": 0, "correct": 0}, | |
"4-bit bnb": {"attempts": 0, "correct": 0}} | |
def _save_agg_stats(stats: dict) -> None: | |
with InterProcessLock(str(LOCK_FILE)): | |
with open(AGG_FILE, "w") as f: | |
json.dump(stats, f, indent=2) | |
USER_STATS_FILE = Path(__file__).parent / "user_stats.json" | |
USER_STATS_LOCK_FILE = USER_STATS_FILE.with_suffix(".lock") | |
def _load_user_stats() -> dict: | |
if USER_STATS_FILE.exists(): | |
with open(USER_STATS_FILE, "r") as f: | |
try: | |
return json.load(f) | |
except json.JSONDecodeError: | |
print(f"Warning: {USER_STATS_FILE} is corrupted. Starting with empty user stats.") | |
return {} | |
return {} | |
def _save_user_stats(stats: dict) -> None: | |
with InterProcessLock(str(USER_STATS_LOCK_FILE)): | |
with open(USER_STATS_FILE, "w") as f: | |
json.dump(stats, f, indent=2) | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
print(f"Using device: {DEVICE}") | |
DEFAULT_HEIGHT = 1024 | |
DEFAULT_WIDTH = 1024 | |
DEFAULT_GUIDANCE_SCALE = 3.5 | |
DEFAULT_NUM_INFERENCE_STEPS = 15 | |
DEFAULT_MAX_SEQUENCE_LENGTH = 512 | |
HF_TOKEN = os.environ.get("HF_ACCESS_TOKEN") | |
HF_DATASET_REPO_ID = "derekl35/flux-quant-challenge-submissions" | |
CACHED_PIPES = {} | |
def load_bf16_pipeline(): | |
print("Loading BF16 pipeline...") | |
MODEL_ID = "black-forest-labs/FLUX.1-dev" | |
if MODEL_ID in CACHED_PIPES: | |
return CACHED_PIPES[MODEL_ID] | |
start_time = time.time() | |
try: | |
pipe = FluxPipeline.from_pretrained( | |
MODEL_ID, | |
torch_dtype=torch.bfloat16, | |
token=HF_TOKEN | |
) | |
pipe.to(DEVICE) | |
end_time = time.time() | |
mem_reserved = torch.cuda.memory_reserved(0)/1024**3 if DEVICE == "cuda" else 0 | |
print(f"BF16 Pipeline loaded in {end_time - start_time:.2f}s. Memory reserved: {mem_reserved:.2f} GB") | |
CACHED_PIPES[MODEL_ID] = pipe | |
return pipe | |
except Exception as e: | |
print(f"Error loading BF16 pipeline: {e}") | |
raise | |
def load_bnb_8bit_pipeline(): | |
print("Loading 8-bit BNB pipeline...") | |
MODEL_ID = "derekl35/FLUX.1-dev-bnb-8bit" | |
if MODEL_ID in CACHED_PIPES: | |
return CACHED_PIPES[MODEL_ID] | |
start_time = time.time() | |
try: | |
pipe = FluxPipeline.from_pretrained( | |
MODEL_ID, | |
torch_dtype=torch.bfloat16 | |
) | |
pipe.to(DEVICE) | |
# pipe.enable_model_cpu_offload() | |
end_time = time.time() | |
mem_reserved = torch.cuda.memory_reserved(0)/1024**3 if DEVICE == "cuda" else 0 | |
print(f"8-bit BNB pipeline loaded in {end_time - start_time:.2f}s. Memory reserved: {mem_reserved:.2f} GB") | |
CACHED_PIPES[MODEL_ID] = pipe | |
return pipe | |
except Exception as e: | |
print(f"Error loading 8-bit BNB pipeline: {e}") | |
raise | |
def load_bnb_4bit_pipeline(): | |
print("Loading 4-bit BNB pipeline...") | |
MODEL_ID = "derekl35/FLUX.1-dev-nf4" | |
if MODEL_ID in CACHED_PIPES: | |
return CACHED_PIPES[MODEL_ID] | |
start_time = time.time() | |
try: | |
pipe = FluxPipeline.from_pretrained( | |
MODEL_ID, | |
torch_dtype=torch.bfloat16 | |
) | |
pipe.to(DEVICE) | |
# pipe.enable_model_cpu_offload() | |
end_time = time.time() | |
mem_reserved = torch.cuda.memory_reserved(0)/1024**3 if DEVICE == "cuda" else 0 | |
print(f"4-bit BNB pipeline loaded in {end_time - start_time:.2f}s. Memory reserved: {mem_reserved:.2f} GB") | |
CACHED_PIPES[MODEL_ID] = pipe | |
return pipe | |
except Exception as e: | |
print(f"Error loading 4-bit BNB pipeline: {e}") | |
raise | |
def generate_images(prompt, quantization_choice, progress=gr.Progress(track_tqdm=True)): | |
if not prompt: | |
return None, {}, gr.update(value="Please enter a prompt.", interactive=False), gr.update(choices=[], value=None), gr.update(interactive=True), gr.update(interactive=True) | |
if not quantization_choice: | |
return None, {}, gr.update(value="Please select a quantization method.", interactive=False), gr.update(choices=[], value=None), gr.update(interactive=True), gr.update(interactive=True) | |
if quantization_choice == "8-bit bnb": | |
quantized_load_func = load_bnb_8bit_pipeline | |
quantized_label = "Quantized (8-bit bnb)" | |
elif quantization_choice == "4-bit bnb": | |
quantized_load_func = load_bnb_4bit_pipeline | |
quantized_label = "Quantized (4-bit bnb)" | |
else: | |
return None, {}, gr.update(value="Invalid quantization choice.", interactive=False), gr.update(choices=[], value=None), gr.update(interactive=True), gr.update(interactive=True) | |
model_configs = [ | |
("Original", load_bf16_pipeline), | |
(quantized_label, quantized_load_func), | |
] | |
results = [] | |
pipe_kwargs = { | |
"prompt": prompt, | |
"height": DEFAULT_HEIGHT, | |
"width": DEFAULT_WIDTH, | |
"guidance_scale": DEFAULT_GUIDANCE_SCALE, | |
"num_inference_steps": DEFAULT_NUM_INFERENCE_STEPS, | |
"max_sequence_length": DEFAULT_MAX_SEQUENCE_LENGTH, | |
} | |
seed = random.getrandbits(64) | |
print(f"Using seed: {seed}") | |
for i, (label, load_func) in enumerate(model_configs): | |
progress(i / len(model_configs), desc=f"Loading {label} model...") | |
print(f"\n--- Loading {label} Model ---") | |
load_start_time = time.time() | |
try: | |
current_pipe = load_func() | |
load_end_time = time.time() | |
print(f"{label} model loaded in {load_end_time - load_start_time:.2f} seconds.") | |
progress((i + 0.5) / len(model_configs), desc=f"Generating with {label} model...") | |
print(f"--- Generating with {label} Model ---") | |
gen_start_time = time.time() | |
image_list = current_pipe(**pipe_kwargs, generator=torch.manual_seed(seed)).images | |
image = image_list[0] | |
gen_end_time = time.time() | |
results.append({"label": label, "image": image}) | |
print(f"--- Finished Generation with {label} Model in {gen_end_time - gen_start_time:.2f} seconds ---") | |
mem_reserved = torch.cuda.memory_reserved(0)/1024**3 if DEVICE == "cuda" else 0 | |
print(f"Memory reserved: {mem_reserved:.2f} GB") | |
except Exception as e: | |
print(f"Error during {label} model processing: {e}") | |
return None, {}, gr.update(value=f"Error processing {label} model: {e}", interactive=False), gr.update(choices=[], value=None), gr.update(interactive=True), gr.update(interactive=True) | |
if len(results) != len(model_configs): | |
return None, {}, gr.update(value="Failed to generate images for all model types.", interactive=False), gr.update(choices=[], value=None), gr.update(interactive=True), gr.update(interactive=True) | |
shuffled_results = results.copy() | |
random.shuffle(shuffled_results) | |
shuffled_data_for_gallery = [(res["image"], f"Image {i+1}") for i, res in enumerate(shuffled_results)] | |
correct_mapping = {i: res["label"] for i, res in enumerate(shuffled_results)} | |
# print("Correct mapping (hidden):", correct_mapping) | |
return shuffled_data_for_gallery, correct_mapping, prompt, seed, results, "Generation complete! Make your guess.", None, gr.update(interactive=True), gr.update(interactive=True) | |
def check_guess(user_guess, correct_mapping_state): | |
if not isinstance(correct_mapping_state, dict) or not correct_mapping_state: | |
return "Please generate images first (state is empty or invalid)." | |
if user_guess is None: | |
return "Please select which image you think is quantized." | |
quantized_image_index = -1 | |
quantized_label_actual = "" | |
for index, label in correct_mapping_state.items(): | |
if "Quantized" in label: | |
quantized_image_index = index | |
quantized_label_actual = label | |
break | |
if quantized_image_index == -1: | |
return "Error: Could not find the quantized image in the mapping data." | |
correct_guess_label = f"Image {quantized_image_index + 1}" | |
if user_guess == correct_guess_label: | |
feedback = f"Correct! {correct_guess_label} used the {quantized_label_actual} model." | |
else: | |
feedback = f"Incorrect. The quantized image ({quantized_label_actual}) was {correct_guess_label}." | |
return feedback | |
EXAMPLE_DIR = Path(__file__).parent / "examples" | |
EXAMPLES = [ | |
{ | |
"prompt": "A photorealistic portrait of an astronaut on Mars", | |
"files": ["astronauts_seed_6456306350371904162.png", "astronauts_bnb_8bit.png"], | |
"quantized_idx": 1, | |
"quant_method": "8-bit bnb", | |
}, | |
{ | |
"prompt": "Water-color painting of a cat wearing sunglasses", | |
"files": ["watercolor_cat_bnb_8bit.png", "watercolor_cat_seed_14269059182221286790.png"], | |
"quantized_idx": 0, | |
"quant_method": "8-bit bnb", | |
}, | |
# { | |
# "prompt": "Neo-tokyo cyberpunk cityscape at night, rain-soaked streets, 8-K", | |
# "files": ["cyber_city_q.jpg", "cyber_city.jpg"], | |
# "quantized_idx": 0, | |
# }, | |
] | |
def load_example(idx): | |
ex = EXAMPLES[idx] | |
imgs = [Image.open(EXAMPLE_DIR / f) for f in ex["files"]] | |
gallery_items = [(img, f"Image {i+1}") for i, img in enumerate(imgs)] | |
mapping = {i: (f"Quantized ({ex['quant_method']})" if i == ex["quantized_idx"] else "Original") | |
for i in range(2)} | |
return gallery_items, mapping, f"{ex['prompt']}" | |
def _accuracy_string(correct: int, attempts: int) -> tuple[str, float]: | |
if attempts: | |
pct = 100 * correct / attempts | |
return f"{pct:.1f}%", pct | |
return "N/A", -1.0 | |
def _add_medals(user_rows): | |
MEDALS = {0: "🥇 ", 1: "🥈 ", 2: "🥉 "} | |
return [ | |
[MEDALS.get(i, "") + row[0], *row[1:]] | |
for i, row in enumerate(user_rows) | |
] | |
def update_leaderboards_data(): | |
agg = _load_agg_stats() | |
quant_rows = [] | |
for method, stats in agg.items(): | |
acc_str, acc_val = _accuracy_string(stats["correct"], stats["attempts"]) | |
quant_rows.append([ | |
method, | |
stats["correct"], | |
stats["attempts"], | |
acc_str | |
]) | |
quant_rows.sort(key=lambda r: r[1]/r[2] if r[2] != 0 else 1e9) | |
user_stats_all = _load_user_stats() | |
overall_user_rows = [] | |
for user, per_method_stats_dict in user_stats_all.items(): | |
user_total_correct = 0 | |
user_total_attempts = 0 | |
for method_stats in per_method_stats_dict.values(): | |
user_total_correct += method_stats.get("correct", 0) | |
user_total_attempts += method_stats.get("attempts", 0) | |
if user_total_attempts >= 1: | |
acc_str, _ = _accuracy_string(user_total_correct, user_total_attempts) | |
overall_user_rows.append([user, user_total_correct, user_total_attempts, acc_str]) | |
overall_user_rows.sort(key=lambda r: (-float(r[3].rstrip('%')) if r[3] != "N/A" else float('-inf'), -r[2])) | |
overall_user_rows_medaled = _add_medals(overall_user_rows) | |
user_leaderboards_per_method = {} | |
quant_method_names = list(agg.keys()) | |
for method_name in quant_method_names: | |
method_specific_user_rows = [] | |
for user, per_user_method_stats_dict in user_stats_all.items(): | |
if method_name in per_user_method_stats_dict: | |
st = per_user_method_stats_dict[method_name] | |
if st.get("attempts", 0) >= 1: # Only include users who have attempted this method | |
acc_str, _ = _accuracy_string(st["correct"], st["attempts"]) | |
method_specific_user_rows.append([user, st["correct"], st["attempts"], acc_str]) | |
method_specific_user_rows.sort(key=lambda r: (-float(r[3].rstrip('%')) if r[3] != "N/A" else float('-inf'), -r[2])) | |
method_specific_user_rows_medaled = _add_medals(method_specific_user_rows) | |
user_leaderboards_per_method[method_name] = method_specific_user_rows_medaled | |
return quant_rows, overall_user_rows_medaled, user_leaderboards_per_method | |
quant_df = gr.DataFrame( | |
headers=["Method", "Correct Guesses", "Total Attempts", "Detectability %"], | |
interactive=False, col_count=(4, "fixed") | |
) | |
user_df = gr.DataFrame( | |
headers=["User", "Correct Guesses", "Total Attempts", "Accuracy %"], | |
interactive=False, col_count=(4, "fixed") | |
) | |
with gr.Blocks(title="FLUX Quantization Challenge", theme=gr.themes.Soft()) as demo: | |
gr.Markdown("# FLUX Model Quantization Challenge") | |
with gr.Tabs(): | |
with gr.TabItem("Challenge"): | |
gr.Markdown( | |
"Compare the original FLUX.1-dev (BF16) model against a quantized version (4-bit or 8-bit bnb). " | |
"Enter a prompt, choose the quantization method, and generate two images. " | |
"The images will be shuffled, can you spot which one was quantized?" | |
) | |
gr.Markdown("### Examples") | |
ex_selector = gr.Radio( | |
choices=[f"Example {i+1}" for i in range(len(EXAMPLES))], | |
label="Choose an example prompt", | |
interactive=True, | |
) | |
gr.Markdown("### …or create your own comparison") | |
with gr.Row(): | |
prompt_input = gr.Textbox(label="Enter Prompt", scale=3) | |
quantization_choice_radio = gr.Radio( | |
choices=["8-bit bnb", "4-bit bnb"], | |
label="Select Quantization", | |
value="8-bit bnb", | |
scale=1 | |
) | |
generate_button = gr.Button("Generate & Compare", variant="primary", scale=1) | |
output_gallery = gr.Gallery( | |
label="Generated Images", | |
columns=2, | |
height=606, | |
object_fit="contain", | |
allow_preview=True, | |
show_label=True, | |
) | |
gr.Markdown("### Which image used the selected quantization method?") | |
with gr.Row(): | |
image1_btn = gr.Button("Image 1") | |
image2_btn = gr.Button("Image 2") | |
feedback_box = gr.Textbox(label="Feedback", interactive=False, lines=1) | |
with gr.Row(): | |
session_score_box = gr.Textbox(label="Your accuracy this session", interactive=False) | |
with gr.Row(equal_height=False): | |
username_input = gr.Textbox( | |
label="Enter Your Name for Leaderboard", | |
placeholder="YourName", | |
visible=False, | |
interactive=True, | |
scale=2 | |
) | |
add_score_button = gr.Button( | |
"Add My Score to Leaderboard", | |
visible=False, | |
variant="secondary", | |
scale=1 | |
) | |
add_score_feedback = gr.Textbox( | |
label="Leaderboard Update", | |
visible=False, | |
interactive=False, | |
lines=1 | |
) | |
correct_mapping_state = gr.State({}) | |
session_stats_state = gr.State( | |
{"8-bit bnb": {"attempts": 0, "correct": 0}, | |
"4-bit bnb": {"attempts": 0, "correct": 0}} | |
) | |
is_example_state = gr.State(False) | |
has_added_score_state = gr.State(False) | |
prompt_state = gr.State("") | |
seed_state = gr.State(None) | |
results_state = gr.State([]) | |
def _load_example_and_update_dfs(sel): | |
idx = int(sel.split()[-1]) - 1 | |
gallery_items, mapping, prompt = load_example(idx) | |
quant_data, overall_user_data, _ = update_leaderboards_data() | |
return gallery_items, mapping, prompt, True, quant_data, overall_user_data, "", None, [] | |
ex_selector.change( | |
fn=_load_example_and_update_dfs, | |
inputs=ex_selector, | |
outputs=[output_gallery, correct_mapping_state, prompt_input, is_example_state, quant_df, user_df, | |
prompt_state, seed_state, results_state], | |
).then( | |
lambda: (gr.update(interactive=True), gr.update(interactive=True)), | |
outputs=[image1_btn, image2_btn], | |
) | |
generate_button.click( | |
fn=generate_images, | |
inputs=[prompt_input, quantization_choice_radio], | |
outputs=[output_gallery, correct_mapping_state, prompt_state, seed_state, results_state, | |
feedback_box] #, quantization_choice_radio, generate_button, prompt_input] | |
).then( | |
lambda: (False, # for is_example_state | |
False, # for has_added_score_state | |
gr.update(visible=False, value="", interactive=True), # username_input reset | |
gr.update(visible=False), # add_score_button reset | |
gr.update(visible=False, value="")), # add_score_feedback reset | |
outputs=[is_example_state, | |
has_added_score_state, | |
username_input, | |
add_score_button, | |
add_score_feedback] | |
).then( | |
lambda: (gr.update(interactive=True), | |
gr.update(interactive=True), | |
""), | |
outputs=[image1_btn, image2_btn, feedback_box], | |
) | |
def choose(choice_string, mapping, session_stats, is_example, has_added_score_curr, | |
prompt, seed, results, username): | |
feedback = check_guess(choice_string, mapping) | |
if not mapping: | |
return feedback, gr.update(), gr.update(), "", session_stats, [], [], gr.update(), gr.update(), gr.update() | |
quant_label_from_mapping = next((label for label in mapping.values() if "Quantized" in label), None) | |
if not quant_label_from_mapping: | |
print("Error: Could not determine quantization label from mapping:", mapping) | |
return ("Internal Error: Could not process results.", gr.update(interactive=False), gr.update(interactive=False), | |
"", session_stats, [], [], gr.update(), gr.update(), gr.update()) | |
quant_key = "8-bit bnb" if "8-bit bnb" in quant_label_from_mapping else "4-bit bnb" | |
got_it_right = "Correct!" in feedback | |
sess = session_stats.copy() | |
should_log_and_update_stats = not is_example and not has_added_score_curr | |
if should_log_and_update_stats: | |
sess[quant_key]["attempts"] += 1 | |
if got_it_right: | |
sess[quant_key]["correct"] += 1 | |
session_stats = sess | |
AGG_STATS = _load_agg_stats() | |
AGG_STATS[quant_key]["attempts"] += 1 | |
if got_it_right: | |
AGG_STATS[quant_key]["correct"] += 1 | |
_save_agg_stats(AGG_STATS) | |
if not HF_TOKEN: | |
print("Warning: HF_TOKEN not set. Skipping dataset logging.") | |
elif not results: | |
print("Warning: Results state is empty. Skipping dataset logging.") | |
else: | |
print(f"Logging guess to HF Dataset: {HF_DATASET_REPO_ID}") | |
original_image = None | |
quantized_image = None | |
quantized_image_pos = -1 | |
for shuffled_idx, original_label in mapping.items(): | |
if "Quantized" in original_label: | |
quantized_image_pos = shuffled_idx | |
break | |
original_image = next((res["image"] for res in results if "Original" in res["label"]), None) | |
quantized_image = next((res["image"] for res in results if "Quantized" in res["label"]), None) | |
if original_image and quantized_image: | |
expected_features = Features({ | |
"timestamp": Value("string"), | |
"prompt": Value("string"), | |
"quantization_method": Value("string"), | |
"seed": Value("string"), | |
"image_original": HFImage(), | |
"image_quantized": HFImage(), | |
"quantized_image_displayed_position": Value("string"), | |
"user_guess_displayed_position": Value("string"), | |
"correct_guess": Value("bool"), | |
"username": Value("string"), # Handles None | |
}) | |
new_data_dict_of_lists = { | |
"timestamp": [datetime.now().isoformat()], | |
"prompt": [prompt], | |
"quantization_method": [quant_key], | |
"seed": [str(seed)], | |
"image_original": [original_image], | |
"image_quantized": [quantized_image], | |
"quantized_image_displayed_position": [f"Image {quantized_image_pos + 1}"], | |
"user_guess_displayed_position": [choice_string], | |
"correct_guess": [got_it_right], | |
"username": [username.strip() if username else None], | |
} | |
try: | |
# Attempt to load existing dataset | |
existing_ds = load_dataset( | |
HF_DATASET_REPO_ID, | |
split="train", | |
token=HF_TOKEN, | |
features=expected_features, | |
# verification_mode="no_checks" # Consider removing or using default | |
# download_mode="force_redownload" # For debugging cache issues | |
) | |
# Create a new dataset from the new item, casting to the expected features | |
new_row_ds = Dataset.from_dict(new_data_dict_of_lists, features=expected_features) | |
# Concatenate | |
combined_ds = concatenate_datasets([existing_ds, new_row_ds]) | |
# Push the combined dataset | |
combined_ds.push_to_hub(HF_DATASET_REPO_ID, token=HF_TOKEN, split="train") | |
print(f"Successfully appended guess to {HF_DATASET_REPO_ID} (train split)") | |
except Exception as e: | |
print(f"Could not load or append to existing dataset/split. Creating 'train' split with the new item. Error: {e}") | |
# Create dataset from only the new item, with explicit features | |
ds_new = Dataset.from_dict(new_data_dict_of_lists, features=expected_features) | |
# Push this new dataset as the 'train' split | |
ds_new.push_to_hub(HF_DATASET_REPO_ID, token=HF_TOKEN, split="train") | |
print(f"Successfully created and logged new 'train' split to {HF_DATASET_REPO_ID}") | |
else: | |
print("Error: Could not find original or quantized image in results state for logging.") | |
def _fmt(d): | |
a, c = d["attempts"], d["correct"] | |
pct = 100 * c / a if a else 0 | |
return f"{c} / {a} ({pct:.1f}%)" | |
session_msg = ", ".join( | |
f"{k}: {_fmt(v)}" for k, v in sess.items() | |
) | |
current_agg_stats = _load_agg_stats() | |
username_input_update = gr.update(visible=False, interactive=True) | |
add_score_button_update = gr.update(visible=False) | |
current_feedback_text = add_score_feedback.value if hasattr(add_score_feedback, 'value') and add_score_feedback.value else "" | |
add_score_feedback_update = gr.update(visible=has_added_score_curr, value=current_feedback_text) | |
session_total_attempts = sum(stats["attempts"] for stats in sess.values()) | |
if not is_example and not has_added_score_curr: | |
if session_total_attempts >= 1 : | |
username_input_update = gr.update(visible=True, interactive=True) | |
add_score_button_update = gr.update(visible=True, interactive=True) | |
add_score_feedback_update = gr.update(visible=False, value="") | |
else: | |
username_input_update = gr.update(visible=False, value=username_input.value if hasattr(username_input, 'value') else "") | |
add_score_button_update = gr.update(visible=False) | |
add_score_feedback_update = gr.update(visible=False, value="") | |
elif has_added_score_curr: | |
username_input_update = gr.update(visible=True, interactive=False, value=username_input.value if hasattr(username_input, 'value') else "") | |
add_score_button_update = gr.update(visible=True, interactive=False) | |
add_score_feedback_update = gr.update(visible=True) | |
quant_data, overall_user_data, _ = update_leaderboards_data() | |
return (feedback, | |
gr.update(interactive=False), | |
gr.update(interactive=False), | |
session_msg, | |
session_stats, | |
quant_data, | |
overall_user_data, | |
username_input_update, | |
add_score_button_update, | |
add_score_feedback_update) | |
image1_btn.click( | |
fn=lambda mapping, sess, is_ex, has_added, p, s, r, uname: choose("Image 1", mapping, sess, is_ex, has_added, p, s, r, uname), | |
inputs=[correct_mapping_state, session_stats_state, is_example_state, has_added_score_state, | |
prompt_state, seed_state, results_state, username_input], | |
outputs=[feedback_box, image1_btn, image2_btn, | |
session_score_box, session_stats_state, | |
quant_df, user_df, | |
username_input, add_score_button, add_score_feedback], | |
) | |
image2_btn.click( | |
fn=lambda mapping, sess, is_ex, has_added, p, s, r, uname: choose("Image 2", mapping, sess, is_ex, has_added, p, s, r, uname), | |
inputs=[correct_mapping_state, session_stats_state, is_example_state, has_added_score_state, | |
prompt_state, seed_state, results_state, username_input], | |
outputs=[feedback_box, image1_btn, image2_btn, | |
session_score_box, session_stats_state, | |
quant_df, user_df, | |
username_input, add_score_button, add_score_feedback], | |
) | |
def handle_add_score_to_leaderboard(username_str, current_session_stats_dict): | |
if not username_str or not username_str.strip(): | |
return ("Username is required.", | |
gr.update(interactive=True), | |
gr.update(interactive=True), | |
False, | |
None, None) | |
user_stats = _load_user_stats() | |
user_key = username_str.strip() | |
session_total_session_attempts = sum(stats["attempts"] for stats in current_session_stats_dict.values()) | |
if session_total_session_attempts == 0: | |
return ("No attempts made in this session to add to leaderboard.", | |
gr.update(interactive=True), | |
gr.update(interactive=True), | |
False, None, None) | |
if user_key not in user_stats: | |
user_stats[user_key] = {} | |
for method, stats in current_session_stats_dict.items(): | |
session_method_correct = stats["correct"] | |
session_method_attempts = stats["attempts"] | |
if session_method_attempts == 0: | |
continue | |
if method not in user_stats[user_key]: | |
user_stats[user_key][method] = {"correct": 0, "attempts": 0} | |
user_stats[user_key][method]["correct"] += session_method_correct | |
user_stats[user_key][method]["attempts"] += session_method_attempts | |
_save_user_stats(user_stats) | |
new_quant_data, new_overall_user_data, _ = update_leaderboards_data() | |
feedback_msg = f"Score for '{user_key}' submitted to leaderboard!" | |
return (feedback_msg, | |
gr.update(interactive=False), | |
gr.update(interactive=False), | |
True, | |
new_quant_data, | |
new_overall_user_data) | |
add_score_button.click( | |
fn=handle_add_score_to_leaderboard, | |
inputs=[username_input, session_stats_state], | |
outputs=[add_score_feedback, username_input, add_score_button, has_added_score_state, quant_df, user_df] | |
) | |
with gr.TabItem("Leaderboard"): | |
gr.Markdown("## Quantization Method Leaderboard *(Lower % ⇒ harder to detect)*") | |
leaderboard_tab_quant_df = gr.DataFrame( | |
headers=["Method", "Correct Guesses", "Total Attempts", "Detectability %"], | |
interactive=False, col_count=(4, "fixed"), label="Quantization Method Leaderboard" | |
) | |
gr.Markdown("---") | |
leaderboard_tab_user_df_8bit = gr.DataFrame( | |
headers=["User", "Correct Guesses", "Total Attempts", "Accuracy %"], | |
interactive=False, col_count=(4, "fixed"), label="8-bit bnb User Leaderboard" | |
) | |
leaderboard_tab_user_df_4bit = gr.DataFrame( | |
headers=["User", "Correct Guesses", "Total Attempts", "Accuracy %"], | |
interactive=False, col_count=(4, "fixed"), label="4-bit bnb User Leaderboard" | |
) | |
def update_all_leaderboards_for_tab(): | |
q_rows, _, per_method_u_dict = update_leaderboards_data() | |
user_rows_8bit = per_method_u_dict.get("8-bit bnb", []) | |
user_rows_4bit = per_method_u_dict.get("4-bit bnb", []) | |
return q_rows, user_rows_8bit, user_rows_4bit | |
demo.load(update_all_leaderboards_for_tab, outputs=[ | |
leaderboard_tab_quant_df, | |
leaderboard_tab_user_df_8bit, | |
leaderboard_tab_user_df_4bit | |
]) | |
if __name__ == "__main__": | |
demo.launch(share=True) |