File size: 4,377 Bytes
92c62ca
e1ded21
92c62ca
c99ebd4
8dd9d83
e1ded21
d0ea771
c99ebd4
e1ded21
c99ebd4
8dd9d83
059c42c
8dd9d83
c99ebd4
8dd9d83
 
 
 
 
c99ebd4
8ec9971
 
 
 
02089a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8ec9971
 
 
 
c99ebd4
 
 
 
 
 
 
 
8dd9d83
c99ebd4
8dd9d83
c99ebd4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8dd9d83
c99ebd4
 
 
 
 
 
86938fd
c99ebd4
86938fd
 
 
c99ebd4
92c62ca
86938fd
c99ebd4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

MODEL_ID = "rinna/bilingual-gpt-neox-4b-instruction-ppo"

# モデルをロード(8ビット量子化を使用せず)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    device_map="auto"  # 自動でCPU/GPUを選択
)

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=False)

def generate_response(user_question,
                      chat_history,
                      temperature=0.3,
                      top_p=0.85,
                      max_new_tokens=2048,
                      repetition_penalty=1.05):

user_prompt_template = "ユーザー: こんにちは、あなたは日本語を学ぶ手助けをしてくれるアシスタントやねん。質問するから、簡潔に答えてな。"
system_prompt_template = "システム: うん、簡潔に答えるで。何を教えてほしいん?"

user_sample = [
    "ユーザー: 日本で一番高い山は何ですか?",
    "ユーザー: 富士山の標高ってどれくらいなん?",
    "ユーザー: 日本の一番高い山を教えて?",
    "ユーザー: 大阪の有名な観光地はどこなん?",
    "ユーザー: 神戸ってどんな街なん?",
    "ユーザー: たこ焼きって何やったっけ?",
    "ユーザー: お好み焼きって何が入ってるん?",
    "ユーザー: あんた、今日はどうなん?",
    "ユーザー: 今日の天気はどうやった?",
    "ユーザー: 今度、どこか行こうか?"
]

system_sample = [
    "システム: 富士山です。高さは3776メートルです。",
    "システム: 富士山の標高は3776メートルやで。",
    "システム: 富士山や。3776メートルやね。",
    "システム: 大阪城やで。歴史ある場所やし、観光スポットとして有名やな。",
    "システム: 神戸はおしゃれな街やで。港町として有名やし、異人館とかもあるし、観光するにはええ場所やな。",
    "システム: たこ焼きは、小麦粉の生地にタコを入れて焼いたもんやで。ソースと鰹節をかけて食べるんや。",
    "システム: お好み焼きは、キャベツと豚肉とか海鮮を混ぜた生地を鉄板で焼いた料理やな。ソースとマヨネーズがたまらんで。",
    "システム: うーん、元気やで。あんたはどうなん?",
    "システム: 今日は晴れてたけど、ちょっと暑かったな。まあ、夏っぽい天気やったで。",
    "システム: そうやな、今度は京都に行ってみようか?お寺とか観光スポットがいっぱいあるで。"
]

user_prefix = "ユーザー: "
system_prefix = "システム: "

    prompt = user_prompt_template + "\n" + system_prompt_template + "\n"

    if len(chat_history) < 1:
        prompt += user_sample + "\n" + system_sample + "\n"
    else:
        u = chat_history[-1][0]
        s = chat_history[-1][1]
        prompt += user_prefix + u + "\n" + system_prefix + s + "\n"

    prompt += user_prefix + user_question + "\n" + system_prefix

    inputs = tokenizer(prompt, add_special_tokens=False, return_tensors="pt")
    inputs = inputs.to(model.device)
    with torch.no_grad():
        tokens = model.generate(
            **inputs,
            temperature=temperature,
            top_p=top_p,
            max_new_tokens=max_new_tokens,
            repetition_penalty=repetition_penalty,
            do_sample=True,
            pad_token_id=tokenizer.pad_token_id,
            bos_token_id=tokenizer.bos_token_id,
            eos_token_id=tokenizer.eos_token_id
        )
    output = tokenizer.decode(tokens[0], skip_special_tokens=True)
    return output[len(prompt):]

import gradio as gr

with gr.Blocks() as demo:
    chat_history = gr.Chatbot()
    user_message = gr.Textbox(label="Question:", placeholder="人工知能とは何ですか?")
    clear = gr.ClearButton([user_message, chat_history])

    def response(user_message, chat_history):
        system_message = generate_response(user_message, chat_history)
        chat_history.append((user_message, system_message))
        return "", chat_history

    user_message.submit(response, inputs=[user_message, chat_history], outputs=[user_message, chat_history])

if __name__ == "__main__":
    demo.launch()