File size: 6,625 Bytes
bb159c0
58384c0
1a92b4b
95d0aed
bb159c0
0f77540
95d0aed
1a92b4b
bb159c0
 
796067e
1a92b4b
796067e
 
0f77540
796067e
 
bb159c0
 
 
 
9d661c1
bb159c0
 
 
 
 
 
 
1a92b4b
 
 
 
 
 
0f77540
 
 
95d0aed
 
796067e
bb159c0
 
95d0aed
9d661c1
 
f3a6c77
0f77540
 
 
 
 
 
1a92b4b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9d661c1
95d0aed
9d661c1
 
 
 
 
 
 
f3a6c77
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import gradio as gr
from langchain_core.messages import HumanMessage, AIMessage
from llm import DeepSeekLLM, OpenRouterLLM
from config import settings


deep_seek_llm = DeepSeekLLM(api_key=settings.deep_seek_api_key)
open_router_llm = OpenRouterLLM(api_key=settings.open_router_api_key)


def init_chat():
    return deep_seek_llm.get_chat_engine()


def predict(message, history, chat):
    if chat is None:
        chat = init_chat()
    history_messages = []
    for human, assistant in history:
        history_messages.append(HumanMessage(content=human))
        history_messages.append(AIMessage(content=assistant))
    history_messages.append(HumanMessage(content=message.text))

    response_message = ''
    for chunk in chat.stream(history_messages):
        response_message = response_message + chunk.content
        yield response_message


def update_chat(_provider: str, _chat, _model: str, _temperature: float, _max_tokens: int):
    print('?????', _provider, _chat, _model, _temperature, _max_tokens)
    if _provider == 'DeepSeek':
        _chat = deep_seek_llm.get_chat_engine(model=_model, temperature=_temperature, max_tokens=_max_tokens)
    if _provider == 'OpenRouter':
        _chat = open_router_llm.get_chat_engine(model=_model, temperature=_temperature, max_tokens=_max_tokens)
    return _chat


with gr.Blocks() as app:
    with gr.Tab('聊天'):
        chat_engine = gr.State(value=None)
        with gr.Row():
            with gr.Column(scale=2, min_width=600):
                chatbot = gr.ChatInterface(
                    predict,
                    multimodal=True,
                    chatbot=gr.Chatbot(elem_id="chatbot", height=600, show_share_button=False),
                    textbox=gr.MultimodalTextbox(lines=1),
                    additional_inputs=[chat_engine]
                )
            with gr.Column(scale=1, min_width=300):
                with gr.Accordion('Select Model', open=True):
                    with gr.Column():
                        provider = gr.Dropdown(label='Provider', choices=['DeepSeek', 'OpenRouter'], value='DeepSeek')

                    @gr.render(inputs=provider)
                    def show_model_config_panel(_provider):
                        if _provider == 'DeepSeek':
                            with gr.Column():
                                model = gr.Dropdown(
                                    label='模型',
                                    choices=deep_seek_llm.support_models,
                                    value=deep_seek_llm.default_model
                                )
                                temperature = gr.Slider(
                                    minimum=0.0,
                                    maximum=1.0,
                                    step=0.1,
                                    value=deep_seek_llm.default_temperature,
                                    label="Temperature",
                                    key="temperature",
                                )
                                max_tokens = gr.Number(
                                    minimum=1024,
                                    maximum=1024 * 20,
                                    step=128,
                                    value=deep_seek_llm.default_max_tokens,
                                    label="Max Tokens",
                                    key="max_tokens",
                                )
                            model.change(
                                fn=update_chat,
                                inputs=[provider, chat_engine, model, temperature, max_tokens],
                                outputs=[chat_engine],
                            )
                            temperature.change(
                                fn=update_chat,
                                inputs=[provider, chat_engine, model, temperature, max_tokens],
                                outputs=[chat_engine],
                            )
                            max_tokens.change(
                                fn=update_chat,
                                inputs=[provider, chat_engine, model, temperature, max_tokens],
                                outputs=[chat_engine],
                            )
                        if _provider == 'OpenRouter':
                            with gr.Column():
                                model = gr.Dropdown(
                                    label='模型',
                                    choices=open_router_llm.support_models,
                                    value=open_router_llm.default_model
                                )
                                temperature = gr.Slider(
                                    minimum=0.0,
                                    maximum=1.0,
                                    step=0.1,
                                    value=open_router_llm.default_temperature,
                                    label="Temperature",
                                    key="temperature",
                                )
                                max_tokens = gr.Number(
                                    minimum=1024,
                                    maximum=1024 * 20,
                                    step=128,
                                    value=open_router_llm.default_max_tokens,
                                    label="Max Tokens",
                                    key="max_tokens",
                                )
                            model.change(
                                fn=update_chat,
                                inputs=[provider, chat_engine, model, temperature, max_tokens],
                                outputs=[chat_engine],
                            )
                            temperature.change(
                                fn=update_chat,
                                inputs=[provider, chat_engine, model, temperature, max_tokens],
                                outputs=[chat_engine],
                            )
                            max_tokens.change(
                                fn=update_chat,
                                inputs=[provider, chat_engine, model, temperature, max_tokens],
                                outputs=[chat_engine],
                            )

    with gr.Tab('画图'):
        with gr.Row():
            with gr.Column(scale=2, min_width=600):
                gr.Image(label="Input Image")
            with gr.Column(scale=1, min_width=300):
                gr.Textbox(label="LoRA")


app.launch(debug=settings.debug, show_api=False)