|  | import os | 
					
						
						|  | import logging | 
					
						
						|  | from typing import Tuple | 
					
						
						|  | from dotenv import load_dotenv | 
					
						
						|  | import gradio as gr | 
					
						
						|  | import numpy as np | 
					
						
						|  | from PIL import Image | 
					
						
						|  | import random | 
					
						
						|  | from db import compute_elo_scores, get_all_votes, add_vote | 
					
						
						|  | import json | 
					
						
						|  | from pathlib import Path | 
					
						
						|  | from uuid import uuid4 | 
					
						
						|  | import logging | 
					
						
						|  | import threading | 
					
						
						|  | import time | 
					
						
						|  | from datasets import load_dataset | 
					
						
						|  | from huggingface_hub import CommitScheduler | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | dataset = load_dataset("bgsys/background-removal-arena-test", split='train') | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | logging.basicConfig(level=logging.INFO) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | load_dotenv() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | JSON_DATASET_DIR = Path("data/json_dataset") | 
					
						
						|  | JSON_DATASET_DIR.mkdir(parents=True, exist_ok=True) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | scheduler = CommitScheduler( | 
					
						
						|  | repo_id="bgsys/votes_datasets_test2", | 
					
						
						|  | repo_type="dataset", | 
					
						
						|  | folder_path=JSON_DATASET_DIR, | 
					
						
						|  | path_in_repo="data", | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def fetch_elo_scores(): | 
					
						
						|  | """Fetch and log Elo scores.""" | 
					
						
						|  | try: | 
					
						
						|  | elo_scores = compute_elo_scores() | 
					
						
						|  | logging.info("Elo scores successfully computed.") | 
					
						
						|  | return elo_scores | 
					
						
						|  | except Exception as e: | 
					
						
						|  | logging.error("Error computing Elo scores: %s", str(e)) | 
					
						
						|  | return None | 
					
						
						|  |  | 
					
						
						|  | def update_rankings_table(): | 
					
						
						|  | """Update and return the rankings table based on Elo scores.""" | 
					
						
						|  | elo_scores = fetch_elo_scores() | 
					
						
						|  | if elo_scores: | 
					
						
						|  | rankings = [ | 
					
						
						|  | ["Photoroom", int(elo_scores.get("Photoroom", 1000))], | 
					
						
						|  |  | 
					
						
						|  | ["RemoveBG", int(elo_scores.get("RemoveBG", 1000))], | 
					
						
						|  | ["BRIA RMBG 2.0", int(elo_scores.get("BRIA RMBG 2.0", 1000))], | 
					
						
						|  | ] | 
					
						
						|  | rankings.sort(key=lambda x: x[1], reverse=True) | 
					
						
						|  | return rankings | 
					
						
						|  | else: | 
					
						
						|  | return [ | 
					
						
						|  | ["Photoroom", -1], | 
					
						
						|  |  | 
					
						
						|  | ["RemoveBG", -1], | 
					
						
						|  | ["BRIA RMBG 2.0", -1], | 
					
						
						|  | ] | 
					
						
						|  |  | 
					
						
						|  | def select_new_image(): | 
					
						
						|  | """Select a new image and its segmented versions.""" | 
					
						
						|  | max_attempts = 10 | 
					
						
						|  | last_image_index = None | 
					
						
						|  |  | 
					
						
						|  | for _ in range(max_attempts): | 
					
						
						|  | available_indices = [i for i in range(len(dataset)) if i != last_image_index] | 
					
						
						|  |  | 
					
						
						|  | if not available_indices: | 
					
						
						|  | logging.error("No available images to select from.") | 
					
						
						|  | return None | 
					
						
						|  |  | 
					
						
						|  | random_index = random.choice(available_indices) | 
					
						
						|  | sample = dataset[random_index] | 
					
						
						|  | input_image = sample['original_image'] | 
					
						
						|  |  | 
					
						
						|  | segmented_images = [sample['clipdrop_image'], sample['bria_image'], | 
					
						
						|  | sample['photoroom_image'], sample['removebg_image']] | 
					
						
						|  | segmented_sources = ['Clipdrop', 'BRIA RMBG 2.0', 'Photoroom', 'RemoveBG'] | 
					
						
						|  |  | 
					
						
						|  | if segmented_images.count(None) > 2: | 
					
						
						|  | logging.error("Not enough segmented images found for: %s. Resampling another image.", sample['original_filename']) | 
					
						
						|  | last_image_index = random_index | 
					
						
						|  | continue | 
					
						
						|  |  | 
					
						
						|  | try: | 
					
						
						|  | selected_indices = random.sample([i for i, img in enumerate(segmented_images) if img is not None], 2) | 
					
						
						|  | model_a_index, model_b_index = selected_indices | 
					
						
						|  | model_a_output_image = segmented_images[model_a_index] | 
					
						
						|  | model_b_output_image = segmented_images[model_b_index] | 
					
						
						|  | model_a_name = segmented_sources[model_a_index] | 
					
						
						|  | model_b_name = segmented_sources[model_b_index] | 
					
						
						|  | return sample['original_filename'], input_image, model_a_output_image, model_b_output_image, model_a_name, model_b_name | 
					
						
						|  | except Exception as e: | 
					
						
						|  | logging.error("Error processing images: %s. Resampling another image.", str(e)) | 
					
						
						|  | last_image_index = random_index | 
					
						
						|  |  | 
					
						
						|  | logging.error("Failed to select a new image after %d attempts.", max_attempts) | 
					
						
						|  | return None | 
					
						
						|  |  | 
					
						
						|  | def get_notice_markdown(): | 
					
						
						|  | """Generate the notice markdown with dynamic vote count.""" | 
					
						
						|  | total_votes = len(get_all_votes()) | 
					
						
						|  | return f""" | 
					
						
						|  | # ⚔️  Background Removal Arena: Compare & Test the Best Background Removal Models | 
					
						
						|  |  | 
					
						
						|  | ## 📜 How It Works | 
					
						
						|  | - **Blind Test**: You will see two images with their background removed from two anonymous background removal models (Clipdrop, RemoveBG, Photoroom, BRIA RMBG 2.0). | 
					
						
						|  | - **Vote for the Best**: Choose the best result, if none stand out choose "Tie". | 
					
						
						|  |  | 
					
						
						|  | ## 📊 Stats | 
					
						
						|  | - **Total #votes**: {total_votes} | 
					
						
						|  |  | 
					
						
						|  | ## 👇 Test now! | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | def compute_mask_difference(segmented_a, segmented_b): | 
					
						
						|  | """Compute the absolute difference between two image masks.""" | 
					
						
						|  | mask_a = np.asarray(segmented_a) | 
					
						
						|  | mask_b = np.asarray(segmented_b) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | mask_a_1d = np.where(mask_a[..., 3] != 0, 1, 0) | 
					
						
						|  | mask_b_1d = np.where(mask_b[..., 3] != 0, 1, 0) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | return np.abs(mask_a_1d - mask_b_1d) | 
					
						
						|  |  | 
					
						
						|  | def gradio_interface(): | 
					
						
						|  | """Create and return the Gradio interface.""" | 
					
						
						|  | with gr.Blocks() as demo: | 
					
						
						|  | gr.Markdown("# Background Removal Arena") | 
					
						
						|  |  | 
					
						
						|  | with gr.Tabs() as tabs: | 
					
						
						|  | with gr.Tab("⚔️ Arena (battle)", id=0): | 
					
						
						|  | notice_markdown = gr.Markdown(get_notice_markdown(), elem_id="notice_markdown") | 
					
						
						|  |  | 
					
						
						|  | filname, input_image, segmented_a, segmented_b, a_name, b_name = select_new_image() | 
					
						
						|  | model_a_name = gr.State(a_name) | 
					
						
						|  | model_b_name = gr.State(b_name) | 
					
						
						|  | fpath_input = gr.State(filname) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | mask_difference = compute_mask_difference(segmented_a, segmented_b) | 
					
						
						|  |  | 
					
						
						|  | with gr.Row(): | 
					
						
						|  | image_a_display = gr.Image( | 
					
						
						|  | value=segmented_a, | 
					
						
						|  | type="pil", | 
					
						
						|  | label="Model A", | 
					
						
						|  | width=500, | 
					
						
						|  | height=500 | 
					
						
						|  | ) | 
					
						
						|  | input_image_display = gr.AnnotatedImage( | 
					
						
						|  | value=(input_image, [(mask_difference > 0, "Difference between masks")]), | 
					
						
						|  | label="Input Image", | 
					
						
						|  | width=500, | 
					
						
						|  | height=500 | 
					
						
						|  | ) | 
					
						
						|  | image_b_display = gr.Image( | 
					
						
						|  | value=segmented_b, | 
					
						
						|  | type="pil", | 
					
						
						|  | label="Model B", | 
					
						
						|  | width=500, | 
					
						
						|  | height=500 | 
					
						
						|  | ) | 
					
						
						|  | tie = gr.State("Tie") | 
					
						
						|  | with gr.Row(): | 
					
						
						|  | vote_a_btn = gr.Button("👈  A is better") | 
					
						
						|  | vote_tie = gr.Button("🤝  Tie") | 
					
						
						|  | vote_b_btn = gr.Button("👉  B is better") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | vote_a_btn.click( | 
					
						
						|  | fn=lambda: vote_for_model("model_a", fpath_input, model_a_name, model_b_name), | 
					
						
						|  | outputs=[ | 
					
						
						|  | fpath_input, input_image_display, image_a_display, image_b_display, model_a_name, model_b_name, notice_markdown | 
					
						
						|  | ] | 
					
						
						|  | ) | 
					
						
						|  | vote_b_btn.click( | 
					
						
						|  | fn=lambda: vote_for_model("model_b",fpath_input, model_a_name, model_b_name), | 
					
						
						|  | outputs=[ | 
					
						
						|  | fpath_input, input_image_display, image_a_display, image_b_display, model_a_name, model_b_name, notice_markdown | 
					
						
						|  | ] | 
					
						
						|  | ) | 
					
						
						|  | vote_tie.click( | 
					
						
						|  | fn=lambda: vote_for_model("tie", fpath_input, model_a_name, model_b_name), | 
					
						
						|  | outputs=[ | 
					
						
						|  | fpath_input, input_image_display, image_a_display, image_b_display, model_a_name, model_b_name, notice_markdown | 
					
						
						|  | ] | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | def vote_for_model(choice, original_filename, model_a_name, model_b_name): | 
					
						
						|  | """Submit a vote for a model and return updated images and model names.""" | 
					
						
						|  | logging.info("Voting for model: %s", choice) | 
					
						
						|  | vote_data = { | 
					
						
						|  | "image_id": original_filename.value, | 
					
						
						|  | "model_a": model_a_name.value, | 
					
						
						|  | "model_b": model_b_name.value, | 
					
						
						|  | "winner": choice, | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  | try: | 
					
						
						|  | logging.debug("Adding vote data to the database: %s", vote_data) | 
					
						
						|  |  | 
					
						
						|  | result = add_vote(vote_data) | 
					
						
						|  | logging.info("Vote successfully recorded in the database with ID: %s", result["id"]) | 
					
						
						|  | except Exception as e: | 
					
						
						|  | logging.error("Error recording vote: %s", str(e)) | 
					
						
						|  |  | 
					
						
						|  | new_fpath_input, new_input_image, new_segmented_a, new_segmented_b, new_a_name, new_b_name = select_new_image() | 
					
						
						|  | model_a_name.value = new_a_name | 
					
						
						|  | model_b_name.value = new_b_name | 
					
						
						|  | fpath_input.value = new_fpath_input | 
					
						
						|  |  | 
					
						
						|  | mask_difference = compute_mask_difference(new_segmented_a, new_segmented_b) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | new_notice_markdown = get_notice_markdown() | 
					
						
						|  |  | 
					
						
						|  | return (fpath_input.value, (new_input_image, [(mask_difference, "Mask")]), new_segmented_a, | 
					
						
						|  | new_segmented_b, model_a_name.value, model_b_name.value, new_notice_markdown) | 
					
						
						|  |  | 
					
						
						|  | with gr.Tab("🏆 Leaderboard", id=1) as leaderboard_tab: | 
					
						
						|  | rankings_table = gr.Dataframe( | 
					
						
						|  | headers=["Model", "Ranking"], | 
					
						
						|  | value=update_rankings_table(), | 
					
						
						|  | label="Current Model Rankings", | 
					
						
						|  | column_widths=[180, 60], | 
					
						
						|  | row_count=4 | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | leaderboard_tab.select( | 
					
						
						|  | fn=lambda: update_rankings_table(), | 
					
						
						|  | outputs=rankings_table | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | with gr.Tab("📊 Vote Data", id=2) as vote_data_tab: | 
					
						
						|  | def update_vote_data(): | 
					
						
						|  | votes = get_all_votes() | 
					
						
						|  | return [[vote.id, vote.image_id, vote.model_a, vote.model_b, vote.winner, vote.timestamp] for vote in votes] | 
					
						
						|  |  | 
					
						
						|  | vote_table = gr.Dataframe( | 
					
						
						|  | headers=["ID", "Image ID", "Model A", "Model B", "Winner", "Timestamp"], | 
					
						
						|  | value=update_vote_data(), | 
					
						
						|  | label="Vote Data", | 
					
						
						|  | column_widths=[20, 150, 100, 100, 100, 150], | 
					
						
						|  | row_count=0 | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | vote_data_tab.select( | 
					
						
						|  | fn=lambda: update_vote_data(), | 
					
						
						|  | outputs=vote_table | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | return demo | 
					
						
						|  |  | 
					
						
						|  | def dump_database_to_json(): | 
					
						
						|  | """Dump the database to a JSON file and upload it to Hugging Face.""" | 
					
						
						|  | votes = get_all_votes() | 
					
						
						|  | json_data = [ | 
					
						
						|  | { | 
					
						
						|  | "id": vote.id, | 
					
						
						|  | "image_id": vote.image_id, | 
					
						
						|  | "model_a": vote.model_a, | 
					
						
						|  | "model_b": vote.model_b, | 
					
						
						|  | "winner": vote.winner, | 
					
						
						|  | "user_id": vote.user_id, | 
					
						
						|  | "timestamp": vote.timestamp.isoformat() | 
					
						
						|  | } | 
					
						
						|  | for vote in votes | 
					
						
						|  | ] | 
					
						
						|  |  | 
					
						
						|  | json_file_path = JSON_DATASET_DIR / "votes.json" | 
					
						
						|  |  | 
					
						
						|  | with scheduler.lock: | 
					
						
						|  | with json_file_path.open("w") as f: | 
					
						
						|  | json.dump(json_data, f, indent=4) | 
					
						
						|  |  | 
					
						
						|  | logging.info("Database dumped to JSON") | 
					
						
						|  |  | 
					
						
						|  | def schedule_dump_database(interval=60): | 
					
						
						|  | """Schedule the database dump to JSON every specified interval in seconds.""" | 
					
						
						|  | def run(): | 
					
						
						|  | while True: | 
					
						
						|  | logging.info("Starting database dump to JSON.") | 
					
						
						|  | dump_database_to_json() | 
					
						
						|  | logging.info("Database dump completed. Sleeping for %d seconds.", interval) | 
					
						
						|  | time.sleep(interval) | 
					
						
						|  |  | 
					
						
						|  | logging.info("Initializing database dump scheduler with interval: %d seconds.", interval) | 
					
						
						|  | thread = threading.Thread(target=run, daemon=True) | 
					
						
						|  | thread.start() | 
					
						
						|  | logging.info("Database dump scheduler started.") | 
					
						
						|  |  | 
					
						
						|  | if __name__ == "__main__": | 
					
						
						|  | schedule_dump_database() | 
					
						
						|  | demo = gradio_interface() | 
					
						
						|  | demo.launch() |