229nagibator229 commited on
Commit
10c6b12
·
1 Parent(s): 790cfa5

Changed image processor

Browse files
Files changed (1) hide show
  1. theia_model.py +44 -5
theia_model.py CHANGED
@@ -296,6 +296,39 @@ class ViTModelReg(ViTModel):
296
  ).to(module.cls_token.dtype)
297
 
298
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
299
  class DeiT(nn.Module):
300
  """DeiT model.
301
 
@@ -326,7 +359,9 @@ class DeiT(nn.Module):
326
 
327
  self.model.pooler = nn.Identity()
328
 
329
- self.processor = AutoProcessor.from_pretrained(model_name)
 
 
330
 
331
  def get_feature_size(
332
  self,
@@ -378,9 +413,13 @@ class DeiT(nn.Module):
378
  Returns:
379
  torch.Tensor: model output.
380
  """
381
- input = self.processor(
382
- x, return_tensors="pt", do_resize=do_resize, do_rescale=do_rescale, do_normalize=do_normalize
383
- ).to(self.model.device)
 
 
 
 
384
  y = self.model(**input, interpolate_pos_encoding=interpolate_pos_encoding)
385
  return y.last_hidden_state
386
 
@@ -1492,4 +1531,4 @@ class TheiaModel(PreTrainedModel):
1492
  "mse_losses_per_model": mse_losses_per_model,
1493
  "cos_losses_per_model": cos_losses_per_model,
1494
  "l1_losses_per_model": l1_losses_per_model,
1495
- }
 
296
  ).to(module.cls_token.dtype)
297
 
298
 
299
+ class TorchImageProcessor:
300
+ def __init__(self, processor):
301
+ # converts huggingface image processor to torch processor
302
+ self.mean = torch.tensor(processor.image_mean, dtype=torch.float32).reshape((1, 3, 1, 1))
303
+ self.std = torch.tensor(processor.image_std, dtype=torch.float32).reshape((1, 3, 1, 1))
304
+ self.width = processor.size['width']
305
+ self.height = processor.size['height']
306
+
307
+ def __call__(self, x,
308
+ do_resize: bool = True,
309
+ do_rescale: bool = True,
310
+ do_normalize: bool = True,
311
+ device='cuda'):
312
+ #x = torch.tensor(x, device=device, dtype=torch.float32)
313
+ if do_resize:
314
+ #assert x.shape[-1] == self.width
315
+ #assert x.shape[-2] == self.height
316
+ x = F.interpolate(
317
+ x,
318
+ size=(self.height, self.width),
319
+ mode='bilinear',
320
+ align_corners=False
321
+ )
322
+
323
+ # not implemented. If you really need resize on each forward step, use torch.interpolate
324
+ if do_rescale:
325
+ x = x / 255.
326
+ if do_normalize:
327
+ x = x - self.mean.to(device)
328
+ x = x / self.std.to(device)
329
+ return {'pixel_values': x}
330
+
331
+
332
  class DeiT(nn.Module):
333
  """DeiT model.
334
 
 
359
 
360
  self.model.pooler = nn.Identity()
361
 
362
+ #self.processor = AutoProcessor.from_pretrained(model_name)
363
+ self.processor = AutoProcessor.from_pretrained(model_name, use_fast=True)
364
+ self.gpu_processor = TorchImageProcessor(self.processor)
365
 
366
  def get_feature_size(
367
  self,
 
413
  Returns:
414
  torch.Tensor: model output.
415
  """
416
+ #input = self.processor(
417
+ # x, return_tensors="pt", do_resize=do_resize, do_rescale=do_rescale, do_normalize=do_normalize
418
+ #).to(self.model.device)
419
+ if x.shape[-1] == 3:
420
+ x = x.permute(0, 3, 1, 2)
421
+ input = self.gpu_processor(x, device=self.model.device)
422
+
423
  y = self.model(**input, interpolate_pos_encoding=interpolate_pos_encoding)
424
  return y.last_hidden_state
425
 
 
1531
  "mse_losses_per_model": mse_losses_per_model,
1532
  "cos_losses_per_model": cos_losses_per_model,
1533
  "l1_losses_per_model": l1_losses_per_model,
1534
+ }