File size: 2,801 Bytes
92c62ca
8dd9d83
92c62ca
c99ebd4
8dd9d83
 
 
 
d0ea771
c99ebd4
8dd9d83
c99ebd4
 
8dd9d83
059c42c
8dd9d83
c99ebd4
8dd9d83
 
 
 
 
c99ebd4
 
 
 
 
 
 
8dd9d83
c99ebd4
 
 
 
 
 
 
 
 
8dd9d83
c99ebd4
8dd9d83
c99ebd4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8dd9d83
c99ebd4
 
 
 
 
 
86938fd
c99ebd4
86938fd
 
 
c99ebd4
92c62ca
86938fd
c99ebd4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

MODEL_ID = "rinna/bilingual-gpt-neox-4b-instruction-ppo"

# 8ビット量子化の設定を作成
quantization_config = BitsAndBytesConfig(load_in_8bit=True)

model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    quantization_config=quantization_config,
    device_map="auto"
)

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=False)

def generate_response(user_question,
                      chat_history,
                      temperature=0.3,
                      top_p=0.85,
                      max_new_tokens=2048,
                      repetition_penalty=1.05):

    user_prompt_template = "ユーザー: Hello, you are an assistant that helps me learn Japanese. I am going to ask you a question, so please answer *briefly*."
    system_prompt_template = "システム: Sure, I will answer briefly. What can I do for you?"

    user_sample = "ユーザー: 日本で一番高い山は何ですか?"
    system_sample = "システム: 富士山です。高さは3776メートルです。"

    user_prefix = "ユーザー: "
    system_prefix = "システム: "

    prompt = user_prompt_template + "\n" + system_prompt_template + "\n"

    if len(chat_history) < 1:
        prompt += user_sample + "\n" + system_sample + "\n"
    else:
        u = chat_history[-1][0]
        s = chat_history[-1][1]
        prompt += user_prefix + u + "\n" + system_prefix + s + "\n"

    prompt += user_prefix + user_question + "\n" + system_prefix

    inputs = tokenizer(prompt, add_special_tokens=False, return_tensors="pt")
    inputs = inputs.to(model.device)
    with torch.no_grad():
        tokens = model.generate(
            **inputs,
            temperature=temperature,
            top_p=top_p,
            max_new_tokens=max_new_tokens,
            repetition_penalty=repetition_penalty,
            do_sample=True,
            pad_token_id=tokenizer.pad_token_id,
            bos_token_id=tokenizer.bos_token_id,
            eos_token_id=tokenizer.eos_token_id
        )
    output = tokenizer.decode(tokens[0], skip_special_tokens=True)
    return output[len(prompt):]

import gradio as gr

with gr.Blocks() as demo:
    chat_history = gr.Chatbot()
    user_message = gr.Textbox(label="Question:", placeholder="人工知能とは何ですか?")
    clear = gr.ClearButton([user_message, chat_history])

    def response(user_message, chat_history):
        system_message = generate_response(user_message, chat_history)
        chat_history.append((user_message, system_message))
        return "", chat_history

    user_message.submit(response, inputs=[user_message, chat_history], outputs=[user_message, chat_history])

if __name__ == "__main__":
    demo.launch()