PinPoint / zero_shot_classification.py
anonymous-upload-neurips-2025's picture
Upload zero_shot_classification.py
cb3baf9 verified
import os
import clip
import torch
import open_clip
import numpy as np
from torchvision.datasets import CIFAR100
from tqdm import tqdm
import torchvision.transforms as transforms
import warnings
with warnings.catch_warnings():
warnings.simplefilter(action='ignore', category=FutureWarning)
with warnings.catch_warnings():
warnings.simplefilter(action='ignore', category=UserWarning)
import torchvision
import pandas as pd
from pathlib import Path
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import pickle
class FACET(Dataset):
"""Face Landmarks dataset."""
def __init__(self, paths, labels, root_dir, file_extension=".jpg", transform=None):
"""
Arguments:
csv_file (string): Path to the csv file with annotations.
root_dir (string): Directory with all the images.
transform (callable, optional): Optional transform to be applied
on a sample.
"""
self.fpaths = paths
self.extension = file_extension
self.labels = labels
self.root_dir = root_dir
self.transform = transform
def __len__(self):
return len(self.fpaths)
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
img_name = os.path.join(self.root_dir,
str(self.fpaths[idx])+self.extension)
image = self.transform(Image.open(img_name).convert('RGB'))
label = self.labels[idx]
return image, label
imagenet_templates = [
'a bad photo of a {}.',
'a photo of many {}.',
'a sculpture of a {}.',
'a photo of the hard to see {}.',
'a low resolution photo of the {}.',
'a rendering of a {}.',
'graffiti of a {}.',
'a bad photo of the {}.',
'a cropped photo of the {}.',
'a tattoo of a {}.',
'the embroidered {}.',
'a photo of a hard to see {}.',
'a bright photo of a {}.',
'a photo of a clean {}.',
'a photo of a dirty {}.',
'a dark photo of the {}.',
'a drawing of a {}.',
'a photo of my {}.',
'the plastic {}.',
'a photo of the cool {}.',
'a close-up photo of a {}.',
'a black and white photo of the {}.',
'a painting of the {}.',
'a painting of a {}.',
'a pixelated photo of the {}.',
'a sculpture of the {}.',
'a bright photo of the {}.',
'a cropped photo of a {}.',
'a plastic {}.',
'a photo of the dirty {}.',
'a jpeg corrupted photo of a {}.',
'a blurry photo of the {}.',
'a photo of the {}.',
'a good photo of the {}.',
'a rendering of the {}.',
'a {} in a video game.',
'a photo of one {}.',
'a doodle of a {}.',
'a close-up photo of the {}.',
'a photo of a {}.',
'the origami {}.',
'the {} in a video game.',
'a sketch of a {}.',
'a doodle of the {}.',
'a origami {}.',
'a low resolution photo of a {}.',
'the toy {}.',
'a rendition of the {}.',
'a photo of the clean {}.',
'a photo of a large {}.',
'a rendition of a {}.',
'a photo of a nice {}.',
'a photo of a weird {}.',
'a blurry photo of a {}.',
'a cartoon {}.',
'art of a {}.',
'a sketch of the {}.',
'a embroidered {}.',
'a pixelated photo of a {}.',
'itap of the {}.',
'a jpeg corrupted photo of the {}.',
'a good photo of a {}.',
'a plushie {}.',
'a photo of the nice {}.',
'a photo of the small {}.',
'a photo of the weird {}.',
'the cartoon {}.',
'art of the {}.',
'a drawing of the {}.',
'a photo of the large {}.',
'a black and white photo of a {}.',
'the plushie {}.',
'a dark photo of a {}.',
'itap of a {}.',
'graffiti of the {}.',
'a toy {}.',
'itap of my {}.',
'a photo of a cool {}.',
'a photo of a small {}.',
'a tattoo of the {}.',
]
models = (
# CLIP OpenAI
"ViT-B/16",
"ViT-B/32",
"ViT-L/14",
"RN50",
"RN101",
# CLIP OpenCLIP
"vit_b_16_400m",
"vit_b_16_2b",
"vit_l_14_400m",
"vit_l_14_2b",
"vit_b_32_400m",
"vit_b_32_2b",
)
weights = (
# CLIP OpenAI
"OpenAI hub",
"OpenAI hub",
"OpenAI hub",
"OpenAI hub",
"OpenAI hub",
# CLIP OpenCLIP
"OpenCLIP hub",
"OpenCLIP hub",
"OpenCLIP hub",
"OpenCLIP hub",
"OpenCLIP hub",
"OpenCLIP hub",
)
facet_annotations_file_path = "INSERT_HERE/annotations.csv"
facet_root = "?????" # where the in-painted images are stored, the following structure is expected:
# facet_root/
# facet_paper_skin_ours/
# facet_paper_clothes_only/
# facet_paper_skin_ours_occupation_prompt/
# facet_paper_clothes_only_occupation_prompt/
# facet_paper_whole_body/
# facet_paper_whole_body_occupation_prompt/
facet = pd.read_csv(facet_annotations_file_path, header=0).rename(columns={'Unnamed: 0': 'sample_idx'})
device = "cuda" if torch.cuda.is_available() else "cpu"
experiments = ["facet_paper_skin_ours", "facet_paper_clothes_only", "facet_paper_skin_ours_occupation_prompt", "facet_paper_clothes_only_occupation_prompt",
"facet_paper_whole_body", "facet_paper_whole_body_occupation_prompt"
]
for experiment in experiments:
for model_name, weight in zip(models, weights):
print( "\n\n",model_name, experiment)
preprocess = None
if model_name == "vit_b_16_400m":
model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-16', pretrained='laion400m_e32')
clip_src = "OpenCLIP"
elif model_name == "vit_b_16_2b":
model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-16', pretrained='laion2b_s34b_b88k')
clip_src = "OpenCLIP"
elif model_name == "vit_b_32_400m":
model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained='laion400m_e32')
clip_src = "OpenCLIP"
elif model_name == "vit_b_32_2b":
model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained='laion2b_s34b_b79k')
clip_src = "OpenCLIP"
elif model_name == "vit_l_14_400m":
model, _, preprocess = open_clip.create_model_and_transforms('ViT-L-14', pretrained='laion400m_e32')
clip_src = "OpenCLIP"
elif model_name == "vit_l_14_2b":
model, _, preprocess = open_clip.create_model_and_transforms('ViT-L-14', pretrained='laion2b_s32b_b82k')
clip_src = "OpenCLIP"
elif "ViT" in model_name:
model, preprocess = clip.load(model_name, device)
clip_src = "OpenAI"
elif "RN" in model_name:
model, preprocess = clip.load(model_name, device)
clip_src = "OpenAI"
else:
raise NotImplementedError
model.cuda()
occupations = ['backpacker', 'ballplayer', 'bartender', 'basketball_player', 'boatman', 'carpenter', 'cheerleader', 'climber', 'computer_user', 'craftsman', 'dancer', 'disk_jockey', 'doctor', 'drummer', 'electrician', 'farmer', 'fireman', 'flutist', 'gardener', 'guard', 'guitarist', 'gymnast', 'hairdresser', 'horseman', 'judge', 'laborer', 'lawman', 'lifeguard', 'machinist', 'motorcyclist', 'nurse', 'painter', 'patient', 'prayer', 'referee', 'repairman', 'reporter', 'retailer', 'runner', 'sculptor', 'seller', 'singer', 'skateboarder', 'soccer_player', 'soldier', 'speaker', 'student', 'teacher', 'tennis_player', 'trumpeter', 'waiter']
tokens_occupations = clip.tokenize([f"A photo of a {occupation}" for occupation in occupations]).cuda()
facet_img_root = facet_save_root + "/" experiment + "/"
out_dir = experiment + "_zero_shot"
if not os.path.exists(out_dir):
os.makedirs(out_dir)
fnames = list(os.listdir(facet_img_root))
for attribute_value in ["only_original_male", "only_original_female", "original", "male_to_female", "male_to_male", "female_to_female", "female_to_male"]:
print(f"----{attribute_value}----")
facet = pd.read_csv("../../datasets/facet/annotations/annotations.csv", header=0).rename(columns={'Unnamed: 0': 'sample_idx'}) # Bounding boxes
extension = ".png"
processed_synthetic_samples = set()
for fname in fnames:
bbid, target_attr = fname.split("_")[0], "_".join(fname.split("_")[1:]).split(".")[0]
if "only" in attribute_value:
if target_attr=="original" and bbid not in processed_synthetic_samples:
processed_synthetic_samples.add(int(bbid))
elif target_attr==attribute_value and bbid not in processed_synthetic_samples:
processed_synthetic_samples.add(int(bbid))
if attribute_value == "only_original_male":
facet = facet[facet.person_id.isin(processed_synthetic_samples)]
facet = facet[facet.gender_presentation_na != 1]
facet = facet[facet.gender_presentation_non_binary != 1]
facet = facet[(facet.gender_presentation_masc == 1)]
elif attribute_value == "only_original_female":
facet = facet[facet.person_id.isin(processed_synthetic_samples)]
facet = facet[facet.gender_presentation_na != 1]
facet = facet[facet.gender_presentation_non_binary != 1]
facet = facet[(facet.gender_presentation_fem == 1)]
else:
facet = facet[facet.person_id.isin(processed_synthetic_samples)]
facet = facet[facet.gender_presentation_na != 1]
facet = facet[facet.gender_presentation_non_binary != 1]
facet = facet[(facet.gender_presentation_masc == 1) | (facet.gender_presentation_fem == 1)]
facet["class1"] = facet["class1"].apply(lambda val: int(occupations.index(val)))
bsize = 512
predictions = []
acc = my_acc = 0
n_batches = 0
def zeroshot_classifier(classnames, templates):
with torch.no_grad():
zeroshot_weights = []
for classname in tqdm(classnames):
texts = [template.format(classname) for template in templates] #format with class
texts = clip.tokenize(texts).cuda() #tokenize
class_embeddings = model.encode_text(texts) #embed with text encoder
class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
class_embedding = class_embeddings.mean(dim=0)
class_embedding /= class_embedding.norm()
zeroshot_weights.append(class_embedding)
zeroshot_weights = torch.stack(zeroshot_weights, dim=1).cuda()
return zeroshot_weights
if "only" in attribute_value:
dataset = FACET(facet.person_id.values, torch.tensor(facet.class1.values), facet_img_root, transform=preprocess, file_extension=f"_original.png")
else:
dataset = FACET(facet.person_id.values, torch.tensor(facet.class1.values), facet_img_root, transform=preprocess, file_extension=f"_{attribute_value}.png")
dataloader = DataLoader(dataset, batch_size=bsize, shuffle=False, num_workers=6, drop_last=False,)
zeroshot_weights = zeroshot_classifier(occupations[:39], imagenet_templates)
for imgs, labels in tqdm(dataloader):
with torch.no_grad(), torch.cuda.amp.autocast():
if clip_src == "OpenAI":
# CLIP
image_features = model.encode_image(imgs.half().cuda())
image_features /= image_features.norm(dim=-1, keepdim=True)
logits = 100. * image_features @ zeroshot_weights
probs = logits.softmax(dim=-1).cpu().numpy()
else:
# OpenCLIP
image_features = model.encode_image(imgs.half().cuda())
image_features /= image_features.norm(dim=-1, keepdim=True)
probs = (100. * image_features @ zeroshot_weights).softmax(dim=-1).cpu().numpy()
preds_batch = np.argmax(probs, axis=-1)
predictions += preds_batch.tolist()
acc += torch.sum(torch.tensor(preds_batch).cuda()==labels.cuda()) / preds_batch.shape[0]
n_batches += 1
print(model_name, "acc: ", acc / n_batches, "%")
results = pd.DataFrame({"person_id": facet.person_id.values,
"inpainted_attribute": attribute_value,
"age_presentation_young": facet.age_presentation_young.values,
"age_presentation_middle": facet.age_presentation_middle.values,
"age_presentation_older": facet.age_presentation_older.values,
"gender_presentation_fem": facet.gender_presentation_fem.values,
"gender_presentation_masc": facet.gender_presentation_masc.values,
"gt_class_label": facet.class1.values,
"class_predictions": predictions
})
results.to_csv(f'{out_dir}/{model_name.replace("/", "_").replace("-", "_")}_{attribute_value}_predictions.csv')
with open(f'{out_dir}/{model_name.replace("/", "_").replace("-", "_")}_{attribute_value}_accuracy.txt', "w") as of:
of.write(str((acc/n_batches).item()))
with open(f'{out_dir}/{model_name.replace("/", "_").replace("-", "_")}_{attribute_value}.pkl', "wb") as f:
pickle.dump(predictions, f)