Spaces:
Sleeping
Sleeping
Tuchuanhuhuhu
commited on
Commit
·
2c5812c
1
Parent(s):
3fe8fc4
加入中止回答的功能
Browse files- ChuanhuChatbot.py +58 -43
- custom.css → assets/custom.css +0 -0
- chat_func.py → modules/chat_func.py +22 -12
- llama_func.py → modules/llama_func.py +2 -2
- overwrites.py → modules/overwrites.py +3 -3
- presets.py → modules/presets.py +2 -1
- modules/shared.py +24 -0
- utils.py → modules/utils.py +27 -13
ChuanhuChatbot.py
CHANGED
|
@@ -5,10 +5,10 @@ import sys
|
|
| 5 |
|
| 6 |
import gradio as gr
|
| 7 |
|
| 8 |
-
from utils import *
|
| 9 |
-
from presets import *
|
| 10 |
-
from overwrites import *
|
| 11 |
-
from chat_func import *
|
| 12 |
|
| 13 |
logging.basicConfig(
|
| 14 |
level=logging.DEBUG,
|
|
@@ -54,7 +54,7 @@ else:
|
|
| 54 |
gr.Chatbot.postprocess = postprocess
|
| 55 |
PromptHelper.compact_text_chunks = compact_text_chunks
|
| 56 |
|
| 57 |
-
with open("custom.css", "r", encoding="utf-8") as f:
|
| 58 |
customCSS = f.read()
|
| 59 |
|
| 60 |
with gr.Blocks(
|
|
@@ -124,8 +124,7 @@ with gr.Blocks(
|
|
| 124 |
token_count = gr.State([])
|
| 125 |
promptTemplates = gr.State(load_template(get_template_names(plain=True)[0], mode=2))
|
| 126 |
user_api_key = gr.State(my_api_key)
|
| 127 |
-
|
| 128 |
-
FALSECONSTANT = gr.State(False)
|
| 129 |
topic = gr.State("未命名对话历史记录")
|
| 130 |
|
| 131 |
with gr.Row():
|
|
@@ -275,12 +274,9 @@ with gr.Blocks(
|
|
| 275 |
|
| 276 |
gr.Markdown(description)
|
| 277 |
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
user_input.submit(
|
| 282 |
-
predict,
|
| 283 |
-
[
|
| 284 |
user_api_key,
|
| 285 |
systemPromptTxt,
|
| 286 |
history,
|
|
@@ -294,40 +290,45 @@ with gr.Blocks(
|
|
| 294 |
use_websearch_checkbox,
|
| 295 |
index_files,
|
| 296 |
],
|
| 297 |
-
[chatbot, history, status_display, token_count],
|
| 298 |
show_progress=True,
|
| 299 |
)
|
| 300 |
-
user_input.submit(reset_textbox, [], [user_input])
|
| 301 |
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 321 |
)
|
| 322 |
-
submitBtn.click(reset_textbox, [], [user_input])
|
| 323 |
|
| 324 |
emptyBtn.click(
|
| 325 |
reset_state,
|
| 326 |
outputs=[chatbot, history, token_count, status_display],
|
| 327 |
show_progress=True,
|
| 328 |
-
)
|
| 329 |
|
| 330 |
-
retryBtn.click(
|
| 331 |
retry,
|
| 332 |
[
|
| 333 |
user_api_key,
|
|
@@ -342,7 +343,7 @@ with gr.Blocks(
|
|
| 342 |
],
|
| 343 |
[chatbot, history, status_display, token_count],
|
| 344 |
show_progress=True,
|
| 345 |
-
)
|
| 346 |
|
| 347 |
delLastBtn.click(
|
| 348 |
delete_last_conversation,
|
|
@@ -441,17 +442,31 @@ if __name__ == "__main__":
|
|
| 441 |
if dockerflag:
|
| 442 |
if authflag:
|
| 443 |
demo.queue(concurrency_count=CONCURRENT_COUNT).launch(
|
| 444 |
-
server_name="0.0.0.0",
|
| 445 |
-
|
|
|
|
|
|
|
| 446 |
)
|
| 447 |
else:
|
| 448 |
-
demo.queue(concurrency_count=CONCURRENT_COUNT).launch(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 449 |
# if not running in Docker
|
| 450 |
else:
|
| 451 |
if authflag:
|
| 452 |
-
demo.queue(concurrency_count=CONCURRENT_COUNT).launch(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 453 |
else:
|
| 454 |
-
demo.queue(concurrency_count=CONCURRENT_COUNT).launch(
|
|
|
|
|
|
|
| 455 |
# demo.queue(concurrency_count=CONCURRENT_COUNT).launch(server_name="0.0.0.0", server_port=7860, share=False) # 可自定义端口
|
| 456 |
# demo.queue(concurrency_count=CONCURRENT_COUNT).launch(server_name="0.0.0.0", server_port=7860,auth=("在这里填写用户名", "在这里填写密码")) # 可设置用户名与密码
|
| 457 |
# demo.queue(concurrency_count=CONCURRENT_COUNT).launch(auth=("在这里填写用户名", "在这里填写密码")) # 适合Nginx反向代理
|
|
|
|
| 5 |
|
| 6 |
import gradio as gr
|
| 7 |
|
| 8 |
+
from modules.utils import *
|
| 9 |
+
from modules.presets import *
|
| 10 |
+
from modules.overwrites import *
|
| 11 |
+
from modules.chat_func import *
|
| 12 |
|
| 13 |
logging.basicConfig(
|
| 14 |
level=logging.DEBUG,
|
|
|
|
| 54 |
gr.Chatbot.postprocess = postprocess
|
| 55 |
PromptHelper.compact_text_chunks = compact_text_chunks
|
| 56 |
|
| 57 |
+
with open("assets/custom.css", "r", encoding="utf-8") as f:
|
| 58 |
customCSS = f.read()
|
| 59 |
|
| 60 |
with gr.Blocks(
|
|
|
|
| 124 |
token_count = gr.State([])
|
| 125 |
promptTemplates = gr.State(load_template(get_template_names(plain=True)[0], mode=2))
|
| 126 |
user_api_key = gr.State(my_api_key)
|
| 127 |
+
outputing = gr.State(False)
|
|
|
|
| 128 |
topic = gr.State("未命名对话历史记录")
|
| 129 |
|
| 130 |
with gr.Row():
|
|
|
|
| 274 |
|
| 275 |
gr.Markdown(description)
|
| 276 |
|
| 277 |
+
chatgpt_predict_args = dict(
|
| 278 |
+
fn=predict,
|
| 279 |
+
inputs=[
|
|
|
|
|
|
|
|
|
|
| 280 |
user_api_key,
|
| 281 |
systemPromptTxt,
|
| 282 |
history,
|
|
|
|
| 290 |
use_websearch_checkbox,
|
| 291 |
index_files,
|
| 292 |
],
|
| 293 |
+
outputs=[chatbot, history, status_display, token_count],
|
| 294 |
show_progress=True,
|
| 295 |
)
|
|
|
|
| 296 |
|
| 297 |
+
start_outputing_args = dict(
|
| 298 |
+
fn=start_outputing, inputs=[], outputs=[submitBtn, cancelBtn], show_progress=True
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
end_outputing_args = dict(
|
| 302 |
+
fn=end_outputing, inputs=[], outputs=[submitBtn, cancelBtn]
|
| 303 |
+
)
|
| 304 |
+
|
| 305 |
+
reset_textbox_args = dict(
|
| 306 |
+
fn=reset_textbox, inputs=[], outputs=[user_input], show_progress=True
|
| 307 |
+
)
|
| 308 |
+
|
| 309 |
+
keyTxt.submit(submit_key, keyTxt, [user_api_key, status_display])
|
| 310 |
+
keyTxt.change(submit_key, keyTxt, [user_api_key, status_display])
|
| 311 |
+
# Chatbot
|
| 312 |
+
cancelBtn.click(cancel_outputing, [], [])
|
| 313 |
+
|
| 314 |
+
user_input.submit(**start_outputing_args).then(
|
| 315 |
+
**chatgpt_predict_args
|
| 316 |
+
).then(**reset_textbox_args).then(
|
| 317 |
+
**end_outputing_args
|
| 318 |
+
)
|
| 319 |
+
submitBtn.click(**start_outputing_args).then(
|
| 320 |
+
**chatgpt_predict_args
|
| 321 |
+
).then(**reset_textbox_args).then(
|
| 322 |
+
**end_outputing_args
|
| 323 |
)
|
|
|
|
| 324 |
|
| 325 |
emptyBtn.click(
|
| 326 |
reset_state,
|
| 327 |
outputs=[chatbot, history, token_count, status_display],
|
| 328 |
show_progress=True,
|
| 329 |
+
).then(**reset_textbox_args)
|
| 330 |
|
| 331 |
+
retryBtn.click(**start_outputing_args).then(
|
| 332 |
retry,
|
| 333 |
[
|
| 334 |
user_api_key,
|
|
|
|
| 343 |
],
|
| 344 |
[chatbot, history, status_display, token_count],
|
| 345 |
show_progress=True,
|
| 346 |
+
).then(**end_outputing_args)
|
| 347 |
|
| 348 |
delLastBtn.click(
|
| 349 |
delete_last_conversation,
|
|
|
|
| 442 |
if dockerflag:
|
| 443 |
if authflag:
|
| 444 |
demo.queue(concurrency_count=CONCURRENT_COUNT).launch(
|
| 445 |
+
server_name="0.0.0.0",
|
| 446 |
+
server_port=7860,
|
| 447 |
+
auth=(username, password),
|
| 448 |
+
favicon_path="./assets/favicon.png",
|
| 449 |
)
|
| 450 |
else:
|
| 451 |
+
demo.queue(concurrency_count=CONCURRENT_COUNT).launch(
|
| 452 |
+
server_name="0.0.0.0",
|
| 453 |
+
server_port=7860,
|
| 454 |
+
share=False,
|
| 455 |
+
favicon_path="./assets/favicon.png",
|
| 456 |
+
)
|
| 457 |
# if not running in Docker
|
| 458 |
else:
|
| 459 |
if authflag:
|
| 460 |
+
demo.queue(concurrency_count=CONCURRENT_COUNT).launch(
|
| 461 |
+
share=False,
|
| 462 |
+
auth=(username, password),
|
| 463 |
+
favicon_path="./assets/favicon.png",
|
| 464 |
+
inbrowser=True,
|
| 465 |
+
)
|
| 466 |
else:
|
| 467 |
+
demo.queue(concurrency_count=CONCURRENT_COUNT).launch(
|
| 468 |
+
share=False, favicon_path="./assets/favicon.ico", inbrowser=True
|
| 469 |
+
) # 改为 share=True 可以创建公开分享链接
|
| 470 |
# demo.queue(concurrency_count=CONCURRENT_COUNT).launch(server_name="0.0.0.0", server_port=7860, share=False) # 可自定义端口
|
| 471 |
# demo.queue(concurrency_count=CONCURRENT_COUNT).launch(server_name="0.0.0.0", server_port=7860,auth=("在这里填写用户名", "在这里填写密码")) # 可设置用户名与密码
|
| 472 |
# demo.queue(concurrency_count=CONCURRENT_COUNT).launch(auth=("在这里填写用户名", "在这里填写密码")) # 适合Nginx反向代理
|
custom.css → assets/custom.css
RENAMED
|
File without changes
|
chat_func.py → modules/chat_func.py
RENAMED
|
@@ -14,9 +14,10 @@ from duckduckgo_search import ddg
|
|
| 14 |
import asyncio
|
| 15 |
import aiohttp
|
| 16 |
|
| 17 |
-
from presets import *
|
| 18 |
-
from llama_func import *
|
| 19 |
-
from utils import *
|
|
|
|
| 20 |
|
| 21 |
# logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s")
|
| 22 |
|
|
@@ -29,7 +30,6 @@ if TYPE_CHECKING:
|
|
| 29 |
|
| 30 |
|
| 31 |
initial_prompt = "You are a helpful assistant."
|
| 32 |
-
API_URL = "https://api.openai.com/v1/chat/completions"
|
| 33 |
HISTORY_DIR = "history"
|
| 34 |
TEMPLATES_DIR = "templates"
|
| 35 |
|
|
@@ -65,16 +65,18 @@ def get_response(
|
|
| 65 |
# 如果存在代理设置,使用它们
|
| 66 |
proxies = {}
|
| 67 |
if http_proxy:
|
| 68 |
-
logging.info(f"
|
| 69 |
proxies["http"] = http_proxy
|
| 70 |
if https_proxy:
|
| 71 |
-
logging.info(f"
|
| 72 |
proxies["https"] = https_proxy
|
| 73 |
|
| 74 |
# 如果有代理,使用代理发送请求,否则使用默认设置发送请求
|
|
|
|
|
|
|
| 75 |
if proxies:
|
| 76 |
response = requests.post(
|
| 77 |
-
|
| 78 |
headers=headers,
|
| 79 |
json=payload,
|
| 80 |
stream=True,
|
|
@@ -83,7 +85,7 @@ def get_response(
|
|
| 83 |
)
|
| 84 |
else:
|
| 85 |
response = requests.post(
|
| 86 |
-
|
| 87 |
headers=headers,
|
| 88 |
json=payload,
|
| 89 |
stream=True,
|
|
@@ -268,10 +270,10 @@ def predict(
|
|
| 268 |
if files:
|
| 269 |
msg = "构建索引中……(这可能需要比较久的时间)"
|
| 270 |
logging.info(msg)
|
| 271 |
-
yield chatbot, history, msg, all_token_counts
|
| 272 |
index = construct_index(openai_api_key, file_src=files)
|
| 273 |
msg = "索引构建完成,获取回答中……"
|
| 274 |
-
yield chatbot, history, msg, all_token_counts
|
| 275 |
history, chatbot, status_text = chat_ai(openai_api_key, index, inputs, history, chatbot)
|
| 276 |
yield chatbot, history, status_text, all_token_counts
|
| 277 |
return
|
|
@@ -306,10 +308,15 @@ def predict(
|
|
| 306 |
all_token_counts.append(0)
|
| 307 |
else:
|
| 308 |
history[-2] = construct_user(inputs)
|
| 309 |
-
yield chatbot, history, status_text, all_token_counts
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 310 |
return
|
| 311 |
|
| 312 |
-
yield chatbot, history, "开始生成回答……", all_token_counts
|
| 313 |
|
| 314 |
if stream:
|
| 315 |
logging.info("使用流式传输")
|
|
@@ -327,6 +334,9 @@ def predict(
|
|
| 327 |
display_append=link_references
|
| 328 |
)
|
| 329 |
for chatbot, history, status_text, all_token_counts in iter:
|
|
|
|
|
|
|
|
|
|
| 330 |
yield chatbot, history, status_text, all_token_counts
|
| 331 |
else:
|
| 332 |
logging.info("不使用流式传输")
|
|
|
|
| 14 |
import asyncio
|
| 15 |
import aiohttp
|
| 16 |
|
| 17 |
+
from modules.presets import *
|
| 18 |
+
from modules.llama_func import *
|
| 19 |
+
from modules.utils import *
|
| 20 |
+
import modules.shared as shared
|
| 21 |
|
| 22 |
# logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s")
|
| 23 |
|
|
|
|
| 30 |
|
| 31 |
|
| 32 |
initial_prompt = "You are a helpful assistant."
|
|
|
|
| 33 |
HISTORY_DIR = "history"
|
| 34 |
TEMPLATES_DIR = "templates"
|
| 35 |
|
|
|
|
| 65 |
# 如果存在代理设置,使用它们
|
| 66 |
proxies = {}
|
| 67 |
if http_proxy:
|
| 68 |
+
logging.info(f"使用 HTTP 代理: {http_proxy}")
|
| 69 |
proxies["http"] = http_proxy
|
| 70 |
if https_proxy:
|
| 71 |
+
logging.info(f"使用 HTTPS 代理: {https_proxy}")
|
| 72 |
proxies["https"] = https_proxy
|
| 73 |
|
| 74 |
# 如果有代理,使用代理发送请求,否则使用默认设置发送请求
|
| 75 |
+
if shared.state.api_url != API_URL:
|
| 76 |
+
logging.info(f"使用自定义API URL: {shared.state.api_url}")
|
| 77 |
if proxies:
|
| 78 |
response = requests.post(
|
| 79 |
+
shared.state.api_url,
|
| 80 |
headers=headers,
|
| 81 |
json=payload,
|
| 82 |
stream=True,
|
|
|
|
| 85 |
)
|
| 86 |
else:
|
| 87 |
response = requests.post(
|
| 88 |
+
shared.state.api_url,
|
| 89 |
headers=headers,
|
| 90 |
json=payload,
|
| 91 |
stream=True,
|
|
|
|
| 270 |
if files:
|
| 271 |
msg = "构建索引中……(这可能需要比较久的时间)"
|
| 272 |
logging.info(msg)
|
| 273 |
+
yield chatbot+[(inputs, "")], history, msg, all_token_counts
|
| 274 |
index = construct_index(openai_api_key, file_src=files)
|
| 275 |
msg = "索引构建完成,获取回答中……"
|
| 276 |
+
yield chatbot+[(inputs, "")], history, msg, all_token_counts
|
| 277 |
history, chatbot, status_text = chat_ai(openai_api_key, index, inputs, history, chatbot)
|
| 278 |
yield chatbot, history, status_text, all_token_counts
|
| 279 |
return
|
|
|
|
| 308 |
all_token_counts.append(0)
|
| 309 |
else:
|
| 310 |
history[-2] = construct_user(inputs)
|
| 311 |
+
yield chatbot+[(inputs, "")], history, status_text, all_token_counts
|
| 312 |
+
return
|
| 313 |
+
elif len(inputs.strip()) == 0:
|
| 314 |
+
status_text = standard_error_msg + no_input_msg
|
| 315 |
+
logging.info(status_text)
|
| 316 |
+
yield chatbot+[(inputs, "")], history, status_text, all_token_counts
|
| 317 |
return
|
| 318 |
|
| 319 |
+
yield chatbot+[(inputs, "")], history, "开始生成回答……", all_token_counts
|
| 320 |
|
| 321 |
if stream:
|
| 322 |
logging.info("使用流式传输")
|
|
|
|
| 334 |
display_append=link_references
|
| 335 |
)
|
| 336 |
for chatbot, history, status_text, all_token_counts in iter:
|
| 337 |
+
if shared.state.interrupted:
|
| 338 |
+
shared.state.recover()
|
| 339 |
+
return
|
| 340 |
yield chatbot, history, status_text, all_token_counts
|
| 341 |
else:
|
| 342 |
logging.info("不使用流式传输")
|
llama_func.py → modules/llama_func.py
RENAMED
|
@@ -14,8 +14,8 @@ from langchain.llms import OpenAI
|
|
| 14 |
import colorama
|
| 15 |
|
| 16 |
|
| 17 |
-
from presets import *
|
| 18 |
-
from utils import *
|
| 19 |
|
| 20 |
|
| 21 |
def get_documents(file_src):
|
|
|
|
| 14 |
import colorama
|
| 15 |
|
| 16 |
|
| 17 |
+
from modules.presets import *
|
| 18 |
+
from modules.utils import *
|
| 19 |
|
| 20 |
|
| 21 |
def get_documents(file_src):
|
overwrites.py → modules/overwrites.py
RENAMED
|
@@ -5,8 +5,8 @@ from llama_index import Prompt
|
|
| 5 |
from typing import List, Tuple
|
| 6 |
import mdtex2html
|
| 7 |
|
| 8 |
-
from presets import *
|
| 9 |
-
from llama_func import *
|
| 10 |
|
| 11 |
|
| 12 |
def compact_text_chunks(self, prompt: Prompt, text_chunks: List[str]) -> List[str]:
|
|
@@ -51,5 +51,5 @@ def reload_javascript():
|
|
| 51 |
return res
|
| 52 |
|
| 53 |
gr.routes.templates.TemplateResponse = template_response
|
| 54 |
-
|
| 55 |
GradioTemplateResponseOriginal = gr.routes.templates.TemplateResponse
|
|
|
|
| 5 |
from typing import List, Tuple
|
| 6 |
import mdtex2html
|
| 7 |
|
| 8 |
+
from modules.presets import *
|
| 9 |
+
from modules.llama_func import *
|
| 10 |
|
| 11 |
|
| 12 |
def compact_text_chunks(self, prompt: Prompt, text_chunks: List[str]) -> List[str]:
|
|
|
|
| 51 |
return res
|
| 52 |
|
| 53 |
gr.routes.templates.TemplateResponse = template_response
|
| 54 |
+
|
| 55 |
GradioTemplateResponseOriginal = gr.routes.templates.TemplateResponse
|
presets.py → modules/presets.py
RENAMED
|
@@ -14,9 +14,10 @@ read_timeout_prompt = "读取超时,无法获取对话。" # 读取超时
|
|
| 14 |
proxy_error_prompt = "代理错误,无法获取对话。" # 代理错误
|
| 15 |
ssl_error_prompt = "SSL错误,无法获取对话。" # SSL 错误
|
| 16 |
no_apikey_msg = "API key长度不是51位,请检查是否输入正确。" # API key 长度不足 51 位
|
|
|
|
| 17 |
|
| 18 |
max_token_streaming = 3500 # 流式对话时的最大 token 数
|
| 19 |
-
timeout_streaming =
|
| 20 |
max_token_all = 3500 # 非流式对话时的最大 token 数
|
| 21 |
timeout_all = 200 # 非流式对话时的超时时间
|
| 22 |
enable_streaming_option = True # 是否启用选择选择是否实时显示回答的勾选框
|
|
|
|
| 14 |
proxy_error_prompt = "代理错误,无法获取对话。" # 代理错误
|
| 15 |
ssl_error_prompt = "SSL错误,无法获取对话。" # SSL 错误
|
| 16 |
no_apikey_msg = "API key长度不是51位,请检查是否输入正确。" # API key 长度不足 51 位
|
| 17 |
+
no_input_msg = "请输入对话内容。" # 未输入对话内容
|
| 18 |
|
| 19 |
max_token_streaming = 3500 # 流式对话时的最大 token 数
|
| 20 |
+
timeout_streaming = 5 # 流式对话时的超时时间
|
| 21 |
max_token_all = 3500 # 非流式对话时的最大 token 数
|
| 22 |
timeout_all = 200 # 非流式对话时的超时时间
|
| 23 |
enable_streaming_option = True # 是否启用选择选择是否实时显示回答的勾选框
|
modules/shared.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from modules.presets import API_URL
|
| 2 |
+
|
| 3 |
+
class State:
|
| 4 |
+
interrupted = False
|
| 5 |
+
api_url = API_URL
|
| 6 |
+
|
| 7 |
+
def interrupt(self):
|
| 8 |
+
self.interrupted = True
|
| 9 |
+
|
| 10 |
+
def recover(self):
|
| 11 |
+
self.interrupted = False
|
| 12 |
+
|
| 13 |
+
def set_api_url(self, api_url):
|
| 14 |
+
self.api_url = api_url
|
| 15 |
+
|
| 16 |
+
def reset_api_url(self):
|
| 17 |
+
self.api_url = API_URL
|
| 18 |
+
return self.api_url
|
| 19 |
+
|
| 20 |
+
def reset_all(self):
|
| 21 |
+
self.interrupted = False
|
| 22 |
+
self.api_url = API_URL
|
| 23 |
+
|
| 24 |
+
state = State()
|
utils.py → modules/utils.py
RENAMED
|
@@ -19,9 +19,13 @@ from pygments import highlight
|
|
| 19 |
from pygments.lexers import get_lexer_by_name
|
| 20 |
from pygments.formatters import HtmlFormatter
|
| 21 |
|
| 22 |
-
from presets import *
|
|
|
|
| 23 |
|
| 24 |
-
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
if TYPE_CHECKING:
|
| 27 |
from typing import TypedDict
|
|
@@ -107,10 +111,12 @@ def convert_mdtext(md_text):
|
|
| 107 |
result = "".join(result)
|
| 108 |
return result
|
| 109 |
|
|
|
|
| 110 |
def convert_user(userinput):
|
| 111 |
userinput = userinput.replace("\n", "<br>")
|
| 112 |
return f"<pre>{userinput}</pre>"
|
| 113 |
|
|
|
|
| 114 |
def detect_language(code):
|
| 115 |
if code.startswith("\n"):
|
| 116 |
first_line = ""
|
|
@@ -297,20 +303,19 @@ def reset_state():
|
|
| 297 |
|
| 298 |
|
| 299 |
def reset_textbox():
|
|
|
|
| 300 |
return gr.update(value="")
|
| 301 |
|
| 302 |
|
| 303 |
def reset_default():
|
| 304 |
-
|
| 305 |
-
API_URL = "https://api.openai.com/v1/chat/completions"
|
| 306 |
os.environ.pop("HTTPS_PROXY", None)
|
| 307 |
os.environ.pop("https_proxy", None)
|
| 308 |
-
return gr.update(value=
|
| 309 |
|
| 310 |
|
| 311 |
def change_api_url(url):
|
| 312 |
-
|
| 313 |
-
API_URL = url
|
| 314 |
msg = f"API地址更改为了{url}"
|
| 315 |
logging.info(msg)
|
| 316 |
return msg
|
|
@@ -384,13 +389,22 @@ def find_n(lst, max_num):
|
|
| 384 |
|
| 385 |
for i in range(len(lst)):
|
| 386 |
if total - lst[i] < max_num:
|
| 387 |
-
return n - i -1
|
| 388 |
total = total - lst[i]
|
| 389 |
return 1
|
| 390 |
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
), gr.Button.update(
|
| 395 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 396 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
from pygments.lexers import get_lexer_by_name
|
| 20 |
from pygments.formatters import HtmlFormatter
|
| 21 |
|
| 22 |
+
from modules.presets import *
|
| 23 |
+
import modules.shared as shared
|
| 24 |
|
| 25 |
+
logging.basicConfig(
|
| 26 |
+
level=logging.INFO,
|
| 27 |
+
format="%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s",
|
| 28 |
+
)
|
| 29 |
|
| 30 |
if TYPE_CHECKING:
|
| 31 |
from typing import TypedDict
|
|
|
|
| 111 |
result = "".join(result)
|
| 112 |
return result
|
| 113 |
|
| 114 |
+
|
| 115 |
def convert_user(userinput):
|
| 116 |
userinput = userinput.replace("\n", "<br>")
|
| 117 |
return f"<pre>{userinput}</pre>"
|
| 118 |
|
| 119 |
+
|
| 120 |
def detect_language(code):
|
| 121 |
if code.startswith("\n"):
|
| 122 |
first_line = ""
|
|
|
|
| 303 |
|
| 304 |
|
| 305 |
def reset_textbox():
|
| 306 |
+
logging.debug("重置文本框")
|
| 307 |
return gr.update(value="")
|
| 308 |
|
| 309 |
|
| 310 |
def reset_default():
|
| 311 |
+
newurl = shared.state.reset_all()
|
|
|
|
| 312 |
os.environ.pop("HTTPS_PROXY", None)
|
| 313 |
os.environ.pop("https_proxy", None)
|
| 314 |
+
return gr.update(value=newurl), gr.update(value=""), "API URL 和代理已重置"
|
| 315 |
|
| 316 |
|
| 317 |
def change_api_url(url):
|
| 318 |
+
shared.state.set_api_url(url)
|
|
|
|
| 319 |
msg = f"API地址更改为了{url}"
|
| 320 |
logging.info(msg)
|
| 321 |
return msg
|
|
|
|
| 389 |
|
| 390 |
for i in range(len(lst)):
|
| 391 |
if total - lst[i] < max_num:
|
| 392 |
+
return n - i - 1
|
| 393 |
total = total - lst[i]
|
| 394 |
return 1
|
| 395 |
|
| 396 |
+
|
| 397 |
+
def start_outputing():
|
| 398 |
+
logging.debug("显示取消按钮,隐藏发送按钮")
|
| 399 |
+
return gr.Button.update(visible=False), gr.Button.update(visible=True)
|
| 400 |
+
|
| 401 |
+
def end_outputing():
|
| 402 |
+
return (
|
| 403 |
+
gr.Button.update(visible=True),
|
| 404 |
+
gr.Button.update(visible=False),
|
| 405 |
)
|
| 406 |
+
|
| 407 |
+
|
| 408 |
+
def cancel_outputing():
|
| 409 |
+
logging.info("中止输出……")
|
| 410 |
+
shared.state.interrupt()
|