Spaces:
Runtime error
Runtime error
0.19 implementing flash_attn
Browse files- app.py +9 -0
- requirements.txt +2 -1
app.py
CHANGED
|
@@ -4,6 +4,7 @@ import torch
|
|
| 4 |
import gradio as gr
|
| 5 |
import logging
|
| 6 |
from huggingface_hub import login
|
|
|
|
| 7 |
|
| 8 |
import os
|
| 9 |
import traceback
|
|
@@ -66,6 +67,10 @@ def load_model_a(model_id):
|
|
| 66 |
device_map="auto",
|
| 67 |
trust_remote_code=True,
|
| 68 |
).eval()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
except Exception as e:
|
| 70 |
logging.error(f'{SPACER} Error: {e}, Traceback {traceback.format_exc()}')
|
| 71 |
|
|
@@ -83,6 +88,10 @@ def load_model_b(model_id):
|
|
| 83 |
device_map="auto",
|
| 84 |
trust_remote_code=True,
|
| 85 |
).eval()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
except Exception as e:
|
| 87 |
logging.error(f'{SPACER} Error: {e}, Traceback {traceback.format_exc()}')
|
| 88 |
return gr.update(label=model_id)
|
|
|
|
| 4 |
import gradio as gr
|
| 5 |
import logging
|
| 6 |
from huggingface_hub import login
|
| 7 |
+
from flash_attn.flash_attention import FlashAttention
|
| 8 |
|
| 9 |
import os
|
| 10 |
import traceback
|
|
|
|
| 67 |
device_map="auto",
|
| 68 |
trust_remote_code=True,
|
| 69 |
).eval()
|
| 70 |
+
for name, module in model_a.named_modules():
|
| 71 |
+
if isinstance(module, torch.nn.MultiheadAttention):
|
| 72 |
+
module.forward = FlashAttention(module.embed_dim)
|
| 73 |
+
logging.debug(f'{SPACER} forwarding module of {model_id_a} to flash_attn')
|
| 74 |
except Exception as e:
|
| 75 |
logging.error(f'{SPACER} Error: {e}, Traceback {traceback.format_exc()}')
|
| 76 |
|
|
|
|
| 88 |
device_map="auto",
|
| 89 |
trust_remote_code=True,
|
| 90 |
).eval()
|
| 91 |
+
for name, module in model_b.named_modules():
|
| 92 |
+
if isinstance(module, torch.nn.MultiheadAttention):
|
| 93 |
+
module.forward = FlashAttention(module.embed_dim)
|
| 94 |
+
logging.debug(f'{SPACER} forwarding module of {model_id_b} to flash_attn')
|
| 95 |
except Exception as e:
|
| 96 |
logging.error(f'{SPACER} Error: {e}, Traceback {traceback.format_exc()}')
|
| 97 |
return gr.update(label=model_id)
|
requirements.txt
CHANGED
|
@@ -5,4 +5,5 @@ accelerate==0.33.0
|
|
| 5 |
sentencepiece==0.2.0
|
| 6 |
spaces==0.29.2
|
| 7 |
gradio==4.39.0
|
| 8 |
-
bitsandbytes==0.43.2
|
|
|
|
|
|
| 5 |
sentencepiece==0.2.0
|
| 6 |
spaces==0.29.2
|
| 7 |
gradio==4.39.0
|
| 8 |
+
bitsandbytes==0.43.2
|
| 9 |
+
flash-attn
|