File size: 4,515 Bytes
9c880cb
 
5bdf9aa
b4dff1d
5bdf9aa
30bf3f3
5bdf9aa
b4dff1d
 
30bf3f3
32957d4
 
 
 
 
 
 
30bf3f3
32957d4
 
 
30bf3f3
87dda7a
32957d4
8ae421b
32957d4
30bf3f3
87dda7a
4487f96
87dda7a
 
32957d4
a5db718
871126f
30bf3f3
 
 
871126f
 
 
a5db718
 
b4dff1d
32957d4
b4dff1d
 
871126f
 
30bf3f3
 
871126f
5bdf9aa
30bf3f3
 
 
8ae421b
 
 
30bf3f3
1e06dbb
30bf3f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87dda7a
30bf3f3
 
 
 
 
 
 
87dda7a
30bf3f3
8ae421b
30bf3f3
 
 
87dda7a
30bf3f3
6c72519
30bf3f3
 
 
 
 
 
f181d4c
30bf3f3
 
 
4487f96
30bf3f3
 
 
 
 
 
 
9c880cb
30bf3f3
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import gradio as gr
from huggingface_hub import InferenceClient
import os
from threading import Event

# Hugging Face API Token을 환경 변수로부터 가져옴
hf_token = os.getenv("HF_TOKEN")
stop_event = Event()

# 모델 목록 정의
models = {
    "deepseek-ai/DeepSeek-Coder-V2-Instruct": "(한국회사)DeepSeek-Coder-V2-Instruct",
    "meta-llama/Meta-Llama-3.1-8B-Instruct": "Meta-Llama-3.1-8B-Instruct",
    "mistralai/Mixtral-8x7B-Instruct-v0.1": "Mixtral-8x7B-Instruct-v0.1",
    "CohereForAI/c4ai-command-r-plus": "Cohere Command-R Plus"
}

# Inference 클라이언트를 반환하는 함수
def get_client(model):
    return InferenceClient(model=model, token=hf_token)

# 메시지 응답 생성 함수
def respond(message, system_message, max_tokens, temperature, top_p, selected_model):
    stop_event.clear()
    client = get_client(selected_model)
    
    # 프롬프트 설정
    messages = [
        {"role": "system", "content": system_message + "\n주어진 입력에만 정확히 답변하세요. 추가 질문을 만들거나 입력을 확장하지 마세요."},
        {"role": "user", "content": message}
    ]
    
    try:
        response = ""
        total_tokens_used = 0  # 사용된 토큰 수 추적
        
        # 모델에서 응답을 청크 단위로 스트리밍
        for chunk in client.text_generation(
            prompt="\n".join([f"{m['role']}: {m['content']}" for m in messages]),
            max_new_tokens=max_tokens,
            temperature=temperature,
            top_p=top_p,
            stream=True
        ):
            if stop_event.is_set():
                break
            if chunk:
                response += chunk
                total_tokens_used += len(chunk.split())  # 청크당 사용된 토큰 수 추산
                yield [(message, response, f"사용된 토큰 수: {total_tokens_used}/{max_tokens}")]
        
    except Exception as e:
        yield [(message, f"오류 발생: {str(e)}", "에러 처리 필요")]

# 이전 응답을 확인하는 함수
def get_last_response(chatbot):
    if chatbot and len(chatbot) > 0:
        return chatbot[-1][1]
    return None

# 프롬프트 비교 및 최적화를 위한 히스토리 기록 추가
class PromptHistory:
    def __init__(self):
        self.history = []

    def add_entry(self, prompt, response, model, settings):
        self.history.append({
            "prompt": prompt,
            "response": response,
            "model": model,
            "settings": settings
        })

    def get_history(self):
        return self.history

# 히스토리 객체 생성
prompt_history = PromptHistory()

# Gradio 인터페이스 함수 정의
def gradio_interface(message, system_message, max_tokens, temperature, top_p, selected_model):
    result = None
    for output in respond(message, system_message, max_tokens, temperature, top_p, selected_model):
        result = output
    
    # 프롬프트와 결과를 히스토리에 추가
    prompt_history.add_entry(
        message,
        result[0][1],  # 모델 응답
        selected_model,
        {"max_tokens": max_tokens, "temperature": temperature, "top_p": top_p}
    )
    
    return result

# 히스토리 확인용 함수
def view_history():
    return prompt_history.get_history()

# Gradio UI 구성
with gr.Blocks() as demo:
    selected_model = gr.Dropdown(choices=list(models.keys()), label="모델 선택")
    message = gr.Textbox(label="사용자 메시지")
    system_message = gr.Textbox(label="시스템 메시지", value="이 메시지를 기준으로 대화 흐름을 설정합니다.")
    max_tokens = gr.Slider(minimum=10, maximum=512, value=128, label="최대 토큰 수")
    temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, label="Temperature")
    top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.9, label="Top-p")
    
    response_output = gr.Textbox(label="모델 응답")
    token_usage = gr.Textbox(label="토큰 사용량")
    history_button = gr.Button("히스토리 보기")
    
    # 버튼을 눌러 응답을 받는 함수 연결
    submit_button = gr.Button("응답 생성")
    submit_button.click(gradio_interface, inputs=[message, system_message, max_tokens, temperature, top_p, selected_model], outputs=[response_output, token_usage])
    
    # 히스토리 보기 기능 연결
    history_output = gr.Textbox(label="히스토리", interactive=False)
    history_button.click(view_history, outputs=history_output)

# UI 실행
demo.launch()