Spaces:
Sleeping
Sleeping
| try: | |
| import spaces | |
| gpu_decorator = spaces.GPU | |
| from .load_model import load_xclip | |
| except ImportError: | |
| # Define a no-operation decorator as fallback | |
| def gpu_decorator(func): | |
| return func | |
| import PIL | |
| import torch | |
| from .prompts import GetPromptList | |
| ORG_PART_ORDER = ['back', 'beak', 'belly', 'breast', 'crown', 'forehead', 'eyes', 'legs', 'wings', 'nape', 'tail', 'throat'] | |
| ORDERED_PARTS = ['crown', 'forehead', 'nape', 'eyes', 'beak', 'throat', 'breast', 'belly', 'back', 'wings', 'legs', 'tail'] | |
| def encode_descs_xclip(owlvit_det_processor: callable, model: callable, descs: list[str], device: str, max_batch_size: int = 512): | |
| total_num_batches = len(descs) // max_batch_size + 1 | |
| with torch.no_grad(): | |
| text_embeds = [] | |
| for batch_idx in range(total_num_batches): | |
| query_descs = descs[batch_idx*max_batch_size:(batch_idx+1)*max_batch_size] | |
| query_tokens = owlvit_det_processor(text=query_descs, padding="max_length", truncation=True, return_tensors="pt").to(device) | |
| query_embeds = model.owlvit.get_text_features(**query_tokens) | |
| text_embeds.append(query_embeds.cpu().float()) | |
| text_embeds = torch.cat(text_embeds, dim=0) | |
| return text_embeds.to(device) | |
| # def encode_descs_clip(model: callable, descs: list[str], device: str, max_batch_size: int = 512): | |
| # total_num_batches = len(descs) // max_batch_size + 1 | |
| # with torch.no_grad(): | |
| # text_embeds = [] | |
| # for batch_idx in range(total_num_batches): | |
| # desc = descs[batch_idx*max_batch_size:(batch_idx+1)*max_batch_size] | |
| # query_tokens = clip.tokenize(desc).to(device) | |
| # text_embeds.append(model.encode_text(query_tokens).cpu().float()) | |
| # text_embeds = torch.cat(text_embeds, dim=0) | |
| # text_embeds = torch.nn.functional.normalize(text_embeds, dim=-1) | |
| # return text_embeds.to(device) | |
| def xclip_pred(new_desc: dict, | |
| new_part_mask: dict, | |
| new_class: str, | |
| org_desc: str, | |
| image: PIL.Image, | |
| model: callable, | |
| owlvit_processor: callable, | |
| device: str, | |
| return_img_embeds: bool = False, | |
| use_precompute_embeddings = True, | |
| image_name: str = None, | |
| cub_embeds: torch.Tensor = None, | |
| cub_idx2name: dict = None, | |
| descriptors: dict = None): | |
| # check if in huggingface space | |
| try: | |
| model.to('cuda') | |
| device = 'cuda' | |
| except: | |
| device = torch.cuda.current_device() if torch.cuda.is_available() else 'cpu' | |
| # reorder the new description and the mask | |
| if new_class is not None: | |
| new_desc_ = {k: new_desc[k] for k in ORG_PART_ORDER} | |
| new_part_mask_ = {k: new_part_mask[k] for k in ORG_PART_ORDER} | |
| desc_mask = list(new_part_mask_.values()) | |
| else: | |
| desc_mask = [1] * 12 | |
| if cub_embeds is None: | |
| # replace the description if the new class is in the description, otherwise add a new class | |
| getprompt = GetPromptList(org_desc) | |
| if new_class not in getprompt.desc and new_class is not None: | |
| getprompt.name2idx[new_class] = len(getprompt.name2idx) | |
| if new_class is not None: | |
| getprompt.desc[new_class] = list(new_desc_.values()) | |
| idx2name = dict(zip(getprompt.name2idx.values(), getprompt.name2idx.keys())) | |
| modified_class_idx = getprompt.name2idx[new_class] if new_class is not None else None | |
| n_classes = len(getprompt.name2idx) | |
| descs, class_idxs, class_mapping, org_desc_mapper, class_list = getprompt('chatgpt-no-template', max_len=12, pad=True) | |
| query_embeds = encode_descs_xclip(owlvit_processor, model, descs, device) | |
| else: | |
| cub_embeds = cub_embeds.to(device) | |
| if new_class is not None: | |
| if new_class in list(cub_idx2name.values()): | |
| new_class = f"{new_class}_custom" | |
| idx2name = cub_idx2name | {200: new_class} | |
| descriptors |= {new_class: list(new_desc_.values())} | |
| n_classes = 201 | |
| query_tokens = owlvit_processor(text=list(new_desc_.values()), padding="max_length", truncation=True, return_tensors="pt").to(device) | |
| new_class_embed = model.owlvit.get_text_features(**query_tokens) | |
| query_embeds = torch.cat([cub_embeds, new_class_embed], dim=0) | |
| modified_class_idx = 200 | |
| else: | |
| n_classes = 200 | |
| query_embeds = cub_embeds | |
| idx2name = cub_idx2name | |
| modified_class_idx = None | |
| model.cls_head.num_classes = n_classes | |
| with torch.no_grad(): | |
| part_embeds = owlvit_processor(text=[ORG_PART_ORDER], return_tensors="pt").to(device) | |
| if use_precompute_embeddings: | |
| image_embeds = torch.load(f'data/image_embeddings/{image_name}.pt').to(device) | |
| else: | |
| image_input = owlvit_processor(images=image, return_tensors='pt').to(device) | |
| image_embeds, _ = model.image_embedder(pixel_values = image_input['pixel_values']) | |
| pred_logits, part_logits, output_dict = model(image_embeds, part_embeds, query_embeds, None) | |
| b, c, n = part_logits.shape | |
| mask = torch.tensor(desc_mask, dtype=float).unsqueeze(0).unsqueeze(0).repeat(b, c, 1).to(device) | |
| # overwrite the pred_logits | |
| part_logits = part_logits * mask | |
| pred_logits = torch.sum(part_logits, dim=-1) | |
| pred_class_idx = torch.argmax(pred_logits, dim=-1).cpu() | |
| pred_class_name = idx2name[pred_class_idx.item()] | |
| softmax_scores = torch.softmax(pred_logits, dim=-1).cpu() | |
| softmax_score_top1 = torch.topk(softmax_scores, k=1, dim=-1)[0].squeeze(-1).item() | |
| part_scores = part_logits[0, pred_class_idx].cpu().squeeze(0) | |
| part_scores_dict = dict(zip(ORG_PART_ORDER, part_scores.tolist())) | |
| if modified_class_idx is not None: | |
| modified_score = softmax_scores[0, modified_class_idx].item() | |
| modified_part_scores = part_logits[0, modified_class_idx].cpu().squeeze(0) | |
| modified_part_scores_dict = dict(zip(ORG_PART_ORDER, modified_part_scores.tolist())) | |
| else: | |
| modified_score = None | |
| modified_part_scores_dict = None | |
| output_dict = {"pred_class": pred_class_name, | |
| "pred_score": softmax_score_top1, | |
| "pred_desc_scores": part_scores_dict, | |
| "descriptions": descriptors[pred_class_name], | |
| "modified_class": new_class, | |
| "modified_score": modified_score, | |
| "modified_desc_scores": modified_part_scores_dict, | |
| "modified_descriptions": descriptors.get(new_class), | |
| } | |
| return (output_dict, image_embeds) if return_img_embeds else output_dict | |
| # def sachit_pred(new_desc: list, | |
| # new_class: str, | |
| # org_desc: str, | |
| # image: PIL.Image, | |
| # model: callable, | |
| # preprocess: callable, | |
| # device: str, | |
| # ): | |
| # # replace the description if the new class is in the description, otherwise add a new class | |
| # getprompt = GetPromptList(org_desc) | |
| # if new_class not in getprompt.desc: | |
| # getprompt.name2idx[new_class] = len(getprompt.name2idx) | |
| # getprompt.desc[new_class] = new_desc | |
| # idx2name = dict(zip(getprompt.name2idx.values(), getprompt.name2idx.keys())) | |
| # modified_class_idx = getprompt.name2idx[new_class] | |
| # descs, class_idxs, class_mapping, org_desc_mapper, class_list = getprompt('Sachit-descriptors', max_len=12, pad=True) | |
| # text_embeds = encode_descs_clip(model, descs, device) | |
| # with torch.no_grad(): | |
| # image_embed = model.encode_image(preprocess(image).unsqueeze(0).to(device)) | |
| # desc_mask = torch.tensor(class_idxs) | |
| # desc_mask = torch.where(desc_mask == -1, 0, 1).unsqueeze(0).to(device) | |
| # sim = torch.matmul(image_embed.float(), text_embeds.T) | |
| # sim = (sim * desc_mask).view(1, -1, 12) | |
| # pred_scores = torch.sum(sim, dim=-1) | |
| # pred_class_idx = torch.argmax(pred_scores, dim=-1).cpu() | |
| # pred_class = idx2name[pred_class_idx.item()] | |
| # softmax_scores = torch.nn.functional.softmax(pred_scores, dim=-1).cpu() | |
| # top1_score = torch.topk(softmax_scores, k=1, dim=-1)[0].squeeze(-1).item() | |
| # modified_score = softmax_scores[0, modified_class_idx].item() | |
| # pred_desc_scores = sim[0, pred_class_idx].cpu().squeeze(0) | |
| # modified_class_scores = sim[0, modified_class_idx].cpu().squeeze(0) | |
| # output_dict = {"pred_class": pred_class, | |
| # "pred_score": top1_score, | |
| # "pred_desc_scores": pred_desc_scores.tolist(), | |
| # "descriptions": getprompt.desc[pred_class], | |
| # "modified_class": new_class, | |
| # "modified_score": modified_score, | |
| # "modified_desc_scores": modified_class_scores.tolist(), | |
| # "modified_descriptions": getprompt.desc[new_class], | |
| # } | |
| # return output_dict |