abobonbobo13 commited on
Commit
c99ebd4
·
verified ·
1 Parent(s): f600f55

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +141 -17
app.py CHANGED
@@ -1,52 +1,176 @@
1
  import torch
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
- import gradio as gr
4
 
 
5
  model = AutoModelForCausalLM.from_pretrained(
6
- "rinna/bilingual-gpt-neox-4b-instruction-ppo",
 
 
 
7
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=False)
8
 
9
  device = model.device
10
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  def generate(user_question,
13
  temperature=0.3,
14
- system_prompt_template = "システム: もちろんやで!どんどん質問してな!今日も気分ええわ!"
15
-
16
- # one-shot
17
-
18
 
 
 
19
 
20
- user_sample = "ユーザー:日本一の高さの山は? "
21
- system_sample = "システム: 富士山や!最高の眺めを拝めるで!!"
22
-
23
 
24
  user_prerix = "ユーザー: "
25
  system_prefix = "システム: "
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  output = tokenizer.decode(tokens[0], skip_special_tokens=True)
28
  return output[len(prompt):]
29
 
 
 
30
 
31
 
32
-
33
 
34
  with gr.Blocks() as demo:
35
- chat_history = gr.Chatbot()
36
- inputs = gr.Textbox(label="Question:", placeholder="質問を入力してください")
37
  outputs = gr.Textbox(label="Answer:")
38
  btn = gr.Button("Send")
39
- clear = gr.ClearButton([inputs, chat_history])
40
 
41
  # ボタンが押された時の動作を以下のように定義する:
42
-
43
  btn.click(fn=generate, inputs=inputs, outputs=outputs)
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  def response(user_message, chat_history):
 
46
  chat_history.append((user_message, system_message))
47
  return "", chat_history
48
 
49
- inputs.submit(response, inputs=[inputs, chat_history], outputs=[inputs, chat_history])
50
 
51
  if __name__ == "__main__":
52
- demo.launch()
 
 
1
  import torch
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
3
 
4
+ MODEL_ID = "rinna/bilingual-gpt-neox-4b-instruction-ppo"
5
  model = AutoModelForCausalLM.from_pretrained(
6
+ MODEL_ID,
7
+ load_in_8bit=True,
8
+ device_map="auto"
9
+ )
10
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=False)
11
 
12
  device = model.device
13
+ device
14
+
15
+ user_prompt_template = "ユーザー: Hello, you are an assistant that helps me learn Japanese. I am going to ask you a question, so please answer *briefly*."
16
+ system_prompt_template = "システム: Sure, I will answer briefly. What can I do for you?"
17
+
18
+ # one-shot
19
+ user_sample = "ユーザー: 日本で一番高い山は何ですか?"
20
+ system_sample = "システム: 富士山です。高さは3776メートルです。"
21
+
22
+ # 質問
23
+ user_prerix = "ユーザー: "
24
+ user_question = "人工知能とは何ですか?"
25
+ system_prefix = "システム: "
26
+
27
+ # プロンプトの整形
28
+ prompt = user_prompt_template + "\n" + system_prompt_template + "\n"
29
+ prompt += user_sample + "\n" + system_sample + "\n"
30
+ prompt += user_prerix + user_question + "\n" + system_prefix
31
+
32
+ inputs = tokenizer(
33
+ prompt,
34
+ add_special_tokens=False, # プロンプトに余計なトークンが付属するのを防ぐ
35
+ return_tensors="pt"
36
+ )
37
+ inputs = inputs.to(model.device)
38
+ with torch.no_grad():
39
+ tokens = model.generate(
40
+ **inputs,
41
+ temperature=0.3,
42
+ top_p=0.85,
43
+ max_new_tokens=2048,
44
+ repetition_penalty=1.05,
45
+ do_sample=True,
46
+ pad_token_id=tokenizer.pad_token_id,
47
+ bos_token_id=tokenizer.bos_token_id,
48
+ eos_token_id=tokenizer.eos_token_id
49
+ )
50
+
51
+ tokens
52
+
53
+ output = tokenizer.decode(
54
+ tokens[0],
55
+ skip_special_tokens=True # 出力に余計なトークンが付属するのを防ぐ
56
+ )
57
+ print(output)
58
+
59
+ output[len(prompt):]
60
 
61
  def generate(user_question,
62
  temperature=0.3,
63
+ top_p=0.85,
64
+ max_new_tokens=2048,
65
+ repetition_penalty=1.05
66
+ ):
67
 
68
+ user_prompt_template = "ユーザー: Hello, you are an assistant that helps me learn Japanese. I am going to ask you a question, so please answer *briefly*."
69
+ system_prompt_template = "システム: Sure, I will answer briefly. What can I do for you?"
70
 
