p1atdev commited on
Commit
cb7c8fd
·
verified ·
1 Parent(s): 756bd11

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -30
app.py CHANGED
@@ -33,7 +33,14 @@ import spaces
33
  load_dotenv()
34
 
35
  HF_API_KEY = os.getenv("HF_API_KEY")
36
- MODEL_NAME = "weblab-GENIAC/Tanuki-8B-dpo-v1.0"
 
 
 
 
 
 
 
37
 
38
  quantization_config = BitsAndBytesConfig(
39
  load_in_4bit=True,
@@ -41,18 +48,25 @@ quantization_config = BitsAndBytesConfig(
41
  bnb_4bit_quant_type="nf4",
42
  bnb_4bit_use_double_quant=True,
43
  )
44
- model = AutoModelForCausalLM.from_pretrained(
45
- MODEL_NAME, quantization_config=quantization_config, device_map="auto", token=HF_API_KEY
46
- )
47
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=HF_API_KEY)
 
 
 
 
 
48
 
49
  print("Compiling model...")
50
- model = torch.compile(model)
 
51
  print("Model compiled.")
52
 
53
 
54
- @spaces.GPU(duration=30)
55
  def generate(
 
56
  message: str,
57
  history: list[tuple[str, str]],
58
  system_message: str,
@@ -74,12 +88,12 @@ def generate(
74
 
75
  messages.append({"role": "user", "content": message})
76
 
77
- tokenized_input = tokenizer.apply_chat_template(
78
  messages, add_generation_prompt=True, tokenize=True, return_tensors="pt"
79
  ).to(model.device)
80
 
81
  streamer = TextIteratorStreamer(
82
- tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
83
  )
84
  generate_kwargs = dict(
85
  input_ids=tokenized_input,
@@ -91,7 +105,7 @@ def generate(
91
  top_p=float(top_p),
92
  num_beams=1,
93
  )
94
- t = Thread(target=model.generate, kwargs=generate_kwargs)
95
  t.start()
96
 
97
  # 返す値を初期化
@@ -105,6 +119,7 @@ def generate(
105
 
106
 
107
  def respond(
 
108
  message: str,
109
  history: list[tuple[str, str]],
110
  system_message: str,
@@ -115,6 +130,7 @@ def respond(
115
  ):
116
 
117
  for stream in generate(
 
118
  message,
119
  history,
120
  system_message,
@@ -127,6 +143,7 @@ def respond(
127
 
128
 
129
  def retry(
 
130
  history: list[tuple[str, str]],
131
  system_message: str,
132
  max_tokens: int,
@@ -140,6 +157,7 @@ def retry(
140
  history = history[:-1]
141
 
142
  for stream in generate(
 
143
  user_message,
144
  history,
145
  system_message,
@@ -156,11 +174,13 @@ def demo():
156
 
157
  gr.Markdown(
158
  """\
159
- # weblab-GENIAC/Tanuki-8B-dpo-v1.0 デモ
160
- モデル: https://huggingface.co/weblab-GENIAC/Tanuki-8B-dpo-v1.0
161
  """
162
  )
163
 
 
 
164
  chat_history = gr.Chatbot(value=[])
165
 
166
  with gr.Row():
@@ -183,7 +203,7 @@ def demo():
183
  scale=2,
184
  )
185
  gr.Markdown(
186
- value="※ Tanuki は誤った情報を生成する可能性があります。"
187
  )
188
 
189
  with gr.Accordion(label="詳細設定", open=False):
@@ -195,7 +215,7 @@ def demo():
195
  minimum=1, maximum=2048, value=256, step=1, label="Max new tokens"
196
  )
197
  temperature_slider = gr.Slider(
198
- minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"
199
  )
200
  top_p_slider = gr.Slider(
201
  minimum=0.1,
@@ -210,7 +230,6 @@ def demo():
210
 
211
  gr.Examples(
212
  examples=[
213
- ["たぬきってなんですか?"],
214
  ["情けは人の為ならずとはどういう意味ですか?"],
215
  ["まどマギで一番可愛いのは誰?"],
216
  ],
@@ -218,22 +237,11 @@ def demo():
218
  cache_examples=False,
219
  )
220
 
221
- start_btn.click(
222
- respond,
223
- inputs=[
224
- input_text,
225
- chat_history,
226
- system_prompt_text,
227
- max_new_tokens_slider,
228
- temperature_slider,
229
- top_p_slider,
230
- top_k_slider,
231
- ],
232
- outputs=[input_text, chat_history],
233
- )
234
- input_text.submit(
235
- respond,
236
  inputs=[
 
237
  input_text,
238
  chat_history,
239
  system_prompt_text,
@@ -247,6 +255,7 @@ def demo():
247
  retry_btn.click(
248
  retry,
249
  inputs=[
 
250
  chat_history,
251
  system_prompt_text,
252
  max_new_tokens_slider,
 
33
  load_dotenv()
34
 
35
  HF_API_KEY = os.getenv("HF_API_KEY")
36
+ MODEL_NAME_MAP = {
37
+ "150m-instruct3": "llm-jp/llm-jp-3-150m-instruct3",
38
+ "440m-instruct3": "llm-jp/llm-jp-3-440m-instruct3",
39
+ "980m-instruct3": "llm-jp/llm-jp-3-980m-instruct3",
40
+ "1.8b-instruct3": "llm-jp/llm-jp-3-1.8b-instruct3",
41
+ "3.7b-instruct3": "llm-jp/llm-jp-3-3.7b-instruct3",
42
+ "13b-instruct3": "llm-jp/llm-jp-3-13b-instruct3",
43
+ }
44
 
45
  quantization_config = BitsAndBytesConfig(
46
  load_in_4bit=True,
 
48
  bnb_4bit_quant_type="nf4",
49
  bnb_4bit_use_double_quant=True,
50
  )
51
+ MODELS = {
52
+ key: AutoModelForCausalLM.from_pretrained(
53
+ MODEL_NAME, quantization_config=quantization_config, device_map="auto"
54
+ ) for key, value in MODEL_NAME_MAP.items()
55
+ }
56
+ TOKENIZERS = {
57
+ key: AutoTokenizer.from_pretrained(MODEL_NAME) for key, value in MODEL_NAME_MAP.items()
58
+
59
+ }
60
 
61
  print("Compiling model...")
62
+ for key, model in MODELS:
63
+ MODELS[key] = torch.compile(model)
64
  print("Model compiled.")
65
 
66
 
67
+ @spaces.GPU(duration=45)
68
  def generate(
69
+ model_name: str,
70
  message: str,
71
  history: list[tuple[str, str]],
72
  system_message: str,
 
88
 
89
  messages.append({"role": "user", "content": message})
90
 
91
+ tokenized_input = TOKENIZERS[model_name].apply_chat_template(
92
  messages, add_generation_prompt=True, tokenize=True, return_tensors="pt"
93
  ).to(model.device)
94
 
95
  streamer = TextIteratorStreamer(
96
+ TOKENIZERS[model_name], timeout=10.0, skip_prompt=True, skip_special_tokens=True
97
  )
98
  generate_kwargs = dict(
99
  input_ids=tokenized_input,
 
105
  top_p=float(top_p),
106
  num_beams=1,
107
  )
108
+ t = Thread(target=MODELS[model_name].generate, kwargs=generate_kwargs)
109
  t.start()
110
 
111
  # 返す値を初期化
 
119
 
120
 
121
  def respond(
122
+ model_name: str,
123
  message: str,
124
  history: list[tuple[str, str]],
125
  system_message: str,
 
130
  ):
131
 
132
  for stream in generate(
133
+ model_name,
134
  message,
135
  history,
136
  system_message,
 
143
 
144
 
145
  def retry(
146
+ model_name: str,
147
  history: list[tuple[str, str]],
148
  system_message: str,
149
  max_tokens: int,
 
157
  history = history[:-1]
158
 
159
  for stream in generate(
160
+ model_name,
161
  user_message,
162
  history,
163
  system_message,
 
174
 
175
  gr.Markdown(
176
  """\
177
+ # llm-jp/llm-jp-3 instruct3 モデルデモ
178
+ コレクション: https://huggingface.co/collections/llm-jp/llm-jp-3-fine-tuned-models-672c621db852a01eae939731
179
  """
180
  )
181
 
182
+ model_name_dropdown = gr.Dropdown(label="モデル", choices=list(MODELS.keys()), value=list(MODELS.keys())[0])
183
+
184
  chat_history = gr.Chatbot(value=[])
185
 
186
  with gr.Row():
 
203
  scale=2,
204
  )
205
  gr.Markdown(
206
+ value="※ 誤った情報を生成する可能性があります。"
207
  )
208
 
209
  with gr.Accordion(label="詳細設定", open=False):
 
215
  minimum=1, maximum=2048, value=256, step=1, label="Max new tokens"
216
  )
217
  temperature_slider = gr.Slider(
218
+ minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature"
219
  )
220
  top_p_slider = gr.Slider(
221
  minimum=0.1,
 
230
 
231
  gr.Examples(
232
  examples=[
 
233
  ["情けは人の為ならずとはどういう意味ですか?"],
234
  ["まどマギで一番可愛いのは誰?"],
235
  ],
 
237
  cache_examples=False,
238
  )
239
 
240
+ gr.on(
241
+ triggers=[start_btn.click, input_text.submit],
242
+ fn=respond,
 
 
 
 
 
 
 
 
 
 
 
 
243
  inputs=[
244
+ model_name_dropdown,
245
  input_text,
246
  chat_history,
247
  system_prompt_text,
 
255
  retry_btn.click(
256
  retry,
257
  inputs=[
258
+ model_name_dropdown,
259
  chat_history,
260
  system_prompt_text,
261
  max_new_tokens_slider,