Update modeling_qwen2_rm.py
Browse files- modeling_qwen2_rm.py +2 -2
modeling_qwen2_rm.py
CHANGED
|
@@ -48,8 +48,8 @@ from transformers.utils import (
|
|
| 48 |
from .configuration_qwen2_rm import Qwen2RMConfig as Qwen2Config
|
| 49 |
|
| 50 |
|
| 51 |
-
|
| 52 |
-
|
| 53 |
|
| 54 |
|
| 55 |
logger = logging.get_logger(__name__)
|
|
|
|
| 48 |
from .configuration_qwen2_rm import Qwen2RMConfig as Qwen2Config
|
| 49 |
|
| 50 |
|
| 51 |
+
if is_flash_attn_2_available():
|
| 52 |
+
from transformers.modeling_flash_attention_utils import _flash_attention_forward
|
| 53 |
|
| 54 |
|
| 55 |
logger = logging.get_logger(__name__)
|