Marco commited on
Commit
3af94b3
·
1 Parent(s): e721eb8

Corrected the readme to be compliant with the standard HF MM pipeline

Browse files
Files changed (1) hide show
  1. README.md +6 -12
README.md CHANGED
@@ -159,20 +159,14 @@ pixel_values = [pixel_values]
159
 
160
  # generate output
161
  with torch.inference_mode():
162
- gen_kwargs = dict(
163
- max_new_tokens=1024,
164
- do_sample=False,
165
- top_p=None,
166
- top_k=None,
167
- temperature=None,
168
- repetition_penalty=None,
169
- eos_token_id=model.generation_config.eos_token_id,
170
- pad_token_id=text_tokenizer.pad_token_id,
171
- use_cache=True
172
- )
173
- output_ids = model.generate(input_ids, pixel_values=pixel_values, attention_mask=attention_mask, **gen_kwargs)[0]
174
  output = text_tokenizer.decode(output_ids, skip_special_tokens=True)
175
  print(f'Output:\n{output}')
 
176
  ```
177
 
178
  <details>
 
159
 
160
  # generate output
161
  with torch.inference_mode():
162
+ if inputs['pixel_values'] is not None:
163
+ inputs['pixel_values'] = [pix.to(model.dtype).to(model.device) for pix in inputs['pixel_values']]
164
+ inputs = inputs.to('cuda')
165
+
166
+ output_ids = model.generate(inputs =inputs.pop('input_ids'), **inputs)[0]
 
 
 
 
 
 
 
167
  output = text_tokenizer.decode(output_ids, skip_special_tokens=True)
168
  print(f'Output:\n{output}')
169
+
170
  ```
171
 
172
  <details>