File size: 3,815 Bytes
f1fef64
d9faa8c
317e409
d9faa8c
 
 
9253654
d9faa8c
3b082f7
9253654
3b082f7
d9faa8c
317e409
d9faa8c
 
 
 
 
 
539566d
f70fc29
 
9253654
3b082f7
9253654
d9faa8c
3b082f7
d9faa8c
3b082f7
 
 
 
 
 
 
 
 
d9faa8c
9122113
3b082f7
 
 
 
 
9122113
d9faa8c
3b082f7
 
d9faa8c
f70fc29
 
d9faa8c
f70fc29
 
 
 
 
 
 
 
 
 
c2dfdca
ab40b57
f70fc29
 
 
 
c2dfdca
 
381d2e1
f70fc29
d9faa8c
f70fc29
c2dfdca
a26f5ee
f70fc29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d9faa8c
f70fc29
 
 
 
d9faa8c
317e409
 
f1fef64
ab40b57
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import gradio as gr
from functools import lru_cache

# Cache model loading to optimize performance
@lru_cache(maxsize=3)
def load_hf_model(model_name):
    # Use the Gradio-built huggingface loader instead of transformers_gradio
    return gr.load(
        name=f"huggingface/deepseek-ai/{model_name}",
        src="huggingface",  # Changed from transformers_gradio.registry
        api_name="chat",
    )

# Load all models at startup
MODELS = {
    "DeepSeek-R1-Distill-Qwen-32B": load_hf_model("DeepSeek-R1-Distill-Qwen-32B"),
    "DeepSeek-R1": load_hf_model("DeepSeek-R1"),
    "DeepSeek-R1-Zero": load_hf_model("DeepSeek-R1-Zero")
}

# --- Chatbot function ---
def chatbot(input_text, history, model_choice, system_message, max_new_tokens, temperature, top_p):
    history = history or []

    # Get the selected model component
    model_component = MODELS[model_choice]

    # Create payload for the model
    payload = [
        history,  # Pass the entire history 
        input_text,
        system_message,
        max_new_tokens,
        temperature,
        top_p
    ]

    # Run inference using the selected model
    try:
        response = model_component(payload) # the response now it is a tuple containing the updated history as the first element and the generated text as the second
        updated_history, assistant_response = response[0], response[1]

        history = updated_history

    except Exception as e:
        assistant_response = f"Error: {str(e)}"
        history.append((input_text, assistant_response))

    return history, history, ""

# --- Gradio Interface ---
with gr.Blocks(theme=gr.themes.Soft(), title="DeepSeek Chatbot") as demo:
    gr.Markdown(
        """
        # DeepSeek Chatbot
        Created by [ruslanmv.com](https://ruslanmv.com/)
        This is a demo of different DeepSeek models. Select a model, type your message, and click "Submit".
        You can also adjust optional parameters like system message, max new tokens, temperature, and top-p.
        """
    )

    with gr.Row():
        with gr.Column():
            chatbot_output = gr.Chatbot(label="DeepSeek Chatbot", height=500)
            msg = gr.Textbox(label="Your Message", placeholder="Type your message here...")
            with gr.Row():
                submit_btn = gr.Button("Submit", variant="primary")
                clear_btn = gr.ClearButton([msg, chatbot_output])

    with gr.Row():
        with gr.Accordion("Options", open=True):
            model_choice = gr.Radio(
                choices=list(MODELS.keys()),
                label="Choose a Model",
                value="DeepSeek-R1"
            )
            with gr.Accordion("Optional Parameters", open=False):
                system_message = gr.Textbox(
                    label="System Message",
                    value="You are a friendly Chatbot created by ruslanmv.com",
                    lines=2,
                )
                max_new_tokens = gr.Slider(
                    minimum=1, maximum=4000, value=200, label="Max New Tokens"
                )
                temperature = gr.Slider(
                    minimum=0.10, maximum=4.00, value=0.70, label="Temperature"
                )
                top_p = gr.Slider(
                    minimum=0.10, maximum=1.00, value=0.90, label="Top-p (nucleus sampling)"
                )

    chat_history = gr.State([])

    # Event handling
    submit_btn.click(
        chatbot,
        [msg, chat_history, model_choice, system_message, max_new_tokens, temperature, top_p],
        [chatbot_output, chat_history, msg]
    )
    msg.submit(
        chatbot,
        [msg, chat_history, model_choice, system_message, max_new_tokens, temperature, top_p],
        [chatbot_output, chat_history, msg]
    )

if __name__ == "__main__":
    demo.launch()