X-iZhang commited on
Commit
99ff09f
·
verified ·
1 Parent(s): 4e94887

Update libra/model/builder.py

Browse files
Files changed (1) hide show
  1. libra/model/builder.py +2 -2
libra/model/builder.py CHANGED
@@ -27,7 +27,7 @@ def load_pretrained_model(model_path, model_base, model_name, device="cpu"):
27
  device_map = {"": device}
28
  kwargs = {
29
  "device_map": device_map,
30
- "torch_dtype": torch.float32 # 对于 CPU,建议使用 float32 或 bfloat16
31
  }
32
 
33
 
@@ -114,7 +114,7 @@ def load_pretrained_model(model_path, model_base, model_name, device="cpu"):
114
  vision_tower = model.get_vision_tower()
115
  if not vision_tower.is_loaded:
116
  vision_tower.load_model()
117
- vision_tower.to(device=device, dtype=torch.float32)
118
  image_processor = vision_tower.image_processor
119
 
120
  if hasattr(model.config, "max_sequence_length"):
 
27
  device_map = {"": device}
28
  kwargs = {
29
  "device_map": device_map,
30
+ "torch_dtype": torch.bfloat16
31
  }
32
 
33
 
 
114
  vision_tower = model.get_vision_tower()
115
  if not vision_tower.is_loaded:
116
  vision_tower.load_model()
117
+ vision_tower.to(device=device, dtype=torch.bfloat16)
118
  image_processor = vision_tower.image_processor
119
 
120
  if hasattr(model.config, "max_sequence_length"):