File size: 4,250 Bytes
9c880cb
 
5bdf9aa
b4dff1d
5bdf9aa
 
b4dff1d
 
30bf3f3
32957d4
 
 
 
 
 
 
b08a6f9
32957d4
 
 
4847675
bb9bd92
32957d4
8ae421b
32957d4
4ae262e
87dda7a
90f6ea5
bb9bd92
87dda7a
32957d4
a5db718
871126f
30bf3f3
b08a6f9
871126f
 
 
a5db718
 
b4dff1d
32957d4
b4dff1d
 
871126f
90f6ea5
 
4847675
bb9bd92
5bdf9aa
4847675
1e06dbb
bb9bd92
 
 
8ae421b
bb9bd92
6c72519
bb9bd92
 
 
 
 
 
4847675
bb9bd92
 
b08a6f9
bb9bd92
 
 
 
5fd690f
 
b08a6f9
bb9bd92
90f6ea5
bb9bd92
 
 
 
 
 
 
 
 
 
 
b08a6f9
bb9bd92
 
 
 
 
 
 
 
 
 
9c880cb
30bf3f3
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import gradio as gr
from huggingface_hub import InferenceClient
import os
from threading import Event

hf_token = os.getenv("HF_TOKEN")
stop_event = Event()

# 모델 목록 정의
models = {
    "deepseek-ai/DeepSeek-Coder-V2-Instruct": "(한국회사)DeepSeek-Coder-V2-Instruct",
    "meta-llama/Meta-Llama-3.1-8B-Instruct": "Meta-Llama-3.1-8B-Instruct",
    "mistralai/Mixtral-8x7B-Instruct-v0.1": "Mixtral-8x7B-Instruct-v0.1",
    "CohereForAI/c4ai-command-r-plus": "Cohere Command-R Plus"
}

# Inference 클라이언트 반환
def get_client(model):
    return InferenceClient(model=model, token=hf_token)

# 응답 생성 함수 (스트리밍 방식, 자문자답 방지)
def respond(prompt, system_message, max_tokens, temperature, top_p, selected_model):
    stop_event.clear()
    client = get_client(selected_model)
    
    # 프롬프트 설정
    messages = [
        {"role": "system", "content": system_message + "\n입력에 대해서만 답변하세요. 추가 질문을 하지 마세요. 입력 내용만 반영하세요."},
        {"role": "user", "content": prompt}
    ]
    
    try:
        response = ""
        
        # 모델에서 응답을 스트리밍
        for chunk in client.text_generation(
            prompt="\n".join([f"{m['role']}: {m['content']}" for m in messages]),
            max_new_tokens=max_tokens,
            temperature=temperature,
            top_p=top_p,
            stream=True
        ):
            if stop_event.is_set():
                break
            if chunk:
                # 모델이 스스로 질문을 하지 않도록 패턴을 체크
                response += chunk
                yield [(prompt, response.strip())]  # 실시간으로 부분적인 응답 반환
    
    except Exception as e:
        yield [(prompt, f"오류 발생: {str(e)}")]

# 응답 중단 함수
def stop_generation():
    stop_event.set()

# Gradio UI 구성
with gr.Blocks() as demo:
    gr.Markdown("# 프롬프트 최적화 Playground")
    
    gr.Markdown("""
    **주의사항:**
    - '전송' 버튼을 클릭하거나 입력 필드에서 Shift+Enter를 눌러 메시지를 전송할 수 있습니다.
    - Enter 키는 줄바꿈으로 작동합니다.
    - 입력한 내용에 대해서만 응답하도록 설정되어 있으며, 모델이 추가 질문을 만들거나 입력을 확장하지 않도록 설정됩니다.
    """)
    
    with gr.Row():
        with gr.Column(scale=1):
            with gr.Accordion("모델 설정", open=True):
                model = gr.Radio(list(models.keys()), value=list(models.keys())[0], label="언어 모델 선택", info="사용할 언어 모델을 선택하세요")
                max_tokens = gr.Slider(minimum=1, maximum=2000, value=500, step=100, label="최대 새 토큰 수")
                temperature = gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.05, label="온도")
                top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.90, step=0.05, label="Top-p (핵 샘플링)")
            
            system_message = gr.Textbox(
                value="너는 나의 최고의 비서이다. 정확하게 답변하라. 추가 질문을 하지 말고, 사용자의 입력 내용에 대해서만 답변하라.",
                label="시스템 메시지",
                lines=5
            )
        
        with gr.Column(scale=2):
            chatbot = gr.Chatbot(height=400, label="대화 결과")
            prompt = gr.Textbox(
                label="내용 입력", 
                lines=3,
                placeholder="메시지를 입력하세요. 전송 버튼을 클릭하거나 Shift+Enter를 눌러 전송합니다."
            )
            
            with gr.Row():
                send = gr.Button("전송")
                stop = gr.Button("🛑 생성 중단")
                clear = gr.Button("🗑️ 대화 내역 지우기")

    # Event handlers
    send.click(respond, inputs=[prompt, system_message, max_tokens, temperature, top_p, model], outputs=[chatbot])
    prompt.submit(respond, inputs=[prompt, system_message, max_tokens, temperature, top_p, model], outputs=[chatbot])
    stop.click(stop_generation)
    clear.click(lambda: None, outputs=[chatbot])

# UI 실행
demo.launch()