File size: 3,411 Bytes
9c880cb
 
f9c7426
5bdf9aa
cfab2e6
 
 
 
 
 
32957d4
 
cfab2e6
f9c7426
d345b65
 
 
 
cfab2e6
 
 
a41f6e0
cfab2e6
 
 
 
a41f6e0
cfab2e6
fcc0bb3
4d79cf7
fcc0bb3
 
 
 
 
 
 
 
 
 
f9c7426
c35470f
 
 
 
 
 
 
 
 
 
 
4d79cf7
c35470f
 
 
 
 
cfab2e6
f9c7426
c14f735
f9c7426
a41f6e0
4d79cf7
 
a41f6e0
 
 
 
 
 
 
 
4d79cf7
 
 
a41f6e0
 
 
 
 
 
 
 
4d79cf7
f9c7426
4d79cf7
 
a41f6e0
 
f9c7426
 
cfab2e6
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import gradio as gr
from huggingface_hub import InferenceClient
import os

MODELS = {
    "Zephyr 7B Beta": "HuggingFaceH4/zephyr-7b-beta",
    "DeepSeek Coder V2": "deepseek-ai/DeepSeek-Coder-V2-Instruct",
    "Meta Llama 3.1 8B": "meta-llama/Meta-Llama-3.1-8B-Instruct",
    "Mixtral 8x7B": "mistralai/Mixtral-8x7B-Instruct-v0.1",
    "Cohere Command R+": "CohereForAI/c4ai-command-r-plus",
}

def get_client(model_name):
    model_id = MODELS[model_name]
    hf_token = os.getenv("HF_TOKEN")
    if not hf_token:
        raise ValueError("HF_TOKEN environment variable is required")
    return InferenceClient(model_id, token=hf_token)

def respond(
    message,
    chat_history,
    model_name,
    max_tokens,
    temperature,
    top_p,
    system_message,
):
    try:
        client = get_client(model_name)
    except ValueError as e:
        chat_history.append((message, str(e)))
        return chat_history

    messages = [{"role": "system", "content": system_message}]
    for human, assistant in chat_history:
        messages.append({"role": "user", "content": human})
        messages.append({"role": "assistant", "content": assistant})
    messages.append({"role": "user", "content": message})

    try:
        stream = client.chat_completion(
            messages,
            max_tokens=max_tokens,
            temperature=temperature,
            top_p=top_p,
            stream=True,
        )
        partial_message = ""
        for response in stream:
            if response.choices[0].delta.content is not None:
                partial_message += response.choices[0].delta.content
                chat_history = chat_history + [(message, partial_message)]
                yield chat_history
    except Exception as e:
        error_message = f"An error occurred: {str(e)}"
        chat_history.append((message, error_message))
        yield chat_history

def clear_conversation():
    return []

with gr.Blocks() as demo:
    gr.Markdown("# Prompting AI Chatbot")
    gr.Markdown("언어모델별 프롬프트 테스트 챗봇입니다.")

    with gr.Row():
        with gr.Column(scale=1):
            model_name = gr.Radio(
                choices=list(MODELS.keys()),
                label="Language Model",
                value="Zephyr 7B Beta"
            )
            max_tokens = gr.Slider(minimum=1, maximum=2000, value=500, step=100, label="Max Tokens")
            temperature = gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.05, label="Temperature")
            top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p")
            system_message = gr.Textbox(
                value="You are a friendly and helpful AI assistant.",
                label="System Message",
                lines=3
            )

        with gr.Column(scale=2):
            chatbot = gr.Chatbot()
            msg = gr.Textbox(label="메세지를 입력하세요")
            with gr.Row():
                submit_button = gr.Button("전송")
                clear_button = gr.Button("대화 내역 지우기")

    msg.submit(respond, [msg, chatbot, model_name, max_tokens, temperature, top_p, system_message], chatbot)
    submit_button.click(respond, [msg, chatbot, model_name, max_tokens, temperature, top_p, system_message], chatbot)
    clear_button.click(clear_conversation, outputs=chatbot, queue=False)

if __name__ == "__main__":
    demo.launch()