curry tang commited on
Commit
0f77540
·
1 Parent(s): 95d0aed
Files changed (3) hide show
  1. app.py +37 -21
  2. requirements-dev.lock +2 -4
  3. requirements.lock +2 -4
app.py CHANGED
@@ -4,12 +4,11 @@ from langchain_core.messages import HumanMessage, AIMessage
4
  from llm import DeepSeekLLM
5
  from config import settings
6
 
 
7
  deep_seek_llm = DeepSeekLLM(api_key=settings.deep_seek_api_key)
8
- chat = ChatOpenAI(model=deep_seek_llm.default_model, api_key=deep_seek_llm.api_key, base_url=deep_seek_llm.base_url)
9
 
10
 
11
- def predict(message, history, model: str, temperature: float, max_tokens: int):
12
- print('???model', model, temperature, max_tokens)
13
  history_messages = []
14
  for human, assistant in history:
15
  history_messages.append(HumanMessage(content=human))
@@ -22,38 +21,52 @@ def predict(message, history, model: str, temperature: float, max_tokens: int):
22
  yield response_message
23
 
24
 
 
 
 
 
 
 
 
 
 
 
 
25
  with gr.Blocks() as app:
26
  with gr.Tab('聊天'):
 
 
 
 
 
 
 
27
  with gr.Row():
28
  with gr.Column(scale=2, min_width=600):
29
  chatbot = gr.ChatInterface(
30
  predict,
31
  multimodal=True,
32
  chatbot=gr.Chatbot(elem_id="chatbot", height=600),
33
- textbox=gr.MultimodalTextbox(),
34
- additional_inputs=[
35
- gr.Dropdown(choices=deep_seek_llm.support_models, label='模型'),
36
- gr.Slider(
 
 
 
 
 
 
 
 
37
  minimum=0.0,
38
  maximum=1.0,
39
  step=0.1,
 
40
  label="Temperature",
41
  key="temperature",
42
- ),
43
- gr.Number(
44
- minimum=1024,
45
- maximum=1024 * 20,
46
- step=128,
47
- value=4096,
48
- label="Max Tokens",
49
- key="max_tokens",
50
  )
51
- ],
52
- )
53
- with gr.Column(scale=1, min_width=300):
54
- with gr.Accordion('Select Model', open=True):
55
- with gr.Column():
56
- gr.Number(
57
  minimum=1024,
58
  maximum=1024 * 20,
59
  step=128,
@@ -61,6 +74,9 @@ with gr.Blocks() as app:
61
  label="Max Tokens",
62
  key="max_tokens",
63
  )
 
 
 
64
 
65
  with gr.Tab('画图'):
66
  with gr.Row():
 
4
  from llm import DeepSeekLLM
5
  from config import settings
6
 
7
+
8
  deep_seek_llm = DeepSeekLLM(api_key=settings.deep_seek_api_key)
 
9
 
10
 
11
+ def predict(message, history, chat):
 
12
  history_messages = []
13
  for human, assistant in history:
14
  history_messages.append(HumanMessage(content=human))
 
21
  yield response_message
22
 
23
 
24
+ def update_chat(_chat, _model: str, _temperature: float, _max_tokens: int):
25
+ _chat = ChatOpenAI(
26
+ model=_model,
27
+ api_key=deep_seek_llm.api_key,
28
+ base_url=deep_seek_llm.base_url,
29
+ temperature=_temperature,
30
+ max_tokens=_max_tokens,
31
+ )
32
+ return _chat
33
+
34
+
35
  with gr.Blocks() as app:
36
  with gr.Tab('聊天'):
37
+ chat_engine = gr.State(
38
+ value=ChatOpenAI(
39
+ model=deep_seek_llm.default_model,
40
+ api_key=deep_seek_llm.api_key,
41
+ base_url=deep_seek_llm.base_url,
42
+ )
43
+ )
44
  with gr.Row():
45
  with gr.Column(scale=2, min_width=600):
46
  chatbot = gr.ChatInterface(
47
  predict,
48
  multimodal=True,
49
  chatbot=gr.Chatbot(elem_id="chatbot", height=600),
50
+ textbox=gr.MultimodalTextbox(lines=1),
51
+ additional_inputs=[chat_engine]
52
+ )
53
+ with gr.Column(scale=1, min_width=300):
54
+ with gr.Accordion('Select Model', open=True):
55
+ with gr.Column():
56
+ model = gr.Dropdown(
57
+ label='模型',
58
+ choices=deep_seek_llm.support_models,
59
+ value=deep_seek_llm.default_model
60
+ )
61
+ temperature = gr.Slider(
62
  minimum=0.0,
63
  maximum=1.0,
64
  step=0.1,
65
+ value=0.5,
66
  label="Temperature",
67
  key="temperature",
 
 
 
 
 
 
 
 
68
  )
69
+ max_tokens = gr.Number(
 
 
 
 
 
70
  minimum=1024,
71
  maximum=1024 * 20,
72
  step=128,
 
74
  label="Max Tokens",
75
  key="max_tokens",
76
  )
77
+ model.change(fn=update_chat, inputs=[chat_engine, model, temperature, max_tokens], outputs=[chat_engine])
78
+ temperature.change(fn=update_chat, inputs=[chat_engine, model, temperature, max_tokens], outputs=[chat_engine])
79
+ max_tokens.change(fn=update_chat, inputs=[chat_engine, model, temperature, max_tokens], outputs=[chat_engine])
80
 
81
  with gr.Tab('画图'):
82
  with gr.Row():
requirements-dev.lock CHANGED
@@ -46,10 +46,6 @@ click==8.1.7
46
  # via nltk
47
  # via typer
48
  # via uvicorn
49
- colorama==0.4.6
50
- # via click
51
- # via tqdm
52
- # via uvicorn
53
  contourpy==1.2.1
54
  # via matplotlib
55
  cycler==0.12.1
@@ -404,6 +400,8 @@ urllib3==2.2.2
404
  uvicorn==0.30.1
405
  # via fastapi
406
  # via gradio
 
 
407
  watchfiles==0.22.0
408
  # via uvicorn
409
  websockets==11.0.3
 
46
  # via nltk
47
  # via typer
48
  # via uvicorn
 
 
 
 
49
  contourpy==1.2.1
50
  # via matplotlib
51
  cycler==0.12.1
 
400
  uvicorn==0.30.1
401
  # via fastapi
402
  # via gradio
403
+ uvloop==0.19.0
404
+ # via uvicorn
405
  watchfiles==0.22.0
406
  # via uvicorn
407
  websockets==11.0.3
requirements.lock CHANGED
@@ -46,10 +46,6 @@ click==8.1.7
46
  # via nltk
47
  # via typer
48
  # via uvicorn
49
- colorama==0.4.6
50
- # via click
51
- # via tqdm
52
- # via uvicorn
53
  contourpy==1.2.1
54
  # via matplotlib
55
  cycler==0.12.1
@@ -404,6 +400,8 @@ urllib3==2.2.2
404
  uvicorn==0.30.1
405
  # via fastapi
406
  # via gradio
 
 
407
  watchfiles==0.22.0
408
  # via uvicorn
409
  websockets==11.0.3
 
46
  # via nltk
47
  # via typer
48
  # via uvicorn
 
 
 
 
49
  contourpy==1.2.1
50
  # via matplotlib
51
  cycler==0.12.1
 
400
  uvicorn==0.30.1
401
  # via fastapi
402
  # via gradio
403
+ uvloop==0.19.0
404
+ # via uvicorn
405
  watchfiles==0.22.0
406
  # via uvicorn
407
  websockets==11.0.3