taesiri's picture
backup
fbfb369
raw
history blame
10.2 kB
import csv
import json
import os
import pickle
import random
import string
import sys
import time
from glob import glob
import datasets
import gdown
import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torchvision
from huggingface_hub import HfApi, login, snapshot_download
from PIL import Image
session_token = os.environ.get("SessionToken")
login(token=session_token)
csv.field_size_limit(sys.maxsize)
np.random.seed(int(time.time()))
with open("./imagenet_hard_nearest_indices.pkl", "rb") as f:
knn_results = pickle.load(f)
with open("imagenet-labels.json") as f:
wnid_to_label = json.load(f)
with open("id_to_label.json", "r") as f:
id_to_labels = json.load(f)
imagenet_training_samples_path = "imagenet_traning_samples"
bad_items = open("./ex2.txt", "r").read().split("\n")
bad_items = [x.split(".")[0] for x in bad_items]
bad_items = [int(x) for x in bad_items if x != ""]
NUMBER_OF_IMAGES = 100 # len(bad_items)
# download and extract folders
gdown.cached_download(
url="https://huggingface.co/datasets/taesiri/imagenet_hard_review_samples/resolve/main/data.zip",
path="./data.zip",
quiet=False,
md5="8666a9b361f6eea79878be6c09701def",
)
# EXTRACT if needed
if not os.path.exists("./imagenet_traning_samples") or not os.path.exists(
"./knn_cache_for_imagenet_hard"
):
torchvision.datasets.utils.extract_archive(
from_path="data.zip",
to_path="./",
remove_finished=False,
)
imagenet_hard = datasets.load_dataset("taesiri/imagenet-hard", split="validation")
def update_snapshot(username):
output_dir = snapshot_download(
repo_id="taesiri/imagenet_hard_review_data",
allow_patterns="*.json",
repo_type="dataset",
)
files = glob(f"{output_dir}/*.json")
df = pd.DataFrame()
columns = ["id", "user_id", "time", "decision"]
rows = []
for file in files:
with open(file) as f:
data = json.load(f)
tdf = [data[x] for x in columns]
# add filename as a column
rows.append(tdf)
df = pd.DataFrame(rows, columns=columns)
df = df[df["user_id"] == username]
return df
def generate_dataset(username):
global NUMBER_OF_IMAGES
df = update_snapshot(username)
all_images = set(bad_items)
answered = set(df.id)
remaining = list(all_images - answered)
if len(remaining) < NUMBER_OF_IMAGES and len(remaining) > 0:
NUMBER_OF_IMAGES = len(remaining)
random_indices = list(remaining)
elif len(remaining) == 0:
return []
else:
random_indices = np.random.choice(remaining, NUMBER_OF_IMAGES, replace=False)
random_images = [imagenet_hard[int(i)]["image"] for i in random_indices]
random_gt_ids = [imagenet_hard[int(i)]["label"] for i in random_indices]
random_gt_labels = [imagenet_hard[int(x)]["english_label"] for x in random_indices]
data = []
for i, image in enumerate(random_images):
data.append(
{
"id": random_indices[i],
"image": image,
"correct_label": random_gt_labels[i],
"original_id": int(random_indices[i]),
}
)
return data
def string_to_image(text):
text = text.replace("_", " ").lower().replace(", ", "\n")
# Create a blank white square image
img = np.ones((220, 75, 3))
fig, ax = plt.subplots(figsize=(6, 2.25))
ax.imshow(img, extent=[0, 1, 0, 1])
ax.text(0.5, 0.75, text, fontsize=18, ha="center", va="center")
ax.set_xticks([])
ax.set_yticks([])
ax.set_xticklabels([])
ax.set_yticklabels([])
for spine in ax.spines.values():
spine.set_visible(False)
return fig
all_samples = glob("./imagenet_traning_samples/*.JPEG")
qid_to_sample = {
int(x.split("/")[-1].split(".")[0].split("_")[0]): x for x in all_samples
}
# user-e3z5b
def get_training_samples(qid):
labels_id = imagenet_hard[int(qid)]["label"]
samples = [qid_to_sample[x] for x in labels_id]
return samples
def load_sample(data, current_index):
image_id = data[current_index]["id"]
qimage = data[current_index]["image"]
labels = data[current_index]["correct_label"]
return qimage, labels
def preprocessing(data, current_index, history, username):
data = generate_dataset(username)
if len(data) == 0:
fake_plot = string_to_image("No more images to review")
empty_image = Image.new("RGB", (224, 224))
return (
empty_image,
fake_plot,
current_index,
history,
data,
None,
)
current_index = 0
qimage, labels = load_sample(data, current_index)
image_id = data[current_index]["id"]
training_samples_image = get_training_samples(image_id)
training_samples_image = [
Image.open(x).convert("RGB") for x in training_samples_image
]
# labels is a list of labels, conver it to a string
labels = ", ".join(labels)
label_plot = string_to_image(labels)
return qimage, label_plot, current_index, history, data, training_samples_image
def update_app(decision, data, current_index, history, username):
global NUMBER_OF_IMAGES
if current_index == -1:
return
if current_index == NUMBER_OF_IMAGES - 1:
time_stamp = int(time.time())
image_id = data[current_index]["id"]
# convert to percentage
dicision_dict = {
"id": int(image_id),
"user_id": username,
"time": time_stamp,
"decision": decision,
}
# upload the decision to the server
temp_filename = f"results_{username}_{time_stamp}.json"
# convert decision_dict to json and save it on the disk
with open(temp_filename, "w") as f:
json.dump(dicision_dict, f)
api = HfApi()
api.upload_file(
path_or_fileobj=temp_filename,
path_in_repo=temp_filename,
repo_id="taesiri/imagenet_hard_review_data",
repo_type="dataset",
)
os.remove(temp_filename)
fake_plot = string_to_image("Thank you for your time!")
empty_image = Image.new("RGB", (224, 224))
return empty_image, fake_plot, current_index, history, data, None
if current_index >= 0 and current_index < NUMBER_OF_IMAGES - 1:
time_stamp = int(time.time())
image_id = data[current_index]["id"]
# convert to percentage
dicision_dict = {
"id": int(image_id),
"user_id": username,
"time": time_stamp,
"decision": decision,
}
# upload the decision to the server
temp_filename = f"results_{username}_{time_stamp}.json"
# convert decision_dict to json and save it on the disk
with open(temp_filename, "w") as f:
json.dump(dicision_dict, f)
api = HfApi()
api.upload_file(
path_or_fileobj=temp_filename,
path_in_repo=temp_filename,
repo_id="taesiri/imagenet_hard_review_data",
repo_type="dataset",
)
os.remove(temp_filename)
# Load the Next Image
current_index += 1
qimage, labels = load_sample(data, current_index)
image_id = data[current_index]["id"]
training_samples_image = get_training_samples(image_id)
training_samples_image = [
Image.open(x).convert("RGB") for x in training_samples_image
]
# labels is a list of labels, conver it to a string
labels = ", ".join(labels)
label_plot = string_to_image(labels)
return qimage, label_plot, current_index, history, data, training_samples_image
newcss = """
#query_image{
height: auto !important;
}
#nn_gallery {
height: auto !important;
}
#sample_gallery {
height: auto !important;
}
"""
with gr.Blocks(css=newcss) as demo:
data_gr = gr.State({})
current_index = gr.State(-1)
history = gr.State({})
gr.Markdown("# Cleaning ImageNet-Hard!")
random_str = "".join(
random.choice(string.ascii_lowercase + string.digits) for _ in range(5)
)
with gr.Row():
username = gr.Textbox(label="Username", value=f"user-{random_str}")
prepare_btn = gr.Button(value="Load Samples")
with gr.Column():
with gr.Row():
accept_btn = gr.Button(value="Accept")
myabe_btn = gr.Button(value="Not Sure!")
reject_btn = gr.Button(value="Reject")
with gr.Row():
query_image = gr.Image(type="pil", label="Query", elem_id="query_image")
with gr.Column():
label_plot = gr.Plot(
label="Is this a correct label for this image?", type="fig"
)
training_samples = gr.Gallery(
type="pil", label="Training samples", elem_id="sample_gallery"
)
accept_btn.click(
update_app,
inputs=[accept_btn, data_gr, current_index, history, username],
outputs=[
query_image,
label_plot,
current_index,
history,
data_gr,
training_samples,
],
)
myabe_btn.click(
update_app,
inputs=[myabe_btn, data_gr, current_index, history, username],
outputs=[
query_image,
label_plot,
current_index,
history,
data_gr,
training_samples,
],
)
reject_btn.click(
update_app,
inputs=[reject_btn, data_gr, current_index, history, username],
outputs=[
query_image,
label_plot,
current_index,
history,
data_gr,
training_samples,
],
)
prepare_btn.click(
preprocessing,
inputs=[data_gr, current_index, history, username],
outputs=[
query_image,
label_plot,
current_index,
history,
data_gr,
training_samples,
],
)
demo.launch()