File size: 5,012 Bytes
9c880cb
 
5bdf9aa
b4dff1d
5bdf9aa
 
b4dff1d
 
32957d4
 
 
 
 
 
 
 
 
 
48a4d38
 
 
 
 
87dda7a
32957d4
871126f
32957d4
87dda7a
 
 
 
32957d4
a5db718
871126f
 
 
 
a5db718
 
b4dff1d
32957d4
b4dff1d
 
871126f
 
87dda7a
871126f
5bdf9aa
87dda7a
844d195
87dda7a
 
 
48a4d38
c1e9c0c
 
 
01c8295
32957d4
 
9a9e197
32957d4
 
1e06dbb
87dda7a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6c72519
32957d4
87dda7a
6c72519
93b41fc
32957d4
 
b4dff1d
93b41fc
 
 
6c72519
 
 
32957d4
6c72519
c04ac55
 
 
6c72519
 
32957d4
87dda7a
 
b16cf8b
87dda7a
 
32957d4
 
9c880cb
 
5bdf9aa
 
960780c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import gradio as gr
from huggingface_hub import InferenceClient
import os
from threading import Event

hf_token = os.getenv("HF_TOKEN")
stop_event = Event()

models = {
    "deepseek-ai/DeepSeek-Coder-V2-Instruct": "(한국회사)DeepSeek-Coder-V2-Instruct",
    "meta-llama/Meta-Llama-3.1-8B-Instruct": "Meta-Llama-3.1-8B-Instruct",
    "mistralai/Mixtral-8x7B-Instruct-v0.1": "Mixtral-8x7B-Instruct-v0.1",
    "CohereForAI/c4ai-command-r-plus": "Cohere Command-R Plus"
}

def get_client(model):
    return InferenceClient(model=model, token=hf_token)

MAX_HISTORY_LENGTH = 5  # 히스토리에 유지할 최대 대화 수

def truncate_history(history):
    return history[-MAX_HISTORY_LENGTH:] if len(history) > MAX_HISTORY_LENGTH else history

def respond(message, system_message, max_tokens, temperature, top_p, selected_model):
    stop_event.clear()
    client = InferenceClient(model=selected_model, token=hf_token)
    
    messages = [
        {"role": "system", "content": system_message},
        {"role": "user", "content": message}
    ]
    
    try:
        response = ""
        for chunk in client.text_generation(
            prompt="\n".join([f"{m['role']}: {m['content']}" for m in messages]),
            max_new_tokens=max_tokens,
            temperature=temperature,
            top_p=top_p,
            stream=True
        ):
            if stop_event.is_set():
                break
            if chunk:
                response += chunk
                yield [(message, response)]
        
    except Exception as e:
        yield [(message, f"오류 발생: {str(e)}")]

def stop_generation():
    stop_event.set()
    return "생성이 중단되었습니다."
    
def stop_generation():
    stop_event.set()
    return "생성이 중단되었습니다."

def regenerate(chat_history, system_message, max_tokens, temperature, top_p, selected_model):
    if not chat_history:
        return "대화 내역이 없습니다."
    last_user_message = chat_history[-1][0]
    return respond(last_user_message, chat_history[:-1], system_message, max_tokens, temperature, top_p, selected_model)

def continue_writing(last_response, system_message, max_tokens, temperature, top_p, selected_model):
    stop_event.clear()
    client = InferenceClient(model=selected_model, token=hf_token)
    
    prompt = f"이전 응답을 이어서 작성해주세요. 이전 응답: {last_response}"
    messages = [
        {"role": "system", "content": system_message},
        {"role": "user", "content": prompt}
    ]
    
    try:
        response = last_response
        for chunk in client.text_generation(
            prompt="\n".join([f"{m['role']}: {m['content']}" for m in messages]),
            max_new_tokens=max_tokens,
            temperature=temperature,
            top_p=top_p,
            stream=True
        ):
            if stop_event.is_set():
                break
            if chunk:
                response += chunk
                yield [("계속 작성", response)]
        
    except Exception as e:
        yield [("계속 작성", f"오류 발생: {str(e)}")]

# Gradio 인터페이스 수정
with gr.Blocks() as demo:
    chatbot = gr.Chatbot()
    msg = gr.Textbox(label="메시지 입력")

    with gr.Row():
        send = gr.Button("전송")
        continue_btn = gr.Button("계속 작성")
        stop = gr.Button("🛑 생성 중단")
        clear = gr.Button("🗑️ 대화 내역 지우기")

    with gr.Accordion("추가 설정", open=True):
        system_message = gr.Textbox(
            value="너는 나의 최고의 비서이다.\n내가 요구하는것들을 최대한 자세하고 정확하게 답변하라.\n반드시 한글로 답변할것.",
            label="시스템 메시지",
            lines=5
        )
        max_tokens = gr.Slider(minimum=1, maximum=2000, value=500, step=100, label="최대 새 토큰 수")
        temperature = gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.05, label="온도")
        top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.90, step=0.05, label="Top-p (핵 샘플링)")
        model = gr.Radio(list(models.keys()), value=list(models.keys())[0], label="언어 모델 선택", info="사용할 언어 모델을 선택하세요")

    # Event handlers
    send.click(respond, inputs=[msg, system_message, max_tokens, temperature, top_p, model], outputs=[chatbot])
    msg.submit(respond, inputs=[msg, system_message, max_tokens, temperature, top_p, model], outputs=[chatbot])
    continue_btn.click(continue_writing, 
                       inputs=[lambda: chatbot[-1][1] if chatbot else "", system_message, max_tokens, temperature, top_p, model], 
                       outputs=[chatbot])
    stop.click(stop_generation, outputs=[msg])
    clear.click(lambda: None, outputs=[chatbot])

if __name__ == "__main__":
    if not hf_token:
        print("경고: HF_TOKEN 환경 변수가 설정되지 않았습니다. 일부 모델에 접근할 수 없을 수 있습니다.")
    demo.launch()