sample / app.py
abobonbobo13's picture
Update app.py
c99ebd4 verified
raw
history blame
6.21 kB
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
MODEL_ID = "rinna/bilingual-gpt-neox-4b-instruction-ppo"
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
load_in_8bit=True,
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=False)
device = model.device
device
user_prompt_template = "ユーザー: Hello, you are an assistant that helps me learn Japanese. I am going to ask you a question, so please answer *briefly*."
system_prompt_template = "システム: Sure, I will answer briefly. What can I do for you?"
# one-shot
user_sample = "ユーザー: 日本で一番高い山は何ですか?"
system_sample = "システム: 富士山です。高さは3776メートルです。"
# 質問
user_prerix = "ユーザー: "
user_question = "人工知能とは何ですか?"
system_prefix = "システム: "
# プロンプトの整形
prompt = user_prompt_template + "\n" + system_prompt_template + "\n"
prompt += user_sample + "\n" + system_sample + "\n"
prompt += user_prerix + user_question + "\n" + system_prefix
inputs = tokenizer(
prompt,
add_special_tokens=False, # プロンプトに余計なトークンが付属するのを防ぐ
return_tensors="pt"
)
inputs = inputs.to(model.device)
with torch.no_grad():
tokens = model.generate(
**inputs,
temperature=0.3,
top_p=0.85,
max_new_tokens=2048,
repetition_penalty=1.05,
do_sample=True,
pad_token_id=tokenizer.pad_token_id,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id
)
tokens
output = tokenizer.decode(
tokens[0],
skip_special_tokens=True # 出力に余計なトークンが付属するのを防ぐ
)
print(output)
output[len(prompt):]
def generate(user_question,
temperature=0.3,
top_p=0.85,
max_new_tokens=2048,
repetition_penalty=1.05
):
user_prompt_template = "ユーザー: Hello, you are an assistant that helps me learn Japanese. I am going to ask you a question, so please answer *briefly*."
system_prompt_template = "システム: Sure, I will answer briefly. What can I do for you?"
user_sample = "ユーザー: 日本で一番高い山は何ですか?"
system_sample = "システム: 富士山です。高さは3776メートルです。"
user_prerix = "ユーザー: "
system_prefix = "システム: "
prompt = user_prompt_template + "\n" + system_prompt_template + "\n"
prompt += user_sample + "\n" + system_sample + "\n"
prompt += user_prerix + user_question + "\n" + system_prefix
inputs = tokenizer(prompt, add_special_tokens=False, return_tensors="pt")
inputs = inputs.to(model.device)
with torch.no_grad():
tokens = model.generate(
**inputs,
temperature=temperature,
top_p=top_p,
max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty,
do_sample=True,
pad_token_id=tokenizer.pad_token_id,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id
)
output = tokenizer.decode(tokens[0], skip_special_tokens=True)
return output[len(prompt):]
output = generate('人工知能とは何ですか?')
output
import gradio as gr # 慣習としてgrと略記
with gr.Blocks() as demo:
inputs = gr.Textbox(label="Question:", placeholder="人工知能とは何ですか?")
outputs = gr.Textbox(label="Answer:")
btn = gr.Button("Send")
# ボタンが押された時の動作を以下のように定義する:
# 「inputs内の値を入力としてモデルに渡し、その戻り値をoutputsの値として設定する」
btn.click(fn=generate, inputs=inputs, outputs=outputs)
if __name__ == "__main__":
demo.launch()
def generate_response(user_question,
chat_history,
temperature=0.3,
top_p=0.85,
max_new_tokens=2048,
repetition_penalty=1.05
):
user_prompt_template = "ユーザー: Hello, you are an assistant that helps me learn Japanese. I am going to ask you a question, so please answer *briefly*."
system_prompt_template = "システム: Sure, I will answer briefly. What can I do for you?"
user_sample = "ユーザー: 日本で一番高い山は何ですか?"
system_sample = "システム: 富士山です。高さは3776メートルです。"
user_prerix = "ユーザー: "
system_prefix = "システム: "
prompt = user_prompt_template + "\n" + system_prompt_template + "\n"
if len(chat_history) < 1:
prompt += user_sample + "\n" + system_sample + "\n"
else:
u = chat_history[-1][0]
s = chat_history[-1][1]
prompt += user_prerix + u + "\n" + system_prefix + s + "\n"
prompt += user_prerix + user_question + "\n" + system_prefix
inputs = tokenizer(prompt, add_special_tokens=False, return_tensors="pt")
inputs = inputs.to(model.device)
with torch.no_grad():
tokens = model.generate(
**inputs,
temperature=temperature,
top_p=top_p,
max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty,
do_sample=True,
pad_token_id=tokenizer.pad_token_id,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id
)
output = tokenizer.decode(tokens[0], skip_special_tokens=True)
return output[len(prompt):]
with gr.Blocks() as demo:
chat_history = gr.Chatbot()
user_message = gr.Textbox(label="Question:", placeholder="人工知能とは何ですか?")
clear = gr.ClearButton([user_message, chat_history])
def response(user_message, chat_history):
system_message = generate_response(user_message, chat_history)
chat_history.append((user_message, system_message))
return "", chat_history
user_message.submit(response, inputs=[user_message, chat_history], outputs=[user_message, chat_history])
if __name__ == "__main__":
demo.launch()