Spaces:
Runtime error
Runtime error
import gradio as gr | |
from langchain_core.messages import HumanMessage, AIMessage | |
from llm import DeepSeekLLM, OpenRouterLLM | |
from config import settings | |
deep_seek_llm = DeepSeekLLM(api_key=settings.deep_seek_api_key) | |
open_router_llm = OpenRouterLLM(api_key=settings.open_router_api_key) | |
def init_chat(): | |
return deep_seek_llm.get_chat_engine() | |
def predict(message, history, chat): | |
if chat is None: | |
chat = init_chat() | |
history_messages = [] | |
for human, assistant in history: | |
history_messages.append(HumanMessage(content=human)) | |
history_messages.append(AIMessage(content=assistant)) | |
history_messages.append(HumanMessage(content=message.text)) | |
response_message = '' | |
for chunk in chat.stream(history_messages): | |
response_message = response_message + chunk.content | |
yield response_message | |
def update_chat(_provider: str, _chat, _model: str, _temperature: float, _max_tokens: int): | |
print('?????', _provider, _chat, _model, _temperature, _max_tokens) | |
if _provider == 'DeepSeek': | |
_chat = deep_seek_llm.get_chat_engine(model=_model, temperature=_temperature, max_tokens=_max_tokens) | |
if _provider == 'OpenRouter': | |
_chat = open_router_llm.get_chat_engine(model=_model, temperature=_temperature, max_tokens=_max_tokens) | |
return _chat | |
with gr.Blocks() as app: | |
with gr.Tab('聊天'): | |
chat_engine = gr.State(value=None) | |
with gr.Row(): | |
with gr.Column(scale=2, min_width=600): | |
chatbot = gr.ChatInterface( | |
predict, | |
multimodal=True, | |
chatbot=gr.Chatbot(elem_id="chatbot", height=600, show_share_button=False), | |
textbox=gr.MultimodalTextbox(lines=1), | |
additional_inputs=[chat_engine] | |
) | |
with gr.Column(scale=1, min_width=300): | |
with gr.Accordion('Select Model', open=True): | |
with gr.Column(): | |
provider = gr.Dropdown(label='Provider', choices=['DeepSeek', 'OpenRouter'], value='DeepSeek') | |
def show_model_config_panel(_provider): | |
if _provider == 'DeepSeek': | |
with gr.Column(): | |
model = gr.Dropdown( | |
label='模型', | |
choices=deep_seek_llm.support_models, | |
value=deep_seek_llm.default_model | |
) | |
temperature = gr.Slider( | |
minimum=0.0, | |
maximum=1.0, | |
step=0.1, | |
value=deep_seek_llm.default_temperature, | |
label="Temperature", | |
key="temperature", | |
) | |
max_tokens = gr.Number( | |
minimum=1024, | |
maximum=1024 * 20, | |
step=128, | |
value=deep_seek_llm.default_max_tokens, | |
label="Max Tokens", | |
key="max_tokens", | |
) | |
model.change( | |
fn=update_chat, | |
inputs=[provider, chat_engine, model, temperature, max_tokens], | |
outputs=[chat_engine], | |
) | |
temperature.change( | |
fn=update_chat, | |
inputs=[provider, chat_engine, model, temperature, max_tokens], | |
outputs=[chat_engine], | |
) | |
max_tokens.change( | |
fn=update_chat, | |
inputs=[provider, chat_engine, model, temperature, max_tokens], | |
outputs=[chat_engine], | |
) | |
if _provider == 'OpenRouter': | |
with gr.Column(): | |
model = gr.Dropdown( | |
label='模型', | |
choices=open_router_llm.support_models, | |
value=open_router_llm.default_model | |
) | |
temperature = gr.Slider( | |
minimum=0.0, | |
maximum=1.0, | |
step=0.1, | |
value=open_router_llm.default_temperature, | |
label="Temperature", | |
key="temperature", | |
) | |
max_tokens = gr.Number( | |
minimum=1024, | |
maximum=1024 * 20, | |
step=128, | |
value=open_router_llm.default_max_tokens, | |
label="Max Tokens", | |
key="max_tokens", | |
) | |
model.change( | |
fn=update_chat, | |
inputs=[provider, chat_engine, model, temperature, max_tokens], | |
outputs=[chat_engine], | |
) | |
temperature.change( | |
fn=update_chat, | |
inputs=[provider, chat_engine, model, temperature, max_tokens], | |
outputs=[chat_engine], | |
) | |
max_tokens.change( | |
fn=update_chat, | |
inputs=[provider, chat_engine, model, temperature, max_tokens], | |
outputs=[chat_engine], | |
) | |
with gr.Tab('画图'): | |
with gr.Row(): | |
with gr.Column(scale=2, min_width=600): | |
gr.Image(label="Input Image") | |
with gr.Column(scale=1, min_width=300): | |
gr.Textbox(label="LoRA") | |
app.launch(debug=settings.debug, show_api=False) | |