Update modeling_ovis.py
Browse files- modeling_ovis.py +4 -4
modeling_ovis.py
CHANGED
@@ -288,10 +288,10 @@ class Ovis(OvisPreTrainedModel):
|
|
288 |
super().__init__(config, *inputs, **kwargs)
|
289 |
attn_kwargs = dict()
|
290 |
if self.config.llm_attn_implementation:
|
291 |
-
if self.config.llm_attn_implementation == "flash_attention_2":
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
attn_kwargs["attn_implementation"] = self.config.llm_attn_implementation
|
296 |
self.llm = AutoModelForCausalLM.from_config(self.config.llm_config, **attn_kwargs)
|
297 |
assert self.config.hidden_size == self.llm.config.hidden_size, "hidden size mismatch"
|
|
|
288 |
super().__init__(config, *inputs, **kwargs)
|
289 |
attn_kwargs = dict()
|
290 |
if self.config.llm_attn_implementation:
|
291 |
+
# if self.config.llm_attn_implementation == "flash_attention_2":
|
292 |
+
# assert (is_flash_attn_2_available() and
|
293 |
+
# version.parse(importlib.metadata.version("flash_attn")) >= version.parse("2.6.3")), \
|
294 |
+
# "Using `flash_attention_2` requires having `flash_attn>=2.6.3` installed."
|
295 |
attn_kwargs["attn_implementation"] = self.config.llm_attn_implementation
|
296 |
self.llm = AutoModelForCausalLM.from_config(self.config.llm_config, **attn_kwargs)
|
297 |
assert self.config.hidden_size == self.llm.config.hidden_size, "hidden size mismatch"
|