p1atdev commited on
Commit
1c916f2
·
verified ·
1 Parent(s): bbf7f96

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -114
app.py CHANGED
@@ -5,7 +5,7 @@ except:
5
 
6
  print("Installing flash-attn...")
7
  subprocess.run(
8
- "pip install flash-attn --no-build-isolation",
9
  env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
10
  shell=True,
11
  )
@@ -35,10 +35,7 @@ import spaces
35
  load_dotenv()
36
 
37
  HF_API_KEY = os.getenv("HF_API_KEY")
38
- MODEL_NAME = "hatakeyama-llm-team/Tanuki-8B-Instruct"
39
- PREFERENCE_API_URL = os.getenv("PREFERENCE_API_URL")
40
- assert PREFERENCE_API_URL, "PREFERENCE_API_URL is not set"
41
-
42
 
43
  quantization_config = BitsAndBytesConfig(
44
  load_in_4bit=True,
@@ -56,55 +53,7 @@ model = torch.compile(model)
56
  print("Model compiled.")
57
 
58
 
59
- def send_report(
60
- type: str,
61
- data: dict,
62
- ):
63
- print(f"Sending report: {data}")
64
- try:
65
- res = requests.post(PREFERENCE_API_URL, json={"type": type, **data})
66
- print(f"Report sent: {res.json()}")
67
- except Exception as e:
68
- print(f"Failed to send report: {e}")
69
-
70
-
71
- def send_reply(
72
- reply_id: str,
73
- parent_id: str,
74
- role: str,
75
- body: str,
76
- ):
77
-
78
- send_report(
79
- "conversation",
80
- {
81
- "reply_id": reply_id,
82
- "parent_id": parent_id,
83
- "role": role,
84
- "body": body,
85
- },
86
- )
87
-
88
-
89
- def send_score(
90
- reply_id: str,
91
- score: int,
92
- ):
93
- # print(f"Score: {score}, reply_id: {reply_id}")
94
- send_report(
95
- "score",
96
- {
97
- "reply_id": reply_id,
98
- "score": score,
99
- },
100
- )
101
-
102
-
103
- def generate_unique_id():
104
- return str(uuid.uuid4())
105
-
106
-
107
- @spaces.GPU(duration=45)
108
  def generate(
109
  message: str,
110
  history: list[tuple[str, str]],
@@ -165,16 +114,7 @@ def respond(
165
  temperature: float,
166
  top_p: float,
167
  top_k: int,
168
- reply_ids: list[str],
169
  ):
170
- if len(reply_ids) == 0:
171
- reply_ids = [generate_unique_id()]
172
- last_reply_id = reply_ids[-1]
173
- user_reply_id = generate_unique_id()
174
- assistant_reply_id = generate_unique_id()
175
-
176
- reply_ids.append(user_reply_id)
177
- reply_ids.append(assistant_reply_id)
178
 
179
  for stream in generate(
180
  message,
@@ -185,13 +125,7 @@ def respond(
185
  top_p,
186
  top_k,
187
  ):
188
- yield *stream, reply_ids
189
-
190
- # 記録を取る
191
- if len(reply_ids) == 3:
192
- send_reply(reply_ids[0], "", "system", system_message)
193
- send_reply(user_reply_id, last_reply_id, "user", message)
194
- send_reply(assistant_reply_id, user_reply_id, "assistant", stream[1][-1][1])
195
 
196
 
197
  def retry(
@@ -201,18 +135,12 @@ def retry(
201
  temperature: float,
202
  top_p: float,
203
  top_k: int,
204
- reply_ids: list[str],
205
  ):
206
  # 最後のメッセージを削除
207
  last_conversation = history[-1]
208
  user_message = last_conversation[0]
209
  history = history[:-1]
210
 
211
- user_reply_id = reply_ids[-2]
212
- reply_ids = reply_ids[:-1]
213
- assistant_reply_id = generate_unique_id()
214
- reply_ids.append(assistant_reply_id)
215
-
216
  for stream in generate(
217
  user_message,
218
  history,
@@ -222,17 +150,7 @@ def retry(
222
  top_p,
223
  top_k,
224
  ):
225
- yield *stream, reply_ids
226
-
227
- # 記録を取る
228
- send_reply(assistant_reply_id, user_reply_id, "assistant", stream[1][-1][1])
229
-
230
-
231
- def like_reponse(like_data: gr.LikeData, reply_ids: list[str]):
232
- # print(like_data.index, like_data.value, like_data.liked)
233
- assert isinstance(like_data.index, list)
234
- # 評価を送信
235
- send_score(reply_ids[like_data.index[0] + 1], 1 if like_data.liked else -1)
236
 
237
 
238
  def demo():
@@ -240,18 +158,11 @@ def demo():
240
 
241
  gr.Markdown(
242
  """\
243
- # Tanuki 8B Instruct デモ
244
- モデル: https://huggingface.co/hatakeyama-llm-team/Tanuki-8B-Instruct
245
-
246
- アシスタントの回答が不適切だと思った場合は **低評価ボタンを押して低評価を送信**、同様に、回答が素晴らしいと思った場合は**高評価ボタンを押して高評価を送信**することで、モデルの改善に貢献できます。
247
-
248
- ## 注意点
249
- **本デモに入力されたデータ・会話は匿名で全て記録されます**。これらのデータは Tanuki の学習に利用する可能性があります。そのため、**機密情報・個人情報を入力しないでください**。
250
  """
251
  )
252
 
253
- reply_ids = gr.State(value=[generate_unique_id()])
254
-
255
  chat_history = gr.Chatbot(value=[])
256
 
257
  with gr.Row():
@@ -304,10 +215,6 @@ def demo():
304
  ["たぬきってなんですか?"],
305
  ["情けは人の為ならずとはどういう意味ですか?"],
306
  ["まどマギで一番可愛いのは誰?"],
307
- ["明晰夢とはなんですか?"],
308
- [
309
- "シュレディンガー方程式とシュレディンガーの猫はどのような関係があ��ますか?"
310
- ],
311
  ],
312
  inputs=[input_text],
313
  cache_examples=False,
@@ -323,9 +230,8 @@ def demo():
323
  temperature_slider,
324
  top_p_slider,
325
  top_k_slider,
326
- reply_ids,
327
  ],
328
- outputs=[input_text, chat_history, reply_ids],
329
  )
330
  input_text.submit(
331
  respond,
@@ -337,9 +243,8 @@ def demo():
337
  temperature_slider,
338
  top_p_slider,
339
  top_k_slider,
340
- reply_ids,
341
  ],
342
- outputs=[input_text, chat_history, reply_ids],
343
  )
344
  retry_btn.click(
345
  retry,
@@ -350,17 +255,8 @@ def demo():
350
  temperature_slider,
351
  top_p_slider,
352
  top_k_slider,
353
- reply_ids,
354
  ],
355
- outputs=[input_text, chat_history, reply_ids],
356
- )
357
-
358
- # 評価されたら
359
- chat_history.like(like_reponse, inputs=[reply_ids], outputs=None)
360
-
361
- clear_btn.click(
362
- lambda: [generate_unique_id()], # system_message用のIDを生成
363
- outputs=[reply_ids],
364
  )
