File size: 2,365 Bytes
92c62ca
 
011d06e
92c62ca
d0ea771
 
 
 
 
 
0e33c0f
92c62ca
 
 
 
4b0b4cb
 
92c62ca
 
 
 
 
 
faafa45
92c62ca
 
 
 
 
 
faafa45
 
 
 
92c62ca
 
 
 
4b0b4cb
 
 
 
 
 
 
 
92c62ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import gradio as gr

model = AutoModelForCausalLM.from_pretrained(
    "rinna/bilingual-gpt-neox-4b-instruction-ppo",
    use_auth_token="your_huggingface_token",
    device_map="cpu"
)

MODEL_ID = "rinna/bilingual-gpt-neox-4b-instruction-ppo"
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=False)

device = model.device

def generate_response(user_question,
             chat_history,
             temperature=0.3,
             top_p=0.85,
             max_new_tokens=2048,
             repetition_penalty=1.05
            ):
    # 挙動の指定
    user_prompt_template = "ユーザー:あなたは日本語で質問やコメントに対して、回答してくれるアシスタントです。関西弁で回答してください"
    system_prompt_template = "システム: もちろんやで!どんどん質問してな!今日も気分ええわ!"

    # one-shot
    user_sample = "ユーザー:日本一の高さの山は? "
    system_sample = "システム: 富士山や!最高の眺めを拝めるで!!"

    user_sample = "大阪で有名な食べ物は? "
    system_sample = "システム: たこ焼きやで!!外がカリカリ、中がふわふわや"

    
    user_prerix = "ユーザー: "
    system_prefix = "システム: "

    prompt = user_prompt_template + "\n" + system_prompt_template + "\n"

    if len(chat_history) < 1:
        prompt += user_sample + "\n" + system_sample + "\n"
    else:
        u = chat_history[-1][0]
        s = chat_history[-1][1]
        prompt += user_prerix + u + "\n" + system_prefix + s + "\n"

    prompt += user_prerix + user_question + "\n" + system_prefix

    inputs = tokenizer(prompt, add_special_tokens=False, return_tensors="pt")
    inputs = inputs.to(model.device)
    with torch.no_grad():
        tokens = model.generate(
            **inputs,
            temperature=temperature,
            top_p=top_p,
            max_new_tokens=max_new_tokens,
            repetition_penalty=repetition_penalty,
            do_sample=True,
            pad_token_id=tokenizer.pad_token_id,
            bos_token_id=tokenizer.bos_token_id,
            eos_token_id=tokenizer.eos_token_id
        )
    output = tokenizer.decode(tokens[0], skip_special_tokens=True)
    return output[len(prompt):]