Spaces:
Sleeping
Sleeping
File size: 2,728 Bytes
92c62ca e1ded21 92c62ca c99ebd4 8dd9d83 e1ded21 d0ea771 c99ebd4 e1ded21 c99ebd4 8dd9d83 059c42c 8dd9d83 c99ebd4 8dd9d83 c99ebd4 8dd9d83 c99ebd4 8dd9d83 c99ebd4 8dd9d83 c99ebd4 8dd9d83 c99ebd4 86938fd c99ebd4 86938fd c99ebd4 92c62ca 86938fd c99ebd4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 |
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
MODEL_ID = "rinna/bilingual-gpt-neox-4b-instruction-ppo"
# モデルをロード(8ビット量子化を使用せず)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
device_map="auto" # 自動でCPU/GPUを選択
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=False)
def generate_response(user_question,
chat_history,
temperature=0.3,
top_p=0.85,
max_new_tokens=2048,
repetition_penalty=1.05):
user_prompt_template = "ユーザー: Hello, you are an assistant that helps me learn Japanese. I am going to ask you a question, so please answer *briefly*."
system_prompt_template = "システム: Sure, I will answer briefly. What can I do for you?"
user_sample = "ユーザー: 日本で一番高い山は何ですか?"
system_sample = "システム: 富士山です。高さは3776メートルです。"
user_prefix = "ユーザー: "
system_prefix = "システム: "
prompt = user_prompt_template + "\n" + system_prompt_template + "\n"
if len(chat_history) < 1:
prompt += user_sample + "\n" + system_sample + "\n"
else:
u = chat_history[-1][0]
s = chat_history[-1][1]
prompt += user_prefix + u + "\n" + system_prefix + s + "\n"
prompt += user_prefix + user_question + "\n" + system_prefix
inputs = tokenizer(prompt, add_special_tokens=False, return_tensors="pt")
inputs = inputs.to(model.device)
with torch.no_grad():
tokens = model.generate(
**inputs,
temperature=temperature,
top_p=top_p,
max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty,
do_sample=True,
pad_token_id=tokenizer.pad_token_id,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id
)
output = tokenizer.decode(tokens[0], skip_special_tokens=True)
return output[len(prompt):]
import gradio as gr
with gr.Blocks() as demo:
chat_history = gr.Chatbot()
user_message = gr.Textbox(label="Question:", placeholder="人工知能とは何ですか?")
clear = gr.ClearButton([user_message, chat_history])
def response(user_message, chat_history):
system_message = generate_response(user_message, chat_history)
chat_history.append((user_message, system_message))
return "", chat_history
user_message.submit(response, inputs=[user_message, chat_history], outputs=[user_message, chat_history])
if __name__ == "__main__":
demo.launch()
|