File size: 3,338 Bytes
f0dff07
 
 
9ab7a40
f0dff07
 
 
 
 
97befb1
f0dff07
97befb1
f0dff07
 
 
 
 
 
 
 
245e479
f0dff07
2e5890c
97befb1
 
 
 
f0dff07
 
 
 
 
 
 
 
 
 
 
 
97befb1
 
 
 
f0dff07
97befb1
 
 
 
f0dff07
920b6db
f0dff07
920b6db
 
f0dff07
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97befb1
 
 
 
f0dff07
920b6db
 
 
f0dff07
 
 
920b6db
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
#!/usr/bin/env python

import os
from collections.abc import Iterator
from threading import Thread

import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, pipeline

DESCRIPTION = "# chat-1"

if not torch.cuda.is_available():
    DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"

MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 1024
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "32768"))


if torch.cuda.is_available():
    model_id = "vericava/llm-jp-3-1.8b-instruct-lora-vericava7-llama"
    my_pipeline=pipeline(
        model=model_id,
    )
    my_pipeline.tokenizer.chat_template = "{{bos_token}}{% for message in messages %}{% if message['role'] == 'user' %}{{ '\\n\\n### 前の投稿:\\n' + message['content'] + '' }}{% elif message['role'] == 'system' %}{{ '以下は、SNS上の投稿です。あなたはSNSの投稿生成botとして、次に続く投稿を考えなさい。説明はせず、投稿の内容のみを鉤括弧をつけずに答えよ。' }}{% elif message['role'] == 'assistant' %}{{ '\\n\\n### 次の投稿:\\n' + message['content'] + eos_token }}{% endif %}{% if loop.last and add_generation_prompt %}{{ '\\n\\n### 次の投稿:\\n' }}{% endif %}{% endfor %}"

@spaces.GPU
@torch.inference_mode()
def generate(
    message: str,
    chat_history: list[tuple[str, str]],
    max_new_tokens: int = 1024,
    temperature: float = 0.7,
    top_p: float = 0.95,
    top_k: int = 50,
    repetition_penalty: float = 1.0,
) -> Iterator[str]:
    messages = [
        {"role": "system", "content": "あなたはSNSの投稿生成botで、次に続く投稿を考えてください。"},
        {"role": "user", "content": message},
    ]

    output = my_pipeline(
        messages,
    )[-1]["generated_text"][-1]["content"]
    yield output

demo = gr.ChatInterface(
    fn=generate,
    type="tuples",
    additional_inputs_accordion=gr.Accordion(label="詳細設定", open=False),
    additional_inputs=[
        gr.Slider(
            label="Max new tokens",
            minimum=1,
            maximum=MAX_MAX_NEW_TOKENS,
            step=1,
            value=DEFAULT_MAX_NEW_TOKENS,
        ),
        gr.Slider(
            label="Temperature",
            minimum=0.1,
            maximum=4.0,
            step=0.1,
            value=0.7,
        ),
        gr.Slider(
            label="Top-p (nucleus sampling)",
            minimum=0.05,
            maximum=1.0,
            step=0.05,
            value=0.95,
        ),
        gr.Slider(
            label="Top-k",
            minimum=1,
            maximum=1000,
            step=1,
            value=50,
        ),
        gr.Slider(
            label="Repetition penalty",
            minimum=1.0,
            maximum=2.0,
            step=0.05,
            value=1.0,
        ),
    ],
    stop_btn=None,
    examples=[
        ["サマリーを作る男の人,サマリーマン。"],
        ["やばい場所にクリティカルな配線ができてしまったので掲示した。"],
        ["にゃん"],
        ["Wikipedia の情報は入っているのかもしれない"],
    ],
    description=DESCRIPTION,
    css_paths="style.css",
    fill_height=True,
)

if __name__ == "__main__":
    demo.launch()