fujiki commited on
Commit
733ad93
·
1 Parent(s): 804ec08

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +262 -0
app.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ import os
4
+ import threading
5
+ import arrow
6
+ import time
7
+ import argparse
8
+ import logging
9
+ from dataclasses import dataclass
10
+
11
+ import torch
12
+ import sentencepiece as spm
13
+ from transformers import AutoModelForCausalLM, AutoTokenizer
14
+ from transformers.generation.streamers import BaseStreamer
15
+ from huggingface_hub import hf_hub_download, login
16
+
17
+
18
+ logger = logging.getLogger()
19
+ logger.setLevel("INFO")
20
+
21
+ gr_interface = None
22
+
23
+ VERSION = "0.1"
24
+
25
+ @dataclass
26
+ class DefaultArgs:
27
+ hf_model_name_or_path: str = "cyberagent/open-calm-1b"
28
+ spm_model_path: str = None
29
+ env: str = "dev"
30
+ port: int = 7860
31
+ make_public: bool = False
32
+
33
+
34
+ if not os.getenv("RUNNING_ON_HF_SPACE"):
35
+ parser = argparse.ArgumentParser(description="")
36
+ parser.add_argument("--hf_model_name_or_path", type=str, default="cyberagent/open-calm-small") # required=True)
37
+ parser.add_argument("--env", type=str, default="dev")
38
+ parser.add_argument("--port", type=int, default=7860)
39
+ parser.add_argument("--make_public", action='store_true')
40
+ args = parser.parse_args()
41
+
42
+ def load_model(
43
+ model_dir,
44
+ ):
45
+ model = AutoModelForCausalLM.from_pretrained(args.hf_model_name_or_path, device_map="auto", torch_dtype=torch.float32)
46
+ if torch.cuda.is_available():
47
+ model = model.to("cuda:0")
48
+ return model
49
+
50
+ logging.info("Loading model")
51
+ model = load_model(args.hf_model_name_or_path)
52
+ tokenizer = AutoTokenizer.from_pretrained(args.hf_model_name_or_path)
53
+ logging.info("Finished loading model")
54
+
55
+ class Streamer(BaseStreamer):
56
+ def __init__(self, tokenizer):
57
+ self.tokenizer = tokenizer
58
+ self.num_invoked = 0
59
+ self.prompt = ""
60
+ self.generated_text = ""
61
+ self.ended = False
62
+
63
+
64
+ def put(self, t: torch.Tensor):
65
+ d = t.dim()
66
+ if d == 1:
67
+ pass
68
+ elif d == 2:
69
+ t = t[0]
70
+ else:
71
+ raise NotImplementedError
72
+
73
+ t = [int(x) for x in t.numpy()]
74
+
75
+ text = self.tokenizer.decode(t, skip_special_tokens=True)
76
+
77
+ if self.num_invoked == 0:
78
+ self.prompt = text
79
+ self.num_invoked += 1
80
+ return
81
+
82
+ self.generated_text += text
83
+ logging.debug(f"[streamer]: {self.generated_text}")
84
+
85
+ def end(self):
86
+ self.ended = True
87
+
88
+ def generate(
89
+ prompt,
90
+ max_new_tokens,
91
+ temperature,
92
+ repetition_penalty,
93
+
94
+ do_sample,
95
+ no_repeat_ngram_size,
96
+ ):
97
+ log = dict(locals())
98
+ logging.debug(log)
99
+ print(log)
100
+
101
+ input_ids = tokenizer(prompt, return_tensors="pt")['input_ids'].to(model.device)
102
+ max_possilbe_new_tokens = model.config.max_position_embeddings - len(input_ids.squeeze(0))
103
+ max_possilbe_new_tokens = min(max_possilbe_new_tokens, max_new_tokens)
104
+
105
+ streamer = Streamer(tokenizer=tokenizer)
106
+ thr = threading.Thread(target=model.generate, args=(), kwargs=dict(
107
+ input_ids=input_ids,
108
+ do_sample=do_sample,
109
+ temperature=temperature,
110
+ repetition_penalty=repetition_penalty,
111
+ no_repeat_ngram_size=no_repeat_ngram_size,
112
+ max_new_tokens=max_possilbe_new_tokens,
113
+ streamer=streamer,
114
+ # max_length=4096,
115
+ # top_k=100,
116
+ # top_p=0.9,
117
+ # num_return_sequences=2,
118
+ # num_beams=2,
119
+ ))
120
+ thr.start()
121
+ gen_tokens = model.generate(
122
+ input_ids=input_ids,
123
+ do_sample=do_sample,
124
+ temperature=temperature,
125
+ repetition_penalty=repetition_penalty,
126
+ no_repeat_ngram_size=no_repeat_ngram_size,
127
+ max_new_tokens=max_possilbe_new_tokens,
128
+ )
129
+ gen = tokenizer.decode(gen_tokens[0], skip_special_tokens=True)
130
+
131
+ while not streamer.ended:
132
+ time.sleep(0.05)
133
+ yield streamer.generated_text
134
+
135
+ # TODO: optimize for final few tokens
136
+ gen = streamer.generated_text
137
+ log.update(dict(
138
+ generation=gen,
139
+ version=VERSION,
140
+ time=str(arrow.now("+09:00"))))
141
+ logging.info(log)
142
+ yield gen
143
+
144
+ def process_feedback(
145
+ rating,
146
+ prompt,
147
+ generation,
148
+
149
+ max_new_tokens,
150
+ temperature,
151
+ repetition_penalty,
152
+ do_sample,
153
+ no_repeat_ngram_size,
154
+ ):
155
+ log = dict(locals())
156
+ log.update(dict(
157
+ time=str(arrow.now("+09:00")),
158
+ version=VERSION,
159
+ ))
160
+ logging.info(log)
161
+
162
+ if gr_interface:
163
+ gr_interface.close(verbose=False)
164
+
165
+ with gr.Blocks() as gr_interface:
166
+ with gr.Row():
167
+ gr.Markdown(f"# open-calm-small Playground ({VERSION})")
168
+ with gr.Row():
169
+ gr.Markdown("open-calm-small Playground")
170
+ with gr.Row():
171
+
172
+ # left panel
173
+ with gr.Column(scale=1):
174
+
175
+ # generation params
176
+ with gr.Box():
177
+ gr.Markdown("hyper parameters")
178
+
179
+ # hidden default params
180
+ do_sample = gr.Checkbox(True, label="Do Sample", visible=True)
181
+ no_repeat_ngram_size = gr.Slider(0, 10, value=5, step=1, label="No Repeat Ngram Size", visible=False)
182
+
183
+ # visible params
184
+ max_new_tokens = gr.Slider(
185
+ 128,
186
+ min(512, model.config.max_position_embeddings),
187
+ value=128,
188
+ step=128,
189
+ label="max tokens",
190
+ )
191
+ temperature = gr.Slider(
192
+ 0, 1, value=0.7, step=0.05, label="temperature",
193
+ )
194
+ repetition_penalty = gr.Slider(
195
+ 1, 1.5, value=1.2, step=0.05, label="frequency penalty",
196
+ )
197
+
198
+ # grouping params for easier reference
199
+ gr_params = [
200
+ max_new_tokens,
201
+ temperature,
202
+ repetition_penalty,
203
+
204
+ do_sample,
205
+ no_repeat_ngram_size,
206
+ ]
207
+
208
+ # right panel
209
+ with gr.Column(scale=2):
210
+ # user input block
211
+ with gr.Box():
212
+ textbox_prompt = gr.Textbox(
213
+ label="入力",
214
+ placeholder="AIによって私達の暮らしは、",
215
+ interactive=True,
216
+ lines=5,
217
+ value="AIによって私達の暮らしは、"
218
+ )
219
+ with gr.Box():
220
+ with gr.Row():
221
+ btn_stop = gr.Button(value="キャンセル", variant="secondary")
222
+ btn_submit = gr.Button(value="実行", variant="primary")
223
+
224
+
225
+ # model output block
226
+ with gr.Box():
227
+ textbox_generation = gr.Textbox(
228
+ label="応答",
229
+ lines=5,
230
+ value=""
231
+ )
232
+
233
+ # rating block
234
+ with gr.Row():
235
+ gr.Markdown("この応答に対するあなたの評価は?")
236
+
237
+ with gr.Box():
238
+ with gr.Row():
239
+ rating_options = [
240
+ "最悪",
241
+ "不合格",
242
+ "中立",
243
+ "合格",
244
+ "最高",
245
+ ]
246
+ btn_ratings = [gr.Button(value=v) for v in rating_options]
247
+
248
+ # TODO: we might not need this for sharing with close groups
249
+ # with gr.Box():
250
+ # gr.Markdown("TODO:For more feedback link for google form")
251
+
252
+ # event handling
253
+ inputs = [textbox_prompt] + gr_params
254
+ click_event = btn_submit.click(generate, inputs, textbox_generation, queue=True)
255
+ btn_stop.click(None, None, None, cancels=click_event, queue=False)
256
+
257
+ for btn_rating in btn_ratings:
258
+ btn_rating.click(process_feedback, [btn_rating, textbox_prompt, textbox_generation] + gr_params, queue=False)
259
+
260
+
261
+ gr_interface.queue(max_size=32, concurrency_count=2)
262
+ gr_interface.launch(server_port=args.port, share=args.make_public)