File size: 3,210 Bytes
9c880cb
 
f9c7426
5bdf9aa
cfab2e6
 
 
 
 
 
32957d4
 
cfab2e6
f9c7426
d345b65
 
 
 
cfab2e6
 
 
a41f6e0
cfab2e6
 
 
 
a41f6e0
cfab2e6
f9c7426
 
 
f082370
 
f9c7426
cfab2e6
a41f6e0
 
 
cfab2e6
 
c14f735
f082370
6158aab
 
 
 
 
 
f082370
af9bced
 
 
 
f082370
cfab2e6
f9c7426
c14f735
f9c7426
a41f6e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f9c7426
 
 
a41f6e0
 
f9c7426
 
cfab2e6
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import gradio as gr
from huggingface_hub import InferenceClient
import os

MODELS = {
    "Zephyr 7B Beta": "HuggingFaceH4/zephyr-7b-beta",
    "DeepSeek Coder V2": "deepseek-ai/DeepSeek-Coder-V2-Instruct",
    "Meta Llama 3.1 8B": "meta-llama/Meta-Llama-3.1-8B-Instruct",
    "Mixtral 8x7B": "mistralai/Mixtral-8x7B-Instruct-v0.1",
    "Cohere Command R+": "CohereForAI/c4ai-command-r-plus",
}

def get_client(model_name):
    model_id = MODELS[model_name]
    hf_token = os.getenv("HF_TOKEN")
    if not hf_token:
        raise ValueError("HF_TOKEN environment variable is required")
    return InferenceClient(model_id, token=hf_token)

def respond(
    message,
    chat_history,
    model_name,
    max_tokens,
    temperature,
    top_p,
    system_message,
):
    try:
        client = get_client(model_name)
    except ValueError as e:
        chat_history.append((message, str(e)))
        return chat_history

    messages = [{"role": "system", "content": system_message}]
    for human, assistant in chat_history:
        messages.append({"role": "user", "content": human})
        messages.append({"role": "assistant", "content": assistant})
    messages.append({"role": "user", "content": message})

    try:
        stream = client.chat_completion(
    messages,
    max_tokens=max_tokens,  # 'max_new_tokens'에서 'max_tokens'로 변경
    temperature=temperature,
    top_p=top_p,
    stream=True,
)
        partial_message = ""
for response in stream:
    if response.choices[0].delta.content is not None:
        partial_message += response.choices[0].delta.content
        chat_history = chat_history[:-1] + [(message, partial_message)]
        yield chat_history

def clear_conversation():
    return []

with gr.Blocks() as demo:
    gr.Markdown("# Advanced AI Chatbot")
    gr.Markdown("Chat with different language models and customize your experience!")

    with gr.Row():
        with gr.Column(scale=1):
            model_name = gr.Radio(
                choices=list(MODELS.keys()),
                label="Language Model",
                value="Zephyr 7B Beta"
            )
            max_tokens = gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max New Tokens")
            temperature = gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature")
            top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)")
            system_message = gr.Textbox(
                value="You are a friendly and helpful AI assistant.",
                label="System Message",
                lines=3
            )

        with gr.Column(scale=2):
            chatbot = gr.Chatbot()
            msg = gr.Textbox(label="Your message")
            with gr.Row():
                submit_button = gr.Button("Submit")
                clear_button = gr.Button("Clear")

    msg.submit(respond, [msg, chatbot, model_name, max_tokens, temperature, top_p, system_message], chatbot)
    submit_button.click(respond, [msg, chatbot, model_name, max_tokens, temperature, top_p, system_message], chatbot)
    clear_button.click(clear_conversation, outputs=chatbot, queue=False)

if __name__ == "__main__":
    demo.launch()