|
import os
|
|
import clip
|
|
import torch
|
|
import open_clip
|
|
|
|
import numpy as np
|
|
from torchvision.datasets import CIFAR100
|
|
from tqdm import tqdm
|
|
import torchvision.transforms as transforms
|
|
import warnings
|
|
with warnings.catch_warnings():
|
|
warnings.simplefilter(action='ignore', category=FutureWarning)
|
|
with warnings.catch_warnings():
|
|
warnings.simplefilter(action='ignore', category=UserWarning)
|
|
import torchvision
|
|
|
|
import pandas as pd
|
|
from pathlib import Path
|
|
from PIL import Image
|
|
from torch.utils.data import Dataset, DataLoader
|
|
import pickle
|
|
|
|
|
|
class FACET(Dataset):
|
|
"""Face Landmarks dataset."""
|
|
|
|
def __init__(self, paths, labels, root_dir, file_extension=".jpg", transform=None):
|
|
"""
|
|
Arguments:
|
|
csv_file (string): Path to the csv file with annotations.
|
|
root_dir (string): Directory with all the images.
|
|
transform (callable, optional): Optional transform to be applied
|
|
on a sample.
|
|
"""
|
|
self.fpaths = paths
|
|
self.extension = file_extension
|
|
self.labels = labels
|
|
self.root_dir = root_dir
|
|
self.transform = transform
|
|
|
|
def __len__(self):
|
|
return len(self.fpaths)
|
|
|
|
def __getitem__(self, idx):
|
|
if torch.is_tensor(idx):
|
|
idx = idx.tolist()
|
|
|
|
img_name = os.path.join(self.root_dir,
|
|
str(self.fpaths[idx])+self.extension)
|
|
image = self.transform(Image.open(img_name).convert('RGB'))
|
|
label = self.labels[idx]
|
|
|
|
return image, label
|
|
|
|
imagenet_templates = [
|
|
'a bad photo of a {}.',
|
|
'a photo of many {}.',
|
|
'a sculpture of a {}.',
|
|
'a photo of the hard to see {}.',
|
|
'a low resolution photo of the {}.',
|
|
'a rendering of a {}.',
|
|
'graffiti of a {}.',
|
|
'a bad photo of the {}.',
|
|
'a cropped photo of the {}.',
|
|
'a tattoo of a {}.',
|
|
'the embroidered {}.',
|
|
'a photo of a hard to see {}.',
|
|
'a bright photo of a {}.',
|
|
'a photo of a clean {}.',
|
|
'a photo of a dirty {}.',
|
|
'a dark photo of the {}.',
|
|
'a drawing of a {}.',
|
|
'a photo of my {}.',
|
|
'the plastic {}.',
|
|
'a photo of the cool {}.',
|
|
'a close-up photo of a {}.',
|
|
'a black and white photo of the {}.',
|
|
'a painting of the {}.',
|
|
'a painting of a {}.',
|
|
'a pixelated photo of the {}.',
|
|
'a sculpture of the {}.',
|
|
'a bright photo of the {}.',
|
|
'a cropped photo of a {}.',
|
|
'a plastic {}.',
|
|
'a photo of the dirty {}.',
|
|
'a jpeg corrupted photo of a {}.',
|
|
'a blurry photo of the {}.',
|
|
'a photo of the {}.',
|
|
'a good photo of the {}.',
|
|
'a rendering of the {}.',
|
|
'a {} in a video game.',
|
|
'a photo of one {}.',
|
|
'a doodle of a {}.',
|
|
'a close-up photo of the {}.',
|
|
'a photo of a {}.',
|
|
'the origami {}.',
|
|
'the {} in a video game.',
|
|
'a sketch of a {}.',
|
|
'a doodle of the {}.',
|
|
'a origami {}.',
|
|
'a low resolution photo of a {}.',
|
|
'the toy {}.',
|
|
'a rendition of the {}.',
|
|
'a photo of the clean {}.',
|
|
'a photo of a large {}.',
|
|
'a rendition of a {}.',
|
|
'a photo of a nice {}.',
|
|
'a photo of a weird {}.',
|
|
'a blurry photo of a {}.',
|
|
'a cartoon {}.',
|
|
'art of a {}.',
|
|
'a sketch of the {}.',
|
|
'a embroidered {}.',
|
|
'a pixelated photo of a {}.',
|
|
'itap of the {}.',
|
|
'a jpeg corrupted photo of the {}.',
|
|
'a good photo of a {}.',
|
|
'a plushie {}.',
|
|
'a photo of the nice {}.',
|
|
'a photo of the small {}.',
|
|
'a photo of the weird {}.',
|
|
'the cartoon {}.',
|
|
'art of the {}.',
|
|
'a drawing of the {}.',
|
|
'a photo of the large {}.',
|
|
'a black and white photo of a {}.',
|
|
'the plushie {}.',
|
|
'a dark photo of a {}.',
|
|
'itap of a {}.',
|
|
'graffiti of the {}.',
|
|
'a toy {}.',
|
|
'itap of my {}.',
|
|
'a photo of a cool {}.',
|
|
'a photo of a small {}.',
|
|
'a tattoo of the {}.',
|
|
]
|
|
|
|
models = (
|
|
|
|
"ViT-B/16",
|
|
"ViT-B/32",
|
|
"ViT-L/14",
|
|
"RN50",
|
|
"RN101",
|
|
|
|
|
|
"vit_b_16_400m",
|
|
"vit_b_16_2b",
|
|
"vit_l_14_400m",
|
|
"vit_l_14_2b",
|
|
"vit_b_32_400m",
|
|
"vit_b_32_2b",
|
|
)
|
|
weights = (
|
|
|
|
"OpenAI hub",
|
|
"OpenAI hub",
|
|
"OpenAI hub",
|
|
"OpenAI hub",
|
|
"OpenAI hub",
|
|
|
|
"OpenCLIP hub",
|
|
"OpenCLIP hub",
|
|
"OpenCLIP hub",
|
|
"OpenCLIP hub",
|
|
"OpenCLIP hub",
|
|
"OpenCLIP hub",
|
|
)
|
|
|
|
|
|
facet_annotations_file_path = "INSERT_HERE/annotations.csv"
|
|
facet_root = "?????"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
facet = pd.read_csv(facet_annotations_file_path, header=0).rename(columns={'Unnamed: 0': 'sample_idx'})
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
|
|
experiments = ["facet_paper_skin_ours", "facet_paper_clothes_only", "facet_paper_skin_ours_occupation_prompt", "facet_paper_clothes_only_occupation_prompt",
|
|
"facet_paper_whole_body", "facet_paper_whole_body_occupation_prompt"
|
|
]
|
|
|
|
|
|
for experiment in experiments:
|
|
for model_name, weight in zip(models, weights):
|
|
print( "\n\n",model_name, experiment)
|
|
preprocess = None
|
|
if model_name == "vit_b_16_400m":
|
|
model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-16', pretrained='laion400m_e32')
|
|
clip_src = "OpenCLIP"
|
|
elif model_name == "vit_b_16_2b":
|
|
model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-16', pretrained='laion2b_s34b_b88k')
|
|
clip_src = "OpenCLIP"
|
|
elif model_name == "vit_b_32_400m":
|
|
model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained='laion400m_e32')
|
|
clip_src = "OpenCLIP"
|
|
elif model_name == "vit_b_32_2b":
|
|
model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained='laion2b_s34b_b79k')
|
|
clip_src = "OpenCLIP"
|
|
elif model_name == "vit_l_14_400m":
|
|
model, _, preprocess = open_clip.create_model_and_transforms('ViT-L-14', pretrained='laion400m_e32')
|
|
clip_src = "OpenCLIP"
|
|
elif model_name == "vit_l_14_2b":
|
|
model, _, preprocess = open_clip.create_model_and_transforms('ViT-L-14', pretrained='laion2b_s32b_b82k')
|
|
clip_src = "OpenCLIP"
|
|
elif "ViT" in model_name:
|
|
model, preprocess = clip.load(model_name, device)
|
|
clip_src = "OpenAI"
|
|
elif "RN" in model_name:
|
|
model, preprocess = clip.load(model_name, device)
|
|
clip_src = "OpenAI"
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
model.cuda()
|
|
|
|
occupations = ['backpacker', 'ballplayer', 'bartender', 'basketball_player', 'boatman', 'carpenter', 'cheerleader', 'climber', 'computer_user', 'craftsman', 'dancer', 'disk_jockey', 'doctor', 'drummer', 'electrician', 'farmer', 'fireman', 'flutist', 'gardener', 'guard', 'guitarist', 'gymnast', 'hairdresser', 'horseman', 'judge', 'laborer', 'lawman', 'lifeguard', 'machinist', 'motorcyclist', 'nurse', 'painter', 'patient', 'prayer', 'referee', 'repairman', 'reporter', 'retailer', 'runner', 'sculptor', 'seller', 'singer', 'skateboarder', 'soccer_player', 'soldier', 'speaker', 'student', 'teacher', 'tennis_player', 'trumpeter', 'waiter']
|
|
|
|
tokens_occupations = clip.tokenize([f"A photo of a {occupation}" for occupation in occupations]).cuda()
|
|
|
|
facet_img_root = facet_save_root + "/" experiment + "/"
|
|
out_dir = experiment + "_zero_shot"
|
|
if not os.path.exists(out_dir):
|
|
os.makedirs(out_dir)
|
|
|
|
|
|
fnames = list(os.listdir(facet_img_root))
|
|
|
|
for attribute_value in ["only_original_male", "only_original_female", "original", "male_to_female", "male_to_male", "female_to_female", "female_to_male"]:
|
|
print(f"----{attribute_value}----")
|
|
facet = pd.read_csv("../../datasets/facet/annotations/annotations.csv", header=0).rename(columns={'Unnamed: 0': 'sample_idx'})
|
|
extension = ".png"
|
|
|
|
processed_synthetic_samples = set()
|
|
|
|
for fname in fnames:
|
|
bbid, target_attr = fname.split("_")[0], "_".join(fname.split("_")[1:]).split(".")[0]
|
|
|
|
if "only" in attribute_value:
|
|
if target_attr=="original" and bbid not in processed_synthetic_samples:
|
|
processed_synthetic_samples.add(int(bbid))
|
|
elif target_attr==attribute_value and bbid not in processed_synthetic_samples:
|
|
processed_synthetic_samples.add(int(bbid))
|
|
|
|
if attribute_value == "only_original_male":
|
|
facet = facet[facet.person_id.isin(processed_synthetic_samples)]
|
|
facet = facet[facet.gender_presentation_na != 1]
|
|
facet = facet[facet.gender_presentation_non_binary != 1]
|
|
facet = facet[(facet.gender_presentation_masc == 1)]
|
|
elif attribute_value == "only_original_female":
|
|
facet = facet[facet.person_id.isin(processed_synthetic_samples)]
|
|
facet = facet[facet.gender_presentation_na != 1]
|
|
facet = facet[facet.gender_presentation_non_binary != 1]
|
|
facet = facet[(facet.gender_presentation_fem == 1)]
|
|
else:
|
|
facet = facet[facet.person_id.isin(processed_synthetic_samples)]
|
|
facet = facet[facet.gender_presentation_na != 1]
|
|
facet = facet[facet.gender_presentation_non_binary != 1]
|
|
facet = facet[(facet.gender_presentation_masc == 1) | (facet.gender_presentation_fem == 1)]
|
|
|
|
|
|
facet["class1"] = facet["class1"].apply(lambda val: int(occupations.index(val)))
|
|
|
|
bsize = 512
|
|
predictions = []
|
|
acc = my_acc = 0
|
|
n_batches = 0
|
|
|
|
def zeroshot_classifier(classnames, templates):
|
|
with torch.no_grad():
|
|
zeroshot_weights = []
|
|
for classname in tqdm(classnames):
|
|
texts = [template.format(classname) for template in templates]
|
|
texts = clip.tokenize(texts).cuda()
|
|
class_embeddings = model.encode_text(texts)
|
|
class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
|
|
class_embedding = class_embeddings.mean(dim=0)
|
|
class_embedding /= class_embedding.norm()
|
|
zeroshot_weights.append(class_embedding)
|
|
zeroshot_weights = torch.stack(zeroshot_weights, dim=1).cuda()
|
|
return zeroshot_weights
|
|
|
|
|
|
if "only" in attribute_value:
|
|
dataset = FACET(facet.person_id.values, torch.tensor(facet.class1.values), facet_img_root, transform=preprocess, file_extension=f"_original.png")
|
|
else:
|
|
dataset = FACET(facet.person_id.values, torch.tensor(facet.class1.values), facet_img_root, transform=preprocess, file_extension=f"_{attribute_value}.png")
|
|
dataloader = DataLoader(dataset, batch_size=bsize, shuffle=False, num_workers=6, drop_last=False,)
|
|
|
|
zeroshot_weights = zeroshot_classifier(occupations[:39], imagenet_templates)
|
|
|
|
|
|
for imgs, labels in tqdm(dataloader):
|
|
|
|
with torch.no_grad(), torch.cuda.amp.autocast():
|
|
if clip_src == "OpenAI":
|
|
|
|
image_features = model.encode_image(imgs.half().cuda())
|
|
image_features /= image_features.norm(dim=-1, keepdim=True)
|
|
logits = 100. * image_features @ zeroshot_weights
|
|
probs = logits.softmax(dim=-1).cpu().numpy()
|
|
else:
|
|
|
|
image_features = model.encode_image(imgs.half().cuda())
|
|
image_features /= image_features.norm(dim=-1, keepdim=True)
|
|
probs = (100. * image_features @ zeroshot_weights).softmax(dim=-1).cpu().numpy()
|
|
|
|
preds_batch = np.argmax(probs, axis=-1)
|
|
predictions += preds_batch.tolist()
|
|
acc += torch.sum(torch.tensor(preds_batch).cuda()==labels.cuda()) / preds_batch.shape[0]
|
|
n_batches += 1
|
|
|
|
|
|
print(model_name, "acc: ", acc / n_batches, "%")
|
|
|
|
results = pd.DataFrame({"person_id": facet.person_id.values,
|
|
"inpainted_attribute": attribute_value,
|
|
"age_presentation_young": facet.age_presentation_young.values,
|
|
"age_presentation_middle": facet.age_presentation_middle.values,
|
|
"age_presentation_older": facet.age_presentation_older.values,
|
|
"gender_presentation_fem": facet.gender_presentation_fem.values,
|
|
"gender_presentation_masc": facet.gender_presentation_masc.values,
|
|
"gt_class_label": facet.class1.values,
|
|
"class_predictions": predictions
|
|
})
|
|
|
|
results.to_csv(f'{out_dir}/{model_name.replace("/", "_").replace("-", "_")}_{attribute_value}_predictions.csv')
|
|
|
|
with open(f'{out_dir}/{model_name.replace("/", "_").replace("-", "_")}_{attribute_value}_accuracy.txt', "w") as of:
|
|
of.write(str((acc/n_batches).item()))
|
|
|
|
with open(f'{out_dir}/{model_name.replace("/", "_").replace("-", "_")}_{attribute_value}.pkl', "wb") as f:
|
|
pickle.dump(predictions, f)
|
|
|