metastable-void commited on
Commit
97befb1
·
unverified ·
1 Parent(s): 9ab7a40

update to use my config

Browse files
Files changed (2) hide show
  1. README.md +1 -1
  2. app.py +18 -45
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- title: CALM2-7B-chat
3
  emoji: ⚡
4
  colorFrom: red
5
  colorTo: purple
 
1
  ---
2
+ title: chat-1
3
  emoji: ⚡
4
  colorFrom: red
5
  colorTo: purple
app.py CHANGED
@@ -7,9 +7,9 @@ from threading import Thread
7
  import gradio as gr
8
  import spaces
9
  import torch
10
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
11
 
12
- DESCRIPTION = "# CALM2-7B-chat"
13
 
14
  if not torch.cuda.is_available():
15
  DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
@@ -20,14 +20,10 @@ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "32768"))
20
 
21
  if torch.cuda.is_available():
22
  model_id = "cyberagent/calm2-7b-chat"
23
- model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto")
24
- tokenizer = AutoTokenizer.from_pretrained(model_id)
25
-
26
-
27
- def apply_chat_template(conversation: list[dict[str, str]]) -> str:
28
- prompt = "\n".join([f"{c['role']}: {c['content']}" for c in conversation])
29
- return f"{prompt}\nASSISTANT: "
30
-
31
 
32
  @spaces.GPU
33
  @torch.inference_mode()
@@ -40,38 +36,15 @@ def generate(
40
  top_k: int = 50,
41
  repetition_penalty: float = 1.0,
42
  ) -> Iterator[str]:
43
- conversation = []
44
- for user, assistant in chat_history:
45
- conversation.extend([{"role": "USER", "content": user}, {"role": "ASSISTANT", "content": assistant}])
46
- conversation.append({"role": "USER", "content": message})
47
-
48
- prompt = apply_chat_template(conversation)
49
- input_ids = tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt")
50
- if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
51
- input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
52
- gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
53
- input_ids = input_ids.to(model.device)
54
-
55
- streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
56
- generate_kwargs = dict(
57
- {"input_ids": input_ids},
58
- streamer=streamer,
59
- max_new_tokens=max_new_tokens,
60
- do_sample=True,
61
- top_p=top_p,
62
- top_k=top_k,
63
- temperature=temperature,
64
- num_beams=1,
65
- repetition_penalty=repetition_penalty,
66
- )
67
- t = Thread(target=model.generate, kwargs=generate_kwargs)
68
- t.start()
69
-
70
- outputs = []
71
- for text in streamer:
72
- outputs.append(text)
73
- yield "".join(outputs)
74
 
 
 
 
 
75
 
76
  demo = gr.ChatInterface(
77
  fn=generate,
@@ -116,10 +89,10 @@ demo = gr.ChatInterface(
116
  ],
117
  stop_btn=None,
118
  examples=[
119
- ["東京の観光名所を教えて。"],
120
- ["落武者って何?"], # noqa: RUF001
121
- ["暴れん坊将軍って誰のこと?"], # noqa: RUF001
122
- ["人がヘリを食べるのにかかる時間は?"], # noqa: RUF001
123
  ],
124
  description=DESCRIPTION,
125
  css_paths="style.css",
 
7
  import gradio as gr
8
  import spaces
9
  import torch
10
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, pipeline
11
 
12
+ DESCRIPTION = "# chat-1"
13
 
14
  if not torch.cuda.is_available():
15
  DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
 
20
 
21
  if torch.cuda.is_available():
22
  model_id = "cyberagent/calm2-7b-chat"
23
+ my_pipeline=pipeline(
24
+ model=model_id,
25
+ )
26
+ my_pipeline.tokenizer.chat_template = "{{bos_token}}{% for message in messages %}{% if message['role'] == 'user' %}{{ '\\n\\n### 前の投稿:\\n' + message['content'] + '' }}{% elif message['role'] == 'system' %}{{ '以下は、SNS上の投稿です。あなたはSNSの投稿生成botとして、次に続く投稿を考えなさい。説明はせず、投稿の内容のみを鉤括弧をつけずに答えよ。' }}{% elif message['role'] == 'assistant' %}{{ '\\n\\n### 次の投稿:\\n' + message['content'] + eos_token }}{% endif %}{% if loop.last and add_generation_prompt %}{{ '\\n\\n### 次の投稿:\\n' }}{% endif %}{% endfor %}"
 
 
 
 
27
 
28
  @spaces.GPU
29
  @torch.inference_mode()
 
36
  top_k: int = 50,
37
  repetition_penalty: float = 1.0,
38
  ) -> Iterator[str]:
39
+ messages = [
40
+ {"role": "system", "content": "あなたはSNSの投稿生成botで、次に続く投稿を考えてください。"},
41
+ {"role": "user", "content": message},
42
+ ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
+ output = my_pipeline(
45
+ messages,
46
+ )[-1]["generated_text"][-1]["content"]
47
+ yield output
48
 
49
  demo = gr.ChatInterface(
50
  fn=generate,
 
89
  ],
90
  stop_btn=None,
91
  examples=[
92
+ ["サマリーを作る男の人,サマリーマン。"],
93
+ ["やばい場所にクリティカルな配線ができてしまったので掲示した。"],
94
+ ["にゃん"],
95
+ ["Wikipedia の情報は入っているのかもしれない"],
96
  ],
97
  description=DESCRIPTION,
98
  css_paths="style.css",