chanfee commited on
Commit
c91d22a
·
verified ·
1 Parent(s): 1eaf9da

Update utils/model.py

Browse files
Files changed (1) hide show
  1. utils/model.py +3 -0
utils/model.py CHANGED
@@ -465,6 +465,9 @@ class OwlViTForClassification(nn.Module):
465
  print(f"im_features sum: {image_feats.sum().item()}, text_embeds sum: {text_embeds.sum().item()}")
466
  # Predict image-level classes (batch_size, num_patches, num_queries)
467
  image_text_logits, pred_logits, part_logits = self.cls_head(image_feats, text_embeds, topk_idxs)
 
 
 
468
 
469
  if self.weight_dict["loss_xclip"] > 0:
470
  targets_cls = torch.tensor([target["targets_cls"] for target in targets]).unsqueeze(1).to(self.device)
 
465
  print(f"im_features sum: {image_feats.sum().item()}, text_embeds sum: {text_embeds.sum().item()}")
466
  # Predict image-level classes (batch_size, num_patches, num_queries)
467
  image_text_logits, pred_logits, part_logits = self.cls_head(image_feats, text_embeds, topk_idxs)
468
+ print(f"topk_idxs: {topk_idxs}")
469
+ print(f"image_text_logits size: {image_text_logits.shape}")
470
+ print(f"image_text_logits sum: {image_text_logits.sum().item()}")
471
 
472
  if self.weight_dict["loss_xclip"] > 0:
473
  targets_cls = torch.tensor([target["targets_cls"] for target in targets]).unsqueeze(1).to(self.device)