Spaces:
Sleeping
Sleeping
Tuchuanhuhuhu
commited on
Commit
·
a2dfe6a
1
Parent(s):
b0a1d94
改进减少token逻辑
Browse files- ChuanhuChatbot.py +1 -1
- chat_func.py +15 -12
- utils.py +46 -26
ChuanhuChatbot.py
CHANGED
|
@@ -359,7 +359,7 @@ with gr.Blocks(
|
|
| 359 |
token_count,
|
| 360 |
top_p,
|
| 361 |
temperature,
|
| 362 |
-
|
| 363 |
model_select_dropdown,
|
| 364 |
],
|
| 365 |
[chatbot, history, status_display, token_count],
|
|
|
|
| 359 |
token_count,
|
| 360 |
top_p,
|
| 361 |
temperature,
|
| 362 |
+
gr.State(0),
|
| 363 |
model_select_dropdown,
|
| 364 |
],
|
| 365 |
[chatbot, history, status_display, token_count],
|
chat_func.py
CHANGED
|
@@ -371,9 +371,8 @@ def predict(
|
|
| 371 |
all_token_counts,
|
| 372 |
top_p,
|
| 373 |
temperature,
|
| 374 |
-
|
| 375 |
selected_model=selected_model,
|
| 376 |
-
hidden=True,
|
| 377 |
)
|
| 378 |
for chatbot, history, status_text, all_token_counts in iter:
|
| 379 |
status_text = f"Token 达到上限,已自动降低Token计数至 {status_text}"
|
|
@@ -410,9 +409,10 @@ def retry(
|
|
| 410 |
stream=stream,
|
| 411 |
selected_model=selected_model,
|
| 412 |
)
|
| 413 |
-
logging.info("
|
| 414 |
for x in iter:
|
| 415 |
yield x
|
|
|
|
| 416 |
|
| 417 |
|
| 418 |
def reduce_token_size(
|
|
@@ -423,9 +423,8 @@ def reduce_token_size(
|
|
| 423 |
token_count,
|
| 424 |
top_p,
|
| 425 |
temperature,
|
| 426 |
-
|
| 427 |
selected_model=MODELS[0],
|
| 428 |
-
hidden=False,
|
| 429 |
):
|
| 430 |
logging.info("开始减少token数量……")
|
| 431 |
iter = predict(
|
|
@@ -437,17 +436,21 @@ def reduce_token_size(
|
|
| 437 |
token_count,
|
| 438 |
top_p,
|
| 439 |
temperature,
|
| 440 |
-
stream=stream,
|
| 441 |
selected_model=selected_model,
|
| 442 |
should_check_token_count=False,
|
| 443 |
)
|
| 444 |
logging.info(f"chatbot: {chatbot}")
|
|
|
|
| 445 |
for chatbot, history, status_text, previous_token_count in iter:
|
| 446 |
-
|
| 447 |
-
|
| 448 |
-
|
| 449 |
-
|
| 450 |
-
|
| 451 |
-
|
|
|
|
|
|
|
|
|
|
| 452 |
), token_count
|
|
|
|
| 453 |
logging.info("减少token数量完毕")
|
|
|
|
| 371 |
all_token_counts,
|
| 372 |
top_p,
|
| 373 |
temperature,
|
| 374 |
+
max_token//2,
|
| 375 |
selected_model=selected_model,
|
|
|
|
| 376 |
)
|
| 377 |
for chatbot, history, status_text, all_token_counts in iter:
|
| 378 |
status_text = f"Token 达到上限,已自动降低Token计数至 {status_text}"
|
|
|
|
| 409 |
stream=stream,
|
| 410 |
selected_model=selected_model,
|
| 411 |
)
|
| 412 |
+
logging.info("重试中……")
|
| 413 |
for x in iter:
|
| 414 |
yield x
|
| 415 |
+
logging.info("重试完毕")
|
| 416 |
|
| 417 |
|
| 418 |
def reduce_token_size(
|
|
|
|
| 423 |
token_count,
|
| 424 |
top_p,
|
| 425 |
temperature,
|
| 426 |
+
max_token_count,
|
| 427 |
selected_model=MODELS[0],
|
|
|
|
| 428 |
):
|
| 429 |
logging.info("开始减少token数量……")
|
| 430 |
iter = predict(
|
|
|
|
| 436 |
token_count,
|
| 437 |
top_p,
|
| 438 |
temperature,
|
|
|
|
| 439 |
selected_model=selected_model,
|
| 440 |
should_check_token_count=False,
|
| 441 |
)
|
| 442 |
logging.info(f"chatbot: {chatbot}")
|
| 443 |
+
flag = False
|
| 444 |
for chatbot, history, status_text, previous_token_count in iter:
|
| 445 |
+
num_chat = find_n(previous_token_count, max_token_count)
|
| 446 |
+
if flag:
|
| 447 |
+
chatbot = chatbot[:-1]
|
| 448 |
+
flag = True
|
| 449 |
+
history = history[-2*num_chat:] if num_chat > 0 else []
|
| 450 |
+
token_count = previous_token_count[-num_chat:] if num_chat > 0 else []
|
| 451 |
+
msg = f"保留了最近{num_chat}轮对话"
|
| 452 |
+
yield chatbot, history, msg + "," + construct_token_message(
|
| 453 |
+
sum(token_count) if len(token_count) > 0 else 0,
|
| 454 |
), token_count
|
| 455 |
+
logging.info(msg)
|
| 456 |
logging.info("减少token数量完毕")
|
utils.py
CHANGED
|
@@ -37,9 +37,10 @@ def count_token(message):
|
|
| 37 |
length = len(encoding.encode(input_str))
|
| 38 |
return length
|
| 39 |
|
|
|
|
| 40 |
def markdown_to_html_with_syntax_highlight(md_str):
|
| 41 |
def replacer(match):
|
| 42 |
-
lang = match.group(1) or
|
| 43 |
code = match.group(2)
|
| 44 |
|
| 45 |
try:
|
|
@@ -50,60 +51,65 @@ def markdown_to_html_with_syntax_highlight(md_str):
|
|
| 50 |
formatter = HtmlFormatter()
|
| 51 |
highlighted_code = highlight(code, lexer, formatter)
|
| 52 |
|
| 53 |
-
return f
|
| 54 |
|
| 55 |
-
code_block_pattern = r
|
| 56 |
md_str = re.sub(code_block_pattern, replacer, md_str, flags=re.MULTILINE)
|
| 57 |
|
| 58 |
html_str = markdown(md_str)
|
| 59 |
return html_str
|
| 60 |
|
|
|
|
| 61 |
def normalize_markdown(md_text: str) -> str:
|
| 62 |
-
lines = md_text.split(
|
| 63 |
normalized_lines = []
|
| 64 |
inside_list = False
|
| 65 |
|
| 66 |
for i, line in enumerate(lines):
|
| 67 |
-
if re.match(r
|
| 68 |
-
if not inside_list and i > 0 and lines[i - 1].strip() !=
|
| 69 |
-
normalized_lines.append(
|
| 70 |
inside_list = True
|
| 71 |
normalized_lines.append(line)
|
| 72 |
-
elif inside_list and line.strip() ==
|
| 73 |
-
if i < len(lines) - 1 and not re.match(
|
|
|
|
|
|
|
| 74 |
normalized_lines.append(line)
|
| 75 |
continue
|
| 76 |
else:
|
| 77 |
inside_list = False
|
| 78 |
normalized_lines.append(line)
|
| 79 |
|
| 80 |
-
return
|
|
|
|
| 81 |
|
| 82 |
def convert_mdtext(md_text):
|
| 83 |
-
code_block_pattern = re.compile(r
|
| 84 |
code_blocks = code_block_pattern.findall(md_text)
|
| 85 |
non_code_parts = code_block_pattern.split(md_text)[::2]
|
| 86 |
|
| 87 |
result = []
|
| 88 |
-
for non_code, code in zip(non_code_parts, code_blocks + [
|
| 89 |
if non_code.strip():
|
| 90 |
non_code = normalize_markdown(non_code)
|
| 91 |
-
result.append(mdtex2html.convert(non_code, extensions=[
|
| 92 |
if code.strip():
|
| 93 |
-
_, code = detect_language(code)
|
| 94 |
code = f"```{code}\n\n```"
|
| 95 |
code = markdown_to_html_with_syntax_highlight(code)
|
| 96 |
result.append(code)
|
| 97 |
result = "".join(result)
|
| 98 |
return result
|
| 99 |
|
|
|
|
| 100 |
def detect_language(code):
|
| 101 |
if code.startswith("\n"):
|
| 102 |
first_line = ""
|
| 103 |
else:
|
| 104 |
-
first_line = code.strip().split(
|
| 105 |
-
language = first_line.lower() if first_line else
|
| 106 |
-
code_without_language = code[len(first_line):].lstrip() if first_line else code
|
| 107 |
return language, code_without_language
|
| 108 |
|
| 109 |
|
|
@@ -336,26 +342,40 @@ def replace_today(prompt):
|
|
| 336 |
today = datetime.datetime.today().strftime("%Y-%m-%d")
|
| 337 |
return prompt.replace("{current_date}", today)
|
| 338 |
|
|
|
|
| 339 |
def get_geoip():
|
| 340 |
-
response = requests.get(
|
| 341 |
try:
|
| 342 |
data = response.json()
|
| 343 |
except:
|
| 344 |
-
data = {
|
| 345 |
-
"error": True,
|
| 346 |
-
"reason" : "连接ipapi失败"
|
| 347 |
-
}
|
| 348 |
if "error" in data.keys():
|
| 349 |
logging.warning(f"无法获取IP地址信息。\n{data}")
|
| 350 |
-
if data[
|
| 351 |
-
return
|
|
|
|
|
|
|
| 352 |
else:
|
| 353 |
return f"获取IP地理位置失败。原因:{data['reason']}。你仍然可以使用聊天功能。"
|
| 354 |
else:
|
| 355 |
-
country = data[
|
| 356 |
if country == "China":
|
| 357 |
text = "**您的IP区域:中国。请立即检查代理设置,在不受支持的地区使用API可能导致账号被封禁。**"
|
| 358 |
else:
|
| 359 |
text = f"您的IP区域:{country}。"
|
| 360 |
logging.info(text)
|
| 361 |
-
return text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
length = len(encoding.encode(input_str))
|
| 38 |
return length
|
| 39 |
|
| 40 |
+
|
| 41 |
def markdown_to_html_with_syntax_highlight(md_str):
|
| 42 |
def replacer(match):
|
| 43 |
+
lang = match.group(1) or "text"
|
| 44 |
code = match.group(2)
|
| 45 |
|
| 46 |
try:
|
|
|
|
| 51 |
formatter = HtmlFormatter()
|
| 52 |
highlighted_code = highlight(code, lexer, formatter)
|
| 53 |
|
| 54 |
+
return f'<pre><code class="{lang}">{highlighted_code}</code></pre>'
|
| 55 |
|
| 56 |
+
code_block_pattern = r"```(\w+)?\n([\s\S]+?)\n```"
|
| 57 |
md_str = re.sub(code_block_pattern, replacer, md_str, flags=re.MULTILINE)
|
| 58 |
|
| 59 |
html_str = markdown(md_str)
|
| 60 |
return html_str
|
| 61 |
|
| 62 |
+
|
| 63 |
def normalize_markdown(md_text: str) -> str:
|
| 64 |
+
lines = md_text.split("\n")
|
| 65 |
normalized_lines = []
|
| 66 |
inside_list = False
|
| 67 |
|
| 68 |
for i, line in enumerate(lines):
|
| 69 |
+
if re.match(r"^(\d+\.|-|\*|\+)\s", line.strip()):
|
| 70 |
+
if not inside_list and i > 0 and lines[i - 1].strip() != "":
|
| 71 |
+
normalized_lines.append("")
|
| 72 |
inside_list = True
|
| 73 |
normalized_lines.append(line)
|
| 74 |
+
elif inside_list and line.strip() == "":
|
| 75 |
+
if i < len(lines) - 1 and not re.match(
|
| 76 |
+
r"^(\d+\.|-|\*|\+)\s", lines[i + 1].strip()
|
| 77 |
+
):
|
| 78 |
normalized_lines.append(line)
|
| 79 |
continue
|
| 80 |
else:
|
| 81 |
inside_list = False
|
| 82 |
normalized_lines.append(line)
|
| 83 |
|
| 84 |
+
return "\n".join(normalized_lines)
|
| 85 |
+
|
| 86 |
|
| 87 |
def convert_mdtext(md_text):
|
| 88 |
+
code_block_pattern = re.compile(r"```(.*?)(?:```|$)", re.DOTALL)
|
| 89 |
code_blocks = code_block_pattern.findall(md_text)
|
| 90 |
non_code_parts = code_block_pattern.split(md_text)[::2]
|
| 91 |
|
| 92 |
result = []
|
| 93 |
+
for non_code, code in zip(non_code_parts, code_blocks + [""]):
|
| 94 |
if non_code.strip():
|
| 95 |
non_code = normalize_markdown(non_code)
|
| 96 |
+
result.append(mdtex2html.convert(non_code, extensions=["tables"]))
|
| 97 |
if code.strip():
|
| 98 |
+
_, code = detect_language(code) # 暂时去除代码高亮功能,因为在大段代码的情况下会出现问题
|
| 99 |
code = f"```{code}\n\n```"
|
| 100 |
code = markdown_to_html_with_syntax_highlight(code)
|
| 101 |
result.append(code)
|
| 102 |
result = "".join(result)
|
| 103 |
return result
|
| 104 |
|
| 105 |
+
|
| 106 |
def detect_language(code):
|
| 107 |
if code.startswith("\n"):
|
| 108 |
first_line = ""
|
| 109 |
else:
|
| 110 |
+
first_line = code.strip().split("\n", 1)[0]
|
| 111 |
+
language = first_line.lower() if first_line else ""
|
| 112 |
+
code_without_language = code[len(first_line) :].lstrip() if first_line else code
|
| 113 |
return language, code_without_language
|
| 114 |
|
| 115 |
|
|
|
|
| 342 |
today = datetime.datetime.today().strftime("%Y-%m-%d")
|
| 343 |
return prompt.replace("{current_date}", today)
|
| 344 |
|
| 345 |
+
|
| 346 |
def get_geoip():
|
| 347 |
+
response = requests.get("https://ipapi.co/json/", timeout=5)
|
| 348 |
try:
|
| 349 |
data = response.json()
|
| 350 |
except:
|
| 351 |
+
data = {"error": True, "reason": "连接ipapi失败"}
|
|
|
|
|
|
|
|
|
|
| 352 |
if "error" in data.keys():
|
| 353 |
logging.warning(f"无法获取IP地址信息。\n{data}")
|
| 354 |
+
if data["reason"] == "RateLimited":
|
| 355 |
+
return (
|
| 356 |
+
f"获取IP地理位置失败,因为达到了检测IP的速率限制。聊天功能可能仍然可用,但请注意,如果您的IP地址在不受支持的地区,您可能会遇到问题。"
|
| 357 |
+
)
|
| 358 |
else:
|
| 359 |
return f"获取IP地理位置失败。原因:{data['reason']}。你仍然可以使用聊天功能。"
|
| 360 |
else:
|
| 361 |
+
country = data["country_name"]
|
| 362 |
if country == "China":
|
| 363 |
text = "**您的IP区域:中国。请立即检查代理设置,在不受支持的地区使用API可能导致账号被封禁。**"
|
| 364 |
else:
|
| 365 |
text = f"您的IP区域:{country}。"
|
| 366 |
logging.info(text)
|
| 367 |
+
return text
|
| 368 |
+
|
| 369 |
+
|
| 370 |
+
def find_n(lst, max_num):
|
| 371 |
+
n = len(lst)
|
| 372 |
+
total = sum(lst)
|
| 373 |
+
|
| 374 |
+
if total < max_num:
|
| 375 |
+
return n
|
| 376 |
+
|
| 377 |
+
for i in range(len(lst)):
|
| 378 |
+
if total - lst[i] < max_num:
|
| 379 |
+
return n - i -1
|
| 380 |
+
total = total - lst[i]
|
| 381 |
+
return 1
|