Spaces:
Sleeping
Sleeping
File size: 10,081 Bytes
1f516b6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 |
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
Pix2Seq model and criterion classes.
"""
import torch
from torch.profiler import profile, record_function, ProfilerActivity
import torch.nn.functional as F
from torch import nn
from .misc import nested_tensor_from_tensor_list
from .backbone import build_backbone
from .transformer import build_transformer
from transformers import GenerationConfig
import numpy as np
class Pix2Seq(nn.Module):
""" This is the Pix2Seq module that performs object detection """
def __init__(self, backbone, transformer, use_hf = False):
""" Initializes the model.
Parameters:
backbone: torch module of the backbone to be used. See backbone.py
transformer: torch module of the transformer architecture. See transformer.py
num_classes: number of object classes
num_bins: number of bins for each side of the input image
"""
super().__init__()
self.transformer = transformer
hidden_dim = 256 if use_hf else transformer.d_model
self.input_proj = nn.Sequential(
nn.Conv2d(backbone.num_channels, hidden_dim, kernel_size=(1, 1)),
nn.GroupNorm(32, hidden_dim))
self.backbone = backbone
self.use_hf = use_hf
def forward(self, image_tensor, targets=None, max_len=500, cheat = None):
"""
image_tensor:
The forward expects a NestedTensor, which consists of:
- samples.tensor: batched images, of shape [batch_size x 3 x H x W]
- samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels
It returns a dict with the following elements:
- "pred_logits": the classification logits (including no-object) for all vocabulary.
Shape= [batch_size, num_sequence, num_vocal]
"""
if isinstance(image_tensor, (list, torch.Tensor)):
image_tensor = nested_tensor_from_tensor_list(image_tensor)
features, pos = self.backbone(image_tensor)
#print(len(features))
#print(pos.size())
'''
with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], with_stack = True, record_shapes=True) as prof:
with record_function("model_inference"):
features, pos = self.backbone(image_tensor)
print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))
prof.export_stacks("/tmp/profiler_stacks_cuda_A6000_16_backbone.txt", "self_cuda_time_total")
'''
src, mask = features[-1].decompose()
assert mask is not None
mask = torch.zeros_like(mask).bool()
src = self.input_proj(src)
if self.use_hf:
if targets is not None:
'''
logits = self.transformer(src)
input_seq, input_len = targets
logits = logits.reshape(-1, 2094)
loss = self.loss_fn(logits, input_seq.view(-1))
return loss, loss
'''
'''
output_logits = self.transformer(src, input_seq[:, 1:], mask, pos[-1])
return output_logits[:, :-1]
'''
#print(input_seq)
input_seq, input_len = targets
input_seq = input_seq[:, 1:]
bs = src.shape[0]
src = src.flatten(2).permute(0, 2, 1)
#b x c x h x w to b x hw x c
pos_embed = pos[-1].flatten(2).permute(0, 2, 1)
max_len = input_seq.size(1)
indices = torch.arange(max_len).unsqueeze(0).expand_as(input_seq).to(src.device)
mask = indices >= input_len - torch.ones(input_len.shape).to(src.device)
masked_input_seq = input_seq.masked_fill(mask, -100)
#print("input_seq "+str(input_seq))
#print("masked_input "+str(masked_input_seq))
#src = src + pos_embed #unclear if this line is needed...
'''
decoder_input = torch.cat(
[
nn.Embedding(1, 256).to(src.device).weight.unsqueeze(0).repeat(bs, 1, 1),
nn.Embedding(2092, 256).to(src.device)(input_seq)
], dim = 1
)
'''
#decoder_mask = torch.full(decoder_input.shape[:2], False, dtype = torch.bool).to(src.device)
#decoder_mask[:, 0] = True
output = self.transformer(inputs_embeds = src,labels = masked_input_seq)
#print("output logits " + str(torch.argmax(output["logits"], dim = 2)) + "target labels "+ str(masked_input_seq))
#print(output["logits"].shape)
return output["logits"], output["loss"]
else:
'''
logits = self.transformer(src)
print(logits.shape)
return self.transformer(src).argmax(dim = 1), self.transformer(src).argmax(dim = 1)
'''
#with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], with_stack = True, record_shapes=True) as prof:
# with record_function("model_inference"):
#print(pos[-1])
#output_seqs, output_scores = self.transformer(src, None, mask, pos[-1], max_len=max_len)
'''
flatten src from B x C x H x W into B x HW x C and pass in as input_embeds
potentially flatten pos[-1] as well and add to input embeds
'''
bs = src.shape[0]
src = src.flatten(2).permute(0, 2, 1)
generation_config = GenerationConfig(max_new_tokens = max_len, bos_token_id = 2002, eos_token_id = 2092, pad_token_id = 2001, output_hidden_states = True)
#output = self.transformer.generate(inputs_embeds = src, generation_config = generation_config, return_dict_in_generate=True, output_scores=True)
#transition_scores = self.transformer.compute_transition_scores(output.sequences, output.scores, normalize_logits=True)
#for tok, score in zip(output.sequences[0], transition_scores[0]):
# print(f"| {tok:5d} | {score.to('cpu').numpy():.3f} | {np.exp(score.to('cpu').numpy()):.2%}")
#print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))
#prof.export_stacks("/tmp/profiler_stacks_cpu_A6000_16_decoder.txt", "self_cpu_time_total")
#print("loss "+str(output.loss))
#encoder_outputs = self.transformer.encoder(inputs_embeds = src)
'''
print(cheat)
print("own predictions")
print(cheat['coref'][0][:, :3])
print(self.transformer.decoder(input_ids = cheat['coref'][0][:, :3].to(src.device), encoder_hidden_states = torch.rand_like(encoder_outputs[0]).to(src.device)).logits.argmax(dim = 2))
print(self.transformer.decoder(input_ids = cheat['coref'][0][:, :4].to(src.device), encoder_hidden_states = torch.rand_like(encoder_outputs[0]).to(src.device)).logits.argmax(dim = 2))
print(self.transformer.decoder(input_ids = cheat['coref'][0][:, :5].to(src.device), encoder_hidden_states = torch.rand_like(encoder_outputs[0]).to(src.device)).logits.argmax(dim = 2))
'''
#input_seq, input_len = cheat['bbox']
#input_seq = input_seq[:, 1:]
#b x c x h x w to b x hw x c
#max_len = input_seq.size(1)
#indices = torch.arange(max_len).unsqueeze(0).expand_as(input_seq).to(src.device)
#mask = indices >= input_len - torch.ones(input_len.shape).to(src.device)
#masked_input_seq = input_seq.masked_fill(mask, -100)
#output = self.transformer(inputs_embeds = src,labels = masked_input_seq)
#print("output logits " + str(torch.argmax(output["logits"], dim = 2)) + "target labels "+ str(masked_input_seq))
outputs = self.transformer.generate(inputs_embeds = src, generation_config = generation_config)
return outputs, outputs
else:
if targets is not None:
input_seq, input_len = targets
output_logits = self.transformer(src, input_seq[:, 1:], mask, pos[-1])
return output_logits[:, :-1]
else:
output_seqs, output_scores = self.transformer(src, None, mask, pos[-1], max_len=max_len)
return output_seqs, output_scores
def build_pix2seq_model(args, tokenizer):
# the `num_classes` naming here is somewhat misleading.
# it indeed corresponds to `max_obj_id + 1`, where max_obj_id
# is the maximum id for a class in your dataset. For example,
# COCO has a max_obj_id of 90, so we pass `num_classes` to be 91.
# As another example, for a dataset that has a single class with id 1,
# you should pass `num_classes` to be 2 (max_obj_id + 1).
# For more details on this, check the following discussion
# https://github.com/facebookresearch/detr/issues/108#issuecomment-650269223
backbone = build_backbone(args)
transformer = build_transformer(args, tokenizer)
model = Pix2Seq(backbone, transformer, use_hf = args.use_hf_transformer)
if args.pix2seq_ckpt is not None:
checkpoint = torch.load(args.pix2seq_ckpt, map_location='cpu')
if args.use_hf_transformer:
new_dict = {}
#print(checkpoint['state_dict'].keys())
for key in checkpoint['state_dict']:
new_dict[key[6:]] = checkpoint['state_dict'][key]
model.load_state_dict(new_dict, strict = False)
else:
model.load_state_dict(checkpoint['model'])
return model
|