damerajee commited on
Commit
e5b4a7d
·
verified ·
1 Parent(s): 98463c7

Update modeling_gpt2vision.py

Browse files
Files changed (1) hide show
  1. modeling_gpt2vision.py +3 -3
modeling_gpt2vision.py CHANGED
@@ -74,9 +74,9 @@ class GPT2Vision(PreTrainedModel):
74
  }
75
 
76
  def preprocess_inputs(self, batch):
77
- pixel_values = batch['pixel_values'].squeeze(1)
78
- input_ids = batch['input_ids'].squeeze(1)
79
- attention_mask = batch['attention_mask'].squeeze(1)
80
  input_ids = input_ids.to(self.device)
81
  attention_mask = attention_mask.to(self.device)
82
  pixel_values = pixel_values.to(self.device)
 
74
  }
75
 
76
  def preprocess_inputs(self, batch):
77
+ pixel_values = batch['pixel_values']
78
+ input_ids = batch['input_ids']
79
+ attention_mask = batch['attention_mask']
80
  input_ids = input_ids.to(self.device)
81
  attention_mask = attention_mask.to(self.device)
82
  pixel_values = pixel_values.to(self.device)