Spaces:
Runtime error
Runtime error
import sys | |
sys.path.append('../../') | |
from data.clip_dataloader.flickr import FlickrDataModule | |
import pytorch_lightning as pl | |
import numpy as np | |
import torch | |
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts | |
import torch.nn.functional as F | |
import math | |
import copy | |
import argparse | |
from transformers import CLIPModel, BertForSequenceClassification | |
class CLIPLightning(pl.LightningModule): | |
def __init__(self, model_name='ViT-B/32', minibatch_size=2): | |
"""A lightning wrapper for a CLIP model as specified in the paper. | |
Args: | |
model_name (str): A case sensitive visual model name. | |
config (dict): A dictionary containing the CLIP instantiation parameters. | |
""" | |
super().__init__() | |
self.prepare_data_per_node = True | |
self.model_name = 'ViT-B/32' | |
# self.model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") | |
self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") # NOTE load from openAI | |
self.text_encoder = BertForSequenceClassification.from_pretrained("IDEA-CCNL/Taiyi-CLIP-Roberta-102M-Chinese") | |
self.minibatch_size = minibatch_size | |
self.isViT = 'ViT' in self.model_name | |
self.automatic_optimization = False | |
# Training loss: https://github.com/openai/CLIP/issues/83 | |
# Mini-batching thanks to https://github.com/crowsonkb / https://twitter.com/RiversHaveWings | |
# Multi-GPU support: https://github.com/MicPie/clasp | |
def training_step(self, train_batch, idx): | |
# get optimizers and scheduler | |
optimizer = self.optimizers() | |
image, text, labels = train_batch | |
n = math.ceil(len(image) // self.minibatch_size) | |
image_mbs = torch.chunk(image, n) | |
text_mbs = torch.chunk(text, n) | |
with torch.no_grad(): | |
ims = [F.normalize(self.clip_model.get_image_features(im), dim=1) for im in image_mbs] | |
txt = [F.normalize(self.text_encoder(t).logits, dim=1) for t in text_mbs] | |
# gather from all GPUs 这里的LOSS要把所有GPU的汇集起来一起算才对 | |
ims = self.all_gather(torch.cat(ims)) | |
txt = self.all_gather(torch.cat(txt)) | |
if len(ims.shape) == 3: | |
ims = list(ims) | |
txt = list(txt) | |
else: | |
ims = [ims] | |
txt = [txt] | |
image_logits = torch.cat(ims) @ torch.cat(txt).t() * self.clip_model.logit_scale.exp() | |
ground_truth = torch.arange(len(image_logits)).long().to(image_logits.device) | |
loss = (F.cross_entropy(image_logits, ground_truth) + | |
F.cross_entropy(image_logits.t(), ground_truth)).div(2) | |
acc_i = (torch.argmax(image_logits, 1) == ground_truth).sum() | |
acc_t = (torch.argmax(image_logits, 0) == ground_truth).sum() | |
self.log_dict({'loss': loss / len(ims), 'acc': (acc_i + acc_t) / 2 / len(image) / len(ims)}, prog_bar=True) | |
if isinstance(optimizer, list): | |
optimizer = optimizer[0] | |
optimizer.zero_grad() | |
# image loss | |
for j, mb in enumerate(image_mbs[:-1]): | |
# 最后一部分样本舍弃。(对齐的bug) | |
images_tmp = copy.deepcopy(ims) | |
images_tmp[self.global_rank][j * self.minibatch_size:(j+1)*self.minibatch_size] = \ | |
F.normalize(self.clip_model.get_image_features(mb), dim=1) | |
image_logits = torch.cat(images_tmp) @ torch.cat(txt).t() * self.clip_model.logit_scale.exp() | |
ground_truth = torch.arange(len(image_logits)).long().to(image_logits.device) | |
loss = (F.cross_entropy(image_logits, ground_truth) + F.cross_entropy(image_logits.t(), ground_truth))/2 | |
self.manual_backward(loss) | |
# text loss | |
for j, mb in enumerate(text_mbs[:-1]): | |
text_tmp = copy.deepcopy(txt) | |
text_tmp[self.global_rank][j * self.minibatch_size:(j+1)*self.minibatch_size] = \ | |
F.normalize(self.text_encoder(mb).logits, dim=1) | |
image_logits = torch.cat(ims) @ torch.cat(text_tmp).t() * self.clip_model.logit_scale.exp() | |
loss = (F.cross_entropy(image_logits, ground_truth) + F.cross_entropy(image_logits.t(), ground_truth))/2 | |
self.manual_backward(loss) | |
optimizer.step() | |
lr_scheduler = self.lr_schedulers() | |
lr_scheduler.step() | |
self.clip_model.logit_scale.data.clamp_(-np.log(100), np.log(100)) | |
def validation_step(self, val_batch, idx): | |
image, text, labels = val_batch | |
img_embed = self.clip_model.get_image_features(image) | |
txt_embed = self.text_encoder(text).logits | |
# print(img_embed.shape) | |
image_norm = F.normalize(img_embed, dim=1) | |
text_norm = F.normalize(txt_embed, dim=1) | |
image_logits = image_norm @ text_norm.t() * self.clip_model.logit_scale.exp() | |
text_logits = text_norm @ image_norm.t() * self.clip_model.logit_scale.exp() | |
# print(image_logits.shape) | |
# image_logits, text_logits = self.forward(image, text) | |
ground_truth = torch.arange(len(image_logits)).long().to(image_logits.device) | |
loss = (F.cross_entropy(image_logits, ground_truth) + F.cross_entropy(text_logits, ground_truth)).div(2) | |
self.log('val_loss', loss, prog_bar=True) | |
return [image_norm, text_norm, labels] | |
def validation_epoch_end(self, outputs): | |
image_features = torch.cat([x[0] for x in outputs]) | |
text_features = torch.cat([x[1] for x in outputs]) | |
labels = [label for x in outputs for label in x[2]] | |
print(image_features.shape, text_features.shape, len(labels)) | |
self.get_metrics(image_features, text_features, labels, 100) | |
def test_step(self, test_batch, idx): | |
image, text, labels = test_batch | |
image_features = self.clip_model.get_image_features(image) | |
text_features = self.text_encoder(text).logits | |
image_features = image_features / image_features.norm(dim=1, keepdim=True) | |
text_features = text_features / text_features.norm(dim=1, keepdim=True) | |
return [image_features, text_features, labels] | |
def test_epoch_end(self, outputs): | |
image_features = torch.cat([x[0] for x in outputs]) | |
text_features = torch.cat([x[1] for x in outputs]) | |
labels = [label for x in outputs for label in x[2]] | |
print(image_features.shape, text_features.shape, len(labels)) | |
self.get_metrics(image_features, text_features, labels, 100) | |
def get_metrics(self, image_features, text_features, labels, logit_scale): | |
# 计算相似度,支持多个样本的情况(比如一个图片有多个caption) | |
# img2txt计算的时候要用到,因为一张图片可能对应多个文本。 | |
# txt2img计算的时候不需要(一般一个text只有一个对应图片) | |
# metrics = {} | |
logits_per_image = (logit_scale * image_features @ text_features.t()).detach().cpu() | |
logits_per_text = logits_per_image.t().detach().cpu() | |
logits = {"image_to_text": logits_per_image, "text_to_image": logits_per_text} | |
label2idx = {} # 计算label到idx的映射。 | |
repeat_id = [] | |
for i, label in enumerate(labels): | |
if label not in label2idx: | |
label2idx[label] = [i] | |
else: | |
# 表示该index的标签出现过,记录这个index,后续算txt2img分数的时候,这些index的权值要降低。 | |
label2idx[label].append(i) | |
repeat_id.append(i) | |
# print(label2idx) # 标注了每个label的idx | |
# print('repeat_id:', repeat_id) | |
ground_truth = [label2idx[label] for label in labels] | |
# print(ground_truth) | |
for name, logit in logits.items(): | |
# print(name, logit.shape) | |
if name == 'text_to_image': | |
logit[:, repeat_id] -= 1e8 # 这部分的分数要降低。(重复出现的图片,直接忽略) | |
r1_stat, r5_stat, r10_stat = [], [], [] | |
ranking = torch.argsort(logit, descending=True) # index of the largest element to the smallest | |
# print(name, ranking[:, :10]) | |
for i, each_query in enumerate(ranking[:, :10]): | |
for j, q in enumerate(each_query): | |
if q in ground_truth[i]: | |
if j == 0: | |
r1_stat.append(1) | |
r5_stat.append(1) | |
r10_stat.append(1) | |
break | |
if j < 5: | |
r5_stat.append(1) | |
r10_stat.append(1) | |
break | |
if j < 10: | |
r10_stat.append(1) | |
break | |
print(f'{name} r1:{sum(r1_stat)/len(logit)}, r5:{sum(r5_stat)/len(logit)}, r10:{sum(r10_stat)/len(logit)}') | |
def configure_optimizers(self): | |
lr = { | |
"RN50": 5e-4, | |
"RN101": 5e-4, | |
"RN50x4": 5e-4, | |
"RN50x16": 4e-4, | |
"RN50x64": 3.6e-4, | |
"ViT-B/32": 5e-4, | |
"ViT-B/16": 5e-4, | |
"ViT-L/14": 4e-4, | |
"ViT-L/14-336px": 2e-5 | |
}[self.model_name] | |
optimizer = torch.optim.AdamW( | |
[{'params': self.clip_model.parameters()}, {'params': self.text_encoder.parameters()}], | |
lr=lr, | |
betas=( | |
0.9, | |
0.98 if self.isViT else 0.999 | |
), | |
eps=1e-6 if self.isViT else 1e-8, | |
weight_decay=0.2 | |
) | |
# Source: https://github.com/openai/CLIP/issues/107 | |
# Use pip install 'git+https://github.com/katsura-jp/pytorch-cosine-annealing-with-warmup' | |
lr_scheduler = CosineAnnealingWarmRestarts( | |
optimizer, | |
T_0=2000 | |
) | |
# CosineAnnealingWarmupRestarts | |
return {'optimizer': optimizer, 'lr_scheduler': lr_scheduler} | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
# model_name | |
parser.add_argument('--model', type=str, | |
default="ViT-B/32", | |
help='model definition') | |
# experiment setting | |
parser.add_argument('--batch_size', type=int, default=128) | |
parser.add_argument('--num_epoches', type=int, default=1) | |
parser.add_argument('--num_gpus', type=int, default=2) | |
# dataset | |
parser.add_argument('--train_filename', type=str, | |
help='dir or csv file') | |
parser.add_argument('--train_root', type=str, | |
help='image root path') | |
parser.add_argument('--val_filename', type=str, | |
help='dir or csv file') | |
parser.add_argument('--val_root', type=str, | |
help='image root path') | |
parser.add_argument('--test_filename', type=str, | |
help='dir or csv file') | |
parser.add_argument('--test_root', type=str, | |
help='image root path') | |
parser.add_argument('--num_workers', type=int, default=0) | |
# huggingface pretrain model 定义 | |
parser.add_argument('--pretrain_model', type=str, | |
default="openai/clip-vit-base-patch32", | |
help='defalut load from openai') # "wf-genius/TaiYi-CLIP-ViT-B-32" 是我训好的 NOTE | |
args = parser.parse_args() | |
dm = FlickrDataModule(args) | |
model = CLIPLightning(model_name=args.model, minibatch_size=args.batch_size//2) | |
trainer = pl.Trainer(gpus=args.num_gpus, precision=16, max_epochs=args.num_epoches) | |
trainer.test(model, dm) # zero-shot test | |
trainer.fit(model, dm) # finetune on train set | |
trainer.test(model, dm) # test again | |