Spaces:
Runtime error
Runtime error
phi
commited on
Commit
·
3709b60
1
Parent(s):
6ded56f
update
Browse files
app.py
CHANGED
|
@@ -57,68 +57,29 @@ TODO:
|
|
| 57 |
need to upload the model as hugginface/models/seal_13b_a
|
| 58 |
# https://huggingface.co/docs/hub/spaces-overview#managing-secrets
|
| 59 |
set
|
| 60 |
-
|
| 61 |
|
|
|
|
| 62 |
# if persistent, then export the following
|
|
|
|
| 63 |
HF_HOME=/data/.huggingface
|
| 64 |
-
TRANSFORMERS_CACHE=/data/.huggingface
|
| 65 |
MODEL_PATH=/data/.huggingface/seal-13b-chat-a
|
| 66 |
HF_MODEL_NAME=DAMO-NLP-SG/seal-13b-chat-a
|
| 67 |
# if not persistent
|
| 68 |
MODEL_PATH=./seal-13b-chat-a
|
| 69 |
HF_MODEL_NAME=DAMO-NLP-SG/seal-13b-chat-a
|
| 70 |
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
# download will auto detect and get the most updated one
|
| 74 |
-
if DOWNLOAD_SNAPSHOT:
|
| 75 |
-
print(f'Download from HF_MODEL_NAME={HF_MODEL_NAME} -> {MODEL_PATH}')
|
| 76 |
-
snapshot_download(HF_MODEL_NAME, local_dir=MODEL_PATH)
|
| 77 |
-
elif not DEBUG:
|
| 78 |
-
assert os.path.exists(MODEL_PATH), f'{MODEL_PATH} not found and no snapshot download'
|
| 79 |
-
|
| 80 |
"""
|
| 81 |
|
| 82 |
|
| 83 |
-
|
| 84 |
-
|
| 85 |
# ==============================
|
| 86 |
print(f'DEBUG mode: {DEBUG}')
|
| 87 |
|
| 88 |
-
if DTYPE == "bfloat16" and not DEBUG:
|
| 89 |
-
try:
|
| 90 |
-
compute_capability = torch.cuda.get_device_capability()
|
| 91 |
-
if compute_capability[0] < 8:
|
| 92 |
-
gpu_name = torch.cuda.get_device_name()
|
| 93 |
-
print(
|
| 94 |
-
"Bfloat16 is only supported on GPUs with compute capability "
|
| 95 |
-
f"of at least 8.0. Your {gpu_name} GPU has compute capability "
|
| 96 |
-
f"{compute_capability[0]}.{compute_capability[1]}. --> Move to FLOAT16")
|
| 97 |
-
DTYPE = "float16"
|
| 98 |
-
except Exception as e:
|
| 99 |
-
print(f'Unable to obtain compute_capability: {e}')
|
| 100 |
|
| 101 |
|
| 102 |
-
# @@ constants ================
|
| 103 |
-
if not DEBUG:
|
| 104 |
-
|
| 105 |
-
# vllm import
|
| 106 |
-
from vllm import LLM, SamplingParams
|
| 107 |
-
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
|
| 108 |
-
from vllm.engine.arg_utils import EngineArgs
|
| 109 |
-
from vllm.engine.llm_engine import LLMEngine
|
| 110 |
-
from vllm.outputs import RequestOutput
|
| 111 |
-
from vllm.sampling_params import SamplingParams
|
| 112 |
-
from vllm.utils import Counter
|
| 113 |
-
from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
|
| 114 |
-
SequenceGroupMetadata, SequenceOutputs,
|
| 115 |
-
SequenceStatus)
|
| 116 |
-
# ! reconfigure vllm to faster llama
|
| 117 |
-
from vllm.model_executor.model_loader import _MODEL_REGISTRY
|
| 118 |
-
from vllm.model_executor.models import LlamaForCausalLM
|
| 119 |
|
|
|
|
| 120 |
|
| 121 |
-
_MODEL_REGISTRY['FasterLlamaForCausalLM'] = LlamaForCausalLM
|
| 122 |
|
| 123 |
|
| 124 |
def _detect_lang(text):
|
|
@@ -390,7 +351,6 @@ def llama_load_weights(
|
|
| 390 |
intermediate_size + shard_size * tensor_model_parallel_rank,
|
| 391 |
intermediate_size + shard_size * (tensor_model_parallel_rank + 1)
|
| 392 |
)
|
| 393 |
-
# print(f'{name} {param.size()} | {g_offsets} | {u_offsets}')
|
| 394 |
_loaded_weight = torch.cat(
|
| 395 |
[
|
| 396 |
loaded_weight[g_offsets[0]:g_offsets[1]],
|
|
@@ -420,7 +380,33 @@ def llama_load_weights(
|
|
| 420 |
|
| 421 |
# Reassign LlamaForCausalLM.load_weights with llama_load_weights
|
| 422 |
if not DEBUG:
|
| 423 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 424 |
|
| 425 |
# ! ==================================================================
|
| 426 |
|
|
@@ -501,11 +487,11 @@ class ChatBot(gr.Chatbot):
|
|
| 501 |
return x
|
| 502 |
|
| 503 |
|
| 504 |
-
# gr.ChatInterface
|
| 505 |
from gradio.components import Button
|
| 506 |
from gradio.events import Dependency, EventListenerMethod
|
| 507 |
|
| 508 |
-
|
|
|
|
| 509 |
def _setup_stop_events(
|
| 510 |
self, event_triggers: list[EventListenerMethod], event_to_cancel: Dependency
|
| 511 |
) -> None:
|
|
@@ -571,13 +557,12 @@ def _setup_stop_events(
|
|
| 571 |
queue=False,
|
| 572 |
)
|
| 573 |
|
| 574 |
-
|
| 575 |
-
|
| 576 |
gr.ChatInterface._setup_stop_events = _setup_stop_events
|
| 577 |
|
| 578 |
def chat_response(message, history, temperature: float, max_tokens: int, system_prompt: str = '') -> str:
|
| 579 |
global llm
|
| 580 |
assert llm is not None
|
|
|
|
| 581 |
temperature = float(temperature)
|
| 582 |
max_tokens = int(max_tokens)
|
| 583 |
if system_prompt.strip() != '':
|
|
@@ -594,6 +579,7 @@ def chat_response(message, history, temperature: float, max_tokens: int, system_
|
|
| 594 |
|
| 595 |
|
| 596 |
def vllm_abort(self: Any):
|
|
|
|
| 597 |
scheduler = self.llm_engine.scheduler
|
| 598 |
for state_queue in [scheduler.waiting, scheduler.running, scheduler.swapped]:
|
| 599 |
for seq_group in state_queue:
|
|
@@ -607,6 +593,7 @@ def vllm_abort(self: Any):
|
|
| 607 |
|
| 608 |
# def _vllm_run_engine(self: LLM, use_tqdm: bool = False) -> Dict[str, RequestOutput]:
|
| 609 |
def _vllm_run_engine(self: Any, use_tqdm: bool = False) -> Dict[str, Any]:
|
|
|
|
| 610 |
# Initialize tqdm.
|
| 611 |
if use_tqdm:
|
| 612 |
num_requests = self.llm_engine.get_num_unfinished_requests()
|
|
@@ -654,6 +641,7 @@ def vllm_generate_stream(
|
|
| 654 |
A list of `RequestOutput` objects containing the generated
|
| 655 |
completions in the same order as the input prompts.
|
| 656 |
"""
|
|
|
|
| 657 |
if prompts is None and prompt_token_ids is None:
|
| 658 |
raise ValueError("Either prompts or prompt_token_ids must be "
|
| 659 |
"provided.")
|
|
@@ -750,6 +738,7 @@ def chat_response_stream_multiturn(
|
|
| 750 |
frequency_penalty: float,
|
| 751 |
system_prompt: Optional[str] = SYSTEM_PROMPT_1
|
| 752 |
) -> str:
|
|
|
|
| 753 |
"""Build multi turn
|
| 754 |
<bos>[INST] B_SYS SytemPrompt E_SYS Prompt [/INST] Answer <eos>
|
| 755 |
<bos>[INST] Prompt [/INST] Answer <eos>
|
|
@@ -837,7 +826,7 @@ This is a DAMO SeaL-13B chatbot assistant built by DAMO Academy, Alibaba Group.
|
|
| 837 |
|
| 838 |
|
| 839 |
cite_markdown = """
|
| 840 |
-
|
| 841 |
If you find our project useful, hope you can star our repo and cite our paper as follows:
|
| 842 |
```
|
| 843 |
@article{damonlpsg2023seallm,
|
|
@@ -849,9 +838,8 @@ If you find our project useful, hope you can star our repo and cite our paper as
|
|
| 849 |
"""
|
| 850 |
|
| 851 |
warning_markdown = """
|
| 852 |
-
|
| 853 |
<span style="color: red">The chatbot may produce inaccurate and harmful information about people, places, or facts.</span>
|
| 854 |
-
|
| 855 |
<span style="color: red">We strongly advise against misuse of the chatbot to knowingly generate harmful or unethical content, \
|
| 856 |
or content that violates locally applicable and international laws or regulations, including hate speech, violence, pornography, deception, etc!</span>
|
| 857 |
"""
|
|
@@ -893,11 +881,12 @@ def launch():
|
|
| 893 |
ckpt_info = "None"
|
| 894 |
|
| 895 |
print(
|
| 896 |
-
f'Launch config: {
|
| 897 |
f'\n| STREAM_YIELD_MULTIPLE={STREAM_YIELD_MULTIPLE} '
|
| 898 |
f'\n| frequence_penalty={frequence_penalty} '
|
| 899 |
f'\n| temperature={temperature} '
|
| 900 |
f'\n| hf_model_name={hf_model_name} '
|
|
|
|
| 901 |
f'\n| DOWNLOAD_SNAPSHOT={DOWNLOAD_SNAPSHOT} '
|
| 902 |
f'\nsys={SYSTEM_PROMPT_1}'
|
| 903 |
f'\ndesc={model_desc}'
|
|
@@ -910,6 +899,8 @@ def launch():
|
|
| 910 |
else:
|
| 911 |
# ! load the model
|
| 912 |
import vllm
|
|
|
|
|
|
|
| 913 |
print(F'VLLM: {vllm.__version__}')
|
| 914 |
|
| 915 |
if DOWNLOAD_SNAPSHOT:
|
|
@@ -962,7 +953,6 @@ def launch():
|
|
| 962 |
|
| 963 |
def main():
|
| 964 |
|
| 965 |
-
# launch(parser.parse_args())
|
| 966 |
launch()
|
| 967 |
|
| 968 |
|
|
|
|
| 57 |
need to upload the model as hugginface/models/seal_13b_a
|
| 58 |
# https://huggingface.co/docs/hub/spaces-overview#managing-secrets
|
| 59 |
set
|
| 60 |
+
HF_TOKEN=???
|
| 61 |
|
| 62 |
+
TRANSFORMERS_CACHE=/data/.huggingface
|
| 63 |
# if persistent, then export the following
|
| 64 |
+
|
| 65 |
HF_HOME=/data/.huggingface
|
|
|
|
| 66 |
MODEL_PATH=/data/.huggingface/seal-13b-chat-a
|
| 67 |
HF_MODEL_NAME=DAMO-NLP-SG/seal-13b-chat-a
|
| 68 |
# if not persistent
|
| 69 |
MODEL_PATH=./seal-13b-chat-a
|
| 70 |
HF_MODEL_NAME=DAMO-NLP-SG/seal-13b-chat-a
|
| 71 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
"""
|
| 73 |
|
| 74 |
|
|
|
|
|
|
|
| 75 |
# ==============================
|
| 76 |
print(f'DEBUG mode: {DEBUG}')
|
| 77 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
|
| 79 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
|
| 81 |
+
# @@ constants ================
|
| 82 |
|
|
|
|
| 83 |
|
| 84 |
|
| 85 |
def _detect_lang(text):
|
|
|
|
| 351 |
intermediate_size + shard_size * tensor_model_parallel_rank,
|
| 352 |
intermediate_size + shard_size * (tensor_model_parallel_rank + 1)
|
| 353 |
)
|
|
|
|
| 354 |
_loaded_weight = torch.cat(
|
| 355 |
[
|
| 356 |
loaded_weight[g_offsets[0]:g_offsets[1]],
|
|
|
|
| 380 |
|
| 381 |
# Reassign LlamaForCausalLM.load_weights with llama_load_weights
|
| 382 |
if not DEBUG:
|
| 383 |
+
|
| 384 |
+
# vllm import
|
| 385 |
+
# from vllm import LLM, SamplingParams
|
| 386 |
+
# ! reconfigure vllm to faster llama
|
| 387 |
+
try:
|
| 388 |
+
import vllm
|
| 389 |
+
from vllm.model_executor.model_loader import _MODEL_REGISTRY
|
| 390 |
+
from vllm.model_executor.models import LlamaForCausalLM
|
| 391 |
+
|
| 392 |
+
_MODEL_REGISTRY['FasterLlamaForCausalLM'] = LlamaForCausalLM
|
| 393 |
+
LlamaForCausalLM.load_weights = llama_load_weights
|
| 394 |
+
|
| 395 |
+
if DTYPE == "bfloat16":
|
| 396 |
+
try:
|
| 397 |
+
compute_capability = torch.cuda.get_device_capability()
|
| 398 |
+
if compute_capability[0] < 8:
|
| 399 |
+
gpu_name = torch.cuda.get_device_name()
|
| 400 |
+
print(
|
| 401 |
+
"Bfloat16 is only supported on GPUs with compute capability "
|
| 402 |
+
f"of at least 8.0. Your {gpu_name} GPU has compute capability "
|
| 403 |
+
f"{compute_capability[0]}.{compute_capability[1]}. --> Move to FLOAT16")
|
| 404 |
+
DTYPE = "float16"
|
| 405 |
+
except Exception as e:
|
| 406 |
+
print(f'Unable to obtain compute_capability: {e}')
|
| 407 |
+
except Exception as e:
|
| 408 |
+
print(f'Failing import and reconfigure VLLM: {str(e)}')
|
| 409 |
+
|
| 410 |
|
| 411 |
# ! ==================================================================
|
| 412 |
|
|
|
|
| 487 |
return x
|
| 488 |
|
| 489 |
|
|
|
|
| 490 |
from gradio.components import Button
|
| 491 |
from gradio.events import Dependency, EventListenerMethod
|
| 492 |
|
| 493 |
+
# replace events so that submit button is disabled during generation, if stop_btn not found
|
| 494 |
+
# this prevent weird behavior
|
| 495 |
def _setup_stop_events(
|
| 496 |
self, event_triggers: list[EventListenerMethod], event_to_cancel: Dependency
|
| 497 |
) -> None:
|
|
|
|
| 557 |
queue=False,
|
| 558 |
)
|
| 559 |
|
|
|
|
|
|
|
| 560 |
gr.ChatInterface._setup_stop_events = _setup_stop_events
|
| 561 |
|
| 562 |
def chat_response(message, history, temperature: float, max_tokens: int, system_prompt: str = '') -> str:
|
| 563 |
global llm
|
| 564 |
assert llm is not None
|
| 565 |
+
from vllm import LLM, SamplingParams
|
| 566 |
temperature = float(temperature)
|
| 567 |
max_tokens = int(max_tokens)
|
| 568 |
if system_prompt.strip() != '':
|
|
|
|
| 579 |
|
| 580 |
|
| 581 |
def vllm_abort(self: Any):
|
| 582 |
+
from vllm.sequence import SequenceStatus
|
| 583 |
scheduler = self.llm_engine.scheduler
|
| 584 |
for state_queue in [scheduler.waiting, scheduler.running, scheduler.swapped]:
|
| 585 |
for seq_group in state_queue:
|
|
|
|
| 593 |
|
| 594 |
# def _vllm_run_engine(self: LLM, use_tqdm: bool = False) -> Dict[str, RequestOutput]:
|
| 595 |
def _vllm_run_engine(self: Any, use_tqdm: bool = False) -> Dict[str, Any]:
|
| 596 |
+
from vllm.outputs import RequestOutput
|
| 597 |
# Initialize tqdm.
|
| 598 |
if use_tqdm:
|
| 599 |
num_requests = self.llm_engine.get_num_unfinished_requests()
|
|
|
|
| 641 |
A list of `RequestOutput` objects containing the generated
|
| 642 |
completions in the same order as the input prompts.
|
| 643 |
"""
|
| 644 |
+
from vllm import LLM, SamplingParams
|
| 645 |
if prompts is None and prompt_token_ids is None:
|
| 646 |
raise ValueError("Either prompts or prompt_token_ids must be "
|
| 647 |
"provided.")
|
|
|
|
| 738 |
frequency_penalty: float,
|
| 739 |
system_prompt: Optional[str] = SYSTEM_PROMPT_1
|
| 740 |
) -> str:
|
| 741 |
+
from vllm import LLM, SamplingParams
|
| 742 |
"""Build multi turn
|
| 743 |
<bos>[INST] B_SYS SytemPrompt E_SYS Prompt [/INST] Answer <eos>
|
| 744 |
<bos>[INST] Prompt [/INST] Answer <eos>
|
|
|
|
| 826 |
|
| 827 |
|
| 828 |
cite_markdown = """
|
| 829 |
+
## Citation
|
| 830 |
If you find our project useful, hope you can star our repo and cite our paper as follows:
|
| 831 |
```
|
| 832 |
@article{damonlpsg2023seallm,
|
|
|
|
| 838 |
"""
|
| 839 |
|
| 840 |
warning_markdown = """
|
| 841 |
+
## Warning:
|
| 842 |
<span style="color: red">The chatbot may produce inaccurate and harmful information about people, places, or facts.</span>
|
|
|
|
| 843 |
<span style="color: red">We strongly advise against misuse of the chatbot to knowingly generate harmful or unethical content, \
|
| 844 |
or content that violates locally applicable and international laws or regulations, including hate speech, violence, pornography, deception, etc!</span>
|
| 845 |
"""
|
|
|
|
| 881 |
ckpt_info = "None"
|
| 882 |
|
| 883 |
print(
|
| 884 |
+
f'Launch config: {model_title=} / {tensor_parallel=} / {dtype=} / {max_tokens} | {BLOCK_ZH=} '
|
| 885 |
f'\n| STREAM_YIELD_MULTIPLE={STREAM_YIELD_MULTIPLE} '
|
| 886 |
f'\n| frequence_penalty={frequence_penalty} '
|
| 887 |
f'\n| temperature={temperature} '
|
| 888 |
f'\n| hf_model_name={hf_model_name} '
|
| 889 |
+
f'\n| model_path={model_path} '
|
| 890 |
f'\n| DOWNLOAD_SNAPSHOT={DOWNLOAD_SNAPSHOT} '
|
| 891 |
f'\nsys={SYSTEM_PROMPT_1}'
|
| 892 |
f'\ndesc={model_desc}'
|
|
|
|
| 899 |
else:
|
| 900 |
# ! load the model
|
| 901 |
import vllm
|
| 902 |
+
from vllm import LLM, SamplingParams
|
| 903 |
+
|
| 904 |
print(F'VLLM: {vllm.__version__}')
|
| 905 |
|
| 906 |
if DOWNLOAD_SNAPSHOT:
|
|
|
|
| 953 |
|
| 954 |
def main():
|
| 955 |
|
|
|
|
| 956 |
launch()
|
| 957 |
|
| 958 |
|