CYF200127's picture
Upload 162 files
1f516b6 verified
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
Pix2Seq model and criterion classes.
"""
import torch
from torch.profiler import profile, record_function, ProfilerActivity
import torch.nn.functional as F
from torch import nn
from .misc import nested_tensor_from_tensor_list
from .backbone import build_backbone
from .transformer import build_transformer
from transformers import GenerationConfig
import numpy as np
class Pix2Seq(nn.Module):
""" This is the Pix2Seq module that performs object detection """
def __init__(self, backbone, transformer, use_hf = False):
""" Initializes the model.
Parameters:
backbone: torch module of the backbone to be used. See backbone.py
transformer: torch module of the transformer architecture. See transformer.py
num_classes: number of object classes
num_bins: number of bins for each side of the input image
"""
super().__init__()
self.transformer = transformer
hidden_dim = 256 if use_hf else transformer.d_model
self.input_proj = nn.Sequential(
nn.Conv2d(backbone.num_channels, hidden_dim, kernel_size=(1, 1)),
nn.GroupNorm(32, hidden_dim))
self.backbone = backbone
self.use_hf = use_hf
def forward(self, image_tensor, targets=None, max_len=500, cheat = None):
""" 
image_tensor:
The forward expects a NestedTensor, which consists of:
- samples.tensor: batched images, of shape [batch_size x 3 x H x W]
- samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels
It returns a dict with the following elements:
- "pred_logits": the classification logits (including no-object) for all vocabulary.
Shape= [batch_size, num_sequence, num_vocal]
"""
if isinstance(image_tensor, (list, torch.Tensor)):
image_tensor = nested_tensor_from_tensor_list(image_tensor)
features, pos = self.backbone(image_tensor)
#print(len(features))
#print(pos.size())
'''
with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], with_stack = True, record_shapes=True) as prof:
with record_function("model_inference"):
features, pos = self.backbone(image_tensor)
print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))
prof.export_stacks("/tmp/profiler_stacks_cuda_A6000_16_backbone.txt", "self_cuda_time_total")
'''
src, mask = features[-1].decompose()
assert mask is not None
mask = torch.zeros_like(mask).bool()
src = self.input_proj(src)
if self.use_hf:
if targets is not None:
'''
logits = self.transformer(src)
input_seq, input_len = targets
logits = logits.reshape(-1, 2094)
loss = self.loss_fn(logits, input_seq.view(-1))
return loss, loss
'''
'''
output_logits = self.transformer(src, input_seq[:, 1:], mask, pos[-1])
return output_logits[:, :-1]
'''
#print(input_seq)
input_seq, input_len = targets
input_seq = input_seq[:, 1:]
bs = src.shape[0]
src = src.flatten(2).permute(0, 2, 1)
#b x c x h x w to b x hw x c
pos_embed = pos[-1].flatten(2).permute(0, 2, 1)
max_len = input_seq.size(1)
indices = torch.arange(max_len).unsqueeze(0).expand_as(input_seq).to(src.device)
mask = indices >= input_len - torch.ones(input_len.shape).to(src.device)
masked_input_seq = input_seq.masked_fill(mask, -100)
#print("input_seq "+str(input_seq))
#print("masked_input "+str(masked_input_seq))
#src = src + pos_embed #unclear if this line is needed...
'''
decoder_input = torch.cat(
[
nn.Embedding(1, 256).to(src.device).weight.unsqueeze(0).repeat(bs, 1, 1),
nn.Embedding(2092, 256).to(src.device)(input_seq)
], dim = 1
)
'''
#decoder_mask = torch.full(decoder_input.shape[:2], False, dtype = torch.bool).to(src.device)
#decoder_mask[:, 0] = True
output = self.transformer(inputs_embeds = src,labels = masked_input_seq)
#print("output logits " + str(torch.argmax(output["logits"], dim = 2)) + "target labels "+ str(masked_input_seq))
#print(output["logits"].shape)
return output["logits"], output["loss"]
else:
'''
logits = self.transformer(src)
print(logits.shape)
return self.transformer(src).argmax(dim = 1), self.transformer(src).argmax(dim = 1)
'''
#with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], with_stack = True, record_shapes=True) as prof:
# with record_function("model_inference"):
#print(pos[-1])
#output_seqs, output_scores = self.transformer(src, None, mask, pos[-1], max_len=max_len)
'''
flatten src from B x C x H x W into B x HW x C and pass in as input_embeds
potentially flatten pos[-1] as well and add to input embeds
'''
bs = src.shape[0]
src = src.flatten(2).permute(0, 2, 1)
generation_config = GenerationConfig(max_new_tokens = max_len, bos_token_id = 2002, eos_token_id = 2092, pad_token_id = 2001, output_hidden_states = True)
#output = self.transformer.generate(inputs_embeds = src, generation_config = generation_config, return_dict_in_generate=True, output_scores=True)
#transition_scores = self.transformer.compute_transition_scores(output.sequences, output.scores, normalize_logits=True)
#for tok, score in zip(output.sequences[0], transition_scores[0]):
# print(f"| {tok:5d} | {score.to('cpu').numpy():.3f} | {np.exp(score.to('cpu').numpy()):.2%}")
#print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))
#prof.export_stacks("/tmp/profiler_stacks_cpu_A6000_16_decoder.txt", "self_cpu_time_total")
#print("loss "+str(output.loss))
#encoder_outputs = self.transformer.encoder(inputs_embeds = src)
'''
print(cheat)
print("own predictions")
print(cheat['coref'][0][:, :3])
print(self.transformer.decoder(input_ids = cheat['coref'][0][:, :3].to(src.device), encoder_hidden_states = torch.rand_like(encoder_outputs[0]).to(src.device)).logits.argmax(dim = 2))
print(self.transformer.decoder(input_ids = cheat['coref'][0][:, :4].to(src.device), encoder_hidden_states = torch.rand_like(encoder_outputs[0]).to(src.device)).logits.argmax(dim = 2))
print(self.transformer.decoder(input_ids = cheat['coref'][0][:, :5].to(src.device), encoder_hidden_states = torch.rand_like(encoder_outputs[0]).to(src.device)).logits.argmax(dim = 2))
'''
#input_seq, input_len = cheat['bbox']
#input_seq = input_seq[:, 1:]
#b x c x h x w to b x hw x c
#max_len = input_seq.size(1)
#indices = torch.arange(max_len).unsqueeze(0).expand_as(input_seq).to(src.device)
#mask = indices >= input_len - torch.ones(input_len.shape).to(src.device)
#masked_input_seq = input_seq.masked_fill(mask, -100)
#output = self.transformer(inputs_embeds = src,labels = masked_input_seq)
#print("output logits " + str(torch.argmax(output["logits"], dim = 2)) + "target labels "+ str(masked_input_seq))
outputs = self.transformer.generate(inputs_embeds = src, generation_config = generation_config)
return outputs, outputs
else:
if targets is not None:
input_seq, input_len = targets
output_logits = self.transformer(src, input_seq[:, 1:], mask, pos[-1])
return output_logits[:, :-1]
else:
output_seqs, output_scores = self.transformer(src, None, mask, pos[-1], max_len=max_len)
return output_seqs, output_scores
def build_pix2seq_model(args, tokenizer):
# the `num_classes` naming here is somewhat misleading.
# it indeed corresponds to `max_obj_id + 1`, where max_obj_id
# is the maximum id for a class in your dataset. For example,
# COCO has a max_obj_id of 90, so we pass `num_classes` to be 91.
# As another example, for a dataset that has a single class with id 1,
# you should pass `num_classes` to be 2 (max_obj_id + 1).
# For more details on this, check the following discussion
# https://github.com/facebookresearch/detr/issues/108#issuecomment-650269223
backbone = build_backbone(args)
transformer = build_transformer(args, tokenizer)
model = Pix2Seq(backbone, transformer, use_hf = args.use_hf_transformer)
if args.pix2seq_ckpt is not None:
checkpoint = torch.load(args.pix2seq_ckpt, map_location='cpu')
if args.use_hf_transformer:
new_dict = {}
#print(checkpoint['state_dict'].keys())
for key in checkpoint['state_dict']:
new_dict[key[6:]] = checkpoint['state_dict'][key]
model.load_state_dict(new_dict, strict = False)
else:
model.load_state_dict(checkpoint['model'])
return model