File size: 2,756 Bytes
92c62ca
 
011d06e
92c62ca
d0ea771
 
 
 
 
 
0e33c0f
92c62ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
011d06e
92c62ca
011d06e
92c62ca
 
 
 
 
 
 
011d06e
92c62ca
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import gradio as gr

model = AutoModelForCausalLM.from_pretrained(
    "rinna/bilingual-gpt-neox-4b-instruction-ppo",
    use_auth_token="your_huggingface_token",
    device_map="cpu"
)

MODEL_ID = "rinna/bilingual-gpt-neox-4b-instruction-ppo"
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=False)

device = model.device

def generate(user_question,
             temperature=0.3,
             top_p=0.85,
             max_new_tokens=2048,
             repetition_penalty=1.05
            ):

    # 挙動の指定
    user_prompt_template = "ユーザー:あなたは日本語で質問やコメントに対して、回答してくれるアシスタントです。ただし超ポジティブかつ、関西弁で回答してください"
    system_prompt_template = "システム: もちろんやで!どんどん質問してな!今日も気分ええわ!"

    # one-shot
    user_sample = "ユーザー:日本一の高さの山は? "
    system_sample = "システム: 富士山や!最高の眺めを拝めるで!!"

    user_prerix = "ユーザー: "
    system_prefix = "システム: "

    prompt = user_prompt_template + "\n" + system_prompt_template + "\n"
    prompt += user_sample + "\n" + system_sample + "\n"
    prompt += user_prerix + user_question + "\n" + system_prefix

    inputs = tokenizer(prompt, add_special_tokens=False, return_tensors="pt")
    inputs = inputs.to(model.device)
    with torch.no_grad():
        tokens = model.generate(
            **inputs,
            temperature=temperature,
            top_p=top_p,
            max_new_tokens=max_new_tokens,
            repetition_penalty=repetition_penalty,
            do_sample=True,
            pad_token_id=tokenizer.pad_token_id,
            bos_token_id=tokenizer.bos_token_id,
            eos_token_id=tokenizer.eos_token_id
        )
    output = tokenizer.decode(tokens[0], skip_special_tokens=True)
    return output[len(prompt):]

with gr.Blocks() as demo:
    chat_history = gr.Chatbot()
    inputs = gr.Textbox(label="Question:", placeholder="質問を入力してください")
    outputs = gr.Textbox(label="Answer:")
    btn = gr.Button("Send")
    clear = gr.ClearButton([inputs, chat_history])

    # ボタンが押された時の動作を以下のように定義する:
    btn.click(fn=generate, inputs=inputs, outputs=outputs)

    def response(user_message, chat_history):
        system_message = generate(user_message)
        chat_history.append((user_message, system_message))
        return "", chat_history

    inputs.submit(response, inputs=[inputs, chat_history], outputs=[inputs, chat_history])

if __name__ == "__main__":
    demo.launch()