chanfee commited on
Commit
6b9a3f4
·
verified ·
1 Parent(s): c806e96

Update utils/model.py

Browse files
Files changed (1) hide show
  1. utils/model.py +4 -2
utils/model.py CHANGED
@@ -421,9 +421,11 @@ class OwlViTForClassification(nn.Module):
421
  print(f"text_inputs_parts - input_ids: {text_inputs_parts['input_ids'].shape}. attention_mask : {text_inputs_parts['attention_mask'].shape}")
422
  seq_length = text_inputs_parts['input_ids'].shape[-1]
423
  position_ids = self.owlvit.text_model.embeddings.position_ids[:, :seq_length]
 
424
  print(f"position_embedding: {self.owlvit.text_model.embeddings.position_embedding(position_ids).shape}")
425
- print(f"text_embeds: {self.owlvit.text_model.embeddings.token_embedding(text_inputs_parts['input_ids']).shape}")
426
- text_embeds_parts = self.owlvit.text_model.text_model.get_text_features(**text_inputs_parts)
 
427
 
428
  # # Embed images and text queries
429
  query_mask, text_embeds_parts = self._get_text_query_mask(text_inputs_parts, text_embeds_parts, batch_size)
 
421
  print(f"text_inputs_parts - input_ids: {text_inputs_parts['input_ids'].shape}. attention_mask : {text_inputs_parts['attention_mask'].shape}")
422
  seq_length = text_inputs_parts['input_ids'].shape[-1]
423
  position_ids = self.owlvit.text_model.embeddings.position_ids[:, :seq_length]
424
+ txt_embeds = self.owlvit.text_model.embeddings.token_embedding(text_inputs_parts['input_ids'])
425
  print(f"position_embedding: {self.owlvit.text_model.embeddings.position_embedding(position_ids).shape}")
426
+ print(f"text_embeds: {txt_embeds.shape}")
427
+ print(f"pos + emb: {(txt_embeds + position_ids).shape}")
428
+ text_embeds_parts = self.owlvit.text_model.get_text_features(**text_inputs_parts)
429
 
430
  # # Embed images and text queries
431
  query_mask, text_embeds_parts = self._get_text_query_mask(text_inputs_parts, text_embeds_parts, batch_size)