Commit
·
10c6b12
1
Parent(s):
790cfa5
Changed image processor
Browse files- theia_model.py +44 -5
theia_model.py
CHANGED
@@ -296,6 +296,39 @@ class ViTModelReg(ViTModel):
|
|
296 |
).to(module.cls_token.dtype)
|
297 |
|
298 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
299 |
class DeiT(nn.Module):
|
300 |
"""DeiT model.
|
301 |
|
@@ -326,7 +359,9 @@ class DeiT(nn.Module):
|
|
326 |
|
327 |
self.model.pooler = nn.Identity()
|
328 |
|
329 |
-
self.processor = AutoProcessor.from_pretrained(model_name)
|
|
|
|
|
330 |
|
331 |
def get_feature_size(
|
332 |
self,
|
@@ -378,9 +413,13 @@ class DeiT(nn.Module):
|
|
378 |
Returns:
|
379 |
torch.Tensor: model output.
|
380 |
"""
|
381 |
-
input = self.processor(
|
382 |
-
|
383 |
-
).to(self.model.device)
|
|
|
|
|
|
|
|
|
384 |
y = self.model(**input, interpolate_pos_encoding=interpolate_pos_encoding)
|
385 |
return y.last_hidden_state
|
386 |
|
@@ -1492,4 +1531,4 @@ class TheiaModel(PreTrainedModel):
|
|
1492 |
"mse_losses_per_model": mse_losses_per_model,
|
1493 |
"cos_losses_per_model": cos_losses_per_model,
|
1494 |
"l1_losses_per_model": l1_losses_per_model,
|
1495 |
-
}
|
|
|
296 |
).to(module.cls_token.dtype)
|
297 |
|
298 |
|
299 |
+
class TorchImageProcessor:
|
300 |
+
def __init__(self, processor):
|
301 |
+
# converts huggingface image processor to torch processor
|
302 |
+
self.mean = torch.tensor(processor.image_mean, dtype=torch.float32).reshape((1, 3, 1, 1))
|
303 |
+
self.std = torch.tensor(processor.image_std, dtype=torch.float32).reshape((1, 3, 1, 1))
|
304 |
+
self.width = processor.size['width']
|
305 |
+
self.height = processor.size['height']
|
306 |
+
|
307 |
+
def __call__(self, x,
|
308 |
+
do_resize: bool = True,
|
309 |
+
do_rescale: bool = True,
|
310 |
+
do_normalize: bool = True,
|
311 |
+
device='cuda'):
|
312 |
+
#x = torch.tensor(x, device=device, dtype=torch.float32)
|
313 |
+
if do_resize:
|
314 |
+
#assert x.shape[-1] == self.width
|
315 |
+
#assert x.shape[-2] == self.height
|
316 |
+
x = F.interpolate(
|
317 |
+
x,
|
318 |
+
size=(self.height, self.width),
|
319 |
+
mode='bilinear',
|
320 |
+
align_corners=False
|
321 |
+
)
|
322 |
+
|
323 |
+
# not implemented. If you really need resize on each forward step, use torch.interpolate
|
324 |
+
if do_rescale:
|
325 |
+
x = x / 255.
|
326 |
+
if do_normalize:
|
327 |
+
x = x - self.mean.to(device)
|
328 |
+
x = x / self.std.to(device)
|
329 |
+
return {'pixel_values': x}
|
330 |
+
|
331 |
+
|
332 |
class DeiT(nn.Module):
|
333 |
"""DeiT model.
|
334 |
|
|
|
359 |
|
360 |
self.model.pooler = nn.Identity()
|
361 |
|
362 |
+
#self.processor = AutoProcessor.from_pretrained(model_name)
|
363 |
+
self.processor = AutoProcessor.from_pretrained(model_name, use_fast=True)
|
364 |
+
self.gpu_processor = TorchImageProcessor(self.processor)
|
365 |
|
366 |
def get_feature_size(
|
367 |
self,
|
|
|
413 |
Returns:
|
414 |
torch.Tensor: model output.
|
415 |
"""
|
416 |
+
#input = self.processor(
|
417 |
+
# x, return_tensors="pt", do_resize=do_resize, do_rescale=do_rescale, do_normalize=do_normalize
|
418 |
+
#).to(self.model.device)
|
419 |
+
if x.shape[-1] == 3:
|
420 |
+
x = x.permute(0, 3, 1, 2)
|
421 |
+
input = self.gpu_processor(x, device=self.model.device)
|
422 |
+
|
423 |
y = self.model(**input, interpolate_pos_encoding=interpolate_pos_encoding)
|
424 |
return y.last_hidden_state
|
425 |
|
|
|
1531 |
"mse_losses_per_model": mse_losses_per_model,
|
1532 |
"cos_losses_per_model": cos_losses_per_model,
|
1533 |
"l1_losses_per_model": l1_losses_per_model,
|
1534 |
+
}
|