Update libra/model/builder.py
Browse files- libra/model/builder.py +2 -2
libra/model/builder.py
CHANGED
|
@@ -27,7 +27,7 @@ def load_pretrained_model(model_path, model_base, model_name, device="cpu"):
|
|
| 27 |
device_map = {"": device}
|
| 28 |
kwargs = {
|
| 29 |
"device_map": device_map,
|
| 30 |
-
"torch_dtype": torch.
|
| 31 |
}
|
| 32 |
|
| 33 |
|
|
@@ -114,7 +114,7 @@ def load_pretrained_model(model_path, model_base, model_name, device="cpu"):
|
|
| 114 |
vision_tower = model.get_vision_tower()
|
| 115 |
if not vision_tower.is_loaded:
|
| 116 |
vision_tower.load_model()
|
| 117 |
-
vision_tower.to(device=device, dtype=torch.
|
| 118 |
image_processor = vision_tower.image_processor
|
| 119 |
|
| 120 |
if hasattr(model.config, "max_sequence_length"):
|
|
|
|
| 27 |
device_map = {"": device}
|
| 28 |
kwargs = {
|
| 29 |
"device_map": device_map,
|
| 30 |
+
"torch_dtype": torch.bfloat16
|
| 31 |
}
|
| 32 |
|
| 33 |
|
|
|
|
| 114 |
vision_tower = model.get_vision_tower()
|
| 115 |
if not vision_tower.is_loaded:
|
| 116 |
vision_tower.load_model()
|
| 117 |
+
vision_tower.to(device=device, dtype=torch.bfloat16)
|
| 118 |
image_processor = vision_tower.image_processor
|
| 119 |
|
| 120 |
if hasattr(model.config, "max_sequence_length"):
|