File size: 6,214 Bytes
92c62ca
 
 
c99ebd4
d0ea771
c99ebd4
 
 
 
059c42c
92c62ca
e87cafc
c99ebd4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92c62ca
e87cafc
92c62ca
c99ebd4
 
 
 
059c42c
c99ebd4
 
e87cafc
c99ebd4
 
faafa45
92c62ca
 
 
c99ebd4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92c62ca
 
 
c99ebd4
 
e87cafc
059c42c
c99ebd4
e87cafc
86938fd
c99ebd4
e87cafc
 
059c42c
 
c99ebd4
e87cafc
 
c99ebd4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86938fd
c99ebd4
86938fd
 
 
c99ebd4
92c62ca
86938fd
c99ebd4
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

MODEL_ID = "rinna/bilingual-gpt-neox-4b-instruction-ppo"
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    load_in_8bit=True,
    device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=False)

device = model.device
device

user_prompt_template = "ユーザー: Hello, you are an assistant that helps me learn Japanese. I am going to ask you a question, so please answer *briefly*."
system_prompt_template = "システム: Sure, I will answer briefly. What can I do for you?"

# one-shot
user_sample = "ユーザー: 日本で一番高い山は何ですか?"
system_sample = "システム: 富士山です。高さは3776メートルです。"

# 質問
user_prerix = "ユーザー: "
user_question = "人工知能とは何ですか?"
system_prefix = "システム: "

# プロンプトの整形
prompt = user_prompt_template + "\n" + system_prompt_template + "\n"
prompt += user_sample + "\n" + system_sample + "\n"
prompt += user_prerix + user_question + "\n" + system_prefix

inputs = tokenizer(
    prompt,
    add_special_tokens=False, # プロンプトに余計なトークンが付属するのを防ぐ
    return_tensors="pt"
)
inputs = inputs.to(model.device)
with torch.no_grad():
    tokens = model.generate(
        **inputs,
        temperature=0.3,
        top_p=0.85,
        max_new_tokens=2048,
        repetition_penalty=1.05,
        do_sample=True,
        pad_token_id=tokenizer.pad_token_id,
        bos_token_id=tokenizer.bos_token_id,
        eos_token_id=tokenizer.eos_token_id
    )

tokens

output = tokenizer.decode(
    tokens[0],
    skip_special_tokens=True # 出力に余計なトークンが付属するのを防ぐ
)
print(output)

output[len(prompt):]

def generate(user_question,
             temperature=0.3,
             top_p=0.85,
             max_new_tokens=2048,
             repetition_penalty=1.05
            ):

    user_prompt_template = "ユーザー: Hello, you are an assistant that helps me learn Japanese. I am going to ask you a question, so please answer *briefly*."
    system_prompt_template = "システム: Sure, I will answer briefly. What can I do for you?"

    user_sample = "ユーザー: 日本で一番高い山は何ですか?"
    system_sample = "システム: 富士山です。高さは3776メートルです。"

    user_prerix = "ユーザー: "
    system_prefix = "システム: "

    prompt = user_prompt_template + "\n" + system_prompt_template + "\n"
    prompt += user_sample + "\n" + system_sample + "\n"
    prompt += user_prerix + user_question + "\n" + system_prefix

    inputs = tokenizer(prompt, add_special_tokens=False, return_tensors="pt")
    inputs = inputs.to(model.device)
    with torch.no_grad():
        tokens = model.generate(
            **inputs,
            temperature=temperature,
            top_p=top_p,
            max_new_tokens=max_new_tokens,
            repetition_penalty=repetition_penalty,
            do_sample=True,
            pad_token_id=tokenizer.pad_token_id,
            bos_token_id=tokenizer.bos_token_id,
            eos_token_id=tokenizer.eos_token_id
        )
    output = tokenizer.decode(tokens[0], skip_special_tokens=True)
    return output[len(prompt):]

output = generate('人工知能とは何ですか?')
output


import gradio as gr # 慣習としてgrと略記

with gr.Blocks() as demo:
    inputs = gr.Textbox(label="Question:", placeholder="人工知能とは何ですか?")
    outputs = gr.Textbox(label="Answer:")
    btn = gr.Button("Send")

    # ボタンが押された時の動作を以下のように定義する:
    # 「inputs内の値を入力としてモデルに渡し、その戻り値をoutputsの値として設定する」
    btn.click(fn=generate, inputs=inputs, outputs=outputs)

if __name__ == "__main__":
    demo.launch()
    
def generate_response(user_question,
             chat_history,
             temperature=0.3,
             top_p=0.85,
             max_new_tokens=2048,
             repetition_penalty=1.05
            ):

    user_prompt_template = "ユーザー: Hello, you are an assistant that helps me learn Japanese. I am going to ask you a question, so please answer *briefly*."
    system_prompt_template = "システム: Sure, I will answer briefly. What can I do for you?"

    user_sample = "ユーザー: 日本で一番高い山は何ですか?"
    system_sample = "システム: 富士山です。高さは3776メートルです。"

    user_prerix = "ユーザー: "
    system_prefix = "システム: "

    prompt = user_prompt_template + "\n" + system_prompt_template + "\n"

    if len(chat_history) < 1:
        prompt += user_sample + "\n" + system_sample + "\n"
    else:
        u = chat_history[-1][0]
        s = chat_history[-1][1]
        prompt += user_prerix + u + "\n" + system_prefix + s + "\n"

    prompt += user_prerix + user_question + "\n" + system_prefix

    inputs = tokenizer(prompt, add_special_tokens=False, return_tensors="pt")
    inputs = inputs.to(model.device)
    with torch.no_grad():
        tokens = model.generate(
            **inputs,
            temperature=temperature,
            top_p=top_p,
            max_new_tokens=max_new_tokens,
            repetition_penalty=repetition_penalty,
            do_sample=True,
            pad_token_id=tokenizer.pad_token_id,
            bos_token_id=tokenizer.bos_token_id,
            eos_token_id=tokenizer.eos_token_id
        )
    output = tokenizer.decode(tokens[0], skip_special_tokens=True)
    return output[len(prompt):]


with gr.Blocks() as demo:
    chat_history = gr.Chatbot()
    user_message = gr.Textbox(label="Question:", placeholder="人工知能とは何ですか?")
    clear = gr.ClearButton([user_message, chat_history])

    def response(user_message, chat_history):
        system_message = generate_response(user_message, chat_history)
        chat_history.append((user_message, system_message))
        return "", chat_history

    user_message.submit(response, inputs=[user_message, chat_history], outputs=[user_message, chat_history])

if __name__ == "__main__":
    demo.launch()