Commit
·
7a58a7d
0
Parent(s):
chore: initialize the app
Browse files- .DS_Store +0 -0
- .gitattributes +35 -0
- LICENSE.txt +1 -0
- README.md +45 -0
- SelfExtend.py +199 -0
- app.py +388 -0
- requirements.txt +8 -0
- self_extend_patch/Llama.py +482 -0
- self_extend_patch/__init__.py +1 -0
- self_extend_patch/selfextend_flash_attn.py +199 -0
- self_extend_patch/selfextend_flash_attn_triton.py +278 -0
- self_extend_patch/triton_selfextend_flash_attn.py +250 -0
- style.css +24 -0
.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
.gitattributes
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
LICENSE.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
~
|
README.md
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Ghost 8B Beta (128k)
|
| 3 |
+
emoji: 👻 / 📚
|
| 4 |
+
colorFrom: indigo
|
| 5 |
+
colorTo: pink
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: 4.36.1
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: false
|
| 10 |
+
suggested_hardware: a10g-small
|
| 11 |
+
language:
|
| 12 |
+
- en
|
| 13 |
+
- vi
|
| 14 |
+
- es
|
| 15 |
+
- pt
|
| 16 |
+
- de
|
| 17 |
+
- it
|
| 18 |
+
- fr
|
| 19 |
+
- ko
|
| 20 |
+
- zh
|
| 21 |
+
license: other
|
| 22 |
+
license_name: ghost-llms
|
| 23 |
+
license_link: https://ghost-x.org/ghost-llms-license
|
| 24 |
+
tags:
|
| 25 |
+
- ghost
|
| 26 |
+
---
|
| 27 |
+
|
| 28 |
+
# ~
|
| 29 |
+
|
| 30 |
+
### Notes
|
| 31 |
+
|
| 32 |
+
The extension source code belongs to: "LLM Maybe LongLM: Self-Extend LLM Context Window Without Tuning".
|
| 33 |
+
|
| 34 |
+
See source code details [here](https://github.com/datamllab/LongLM).
|
| 35 |
+
|
| 36 |
+
```
|
| 37 |
+
@misc{jin2024llm,
|
| 38 |
+
title={LLM Maybe LongLM: Self-Extend LLM Context Window Without Tuning},
|
| 39 |
+
author={Hongye Jin and Xiaotian Han and Jingfeng Yang and Zhimeng Jiang and Zirui Liu and Chia-Yuan Chang and Huiyuan Chen and Xia Hu},
|
| 40 |
+
year={2024},
|
| 41 |
+
eprint={2401.01325},
|
| 42 |
+
archivePrefix={arXiv},
|
| 43 |
+
primaryClass={cs.CL}
|
| 44 |
+
}
|
| 45 |
+
```
|
SelfExtend.py
ADDED
|
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from types import MethodType
|
| 2 |
+
from functools import partial
|
| 3 |
+
import self_extend_patch as SE
|
| 4 |
+
|
| 5 |
+
def modify_method_of_instance(instance, target_class_name, target_method_name, new_method, visited_instances=None):
|
| 6 |
+
"""
|
| 7 |
+
This function modifies the method of an instance of a model class.
|
| 8 |
+
It's part from chat-GPT.
|
| 9 |
+
It will replace the method with the new method.
|
| 10 |
+
Currently, we only use this function to modify the attention method of a model. Do not test it further.
|
| 11 |
+
|
| 12 |
+
instance:
|
| 13 |
+
instance of a model to modify.
|
| 14 |
+
target_class_name:
|
| 15 |
+
name of the attention class to modify. E.g. 'LlamaAttention', 'GPTNeoXAttention', etc.
|
| 16 |
+
new_method: new method to replace the original method. E.g. 'self_extend_forward'.
|
| 17 |
+
It should include a parameter 'self' to be binded to the instance.
|
| 18 |
+
"""
|
| 19 |
+
target_found = False
|
| 20 |
+
if visited_instances is None:
|
| 21 |
+
visited_instances = set()
|
| 22 |
+
# Unique identifier for the instance (using id() since object's id is unique)
|
| 23 |
+
instance_id = id(instance)
|
| 24 |
+
if instance_id in visited_instances:
|
| 25 |
+
target_found = False
|
| 26 |
+
return target_found
|
| 27 |
+
# Add the instance to the already_visited set
|
| 28 |
+
visited_instances.add(instance_id)
|
| 29 |
+
|
| 30 |
+
# Check if this instance is of the target class
|
| 31 |
+
if instance.__class__.__name__ == target_class_name:
|
| 32 |
+
bond_method = MethodType(new_method, instance)
|
| 33 |
+
setattr(instance, target_method_name, bond_method)
|
| 34 |
+
target_found = True
|
| 35 |
+
return target_found
|
| 36 |
+
elif hasattr(instance, '__dict__'):
|
| 37 |
+
for attr_name, attr_value in instance.__dict__.items():
|
| 38 |
+
if isinstance(attr_value, object) and not isinstance(attr_value, (list, tuple, dict, set)):
|
| 39 |
+
_found = modify_method_of_instance(attr_value, target_class_name, target_method_name, new_method, visited_instances)
|
| 40 |
+
if _found:
|
| 41 |
+
target_found = True
|
| 42 |
+
elif isinstance(attr_value, (list, tuple)):
|
| 43 |
+
for item in attr_value:
|
| 44 |
+
if isinstance(item, object):
|
| 45 |
+
_found = modify_method_of_instance(item, target_class_name, target_method_name, new_method, visited_instances)
|
| 46 |
+
if _found:
|
| 47 |
+
target_found = True
|
| 48 |
+
# If attribute value is a dictionary, iterate over its values and recurse
|
| 49 |
+
# E.g, for a ModuleList, its moudels are stored in a dictionary: ._modules
|
| 50 |
+
elif isinstance(attr_value, dict):
|
| 51 |
+
for key, value in attr_value.items():
|
| 52 |
+
if isinstance(value, object):
|
| 53 |
+
_found = modify_method_of_instance(value, target_class_name, target_method_name, new_method, visited_instances)
|
| 54 |
+
if _found:
|
| 55 |
+
target_found = True
|
| 56 |
+
# If attribute value is a set, iterate and recurse
|
| 57 |
+
elif isinstance(attr_value, set):
|
| 58 |
+
for item in attr_value:
|
| 59 |
+
if isinstance(item, object):
|
| 60 |
+
_found = modify_method_of_instance(item, target_class_name, target_method_name, new_method, visited_instances)
|
| 61 |
+
if _found:
|
| 62 |
+
target_found = True
|
| 63 |
+
|
| 64 |
+
return target_found
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def apply(loaded_model, group_size, window_size, enable_flash_attention=False, scale_base=-1, flash_attention_impl="triton"):
|
| 68 |
+
'''
|
| 69 |
+
loaded_model:
|
| 70 |
+
model to apply the self-attention extension.
|
| 71 |
+
group_size:
|
| 72 |
+
group size for the self-attention extension.
|
| 73 |
+
window_size:
|
| 74 |
+
window size for the self-attention extension.
|
| 75 |
+
scale_base:
|
| 76 |
+
base for the scale, equal to pretraining length.
|
| 77 |
+
e.g. 4096 for Llama, 8192 for Gemma
|
| 78 |
+
|
| 79 |
+
Two recommended scale factor:
|
| 80 |
+
yarn: https://arxiv.org/abs/2309.00071
|
| 81 |
+
log: https://arxiv.org/abs/2202.12172 ; https://kexue.fm/archives/8823
|
| 82 |
+
This is helpful while retrieving a long sequence (e.g a long passkey).
|
| 83 |
+
But on real-world data, the impact is minor. (e.g. on LongBench, LEval).
|
| 84 |
+
|
| 85 |
+
The reported results in our paper does not use this scale except for long passkey retrieval.
|
| 86 |
+
'''
|
| 87 |
+
arch_name = loaded_model.__class__.__name__
|
| 88 |
+
if 'Llama' in arch_name:
|
| 89 |
+
if enable_flash_attention:
|
| 90 |
+
if flash_attention_impl == "flash_attn":
|
| 91 |
+
self_extend_attention_forward = partial(SE.Llama.flash_self_extend_forward,
|
| 92 |
+
group_size_1=group_size,
|
| 93 |
+
group_size_2=window_size,
|
| 94 |
+
scale_base=scale_base)
|
| 95 |
+
modifed_1 = modify_method_of_instance(loaded_model, "LlamaFlashAttention2", "_flash_attention_forward", SE.selfextend_flash_attn.flash_attention2_forward_with_window_size)
|
| 96 |
+
modifed_2 = modify_method_of_instance(loaded_model, "LlamaFlashAttention2", "forward", self_extend_attention_forward)
|
| 97 |
+
print("Using flash_attn flash self_extend!!")
|
| 98 |
+
if (not modifed_1) or (not modifed_2):
|
| 99 |
+
raise Exception(f"Failed to modify the attention method of {arch_name}")
|
| 100 |
+
|
| 101 |
+
elif flash_attention_impl == "triton":
|
| 102 |
+
self_extend_attention_forward = partial(SE.Llama.flash_self_extend_forward_triton,
|
| 103 |
+
group_size_1=group_size,
|
| 104 |
+
group_size_2=window_size,
|
| 105 |
+
scale_base=scale_base)
|
| 106 |
+
modifed = modify_method_of_instance(loaded_model, "LlamaFlashAttention2", "forward", self_extend_attention_forward)
|
| 107 |
+
print("Using triton flash self_extend!!")
|
| 108 |
+
if (not modifed):
|
| 109 |
+
raise Exception(f"Failed to modify the attention method of {arch_name}")
|
| 110 |
+
else:
|
| 111 |
+
raise Exception(f"Need to set the flash_attention_impl to 'flash_attn' or 'triton'.")
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
else:
|
| 115 |
+
self_extend_attention_forward = partial(SE.Llama.self_extend_forward,
|
| 116 |
+
group_size_1=group_size,
|
| 117 |
+
group_size_2=window_size,
|
| 118 |
+
scale_base=scale_base)
|
| 119 |
+
# after the default version of attention in 4.36 is LlamaSpdaAttention, but in before 4,36 or in 4.38, it is LlamaAttention
|
| 120 |
+
# print("loaded_model", loaded_model)
|
| 121 |
+
modifed_2 = modify_method_of_instance(loaded_model, "LlamaAttention", "forward", self_extend_attention_forward)
|
| 122 |
+
if not modifed_2:
|
| 123 |
+
raise Exception(f"Failed to modify the attention method of {arch_name}")
|
| 124 |
+
elif 'Mistral' in arch_name:
|
| 125 |
+
# Mistral shares the same architecture with Llama, so the implementation should be exchangable.
|
| 126 |
+
if enable_flash_attention:
|
| 127 |
+
self_extend_attention_forward = partial(SE.Mistral.flash_self_extend_forward,
|
| 128 |
+
group_size_1=group_size,
|
| 129 |
+
group_size_2=window_size,
|
| 130 |
+
scale_base=scale_base)
|
| 131 |
+
modifed_1 = modify_method_of_instance(loaded_model, "MistralFlashAttention2", "_flash_attention_forward", SE.selfextend_flash_attn.flash_attention2_forward_with_window_size)
|
| 132 |
+
modifed_2 = modify_method_of_instance(loaded_model, "MistralFlashAttention2", "forward", self_extend_attention_forward)
|
| 133 |
+
if (not modifed_1) or (not modifed_2):
|
| 134 |
+
raise Exception(f"Failed to modify the attention method of {arch_name}")
|
| 135 |
+
else:
|
| 136 |
+
self_extend_attention_forward = partial(SE.Mistral.self_extend_forward,
|
| 137 |
+
group_size_1=group_size,
|
| 138 |
+
group_size_2=window_size,
|
| 139 |
+
scale_base=scale_base)
|
| 140 |
+
modifed_2 = modify_method_of_instance(loaded_model, "MistralAttention", "forward", self_extend_attention_forward)
|
| 141 |
+
if not modifed_2:
|
| 142 |
+
raise Exception(f"Failed to modify the attention method of {arch_name}")
|
| 143 |
+
elif 'Gemma' in arch_name:
|
| 144 |
+
if enable_flash_attention:
|
| 145 |
+
self_extend_attention_forward = partial(SE.Gemma.flash_self_extend_forward,
|
| 146 |
+
group_size_1=group_size,
|
| 147 |
+
group_size_2=window_size,
|
| 148 |
+
scale_base=scale_base)
|
| 149 |
+
modifed_1 = modify_method_of_instance(loaded_model, "GemmaFlashAttention2", "_flash_attention_forward", SE.selfextend_flash_attn.flash_attention2_forward_with_window_size)
|
| 150 |
+
modifed_2 = modify_method_of_instance(loaded_model, "GemmaFlashAttention2", "forward", self_extend_attention_forward)
|
| 151 |
+
if (not modifed_1) or (not modifed_2):
|
| 152 |
+
raise Exception(f"Failed to modify the attention method of {arch_name}")
|
| 153 |
+
else:
|
| 154 |
+
self_extend_attention_forward = partial(SE.Gemma.self_extend_forward,
|
| 155 |
+
group_size_1=group_size,
|
| 156 |
+
group_size_2=window_size,
|
| 157 |
+
scale_base=scale_base)
|
| 158 |
+
modifed_2= modify_method_of_instance(loaded_model, "GemmaAttention", "forward", self_extend_attention_forward)
|
| 159 |
+
if not modifed_2:
|
| 160 |
+
raise Exception(f"Failed to modify the attention method of {arch_name}")
|
| 161 |
+
elif 'Qwen2' in arch_name:
|
| 162 |
+
if enable_flash_attention:
|
| 163 |
+
self_extend_attention_forward = partial(SE.Qwen2.flash_self_extend_forward,
|
| 164 |
+
group_size_1=group_size,
|
| 165 |
+
group_size_2=window_size,
|
| 166 |
+
scale_base=scale_base)
|
| 167 |
+
modifed_1 = modify_method_of_instance(loaded_model, "Qwen2FlashAttention2", "_flash_attention_forward", SE.selfextend_flash_attn.flash_attention2_forward_with_window_size)
|
| 168 |
+
modifed_2 = modify_method_of_instance(loaded_model, "Qwen2FlashAttention2", "forward", self_extend_attention_forward)
|
| 169 |
+
if (not modifed_1) or (not modifed_2):
|
| 170 |
+
raise Exception(f"Failed to modify the attention method of {arch_name}")
|
| 171 |
+
else:
|
| 172 |
+
self_extend_attention_forward = partial(SE.Qwen2.self_extend_forward,
|
| 173 |
+
group_size_1=group_size,
|
| 174 |
+
group_size_2=window_size,
|
| 175 |
+
scale_base=scale_base)
|
| 176 |
+
modifed_2 = modify_method_of_instance(loaded_model, "Qwen2Attention", "forward", self_extend_attention_forward)
|
| 177 |
+
if not modifed_2:
|
| 178 |
+
raise Exception(f"Failed to modify the attention method of {arch_name}")
|
| 179 |
+
elif 'Phi' in arch_name:
|
| 180 |
+
if enable_flash_attention:
|
| 181 |
+
self_extend_attention_forward = partial(SE.Phi.flash_self_extend_forward,
|
| 182 |
+
group_size_1=group_size,
|
| 183 |
+
group_size_2=window_size,
|
| 184 |
+
scale_base=scale_base)
|
| 185 |
+
modifed_1 = modify_method_of_instance(loaded_model, "PhiFlashAttention2", "_flash_attention_forward", SE.selfextend_flash_attn.flash_attention2_forward_with_window_size)
|
| 186 |
+
modifed_2 = modify_method_of_instance(loaded_model, "PhiFlashAttention2", "forward", self_extend_attention_forward)
|
| 187 |
+
if (not modifed_1) or (not modifed_2):
|
| 188 |
+
raise Exception(f"Failed to modify the attention method of {arch_name}")
|
| 189 |
+
else:
|
| 190 |
+
self_extend_attention_forward = partial(SE.Phi.self_extend_forward,
|
| 191 |
+
group_size_1=group_size,
|
| 192 |
+
group_size_2=window_size,
|
| 193 |
+
scale_base=scale_base)
|
| 194 |
+
modifed_2 = modify_method_of_instance(loaded_model, "PhiAttention", "forward", self_extend_attention_forward)
|
| 195 |
+
if not modifed_2:
|
| 196 |
+
raise Exception(f"Failed to modify the attention method of {arch_name}")
|
| 197 |
+
else:
|
| 198 |
+
raise NotImplementedError
|
| 199 |
+
|
app.py
ADDED
|
@@ -0,0 +1,388 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# pylint: skip-file
|
| 2 |
+
|
| 3 |
+
import subprocess
|
| 4 |
+
|
| 5 |
+
subprocess.run(
|
| 6 |
+
f"pip install flash-attn --no-build-isolation",
|
| 7 |
+
env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
|
| 8 |
+
shell=True,
|
| 9 |
+
)
|
| 10 |
+
|
| 11 |
+
import os
|
| 12 |
+
from threading import Thread
|
| 13 |
+
from typing import Iterator
|
| 14 |
+
|
| 15 |
+
import gradio as gr
|
| 16 |
+
import spaces
|
| 17 |
+
import torch
|
| 18 |
+
import SelfExtend
|
| 19 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
MAX_MAX_NEW_TOKENS = 4096
|
| 23 |
+
DEFAULT_MAX_NEW_TOKENS = 1536
|
| 24 |
+
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "123392"))
|
| 25 |
+
|
| 26 |
+
DESCRIPTION = """\
|
| 27 |
+
# Playground with Ghost 8B Beta (p)
|
| 28 |
+
|
| 29 |
+
**Ghost 8B Beta** is a large language model developed with goals that include excellent multilingual support, superior knowledge capabilities, and cost-effectiveness. The model comes in two context length versions, 8k and 128k, along with multilingual function tools support by default.
|
| 30 |
+
|
| 31 |
+
The languages supported are 🇺🇸 English, 🇫🇷 French, 🇮🇹 Italian, 🇪🇸 Spanish, 🇵🇹 Portuguese, 🇩🇪 German, 🇻🇳 Vietnamese, 🇰🇷 Korean and 🇨🇳 Chinese.
|
| 32 |
+
|
| 33 |
+
📋 Note: current model version is "disl-0x5" (10 Jul 2024), context length 128k (123392 tokens) and current status is "moderating / previewing". For detailed information about the model, see [here](https://ghost-x.org/docs/models/ghost-8b-beta/). Try to experience it the way you want!
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
PLACEHOLDER = """
|
| 38 |
+
<div style="padding: 30px; text-align: center; display: flex; flex-direction: column; align-items: center;">
|
| 39 |
+
<h1 style="font-size: 26px; margin-bottom: 2px; opacity: 0.20;">👻 Ghost 8B Beta</h1>
|
| 40 |
+
<p style="font-size: 18px; margin-bottom: 2px; opacity: 0.10;">Ask and share whatever you want ~</p>
|
| 41 |
+
</div>
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
LICENSE = """
|
| 45 |
+
<p/>
|
| 46 |
+
|
| 47 |
+
---
|
| 48 |
+
Ghost 8B Beta may give inaccurate information, including information about people, so please verify Ghost 8B Beta's answers. [Ghost 8B Beta](https://ghost-x.org/docs/models/ghost-8b-beta/) by [Ghost X](https://ghost-x.org).
|
| 49 |
+
"""
|
| 50 |
+
|
| 51 |
+
EXAMPLES = [
|
| 52 |
+
[
|
| 53 |
+
"What is the significance of the Higgs boson in the Standard Model of particle physics?"
|
| 54 |
+
],
|
| 55 |
+
[
|
| 56 |
+
"Qu'est-ce que l'effet fondateur et comment influence-t-il la diversité génétique d'une population?"
|
| 57 |
+
],
|
| 58 |
+
["Qual è il principio di Le Chatelier e come si applica agli equilibri chimici?"],
|
| 59 |
+
[
|
| 60 |
+
"¿Qué es una supernova y cuál es su importancia en la formación de elementos pesados en el universo?"
|
| 61 |
+
],
|
| 62 |
+
[
|
| 63 |
+
"Qual é a definição formal de uma integral de linha e como é utilizada em física?"
|
| 64 |
+
],
|
| 65 |
+
[
|
| 66 |
+
"Was versteht man unter dem Moho-Diskontinuität und welche Bedeutung hat sie für das Verständnis der Erdkruste?"
|
| 67 |
+
],
|
| 68 |
+
[
|
| 69 |
+
"Hiện tượng nhà kính là gì và nó ảnh hưởng như thế nào đến biến đổi khí hậu toàn cầu?"
|
| 70 |
+
],
|
| 71 |
+
[
|
| 72 |
+
"알고리즘의 시간 복잡도가 중요한 이유는 무엇이며, 시간 복잡도를 어떻게 분석하나요?"
|
| 73 |
+
],
|
| 74 |
+
["什么是CRISPR-Cas9基因编辑技术,它在现代生物学研究中的作用是什么?"],
|
| 75 |
+
[
|
| 76 |
+
"Create a Python function that takes a list of integers and returns the list sorted in ascending order without using the built-in sort or sorted functions."
|
| 77 |
+
],
|
| 78 |
+
[
|
| 79 |
+
"Écrivez une fonction en C++ qui trouve le plus long sous-tableau contigu avec une somme égale à zéro."
|
| 80 |
+
],
|
| 81 |
+
[
|
| 82 |
+
"Scrivi una funzione in Java che calcola il fattoriale di un numero utilizzando la ricorsione."
|
| 83 |
+
],
|
| 84 |
+
[
|
| 85 |
+
"Desarrolla una función en JavaScript que determine si una cadena de texto es un palíndromo, ignorando espacios y signos de puntuación."
|
| 86 |
+
],
|
| 87 |
+
["Implemente uma função em C# que verifique se uma matriz quadrada é simétrica."],
|
| 88 |
+
[
|
| 89 |
+
"Schreiben Sie eine Funktion in Swift, die eine gegebene Zeichenfolge in umgekehrter Reihenfolge zurückgibt, ohne integrierte Funktionen zu verwenden."
|
| 90 |
+
],
|
| 91 |
+
[
|
| 92 |
+
"Viết một hàm trong PHP để tìm tất cả các số nguyên tố trong một khoảng cho trước."
|
| 93 |
+
],
|
| 94 |
+
[
|
| 95 |
+
"파이썬을 사용하여 주어진 이진 트리가 이진 탐색 트리인지 확인하는 함수를 작성하십시오."
|
| 96 |
+
],
|
| 97 |
+
[
|
| 98 |
+
"用 Go 语言编写一个函数,计算给定字符串中每个字符出现的次数,并返回一个包含字符及其出现次数的映射。"
|
| 99 |
+
],
|
| 100 |
+
[
|
| 101 |
+
"Can you help me design a detailed project plan for developing a machine learning model for predicting stock prices?"
|
| 102 |
+
],
|
| 103 |
+
[
|
| 104 |
+
"Pouvez-vous m'aider à organiser un emploi du temps hebdomadaire pour maximiser la productivité de mon équipe de développement logiciel?"
|
| 105 |
+
],
|
| 106 |
+
[
|
| 107 |
+
"Puoi aiutarmi a creare un piano di sviluppo per un'applicazione mobile che gestisce le prenotazioni di ristoranti?"
|
| 108 |
+
],
|
| 109 |
+
[
|
| 110 |
+
"¿Podrías ayudarme a elaborar un plan detallado para la implementación de un sistema de gestión de contenido (CMS) en una empresa mediana?"
|
| 111 |
+
],
|
| 112 |
+
[
|
| 113 |
+
"Você pode me ajudar a planejar uma estratégia de desenvolvimento para um sistema de comércio eletrônico escalável?"
|
| 114 |
+
],
|
| 115 |
+
[
|
| 116 |
+
"Können Sie mir helfen, einen detaillierten Zeitplan für die Implementierung eines neuen ERP-Systems in unserem Unternehmen zu erstellen?"
|
| 117 |
+
],
|
| 118 |
+
[
|
| 119 |
+
"Bạn có thể giúp tôi xây dựng một kế hoạch phát triển chi tiết cho dự án xây dựng hệ thống quản lý chuỗi cung ứng không?"
|
| 120 |
+
],
|
| 121 |
+
[
|
| 122 |
+
"신경망 기반 이미지 인식 모델 개발을 위한 세부 프로젝트 계획을 세우는 데 도움을 줄 수 있나요?"
|
| 123 |
+
],
|
| 124 |
+
["你能帮我制定一个详细的开发计划,用于创建一个基于区块链的分布式账本系统吗?"],
|
| 125 |
+
[
|
| 126 |
+
"Prove that the sum of the squares of any two sides of a right triangle is equal to the square of the hypotenuse."
|
| 127 |
+
],
|
| 128 |
+
[
|
| 129 |
+
"Calculez la force gravitationnelle entre deux masses de 10 kg chacune séparées par une distance de 1 mètre."
|
| 130 |
+
],
|
| 131 |
+
[
|
| 132 |
+
"Determina la formula molecolare di un composto che contiene il 40% di carbonio, il 6.67% di idrogeno e il 53.33% di ossigeno in massa."
|
| 133 |
+
],
|
| 134 |
+
[
|
| 135 |
+
"Explica la teoría del ciclo económico de Schumpeter y cómo se aplica a la economía moderna."
|
| 136 |
+
],
|
| 137 |
+
[
|
| 138 |
+
"Calcule a energia potencial gravitacional de um objeto de 5 kg a uma altura de 10 metros acima do solo (g = 9,8 m/s²)."
|
| 139 |
+
],
|
| 140 |
+
[
|
| 141 |
+
"Beweisen Sie, dass jede Primzahl der Form 4k+1 als Summe zweier Quadrate geschrieben werden kann."
|
| 142 |
+
],
|
| 143 |
+
[
|
| 144 |
+
"Tính nồng độ mol của dung dịch H₂SO₄ khi hoà tan 98 gam H₂SO₄ vào nước để được 1 lít dung dịch."
|
| 145 |
+
],
|
| 146 |
+
["케인스 경제학의 핵심 개념과 그것이 현대 경제 정책에 미치는 영향을 설명하십시오."],
|
| 147 |
+
["计算一个质量为2 kg的物体在3米高处的重力势能(g = 9.8 m/s²)。"],
|
| 148 |
+
[
|
| 149 |
+
'Identify the author of a novel that features a dystopian society where "Big Brother" watches over its citizens and the protagonist works for the Ministry of Truth.'
|
| 150 |
+
],
|
| 151 |
+
[
|
| 152 |
+
"Quel est le seul mammifère capable de voler activement, souvent associé à la nuit et capable d'écholocalisation?"
|
| 153 |
+
],
|
| 154 |
+
[
|
| 155 |
+
"Qual è l'opera letteraria italiana che narra il viaggio immaginario di un poeta attraverso Inferno, Purgatorio e Paradiso, guidato da Virgilio e Beatrice?"
|
| 156 |
+
],
|
| 157 |
+
[
|
| 158 |
+
"¿Qué insecto es conocido por su organización social compleja, su capacidad para producir miel y su comunicación mediante la danza?"
|
| 159 |
+
],
|
| 160 |
+
[
|
| 161 |
+
"Qual é o fenômeno atmosférico que ocorre quando uma massa de ar quente se encontra com uma massa de ar frio, resultando em uma violenta tempestade giratória?"
|
| 162 |
+
],
|
| 163 |
+
[
|
| 164 |
+
"Welches literarische Werk beschreibt die Geschichte eines jungen Mädchens, das durch einen Kaninchenbau in eine fantastische Welt voller skurriler Charaktere fällt?"
|
| 165 |
+
],
|
| 166 |
+
[
|
| 167 |
+
"Động vật nào có thể tái sinh toàn bộ cơ thể từ một mảnh nhỏ của chính nó, thường sống dưới nước và có thể có nhiều xúc tu?"
|
| 168 |
+
],
|
| 169 |
+
[
|
| 170 |
+
"어떤 자연 현상은 태양빛이 대기 중의 물방울에 반사되고 굴절되어 발생하며, 하늘에 나타나는 여러 색깔의 아치 형태를 띠나요?"
|
| 171 |
+
],
|
| 172 |
+
["这部文学作品讲述了一位绅士和他的侍从的冒险故事,他们在"],
|
| 173 |
+
[
|
| 174 |
+
"Can you derive the Euler-Lagrange equation from the principle of stationary action in classical mechanics?"
|
| 175 |
+
],
|
| 176 |
+
[
|
| 177 |
+
"Expliquez la notion de « différence ontologique » chez Martin Heidegger et son importance pour la phénoménologie."
|
| 178 |
+
],
|
| 179 |
+
[
|
| 180 |
+
"Qual è il significato simbolico del colore blu nei dipinti di Giotto di Bondone durante il Rinascimento?"
|
| 181 |
+
],
|
| 182 |
+
[
|
| 183 |
+
"¿Cómo afecta el cambio de código a la estructura gramatical en comunidades bilingües de habla español-inglés?"
|
| 184 |
+
],
|
| 185 |
+
[
|
| 186 |
+
"Qual é o impacto da política monetária não convencional no controle da inflação durante uma crise econômica?"
|
| 187 |
+
],
|
| 188 |
+
[
|
| 189 |
+
"Erklären Sie den Unterschied zwischen deterministischen und nicht-deterministischen endlichen Automaten und ihre Anwendungsbereiche."
|
| 190 |
+
],
|
| 191 |
+
[
|
| 192 |
+
"Giải thích cơ chế của quá trình phiên mã ngược (reverse transcription) và tầm quan trọng của nó trong nghiên cứu HIV/AIDS."
|
| 193 |
+
],
|
| 194 |
+
["조선시대 성리학이 한국 사회와 문화에 미친 영향을 설명하세요."],
|
| 195 |
+
["如何解释量子纠缠现象,以及它在量子计算中的潜在应用?"],
|
| 196 |
+
[
|
| 197 |
+
"How can you design a daily schedule that maximizes productivity for a remote worker who has multiple meetings and project deadlines?"
|
| 198 |
+
],
|
| 199 |
+
[
|
| 200 |
+
"Quels sont les meilleures stratégies pour gérer les conflits au sein d'une équipe multiculturelle travaillant sur un projet commun?"
|
| 201 |
+
],
|
| 202 |
+
[
|
| 203 |
+
"Quali sono i migliori consigli per mantenere un equilibrio tra vita professionale e vita privata in un ambiente lavorativo stressante?"
|
| 204 |
+
],
|
| 205 |
+
[
|
| 206 |
+
"¿Cómo se puede elaborar un plan financiero personal efectivo que incluya ahorro para la jubilación, inversión y manejo de deudas?"
|
| 207 |
+
],
|
| 208 |
+
[
|
| 209 |
+
"Quais são as melhores práticas para implementar metodologias ágeis em uma equipe de desenvolvimento de software?"
|
| 210 |
+
],
|
| 211 |
+
[
|
| 212 |
+
"Welche Strategien können verwendet werden, um ein starkes berufliches Netzwerk aufzubauen und zu pflegen, insbesondere in der Tech-Branche?"
|
| 213 |
+
],
|
| 214 |
+
[
|
| 215 |
+
"Những bước nào cần thiết để xây dựng một lộ trình phát triển sự nghiệp bền vững trong lĩnh vực công nghệ thông tin?"
|
| 216 |
+
],
|
| 217 |
+
["프로젝트의 범위 변동을 효과적으로 관리하기 위한 최고의 방법은 무엇인가요?"],
|
| 218 |
+
["在快速变化的职场环境中,如何有效地实现工作与生活的平衡?"],
|
| 219 |
+
[
|
| 220 |
+
"Write an argumentative essay discussing the pros and cons of artificial intelligence in the workplace, including potential ethical concerns."
|
| 221 |
+
],
|
| 222 |
+
[
|
| 223 |
+
"Analysez les impacts sociaux et économiques de la digitalisation sur les petites entreprises en France."
|
| 224 |
+
],
|
| 225 |
+
[
|
| 226 |
+
"Scrivi un'email formale al direttore di una rivista per proporre un articolo sulla sostenibilità ambientale nelle città italiane."
|
| 227 |
+
],
|
| 228 |
+
[
|
| 229 |
+
"Elabora un informe detallado sobre los efectos del cambio climático en la biodiversidad de la región amazónica."
|
| 230 |
+
],
|
| 231 |
+
[
|
| 232 |
+
"Analise criticamente os principais pontos abordados no relatório anual do Banco Mundial sobre a pobreza global."
|
| 233 |
+
],
|
| 234 |
+
[
|
| 235 |
+
"Erstellen Sie eine technische Dokumentation für die Implementierung eines neuen Software-Features in einer bestehenden Anwendung."
|
| 236 |
+
],
|
| 237 |
+
[
|
| 238 |
+
"Viết một bài luận phân tích về tác động của cuộc cách mạng công nghiệp 4.0 đối với thị trường lao động Việt Nam."
|
| 239 |
+
],
|
| 240 |
+
[
|
| 241 |
+
"인공지능의 윤리적 문제에 대한 연구 논문을 작성하고, 다양한 사례를 통해 그 영향을 분석하세요."
|
| 242 |
+
],
|
| 243 |
+
["分析鲁迅的小说《阿Q正传》中反映的中国社会问题和作者的批判态度。"],
|
| 244 |
+
]
|
| 245 |
+
|
| 246 |
+
if not torch.cuda.is_available():
|
| 247 |
+
DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
if torch.cuda.is_available():
|
| 251 |
+
model_id = "lamhieu/ghost-8b-beta-disl-0x5-8k"
|
| 252 |
+
model_tk = os.getenv("HF_TOKEN", None)
|
| 253 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 254 |
+
model_id,
|
| 255 |
+
device_map="auto",
|
| 256 |
+
torch_dtype=torch.bfloat16,
|
| 257 |
+
attn_implementation="flash_attention_2",
|
| 258 |
+
trust_remote_code=True,
|
| 259 |
+
token=model_tk,
|
| 260 |
+
)
|
| 261 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 262 |
+
model_id,
|
| 263 |
+
trust_remote_code=True,
|
| 264 |
+
token=model_tk,
|
| 265 |
+
)
|
| 266 |
+
SelfExtend.apply(
|
| 267 |
+
model,
|
| 268 |
+
group_size=16,
|
| 269 |
+
window_size=512,
|
| 270 |
+
enable_flash_attention=True,
|
| 271 |
+
flash_attention_impl="flash_attn",
|
| 272 |
+
)
|
| 273 |
+
model.generation_config.max_length = 123392
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
@spaces.GPU(duration=120)
|
| 277 |
+
def generate(
|
| 278 |
+
message: str,
|
| 279 |
+
chat_history: list[tuple[str, str]],
|
| 280 |
+
system_prompt: str,
|
| 281 |
+
max_new_tokens: int = 1536,
|
| 282 |
+
temperature: float = 0.4,
|
| 283 |
+
top_p: float = 0.95,
|
| 284 |
+
top_k: int = 50,
|
| 285 |
+
repetition_penalty: float = 1.0,
|
| 286 |
+
) -> Iterator[str]:
|
| 287 |
+
conversation = []
|
| 288 |
+
if system_prompt:
|
| 289 |
+
conversation.append({"role": "system", "content": system_prompt})
|
| 290 |
+
for user, assistant in chat_history:
|
| 291 |
+
conversation.extend(
|
| 292 |
+
[
|
| 293 |
+
{"role": "user", "content": user},
|
| 294 |
+
{"role": "assistant", "content": assistant},
|
| 295 |
+
]
|
| 296 |
+
)
|
| 297 |
+
conversation.append({"role": "user", "content": message})
|
| 298 |
+
|
| 299 |
+
input_ids = tokenizer.apply_chat_template(
|
| 300 |
+
conversation, add_generation_prompt=True, return_tensors="pt"
|
| 301 |
+
)
|
| 302 |
+
input_ids = input_ids.to(model.device)
|
| 303 |
+
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
|
| 304 |
+
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
|
| 305 |
+
gr.Warning(
|
| 306 |
+
f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens."
|
| 307 |
+
)
|
| 308 |
+
|
| 309 |
+
streamer = TextIteratorStreamer(
|
| 310 |
+
tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
|
| 311 |
+
)
|
| 312 |
+
generate_kwargs = dict(
|
| 313 |
+
input_ids=input_ids,
|
| 314 |
+
streamer=streamer,
|
| 315 |
+
max_new_tokens=max_new_tokens,
|
| 316 |
+
do_sample=True,
|
| 317 |
+
top_p=top_p,
|
| 318 |
+
top_k=top_k,
|
| 319 |
+
temperature=temperature,
|
| 320 |
+
repetition_penalty=repetition_penalty,
|
| 321 |
+
)
|
| 322 |
+
t = Thread(target=model.generate, kwargs=generate_kwargs)
|
| 323 |
+
t.start()
|
| 324 |
+
|
| 325 |
+
outputs = []
|
| 326 |
+
for text in streamer:
|
| 327 |
+
outputs.append(text)
|
| 328 |
+
yield "".join(outputs)
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
chatbot = gr.Chatbot(height=500, placeholder=PLACEHOLDER, label="Ghost 8B Beta")
|
| 332 |
+
|
| 333 |
+
chat_interface = gr.ChatInterface(
|
| 334 |
+
fn=generate,
|
| 335 |
+
chatbot=chatbot,
|
| 336 |
+
fill_height=True,
|
| 337 |
+
additional_inputs=[
|
| 338 |
+
gr.Textbox(label="System prompt", lines=6),
|
| 339 |
+
gr.Slider(
|
| 340 |
+
label="Max new tokens",
|
| 341 |
+
minimum=1,
|
| 342 |
+
maximum=MAX_MAX_NEW_TOKENS,
|
| 343 |
+
step=1,
|
| 344 |
+
value=DEFAULT_MAX_NEW_TOKENS,
|
| 345 |
+
),
|
| 346 |
+
gr.Slider(
|
| 347 |
+
label="Temperature",
|
| 348 |
+
minimum=0.1,
|
| 349 |
+
maximum=2.0,
|
| 350 |
+
step=0.1,
|
| 351 |
+
value=0.4,
|
| 352 |
+
),
|
| 353 |
+
gr.Slider(
|
| 354 |
+
label="Top-p (nucleus sampling)",
|
| 355 |
+
minimum=0.05,
|
| 356 |
+
maximum=1.0,
|
| 357 |
+
step=0.05,
|
| 358 |
+
value=0.95,
|
| 359 |
+
),
|
| 360 |
+
gr.Slider(
|
| 361 |
+
label="Top-k",
|
| 362 |
+
minimum=1,
|
| 363 |
+
maximum=100,
|
| 364 |
+
step=1,
|
| 365 |
+
value=50,
|
| 366 |
+
),
|
| 367 |
+
gr.Slider(
|
| 368 |
+
label="Repetition penalty",
|
| 369 |
+
minimum=1.0,
|
| 370 |
+
maximum=2.0,
|
| 371 |
+
step=0.05,
|
| 372 |
+
value=1.0,
|
| 373 |
+
),
|
| 374 |
+
],
|
| 375 |
+
stop_btn="Stop",
|
| 376 |
+
cache_examples=False,
|
| 377 |
+
examples=EXAMPLES,
|
| 378 |
+
examples_per_page=9,
|
| 379 |
+
)
|
| 380 |
+
|
| 381 |
+
with gr.Blocks(fill_height=True, css="style.css") as demo:
|
| 382 |
+
gr.Markdown(DESCRIPTION)
|
| 383 |
+
chat_interface.render()
|
| 384 |
+
gr.Markdown(LICENSE)
|
| 385 |
+
|
| 386 |
+
if __name__ == "__main__":
|
| 387 |
+
demo.queue(max_size=20).launch(share=True)
|
| 388 |
+
# demo.launch(share=True)
|
requirements.txt
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
accelerate==0.30.1
|
| 2 |
+
bitsandbytes==0.43.1
|
| 3 |
+
gradio==4.37.2
|
| 4 |
+
scipy==1.13.0
|
| 5 |
+
sentencepiece==0.2.0
|
| 6 |
+
spaces==0.28.3
|
| 7 |
+
torch==2.0.0
|
| 8 |
+
transformers==4.41.0
|
self_extend_patch/Llama.py
ADDED
|
@@ -0,0 +1,482 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import math
|
| 5 |
+
from typing import Optional, Tuple
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from transformers.cache_utils import Cache
|
| 8 |
+
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
| 9 |
+
from .selfextend_flash_attn import self_extend_flash_forward
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
| 14 |
+
"""
|
| 15 |
+
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
| 16 |
+
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
| 17 |
+
"""
|
| 18 |
+
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
| 19 |
+
if n_rep == 1:
|
| 20 |
+
return hidden_states
|
| 21 |
+
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
| 22 |
+
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
| 23 |
+
|
| 24 |
+
def rotate_half(x):
|
| 25 |
+
"""Rotates half the hidden dims of the input."""
|
| 26 |
+
x1 = x[..., : x.shape[-1] // 2]
|
| 27 |
+
x2 = x[..., x.shape[-1] // 2 :]
|
| 28 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
| 32 |
+
"""Applies Rotary Position Embedding to the query and key tensors.
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
q (`torch.Tensor`): The query tensor.
|
| 36 |
+
k (`torch.Tensor`): The key tensor.
|
| 37 |
+
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
| 38 |
+
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
| 39 |
+
position_ids (`torch.Tensor`, *optional*):
|
| 40 |
+
Deprecated and unused.
|
| 41 |
+
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
| 42 |
+
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
| 43 |
+
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
| 44 |
+
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
| 45 |
+
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
| 46 |
+
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
| 47 |
+
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
| 48 |
+
Returns:
|
| 49 |
+
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
| 50 |
+
"""
|
| 51 |
+
cos = cos.unsqueeze(unsqueeze_dim)
|
| 52 |
+
sin = sin.unsqueeze(unsqueeze_dim)
|
| 53 |
+
q_embed = (q * cos) + (rotate_half(q) * sin) if not q is None else None
|
| 54 |
+
k_embed = (k * cos) + (rotate_half(k) * sin) if not k is None else None
|
| 55 |
+
return q_embed, k_embed
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def self_extend_forward(
|
| 61 |
+
self,
|
| 62 |
+
hidden_states: torch.Tensor,
|
| 63 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 64 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 65 |
+
past_key_value: Optional[Cache] = None,
|
| 66 |
+
output_attentions: bool = False,
|
| 67 |
+
use_cache: bool = False,
|
| 68 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 69 |
+
group_size_1: Optional[float] = 8,
|
| 70 |
+
group_size_2: Optional[float] = 1024,
|
| 71 |
+
scale_base: Optional[int] = -1,
|
| 72 |
+
**kwargs,
|
| 73 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 74 |
+
if "padding_mask" in kwargs:
|
| 75 |
+
warnings.warn(
|
| 76 |
+
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
bsz, q_len, _ = hidden_states.size()
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
if self.config.pretraining_tp > 1:
|
| 83 |
+
key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
|
| 84 |
+
query_slices = self.q_proj.weight.split(
|
| 85 |
+
(self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
|
| 86 |
+
)
|
| 87 |
+
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
|
| 88 |
+
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
|
| 89 |
+
|
| 90 |
+
query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
|
| 91 |
+
query_states = torch.cat(query_states, dim=-1)
|
| 92 |
+
|
| 93 |
+
key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
|
| 94 |
+
key_states = torch.cat(key_states, dim=-1)
|
| 95 |
+
|
| 96 |
+
value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
|
| 97 |
+
value_states = torch.cat(value_states, dim=-1)
|
| 98 |
+
|
| 99 |
+
else:
|
| 100 |
+
query_states = self.q_proj(hidden_states)
|
| 101 |
+
key_states = self.k_proj(hidden_states)
|
| 102 |
+
value_states = self.v_proj(hidden_states)
|
| 103 |
+
|
| 104 |
+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 105 |
+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 106 |
+
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 107 |
+
|
| 108 |
+
if scale_base > 0:
|
| 109 |
+
scaled_query = query_states * ((position_ids + 1)[:, None, :, None].log() / np.log(scale_base)).clip(1).to(query_states.dtype) # log scale
|
| 110 |
+
#scaled_query = query_states * (((0.1*(((position_ids+1)[:, None, :, None]/scale_base).log())+1)**2).clip(1)).to(query_states.dtype) # Yarn scale
|
| 111 |
+
else:
|
| 112 |
+
scaled_query = query_states
|
| 113 |
+
|
| 114 |
+
past_key_value = getattr(self, "past_key_value", past_key_value)
|
| 115 |
+
if past_key_value is not None:
|
| 116 |
+
# sin and cos are specific to RoPE models; position_ids needed for the static cache
|
| 117 |
+
cache_kwargs = {"cache_position": cache_position}
|
| 118 |
+
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
| 119 |
+
kv_seq_len = key_states.shape[-2]
|
| 120 |
+
|
| 121 |
+
query_position = position_ids
|
| 122 |
+
key_position = position_ids if q_len != 1 else torch.arange(kv_seq_len, dtype=position_ids.dtype).to(query_position.device).view(1, kv_seq_len) # only consider bsz=1 for now.
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
neighbor_q_cos, neighbor_q_sin = self.rotary_emb(value_states, query_position)#, seq_len=None)
|
| 127 |
+
neighbor_k_cos, neighbor_k_sin = self.rotary_emb(value_states, key_position)#, seq_len=None)
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
_re_group_size_2 = 0 if query_position.max() < group_size_2 else group_size_2 # in case that, the smallest q position, g2-g2//g1 exceed the max position
|
| 131 |
+
group_query_position = query_position // group_size_1 + _re_group_size_2 - _re_group_size_2 // group_size_1
|
| 132 |
+
group_key_position = key_position // group_size_1
|
| 133 |
+
|
| 134 |
+
group_q_cos, group_q_sin = self.rotary_emb(value_states, group_query_position)#, seq_len=None)
|
| 135 |
+
group_k_cos, group_k_sin = self.rotary_emb(value_states, group_key_position)#, seq_len=None)
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
neighbor_query_states, _ = apply_rotary_pos_emb(scaled_query, None, neighbor_q_cos, neighbor_q_sin, None)
|
| 140 |
+
_, neighbor_key_states = apply_rotary_pos_emb(None, key_states, neighbor_k_cos, neighbor_k_sin, None)
|
| 141 |
+
group_query_states, _ = apply_rotary_pos_emb(scaled_query, None, group_q_cos, group_q_sin, None)
|
| 142 |
+
_, group_key_states = apply_rotary_pos_emb(None, key_states, group_k_cos, group_k_sin, None)
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
neighbor_key_states = repeat_kv(neighbor_key_states, self.num_key_value_groups)
|
| 147 |
+
group_key_states = repeat_kv(group_key_states, self.num_key_value_groups)
|
| 148 |
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
neighbor_attn_weights = torch.matmul(neighbor_query_states, neighbor_key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
| 153 |
+
group_attn_weights = torch.matmul(group_query_states, group_key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
if attention_mask is not None: # no matter the length, we just slice it
|
| 157 |
+
if cache_position is not None:
|
| 158 |
+
causal_mask = attention_mask[:, :, cache_position, : key_states.shape[-2]]
|
| 159 |
+
else:
|
| 160 |
+
causal_mask = attention_mask
|
| 161 |
+
group_attn_weights = group_attn_weights + causal_mask
|
| 162 |
+
neighbor_attn_weights = neighbor_attn_weights + causal_mask
|
| 163 |
+
|
| 164 |
+
if q_len == 1:
|
| 165 |
+
neighbor_attention_mask = torch.zeros((q_len, kv_seq_len), device=neighbor_attn_weights.device)
|
| 166 |
+
neighbor_attention_mask[:, -group_size_2:] = 1
|
| 167 |
+
elif q_len == kv_seq_len:
|
| 168 |
+
neighbor_attention_mask = torch.ones((q_len, kv_seq_len), device=neighbor_attn_weights.device)
|
| 169 |
+
neighbor_attention_mask = torch.tril(neighbor_attention_mask)
|
| 170 |
+
if q_len-group_size_2 > 0:
|
| 171 |
+
group_attention_mask = torch.tril(torch.ones((q_len-group_size_2, kv_seq_len-group_size_2), device=group_attn_weights.device))
|
| 172 |
+
neighbor_attention_mask[group_size_2:, :-group_size_2] -= group_attention_mask
|
| 173 |
+
else:
|
| 174 |
+
raise ValueError("q_len should be 1 or seq_len.")
|
| 175 |
+
|
| 176 |
+
neighbor_attention_mask = neighbor_attention_mask.bool()
|
| 177 |
+
attn_weights = torch.where(neighbor_attention_mask, neighbor_attn_weights, group_attn_weights)
|
| 178 |
+
|
| 179 |
+
# upcast attention to fp32
|
| 180 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
| 181 |
+
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
| 182 |
+
attn_output = torch.matmul(attn_weights, value_states)
|
| 183 |
+
|
| 184 |
+
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
| 185 |
+
raise ValueError(
|
| 186 |
+
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
| 187 |
+
f" {attn_output.size()}"
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 192 |
+
|
| 193 |
+
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
| 194 |
+
|
| 195 |
+
if self.config.pretraining_tp > 1:
|
| 196 |
+
attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
|
| 197 |
+
o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
|
| 198 |
+
attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
|
| 199 |
+
else:
|
| 200 |
+
attn_output = self.o_proj(attn_output)
|
| 201 |
+
|
| 202 |
+
if not output_attentions:
|
| 203 |
+
attn_weights = None
|
| 204 |
+
|
| 205 |
+
return attn_output, attn_weights, past_key_value
|
| 206 |
+
|
| 207 |
+
def flash_self_extend_forward(
|
| 208 |
+
self,
|
| 209 |
+
hidden_states: torch.Tensor,
|
| 210 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 211 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 212 |
+
past_key_value: Optional[Cache] = None,
|
| 213 |
+
output_attentions: bool = False,
|
| 214 |
+
use_cache: bool = False,
|
| 215 |
+
group_size_1: Optional[float] = 8,
|
| 216 |
+
group_size_2: Optional[float] = 1024,
|
| 217 |
+
scale_base: Optional[int] = -1,
|
| 218 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 219 |
+
**kwargs,
|
| 220 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 221 |
+
"""
|
| 222 |
+
Require updating tansformers to >= 4.38.2, flash_attn >= 2.5.6
|
| 223 |
+
a. Only support causal mask.
|
| 224 |
+
b. Don't support atttention_mask.
|
| 225 |
+
c. Never test it with batch size > 1.
|
| 226 |
+
d. Only support q_len = 1 or q_len = seq_len.
|
| 227 |
+
"""
|
| 228 |
+
if "padding_mask" in kwargs:
|
| 229 |
+
warnings.warn(
|
| 230 |
+
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
|
| 231 |
+
)
|
| 232 |
+
attention_mask = kwargs.pop("padding_mask")
|
| 233 |
+
|
| 234 |
+
bsz, q_len, _ = hidden_states.size()
|
| 235 |
+
|
| 236 |
+
query_states = self.q_proj(hidden_states)
|
| 237 |
+
key_states = self.k_proj(hidden_states)
|
| 238 |
+
value_states = self.v_proj(hidden_states)
|
| 239 |
+
|
| 240 |
+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 241 |
+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 242 |
+
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 243 |
+
|
| 244 |
+
if scale_base > 0:
|
| 245 |
+
scaled_query = query_states * ((position_ids + 1)[:, None, :, None].log() / np.log(scale_base)).clip(1).to(query_states.dtype) # log scale
|
| 246 |
+
#scaled_query = query_states * (((0.1*(((position_ids+1)[:, None, :, None]/scale_base).log())+1)**2).clip(1)).to(query_states.dtype) # Yarn scale
|
| 247 |
+
else:
|
| 248 |
+
scaled_query = query_states
|
| 249 |
+
|
| 250 |
+
past_key_value = getattr(self, "past_key_value", past_key_value)
|
| 251 |
+
if past_key_value is not None:
|
| 252 |
+
# sin and cos are specific to RoPE models; position_ids needed for the static cache
|
| 253 |
+
cache_kwargs = {"cache_position": cache_position}
|
| 254 |
+
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
| 255 |
+
kv_seq_len = key_states.shape[-2]
|
| 256 |
+
|
| 257 |
+
query_position = position_ids
|
| 258 |
+
# only consider bsz=1 for now.
|
| 259 |
+
key_position = position_ids if q_len != 1 else torch.arange(kv_seq_len, dtype=position_ids.dtype).to(query_position.device).view(1, kv_seq_len)
|
| 260 |
+
attn_dropout = self.config.attention_dropout if self.training else 0.0
|
| 261 |
+
if q_len == 1:
|
| 262 |
+
# We implement the case q_len == 1 separately, by manipulating positions.
|
| 263 |
+
# for our flash implementation doesnot work for decoding stage at the releasing time.
|
| 264 |
+
|
| 265 |
+
neighbor_key_position = position_ids[:, -1] - key_position
|
| 266 |
+
_re_group_size_2 = 0 if position_ids.max() < group_size_2 else group_size_2
|
| 267 |
+
group_key_position = position_ids[:, -1]//group_size_1 - key_position//group_size_1 + (_re_group_size_2 - _re_group_size_2//group_size_1)
|
| 268 |
+
decode_key_position = torch.cat([group_key_position[:, :-group_size_2], neighbor_key_position[:,-group_size_2:]], dim=1)
|
| 269 |
+
|
| 270 |
+
decode_k_cos, decode_k_sin = self.rotary_emb(value_states, decode_key_position)#, seq_len=None)
|
| 271 |
+
#import pdb; pdb.set_trace()
|
| 272 |
+
#neighbor_query_states, _ = apply_rotary_pos_emb(scaled_query, None, cos, sin, query_position_ids)
|
| 273 |
+
decode_query_states = scaled_query.transpose(1,2).contiguous() # position 0: cos 0 = 1, sin 0 = 0
|
| 274 |
+
_, decode_key_states = apply_rotary_pos_emb(None, key_states, decode_k_cos, -decode_k_sin, decode_key_position)
|
| 275 |
+
|
| 276 |
+
decode_key_states = repeat_kv(decode_key_states, self.num_key_value_groups).transpose(1, 2).contiguous()
|
| 277 |
+
decode_value_states = repeat_kv(value_states, self.num_key_value_groups).transpose(1, 2).contiguous()
|
| 278 |
+
|
| 279 |
+
attn_output = flash_attn_func(decode_query_states,
|
| 280 |
+
decode_key_states,
|
| 281 |
+
decode_value_states,
|
| 282 |
+
attn_dropout,
|
| 283 |
+
softmax_scale=None,
|
| 284 |
+
causal=True)
|
| 285 |
+
elif q_len == kv_seq_len:
|
| 286 |
+
# set correct position_ids & apply RoPE.
|
| 287 |
+
neighbor_q_cos, neighbor_q_sin = self.rotary_emb(value_states, query_position)#, seq_len=None)
|
| 288 |
+
neighbor_k_cos, neighbor_k_sin = self.rotary_emb(value_states, key_position)#, seq_len=None)
|
| 289 |
+
|
| 290 |
+
_re_group_size_2 = 0 if query_position.max() < group_size_2 else group_size_2 # in case that, the smallest q position, g2-g2//g1 exceed the max position
|
| 291 |
+
group_query_position = query_position // group_size_1 + _re_group_size_2 - _re_group_size_2 / group_size_1
|
| 292 |
+
group_key_position = key_position // group_size_1
|
| 293 |
+
|
| 294 |
+
group_q_cos, group_q_sin = self.rotary_emb(value_states, group_query_position)#, seq_len=None)
|
| 295 |
+
group_k_cos, group_k_sin = self.rotary_emb(value_states, group_key_position)#, seq_len=None)
|
| 296 |
+
|
| 297 |
+
neighbor_query_states, _ = apply_rotary_pos_emb(scaled_query, None, neighbor_q_cos, neighbor_q_sin, None)
|
| 298 |
+
_, neighbor_key_states = apply_rotary_pos_emb(None, key_states, neighbor_k_cos, neighbor_k_sin, None)
|
| 299 |
+
group_query_states, _ = apply_rotary_pos_emb(scaled_query, None, group_q_cos, group_q_sin, None)
|
| 300 |
+
_, group_key_states = apply_rotary_pos_emb(None, key_states, group_k_cos, group_k_sin, None)
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
neighbor_query_states = neighbor_query_states.transpose(1, 2).contiguous()
|
| 304 |
+
neighbor_key_states = repeat_kv(neighbor_key_states, self.num_key_value_groups).transpose(1, 2).contiguous()
|
| 305 |
+
group_query_states = group_query_states.transpose(1, 2).contiguous()
|
| 306 |
+
group_key_states = repeat_kv(group_key_states, self.num_key_value_groups).transpose(1, 2).contiguous()
|
| 307 |
+
value_states = repeat_kv(value_states, self.num_key_value_groups).transpose(1, 2).contiguous()
|
| 308 |
+
|
| 309 |
+
attn_output = self_extend_flash_forward(self,
|
| 310 |
+
query_position,
|
| 311 |
+
group_size_2,
|
| 312 |
+
neighbor_query_states,
|
| 313 |
+
neighbor_key_states,
|
| 314 |
+
group_query_states,
|
| 315 |
+
group_key_states,
|
| 316 |
+
value_states,
|
| 317 |
+
attention_mask,
|
| 318 |
+
bsz,
|
| 319 |
+
q_len,
|
| 320 |
+
kv_seq_len,
|
| 321 |
+
attn_dropout,
|
| 322 |
+
)
|
| 323 |
+
else:
|
| 324 |
+
raise ValueError("q_len should be 1 or seq_len.")
|
| 325 |
+
|
| 326 |
+
attn_output = attn_output.contiguous()
|
| 327 |
+
attn_output = attn_output.view(bsz, q_len, -1).contiguous()
|
| 328 |
+
attn_output = self.o_proj(attn_output)
|
| 329 |
+
|
| 330 |
+
if not output_attentions:
|
| 331 |
+
attn_weights = None
|
| 332 |
+
return attn_output, attn_weights, past_key_value
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
def lm_infinite_forward(
|
| 337 |
+
self,
|
| 338 |
+
hidden_states: torch.Tensor,
|
| 339 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 340 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 341 |
+
past_key_value: Optional[Cache] = None,
|
| 342 |
+
output_attentions: bool = False,
|
| 343 |
+
use_cache: bool = False,
|
| 344 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 345 |
+
group_size_1: Optional[float] = 8,
|
| 346 |
+
group_size_2: Optional[float] = 1024,
|
| 347 |
+
initial_num: Optional[int] = 1,
|
| 348 |
+
scale_base: Optional[int] = -1,
|
| 349 |
+
**kwargs,
|
| 350 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 351 |
+
if "padding_mask" in kwargs:
|
| 352 |
+
warnings.warn(
|
| 353 |
+
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
|
| 354 |
+
)
|
| 355 |
+
|
| 356 |
+
bsz, q_len, _ = hidden_states.size()
|
| 357 |
+
|
| 358 |
+
|
| 359 |
+
if self.config.pretraining_tp > 1:
|
| 360 |
+
key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
|
| 361 |
+
query_slices = self.q_proj.weight.split(
|
| 362 |
+
(self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
|
| 363 |
+
)
|
| 364 |
+
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
|
| 365 |
+
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
|
| 366 |
+
|
| 367 |
+
query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
|
| 368 |
+
query_states = torch.cat(query_states, dim=-1)
|
| 369 |
+
|
| 370 |
+
key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
|
| 371 |
+
key_states = torch.cat(key_states, dim=-1)
|
| 372 |
+
|
| 373 |
+
value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
|
| 374 |
+
value_states = torch.cat(value_states, dim=-1)
|
| 375 |
+
|
| 376 |
+
else:
|
| 377 |
+
query_states = self.q_proj(hidden_states)
|
| 378 |
+
key_states = self.k_proj(hidden_states)
|
| 379 |
+
value_states = self.v_proj(hidden_states)
|
| 380 |
+
|
| 381 |
+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 382 |
+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 383 |
+
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 384 |
+
|
| 385 |
+
if scale_base > 0:
|
| 386 |
+
scaled_query = query_states * ((position_ids + 1)[:, None, :, None].log() / np.log(scale_base)).clip(1).to(query_states.dtype) # log scale
|
| 387 |
+
#scaled_query = query_states * (((0.1*(((position_ids+1)[:, None, :, None]/scale_base).log())+1)**2).clip(1)).to(query_states.dtype) # Yarn scale
|
| 388 |
+
else:
|
| 389 |
+
scaled_query = query_states
|
| 390 |
+
|
| 391 |
+
past_key_value = getattr(self, "past_key_value", past_key_value)
|
| 392 |
+
if past_key_value is not None:
|
| 393 |
+
# sin and cos are specific to RoPE models; position_ids needed for the static cache
|
| 394 |
+
cache_kwargs = {"cache_position": cache_position}
|
| 395 |
+
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
| 396 |
+
kv_seq_len = key_states.shape[-2]
|
| 397 |
+
|
| 398 |
+
query_position = position_ids
|
| 399 |
+
key_position = position_ids if q_len != 1 else torch.arange(kv_seq_len, dtype=position_ids.dtype).to(query_position.device).view(1, kv_seq_len) # only consider bsz=1 for now.
|
| 400 |
+
|
| 401 |
+
|
| 402 |
+
|
| 403 |
+
neighbor_q_cos, neighbor_q_sin = self.rotary_emb(value_states, query_position)#, seq_len=None)
|
| 404 |
+
neighbor_k_cos, neighbor_k_sin = self.rotary_emb(value_states, key_position)#, seq_len=None)
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
_re_group_size_2 = 0 if query_position.max() < group_size_2 else group_size_2 # in case that, the smallest q position, g2-g2//g1 exceed the max position
|
| 408 |
+
group_query_position = query_position // group_size_1 + _re_group_size_2 - _re_group_size_2 / group_size_1
|
| 409 |
+
group_key_position = key_position // group_size_1
|
| 410 |
+
|
| 411 |
+
group_q_cos, group_q_sin = self.rotary_emb(value_states, group_query_position)#, seq_len=None)
|
| 412 |
+
group_k_cos, group_k_sin = self.rotary_emb(value_states, group_key_position)#, seq_len=None)
|
| 413 |
+
|
| 414 |
+
|
| 415 |
+
|
| 416 |
+
neighbor_query_states, _ = apply_rotary_pos_emb(scaled_query, None, neighbor_q_cos, neighbor_q_sin, None)
|
| 417 |
+
_, neighbor_key_states = apply_rotary_pos_emb(None, key_states, neighbor_k_cos, neighbor_k_sin, None)
|
| 418 |
+
group_query_states, _ = apply_rotary_pos_emb(scaled_query, None, group_q_cos, group_q_sin, None)
|
| 419 |
+
_, group_key_states = apply_rotary_pos_emb(None, key_states, group_k_cos, group_k_sin, None)
|
| 420 |
+
|
| 421 |
+
|
| 422 |
+
|
| 423 |
+
neighbor_key_states = repeat_kv(neighbor_key_states, self.num_key_value_groups)
|
| 424 |
+
group_key_states = repeat_kv(group_key_states, self.num_key_value_groups)
|
| 425 |
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
| 426 |
+
|
| 427 |
+
|
| 428 |
+
|
| 429 |
+
neighbor_attn_weights = torch.matmul(neighbor_query_states, neighbor_key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
| 430 |
+
group_attn_weights = torch.matmul(group_query_states, group_key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
| 431 |
+
|
| 432 |
+
|
| 433 |
+
if attention_mask is not None: # no matter the length, we just slice it
|
| 434 |
+
if cache_position is not None:
|
| 435 |
+
causal_mask = attention_mask[:, :, cache_position, : key_states.shape[-2]]
|
| 436 |
+
else:
|
| 437 |
+
causal_mask = attention_mask
|
| 438 |
+
group_attn_weights = group_attn_weights + causal_mask
|
| 439 |
+
neighbor_attn_weights = neighbor_attn_weights + causal_mask
|
| 440 |
+
|
| 441 |
+
if q_len == 1:
|
| 442 |
+
neighbor_attention_mask = torch.zeros((q_len, kv_seq_len), device=neighbor_attn_weights.device)
|
| 443 |
+
neighbor_attention_mask[:, -group_size_2:] = 1
|
| 444 |
+
elif q_len == kv_seq_len:
|
| 445 |
+
neighbor_attention_mask = torch.ones((q_len, kv_seq_len), device=neighbor_attn_weights.device)
|
| 446 |
+
neighbor_attention_mask = torch.tril(neighbor_attention_mask)
|
| 447 |
+
if q_len-group_size_2 > 0:
|
| 448 |
+
group_attention_mask = torch.tril(torch.ones((q_len-group_size_2, kv_seq_len-group_size_2), device=group_attn_weights.device))
|
| 449 |
+
neighbor_attention_mask[group_size_2:, :-group_size_2] -= group_attention_mask
|
| 450 |
+
else:
|
| 451 |
+
raise ValueError("q_len should be 1 or seq_len.")
|
| 452 |
+
|
| 453 |
+
neighbor_attention_mask = neighbor_attention_mask.bool()
|
| 454 |
+
attn_weights = torch.where(neighbor_attention_mask, neighbor_attn_weights, group_attn_weights)
|
| 455 |
+
|
| 456 |
+
# upcast attention to fp32
|
| 457 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
| 458 |
+
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
| 459 |
+
attn_output = torch.matmul(attn_weights, value_states)
|
| 460 |
+
|
| 461 |
+
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
| 462 |
+
raise ValueError(
|
| 463 |
+
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
| 464 |
+
f" {attn_output.size()}"
|
| 465 |
+
)
|
| 466 |
+
|
| 467 |
+
|
| 468 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 469 |
+
|
| 470 |
+
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
| 471 |
+
|
| 472 |
+
if self.config.pretraining_tp > 1:
|
| 473 |
+
attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
|
| 474 |
+
o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
|
| 475 |
+
attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
|
| 476 |
+
else:
|
| 477 |
+
attn_output = self.o_proj(attn_output)
|
| 478 |
+
|
| 479 |
+
if not output_attentions:
|
| 480 |
+
attn_weights = None
|
| 481 |
+
|
| 482 |
+
return attn_output, attn_weights, past_key_value
|
self_extend_patch/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from . import Llama
|
self_extend_patch/selfextend_flash_attn.py
ADDED
|
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
# must replace orginal flash forward method with the following one first, to enbale the window feature.
|
| 5 |
+
def flash_attention2_forward_with_window_size(
|
| 6 |
+
self,
|
| 7 |
+
query_states,
|
| 8 |
+
key_states,
|
| 9 |
+
value_states,
|
| 10 |
+
attention_mask,
|
| 11 |
+
query_length,
|
| 12 |
+
dropout=0.0,
|
| 13 |
+
softmax_scale=None,
|
| 14 |
+
window_size=[-1, -1],
|
| 15 |
+
return_attn_probs=False,
|
| 16 |
+
):
|
| 17 |
+
"""
|
| 18 |
+
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
|
| 19 |
+
first unpad the input, then computes the attention scores and pad the final attention scores.
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
query_states (`torch.Tensor`):
|
| 23 |
+
Input query states to be passed to Flash Attention API
|
| 24 |
+
key_states (`torch.Tensor`):
|
| 25 |
+
Input key states to be passed to Flash Attention API
|
| 26 |
+
value_states (`torch.Tensor`):
|
| 27 |
+
Input value states to be passed to Flash Attention API
|
| 28 |
+
attention_mask (`torch.Tensor`):
|
| 29 |
+
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
|
| 30 |
+
position of padding tokens and 1 for the position of non-padding tokens.
|
| 31 |
+
dropout (`int`, *optional*):
|
| 32 |
+
Attention dropout
|
| 33 |
+
softmax_scale (`float`, *optional*):
|
| 34 |
+
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
|
| 35 |
+
window_size ([Int, Int])
|
| 36 |
+
The left & right window size for Flash Attention. Default to [-1, -1] which means no window size is used.
|
| 37 |
+
return_attn_probs (`bool`, *optional*):
|
| 38 |
+
Whether to return the attention softmax logssumexp and probabilities. Default to False.
|
| 39 |
+
"""
|
| 40 |
+
if not self._flash_attn_uses_top_left_mask:
|
| 41 |
+
causal = self.is_causal
|
| 42 |
+
else:
|
| 43 |
+
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
|
| 44 |
+
causal = self.is_causal and query_length != 1
|
| 45 |
+
|
| 46 |
+
# Contains at least one padding token in the sequence
|
| 47 |
+
if attention_mask is not None:
|
| 48 |
+
batch_size = query_states.shape[0]
|
| 49 |
+
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
|
| 50 |
+
query_states, key_states, value_states, attention_mask, query_length
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
|
| 54 |
+
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
|
| 55 |
+
attn_output_unpad, softmax_lse, S_dmask = flash_attn_varlen_func(
|
| 56 |
+
query_states,
|
| 57 |
+
key_states,
|
| 58 |
+
value_states,
|
| 59 |
+
cu_seqlens_q=cu_seqlens_q,
|
| 60 |
+
cu_seqlens_k=cu_seqlens_k,
|
| 61 |
+
max_seqlen_q=max_seqlen_in_batch_q,
|
| 62 |
+
max_seqlen_k=max_seqlen_in_batch_k,
|
| 63 |
+
dropout_p=dropout,
|
| 64 |
+
softmax_scale=softmax_scale,
|
| 65 |
+
causal=causal,
|
| 66 |
+
window_size=window_size,
|
| 67 |
+
return_attn_probs=True,
|
| 68 |
+
)
|
| 69 |
+
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
|
| 70 |
+
else:
|
| 71 |
+
attn_output, softmax_lse, S_dmask = flash_attn_func(
|
| 72 |
+
query_states,
|
| 73 |
+
key_states,
|
| 74 |
+
value_states,
|
| 75 |
+
dropout,
|
| 76 |
+
softmax_scale=softmax_scale,
|
| 77 |
+
causal=causal,
|
| 78 |
+
window_size=window_size,
|
| 79 |
+
return_attn_probs=True,
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
if return_attn_probs:
|
| 83 |
+
return attn_output, softmax_lse, S_dmask
|
| 84 |
+
else:
|
| 85 |
+
return attn_output
|
| 86 |
+
|
| 87 |
+
def self_extend_flash_forward(
|
| 88 |
+
model_self,
|
| 89 |
+
query_position,
|
| 90 |
+
group_size_2,
|
| 91 |
+
neighbor_query_states,
|
| 92 |
+
neighbor_key_states,
|
| 93 |
+
group_query_states,
|
| 94 |
+
group_key_states,
|
| 95 |
+
value_states,
|
| 96 |
+
attention_mask,
|
| 97 |
+
bsz,
|
| 98 |
+
q_len,
|
| 99 |
+
kv_seq_len,
|
| 100 |
+
attn_dropout,
|
| 101 |
+
):
|
| 102 |
+
|
| 103 |
+
if query_position.max() >= group_size_2:
|
| 104 |
+
neighbor_attn_output, neighbor_softmax_lse_right_padded, neighbor_prob = model_self._flash_attention_forward(
|
| 105 |
+
neighbor_query_states,
|
| 106 |
+
neighbor_key_states,
|
| 107 |
+
value_states,
|
| 108 |
+
attention_mask,
|
| 109 |
+
q_len,
|
| 110 |
+
dropout=attn_dropout,
|
| 111 |
+
window_size=[group_size_2 - 1, 0],
|
| 112 |
+
# right dim here does not matter and can be -1, or > 0 due to causal mask
|
| 113 |
+
return_attn_probs=True,
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
group_attention_len = (
|
| 117 |
+
kv_seq_len - group_size_2
|
| 118 |
+
) # here we should use kv_seq_len rather than max_kv_len since we have paddings in qkv and attention_mask
|
| 119 |
+
|
| 120 |
+
group_attention_mask = attention_mask[:, :group_attention_len] if not attention_mask is None else None
|
| 121 |
+
group_attn_output, group_softmax_lse_right_padded, group_prob = model_self._flash_attention_forward(
|
| 122 |
+
group_query_states[:, -group_attention_len:, :, :],
|
| 123 |
+
group_key_states[:, :group_attention_len, :, :],
|
| 124 |
+
value_states[:, :group_attention_len, :, :],
|
| 125 |
+
group_attention_mask,
|
| 126 |
+
group_query_states[:, -group_attention_len:, :, :].shape[1],
|
| 127 |
+
dropout=attn_dropout,
|
| 128 |
+
window_size=[-1, -1],
|
| 129 |
+
return_attn_probs=True,
|
| 130 |
+
) # note that kv and q's indexing are different! also query size could be different from kv length and very small during generation compared to prefilling
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
# normalize lse first
|
| 134 |
+
neighbor_seq_length = torch.Tensor([kv_seq_len,]).long().expand(bsz, 1) if attention_mask is None else torch.sum(attention_mask, axis=1, keepdim=True) # [batch_size, 1]
|
| 135 |
+
group_seq_length = torch.Tensor([group_attention_len,]).long().expand(bsz, 1) if attention_mask is None else torch.sum(attention_mask[:, :group_attention_len], axis=1, keepdim=True) # [batch_size, 1]
|
| 136 |
+
|
| 137 |
+
# convert align left to align right and convert exp(0) to 0
|
| 138 |
+
neighbor_softmax_lse = torch.zeros_like(neighbor_softmax_lse_right_padded)
|
| 139 |
+
group_softmax_lse = torch.zeros_like(group_softmax_lse_right_padded)
|
| 140 |
+
for idx in range(bsz):
|
| 141 |
+
if neighbor_seq_length[idx] > 0:
|
| 142 |
+
neighbor_softmax_lse[idx, :, -neighbor_seq_length[idx] :] = neighbor_softmax_lse_right_padded[
|
| 143 |
+
idx, :, : neighbor_seq_length[idx]
|
| 144 |
+
]
|
| 145 |
+
if group_seq_length[idx] > 0:
|
| 146 |
+
group_softmax_lse[idx, :, -group_seq_length[idx] :] = group_softmax_lse_right_padded[
|
| 147 |
+
idx, :, : group_seq_length[idx]
|
| 148 |
+
]
|
| 149 |
+
|
| 150 |
+
# attn_output size is [batch_size, max_seq_len (not the true one), query_length, dim]
|
| 151 |
+
true_neighbor_seq_max_length = neighbor_softmax_lse.shape[
|
| 152 |
+
-1
|
| 153 |
+
] # it could be smaller than query_length due to the attention_mask
|
| 154 |
+
true_group_seq_max_length = group_softmax_lse.shape[
|
| 155 |
+
-1
|
| 156 |
+
] # it could be smaller than group_query_layer[:, -group_attention_len:, :, :].shape[1] due to the attention_mask[:, :group_attention_len]
|
| 157 |
+
|
| 158 |
+
neighbor_softmax_lse = neighbor_softmax_lse.transpose(1, 2).unsqueeze(
|
| 159 |
+
-1
|
| 160 |
+
) # [batch_size, true_neighbor_seq_max_length, self.num_heads, 1]
|
| 161 |
+
group_softmax_lse = group_softmax_lse.transpose(1, 2).unsqueeze(
|
| 162 |
+
-1
|
| 163 |
+
) # [batch_size, true_group_seq_max_length, self.num_heads, 1]
|
| 164 |
+
|
| 165 |
+
lse_gap = group_softmax_lse - neighbor_softmax_lse[:, -true_group_seq_max_length:, :, :]
|
| 166 |
+
#if torch.isinf(neighbor_softmax_lse).any() or torch.isnan(neighbor_softmax_lse).any():
|
| 167 |
+
# import pdb; pdb.set_trace()
|
| 168 |
+
|
| 169 |
+
neighbor_softmax_lse[:, -true_group_seq_max_length:, :, :] = 1 / (1 + torch.exp(lse_gap))
|
| 170 |
+
neighbor_softmax_lse[:, :-true_group_seq_max_length, :, :] = 1.
|
| 171 |
+
group_softmax_lse = 1 / (1 + torch.exp(-lse_gap))
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
neighbor_attn_output[:, -true_neighbor_seq_max_length:, ...] = (
|
| 176 |
+
neighbor_attn_output[:, -true_neighbor_seq_max_length:, ...] * neighbor_softmax_lse
|
| 177 |
+
)
|
| 178 |
+
group_attn_output[:, -true_group_seq_max_length:, ...] = (
|
| 179 |
+
group_attn_output[:, -true_group_seq_max_length:, ...] * group_softmax_lse
|
| 180 |
+
)
|
| 181 |
+
attn_output = torch.empty_like(neighbor_attn_output).copy_(
|
| 182 |
+
neighbor_attn_output
|
| 183 |
+
) # might be slightly faster than clone
|
| 184 |
+
#attn_output[:, group_size_2:, ...] += group_attn_output
|
| 185 |
+
attn_output[:, group_size_2-kv_seq_len:, ...] += group_attn_output
|
| 186 |
+
attn_output = torch.nan_to_num(attn_output, nan=0)
|
| 187 |
+
|
| 188 |
+
else:
|
| 189 |
+
attn_output = model_self._flash_attention_forward(
|
| 190 |
+
neighbor_query_states,
|
| 191 |
+
neighbor_key_states,
|
| 192 |
+
value_states,
|
| 193 |
+
attention_mask,
|
| 194 |
+
q_len,
|
| 195 |
+
dropout=attn_dropout,
|
| 196 |
+
window_size=[-1, -1],
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
return attn_output
|
self_extend_patch/selfextend_flash_attn_triton.py
ADDED
|
@@ -0,0 +1,278 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
import triton
|
| 5 |
+
import triton.language as tl
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def self_extend_flash_forward_triton(
|
| 9 |
+
model_self,
|
| 10 |
+
query_position,
|
| 11 |
+
group_size_2,
|
| 12 |
+
neighbor_query_states,
|
| 13 |
+
neighbor_key_states,
|
| 14 |
+
group_query_states,
|
| 15 |
+
group_key_states,
|
| 16 |
+
value_states,
|
| 17 |
+
attention_mask,
|
| 18 |
+
bsz,
|
| 19 |
+
q_len,
|
| 20 |
+
kv_seq_len,
|
| 21 |
+
attn_dropout,
|
| 22 |
+
):
|
| 23 |
+
|
| 24 |
+
o = _self_extend_flash_forward_triton(q=neighbor_query_states,
|
| 25 |
+
k=neighbor_key_states,
|
| 26 |
+
q1=group_query_states,
|
| 27 |
+
k1=group_key_states,
|
| 28 |
+
v=value_states,
|
| 29 |
+
causal=(q_len == kv_seq_len),
|
| 30 |
+
sm_scale=1. / math.sqrt(neighbor_query_states.shape[-1]),
|
| 31 |
+
window=group_size_2)
|
| 32 |
+
o = o.transpose(1, 2).contiguous()
|
| 33 |
+
# print("o", o.shape)
|
| 34 |
+
return o
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def _self_extend_flash_forward_triton(q, k, q1, k1, v, causal, sm_scale, window):
|
| 42 |
+
# shape constraints
|
| 43 |
+
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
|
| 44 |
+
assert Lq == Lk and Lk == Lv
|
| 45 |
+
assert Lk in {16, 32, 64, 128}
|
| 46 |
+
|
| 47 |
+
device = torch.cuda.device_of(q)
|
| 48 |
+
with torch.cuda.device(device):
|
| 49 |
+
o = torch.empty_like(q)
|
| 50 |
+
BLOCK_M = 128
|
| 51 |
+
BLOCK_N = 32
|
| 52 |
+
grid = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1])
|
| 53 |
+
L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
|
| 54 |
+
_fwd_kernel[grid](
|
| 55 |
+
q,
|
| 56 |
+
k,
|
| 57 |
+
q1,
|
| 58 |
+
k1,
|
| 59 |
+
v,
|
| 60 |
+
sm_scale,
|
| 61 |
+
L,
|
| 62 |
+
o,
|
| 63 |
+
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
| 64 |
+
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
| 65 |
+
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
|
| 66 |
+
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
|
| 67 |
+
q.shape[0],
|
| 68 |
+
q.shape[1],
|
| 69 |
+
q.shape[2],
|
| 70 |
+
k.shape[2],
|
| 71 |
+
BLOCK_M=BLOCK_M,
|
| 72 |
+
BLOCK_N=BLOCK_N,
|
| 73 |
+
BLOCK_DMODEL=Lk,
|
| 74 |
+
IS_CAUSAL=causal,
|
| 75 |
+
WINDOW=window,
|
| 76 |
+
num_warps=8,
|
| 77 |
+
num_stages=2)
|
| 78 |
+
|
| 79 |
+
return o
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
@triton.heuristics(
|
| 85 |
+
{
|
| 86 |
+
"EVEN_M": lambda args: args["Q_CTX"] % args["BLOCK_M"] == 0,
|
| 87 |
+
"EVEN_N": lambda args: args["KV_CTX"] % args["BLOCK_N"] == 0,
|
| 88 |
+
}
|
| 89 |
+
)
|
| 90 |
+
@triton.jit
|
| 91 |
+
def _fwd_kernel(
|
| 92 |
+
Q,
|
| 93 |
+
K,
|
| 94 |
+
Q1,
|
| 95 |
+
K1,
|
| 96 |
+
V,
|
| 97 |
+
sm_scale,
|
| 98 |
+
L,
|
| 99 |
+
Out,
|
| 100 |
+
stride_qz, stride_qh, stride_qm, stride_qk,
|
| 101 |
+
stride_kz, stride_kh, stride_kn, stride_kk,
|
| 102 |
+
stride_vz, stride_vh, stride_vn, stride_vk,
|
| 103 |
+
stride_oz, stride_oh, stride_om, stride_on,
|
| 104 |
+
Z,
|
| 105 |
+
H,
|
| 106 |
+
Q_CTX,
|
| 107 |
+
KV_CTX,
|
| 108 |
+
BLOCK_M: tl.constexpr,
|
| 109 |
+
BLOCK_DMODEL: tl.constexpr,
|
| 110 |
+
BLOCK_N: tl.constexpr,
|
| 111 |
+
IS_CAUSAL: tl.constexpr,
|
| 112 |
+
WINDOW: tl.constexpr,
|
| 113 |
+
EVEN_M: tl.constexpr,
|
| 114 |
+
EVEN_N: tl.constexpr
|
| 115 |
+
|
| 116 |
+
):
|
| 117 |
+
start_m = tl.program_id(0)
|
| 118 |
+
off_hz = tl.program_id(1)
|
| 119 |
+
# qvk_offset = off_hz * stride_qh
|
| 120 |
+
q_offset = off_hz * stride_qh
|
| 121 |
+
vk_offset = off_hz * stride_kh
|
| 122 |
+
# vk_offset = q_offset
|
| 123 |
+
|
| 124 |
+
Q_block_ptr = tl.make_block_ptr(
|
| 125 |
+
base=Q + q_offset,
|
| 126 |
+
shape=(Q_CTX, BLOCK_DMODEL),
|
| 127 |
+
strides=(stride_qm, stride_qk),
|
| 128 |
+
offsets=(start_m * BLOCK_M, 0),
|
| 129 |
+
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
| 130 |
+
order=(1, 0)
|
| 131 |
+
)
|
| 132 |
+
K_block_ptr = tl.make_block_ptr(
|
| 133 |
+
base=K + vk_offset,
|
| 134 |
+
shape=(KV_CTX, BLOCK_DMODEL),
|
| 135 |
+
strides=(stride_kn, stride_kk),
|
| 136 |
+
offsets=(0, 0),
|
| 137 |
+
block_shape=(BLOCK_N, BLOCK_DMODEL),
|
| 138 |
+
order=(1, 0)
|
| 139 |
+
)
|
| 140 |
+
Q1_block_ptr = tl.make_block_ptr(
|
| 141 |
+
base=Q1 + q_offset,
|
| 142 |
+
shape=(Q_CTX, BLOCK_DMODEL),
|
| 143 |
+
strides=(stride_qm, stride_qk),
|
| 144 |
+
offsets=(start_m * BLOCK_M, 0),
|
| 145 |
+
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
| 146 |
+
order=(1, 0)
|
| 147 |
+
)
|
| 148 |
+
K1_block_ptr = tl.make_block_ptr(
|
| 149 |
+
base=K1 + vk_offset,
|
| 150 |
+
shape=(KV_CTX, BLOCK_DMODEL),
|
| 151 |
+
strides=(stride_kn, stride_kk),
|
| 152 |
+
offsets=(0, 0),
|
| 153 |
+
block_shape=(BLOCK_N, BLOCK_DMODEL),
|
| 154 |
+
order=(1, 0)
|
| 155 |
+
)
|
| 156 |
+
V_block_ptr = tl.make_block_ptr(
|
| 157 |
+
base=V + vk_offset,
|
| 158 |
+
shape=(KV_CTX, BLOCK_DMODEL),
|
| 159 |
+
strides=(stride_vn, stride_vk),
|
| 160 |
+
offsets=(0, 0),
|
| 161 |
+
block_shape=(BLOCK_N, BLOCK_DMODEL),
|
| 162 |
+
order=(1, 0)
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
# initialize offsets
|
| 166 |
+
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
| 167 |
+
offs_n = tl.arange(0, BLOCK_N)
|
| 168 |
+
|
| 169 |
+
# initialize pointer to m and l
|
| 170 |
+
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
| 171 |
+
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
| 172 |
+
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
| 173 |
+
|
| 174 |
+
# scale sm_scale by log_2(e) and use
|
| 175 |
+
# 2^x instead of exp in the loop because CSE and LICM
|
| 176 |
+
# don't work as expected with `exp` in the loop
|
| 177 |
+
qk_scale = sm_scale * 1.4426950408889634
|
| 178 |
+
|
| 179 |
+
# load q: it will stay in SRAM throughout
|
| 180 |
+
if EVEN_M:
|
| 181 |
+
q = tl.load(Q_block_ptr)
|
| 182 |
+
q1 = tl.load(Q1_block_ptr)
|
| 183 |
+
else:
|
| 184 |
+
q = tl.load(Q_block_ptr, boundary_check=(1,0))
|
| 185 |
+
q1 = tl.load(Q1_block_ptr, boundary_check=(1,0))
|
| 186 |
+
|
| 187 |
+
q = (q * qk_scale).to(tl.bfloat16)
|
| 188 |
+
q1 = (q1 * qk_scale).to(tl.bfloat16)
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
# Dot I trick: it converts q1, q2 into mma layout and saves shared memory
|
| 192 |
+
# better way to generate a eye matrix. avoid casting from bool
|
| 193 |
+
offs_k = tl.arange(0, BLOCK_DMODEL)
|
| 194 |
+
I = tl.where(offs_k[:, None] == offs_k,
|
| 195 |
+
tl.full((BLOCK_DMODEL, BLOCK_DMODEL), 1.0, dtype=tl.bfloat16),
|
| 196 |
+
tl.full((BLOCK_DMODEL, BLOCK_DMODEL), 0.0, dtype=tl.bfloat16))
|
| 197 |
+
q = tl.dot(q, I).to(tl.bfloat16)
|
| 198 |
+
q1 = tl.dot(q1, I).to(tl.bfloat16)
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
# loop over k, v and update accumulator
|
| 202 |
+
lo = 0
|
| 203 |
+
if IS_CAUSAL:
|
| 204 |
+
hi = tl.minimum(KV_CTX, (start_m + 1) * BLOCK_M)
|
| 205 |
+
else:
|
| 206 |
+
hi = KV_CTX
|
| 207 |
+
|
| 208 |
+
for start_n in range(lo, hi, BLOCK_N):
|
| 209 |
+
# -- load k, v --
|
| 210 |
+
if EVEN_N:
|
| 211 |
+
k = tl.load(K_block_ptr)
|
| 212 |
+
k1 = tl.load(K1_block_ptr)
|
| 213 |
+
v = tl.load(V_block_ptr)
|
| 214 |
+
else:
|
| 215 |
+
k = tl.load(K_block_ptr, boundary_check=(1,0))
|
| 216 |
+
k1 = tl.load(K1_block_ptr, boundary_check=(1,0))
|
| 217 |
+
v = tl.load(V_block_ptr, boundary_check=(1,0))
|
| 218 |
+
|
| 219 |
+
# -- compute qk ---
|
| 220 |
+
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
| 221 |
+
|
| 222 |
+
# Window masking
|
| 223 |
+
mask = ( KV_CTX - Q_CTX + offs_m[:, None]) >= (start_n + offs_n[None, :] + WINDOW)
|
| 224 |
+
qk += tl.where(mask, tl.dot(q1, tl.trans(k1)), tl.dot(q, tl.trans(k)))
|
| 225 |
+
|
| 226 |
+
# if not EVEN_N:
|
| 227 |
+
# mask = (start_n + offs_n) < KV_CTX
|
| 228 |
+
# qk = tl.where(mask, qk, float("-inf"))
|
| 229 |
+
|
| 230 |
+
if IS_CAUSAL:
|
| 231 |
+
mask = offs_m[:, None] >= (start_n + offs_n[None, :])
|
| 232 |
+
qk = tl.where(mask, qk, float("-inf"))
|
| 233 |
+
# qk += tl.dot(q, k)
|
| 234 |
+
|
| 235 |
+
# -- compute scaling constant ---
|
| 236 |
+
m_i_new = tl.maximum(m_i, tl.max(qk, 1))
|
| 237 |
+
alpha = tl.math.exp2(m_i - m_i_new)
|
| 238 |
+
p = tl.math.exp2(qk - m_i_new[:, None])
|
| 239 |
+
|
| 240 |
+
# -- scale and update acc --
|
| 241 |
+
acc_scale = l_i * 0 + alpha # workaround some compiler bug
|
| 242 |
+
acc *= acc_scale[:, None]
|
| 243 |
+
acc += tl.dot(p.to(tl.bfloat16), v)
|
| 244 |
+
|
| 245 |
+
# -- update m_i and l_i --
|
| 246 |
+
l_i = l_i * alpha + tl.sum(p, 1)
|
| 247 |
+
m_i = m_i_new
|
| 248 |
+
|
| 249 |
+
# update pointers
|
| 250 |
+
K_block_ptr = tl.advance(K_block_ptr, (BLOCK_N, 0))
|
| 251 |
+
K1_block_ptr = tl.advance(K1_block_ptr, (BLOCK_N, 0))
|
| 252 |
+
V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
# write back l and m
|
| 256 |
+
acc = acc * (1.0 / l_i[:, None])
|
| 257 |
+
l_ptrs = L + off_hz * Q_CTX + offs_m
|
| 258 |
+
|
| 259 |
+
mask_m = offs_m < Q_CTX
|
| 260 |
+
l_i = m_i + tl.math.log2(l_i)
|
| 261 |
+
if EVEN_M:
|
| 262 |
+
tl.store(l_ptrs, l_i)
|
| 263 |
+
else:
|
| 264 |
+
tl.store(l_ptrs, l_i, mask=mask_m)
|
| 265 |
+
|
| 266 |
+
# write back O
|
| 267 |
+
O_block_ptr = tl.make_block_ptr(
|
| 268 |
+
base=Out + q_offset,
|
| 269 |
+
shape=(Q_CTX, BLOCK_DMODEL),
|
| 270 |
+
strides=(stride_om, stride_on),
|
| 271 |
+
offsets=(start_m * BLOCK_M, 0),
|
| 272 |
+
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
| 273 |
+
order=(1, 0)
|
| 274 |
+
)
|
| 275 |
+
if EVEN_M:
|
| 276 |
+
tl.store(O_block_ptr, acc.to(tl.bfloat16))
|
| 277 |
+
else:
|
| 278 |
+
tl.store(O_block_ptr, acc.to(tl.bfloat16), boundary_check=(1,0))
|
self_extend_patch/triton_selfextend_flash_attn.py
ADDED
|
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import math
|
| 3 |
+
import triton
|
| 4 |
+
import triton.language as tl
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
# We don't run auto-tuning every time to keep the tutorial fast. Uncommenting
|
| 10 |
+
# the code below and commenting out the equivalent parameters is convenient for
|
| 11 |
+
# re-tuning.
|
| 12 |
+
#@triton.autotune(
|
| 13 |
+
# configs=[
|
| 14 |
+
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64}, num_stages=4, num_warps=8),
|
| 15 |
+
# triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64}, num_stages=3, num_warps=8),
|
| 16 |
+
# triton.Config({'BLOCK_M': 256, 'BLOCK_N': 32}, num_stages=3, num_warps=8),
|
| 17 |
+
# triton.Config({'BLOCK_M': 256, 'BLOCK_N': 32}, num_stages=3, num_warps=4),
|
| 18 |
+
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32}, num_stages=3, num_warps=4),
|
| 19 |
+
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32}, num_stages=4, num_warps=4),
|
| 20 |
+
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64}, num_stages=3, num_warps=4),
|
| 21 |
+
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64}, num_stages=4, num_warps=4),
|
| 22 |
+
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64}, num_stages=3, num_warps=8),
|
| 23 |
+
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64}, num_stages=7, num_warps=8),
|
| 24 |
+
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32}, num_stages=7, num_warps=8),
|
| 25 |
+
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32}, num_stages=6, num_warps=8),
|
| 26 |
+
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32}, num_stages=5, num_warps=8),
|
| 27 |
+
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32}, num_stages=4, num_warps=8),
|
| 28 |
+
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64}, num_stages=6, num_warps=4),
|
| 29 |
+
# ],
|
| 30 |
+
# key=['N_CTX'],
|
| 31 |
+
#)
|
| 32 |
+
@triton.jit
|
| 33 |
+
def _attn_fwd_prefill(Q1, K1, Q2, K2, V, sm_scale, M, Out, #
|
| 34 |
+
stride_qz, stride_qh, stride_qm, stride_qk, #
|
| 35 |
+
stride_kz, stride_kh, stride_kn, stride_kk, #
|
| 36 |
+
stride_vz, stride_vh, stride_vk, stride_vn, #
|
| 37 |
+
stride_oz, stride_oh, stride_om, stride_on, #
|
| 38 |
+
Z, H, #
|
| 39 |
+
Q_CTX: tl.constexpr, #
|
| 40 |
+
N_CTX: tl.constexpr, #
|
| 41 |
+
WINDOW: tl.constexpr, #
|
| 42 |
+
BLOCK_M: tl.constexpr, #
|
| 43 |
+
BLOCK_DMODEL: tl.constexpr, #
|
| 44 |
+
BLOCK_N: tl.constexpr, #
|
| 45 |
+
):
|
| 46 |
+
start_m = tl.program_id(0)
|
| 47 |
+
off_hz = tl.program_id(1)
|
| 48 |
+
off_z = off_hz // H
|
| 49 |
+
off_h = off_hz % H
|
| 50 |
+
qvk_offset = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh
|
| 51 |
+
|
| 52 |
+
# block pointers
|
| 53 |
+
Q1_block_ptr = tl.make_block_ptr(
|
| 54 |
+
base=Q1 + qvk_offset,
|
| 55 |
+
shape=(Q_CTX, BLOCK_DMODEL),
|
| 56 |
+
strides=(stride_qm, stride_qk),
|
| 57 |
+
offsets=(start_m * BLOCK_M, 0),
|
| 58 |
+
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
| 59 |
+
order=(1, 0),
|
| 60 |
+
)
|
| 61 |
+
Q2_block_ptr = tl.make_block_ptr(
|
| 62 |
+
base=Q2 + qvk_offset,
|
| 63 |
+
shape=(Q_CTX, BLOCK_DMODEL),
|
| 64 |
+
strides=(stride_qm, stride_qk),
|
| 65 |
+
offsets=(start_m * BLOCK_M, 0),
|
| 66 |
+
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
| 67 |
+
order=(1, 0),
|
| 68 |
+
)
|
| 69 |
+
V_block_ptr = tl.make_block_ptr(
|
| 70 |
+
base=V + qvk_offset,
|
| 71 |
+
shape=(N_CTX, BLOCK_DMODEL),
|
| 72 |
+
strides=(stride_vk, stride_vn),
|
| 73 |
+
offsets=(0, 0),
|
| 74 |
+
block_shape=(BLOCK_N, BLOCK_DMODEL),
|
| 75 |
+
order=(1, 0),
|
| 76 |
+
)
|
| 77 |
+
K1_block_ptr = tl.make_block_ptr(
|
| 78 |
+
base=K1 + qvk_offset,
|
| 79 |
+
shape=(BLOCK_DMODEL, N_CTX),
|
| 80 |
+
strides=(stride_kk, stride_kn),
|
| 81 |
+
offsets=(0, 0),
|
| 82 |
+
block_shape=(BLOCK_DMODEL, BLOCK_N),
|
| 83 |
+
order=(0, 1),
|
| 84 |
+
)
|
| 85 |
+
K2_block_ptr = tl.make_block_ptr(
|
| 86 |
+
base=K2 + qvk_offset,
|
| 87 |
+
shape=(BLOCK_DMODEL, N_CTX),
|
| 88 |
+
strides=(stride_kk, stride_kn),
|
| 89 |
+
offsets=(0, 0),
|
| 90 |
+
block_shape=(BLOCK_DMODEL, BLOCK_N),
|
| 91 |
+
order=(0, 1),
|
| 92 |
+
)
|
| 93 |
+
O_block_ptr = tl.make_block_ptr(
|
| 94 |
+
base=Out + qvk_offset,
|
| 95 |
+
shape=(Q_CTX, BLOCK_DMODEL),
|
| 96 |
+
strides=(stride_om, stride_on),
|
| 97 |
+
offsets=(start_m * BLOCK_M, 0),
|
| 98 |
+
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
| 99 |
+
order=(1, 0),
|
| 100 |
+
)
|
| 101 |
+
# initialize offsets
|
| 102 |
+
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
| 103 |
+
offs_n = tl.arange(0, BLOCK_N)
|
| 104 |
+
# initialize pointer to m and l
|
| 105 |
+
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
| 106 |
+
l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0
|
| 107 |
+
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
| 108 |
+
# load scales
|
| 109 |
+
qk_scale = sm_scale
|
| 110 |
+
qk_scale *= 1.442695040888963#1.44269504 # 1/log(2)
|
| 111 |
+
# load q: it will stay in SRAM throughout
|
| 112 |
+
#q = tl.load(Q_block_ptr)
|
| 113 |
+
if start_m * BLOCK_M + BLOCK_M > Q_CTX:
|
| 114 |
+
q1 = tl.load(Q1_block_ptr, boundary_check=(0,), padding_option='zero')
|
| 115 |
+
q2 = tl.load(Q2_block_ptr, boundary_check=(0,), padding_option='zero')
|
| 116 |
+
else:
|
| 117 |
+
q1 = tl.load(Q1_block_ptr)
|
| 118 |
+
q2 = tl.load(Q2_block_ptr)
|
| 119 |
+
#q1 = (q1 * qk_scale).to(tl.float16)
|
| 120 |
+
#q2 = (q2 * qk_scale).to(tl.float16)
|
| 121 |
+
|
| 122 |
+
lo = 0
|
| 123 |
+
hi = (start_m + 1) * BLOCK_M
|
| 124 |
+
# loop over k, v and update accumulator
|
| 125 |
+
for start_n in range(lo, hi, BLOCK_N):
|
| 126 |
+
start_n = tl.multiple_of(start_n, BLOCK_N)
|
| 127 |
+
#qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) #?
|
| 128 |
+
#qk = qk.to(tl.float16)
|
| 129 |
+
# if use condition, qk has to be float32, then convert to float16...
|
| 130 |
+
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
| 131 |
+
if start_n + BLOCK_N - 1 > start_m * BLOCK_M - 1:
|
| 132 |
+
qk += tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), 0, -1.0e6)#float("-inf"))
|
| 133 |
+
|
| 134 |
+
#qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))
|
| 135 |
+
# -- compute qk ----
|
| 136 |
+
#k = tl.load(K_block_ptr)
|
| 137 |
+
# case 1: only need group attention: q2, k2
|
| 138 |
+
if BLOCK_N + start_n <= (start_m * BLOCK_M - WINDOW + 1):
|
| 139 |
+
if BLOCK_N + start_n >= N_CTX:
|
| 140 |
+
k2 = tl.load(K2_block_ptr, boundary_check=(1,), padding_option='zero')
|
| 141 |
+
v = tl.load(V_block_ptr, boundary_check=(0,), padding_option='zero')
|
| 142 |
+
else:
|
| 143 |
+
k2 = tl.load(K2_block_ptr)
|
| 144 |
+
v = tl.load(V_block_ptr)
|
| 145 |
+
#qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
| 146 |
+
#qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float16)
|
| 147 |
+
qk += tl.dot(q2, k2)#, out_dtype=tl.float16)
|
| 148 |
+
else:
|
| 149 |
+
#case 2: only need neighbor attention: q1, k1
|
| 150 |
+
if start_n >= (start_m+1) * BLOCK_M - WINDOW:
|
| 151 |
+
if BLOCK_N + start_n >= N_CTX:
|
| 152 |
+
k1 = tl.load(K1_block_ptr, boundary_check=(1,), padding_option='zero')
|
| 153 |
+
v = tl.load(V_block_ptr, boundary_check=(0,), padding_option='zero')
|
| 154 |
+
else:
|
| 155 |
+
k1 = tl.load(K1_block_ptr)
|
| 156 |
+
v = tl.load(V_block_ptr)
|
| 157 |
+
#qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
| 158 |
+
#qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float16)
|
| 159 |
+
qk += tl.dot(q1, k1)#, out_dtype=tl.float16)
|
| 160 |
+
else:
|
| 161 |
+
#case 3: need both q1, k1 and q2, k2
|
| 162 |
+
if BLOCK_N + start_n >= N_CTX:
|
| 163 |
+
k1 = tl.load(K1_block_ptr, boundary_check=(1,), padding_option='zero')
|
| 164 |
+
k2 = tl.load(K2_block_ptr, boundary_check=(1,), padding_option='zero')
|
| 165 |
+
v = tl.load(V_block_ptr, boundary_check=(0,), padding_option='zero')
|
| 166 |
+
else:
|
| 167 |
+
k1 = tl.load(K1_block_ptr)
|
| 168 |
+
k2 = tl.load(K2_block_ptr)
|
| 169 |
+
v = tl.load(V_block_ptr)
|
| 170 |
+
#qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
| 171 |
+
#qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float16)
|
| 172 |
+
qk1 = tl.dot(q1, k1)#, out_dtype=tl.float16)
|
| 173 |
+
qk2 = tl.dot(q2, k2)#, out_dtype=tl.float16)
|
| 174 |
+
#merge_mask = tl.abs((offs_m[:, None] - (start_n + offs_n[None, :]))) >= WINDOW
|
| 175 |
+
#qk += tl.where(merge_mask, qk2, qk1)
|
| 176 |
+
qk += tl.where(tl.abs(offs_m[:, None] - (start_n + offs_n[None, :])) < WINDOW, qk1, qk2)
|
| 177 |
+
|
| 178 |
+
qk *= qk_scale
|
| 179 |
+
|
| 180 |
+
m_ij = tl.maximum(m_i, tl.max(qk, 1))
|
| 181 |
+
qk = qk - m_ij[:, None]
|
| 182 |
+
p = tl.math.exp2(qk)
|
| 183 |
+
l_ij = tl.sum(p, 1)
|
| 184 |
+
# -- update m_i and l_i
|
| 185 |
+
alpha = tl.math.exp2(m_i - m_ij)
|
| 186 |
+
l_i = l_i * alpha + l_ij
|
| 187 |
+
# -- update output accumulator --
|
| 188 |
+
acc = acc * alpha[:, None]
|
| 189 |
+
# update acc
|
| 190 |
+
#v = tl.load(V_block_ptr)
|
| 191 |
+
#v = tl.load(V_block_ptr, boundary_check=(0,), padding_option='zero')
|
| 192 |
+
acc += tl.dot(p.to(tl.float16), v)
|
| 193 |
+
# update m_i and l_i
|
| 194 |
+
m_i = m_ij
|
| 195 |
+
V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
|
| 196 |
+
K1_block_ptr = tl.advance(K1_block_ptr, (0, BLOCK_N))
|
| 197 |
+
K2_block_ptr = tl.advance(K2_block_ptr, (0, BLOCK_N))
|
| 198 |
+
|
| 199 |
+
# epilogue
|
| 200 |
+
m_i += tl.math.log2(l_i)
|
| 201 |
+
acc = acc / l_i[:, None]
|
| 202 |
+
m_ptrs = M + off_hz * Q_CTX + offs_m
|
| 203 |
+
if start_m * BLOCK_M + BLOCK_M >= Q_CTX:
|
| 204 |
+
tl.store(m_ptrs, m_i, mask=offs_m < Q_CTX)
|
| 205 |
+
tl.store(O_block_ptr, acc.to(Out.type.element_ty), boundary_check=(0,))
|
| 206 |
+
else:
|
| 207 |
+
tl.store(m_ptrs, m_i)
|
| 208 |
+
tl.store(O_block_ptr, acc.to(Out.type.element_ty))
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
def prefill_flash_forward(q1, k1, q2, k2, v, q_len, seq_len, window, sm_scale=None):
|
| 212 |
+
# shape constraints
|
| 213 |
+
Lq, Lk, Lv = q1.shape[-1], k1.shape[-1], v.shape[-1]
|
| 214 |
+
assert Lq == Lk and Lk == Lv
|
| 215 |
+
assert Lk in {16, 32, 64, 128}
|
| 216 |
+
assert q_len == seq_len or q_len == 1
|
| 217 |
+
if sm_scale is None:
|
| 218 |
+
sm_scale = 1.0 / math.sqrt(Lq) # the default scale factor.
|
| 219 |
+
o = torch.empty_like(q1, device=q1.device)
|
| 220 |
+
block_m = 128
|
| 221 |
+
block_n = 64 # if Lk <= 64 else 32
|
| 222 |
+
num_stages = 4 if Lk <= 64 else 3
|
| 223 |
+
num_warps = 4
|
| 224 |
+
# Tuning for H100
|
| 225 |
+
if torch.cuda.get_device_capability()[0] == 9:
|
| 226 |
+
num_warps = 8
|
| 227 |
+
num_stages = 7 if Lk >= 64 else 3
|
| 228 |
+
grid = (triton.cdiv(q1.shape[2], block_m), q1.shape[0] * q1.shape[1], 1)
|
| 229 |
+
M = torch.empty((q1.shape[0], q1.shape[1], q1.shape[2]), device=q1.device, dtype=torch.float32)
|
| 230 |
+
with torch.cuda.device(v.device.index):
|
| 231 |
+
# https://github.com/Dao-AILab/flash-attention/commit/9795159082f6e6c847db2bf4284fd17326c31fbd
|
| 232 |
+
# to avoid the device issue .
|
| 233 |
+
_attn_fwd_prefill[grid](
|
| 234 |
+
q1, k1, q2, k2, v, sm_scale, M, o, #
|
| 235 |
+
q1.stride(0), q1.stride(1), q1.stride(2), q1.stride(3), #
|
| 236 |
+
k1.stride(0), k1.stride(1), k1.stride(2), k1.stride(3), #
|
| 237 |
+
v.stride(0), v.stride(1), v.stride(2), v.stride(3), #
|
| 238 |
+
o.stride(0), o.stride(1), o.stride(2), o.stride(3), #
|
| 239 |
+
q1.shape[0], q1.shape[1], #
|
| 240 |
+
Q_CTX=q_len,
|
| 241 |
+
N_CTX=seq_len, #
|
| 242 |
+
BLOCK_M=block_m, #
|
| 243 |
+
BLOCK_N=block_n, #
|
| 244 |
+
WINDOW=window,
|
| 245 |
+
BLOCK_DMODEL=Lk, #
|
| 246 |
+
num_warps=num_warps, #
|
| 247 |
+
num_stages=num_stages #
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
return o
|
style.css
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
h1 {
|
| 2 |
+
text-align: center;
|
| 3 |
+
display: block;
|
| 4 |
+
}
|
| 5 |
+
|
| 6 |
+
#duplicate-button {
|
| 7 |
+
margin: auto;
|
| 8 |
+
color: white;
|
| 9 |
+
background: #1565c0;
|
| 10 |
+
border-radius: 100vh;
|
| 11 |
+
}
|
| 12 |
+
|
| 13 |
+
.contain {
|
| 14 |
+
max-width: 900px;
|
| 15 |
+
margin: auto;
|
| 16 |
+
padding-top: 1.5rem;
|
| 17 |
+
}
|
| 18 |
+
|
| 19 |
+
.s-pad {
|
| 20 |
+
display: block;
|
| 21 |
+
padding-top: 2rem;
|
| 22 |
+
height: 1px;
|
| 23 |
+
width: 100%;
|
| 24 |
+
}
|