File size: 6,035 Bytes
9c880cb
 
5bdf9aa
b4dff1d
5bdf9aa
 
b4dff1d
 
6c72519
a5db718
 
 
 
 
 
b4dff1d
a5db718
b4dff1d
 
 
 
b8376a6
 
 
 
b4dff1d
 
 
6c72519
b4dff1d
a5db718
1e06dbb
b4dff1d
5bdf9aa
 
 
 
6c72519
5bdf9aa
 
b8376a6
 
 
c1e9c0c
a5db718
5bdf9aa
c1e9c0c
 
9c880cb
01c8295
 
c1e9c0c
01c8295
 
 
 
 
 
 
 
 
 
 
b8376a6
 
 
 
01c8295
c1e9c0c
b8376a6
c1e9c0c
 
 
 
01c8295
1e06dbb
08c7a06
d8850f3
08c7a06
 
1e06dbb
 
6c72519
c04ac55
93b41fc
 
c04ac55
 
 
 
b045cea
c04ac55
01c8295
 
93b41fc
01c8295
6c72519
93b41fc
 
b4dff1d
93b41fc
 
 
6c72519
 
 
cbafc8c
6c72519
c04ac55
 
 
6c72519
 
c1e9c0c
 
 
 
 
 
 
 
 
 
 
 
 
 
b045cea
 
 
c04ac55
c1e9c0c
 
b4dff1d
 
c174edf
9c880cb
 
5bdf9aa
 
b4dff1d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import gradio as gr
from huggingface_hub import InferenceClient
import os
from threading import Event

hf_token = os.getenv("HF_TOKEN")
stop_event = Event()

def get_model_response(client, messages, max_tokens, temperature, top_p):
    try:
        response = client.chat_completion(
            messages,
            max_tokens=max_tokens,
            temperature=temperature,
            top_p=top_p,
            stream=True
        )
        full_response = ""
        for message in response:
            if stop_event.is_set():
                break
            if hasattr(message.choices[0], 'delta'):
                token = message.choices[0].delta.content
            else:
                token = message.choices[0].text
            if token:
                full_response += token
                yield full_response
    except Exception as e:
        yield f"모델 추론 실패: {str(e)}"

def respond(message, history, system_message, max_tokens, temperature, top_p, selected_model):
    stop_event.clear()
    try:
        client = InferenceClient(model=selected_model, token=hf_token)
        
        messages = [{"role": "system", "content": system_message}]
        messages.extend([{"role": "user" if i % 2 == 0 else "assistant", "content": m} for h in history for i, m in enumerate(h) if m])
        messages.append({"role": "user", "content": message})
        
        history.append((message, ""))
        for response in get_model_response(client, messages, max_tokens, temperature, top_p):
            history[-1] = (message, response)
            yield "", history
        
    except Exception as e:
        history.append((message, f"오류 발생: {str(e)}"))
        yield "", history

def continue_writing(history, system_message, max_tokens, temperature, top_p, model):
    if not history:
        return "", history
    last_user_message = history[-1][0]
    last_assistant_message = history[-1][1]
    
    prompt = f"이전 대화를 계속 이어서 작성해주세요. 이전 응답: {last_assistant_message}"
    
    try:
        client = InferenceClient(model=model, token=hf_token)
        messages = [{"role": "system", "content": system_message}]
        messages.extend([{"role": "user" if i % 2 == 0 else "assistant", "content": m} for h in history for i, m in enumerate(h)])
        messages.append({"role": "user", "content": prompt})
        
        for response in get_model_response(client, messages, max_tokens, temperature, top_p):
            continued_response = last_assistant_message + " " + response
            history[-1] = (last_user_message, continued_response)
            yield "", history
    except Exception as e:
        history.append(("시스템", f"계속 작성 중 오류 발생: {str(e)}"))
        yield "", history

def stop_generation():
    stop_event.set()
    return "생성이 중단되었습니다."

models = {
    "deepseek-ai/DeepSeek-Coder-V2-Instruct": "(한국회사)DeepSeek-Coder-V2-Instruct",
    "meta-llama/Meta-Llama-3.1-8B-Instruct": "Meta-Llama-3.1-8B-Instruct",
    "mistralai/Mixtral-8x7B-Instruct-v0.1": "Mixtral-8x7B-Instruct-v0.1",
    "CohereForAI/c4ai-command-r-plus": "Cohere Command-R Plus"
}

with gr.Blocks() as demo:
    chatbot = gr.Chatbot(height=600)
    
    with gr.Row():
        msg = gr.Textbox(
            scale=4,
            label="메시지 입력",
            lines=3,
            placeholder="메시지를 입력하세요. 엔터 키로 줄바꿈, Shift+Enter로 전송"
        )
    
    with gr.Row():
        send = gr.Button("전송", scale=1)
        continue_btn = gr.Button("계속 작성", scale=1)

    with gr.Row():
        regenerate = gr.Button("🔄 재생성")
        stop = gr.Button("🛑 생성 중단")
        clear = gr.Button("🗑️ 대화 내역 지우기")

    with gr.Accordion("추가 설정", open=True):
        system_message = gr.Textbox(
            value="너는 나의 최고의 비서이다.\n내가 요구하는것들을 최대한 자세하고 정확하게 답변하라.\n반드시 한글로 답변할것.",
            label="시스템 메시지",
            lines=10
        )
        max_tokens = gr.Slider(minimum=1, maximum=2000, value=500, step=100, label="최대 새 토큰 수")
        temperature = gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.05, label="온도")
        top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.90, step=0.05, label="Top-p (핵 샘플링)")
        model = gr.Radio(list(models.keys()), value=list(models.keys())[0], label="언어 모델 선택", info="사용할 언어 모델을 선택하세요")

    msg.javascript = """
    (x) => {
        const textbox = document.querySelector("#component-3");
        textbox.addEventListener("keydown", function(e) {
            if (e.key === 'Enter' && !e.shiftKey) {
                e.preventDefault();
                return;
            }
            if (e.key === 'Enter' && e.shiftKey) {
                e.preventDefault();
                document.querySelector("#component-5").click();
            }
        });
        return x;
    }
    """

    send.click(respond, inputs=[msg, chatbot, system_message, max_tokens, temperature, top_p, model], outputs=[msg, chatbot])
    msg.submit(respond, inputs=[msg, chatbot, system_message, max_tokens, temperature, top_p, model], outputs=[msg, chatbot])
    continue_btn.click(continue_writing, inputs=[chatbot, system_message, max_tokens, temperature, top_p, model], outputs=[msg, chatbot])
    regenerate.click(lambda h, s, m, t, p, mod: respond(h[-1][0] if h else "", h[:-1], s, m, t, p, mod), inputs=[chatbot, system_message, max_tokens, temperature, top_p, model], outputs=[msg, chatbot])
    stop.click(stop_generation, inputs=[], outputs=[msg])
    clear.click(lambda: (None, None), outputs=[msg, chatbot])

if __name__ == "__main__":
    if not hf_token:
        print("경고: HF_TOKEN 환경 변수가 설정되지 않았습니다. 일부 모델에 접근할 수 없을 수 있습니다.")
    demo.launch(share=True)