365
 
366
  ui.launch()
 
5
 
6
  print("Installing flash-attn...")
7
  subprocess.run(
8
+ "uv install --system flash-attn --no-build-isolation",
9
  env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
10
  shell=True,
11
  )
 
35
  load_dotenv()
36
 
37
  HF_API_KEY = os.getenv("HF_API_KEY")
38
+ MODEL_NAME = "weblab-GENIAC/Tanuki-8B-dpo-v1.0"
 
 
 
39
 
40
  quantization_config = BitsAndBytesConfig(
41
  load_in_4bit=True,
 
53
  print("Model compiled.")
54
 
55
 
56
+ @spaces.GPU(duration=30)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  def generate(
58
  message: str,
59
  history: list[tuple[str, str]],
 
114
  temperature: float,
115
  top_p: float,
116
  top_k: int,
 
117
  ):
 
 
 
 
 
 
 
 
118
 
119
  for stream in generate(
120
  message,
 
125
  top_p,
126
  top_k,
127
  ):
128
+ yield *stream
 
 
 
 
 
 
129
 
130
 
131
  def retry(
 
135
  temperature: float,
136
  top_p: float,
137
  top_k: int,
 
138
  ):
139
  # 最後のメッセージを削除
140
  last_conversation = history[-1]
141
  user_message = last_conversation[0]
142
  history = history[:-1]
143
 
 
 
 
 
 
144
  for stream in generate(
145
  user_message,
146
  history,
 
150
  top_p,
151
  top_k,
152
  ):
153
+ yield *stream
 
 
 
 
 
 
 
 
 
 
154
 
155
 
156
  def demo():
 
158
 
159
  gr.Markdown(
160
  """\
161
+ # weblab-GENIAC/Tanuki-8B-dpo-v1.0 デモ
162
+ モデル: https://huggingface.co/weblab-GENIAC/Tanuki-8B-dpo-v1.0
 
 
 
 
 
163
  """
164
  )
165
 
 
 
166
  chat_history = gr.Chatbot(value=[])
167
 
168
  with gr.Row():
 
215
  ["たぬきってなんですか?"],
216
  ["情けは人の為ならずとはどういう意味ですか?"],
217
  ["まどマギで一番可愛いのは誰?"],
 
 
 
 
218
  ],
219
  inputs=[input_text],
220
  cache_examples=False,
 
230
  temperature_slider,
231
  top_p_slider,
232
  top_k_slider,
 
233
  ],
234
+ outputs=[input_text, chat_history],
235
  )
236
  input_text.submit(
237
  respond,
 
243
  temperature_slider,
244
  top_p_slider,
245
  top_k_slider,
 
246
  ],
247
+ outputs=[input_text, chat_history],
248
  )
249
  retry_btn.click(
250
  retry,
 
255
  temperature_slider,
256
  top_p_slider,
257
  top_k_slider,
 
258
  ],
259
+ outputs=[input_text, chat_history],
 
 
 
 
 
 
 
 
260
  )
261
 
262
  ui.launch()