AIRider commited on
Commit
30bf3f3
·
verified ·
1 Parent(s): 4487f96

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -80
app.py CHANGED
@@ -3,9 +3,11 @@ from huggingface_hub import InferenceClient
3
  import os
4
  from threading import Event
5
 
 
6
  hf_token = os.getenv("HF_TOKEN")
7
  stop_event = Event()
8
 
 
9
  models = {
10
  "deepseek-ai/DeepSeek-Coder-V2-Instruct": "(한국회사)DeepSeek-Coder-V2-Instruct",
11
  "meta-llama/Meta-Llama-3.1-8B-Instruct": "Meta-Llama-3.1-8B-Instruct",
@@ -13,13 +15,16 @@ models = {
13
  "CohereForAI/c4ai-command-r-plus": "Cohere Command-R Plus"
14
  }
15
 
 
16
  def get_client(model):
17
  return InferenceClient(model=model, token=hf_token)
18
 
 
19
  def respond(message, system_message, max_tokens, temperature, top_p, selected_model):
20
  stop_event.clear()
21
  client = get_client(selected_model)
22
 
 
23
  messages = [
24
  {"role": "system", "content": system_message + "\n주어진 입력에만 정확히 답변하세요. 추가 질문을 만들거나 입력을 확장하지 마세요."},
25
  {"role": "user", "content": message}
@@ -27,6 +32,9 @@ def respond(message, system_message, max_tokens, temperature, top_p, selected_mo
27
 
28
  try:
29
  response = ""
 
 
 
30
  for chunk in client.text_generation(
31
  prompt="\n".join([f"{m['role']}: {m['content']}" for m in messages]),
32
  max_new_tokens=max_tokens,
@@ -38,97 +46,77 @@ def respond(message, system_message, max_tokens, temperature, top_p, selected_mo
38
  break
39
  if chunk:
40
  response += chunk
41
- yield [(message, response)]
 
42
 
43
  except Exception as e:
44
- yield [(message, f"오류 발생: {str(e)}")]
45
-
 
46
  def get_last_response(chatbot):
47
  if chatbot and len(chatbot) > 0:
48
  return chatbot[-1][1]
49
- return ""
50
 
51
- def continue_writing(chatbot, system_message, max_tokens, temperature, top_p, selected_model):
52
- last_response = get_last_response(chatbot)
53
- stop_event.clear()
54
- client = get_client(selected_model)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
- prompt = f"이전 응답을 이어서 작성해주세요. 이전 응답: {last_response}"
57
- messages = [
58
- {"role": "system", "content": system_message},
59
- {"role": "user", "content": prompt}
60
- ]
 
 
61
 
62
- try:
63
- response = last_response
64
- for chunk in client.text_generation(
65
- prompt="\n".join([f"{m['role']}: {m['content']}" for m in messages]),
66
- max_new_tokens=max_tokens,
67
- temperature=temperature,
68
- top_p=top_p,
69
- stream=True
70
- ):
71
- if stop_event.is_set():
72
- break
73
- if chunk:
74
- response += chunk
75
- yield chatbot + [("계속 작성", response)]
76
-
77
- except Exception as e:
78
- yield chatbot + [("계속 작성", f"오류 발생: {str(e)}")]
79
 
80
- def stop_generation():
81
- stop_event.set()
82
- return "생성이 중단되었습니다."
83
 
 
84
  with gr.Blocks() as demo:
85
- gr.Markdown("# 프롬프트 최적화 Playground")
 
 
 
 
 
86
 
87
- gr.Markdown("""
88
- **주의사항:**
89
- - '전송' 버튼을 클릭하거나 입력 필드에서 Shift+Enter를 눌러 메시지를 전송할 수 있습니다.
90
- - Enter 키는 줄바꿈으로 작동합니다.
91
- - 입력한 내용에 대해서만 응답하도록 설정되어 있지만, 모델이 때때로 예상치 못한 방식으로 응답할 수 있습니다.
92
- """)
93
 
94
- with gr.Row():
95
- with gr.Column(scale=1):
96
- with gr.Accordion("모델 설정", open=True):
97
- model = gr.Radio(list(models.keys()), value=list(models.keys())[0], label="언어 모델 선택", info="사용할 언어 모델을 선택하세요")
98
- max_tokens = gr.Slider(minimum=1, maximum=2000, value=500, step=100, label="최대 새 토큰 수")
99
- temperature = gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.05, label="온도")
100
- top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.90, step=0.05, label="Top-p (핵 샘플링)")
101
-
102
- system_message = gr.Textbox(
103
- value="너는 나의 최고의 비서이다.\n내가 요구하는것들을 최대한 자세하고 정확하게 답변하라.\n반드시 한글로 답변할것.\n사용자의 입력 내용에만 직접적으로 답변하고, 추가 질문을 만들거나 입력을 확장하지 마라.",
104
- label="시스템 메시지",
105
- lines=5
106
- )
107
-
108
- with gr.Column(scale=2):
109
- chatbot = gr.Chatbot(height=400, label="대화 결과")
110
- prompt = gr.Textbox(
111
- label="내용 입력",
112
- lines=3,
113
- placeholder="메시지를 입력하세요. 전송 버튼을 클릭하거나 Shift+Enter를 눌러 전송합니다."
114
- )
115
-
116
- with gr.Row():
117
- send = gr.Button("전송")
118
- continue_btn = gr.Button("계속 작성")
119
- stop = gr.Button("🛑 생성 중단")
120
- clear = gr.Button("🗑️ 대화 내역 지우기")
121
-
122
- # Event handlers
123
- send.click(respond, inputs=[prompt, system_message, max_tokens, temperature, top_p, model], outputs=[chatbot])
124
- prompt.submit(respond, inputs=[prompt, system_message, max_tokens, temperature, top_p, model], outputs=[chatbot])
125
- continue_btn.click(continue_writing,
126
- inputs=[chatbot, system_message, max_tokens, temperature, top_p, model],
127
- outputs=[chatbot])
128
- stop.click(stop_generation, outputs=[prompt])
129
- clear.click(lambda: None, outputs=[chatbot])
130
 
131
- if __name__ == "__main__":
132
- if not hf_token:
133
- print("경고: HF_TOKEN 환경 변수가 설정되지 않았습니다. 일부 모델에 접근할 수 없을 수 있습니다.")
134
- demo.launch()
 
3
  import os
4
  from threading import Event
5
 
6
+ # Hugging Face API Token을 환경 변수로부터 가져옴
7
  hf_token = os.getenv("HF_TOKEN")
8
  stop_event = Event()
9
 
10
+ # 모델 목록 정의
11
  models = {
12
  "deepseek-ai/DeepSeek-Coder-V2-Instruct": "(한국회사)DeepSeek-Coder-V2-Instruct",
13
  "meta-llama/Meta-Llama-3.1-8B-Instruct": "Meta-Llama-3.1-8B-Instruct",
 
15
  "CohereForAI/c4ai-command-r-plus": "Cohere Command-R Plus"
16
  }
17
 
18
+ # Inference 클라이언트를 반환하는 함수
19
  def get_client(model):
20
  return InferenceClient(model=model, token=hf_token)
21
 
22
+ # 메시지 응답 생성 함수
23
  def respond(message, system_message, max_tokens, temperature, top_p, selected_model):
24
  stop_event.clear()
25
  client = get_client(selected_model)
26
 
27
+ # 프롬프트 설정
28
  messages = [
29
  {"role": "system", "content": system_message + "\n주어진 입력에만 정확히 답변하세요. 추가 질문을 만들거나 입력을 확장하지 마세요."},
30
  {"role": "user", "content": message}
 
32
 
33
  try:
34
  response = ""
35
+ total_tokens_used = 0 # 사용된 토큰 수 추적
36
+
37
+ # 모델에서 응답을 청크 단위로 스트리밍
38
  for chunk in client.text_generation(
39
  prompt="\n".join([f"{m['role']}: {m['content']}" for m in messages]),
40
  max_new_tokens=max_tokens,
 
46
  break
47
  if chunk:
48
  response += chunk
49
+ total_tokens_used += len(chunk.split()) # 청크당 사용된 토큰 수 추산
50
+ yield [(message, response, f"사용된 토큰 수: {total_tokens_used}/{max_tokens}")]
51
 
52
  except Exception as e:
53
+ yield [(message, f"오류 발생: {str(e)}", "에러 처리 필요")]
54
+
55
+ # 이전 응답을 확인하는 함수
56
  def get_last_response(chatbot):
57
  if chatbot and len(chatbot) > 0:
58
  return chatbot[-1][1]
59
+ return None
60
 
61
+ # 프롬프트 비교 최적화를 위한 히스토리 기록 추가
62
+ class PromptHistory:
63
+ def __init__(self):
64
+ self.history = []
65
+
66
+ def add_entry(self, prompt, response, model, settings):
67
+ self.history.append({
68
+ "prompt": prompt,
69
+ "response": response,
70
+ "model": model,
71
+ "settings": settings
72
+ })
73
+
74
+ def get_history(self):
75
+ return self.history
76
+
77
+ # 히스토리 객체 생성
78
+ prompt_history = PromptHistory()
79
+
80
+ # Gradio 인터페이스 함수 정의
81
+ def gradio_interface(message, system_message, max_tokens, temperature, top_p, selected_model):
82
+ result = None
83
+ for output in respond(message, system_message, max_tokens, temperature, top_p, selected_model):
84
+ result = output
85
 
86
+ # 프롬프트와 결과를 히스토리에 추가
87
+ prompt_history.add_entry(
88
+ message,
89
+ result[0][1], # 모델 응답
90
+ selected_model,
91
+ {"max_tokens": max_tokens, "temperature": temperature, "top_p": top_p}
92
+ )
93
 
94
+ return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
+ # 히스토리 확인용 함수
97
+ def view_history():
98
+ return prompt_history.get_history()
99
 
100
+ # Gradio UI 구성
101
  with gr.Blocks() as demo:
102
+ selected_model = gr.Dropdown(choices=list(models.keys()), label="모델 선택")
103
+ message = gr.Textbox(label="사용자 메시지")
104
+ system_message = gr.Textbox(label="시스템 메시지", value="이 메시지를 기준으로 대화 흐름을 설정합니다.")
105
+ max_tokens = gr.Slider(minimum=10, maximum=512, value=128, label="최대 토��� 수")
106
+ temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, label="Temperature")
107
+ top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.9, label="Top-p")
108
 
109
+ response_output = gr.Textbox(label="모델 응답")
110
+ token_usage = gr.Textbox(label="토큰 사용량")
111
+ history_button = gr.Button("히스토리 보기")
 
 
 
112
 
113
+ # 버튼을 눌러 응답을 받는 함수 연결
114
+ submit_button = gr.Button("응답 생성")
115
+ submit_button.click(gradio_interface, inputs=[message, system_message, max_tokens, temperature, top_p, selected_model], outputs=[response_output, token_usage])
116
+
117
+ # 히스토리 보기 기능 연결
118
+ history_output = gr.Textbox(label="히스토리", interactive=False)
119
+ history_button.click(view_history, outputs=history_output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
 
121
+ # UI 실행
122
+ demo.launch()