Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python | |
| # coding: utf-8 | |
| # In[37]: | |
| import torch | |
| import torch.nn as nn | |
| from functools import partial | |
| #import clip | |
| from einops import rearrange, repeat | |
| from glob import glob | |
| from PIL import Image | |
| from torchvision import transforms as T | |
| from tqdm import tqdm | |
| import pickle | |
| import numpy as np | |
| import os | |
| from transformers import AutoProcessor, CLIPVisionModelWithProjection, CLIPProcessor, CLIPModel | |
| device = 'cuda:0' | |
| #model = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14").to(device) | |
| #processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14") | |
| class ClipImageEncoder(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.emb_dim = (1, 257, 1024) | |
| self.model = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14") | |
| self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14") | |
| self.model = self.model.eval() | |
| for param in self.parameters(): | |
| param.requires_grad = False | |
| def forward(self, x): | |
| ret = self.model(x) | |
| return ret.last_hidden_state, ret.image_embeds | |
| def preprocess(self, style_image): | |
| # if os.path.exists(style_file): | |
| # style_image = Image.open(style_file) | |
| # else: | |
| # style_image = Image.fromarray(np.zeros((224,224,3), dtype=np.uint8)) | |
| x = torch.tensor(np.array(self.processor.image_processor(style_image).pixel_values)) | |
| return x | |
| def postprocess(self, x): # return numpy | |
| return x.detach().cpu().squeeze(0).numpy() | |
| if __name__ == '__main__': | |
| device = 'cuda:1' | |
| style_files = glob("/home/soon/datasets/deepfashion_inshop/styles_default/**/*.jpg", recursive=True) | |
| style_files = [x for x in style_files if x.split('/')[-1]!='background.jpg'] | |
| clip_model = ClipImageEncoder().to(device) | |
| for style_file in tqdm(style_files[24525:]): | |
| style_image = Image.open(style_file) | |
| emb_local, emb_global = clip_model(clip_model.preprocess(style_image).to(device)) | |
| emb_local = clip_model.postprocess(emb_local) | |
| emb_global = clip_model.postprocess(emb_global) | |
| #x = torch.tensor(np.array(processor.image_processor(style_image).pixel_values)) | |
| #emb = model(x.to(device)).last_hidden_state | |
| #emb = emb.detach().cpu().squeeze(0).numpy() | |
| emb_file = style_file.replace('.jpg','_hidden.p') | |
| with open(emb_file, 'wb') as file: | |
| pickle.dump(emb_local, file) | |
| emb_file = style_file.replace('.jpg','.p') | |
| with open(emb_file, 'wb') as file: | |
| pickle.dump(emb_global, file) | |