Upload models
Browse files- modeling_internimage.py +4 -4
modeling_internimage.py
CHANGED
|
@@ -853,8 +853,8 @@ class InternImageModel(PreTrainedModel):
|
|
| 853 |
remove_center=config.remove_center, # for InternImage-H/G
|
| 854 |
)
|
| 855 |
|
| 856 |
-
def forward(self,
|
| 857 |
-
return self.model.forward_features(
|
| 858 |
|
| 859 |
|
| 860 |
class InternImageModelForImageClassification(PreTrainedModel):
|
|
@@ -888,8 +888,8 @@ class InternImageModelForImageClassification(PreTrainedModel):
|
|
| 888 |
remove_center=config.remove_center, # for InternImage-H/G
|
| 889 |
)
|
| 890 |
|
| 891 |
-
def forward(self,
|
| 892 |
-
outputs = self.model.forward(
|
| 893 |
|
| 894 |
if labels is not None:
|
| 895 |
logits = outputs['logits']
|
|
|
|
| 853 |
remove_center=config.remove_center, # for InternImage-H/G
|
| 854 |
)
|
| 855 |
|
| 856 |
+
def forward(self, pixel_values):
|
| 857 |
+
return self.model.forward_features(pixel_values)
|
| 858 |
|
| 859 |
|
| 860 |
class InternImageModelForImageClassification(PreTrainedModel):
|
|
|
|
| 888 |
remove_center=config.remove_center, # for InternImage-H/G
|
| 889 |
)
|
| 890 |
|
| 891 |
+
def forward(self, pixel_values, labels=None):
|
| 892 |
+
outputs = self.model.forward(pixel_values)
|
| 893 |
|
| 894 |
if labels is not None:
|
| 895 |
logits = outputs['logits']
|