ARCQUB commited on
Commit
d6daf6b
·
verified ·
1 Parent(s): 6c0c37c

Update models/qwen.py

Browse files
Files changed (1) hide show
  1. models/qwen.py +2 -2
models/qwen.py CHANGED
@@ -10,8 +10,8 @@ from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor
10
  model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
11
  "Qwen/Qwen2.5-VL-7B-Instruct",
12
  torch_dtype=torch.bfloat16,
13
- device_map="cuda",
14
- attn_implementation="flash_attention_2"
15
  )
16
 
17
  min_pixels = 256 * 28 * 28
 
10
  model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
11
  "Qwen/Qwen2.5-VL-7B-Instruct",
12
  torch_dtype=torch.bfloat16,
13
+ device_map="cuda"
14
+ #attn_implementation="flash_attention_2"
15
  )
16
 
17
  min_pixels = 256 * 28 * 28