Spaces:
Paused
Paused
| import csv | |
| import json | |
| import os | |
| import pickle | |
| import random | |
| import string | |
| import sys | |
| import time | |
| from glob import glob | |
| import datasets | |
| import gdown | |
| import gradio as gr | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import pandas as pd | |
| import torch | |
| import torchvision | |
| from huggingface_hub import HfApi, login, snapshot_download | |
| from PIL import Image | |
| # session_token = os.environ.get("SessionToken") | |
| # login(token=session_token) | |
| csv.field_size_limit(sys.maxsize) | |
| np.random.seed(int(time.time())) | |
| with open("./imagenet_hard_nearest_indices.pkl", "rb") as f: | |
| knn_results = pickle.load(f) | |
| with open("imagenet-labels.json") as f: | |
| wnid_to_label = json.load(f) | |
| with open("id_to_label.json", "r") as f: | |
| id_to_labels = json.load(f) | |
| imagenet_training_samples_path = "imagenet_traning_samples" | |
| bad_items = open("./ex2.txt", "r").read().split("\n") | |
| bad_items = [x.split(".")[0] for x in bad_items] | |
| bad_items = [int(x) for x in bad_items if x != ""] | |
| NUMBER_OF_IMAGES = len(bad_items) | |
| gdown.cached_download( | |
| url="https://huggingface.co/datasets/taesiri/imagenet_hard_review_samples/resolve/main/data.zip", | |
| path="./data.zip", | |
| quiet=False, | |
| md5="8666a9b361f6eea79878be6c09701def", | |
| ) | |
| # EXTRACT if needed | |
| if not os.path.exists("./imagenet_traning_samples") or not os.path.exists( | |
| "./knn_cache_for_imagenet_hard" | |
| ): | |
| torchvision.datasets.utils.extract_archive( | |
| from_path="data.zip", | |
| to_path="./", | |
| remove_finished=False, | |
| ) | |
| imagenet_hard = datasets.load_dataset("taesiri/imagenet-hard", split="validation") | |
| def update_snapshot(username): | |
| output_dir = snapshot_download( | |
| repo_id="taesiri/imagenet_hard_review_data", | |
| allow_patterns="*.json", | |
| repo_type="dataset", | |
| ) | |
| files = glob(f"{output_dir}/*.json") | |
| df = pd.DataFrame() | |
| columns = ["id", "user_id", "time", "decision"] | |
| rows = [] | |
| for file in files: | |
| with open(file) as f: | |
| data = json.load(f) | |
| tdf = [data[x] for x in columns] | |
| rows.append(tdf) | |
| df = pd.DataFrame(rows, columns=columns) | |
| df = df[df["user_id"] == username] | |
| return df | |
| def generate_dataset(username): | |
| global NUMBER_OF_IMAGES | |
| df = update_snapshot(username) | |
| all_images = set(bad_items) | |
| answered = set(df.id) | |
| remaining = list(all_images - answered) | |
| # shuffle remaining | |
| random.shuffle(remaining) | |
| NUMBER_OF_IMAGES = len(bad_items) | |
| print(f"NUMBER_OF_IMAGES: {NUMBER_OF_IMAGES}") | |
| print(f"Remaining: {len(remaining)}") | |
| if NUMBER_OF_IMAGES == 0: | |
| return [] | |
| # random_indices = remaining | |
| # random_images = [imagenet_hard[int(i)]["image"] for i in random_indices] | |
| # random_gt_ids = [imagenet_hard[int(i)]["label"] for i in random_indices] | |
| # random_gt_labels = [imagenet_hard[int(x)]["english_label"] for x in random_indices] | |
| data = [] | |
| for i, image in enumerate(remaining): | |
| data.append( | |
| { | |
| "id": remaining[i], | |
| # "correct_label": random_gt_labels[i], | |
| # "original_id": int(random_indices[i]), | |
| } | |
| ) | |
| return data | |
| def string_to_image(text): | |
| text = text.replace("_", " ").lower().replace(", ", "\n") | |
| # Create a blank white square image | |
| img = np.ones((220, 75, 3)) | |
| fig, ax = plt.subplots(figsize=(6, 2.25)) | |
| ax.imshow(img, extent=[0, 1, 0, 1]) | |
| ax.text(0.5, 0.75, text, fontsize=18, ha="center", va="center") | |
| ax.set_xticks([]) | |
| ax.set_yticks([]) | |
| ax.set_xticklabels([]) | |
| ax.set_yticklabels([]) | |
| for spine in ax.spines.values(): | |
| spine.set_visible(False) | |
| return fig | |
| all_samples = glob("./imagenet_traning_samples/*.JPEG") | |
| qid_to_sample = { | |
| int(x.split("/")[-1].split(".")[0].split("_")[0]): x for x in all_samples | |
| } | |
| def get_training_samples(qid): | |
| labels_id = imagenet_hard[int(qid)]["label"] | |
| samples = [qid_to_sample[x] for x in labels_id] | |
| return samples | |
| def load_sample(data, current_index): | |
| image_id = data[current_index]["id"] | |
| qimage = imagenet_hard[int(image_id)]["image"] | |
| # labels = data[current_index]["correct_label"] | |
| labels = imagenet_hard[int(image_id)]["english_label"] | |
| # print(f"Image ID: {image_id}") | |
| # print(f"Labels: {labels}") | |
| return qimage, labels | |
| def preprocessing(data, current_index, history, username): | |
| data = generate_dataset(username) | |
| remaining_images = len(data) | |
| labeled_images = len(bad_items) - remaining_images | |
| if remaining_images == 0: | |
| fake_plot = string_to_image("No more images to review") | |
| empty_image = Image.new("RGB", (224, 224)) | |
| return ( | |
| empty_image, | |
| fake_plot, | |
| current_index, | |
| history, | |
| data, | |
| None, | |
| labeled_images, | |
| ) | |
| current_index = 0 | |
| qimage, labels = load_sample(data, current_index) | |
| image_id = data[current_index]["id"] | |
| training_samples_image = get_training_samples(image_id) | |
| training_samples_image = [ | |
| Image.open(x).convert("RGB") for x in training_samples_image | |
| ] | |
| # labels is a list of labels, conver it to a string | |
| labels = ", ".join(labels) | |
| label_plot = string_to_image(labels) | |
| return ( | |
| qimage, | |
| label_plot, | |
| current_index, | |
| history, | |
| data, | |
| training_samples_image, | |
| labeled_images, | |
| ) | |
| def update_app(decision, data, current_index, history, username): | |
| global NUMBER_OF_IMAGES | |
| if current_index == -1: | |
| fake_plot = string_to_image("Please Enter your username and load samples") | |
| empty_image = Image.new("RGB", (224, 224)) | |
| return empty_image, fake_plot, current_index, history, data, None, 0 | |
| if current_index == NUMBER_OF_IMAGES - 1: | |
| time_stamp = int(time.time()) | |
| image_id = data[current_index]["id"] | |
| # convert to percentage | |
| dicision_dict = { | |
| "id": int(image_id), | |
| "user_id": username, | |
| "time": time_stamp, | |
| "decision": decision, | |
| } | |
| # upload the decision to the server | |
| temp_filename = f"results_{username}_{time_stamp}.json" | |
| # convert decision_dict to json and save it on the disk | |
| with open(temp_filename, "w") as f: | |
| json.dump(dicision_dict, f) | |
| api = HfApi() | |
| api.upload_file( | |
| path_or_fileobj=temp_filename, | |
| path_in_repo=temp_filename, | |
| repo_id="taesiri/imagenet_hard_review_data", | |
| repo_type="dataset", | |
| ) | |
| os.remove(temp_filename) | |
| fake_plot = string_to_image("Thank you for your time!") | |
| empty_image = Image.new("RGB", (224, 224)) | |
| remaining_images = len(data) | |
| labeled_images = (len(bad_items) - remaining_images) + current_index | |
| return ( | |
| empty_image, | |
| fake_plot, | |
| current_index, | |
| history, | |
| data, | |
| None, | |
| labeled_images + 1, | |
| ) | |
| if current_index >= 0 and current_index < NUMBER_OF_IMAGES - 1: | |
| time_stamp = int(time.time()) | |
| image_id = data[current_index]["id"] | |
| # convert to percentage | |
| dicision_dict = { | |
| "id": int(image_id), | |
| "user_id": username, | |
| "time": time_stamp, | |
| "decision": decision, | |
| } | |
| # upload the decision to the server | |
| temp_filename = f"results_{username}_{time_stamp}.json" | |
| # convert decision_dict to json and save it on the disk | |
| with open(temp_filename, "w") as f: | |
| json.dump(dicision_dict, f) | |
| api = HfApi() | |
| api.upload_file( | |
| path_or_fileobj=temp_filename, | |
| path_in_repo=temp_filename, | |
| repo_id="taesiri/imagenet_hard_review_data", | |
| repo_type="dataset", | |
| ) | |
| os.remove(temp_filename) | |
| # Load the Next Image | |
| current_index += 1 | |
| qimage, labels = load_sample(data, current_index) | |
| image_id = data[current_index]["id"] | |
| training_samples_image = get_training_samples(image_id) | |
| training_samples_image = [ | |
| Image.open(x).convert("RGB") for x in training_samples_image | |
| ] | |
| # labels is a list of labels, conver it to a string | |
| labels = ", ".join(labels) | |
| label_plot = string_to_image(labels) | |
| remaining_images = len(data) | |
| labeled_images = (len(bad_items) - remaining_images) + current_index | |
| return ( | |
| qimage, | |
| label_plot, | |
| current_index, | |
| history, | |
| data, | |
| training_samples_image, | |
| labeled_images, | |
| ) | |
| newcss = """ | |
| #query_image{ | |
| height: auto !important; | |
| } | |
| #nn_gallery { | |
| height: auto !important; | |
| } | |
| #sample_gallery { | |
| height: auto !important; | |
| } | |
| """ | |
| with gr.Blocks(css=newcss, theme=gr.themes.Soft()) as demo: | |
| data_gr = gr.State({}) | |
| current_index = gr.State(-1) | |
| history = gr.State({}) | |
| gr.Markdown("# Help Us to Clean `ImageNet-Hard`!") | |
| gr.Markdown("## Instructions") | |
| gr.Markdown( | |
| "Please enter your username and press `Load Samples`. The loading process might take up to a minute. Once the loading is done, you can start reviewing the samples." | |
| ) | |
| gr.Markdown( | |
| """For each image, please select one of the following options: `Accept`, `Not Sure!`, `Reject`. | |
| - If you think any of the labels are correct, please select `Accept`. | |
| - If you think none of the labels matching the image, please select `Reject`. | |
| - If you are not sure about the label, please select `Not Sure!`. | |
| You can refer to `Training samples` if you are not sure about the target label. | |
| """ | |
| ) | |
| random_str = "".join( | |
| random.choice(string.ascii_lowercase + string.digits) for _ in range(5) | |
| ) | |
| with gr.Column(): | |
| with gr.Row(): | |
| username = gr.Textbox(label="Username", value=f"user-{random_str}") | |
| labeled_images = gr.Textbox(label="Labeled Images", value="0") | |
| total_images = gr.Textbox(label="Total Images", value=len(bad_items)) | |
| prepare_btn = gr.Button(value="Load Samples") | |
| with gr.Column(): | |
| with gr.Row(): | |
| accept_btn = gr.Button(value="Accept") | |
| myabe_btn = gr.Button(value="Not Sure!") | |
| reject_btn = gr.Button(value="Reject") | |
| with gr.Row(): | |
| query_image = gr.Image(type="pil", label="Query", elem_id="query_image") | |
| with gr.Column(): | |
| label_plot = gr.Plot( | |
| label="Is this a correct label for this image?", type="fig" | |
| ) | |
| training_samples = gr.Gallery( | |
| type="pil", label="Training samples", elem_id="sample_gallery" | |
| ) | |
| accept_btn.click( | |
| update_app, | |
| inputs=[accept_btn, data_gr, current_index, history, username], | |
| outputs=[ | |
| query_image, | |
| label_plot, | |
| current_index, | |
| history, | |
| data_gr, | |
| training_samples, | |
| labeled_images, | |
| ], | |
| ) | |
| myabe_btn.click( | |
| update_app, | |
| inputs=[myabe_btn, data_gr, current_index, history, username], | |
| outputs=[ | |
| query_image, | |
| label_plot, | |
| current_index, | |
| history, | |
| data_gr, | |
| training_samples, | |
| labeled_images, | |
| ], | |
| ) | |
| reject_btn.click( | |
| update_app, | |
| inputs=[reject_btn, data_gr, current_index, history, username], | |
| outputs=[ | |
| query_image, | |
| label_plot, | |
| current_index, | |
| history, | |
| data_gr, | |
| training_samples, | |
| labeled_images, | |
| ], | |
| ) | |
| prepare_btn.click( | |
| preprocessing, | |
| inputs=[data_gr, current_index, history, username], | |
| outputs=[ | |
| query_image, | |
| label_plot, | |
| current_index, | |
| history, | |
| data_gr, | |
| training_samples, | |
| labeled_images, | |
| ], | |
| ) | |
| demo.launch(debug=True) | |