File size: 3,125 Bytes
9c880cb
 
f9c7426
5bdf9aa
cfab2e6
 
 
 
 
 
32957d4
 
cfab2e6
f9c7426
d345b65
 
 
 
cfab2e6
 
 
a41f6e0
cfab2e6
 
 
 
a41f6e0
cfab2e6
f9c7426
 
 
c14f735
f9c7426
cfab2e6
 
a41f6e0
 
 
cfab2e6
 
 
c14f735
 
 
 
 
 
 
 
 
 
cfab2e6
c14f735
 
cfab2e6
f9c7426
c14f735
f9c7426
a41f6e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f9c7426
 
 
a41f6e0
 
f9c7426
 
cfab2e6
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import gradio as gr
from huggingface_hub import InferenceClient
import os

MODELS = {
    "Zephyr 7B Beta": "HuggingFaceH4/zephyr-7b-beta",
    "DeepSeek Coder V2": "deepseek-ai/DeepSeek-Coder-V2-Instruct",
    "Meta Llama 3.1 8B": "meta-llama/Meta-Llama-3.1-8B-Instruct",
    "Mixtral 8x7B": "mistralai/Mixtral-8x7B-Instruct-v0.1",
    "Cohere Command R+": "CohereForAI/c4ai-command-r-plus",
}

def get_client(model_name):
    model_id = MODELS[model_name]
    hf_token = os.getenv("HF_TOKEN")
    if not hf_token:
        raise ValueError("HF_TOKEN environment variable is required")
    return InferenceClient(model_id, token=hf_token)

def respond(
    message,
    chat_history,
    model_name,
    max_tokens,
    temperature,
    top_p,
    system_message,
):
    try:
        client = get_client(model_name)
    except ValueError as e:
        return chat_history + [[message, str(e)]]

    messages = [{"role": "system", "content": system_message}]

    for human, assistant in chat_history:
        messages.append({"role": "user", "content": human})
        messages.append({"role": "assistant", "content": assistant})

    messages.append({"role": "user", "content": message})

    try:
        response = client.chat_completion(
            messages,
            max_tokens=max_tokens,
            temperature=temperature,
            top_p=top_p,
        )
        assistant_message = response.choices[0].message.content
    except Exception as e:
        assistant_message = f"An error occurred: {str(e)}"

    chat_history.append((message, assistant_message))
    return chat_history

def clear_conversation():
    return []

with gr.Blocks() as demo:
    gr.Markdown("# Advanced AI Chatbot")
    gr.Markdown("Chat with different language models and customize your experience!")

    with gr.Row():
        with gr.Column(scale=1):
            model_name = gr.Radio(
                choices=list(MODELS.keys()),
                label="Language Model",
                value="Zephyr 7B Beta"
            )
            max_tokens = gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max New Tokens")
            temperature = gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature")
            top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)")
            system_message = gr.Textbox(
                value="You are a friendly and helpful AI assistant.",
                label="System Message",
                lines=3
            )

        with gr.Column(scale=2):
            chatbot = gr.Chatbot()
            msg = gr.Textbox(label="Your message")
            with gr.Row():
                submit_button = gr.Button("Submit")
                clear_button = gr.Button("Clear")

    msg.submit(respond, [msg, chatbot, model_name, max_tokens, temperature, top_p, system_message], chatbot)
    submit_button.click(respond, [msg, chatbot, model_name, max_tokens, temperature, top_p, system_message], chatbot)
    clear_button.click(clear_conversation, outputs=chatbot, queue=False)

if __name__ == "__main__":
    demo.launch()