Plat commited on
Commit
fd490fb
·
1 Parent(s): 798bcec

chore: report conversation data

Browse files
Files changed (2) hide show
  1. app.py +290 -51
  2. requirements.txt +2 -0
app.py CHANGED
@@ -13,6 +13,10 @@ except:
13
 
14
  print("flash-attn installed.")
15
 
 
 
 
 
16
  import torch
17
  from transformers import (
18
  AutoModelForCausalLM,
@@ -23,18 +27,16 @@ from transformers import (
23
  from threading import Thread
24
 
25
  import gradio as gr
 
26
 
27
- try:
28
- import spaces
29
- except:
30
 
31
- class spaces:
32
- @staticmethod
33
- def GPU(duration: int):
34
- return lambda x: x
35
 
 
36
 
37
  MODEL_NAME = "hatakeyama-llm-team/Tanuki-8B-Instruct"
 
 
38
 
39
  quantization_config = BitsAndBytesConfig(
40
  load_in_4bit=True,
@@ -47,19 +49,72 @@ model = AutoModelForCausalLM.from_pretrained(
47
  )
48
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
49
 
50
- print(model.hf_device_map)
 
 
51
 
52
 
53
- @spaces.GPU(duration=30)
54
- def respond(
55
- message,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  history: list[tuple[str, str]],
57
- system_message,
58
- max_tokens,
59
- temperature,
60
- top_p,
61
- top_k,
62
  ):
 
 
 
63
  messages = [{"role": "system", "content": system_message}]
64
 
65
  for val in history:
@@ -80,50 +135,234 @@ def respond(
80
  generate_kwargs = dict(
81
  input_ids=tokenized_input,
82
  streamer=streamer,
83
- max_new_tokens=max_tokens,
84
  do_sample=True,
85
- temperature=temperature,
86
- top_k=top_k,
87
- top_p=top_p,
88
  num_beams=1,
89
  )
90
  t = Thread(target=model.generate, kwargs=generate_kwargs)
91
  t.start()
92
 
 
93
  partial_message = ""
 
94
  for new_token in streamer:
95
  partial_message += new_token
96
- yield partial_message
97
-
98
-
99
- demo = gr.ChatInterface(
100
- respond,
101
- additional_inputs=[
102
- gr.Textbox(
103
- value="以下は、タスクを説明する指示と、文脈のある入力の組み合わせです。要求を適切に満たす応答を書きなさい。",
104
- label="システムプロンプト",
105
- ),
106
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
107
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
108
- gr.Slider(
109
- minimum=0.1,
110
- maximum=1.0,
111
- value=0.95,
112
- step=0.05,
113
- label="Top-p",
114
- ),
115
- gr.Slider(minimum=1, maximum=2000, value=200, step=10, label="Top-k"),
116
- ],
117
- examples=[
118
- ["たぬきってなんですか?"],
119
- ["情けは人の為ならずとはどういう意味ですか?"],
120
- ["まどマギで一番可愛いのは誰?"],
121
- ["明晰夢とはなんですか?"],
122
- ["シュレディンガー方程式とシュレディンガーの猫はどのような関係がありますか?"],
123
- ],
124
- cache_examples=False,
125
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
 
127
 
128
  if __name__ == "__main__":
129
- demo.launch()
 
13
 
14
  print("flash-attn installed.")
15
 
16
+ import os
17
+ import uuid
18
+ import requests
19
+
20
  import torch
21
  from transformers import (
22
  AutoModelForCausalLM,
 
27
  from threading import Thread
28
 
29
  import gradio as gr
30
+ from dotenv import load_dotenv
31
 
32
+ import spaces
 
 
33
 
 
 
 
 
34
 
35
+ load_dotenv()
36
 
37
  MODEL_NAME = "hatakeyama-llm-team/Tanuki-8B-Instruct"
38
+ PREFERENCE_API_URL = os.getenv("PREFERENCE_API_URL")
39
+ assert PREFERENCE_API_URL, "PREFERENCE_API_URL is not set"
40
 
41
  quantization_config = BitsAndBytesConfig(
42
  load_in_4bit=True,
 
49
  )
50
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
51
 
52
+ print("Compiling model...")
53
+ model = torch.compile(model)
54
+ print("Model compiled.")
55
 
56
 
57
+ def send_report(
58
+ type: str,
59
+ data: dict,
60
+ ):
61
+ print(f"Sending report: {data}")
62
+ try:
63
+ res = requests.post(PREFERENCE_API_URL, json={"type": type, **data})
64
+ print(f"Report sent: {res.json()}")
65
+ except Exception as e:
66
+ print(f"Failed to send report: {e}")
67
+
68
+
69
+ def send_reply(
70
+ reply_id: str,
71
+ parent_id: str,
72
+ role: str,
73
+ body: str,
74
+ ):
75
+
76
+ send_report(
77
+ "conversation",
78
+ {
79
+ "reply_id": reply_id,
80
+ "parent_id": parent_id,
81
+ "role": role,
82
+ "body": body,
83
+ },
84
+ )
85
+
86
+
87
+ def send_score(
88
+ reply_id: str,
89
+ score: int,
90
+ ):
91
+ # print(f"Score: {score}, reply_id: {reply_id}")
92
+ send_report(
93
+ "score",
94
+ {
95
+ "reply_id": reply_id,
96
+ "score": score,
97
+ },
98
+ )
99
+
100
+
101
+ def generate_unique_id():
102
+ return str(uuid.uuid4())
103
+
104
+
105
+ @spaces.GPU(duration=45)
106
+ def generate(
107
+ message: str,
108
  history: list[tuple[str, str]],
109
+ system_message: str,
110
+ max_tokens: int,
111
+ temperature: float,
112
+ top_p: float,
113
+ top_k: int,
114
  ):
115
+ if not message or message.strip() == "":
116
+ return "", history
117
+
118
  messages = [{"role": "system", "content": system_message}]
119
 
120
  for val in history:
 
135
  generate_kwargs = dict(
136
  input_ids=tokenized_input,
137
  streamer=streamer,
138
+ max_new_tokens=int(max_tokens),
139
  do_sample=True,
140
+ temperature=float(temperature),
141
+ top_k=int(top_k),
142
+ top_p=float(top_p),
143
  num_beams=1,
144
  )
145
  t = Thread(target=model.generate, kwargs=generate_kwargs)
146
  t.start()
147
 
148
+ # 返す値を初期化
149
  partial_message = ""
150
+
151
  for new_token in streamer:
152
  partial_message += new_token
153
+ new_history = history + [(message, partial_message)]
154
+ # 入力テキストをクリアする
155
+ yield "", new_history
156
+
157
+
158
+ def respond(
159
+ message: str,
160
+ history: list[tuple[str, str]],
161
+ system_message: str,
162
+ max_tokens: int,
163
+ temperature: float,
164
+ top_p: float,
165
+ top_k: int,
166
+ reply_ids: list[str],
167
+ ):
168
+ if len(reply_ids) == 0:
169
+ reply_ids = [generate_unique_id()]
170
+ last_reply_id = reply_ids[-1]
171
+ user_reply_id = generate_unique_id()
172
+ assistant_reply_id = generate_unique_id()
173
+
174
+ reply_ids.append(user_reply_id)
175
+ reply_ids.append(assistant_reply_id)
176
+
177
+ for stream in generate(
178
+ message,
179
+ history,
180
+ system_message,
181
+ max_tokens,
182
+ temperature,
183
+ top_p,
184
+ top_k,
185
+ ):
186
+ yield *stream, reply_ids
187
+
188
+ # 記録を取る
189
+ if len(reply_ids) == 3:
190
+ send_reply(reply_ids[0], "", "system", system_message)
191
+ send_reply(user_reply_id, last_reply_id, "user", message)
192
+ send_reply(assistant_reply_id, user_reply_id, "assistant", stream[1][-1][1])
193
+
194
+
195
+ def retry(
196
+ history: list[tuple[str, str]],
197
+ system_message: str,
198
+ max_tokens: int,
199
+ temperature: float,
200
+ top_p: float,
201
+ top_k: int,
202
+ reply_ids: list[str],
203
+ ):
204
+ # 最後のメッセージを削除
205
+ last_conversation = history[-1]
206
+ user_message = last_conversation[0]
207
+ history = history[:-1]
208
+
209
+ user_reply_id = reply_ids[-2]
210
+ reply_ids = reply_ids[:-1]
211
+ assistant_reply_id = generate_unique_id()
212
+ reply_ids.append(assistant_reply_id)
213
+
214
+ for stream in generate(
215
+ user_message,
216
+ history,
217
+ system_message,
218
+ max_tokens,
219
+ temperature,
220
+ top_p,
221
+ top_k,
222
+ ):
223
+ yield *stream, reply_ids
224
+
225
+ # 記録を取る
226
+ send_reply(assistant_reply_id, user_reply_id, "assistant", stream[1][-1][1])
227
+
228
+
229
+ def like_reponse(like_data: gr.LikeData, reply_ids: list[str]):
230
+ # print(like_data.index, like_data.value, like_data.liked)
231
+ assert isinstance(like_data.index, list)
232
+ # 評価を送信
233
+ send_score(reply_ids[like_data.index[0] + 1], 1 if like_data.liked else -1)
234
+
235
+
236
+ def demo():
237
+ with gr.Blocks() as ui:
238
+
239
+ gr.Markdown(
240
+ """\
241
+ # Tanuki 8B Instruct デモ
242
+ モデル: https://huggingface.co/hatakeyama-llm-team/Tanuki-8B-Instruct
243
+
244
+ アシスタントの回答が不適切だと思った場合は **低評価ボタンを押して低評価を送信**、同様に、回答が素晴らしいと思った場合は**高評価ボタンを押して高評価を送信**することで、モデルの改善に貢献できます。
245
+
246
+ ## 注意点
247
+ **本デモに入力されたデータ・会話は匿名で全て記録されます**。これらのデータは Tanuki の学習に利用する可能性があります。そのため、**機密情報・個人情報を入力しないでください**。
248
+ """
249
+ )
250
+
251
+ reply_ids = gr.State(value=[generate_unique_id()])
252
+
253
+ chat_history = gr.Chatbot(value=[])
254
+
255
+ with gr.Row():
256
+ retry_btn = gr.Button(value="🔄 再生成", scale=1, size="sm")
257
+ clear_btn = gr.ClearButton(
258
+ components=[chat_history], value="🗑️ 削除", scale=1, size="sm"
259
+ )
260
+
261
+ with gr.Group():
262
+ with gr.Row():
263
+ input_text = gr.Textbox(
264
+ value="",
265
+ placeholder="質問を入力してください...",
266
+ show_label=False,
267
+ scale=8,
268
+ )
269
+ start_btn = gr.Button(
270
+ value="送信",
271
+ variant="primary",
272
+ scale=1,
273
+ )
274
+ gr.Markdown(
275
+ value="※ 機密情報を入力しないでください。また、Tanuki は誤った情報を生成する可能性があります。"
276
+ )
277
+
278
+ with gr.Accordion(label="詳細設定", open=False):
279
+ system_prompt_text = gr.Textbox(
280
+ label="システムプロンプト",
281
+ value="以下は、タスクを説明する指示と、文脈のある入力の組み合わせです。要求を適切に満たす応答を書きなさい。",
282
+ )
283
+ max_new_tokens_slider = gr.Slider(
284
+ minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"
285
+ )
286
+ temperature_slider = gr.Slider(
287
+ minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"
288
+ )
289
+ top_p_slider = gr.Slider(
290
+ minimum=0.1,
291
+ maximum=1.0,
292
+ value=0.95,
293
+ step=0.05,
294
+ label="Top-p",
295
+ )
296
+ top_k_slider = gr.Slider(
297
+ minimum=1, maximum=2000, value=250, step=10, label="Top-k"
298
+ )
299
+
300
+ gr.Examples(
301
+ examples=[
302
+ ["たぬきってなんですか?"],
303
+ ["情けは人の為ならずとはどういう意味ですか?"],
304
+ ["まどマギで一番可愛いのは誰?"],
305
+ ["明晰夢とはなんですか?"],
306
+ [
307
+ "シュレディンガー方程式とシュレディンガーの猫はどのような関係がありますか?"
308
+ ],
309
+ ],
310
+ inputs=[input_text],
311
+ cache_examples=False,
312
+ )
313
+
314
+ start_btn.click(
315
+ respond,
316
+ inputs=[
317
+ input_text,
318
+ chat_history,
319
+ system_prompt_text,
320
+ max_new_tokens_slider,
321
+ temperature_slider,
322
+ top_p_slider,
323
+ top_k_slider,
324
+ reply_ids,
325
+ ],
326
+ outputs=[input_text, chat_history, reply_ids],
327
+ )
328
+ input_text.submit(
329
+ respond,
330
+ inputs=[
331
+ input_text,
332
+ chat_history,
333
+ system_prompt_text,
334
+ max_new_tokens_slider,
335
+ temperature_slider,
336
+ top_p_slider,
337
+ top_k_slider,
338
+ reply_ids,
339
+ ],
340
+ outputs=[input_text, chat_history, reply_ids],
341
+ )
342
+ retry_btn.click(
343
+ retry,
344
+ inputs=[
345
+ chat_history,
346
+ system_prompt_text,
347
+ max_new_tokens_slider,
348
+ temperature_slider,
349
+ top_p_slider,
350
+ top_k_slider,
351
+ reply_ids,
352
+ ],
353
+ outputs=[input_text, chat_history, reply_ids],
354
+ )
355
+
356
+ # 評価されたら
357
+ chat_history.like(like_reponse, inputs=[reply_ids], outputs=None)
358
+
359
+ clear_btn.click(
360
+ lambda: [generate_unique_id()], # system_message用のIDを生成
361
+ outputs=[reply_ids],
362
+ )
363
+
364
+ ui.launch()
365
 
366
 
367
  if __name__ == "__main__":
368
+ demo()
requirements.txt CHANGED
@@ -4,3 +4,5 @@ accelerate==0.30.1
4
  transformers==4.41.2
5
  spaces==0.28.3
6
  bitsandbytes==0.43.1
 
 
 
4
  transformers==4.41.2
5
  spaces==0.28.3
6
  bitsandbytes==0.43.1
7
+ dotenv
8
+ requests