Spaces:
Sleeping
Sleeping
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | |
""" | |
Pix2Seq model and criterion classes. | |
""" | |
import torch | |
from torch.profiler import profile, record_function, ProfilerActivity | |
import torch.nn.functional as F | |
from torch import nn | |
from .misc import nested_tensor_from_tensor_list | |
from .backbone import build_backbone | |
from .transformer import build_transformer | |
from transformers import GenerationConfig | |
import numpy as np | |
class Pix2Seq(nn.Module): | |
""" This is the Pix2Seq module that performs object detection """ | |
def __init__(self, backbone, transformer, use_hf = False): | |
""" Initializes the model. | |
Parameters: | |
backbone: torch module of the backbone to be used. See backbone.py | |
transformer: torch module of the transformer architecture. See transformer.py | |
num_classes: number of object classes | |
num_bins: number of bins for each side of the input image | |
""" | |
super().__init__() | |
self.transformer = transformer | |
hidden_dim = 256 if use_hf else transformer.d_model | |
self.input_proj = nn.Sequential( | |
nn.Conv2d(backbone.num_channels, hidden_dim, kernel_size=(1, 1)), | |
nn.GroupNorm(32, hidden_dim)) | |
self.backbone = backbone | |
self.use_hf = use_hf | |
def forward(self, image_tensor, targets=None, max_len=500, cheat = None): | |
""" | |
image_tensor: | |
The forward expects a NestedTensor, which consists of: | |
- samples.tensor: batched images, of shape [batch_size x 3 x H x W] | |
- samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels | |
It returns a dict with the following elements: | |
- "pred_logits": the classification logits (including no-object) for all vocabulary. | |
Shape= [batch_size, num_sequence, num_vocal] | |
""" | |
if isinstance(image_tensor, (list, torch.Tensor)): | |
image_tensor = nested_tensor_from_tensor_list(image_tensor) | |
features, pos = self.backbone(image_tensor) | |
#print(len(features)) | |
#print(pos.size()) | |
''' | |
with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], with_stack = True, record_shapes=True) as prof: | |
with record_function("model_inference"): | |
features, pos = self.backbone(image_tensor) | |
print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10)) | |
prof.export_stacks("/tmp/profiler_stacks_cuda_A6000_16_backbone.txt", "self_cuda_time_total") | |
''' | |
src, mask = features[-1].decompose() | |
assert mask is not None | |
mask = torch.zeros_like(mask).bool() | |
src = self.input_proj(src) | |
if self.use_hf: | |
if targets is not None: | |
''' | |
logits = self.transformer(src) | |
input_seq, input_len = targets | |
logits = logits.reshape(-1, 2094) | |
loss = self.loss_fn(logits, input_seq.view(-1)) | |
return loss, loss | |
''' | |
''' | |
output_logits = self.transformer(src, input_seq[:, 1:], mask, pos[-1]) | |
return output_logits[:, :-1] | |
''' | |
#print(input_seq) | |
input_seq, input_len = targets | |
input_seq = input_seq[:, 1:] | |
bs = src.shape[0] | |
src = src.flatten(2).permute(0, 2, 1) | |
#b x c x h x w to b x hw x c | |
pos_embed = pos[-1].flatten(2).permute(0, 2, 1) | |
max_len = input_seq.size(1) | |
indices = torch.arange(max_len).unsqueeze(0).expand_as(input_seq).to(src.device) | |
mask = indices >= input_len - torch.ones(input_len.shape).to(src.device) | |
masked_input_seq = input_seq.masked_fill(mask, -100) | |
#print("input_seq "+str(input_seq)) | |
#print("masked_input "+str(masked_input_seq)) | |
#src = src + pos_embed #unclear if this line is needed... | |
''' | |
decoder_input = torch.cat( | |
[ | |
nn.Embedding(1, 256).to(src.device).weight.unsqueeze(0).repeat(bs, 1, 1), | |
nn.Embedding(2092, 256).to(src.device)(input_seq) | |
], dim = 1 | |
) | |
''' | |
#decoder_mask = torch.full(decoder_input.shape[:2], False, dtype = torch.bool).to(src.device) | |
#decoder_mask[:, 0] = True | |
output = self.transformer(inputs_embeds = src,labels = masked_input_seq) | |
#print("output logits " + str(torch.argmax(output["logits"], dim = 2)) + "target labels "+ str(masked_input_seq)) | |
#print(output["logits"].shape) | |
return output["logits"], output["loss"] | |
else: | |
''' | |
logits = self.transformer(src) | |
print(logits.shape) | |
return self.transformer(src).argmax(dim = 1), self.transformer(src).argmax(dim = 1) | |
''' | |
#with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], with_stack = True, record_shapes=True) as prof: | |
# with record_function("model_inference"): | |
#print(pos[-1]) | |
#output_seqs, output_scores = self.transformer(src, None, mask, pos[-1], max_len=max_len) | |
''' | |
flatten src from B x C x H x W into B x HW x C and pass in as input_embeds | |
potentially flatten pos[-1] as well and add to input embeds | |
''' | |
bs = src.shape[0] | |
src = src.flatten(2).permute(0, 2, 1) | |
generation_config = GenerationConfig(max_new_tokens = max_len, bos_token_id = 2002, eos_token_id = 2092, pad_token_id = 2001, output_hidden_states = True) | |
#output = self.transformer.generate(inputs_embeds = src, generation_config = generation_config, return_dict_in_generate=True, output_scores=True) | |
#transition_scores = self.transformer.compute_transition_scores(output.sequences, output.scores, normalize_logits=True) | |
#for tok, score in zip(output.sequences[0], transition_scores[0]): | |
# print(f"| {tok:5d} | {score.to('cpu').numpy():.3f} | {np.exp(score.to('cpu').numpy()):.2%}") | |
#print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10)) | |
#prof.export_stacks("/tmp/profiler_stacks_cpu_A6000_16_decoder.txt", "self_cpu_time_total") | |
#print("loss "+str(output.loss)) | |
#encoder_outputs = self.transformer.encoder(inputs_embeds = src) | |
''' | |
print(cheat) | |
print("own predictions") | |
print(cheat['coref'][0][:, :3]) | |
print(self.transformer.decoder(input_ids = cheat['coref'][0][:, :3].to(src.device), encoder_hidden_states = torch.rand_like(encoder_outputs[0]).to(src.device)).logits.argmax(dim = 2)) | |
print(self.transformer.decoder(input_ids = cheat['coref'][0][:, :4].to(src.device), encoder_hidden_states = torch.rand_like(encoder_outputs[0]).to(src.device)).logits.argmax(dim = 2)) | |
print(self.transformer.decoder(input_ids = cheat['coref'][0][:, :5].to(src.device), encoder_hidden_states = torch.rand_like(encoder_outputs[0]).to(src.device)).logits.argmax(dim = 2)) | |
''' | |
#input_seq, input_len = cheat['bbox'] | |
#input_seq = input_seq[:, 1:] | |
#b x c x h x w to b x hw x c | |
#max_len = input_seq.size(1) | |
#indices = torch.arange(max_len).unsqueeze(0).expand_as(input_seq).to(src.device) | |
#mask = indices >= input_len - torch.ones(input_len.shape).to(src.device) | |
#masked_input_seq = input_seq.masked_fill(mask, -100) | |
#output = self.transformer(inputs_embeds = src,labels = masked_input_seq) | |
#print("output logits " + str(torch.argmax(output["logits"], dim = 2)) + "target labels "+ str(masked_input_seq)) | |
outputs = self.transformer.generate(inputs_embeds = src, generation_config = generation_config) | |
return outputs, outputs | |
else: | |
if targets is not None: | |
input_seq, input_len = targets | |
output_logits = self.transformer(src, input_seq[:, 1:], mask, pos[-1]) | |
return output_logits[:, :-1] | |
else: | |
output_seqs, output_scores = self.transformer(src, None, mask, pos[-1], max_len=max_len) | |
return output_seqs, output_scores | |
def build_pix2seq_model(args, tokenizer): | |
# the `num_classes` naming here is somewhat misleading. | |
# it indeed corresponds to `max_obj_id + 1`, where max_obj_id | |
# is the maximum id for a class in your dataset. For example, | |
# COCO has a max_obj_id of 90, so we pass `num_classes` to be 91. | |
# As another example, for a dataset that has a single class with id 1, | |
# you should pass `num_classes` to be 2 (max_obj_id + 1). | |
# For more details on this, check the following discussion | |
# https://github.com/facebookresearch/detr/issues/108#issuecomment-650269223 | |
backbone = build_backbone(args) | |
transformer = build_transformer(args, tokenizer) | |
model = Pix2Seq(backbone, transformer, use_hf = args.use_hf_transformer) | |
if args.pix2seq_ckpt is not None: | |
checkpoint = torch.load(args.pix2seq_ckpt, map_location='cpu') | |
if args.use_hf_transformer: | |
new_dict = {} | |
#print(checkpoint['state_dict'].keys()) | |
for key in checkpoint['state_dict']: | |
new_dict[key[6:]] = checkpoint['state_dict'][key] | |
model.load_state_dict(new_dict, strict = False) | |
else: | |
model.load_state_dict(checkpoint['model']) | |
return model | |