File size: 2,862 Bytes
92c62ca
 
 
86938fd
d0ea771
86938fd
 
 
d0ea771
92c62ca
 
86938fd
92c62ca
4b0b4cb
 
92c62ca
 
 
 
 
 
faafa45
92c62ca
 
 
 
 
 
faafa45
 
 
 
92c62ca
 
 
 
4b0b4cb
 
 
 
 
 
 
 
92c62ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86938fd
 
 
 
 
 
 
 
 
 
 
92c62ca
86938fd
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

MODEL_ID = "rinna/bilingual-gpt-neox-4b-instruction-ppo"
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    load_in_8bit=True,
    device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=False)

import gradio as gr

def generate_response(user_question,
             chat_history,
             temperature=0.3,
             top_p=0.85,
             max_new_tokens=2048,
             repetition_penalty=1.05
            ):
    # 挙動の指定
    user_prompt_template = "ユーザー:あなたは日本語で質問やコメントに対して、回答してくれるアシスタントです。関西弁で回答してください"
    system_prompt_template = "システム: もちろんやで!どんどん質問してな!今日も気分ええわ!"

    # one-shot
    user_sample = "ユーザー:日本一の高さの山は? "
    system_sample = "システム: 富士山や!最高の眺めを拝めるで!!"

    user_sample = "大阪で有名な食べ物は? "
    system_sample = "システム: たこ焼きやで!!外がカリカリ、中がふわふわや"

    
    user_prerix = "ユーザー: "
    system_prefix = "システム: "

    prompt = user_prompt_template + "\n" + system_prompt_template + "\n"

    if len(chat_history) < 1:
        prompt += user_sample + "\n" + system_sample + "\n"
    else:
        u = chat_history[-1][0]
        s = chat_history[-1][1]
        prompt += user_prerix + u + "\n" + system_prefix + s + "\n"

    prompt += user_prerix + user_question + "\n" + system_prefix

    inputs = tokenizer(prompt, add_special_tokens=False, return_tensors="pt")
    inputs = inputs.to(model.device)
    with torch.no_grad():
        tokens = model.generate(
            **inputs,
            temperature=temperature,
            top_p=top_p,
            max_new_tokens=max_new_tokens,
            repetition_penalty=repetition_penalty,
            do_sample=True,
            pad_token_id=tokenizer.pad_token_id,
            bos_token_id=tokenizer.bos_token_id,
            eos_token_id=tokenizer.eos_token_id
        )
    output = tokenizer.decode(tokens[0], skip_special_tokens=True)
    return output[len(prompt):]

with gr.Blocks() as demo:
    chat_history = gr.Chatbot()
    user_message = gr.Textbox(label="Question:", placeholder="人工知能とは何ですか?")
    clear = gr.ClearButton([user_message, chat_history])

    def response(user_message, chat_history):
        system_message = generate_response(user_message, chat_history)
        chat_history.append((user_message, system_message))
        return "", chat_history

    user_message.submit(response, inputs=[user_message, chat_history], outputs=[user_message, chat_history])

if __name__ == "__main__":
    demo.launch()