71
+ user_sample = "ユーザー: 日本で一番高い山は何ですか?"
72
+ system_sample = "システム: 富士山です。高さは3776メートルです。"
 
73
 
74
  user_prerix = "ユーザー: "
75
  system_prefix = "システム: "
76
 
77
+ prompt = user_prompt_template + "\n" + system_prompt_template + "\n"
78
+ prompt += user_sample + "\n" + system_sample + "\n"
79
+ prompt += user_prerix + user_question + "\n" + system_prefix
80
+
81
+ inputs = tokenizer(prompt, add_special_tokens=False, return_tensors="pt")
82
+ inputs = inputs.to(model.device)
83
+ with torch.no_grad():
84
+ tokens = model.generate(
85
+ **inputs,
86
+ temperature=temperature,
87
+ top_p=top_p,
88
+ max_new_tokens=max_new_tokens,
89
+ repetition_penalty=repetition_penalty,
90
+ do_sample=True,
91
+ pad_token_id=tokenizer.pad_token_id,
92
+ bos_token_id=tokenizer.bos_token_id,
93
+ eos_token_id=tokenizer.eos_token_id
94
+ )
95
  output = tokenizer.decode(tokens[0], skip_special_tokens=True)
96
  return output[len(prompt):]
97
 
98
+ output = generate('人工知能とは何ですか?')
99
+ output
100
 
101
 
102
+ import gradio as gr # 慣習としてgrと略記
103
 
104
  with gr.Blocks() as demo:
105
+ inputs = gr.Textbox(label="Question:", placeholder="人工知能とは何ですか?")
 
106
  outputs = gr.Textbox(label="Answer:")
107
  btn = gr.Button("Send")
 
108
 
109
  # ボタンが押された時の動作を以下のように定義する:
110
+ # 「inputs内の値を入力としてモデルに渡し、その戻り値をoutputsの値として設定する」
111
  btn.click(fn=generate, inputs=inputs, outputs=outputs)
112
 
113
+ if __name__ == "__main__":
114
+ demo.launch()
115
+
116
+ def generate_response(user_question,
117
+ chat_history,
118
+ temperature=0.3,
119
+ top_p=0.85,
120
+ max_new_tokens=2048,
121
+ repetition_penalty=1.05
122
+ ):
123
+
124
+ user_prompt_template = "ユーザー: Hello, you are an assistant that helps me learn Japanese. I am going to ask you a question, so please answer *briefly*."
125
+ system_prompt_template = "システム: Sure, I will answer briefly. What can I do for you?"
126
+
127
+ user_sample = "ユーザー: 日本で一番高い山は何ですか?"
128
+ system_sample = "システム: 富士山です。高さは3776メートルです。"
129
+
130
+ user_prerix = "ユーザー: "
131
+ system_prefix = "システム: "
132
+
133
+ prompt = user_prompt_template + "\n" + system_prompt_template + "\n"
134
+
135
+ if len(chat_history) < 1:
136
+ prompt += user_sample + "\n" + system_sample + "\n"
137
+ else:
138
+ u = chat_history[-1][0]
139
+ s = chat_history[-1][1]
140
+ prompt += user_prerix + u + "\n" + system_prefix + s + "\n"
141
+
142
+ prompt += user_prerix + user_question + "\n" + system_prefix
143
+
144
+ inputs = tokenizer(prompt, add_special_tokens=False, return_tensors="pt")
145
+ inputs = inputs.to(model.device)
146
+ with torch.no_grad():
147
+ tokens = model.generate(
148
+ **inputs,
149
+ temperature=temperature,
150
+ top_p=top_p,
151
+ max_new_tokens=max_new_tokens,
152
+ repetition_penalty=repetition_penalty,
153
+ do_sample=True,
154
+ pad_token_id=tokenizer.pad_token_id,
155
+ bos_token_id=tokenizer.bos_token_id,
156
+ eos_token_id=tokenizer.eos_token_id
157
+ )
158
+ output = tokenizer.decode(tokens[0], skip_special_tokens=True)
159
+ return output[len(prompt):]
160
+
161
+
162
+ with gr.Blocks() as demo:
163
+ chat_history = gr.Chatbot()
164
+ user_message = gr.Textbox(label="Question:", placeholder="人工知能とは何ですか?")
165
+ clear = gr.ClearButton([user_message, chat_history])
166
+
167
  def response(user_message, chat_history):
168
+ system_message = generate_response(user_message, chat_history)
169
  chat_history.append((user_message, system_message))
170
  return "", chat_history
171
 
172
+ user_message.submit(response, inputs=[user_message, chat_history], outputs=[user_message, chat_history])
173
 
174
  if __name__ == "__main__":
175
+ demo.launch()
176
+