File size: 2,958 Bytes
cc9c601
2b7f2f4
 
 
 
 
 
cc9c601
 
d676716
 
cc9c601
 
 
 
 
 
 
 
 
 
 
 
2b7f2f4
cc9c601
2b7f2f4
cc9c601
 
 
 
d676716
 
cc9c601
d676716
 
 
 
 
 
 
cc9c601
d676716
 
 
 
 
 
 
 
 
 
 
cc9c601
 
 
d676716
cc9c601
 
 
 
 
 
 
 
d676716
cc9c601
d676716
cc9c601
 
 
 
 
 
 
 
 
d676716
 
 
 
 
cc9c601
 
 
 
d676716
 
 
 
 
 
 
cc9c601
d676716
cc9c601
 
 
 
 
 
 
d676716
 
 
 
 
cc9c601
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import torch
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    TextIteratorStreamer,
    BitsAndBytesConfig,
)
from threading import Thread

import gradio as gr

try:
    import spaces
except:

    class spaces:
        @staticmethod
        def GPU(duration: int):
            return lambda x: x


MODEL_NAME = "hatakeyama-llm-team/Tanuki-8B-Instruct"

quantization_config = BitsAndBytesConfig(load_in_8bit=True)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME, quantization_config=quantization_config, device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

print(model.hf_device_map)


@spaces.GPU(duration=10)
def respond(
    message,
    history: list[tuple[str, str]],
    system_message,
    max_tokens,
    temperature,
    top_p,
    top_k,
):
    messages = [{"role": "system", "content": system_message}]

    for val in history:
        if val[0]:
            messages.append({"role": "user", "content": val[0]})
        if val[1]:
            messages.append({"role": "assistant", "content": val[1]})

    messages.append({"role": "user", "content": message})

    tokenized_input = tokenizer.apply_chat_template(
        messages, add_generation_prompt=True, tokenize=True, return_tensors="pt"
    ).to(model.device)

    streamer = TextIteratorStreamer(
        tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
    )
    generate_kwargs = dict(
        input_ids=tokenized_input,
        streamer=streamer,
        max_new_tokens=max_tokens,
        do_sample=True,
        temperature=temperature,
        top_k=top_k,
        top_p=top_p,
        num_beams=1,
    )
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    partial_message = ""
    for new_token in streamer:
        partial_message += new_token
        yield partial_message


demo = gr.ChatInterface(
    respond,
    additional_inputs=[
        gr.Textbox(
            value="以下は、タスクを説明する指示と、文脈のある入力の組み合わせです。要求を適切に満たす応答を書きなさい。",
            label="システムプロンプト",
        ),
        gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(
            minimum=0.1,
            maximum=1.0,
            value=0.95,
            step=0.05,
            label="Top-p",
        ),
        gr.Slider(minimum=1, maximum=2000, value=200, step=10, label="Top-k"),
    ],
    examples=[
        ["たぬきってなんですか?"],
        ["情けは人の為ならずとはどういう意味ですか?"],
        ["明晰夢とはなんですか?"],
        ["シュレディンガー方程式とシュレディンガーの猫はどのような関係がありますか?"],
    ],
)


if __name__ == "__main__":
    demo.launch()