Spaces:
Sleeping
Sleeping
File size: 2,772 Bytes
92c62ca e1ded21 92c62ca c99ebd4 8dd9d83 e1ded21 d0ea771 c99ebd4 e1ded21 c99ebd4 8dd9d83 059c42c 8dd9d83 c99ebd4 8dd9d83 c99ebd4 b51e865 264e69b c99ebd4 8dd9d83 c99ebd4 8dd9d83 c99ebd4 8dd9d83 c99ebd4 86938fd c99ebd4 86938fd c99ebd4 92c62ca 86938fd c99ebd4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 |
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
MODEL_ID = "rinna/bilingual-gpt-neox-4b-instruction-ppo"
# モデルをロード(8ビット量子化を使用せず)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
device_map="auto" # 自動でCPU/GPUを選択
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=False)
def generate_response(user_question,
chat_history,
temperature=0.3,
top_p=0.85,
max_new_tokens=2048,
repetition_penalty=1.05):
user_prompt_template = "ユーザー: こんにちは、あなたは日本語を学ぶ手助けをしてくれるアシスタントやねん。質問するから、簡潔に答えてな。"
system_prompt_template = "システム: うん、簡潔に答えるで。何を教えてほしいん?"
user_sample = "ユーザー: 富士山の標高ってどれくらいなん?"
system_sample = "システム: 富士山の標高は3776メートルやで。"
user_prefix = "ユーザー: "
system_prefix = "システム: "
prompt = user_prompt_template + "\n" + system_prompt_template + "\n"
if len(chat_history) < 1:
prompt += user_sample + "\n" + system_sample + "\n"
else:
u = chat_history[-1][0]
s = chat_history[-1][1]
prompt += user_prefix + u + "\n" + system_prefix + s + "\n"
prompt += user_prefix + user_question + "\n" + system_prefix
inputs = tokenizer(prompt, add_special_tokens=False, return_tensors="pt")
inputs = inputs.to(model.device)
with torch.no_grad():
tokens = model.generate(
**inputs,
temperature=temperature,
top_p=top_p,
max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty,
do_sample=True,
pad_token_id=tokenizer.pad_token_id,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id
)
output = tokenizer.decode(tokens[0], skip_special_tokens=True)
return output[len(prompt):]
import gradio as gr
with gr.Blocks() as demo:
chat_history = gr.Chatbot()
user_message = gr.Textbox(label="Question:", placeholder="人工知能とは何ですか?")
clear = gr.ClearButton([user_message, chat_history])
def response(user_message, chat_history):
system_message = generate_response(user_message, chat_history)
chat_history.append((user_message, system_message))
return "", chat_history
user_message.submit(response, inputs=[user_message, chat_history], outputs=[user_message, chat_history])
if __name__ == "__main__":
demo.launch()
|