# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved """ Pix2Seq model and criterion classes. """ import torch from torch.profiler import profile, record_function, ProfilerActivity import torch.nn.functional as F from torch import nn from .misc import nested_tensor_from_tensor_list from .backbone import build_backbone from .transformer import build_transformer from transformers import GenerationConfig import numpy as np class Pix2Seq(nn.Module): """ This is the Pix2Seq module that performs object detection """ def __init__(self, backbone, transformer, use_hf = False): """ Initializes the model. Parameters: backbone: torch module of the backbone to be used. See backbone.py transformer: torch module of the transformer architecture. See transformer.py num_classes: number of object classes num_bins: number of bins for each side of the input image """ super().__init__() self.transformer = transformer hidden_dim = 256 if use_hf else transformer.d_model self.input_proj = nn.Sequential( nn.Conv2d(backbone.num_channels, hidden_dim, kernel_size=(1, 1)), nn.GroupNorm(32, hidden_dim)) self.backbone = backbone self.use_hf = use_hf def forward(self, image_tensor, targets=None, max_len=500, cheat = None): """  image_tensor: The forward expects a NestedTensor, which consists of: - samples.tensor: batched images, of shape [batch_size x 3 x H x W] - samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels It returns a dict with the following elements: - "pred_logits": the classification logits (including no-object) for all vocabulary. Shape= [batch_size, num_sequence, num_vocal] """ if isinstance(image_tensor, (list, torch.Tensor)): image_tensor = nested_tensor_from_tensor_list(image_tensor) features, pos = self.backbone(image_tensor) #print(len(features)) #print(pos.size()) ''' with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], with_stack = True, record_shapes=True) as prof: with record_function("model_inference"): features, pos = self.backbone(image_tensor) print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10)) prof.export_stacks("/tmp/profiler_stacks_cuda_A6000_16_backbone.txt", "self_cuda_time_total") ''' src, mask = features[-1].decompose() assert mask is not None mask = torch.zeros_like(mask).bool() src = self.input_proj(src) if self.use_hf: if targets is not None: ''' logits = self.transformer(src) input_seq, input_len = targets logits = logits.reshape(-1, 2094) loss = self.loss_fn(logits, input_seq.view(-1)) return loss, loss ''' ''' output_logits = self.transformer(src, input_seq[:, 1:], mask, pos[-1]) return output_logits[:, :-1] ''' #print(input_seq) input_seq, input_len = targets input_seq = input_seq[:, 1:] bs = src.shape[0] src = src.flatten(2).permute(0, 2, 1) #b x c x h x w to b x hw x c pos_embed = pos[-1].flatten(2).permute(0, 2, 1) max_len = input_seq.size(1) indices = torch.arange(max_len).unsqueeze(0).expand_as(input_seq).to(src.device) mask = indices >= input_len - torch.ones(input_len.shape).to(src.device) masked_input_seq = input_seq.masked_fill(mask, -100) #print("input_seq "+str(input_seq)) #print("masked_input "+str(masked_input_seq)) #src = src + pos_embed #unclear if this line is needed... ''' decoder_input = torch.cat( [ nn.Embedding(1, 256).to(src.device).weight.unsqueeze(0).repeat(bs, 1, 1), nn.Embedding(2092, 256).to(src.device)(input_seq) ], dim = 1 ) ''' #decoder_mask = torch.full(decoder_input.shape[:2], False, dtype = torch.bool).to(src.device) #decoder_mask[:, 0] = True output = self.transformer(inputs_embeds = src,labels = masked_input_seq) #print("output logits " + str(torch.argmax(output["logits"], dim = 2)) + "target labels "+ str(masked_input_seq)) #print(output["logits"].shape) return output["logits"], output["loss"] else: ''' logits = self.transformer(src) print(logits.shape) return self.transformer(src).argmax(dim = 1), self.transformer(src).argmax(dim = 1) ''' #with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], with_stack = True, record_shapes=True) as prof: # with record_function("model_inference"): #print(pos[-1]) #output_seqs, output_scores = self.transformer(src, None, mask, pos[-1], max_len=max_len) ''' flatten src from B x C x H x W into B x HW x C and pass in as input_embeds potentially flatten pos[-1] as well and add to input embeds ''' bs = src.shape[0] src = src.flatten(2).permute(0, 2, 1) generation_config = GenerationConfig(max_new_tokens = max_len, bos_token_id = 2002, eos_token_id = 2092, pad_token_id = 2001, output_hidden_states = True) #output = self.transformer.generate(inputs_embeds = src, generation_config = generation_config, return_dict_in_generate=True, output_scores=True) #transition_scores = self.transformer.compute_transition_scores(output.sequences, output.scores, normalize_logits=True) #for tok, score in zip(output.sequences[0], transition_scores[0]): # print(f"| {tok:5d} | {score.to('cpu').numpy():.3f} | {np.exp(score.to('cpu').numpy()):.2%}") #print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10)) #prof.export_stacks("/tmp/profiler_stacks_cpu_A6000_16_decoder.txt", "self_cpu_time_total") #print("loss "+str(output.loss)) #encoder_outputs = self.transformer.encoder(inputs_embeds = src) ''' print(cheat) print("own predictions") print(cheat['coref'][0][:, :3]) print(self.transformer.decoder(input_ids = cheat['coref'][0][:, :3].to(src.device), encoder_hidden_states = torch.rand_like(encoder_outputs[0]).to(src.device)).logits.argmax(dim = 2)) print(self.transformer.decoder(input_ids = cheat['coref'][0][:, :4].to(src.device), encoder_hidden_states = torch.rand_like(encoder_outputs[0]).to(src.device)).logits.argmax(dim = 2)) print(self.transformer.decoder(input_ids = cheat['coref'][0][:, :5].to(src.device), encoder_hidden_states = torch.rand_like(encoder_outputs[0]).to(src.device)).logits.argmax(dim = 2)) ''' #input_seq, input_len = cheat['bbox'] #input_seq = input_seq[:, 1:] #b x c x h x w to b x hw x c #max_len = input_seq.size(1) #indices = torch.arange(max_len).unsqueeze(0).expand_as(input_seq).to(src.device) #mask = indices >= input_len - torch.ones(input_len.shape).to(src.device) #masked_input_seq = input_seq.masked_fill(mask, -100) #output = self.transformer(inputs_embeds = src,labels = masked_input_seq) #print("output logits " + str(torch.argmax(output["logits"], dim = 2)) + "target labels "+ str(masked_input_seq)) outputs = self.transformer.generate(inputs_embeds = src, generation_config = generation_config) return outputs, outputs else: if targets is not None: input_seq, input_len = targets output_logits = self.transformer(src, input_seq[:, 1:], mask, pos[-1]) return output_logits[:, :-1] else: output_seqs, output_scores = self.transformer(src, None, mask, pos[-1], max_len=max_len) return output_seqs, output_scores def build_pix2seq_model(args, tokenizer): # the `num_classes` naming here is somewhat misleading. # it indeed corresponds to `max_obj_id + 1`, where max_obj_id # is the maximum id for a class in your dataset. For example, # COCO has a max_obj_id of 90, so we pass `num_classes` to be 91. # As another example, for a dataset that has a single class with id 1, # you should pass `num_classes` to be 2 (max_obj_id + 1). # For more details on this, check the following discussion # https://github.com/facebookresearch/detr/issues/108#issuecomment-650269223 backbone = build_backbone(args) transformer = build_transformer(args, tokenizer) model = Pix2Seq(backbone, transformer, use_hf = args.use_hf_transformer) if args.pix2seq_ckpt is not None: checkpoint = torch.load(args.pix2seq_ckpt, map_location='cpu') if args.use_hf_transformer: new_dict = {} #print(checkpoint['state_dict'].keys()) for key in checkpoint['state_dict']: new_dict[key[6:]] = checkpoint['state_dict'][key] model.load_state_dict(new_dict, strict = False) else: model.load_state_dict(checkpoint['model']) return model