File size: 2,772 Bytes
92c62ca
e1ded21
92c62ca
c99ebd4
8dd9d83
e1ded21
d0ea771
c99ebd4
e1ded21
c99ebd4
8dd9d83
059c42c
8dd9d83
c99ebd4
8dd9d83
 
 
 
 
c99ebd4
b51e865
264e69b
 
 
 
 
 
 
 
 
c99ebd4
 
 
 
 
 
 
8dd9d83
c99ebd4
8dd9d83
c99ebd4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8dd9d83
c99ebd4
 
 
 
 
 
86938fd
c99ebd4
86938fd
 
 
c99ebd4
92c62ca
86938fd
c99ebd4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

MODEL_ID = "rinna/bilingual-gpt-neox-4b-instruction-ppo"

# モデルをロード(8ビット量子化を使用せず)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    device_map="auto"  # 自動でCPU/GPUを選択
)

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=False)

def generate_response(user_question,
                      chat_history,
                      temperature=0.3,
                      top_p=0.85,
                      max_new_tokens=2048,
                      repetition_penalty=1.05):

    user_prompt_template = "ユーザー: こんにちは、あなたは日本語を学ぶ手助けをしてくれるアシスタントやねん。質問するから、簡潔に答えてな。"
    system_prompt_template = "システム: うん、簡潔に答えるで。何を教えてほしいん?"


    user_sample = "ユーザー: 富士山の標高ってどれくらいなん?"
    system_sample = "システム: 富士山の標高は3776メートルやで。"

    user_prefix = "ユーザー: "
    system_prefix = "システム: "

    prompt = user_prompt_template + "\n" + system_prompt_template + "\n"

    if len(chat_history) < 1:
        prompt += user_sample + "\n" + system_sample + "\n"
    else:
        u = chat_history[-1][0]
        s = chat_history[-1][1]
        prompt += user_prefix + u + "\n" + system_prefix + s + "\n"

    prompt += user_prefix + user_question + "\n" + system_prefix

    inputs = tokenizer(prompt, add_special_tokens=False, return_tensors="pt")
    inputs = inputs.to(model.device)
    with torch.no_grad():
        tokens = model.generate(
            **inputs,
            temperature=temperature,
            top_p=top_p,
            max_new_tokens=max_new_tokens,
            repetition_penalty=repetition_penalty,
            do_sample=True,
            pad_token_id=tokenizer.pad_token_id,
            bos_token_id=tokenizer.bos_token_id,
            eos_token_id=tokenizer.eos_token_id
        )
    output = tokenizer.decode(tokens[0], skip_special_tokens=True)
    return output[len(prompt):]

import gradio as gr

with gr.Blocks() as demo:
    chat_history = gr.Chatbot()
    user_message = gr.Textbox(label="Question:", placeholder="人工知能とは何ですか?")
    clear = gr.ClearButton([user_message, chat_history])

    def response(user_message, chat_history):
        system_message = generate_response(user_message, chat_history)
        chat_history.append((user_message, system_message))
        return "", chat_history

    user_message.submit(response, inputs=[user_message, chat_history], outputs=[user_message, chat_history])

if __name__ == "__main__":
    demo.launch()