Spaces:
Running
Running
Update models/tag2text.py
Browse files- models/tag2text.py +14 -2
models/tag2text.py
CHANGED
|
@@ -26,7 +26,14 @@ def read_json(rpath):
|
|
| 26 |
with open(rpath, 'r') as f:
|
| 27 |
return json.load(f)
|
| 28 |
|
|
|
|
|
|
|
| 29 |
delete_tag_index = [127,2961, 3351, 3265, 3338, 3355, 3359]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
class Tag2Text_Caption(nn.Module):
|
| 32 |
def __init__(self,
|
|
@@ -36,7 +43,7 @@ class Tag2Text_Caption(nn.Module):
|
|
| 36 |
vit_grad_ckpt = False,
|
| 37 |
vit_ckpt_layer = 0,
|
| 38 |
prompt = 'a picture of ',
|
| 39 |
-
threshold = 0.
|
| 40 |
):
|
| 41 |
"""
|
| 42 |
Args:
|
|
@@ -105,6 +112,10 @@ class Tag2Text_Caption(nn.Module):
|
|
| 105 |
tie_encoder_decoder_weights(self.tag_encoder,self.vision_multi,'',' ')
|
| 106 |
self.tag_array = tra_array
|
| 107 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
def del_selfattention(self):
|
| 109 |
del self.vision_multi.embeddings
|
| 110 |
for layer in self.vision_multi.encoder.layer:
|
|
@@ -130,7 +141,8 @@ class Tag2Text_Caption(nn.Module):
|
|
| 130 |
|
| 131 |
logits = self.fc(mlr_tagembedding[0])
|
| 132 |
|
| 133 |
-
targets = torch.where(torch.sigmoid(logits) > self.threshold , torch.tensor(1.0).to(image.device), torch.zeros(self.num_class).to(image.device))
|
|
|
|
| 134 |
|
| 135 |
tag = targets.cpu().numpy()
|
| 136 |
tag[:,delete_tag_index] = 0
|
|
|
|
| 26 |
with open(rpath, 'r') as f:
|
| 27 |
return json.load(f)
|
| 28 |
|
| 29 |
+
# delete some tags that may disturb captioning
|
| 30 |
+
# 127: "quarter"; 2961: "back"; 3351: "two"; 3265: "three"; 3338: "four"; 3355: "five"; 3359: "one"
|
| 31 |
delete_tag_index = [127,2961, 3351, 3265, 3338, 3355, 3359]
|
| 32 |
+
|
| 33 |
+
# adjust thresholds for some tags
|
| 34 |
+
# default threshold: 0.68
|
| 35 |
+
# 2701: "person"; 2828: "man"; 1167: "woman";
|
| 36 |
+
tag_thrshold = {2701:0.7, 2828: 0.7, 1167: 0.7}
|
| 37 |
|
| 38 |
class Tag2Text_Caption(nn.Module):
|
| 39 |
def __init__(self,
|
|
|
|
| 43 |
vit_grad_ckpt = False,
|
| 44 |
vit_ckpt_layer = 0,
|
| 45 |
prompt = 'a picture of ',
|
| 46 |
+
threshold = 0.68,
|
| 47 |
):
|
| 48 |
"""
|
| 49 |
Args:
|
|
|
|
| 112 |
tie_encoder_decoder_weights(self.tag_encoder,self.vision_multi,'',' ')
|
| 113 |
self.tag_array = tra_array
|
| 114 |
|
| 115 |
+
self.class_threshold = torch.ones(self.num_class) * self.threshold
|
| 116 |
+
for key,value in tag_thrshold.items():
|
| 117 |
+
self.class_threshold[key] = value
|
| 118 |
+
|
| 119 |
def del_selfattention(self):
|
| 120 |
del self.vision_multi.embeddings
|
| 121 |
for layer in self.vision_multi.encoder.layer:
|
|
|
|
| 141 |
|
| 142 |
logits = self.fc(mlr_tagembedding[0])
|
| 143 |
|
| 144 |
+
# targets = torch.where(torch.sigmoid(logits) > self.threshold , torch.tensor(1.0).to(image.device), torch.zeros(self.num_class).to(image.device))
|
| 145 |
+
targets = torch.where(torch.sigmoid(logits) > self.class_threshold.to(image.device) , torch.tensor(1.0).to(image.device), torch.zeros(self.num_class).to(image.device))
|
| 146 |
|
| 147 |
tag = targets.cpu().numpy()
|
| 148 |
tag[:,delete_tag_index] = 0
|