File size: 2,728 Bytes
92c62ca
e1ded21
92c62ca
c99ebd4
8dd9d83
e1ded21
d0ea771
c99ebd4
e1ded21
c99ebd4
8dd9d83
059c42c
8dd9d83
c99ebd4
8dd9d83
 
 
 
 
c99ebd4
 
 
 
 
 
 
8dd9d83
c99ebd4
 
 
 
 
 
 
 
 
8dd9d83
c99ebd4
8dd9d83
c99ebd4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8dd9d83
c99ebd4
 
 
 
 
 
86938fd
c99ebd4
86938fd
 
 
c99ebd4
92c62ca
86938fd
c99ebd4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

MODEL_ID = "rinna/bilingual-gpt-neox-4b-instruction-ppo"

# モデルをロード(8ビット量子化を使用せず)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    device_map="auto"  # 自動でCPU/GPUを選択
)

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=False)

def generate_response(user_question,
                      chat_history,
                      temperature=0.3,
                      top_p=0.85,
                      max_new_tokens=2048,
                      repetition_penalty=1.05):

    user_prompt_template = "ユーザー: Hello, you are an assistant that helps me learn Japanese. I am going to ask you a question, so please answer *briefly*."
    system_prompt_template = "システム: Sure, I will answer briefly. What can I do for you?"

    user_sample = "ユーザー: 日本で一番高い山は何ですか?"
    system_sample = "システム: 富士山です。高さは3776メートルです。"

    user_prefix = "ユーザー: "
    system_prefix = "システム: "

    prompt = user_prompt_template + "\n" + system_prompt_template + "\n"

    if len(chat_history) < 1:
        prompt += user_sample + "\n" + system_sample + "\n"
    else:
        u = chat_history[-1][0]
        s = chat_history[-1][1]
        prompt += user_prefix + u + "\n" + system_prefix + s + "\n"

    prompt += user_prefix + user_question + "\n" + system_prefix

    inputs = tokenizer(prompt, add_special_tokens=False, return_tensors="pt")
    inputs = inputs.to(model.device)
    with torch.no_grad():
        tokens = model.generate(
            **inputs,
            temperature=temperature,
            top_p=top_p,
            max_new_tokens=max_new_tokens,
            repetition_penalty=repetition_penalty,
            do_sample=True,
            pad_token_id=tokenizer.pad_token_id,
            bos_token_id=tokenizer.bos_token_id,
            eos_token_id=tokenizer.eos_token_id
        )
    output = tokenizer.decode(tokens[0], skip_special_tokens=True)
    return output[len(prompt):]

import gradio as gr

with gr.Blocks() as demo:
    chat_history = gr.Chatbot()
    user_message = gr.Textbox(label="Question:", placeholder="人工知能とは何ですか?")
    clear = gr.ClearButton([user_message, chat_history])

    def response(user_message, chat_history):
        system_message = generate_response(user_message, chat_history)
        chat_history.append((user_message, system_message))
        return "", chat_history

    user_message.submit(response, inputs=[user_message, chat_history], outputs=[user_message, chat_history])

if __name__ == "__main__":
    demo.launch()