Spaces:
Sleeping
Sleeping
File size: 2,801 Bytes
92c62ca 8dd9d83 92c62ca c99ebd4 8dd9d83 d0ea771 c99ebd4 8dd9d83 c99ebd4 8dd9d83 059c42c 8dd9d83 c99ebd4 8dd9d83 c99ebd4 8dd9d83 c99ebd4 8dd9d83 c99ebd4 8dd9d83 c99ebd4 8dd9d83 c99ebd4 86938fd c99ebd4 86938fd c99ebd4 92c62ca 86938fd c99ebd4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 |
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
MODEL_ID = "rinna/bilingual-gpt-neox-4b-instruction-ppo"
# 8ビット量子化の設定を作成
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
quantization_config=quantization_config,
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=False)
def generate_response(user_question,
chat_history,
temperature=0.3,
top_p=0.85,
max_new_tokens=2048,
repetition_penalty=1.05):
user_prompt_template = "ユーザー: Hello, you are an assistant that helps me learn Japanese. I am going to ask you a question, so please answer *briefly*."
system_prompt_template = "システム: Sure, I will answer briefly. What can I do for you?"
user_sample = "ユーザー: 日本で一番高い山は何ですか?"
system_sample = "システム: 富士山です。高さは3776メートルです。"
user_prefix = "ユーザー: "
system_prefix = "システム: "
prompt = user_prompt_template + "\n" + system_prompt_template + "\n"
if len(chat_history) < 1:
prompt += user_sample + "\n" + system_sample + "\n"
else:
u = chat_history[-1][0]
s = chat_history[-1][1]
prompt += user_prefix + u + "\n" + system_prefix + s + "\n"
prompt += user_prefix + user_question + "\n" + system_prefix
inputs = tokenizer(prompt, add_special_tokens=False, return_tensors="pt")
inputs = inputs.to(model.device)
with torch.no_grad():
tokens = model.generate(
**inputs,
temperature=temperature,
top_p=top_p,
max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty,
do_sample=True,
pad_token_id=tokenizer.pad_token_id,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id
)
output = tokenizer.decode(tokens[0], skip_special_tokens=True)
return output[len(prompt):]
import gradio as gr
with gr.Blocks() as demo:
chat_history = gr.Chatbot()
user_message = gr.Textbox(label="Question:", placeholder="人工知能とは何ですか?")
clear = gr.ClearButton([user_message, chat_history])
def response(user_message, chat_history):
system_message = generate_response(user_message, chat_history)
chat_history.append((user_message, system_message))
return "", chat_history
user_message.submit(response, inputs=[user_message, chat_history], outputs=[user_message, chat_history])
if __name__ == "__main__":
demo.launch()
|