File size: 3,014 Bytes
cc9c601
2b7f2f4
 
 
 
 
 
cc9c601
 
d676716
 
cc9c601
 
 
 
 
 
 
 
 
 
 
 
2b7f2f4
cc9c601
2b7f2f4
cc9c601
 
 
 
d676716
 
a3e72a9
d676716
 
 
 
 
 
 
cc9c601
d676716
 
 
 
 
 
 
 
 
 
 
cc9c601
 
 
d676716
cc9c601
 
 
 
 
 
 
 
d676716
cc9c601
d676716
cc9c601
 
 
 
 
 
 
 
 
d676716
 
 
 
 
cc9c601
 
 
 
d676716
 
 
 
 
 
 
cc9c601
d676716
cc9c601
 
 
 
 
00d5978
cc9c601
 
d676716
 
 
 
 
cc9c601
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import torch
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    TextIteratorStreamer,
    BitsAndBytesConfig,
)
from threading import Thread

import gradio as gr

try:
    import spaces
except:

    class spaces:
        @staticmethod
        def GPU(duration: int):
            return lambda x: x


MODEL_NAME = "hatakeyama-llm-team/Tanuki-8B-Instruct"

quantization_config = BitsAndBytesConfig(load_in_8bit=True)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME, quantization_config=quantization_config, device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

print(model.hf_device_map)


@spaces.GPU(duration=30)
def respond(
    message,
    history: list[tuple[str, str]],
    system_message,
    max_tokens,
    temperature,
    top_p,
    top_k,
):
    messages = [{"role": "system", "content": system_message}]

    for val in history:
        if val[0]:
            messages.append({"role": "user", "content": val[0]})
        if val[1]:
            messages.append({"role": "assistant", "content": val[1]})

    messages.append({"role": "user", "content": message})

    tokenized_input = tokenizer.apply_chat_template(
        messages, add_generation_prompt=True, tokenize=True, return_tensors="pt"
    ).to(model.device)

    streamer = TextIteratorStreamer(
        tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
    )
    generate_kwargs = dict(
        input_ids=tokenized_input,
        streamer=streamer,
        max_new_tokens=max_tokens,
        do_sample=True,
        temperature=temperature,
        top_k=top_k,
        top_p=top_p,
        num_beams=1,
    )
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    partial_message = ""
    for new_token in streamer:
        partial_message += new_token
        yield partial_message


demo = gr.ChatInterface(
    respond,
    additional_inputs=[
        gr.Textbox(
            value="以下は、タスクを説明する指示と、文脈のある入力の組み合わせです。要求を適切に満たす応答を書きなさい。",
            label="システムプロンプト",
        ),
        gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(
            minimum=0.1,
            maximum=1.0,
            value=0.95,
            step=0.05,
            label="Top-p",
        ),
        gr.Slider(minimum=1, maximum=2000, value=200, step=10, label="Top-k"),
    ],
    examples=[
        ["たぬきってなんですか?"],
        ["情けは人の為ならずとはどういう意味ですか?"],
        ["まどマギで一番可愛いのは誰?"],
        ["明晰夢とはなんですか?"],
        ["シュレディンガー方程式とシュレディンガーの猫はどのような関係がありますか?"],
    ],
)


if __name__ == "__main__":
    demo.launch()