File size: 3,470 Bytes
41b1248
 
 
 
 
3f35dda
41b1248
 
 
 
 
3f35dda
 
 
41b1248
cc9c601
2b7f2f4
 
 
 
 
 
cc9c601
 
d676716
 
cc9c601
 
 
 
 
 
 
 
 
 
 
 
798bcec
 
 
 
 
 
cc9c601
2b7f2f4
cc9c601
 
 
 
d676716
 
a3e72a9
d676716
 
 
 
 
 
 
cc9c601
d676716
 
 
 
 
 
 
 
 
 
 
cc9c601
 
 
d676716
cc9c601
 
 
 
 
 
 
 
d676716
cc9c601
d676716
cc9c601
 
 
 
 
 
 
 
 
d676716
 
 
 
 
cc9c601
 
 
 
d676716
 
 
 
 
 
 
cc9c601
d676716
cc9c601
 
 
 
 
00d5978
cc9c601
 
d676716
94d52d9
d676716
 
 
 
cc9c601
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
try:
    import flash_attn
except:
    import subprocess

    print("Installing flash-attn...")
    subprocess.run(
        "pip install flash-attn --no-build-isolation",
        env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
        shell=True,
    )
    import flash_attn

    print("flash-attn installed.")

import torch
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    TextIteratorStreamer,
    BitsAndBytesConfig,
)
from threading import Thread

import gradio as gr

try:
    import spaces
except:

    class spaces:
        @staticmethod
        def GPU(duration: int):
            return lambda x: x


MODEL_NAME = "hatakeyama-llm-team/Tanuki-8B-Instruct"

quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME, quantization_config=quantization_config, device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

print(model.hf_device_map)


@spaces.GPU(duration=30)
def respond(
    message,
    history: list[tuple[str, str]],
    system_message,
    max_tokens,
    temperature,
    top_p,
    top_k,
):
    messages = [{"role": "system", "content": system_message}]

    for val in history:
        if val[0]:
            messages.append({"role": "user", "content": val[0]})
        if val[1]:
            messages.append({"role": "assistant", "content": val[1]})

    messages.append({"role": "user", "content": message})

    tokenized_input = tokenizer.apply_chat_template(
        messages, add_generation_prompt=True, tokenize=True, return_tensors="pt"
    ).to(model.device)

    streamer = TextIteratorStreamer(
        tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
    )
    generate_kwargs = dict(
        input_ids=tokenized_input,
        streamer=streamer,
        max_new_tokens=max_tokens,
        do_sample=True,
        temperature=temperature,
        top_k=top_k,
        top_p=top_p,
        num_beams=1,
    )
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    partial_message = ""
    for new_token in streamer:
        partial_message += new_token
        yield partial_message


demo = gr.ChatInterface(
    respond,
    additional_inputs=[
        gr.Textbox(
            value="以下は、タスクを説明する指示と、文脈のある入力の組み合わせです。要求を適切に満たす応答を書きなさい。",
            label="システムプロンプト",
        ),
        gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(
            minimum=0.1,
            maximum=1.0,
            value=0.95,
            step=0.05,
            label="Top-p",
        ),
        gr.Slider(minimum=1, maximum=2000, value=200, step=10, label="Top-k"),
    ],
    examples=[
        ["たぬきってなんですか?"],
        ["情けは人の為ならずとはどういう意味ですか?"],
        ["まどマギで一番可愛いのは誰?"],
        ["明晰夢とはなんですか?"],
        ["シュレディンガー方程式とシュレディンガーの猫はどのような関係がありますか?"],
    ],
    cache_examples=False,
)


if __name__ == "__main__":
    demo.launch()