File size: 2,663 Bytes
92c62ca
 
 
c99ebd4
d0ea771
c99ebd4
 
 
 
059c42c
c99ebd4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e8fc194
c99ebd4
e8fc194
c99ebd4
 
 
 
 
86938fd
c99ebd4
86938fd
 
 
c99ebd4
92c62ca
86938fd
c99ebd4
 
e8fc194
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

MODEL_ID = "rinna/bilingual-gpt-neox-4b-instruction-ppo"
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    load_in_8bit=True,
    device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=False)
def generate_response(user_question,
             chat_history,
             temperature=0.3,
             top_p=0.85,
             max_new_tokens=2048,
             repetition_penalty=1.05
            ):

    user_prompt_template = "ユーザー: Hello, you are an assistant that helps me learn Japanese. I am going to ask you a question, so please answer *briefly*."
    system_prompt_template = "システム: Sure, I will answer briefly. What can I do for you?"

    user_sample = "ユーザー: 日本で一番高い山は何ですか?"
    system_sample = "システム: 富士山です。高さは3776メートルです。"

    user_prerix = "ユーザー: "
    system_prefix = "システム: "

    prompt = user_prompt_template + "\n" + system_prompt_template + "\n"

    if len(chat_history) < 1:
        prompt += user_sample + "\n" + system_sample + "\n"
    else:
        u = chat_history[-1][0]
        s = chat_history[-1][1]
        prompt += user_prerix + u + "\n" + system_prefix + s + "\n"

    prompt += user_prerix + user_question + "\n" + system_prefix

    inputs = tokenizer(prompt, add_special_tokens=False, return_tensors="pt")
    inputs = inputs.to(model.device)
    with torch.no_grad():
        tokens = model.generate(
            **inputs,
            temperature=temperature,
            top_p=top_p,
            max_new_tokens=max_new_tokens,
            repetition_penalty=repetition_penalty,
            do_sample=True,
            pad_token_id=tokenizer.pad_token_id,
            bos_token_id=tokenizer.bos_token_id,
            eos_token_id=tokenizer.eos_token_id
        )
    output = tokenizer.decode(tokens[0], skip_special_tokens=True)
    return output[len(prompt):]

    import gradio as gr # 慣習としてgrと略記

    
with gr.Blocks() as demo:
    chat_history = gr.Chatbot()
    user_message = gr.Textbox(label="Question:", placeholder="人工知能とは何ですか?")
    clear = gr.ClearButton([user_message, chat_history])

    def response(user_message, chat_history):
        system_message = generate_response(user_message, chat_history)
        chat_history.append((user_message, system_message))
        return "", chat_history

    user_message.submit(response, inputs=[user_message, chat_history], outputs=[user_message, chat_history])

if __name__ == "__main__":
    demo.launch()