import csv import json import os import pickle import random import string import sys import time from glob import glob import datasets import gdown import gradio as gr import matplotlib.pyplot as plt import numpy as np import pandas as pd import torch import torchvision from huggingface_hub import HfApi, login, snapshot_download from PIL import Image # session_token = os.environ.get("SessionToken") # login(token=session_token) csv.field_size_limit(sys.maxsize) np.random.seed(int(time.time())) with open("./imagenet_hard_nearest_indices.pkl", "rb") as f: knn_results = pickle.load(f) with open("imagenet-labels.json") as f: wnid_to_label = json.load(f) with open("id_to_label.json", "r") as f: id_to_labels = json.load(f) imagenet_training_samples_path = "imagenet_traning_samples" bad_items = open("./ex2.txt", "r").read().split("\n") bad_items = [x.split(".")[0] for x in bad_items] bad_items = [int(x) for x in bad_items if x != ""] NUMBER_OF_IMAGES = 100 # len(bad_items) # download and extract folders gdown.cached_download( url="https://huggingface.co/datasets/taesiri/imagenet_hard_review_samples/resolve/main/data.zip", path="./data.zip", quiet=False, md5="8666a9b361f6eea79878be6c09701def", ) # EXTRACT if needed if not os.path.exists("./imagenet_traning_samples") or not os.path.exists( "./knn_cache_for_imagenet_hard" ): torchvision.datasets.utils.extract_archive( from_path="data.zip", to_path="./", remove_finished=False, ) imagenet_hard = datasets.load_dataset("taesiri/imagenet-hard", split="validation") def update_snapshot(username): output_dir = snapshot_download( repo_id="taesiri/imagenet_hard_review_data", allow_patterns="*.json", repo_type="dataset", ) files = glob(f"{output_dir}/*.json") df = pd.DataFrame() columns = ["id", "user_id", "time", "decision"] rows = [] for file in files: with open(file) as f: data = json.load(f) tdf = [data[x] for x in columns] # add filename as a column rows.append(tdf) df = pd.DataFrame(rows, columns=columns) df = df[df["user_id"] == username] return df def generate_dataset(username): global NUMBER_OF_IMAGES df = update_snapshot(username) all_images = set(bad_items) answered = set(df.id) remaining = list(all_images - answered) if len(remaining) < NUMBER_OF_IMAGES and len(remaining) > 0: NUMBER_OF_IMAGES = len(remaining) random_indices = list(remaining) elif len(remaining) == 0: return [] else: random_indices = np.random.choice(remaining, NUMBER_OF_IMAGES, replace=False) random_images = [imagenet_hard[int(i)]["image"] for i in random_indices] random_gt_ids = [imagenet_hard[int(i)]["label"] for i in random_indices] random_gt_labels = [imagenet_hard[int(x)]["english_label"] for x in random_indices] data = [] for i, image in enumerate(random_images): data.append( { "id": random_indices[i], "image": image, "correct_label": random_gt_labels[i], "original_id": int(random_indices[i]), } ) return data def string_to_image(text): text = text.replace("_", " ").lower().replace(", ", "\n") # Create a blank white square image img = np.ones((220, 75, 3)) fig, ax = plt.subplots(figsize=(6, 2.25)) ax.imshow(img, extent=[0, 1, 0, 1]) ax.text(0.5, 0.75, text, fontsize=18, ha="center", va="center") ax.set_xticks([]) ax.set_yticks([]) ax.set_xticklabels([]) ax.set_yticklabels([]) for spine in ax.spines.values(): spine.set_visible(False) return fig all_samples = glob("./imagenet_traning_samples/*.JPEG") qid_to_sample = { int(x.split("/")[-1].split(".")[0].split("_")[0]): x for x in all_samples } # user-e3z5b def get_training_samples(qid): labels_id = imagenet_hard[int(qid)]["label"] samples = [qid_to_sample[x] for x in labels_id] return samples def load_sample(data, current_index): image_id = data[current_index]["id"] qimage = data[current_index]["image"] labels = data[current_index]["correct_label"] return qimage, labels def preprocessing(data, current_index, history, username): data = generate_dataset(username) if len(data) == 0: fake_plot = string_to_image("No more images to review") empty_image = Image.new("RGB", (224, 224)) return ( empty_image, fake_plot, current_index, history, data, None, ) current_index = 0 qimage, labels = load_sample(data, current_index) image_id = data[current_index]["id"] training_samples_image = get_training_samples(image_id) training_samples_image = [ Image.open(x).convert("RGB") for x in training_samples_image ] # labels is a list of labels, conver it to a string labels = ", ".join(labels) label_plot = string_to_image(labels) return qimage, label_plot, current_index, history, data, training_samples_image def update_app(decision, data, current_index, history, username): global NUMBER_OF_IMAGES if current_index == -1: return if current_index == NUMBER_OF_IMAGES - 1: time_stamp = int(time.time()) image_id = data[current_index]["id"] # convert to percentage dicision_dict = { "id": int(image_id), "user_id": username, "time": time_stamp, "decision": decision, } # upload the decision to the server temp_filename = f"results_{username}_{time_stamp}.json" # convert decision_dict to json and save it on the disk with open(temp_filename, "w") as f: json.dump(dicision_dict, f) api = HfApi() api.upload_file( path_or_fileobj=temp_filename, path_in_repo=temp_filename, repo_id="taesiri/imagenet_hard_review_data", repo_type="dataset", ) os.remove(temp_filename) fake_plot = string_to_image("Thank you for your time!") empty_image = Image.new("RGB", (224, 224)) return empty_image, fake_plot, current_index, history, data, None if current_index >= 0 and current_index < NUMBER_OF_IMAGES - 1: time_stamp = int(time.time()) image_id = data[current_index]["id"] # convert to percentage dicision_dict = { "id": int(image_id), "user_id": username, "time": time_stamp, "decision": decision, } # upload the decision to the server temp_filename = f"results_{username}_{time_stamp}.json" # convert decision_dict to json and save it on the disk with open(temp_filename, "w") as f: json.dump(dicision_dict, f) api = HfApi() api.upload_file( path_or_fileobj=temp_filename, path_in_repo=temp_filename, repo_id="taesiri/imagenet_hard_review_data", repo_type="dataset", ) os.remove(temp_filename) # Load the Next Image current_index += 1 qimage, labels = load_sample(data, current_index) image_id = data[current_index]["id"] training_samples_image = get_training_samples(image_id) training_samples_image = [ Image.open(x).convert("RGB") for x in training_samples_image ] # labels is a list of labels, conver it to a string labels = ", ".join(labels) label_plot = string_to_image(labels) return qimage, label_plot, current_index, history, data, training_samples_image newcss = """ #query_image{ height: auto !important; } #nn_gallery { height: auto !important; } #sample_gallery { height: auto !important; } """ with gr.Blocks(css=newcss) as demo: data_gr = gr.State({}) current_index = gr.State(-1) history = gr.State({}) gr.Markdown("# Cleaning ImageNet-Hard!") random_str = "".join( random.choice(string.ascii_lowercase + string.digits) for _ in range(5) ) with gr.Row(): username = gr.Textbox(label="Username", value=f"user-{random_str}") prepare_btn = gr.Button(value="Load Samples") with gr.Column(): with gr.Row(): accept_btn = gr.Button(value="Accept") myabe_btn = gr.Button(value="Not Sure!") reject_btn = gr.Button(value="Reject") with gr.Row(): query_image = gr.Image(type="pil", label="Query", elem_id="query_image") with gr.Column(): label_plot = gr.Plot( label="Is this a correct label for this image?", type="fig" ) training_samples = gr.Gallery( type="pil", label="Training samples", elem_id="sample_gallery" ) accept_btn.click( update_app, inputs=[accept_btn, data_gr, current_index, history, username], outputs=[ query_image, label_plot, current_index, history, data_gr, training_samples, ], ) myabe_btn.click( update_app, inputs=[myabe_btn, data_gr, current_index, history, username], outputs=[ query_image, label_plot, current_index, history, data_gr, training_samples, ], ) reject_btn.click( update_app, inputs=[reject_btn, data_gr, current_index, history, username], outputs=[ query_image, label_plot, current_index, history, data_gr, training_samples, ], ) prepare_btn.click( preprocessing, inputs=[data_gr, current_index, history, username], outputs=[ query_image, label_plot, current_index, history, data_gr, training_samples, ], ) demo.launch()