Spaces:
Sleeping
Sleeping
File size: 4,377 Bytes
92c62ca e1ded21 92c62ca c99ebd4 8dd9d83 e1ded21 d0ea771 c99ebd4 e1ded21 c99ebd4 8dd9d83 059c42c 8dd9d83 c99ebd4 8dd9d83 c99ebd4 8ec9971 02089a7 8ec9971 c99ebd4 8dd9d83 c99ebd4 8dd9d83 c99ebd4 8dd9d83 c99ebd4 86938fd c99ebd4 86938fd c99ebd4 92c62ca 86938fd c99ebd4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 |
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
MODEL_ID = "rinna/bilingual-gpt-neox-4b-instruction-ppo"
# モデルをロード(8ビット量子化を使用せず)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
device_map="auto" # 自動でCPU/GPUを選択
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=False)
def generate_response(user_question,
chat_history,
temperature=0.3,
top_p=0.85,
max_new_tokens=2048,
repetition_penalty=1.05):
user_prompt_template = "ユーザー: こんにちは、あなたは日本語を学ぶ手助けをしてくれるアシスタントやねん。質問するから、簡潔に答えてな。"
system_prompt_template = "システム: うん、簡潔に答えるで。何を教えてほしいん?"
user_sample = [
"ユーザー: 日本で一番高い山は何ですか?",
"ユーザー: 富士山の標高ってどれくらいなん?",
"ユーザー: 日本の一番高い山を教えて?",
"ユーザー: 大阪の有名な観光地はどこなん?",
"ユーザー: 神戸ってどんな街なん?",
"ユーザー: たこ焼きって何やったっけ?",
"ユーザー: お好み焼きって何が入ってるん?",
"ユーザー: あんた、今日はどうなん?",
"ユーザー: 今日の天気はどうやった?",
"ユーザー: 今度、どこか行こうか?"
]
system_sample = [
"システム: 富士山です。高さは3776メートルです。",
"システム: 富士山の標高は3776メートルやで。",
"システム: 富士山や。3776メートルやね。",
"システム: 大阪城やで。歴史ある場所やし、観光スポットとして有名やな。",
"システム: 神戸はおしゃれな街やで。港町として有名やし、異人館とかもあるし、観光するにはええ場所やな。",
"システム: たこ焼きは、小麦粉の生地にタコを入れて焼いたもんやで。ソースと鰹節をかけて食べるんや。",
"システム: お好み焼きは、キャベツと豚肉とか海鮮を混ぜた生地を鉄板で焼いた料理やな。ソースとマヨネーズがたまらんで。",
"システム: うーん、元気やで。あんたはどうなん?",
"システム: 今日は晴れてたけど、ちょっと暑かったな。まあ、夏っぽい天気やったで。",
"システム: そうやな、今度は京都に行ってみようか?お寺とか観光スポットがいっぱいあるで。"
]
user_prefix = "ユーザー: "
system_prefix = "システム: "
prompt = user_prompt_template + "\n" + system_prompt_template + "\n"
if len(chat_history) < 1:
prompt += user_sample + "\n" + system_sample + "\n"
else:
u = chat_history[-1][0]
s = chat_history[-1][1]
prompt += user_prefix + u + "\n" + system_prefix + s + "\n"
prompt += user_prefix + user_question + "\n" + system_prefix
inputs = tokenizer(prompt, add_special_tokens=False, return_tensors="pt")
inputs = inputs.to(model.device)
with torch.no_grad():
tokens = model.generate(
**inputs,
temperature=temperature,
top_p=top_p,
max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty,
do_sample=True,
pad_token_id=tokenizer.pad_token_id,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id
)
output = tokenizer.decode(tokens[0], skip_special_tokens=True)
return output[len(prompt):]
import gradio as gr
with gr.Blocks() as demo:
chat_history = gr.Chatbot()
user_message = gr.Textbox(label="Question:", placeholder="人工知能とは何ですか?")
clear = gr.ClearButton([user_message, chat_history])
def response(user_message, chat_history):
system_message = generate_response(user_message, chat_history)
chat_history.append((user_message, system_message))
return "", chat_history
user_message.submit(response, inputs=[user_message, chat_history], outputs=[user_message, chat_history])
if __name__ == "__main__":
demo.launch()
|