Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -17,6 +17,7 @@ from typing import Iterator, List, Optional, Tuple
|
|
| 17 |
import filelock
|
| 18 |
import glob
|
| 19 |
import json
|
|
|
|
| 20 |
|
| 21 |
from gradio_client.documentation import document, set_documentation_group
|
| 22 |
|
|
@@ -51,6 +52,23 @@ HF_TOKEN = os.environ.get("HF_TOKEN", None)
|
|
| 51 |
# ! path where the model is downloaded, either on ./ or persistent disc
|
| 52 |
MODEL_PATH = os.environ.get("MODEL_PATH", "./seal-13b-chat-a")
|
| 53 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
# ! !! Whether to delete the folder, ONLY SET THIS IF YOU WANT TO DELETE SAVED MODEL ON PERSISTENT DISC
|
| 55 |
DELETE_FOLDER = os.environ.get("DELETE_FOLDER", "")
|
| 56 |
IS_DELETE_FOLDER = DELETE_FOLDER is not None and os.path.exists(DELETE_FOLDER)
|
|
@@ -86,13 +104,18 @@ Internal instructions of how to configure the DEMO
|
|
| 86 |
|
| 87 |
1. Upload SFT model as a model to huggingface: hugginface/models/seal_13b_a
|
| 88 |
2. If the model weights is private, set HF_TOKEN=<your private hf token> in https://huggingface.co/spaces/????/?????/settings
|
| 89 |
-
3. space config env: `HF_MODEL_NAME=
|
| 90 |
4. If enable persistent storage: set
|
| 91 |
HF_HOME=/data/.huggingface
|
| 92 |
MODEL_PATH=/data/.huggingface/seal-13b-chat-a
|
| 93 |
if not:
|
| 94 |
MODEL_PATH=./seal-13b-chat-a
|
| 95 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
"""
|
| 97 |
|
| 98 |
# ==============================
|
|
@@ -127,6 +150,7 @@ EOS_TOKEN = '</s>'
|
|
| 127 |
B_INST, E_INST = "[INST]", "[/INST]"
|
| 128 |
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
|
| 129 |
|
|
|
|
| 130 |
SYSTEM_PROMPT_1 = """You are a multilingual, helpful, respectful and honest assistant. Your name is SeaLLM and you are built by DAMO Academy, Alibaba Group. \
|
| 131 |
Please always answer as helpfully as possible, while being safe. Your \
|
| 132 |
answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure \
|
|
@@ -168,12 +192,12 @@ MODEL_TITLE = """
|
|
| 168 |
</div>
|
| 169 |
</div>
|
| 170 |
"""
|
|
|
|
| 171 |
MODEL_DESC = """
|
| 172 |
<div style='display:flex; gap: 0.25rem; '>
|
| 173 |
-
<a href=''><img src='https://img.shields.io/badge/Github-Code-success'></a>
|
| 174 |
<a href='https://huggingface.co/spaces/SeaLLMs/SeaLLM-Chat-13b'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue'></a>
|
| 175 |
<a href='https://huggingface.co/SeaLLMs/SeaLLM-Chat-13b'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-blue'></a>
|
| 176 |
-
<a href=''><img src='https://img.shields.io/badge/Paper-PDF-red'></a>
|
| 177 |
</div>
|
| 178 |
<span style="font-size: larger">
|
| 179 |
This is <a href="https://huggingface.co/SeaLLMs/SeaLLM-Chat-13b" target="_blank">SeaLLM-13B-Chat</a> - a chatbot assistant optimized for Southeast Asian Languages. It produces helpful responses in English 🇬🇧, Vietnamese 🇻🇳, Indonesian 🇮🇩 and Thai 🇹🇭.
|
|
@@ -182,7 +206,7 @@ Explore <a href="https://huggingface.co/SeaLLMs/SeaLLM-Chat-13b" target="_blank"
|
|
| 182 |
<br>
|
| 183 |
<span >
|
| 184 |
NOTE: The chatbot may produce inaccurate and harmful information about people, places, or facts.
|
| 185 |
-
<
|
| 186 |
<ul>
|
| 187 |
<li >
|
| 188 |
You must not use our service to generate any harmful, unethical or illegal content that violates locally applicable and international laws or regulations,
|
|
@@ -725,6 +749,7 @@ from gradio.events import Dependency, EventListenerMethod
|
|
| 725 |
def _setup_stop_events(
|
| 726 |
self, event_triggers: list[EventListenerMethod], event_to_cancel: Dependency
|
| 727 |
) -> None:
|
|
|
|
| 728 |
event_triggers = event_triggers if isinstance(event_triggers, (list, tuple)) else [event_triggers]
|
| 729 |
if self.stop_btn and self.is_generator:
|
| 730 |
if self.submit_btn:
|
|
@@ -799,6 +824,7 @@ def _setup_stop_events(
|
|
| 799 |
|
| 800 |
# TODO: reconfigure clear button as stop and clear button
|
| 801 |
def _setup_events(self) -> None:
|
|
|
|
| 802 |
has_on = False
|
| 803 |
try:
|
| 804 |
from gradio.events import Dependency, EventListenerMethod, on
|
|
@@ -807,6 +833,14 @@ def _setup_events(self) -> None:
|
|
| 807 |
has_on = False
|
| 808 |
submit_fn = self._stream_fn if self.is_generator else self._submit_fn
|
| 809 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 810 |
|
| 811 |
if has_on:
|
| 812 |
# new version
|
|
@@ -831,6 +865,13 @@ def _setup_events(self) -> None:
|
|
| 831 |
api_name=False,
|
| 832 |
queue=False,
|
| 833 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 834 |
.then(
|
| 835 |
submit_fn,
|
| 836 |
[self.saved_input, self.chatbot_state] + self.additional_inputs,
|
|
@@ -912,6 +953,7 @@ def vllm_abort(self: Any):
|
|
| 912 |
continue
|
| 913 |
scheduler.free_seq(seq, SequenceStatus.FINISHED_ABORTED)
|
| 914 |
|
|
|
|
| 915 |
def _vllm_run_engine(self: Any, use_tqdm: bool = False) -> Dict[str, Any]:
|
| 916 |
from vllm.outputs import RequestOutput
|
| 917 |
# Initialize tqdm.
|
|
@@ -1027,10 +1069,6 @@ def block_lang(
|
|
| 1027 |
return False
|
| 1028 |
|
| 1029 |
|
| 1030 |
-
def log_responses(history, message, response):
|
| 1031 |
-
pass
|
| 1032 |
-
|
| 1033 |
-
|
| 1034 |
def safety_check(text, history=None, ) -> Optional[str]:
|
| 1035 |
"""
|
| 1036 |
Despite our effort in safety tuning and red teaming, our models may still generate harmful or illegal content.
|
|
@@ -1052,8 +1090,10 @@ def chat_response_stream_multiturn(
|
|
| 1052 |
temperature: float,
|
| 1053 |
max_tokens: int,
|
| 1054 |
frequency_penalty: float,
|
|
|
|
| 1055 |
system_prompt: Optional[str] = SYSTEM_PROMPT_1
|
| 1056 |
) -> str:
|
|
|
|
| 1057 |
from vllm import LLM, SamplingParams
|
| 1058 |
"""Build multi turn
|
| 1059 |
<bos>[INST] B_SYS SytemPrompt E_SYS Prompt [/INST] Answer <eos>
|
|
@@ -1075,6 +1115,12 @@ def chat_response_stream_multiturn(
|
|
| 1075 |
max_tokens = int(max_tokens)
|
| 1076 |
|
| 1077 |
message = message.strip()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1078 |
if len(message) == 0:
|
| 1079 |
raise gr.Error("The message cannot be empty!")
|
| 1080 |
|
|
@@ -1114,8 +1160,12 @@ def chat_response_stream_multiturn(
|
|
| 1114 |
assert len(gen) == 1, f'{gen}'
|
| 1115 |
item = next(iter(gen.values()))
|
| 1116 |
cur_out = item.outputs[0].text
|
| 1117 |
-
|
| 1118 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1119 |
|
| 1120 |
if cur_out is not None and "\\n" in cur_out:
|
| 1121 |
print(f'double slash-n in cur_out:\n{cur_out}')
|
|
@@ -1128,11 +1178,51 @@ def chat_response_stream_multiturn(
|
|
| 1128 |
if message_safety is not None:
|
| 1129 |
yield message_safety
|
| 1130 |
return
|
| 1131 |
-
|
| 1132 |
-
if LOG_RESPONSE:
|
| 1133 |
-
log_responses(history, message, cur_out)
|
| 1134 |
|
| 1135 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1136 |
|
| 1137 |
def debug_chat_response_echo(
|
| 1138 |
message: str,
|
|
@@ -1140,11 +1230,23 @@ def debug_chat_response_echo(
|
|
| 1140 |
temperature: float = 0.0,
|
| 1141 |
max_tokens: int = 4096,
|
| 1142 |
frequency_penalty: float = 0.4,
|
|
|
|
| 1143 |
system_prompt: str = SYSTEM_PROMPT_1,
|
| 1144 |
) -> str:
|
|
|
|
| 1145 |
import time
|
| 1146 |
time.sleep(0.5)
|
| 1147 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1148 |
|
| 1149 |
|
| 1150 |
def check_model_path(model_path) -> str:
|
|
@@ -1162,7 +1264,6 @@ def check_model_path(model_path) -> str:
|
|
| 1162 |
return ckpt_info
|
| 1163 |
|
| 1164 |
|
| 1165 |
-
|
| 1166 |
def maybe_delete_folder():
|
| 1167 |
if IS_DELETE_FOLDER and DOWNLOAD_SNAPSHOT:
|
| 1168 |
import shutil
|
|
@@ -1184,7 +1285,7 @@ async () => {
|
|
| 1184 |
"""
|
| 1185 |
|
| 1186 |
def launch():
|
| 1187 |
-
global demo, llm, DEBUG
|
| 1188 |
model_desc = MODEL_DESC
|
| 1189 |
model_path = MODEL_PATH
|
| 1190 |
model_title = MODEL_TITLE
|
|
@@ -1199,8 +1300,11 @@ def launch():
|
|
| 1199 |
ckpt_info = "None"
|
| 1200 |
|
| 1201 |
print(
|
| 1202 |
-
f'Launch config:
|
| 1203 |
f'\n| model_title=`{model_title}` '
|
|
|
|
|
|
|
|
|
|
| 1204 |
f'\n| BLOCK_LANGS={BLOCK_LANGS} '
|
| 1205 |
f'\n| IS_DELETE_FOLDER={IS_DELETE_FOLDER} '
|
| 1206 |
f'\n| STREAM_YIELD_MULTIPLE={STREAM_YIELD_MULTIPLE} '
|
|
@@ -1214,6 +1318,8 @@ def launch():
|
|
| 1214 |
f'\n| DOWNLOAD_SNAPSHOT={DOWNLOAD_SNAPSHOT} '
|
| 1215 |
f'\n| gpu_memory_utilization={gpu_memory_utilization} '
|
| 1216 |
f'\n| KEYWORDS={KEYWORDS} '
|
|
|
|
|
|
|
| 1217 |
f'\n| Sys={SYSTEM_PROMPT_1}'
|
| 1218 |
f'\n| Desc={model_desc}'
|
| 1219 |
)
|
|
@@ -1222,6 +1328,8 @@ def launch():
|
|
| 1222 |
model_desc += "\n<br>!!!!! This is in debug mode, responses will copy original"
|
| 1223 |
response_fn = debug_chat_response_echo
|
| 1224 |
print(f'Creating in DEBUG MODE')
|
|
|
|
|
|
|
| 1225 |
else:
|
| 1226 |
# ! load the model
|
| 1227 |
maybe_delete_folder()
|
|
@@ -1265,6 +1373,9 @@ def launch():
|
|
| 1265 |
response_fn = chat_response_stream_multiturn
|
| 1266 |
print(F'respond: {response_fn}')
|
| 1267 |
|
|
|
|
|
|
|
|
|
|
| 1268 |
demo = gr.ChatInterface(
|
| 1269 |
response_fn,
|
| 1270 |
chatbot=ChatBot(
|
|
@@ -1286,6 +1397,7 @@ def launch():
|
|
| 1286 |
gr.Number(value=temperature, label='Temperature (higher -> more random)'),
|
| 1287 |
gr.Number(value=max_tokens, label='Max generated tokens (increase if want more generation)'),
|
| 1288 |
gr.Number(value=frequence_penalty, label='Frequency penalty (> 0 encourage new tokens)'),
|
|
|
|
| 1289 |
# ! Remove the system prompt textbox to avoid jailbreaking
|
| 1290 |
# gr.Textbox(value=sys_prompt, label='System prompt', lines=8)
|
| 1291 |
],
|
|
@@ -1310,5 +1422,4 @@ def main():
|
|
| 1310 |
|
| 1311 |
|
| 1312 |
if __name__ == "__main__":
|
| 1313 |
-
main()
|
| 1314 |
-
|
|
|
|
| 17 |
import filelock
|
| 18 |
import glob
|
| 19 |
import json
|
| 20 |
+
import time
|
| 21 |
|
| 22 |
from gradio_client.documentation import document, set_documentation_group
|
| 23 |
|
|
|
|
| 52 |
# ! path where the model is downloaded, either on ./ or persistent disc
|
| 53 |
MODEL_PATH = os.environ.get("MODEL_PATH", "./seal-13b-chat-a")
|
| 54 |
|
| 55 |
+
# ! log path
|
| 56 |
+
LOG_PATH = os.environ.get("LOG_PATH", "").strip()
|
| 57 |
+
LOG_FILE = None
|
| 58 |
+
SAVE_LOGS = LOG_PATH is not None and LOG_PATH != ''
|
| 59 |
+
if SAVE_LOGS:
|
| 60 |
+
if os.path.exists(LOG_PATH):
|
| 61 |
+
print(f'LOG_PATH exist: {LOG_PATH}')
|
| 62 |
+
else:
|
| 63 |
+
LOG_DIR = os.path.dirname(LOG_PATH)
|
| 64 |
+
os.makedirs(LOG_DIR, exist_ok=True)
|
| 65 |
+
|
| 66 |
+
# ! get LOG_PATH as aggregated outputs in log
|
| 67 |
+
GET_LOG_CMD = os.environ.get("GET_LOG_CMD", "").strip()
|
| 68 |
+
|
| 69 |
+
print(f'SAVE_LOGS: {SAVE_LOGS} | {LOG_PATH}')
|
| 70 |
+
print(f'GET_LOG_CMD: {GET_LOG_CMD}')
|
| 71 |
+
|
| 72 |
# ! !! Whether to delete the folder, ONLY SET THIS IF YOU WANT TO DELETE SAVED MODEL ON PERSISTENT DISC
|
| 73 |
DELETE_FOLDER = os.environ.get("DELETE_FOLDER", "")
|
| 74 |
IS_DELETE_FOLDER = DELETE_FOLDER is not None and os.path.exists(DELETE_FOLDER)
|
|
|
|
| 104 |
|
| 105 |
1. Upload SFT model as a model to huggingface: hugginface/models/seal_13b_a
|
| 106 |
2. If the model weights is private, set HF_TOKEN=<your private hf token> in https://huggingface.co/spaces/????/?????/settings
|
| 107 |
+
3. space config env: `HF_MODEL_NAME=SeaLLMs/seal-13b-chat-a` or the underlining model
|
| 108 |
4. If enable persistent storage: set
|
| 109 |
HF_HOME=/data/.huggingface
|
| 110 |
MODEL_PATH=/data/.huggingface/seal-13b-chat-a
|
| 111 |
if not:
|
| 112 |
MODEL_PATH=./seal-13b-chat-a
|
| 113 |
|
| 114 |
+
|
| 115 |
+
HF_HOME=/data/.huggingface
|
| 116 |
+
MODEL_PATH=/data/ckpt/seal-13b-chat-a
|
| 117 |
+
DELETE_FOLDER=/data/
|
| 118 |
+
|
| 119 |
"""
|
| 120 |
|
| 121 |
# ==============================
|
|
|
|
| 150 |
B_INST, E_INST = "[INST]", "[/INST]"
|
| 151 |
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
|
| 152 |
|
| 153 |
+
# TODO: should Hide the system prompt
|
| 154 |
SYSTEM_PROMPT_1 = """You are a multilingual, helpful, respectful and honest assistant. Your name is SeaLLM and you are built by DAMO Academy, Alibaba Group. \
|
| 155 |
Please always answer as helpfully as possible, while being safe. Your \
|
| 156 |
answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure \
|
|
|
|
| 192 |
</div>
|
| 193 |
</div>
|
| 194 |
"""
|
| 195 |
+
# <a href=''><img src='https://img.shields.io/badge/Paper-PDF-red'></a>
|
| 196 |
MODEL_DESC = """
|
| 197 |
<div style='display:flex; gap: 0.25rem; '>
|
| 198 |
+
<a href='https://github.com/SeaLLMs/SeaLLMs'><img src='https://img.shields.io/badge/Github-Code-success'></a>
|
| 199 |
<a href='https://huggingface.co/spaces/SeaLLMs/SeaLLM-Chat-13b'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue'></a>
|
| 200 |
<a href='https://huggingface.co/SeaLLMs/SeaLLM-Chat-13b'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-blue'></a>
|
|
|
|
| 201 |
</div>
|
| 202 |
<span style="font-size: larger">
|
| 203 |
This is <a href="https://huggingface.co/SeaLLMs/SeaLLM-Chat-13b" target="_blank">SeaLLM-13B-Chat</a> - a chatbot assistant optimized for Southeast Asian Languages. It produces helpful responses in English 🇬🇧, Vietnamese 🇻🇳, Indonesian 🇮🇩 and Thai 🇹🇭.
|
|
|
|
| 206 |
<br>
|
| 207 |
<span >
|
| 208 |
NOTE: The chatbot may produce inaccurate and harmful information about people, places, or facts.
|
| 209 |
+
<span style="color: red">By using our service, you are required to agree to our <a href="https://huggingface.co/SeaLLMs/SeaLLM-Chat-13b/blob/main/LICENSE" target="_blank" style="color: red">SeaLLM Terms Of Use</a>, which include:</span><br>
|
| 210 |
<ul>
|
| 211 |
<li >
|
| 212 |
You must not use our service to generate any harmful, unethical or illegal content that violates locally applicable and international laws or regulations,
|
|
|
|
| 749 |
def _setup_stop_events(
|
| 750 |
self, event_triggers: list[EventListenerMethod], event_to_cancel: Dependency
|
| 751 |
) -> None:
|
| 752 |
+
from gradio.components import State
|
| 753 |
event_triggers = event_triggers if isinstance(event_triggers, (list, tuple)) else [event_triggers]
|
| 754 |
if self.stop_btn and self.is_generator:
|
| 755 |
if self.submit_btn:
|
|
|
|
| 824 |
|
| 825 |
# TODO: reconfigure clear button as stop and clear button
|
| 826 |
def _setup_events(self) -> None:
|
| 827 |
+
from gradio.components import State
|
| 828 |
has_on = False
|
| 829 |
try:
|
| 830 |
from gradio.events import Dependency, EventListenerMethod, on
|
|
|
|
| 833 |
has_on = False
|
| 834 |
submit_fn = self._stream_fn if self.is_generator else self._submit_fn
|
| 835 |
|
| 836 |
+
def update_time(c_time, chatbot_state):
|
| 837 |
+
# if chatbot_state is empty, register a new conversaion with the current timestamp
|
| 838 |
+
assert len(chatbot_state) > 0, f'empty chatbot state'
|
| 839 |
+
if len(chatbot_state) == 1:
|
| 840 |
+
assert chatbot_state[-1][-1] is None, f'invalid [[message, None]] , got {chatbot_state}'
|
| 841 |
+
return gr.Number(value=time.time(), label='current_time', visible=False), chatbot_state
|
| 842 |
+
else:
|
| 843 |
+
return c_time, chatbot_state
|
| 844 |
|
| 845 |
if has_on:
|
| 846 |
# new version
|
|
|
|
| 865 |
api_name=False,
|
| 866 |
queue=False,
|
| 867 |
)
|
| 868 |
+
.then(
|
| 869 |
+
update_time,
|
| 870 |
+
[self.additional_inputs[-1], self.chatbot_state],
|
| 871 |
+
[self.additional_inputs[-1], self.chatbot_state],
|
| 872 |
+
api_name=False,
|
| 873 |
+
queue=False,
|
| 874 |
+
)
|
| 875 |
.then(
|
| 876 |
submit_fn,
|
| 877 |
[self.saved_input, self.chatbot_state] + self.additional_inputs,
|
|
|
|
| 953 |
continue
|
| 954 |
scheduler.free_seq(seq, SequenceStatus.FINISHED_ABORTED)
|
| 955 |
|
| 956 |
+
|
| 957 |
def _vllm_run_engine(self: Any, use_tqdm: bool = False) -> Dict[str, Any]:
|
| 958 |
from vllm.outputs import RequestOutput
|
| 959 |
# Initialize tqdm.
|
|
|
|
| 1069 |
return False
|
| 1070 |
|
| 1071 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1072 |
def safety_check(text, history=None, ) -> Optional[str]:
|
| 1073 |
"""
|
| 1074 |
Despite our effort in safety tuning and red teaming, our models may still generate harmful or illegal content.
|
|
|
|
| 1090 |
temperature: float,
|
| 1091 |
max_tokens: int,
|
| 1092 |
frequency_penalty: float,
|
| 1093 |
+
current_time: Optional[float] = None,
|
| 1094 |
system_prompt: Optional[str] = SYSTEM_PROMPT_1
|
| 1095 |
) -> str:
|
| 1096 |
+
global LOG_FILE, LOG_PATH
|
| 1097 |
from vllm import LLM, SamplingParams
|
| 1098 |
"""Build multi turn
|
| 1099 |
<bos>[INST] B_SYS SytemPrompt E_SYS Prompt [/INST] Answer <eos>
|
|
|
|
| 1115 |
max_tokens = int(max_tokens)
|
| 1116 |
|
| 1117 |
message = message.strip()
|
| 1118 |
+
|
| 1119 |
+
if message.strip() == GET_LOG_CMD:
|
| 1120 |
+
print_log_file()
|
| 1121 |
+
yield "Finish printed log. Please clear the chatbox now."
|
| 1122 |
+
return
|
| 1123 |
+
|
| 1124 |
if len(message) == 0:
|
| 1125 |
raise gr.Error("The message cannot be empty!")
|
| 1126 |
|
|
|
|
| 1160 |
assert len(gen) == 1, f'{gen}'
|
| 1161 |
item = next(iter(gen.values()))
|
| 1162 |
cur_out = item.outputs[0].text
|
| 1163 |
+
|
| 1164 |
+
# TODO: use current_time to register conversations, accoriding history and cur_out
|
| 1165 |
+
history_str = format_conversation(history + [[message, cur_out]])
|
| 1166 |
+
print(f'@@@@@@@@@@\n{history_str}\n##########\n')
|
| 1167 |
+
|
| 1168 |
+
maybe_log_conv_file(current_time, history, message, cur_out, temperature=temperature, frequency_penalty=frequency_penalty)
|
| 1169 |
|
| 1170 |
if cur_out is not None and "\\n" in cur_out:
|
| 1171 |
print(f'double slash-n in cur_out:\n{cur_out}')
|
|
|
|
| 1178 |
if message_safety is not None:
|
| 1179 |
yield message_safety
|
| 1180 |
return
|
|
|
|
|
|
|
|
|
|
| 1181 |
|
| 1182 |
|
| 1183 |
+
def maybe_log_conv_file(current_time, history, message, response, **kwargs):
|
| 1184 |
+
global LOG_FILE
|
| 1185 |
+
if LOG_FILE is not None:
|
| 1186 |
+
my_history = history + [[message, response]]
|
| 1187 |
+
obj = {
|
| 1188 |
+
'key': str(current_time),
|
| 1189 |
+
'history': my_history
|
| 1190 |
+
}
|
| 1191 |
+
for k, v in kwargs.items():
|
| 1192 |
+
obj[k] = v
|
| 1193 |
+
log_ = json.dumps(obj, ensure_ascii=False)
|
| 1194 |
+
LOG_FILE.write(log_ + "\n")
|
| 1195 |
+
LOG_FILE.flush()
|
| 1196 |
+
print(f'Wrote {obj["key"]} to {LOG_PATH}')
|
| 1197 |
+
|
| 1198 |
+
|
| 1199 |
+
def format_conversation(history):
|
| 1200 |
+
_str = '\n'.join([
|
| 1201 |
+
(
|
| 1202 |
+
f'<<<User>>> {h[0]}\n'
|
| 1203 |
+
f'<<<Asst>>> {h[1]}'
|
| 1204 |
+
)
|
| 1205 |
+
for h in history
|
| 1206 |
+
])
|
| 1207 |
+
return _str
|
| 1208 |
+
|
| 1209 |
+
|
| 1210 |
+
def print_log_file():
|
| 1211 |
+
global LOG_FILE, LOG_PATH
|
| 1212 |
+
if SAVE_LOGS and os.path.exists(LOG_PATH):
|
| 1213 |
+
with open(LOG_PATH, 'r', encoding='utf-8') as f:
|
| 1214 |
+
convos = {}
|
| 1215 |
+
for l in f:
|
| 1216 |
+
if l:
|
| 1217 |
+
item = json.loads(l)
|
| 1218 |
+
convos[item['key']] = item
|
| 1219 |
+
print(f'Printing log from {LOG_PATH}')
|
| 1220 |
+
for k, v in convos.items():
|
| 1221 |
+
history = v.pop('history')
|
| 1222 |
+
print(f'######--{v}--##')
|
| 1223 |
+
_str = format_conversation(history)
|
| 1224 |
+
print(_str)
|
| 1225 |
+
|
| 1226 |
|
| 1227 |
def debug_chat_response_echo(
|
| 1228 |
message: str,
|
|
|
|
| 1230 |
temperature: float = 0.0,
|
| 1231 |
max_tokens: int = 4096,
|
| 1232 |
frequency_penalty: float = 0.4,
|
| 1233 |
+
current_time: Optional[float] = None,
|
| 1234 |
system_prompt: str = SYSTEM_PROMPT_1,
|
| 1235 |
) -> str:
|
| 1236 |
+
global LOG_FILE
|
| 1237 |
import time
|
| 1238 |
time.sleep(0.5)
|
| 1239 |
+
|
| 1240 |
+
if message.strip() == GET_LOG_CMD:
|
| 1241 |
+
print_log_file()
|
| 1242 |
+
yield "Finish printed log."
|
| 1243 |
+
return
|
| 1244 |
+
|
| 1245 |
+
for i in range(len(message)):
|
| 1246 |
+
yield f"repeat: {current_time} {message[:i + 1]}"
|
| 1247 |
+
|
| 1248 |
+
cur_out = f"repeat: {current_time} {message}"
|
| 1249 |
+
maybe_log_conv_file(current_time, history, message, cur_out, temperature=temperature, frequency_penalty=frequency_penalty)
|
| 1250 |
|
| 1251 |
|
| 1252 |
def check_model_path(model_path) -> str:
|
|
|
|
| 1264 |
return ckpt_info
|
| 1265 |
|
| 1266 |
|
|
|
|
| 1267 |
def maybe_delete_folder():
|
| 1268 |
if IS_DELETE_FOLDER and DOWNLOAD_SNAPSHOT:
|
| 1269 |
import shutil
|
|
|
|
| 1285 |
"""
|
| 1286 |
|
| 1287 |
def launch():
|
| 1288 |
+
global demo, llm, DEBUG, LOG_FILE
|
| 1289 |
model_desc = MODEL_DESC
|
| 1290 |
model_path = MODEL_PATH
|
| 1291 |
model_title = MODEL_TITLE
|
|
|
|
| 1300 |
ckpt_info = "None"
|
| 1301 |
|
| 1302 |
print(
|
| 1303 |
+
f'Launch config: '
|
| 1304 |
f'\n| model_title=`{model_title}` '
|
| 1305 |
+
f'\n| max_tokens={max_tokens} '
|
| 1306 |
+
f'\n| dtype={dtype} '
|
| 1307 |
+
f'\n| tensor_parallel={tensor_parallel} '
|
| 1308 |
f'\n| BLOCK_LANGS={BLOCK_LANGS} '
|
| 1309 |
f'\n| IS_DELETE_FOLDER={IS_DELETE_FOLDER} '
|
| 1310 |
f'\n| STREAM_YIELD_MULTIPLE={STREAM_YIELD_MULTIPLE} '
|
|
|
|
| 1318 |
f'\n| DOWNLOAD_SNAPSHOT={DOWNLOAD_SNAPSHOT} '
|
| 1319 |
f'\n| gpu_memory_utilization={gpu_memory_utilization} '
|
| 1320 |
f'\n| KEYWORDS={KEYWORDS} '
|
| 1321 |
+
f'\n| LOG_PATH={LOG_PATH} | SAVE_LOGS={SAVE_LOGS} '
|
| 1322 |
+
f'\n| GET_LOG_CMD={GET_LOG_CMD} '
|
| 1323 |
f'\n| Sys={SYSTEM_PROMPT_1}'
|
| 1324 |
f'\n| Desc={model_desc}'
|
| 1325 |
)
|
|
|
|
| 1328 |
model_desc += "\n<br>!!!!! This is in debug mode, responses will copy original"
|
| 1329 |
response_fn = debug_chat_response_echo
|
| 1330 |
print(f'Creating in DEBUG MODE')
|
| 1331 |
+
if SAVE_LOGS:
|
| 1332 |
+
LOG_FILE = open(LOG_PATH, 'a', encoding='utf-8')
|
| 1333 |
else:
|
| 1334 |
# ! load the model
|
| 1335 |
maybe_delete_folder()
|
|
|
|
| 1373 |
response_fn = chat_response_stream_multiturn
|
| 1374 |
print(F'respond: {response_fn}')
|
| 1375 |
|
| 1376 |
+
if SAVE_LOGS:
|
| 1377 |
+
LOG_FILE = open(LOG_PATH, 'a', encoding='utf-8')
|
| 1378 |
+
|
| 1379 |
demo = gr.ChatInterface(
|
| 1380 |
response_fn,
|
| 1381 |
chatbot=ChatBot(
|
|
|
|
| 1397 |
gr.Number(value=temperature, label='Temperature (higher -> more random)'),
|
| 1398 |
gr.Number(value=max_tokens, label='Max generated tokens (increase if want more generation)'),
|
| 1399 |
gr.Number(value=frequence_penalty, label='Frequency penalty (> 0 encourage new tokens)'),
|
| 1400 |
+
gr.Number(value=0, label='current_time', visible=False),
|
| 1401 |
# ! Remove the system prompt textbox to avoid jailbreaking
|
| 1402 |
# gr.Textbox(value=sys_prompt, label='System prompt', lines=8)
|
| 1403 |
],
|
|
|
|
| 1422 |
|
| 1423 |
|
| 1424 |
if __name__ == "__main__":
|
| 1425 |
+
main()
|
|
|