File size: 3,453 Bytes
f0dff07 9ab7a40 f0dff07 97befb1 f0dff07 97befb1 f0dff07 245e479 f0dff07 2e5890c f602cdc 97befb1 f602cdc 831dbac 97befb1 f0dff07 97befb1 f0dff07 97befb1 f0dff07 920b6db f0dff07 920b6db f0dff07 97befb1 f0dff07 920b6db f0dff07 920b6db |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 |
#!/usr/bin/env python
import os
from collections.abc import Iterator
from threading import Thread
import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, pipeline
DESCRIPTION = "# chat-1"
if not torch.cuda.is_available():
DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 1024
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "32768"))
if torch.cuda.is_available():
model_id = "vericava/llm-jp-3-1.8b-instruct-lora-vericava7-llama"
base_model_id = "llm-jp/llm-jp-3-1.8b-instruct"
my_pipeline=pipeline(
model=model_id,
tokenizer=base_model_id,
use_safetensors=True,
)
my_pipeline.tokenizer.chat_template = "{{bos_token}}{% for message in messages %}{% if message['role'] == 'user' %}{{ '\\n\\n### 前の投稿:\\n' + message['content'] + '' }}{% elif message['role'] == 'system' %}{{ '以下は、SNS上の投稿です。あなたはSNSの投稿生成botとして、次に続く投稿を考えなさい。説明はせず、投稿の内容のみを鉤括弧をつけずに答えよ。' }}{% elif message['role'] == 'assistant' %}{{ '\\n\\n### 次の投稿:\\n' + message['content'] + eos_token }}{% endif %}{% if loop.last and add_generation_prompt %}{{ '\\n\\n### 次の投稿:\\n' }}{% endif %}{% endfor %}"
@spaces.GPU
@torch.inference_mode()
def generate(
message: str,
chat_history: list[tuple[str, str]],
max_new_tokens: int = 1024,
temperature: float = 0.7,
top_p: float = 0.95,
top_k: int = 50,
repetition_penalty: float = 1.0,
) -> Iterator[str]:
messages = [
{"role": "system", "content": "あなたはSNSの投稿生成botで、次に続く投稿を考えてください。"},
{"role": "user", "content": message},
]
output = my_pipeline(
messages,
)[-1]["generated_text"][-1]["content"]
yield output
demo = gr.ChatInterface(
fn=generate,
type="tuples",
additional_inputs_accordion=gr.Accordion(label="詳細設定", open=False),
additional_inputs=[
gr.Slider(
label="Max new tokens",
minimum=1,
maximum=MAX_MAX_NEW_TOKENS,
step=1,
value=DEFAULT_MAX_NEW_TOKENS,
),
gr.Slider(
label="Temperature",
minimum=0.1,
maximum=4.0,
step=0.1,
value=0.7,
),
gr.Slider(
label="Top-p (nucleus sampling)",
minimum=0.05,
maximum=1.0,
step=0.05,
value=0.95,
),
gr.Slider(
label="Top-k",
minimum=1,
maximum=1000,
step=1,
value=50,
),
gr.Slider(
label="Repetition penalty",
minimum=1.0,
maximum=2.0,
step=0.05,
value=1.0,
),
],
stop_btn=None,
examples=[
["サマリーを作る男の人,サマリーマン。"],
["やばい場所にクリティカルな配線ができてしまったので掲示した。"],
["にゃん"],
["Wikipedia の情報は入っているのかもしれない"],
],
description=DESCRIPTION,
css_paths="style.css",
fill_height=True,
)
if __name__ == "__main__":
demo.launch()
|