Azure99 commited on
Commit
0aa17c3
·
verified ·
1 Parent(s): 86be887

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -28
app.py CHANGED
@@ -5,28 +5,45 @@ import spaces
5
  import torch
6
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
7
 
8
- MAX_NEW_TOKENS = 2048
9
- MODEL_NAME = "Azure99/Blossom-V6-7B"
10
 
11
- model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.bfloat16, device_map="auto")
 
 
12
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
13
 
14
 
15
  def get_input_ids(inst, history):
16
  conversation = []
17
  for user, assistant in history:
18
- conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
 
 
 
 
 
19
  conversation.append({"role": "user", "content": inst})
20
- return tokenizer.apply_chat_template(conversation, return_tensors="pt").to(model.device)
 
 
21
 
22
 
23
- @spaces.GPU
24
  def chat(inst, history, temperature, top_p, repetition_penalty):
25
- streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
 
 
26
  input_ids = get_input_ids(inst, history)
27
- generation_kwargs = dict(input_ids=input_ids,
28
- streamer=streamer, do_sample=True, max_new_tokens=MAX_NEW_TOKENS,
29
- temperature=temperature, top_p=top_p, repetition_penalty=repetition_penalty)
 
 
 
 
 
 
30
 
31
  Thread(target=model.generate, kwargs=generation_kwargs).start()
32
 
@@ -63,23 +80,30 @@ additional_inputs = [
63
  step=0.01,
64
  interactive=True,
65
  info="Repetition Penalty: Controls how much repetition is penalized.",
66
- )
67
  ]
68
 
69
- gr.ChatInterface(chat,
70
- chatbot=gr.Chatbot(show_label=False, height=500, show_copy_button=True, render_markdown=True),
71
- textbox=gr.Textbox(placeholder="", container=False, scale=7),
72
- title="Blossom-V6-7B Demo",
73
- description='Hello, I am Blossom, an open source conversational large language model.🌠'
74
- '<a href="https://github.com/Azure99/BlossomLM">GitHub</a>',
75
- theme="soft",
76
- examples=[["Hello"], ["What is MBTI"], ["用Python实现二分查找"],
77
- ["为switch写一篇小红书种草文案,带上emoji"]],
78
- cache_examples=False,
79
- additional_inputs=additional_inputs,
80
- additional_inputs_accordion=gr.Accordion(label="Config", open=True),
81
- clear_btn="🗑️Clear",
82
- undo_btn="↩️Undo",
83
- retry_btn="🔄Retry",
84
- submit_btn="➡️Submit",
85
- ).queue().launch()
 
 
 
 
 
 
 
 
5
  import torch
6
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
7
 
8
+ MAX_NEW_TOKENS = 8192
9
+ MODEL_NAME = "Azure99/Blossom-V6.1-8B"
10
 
11
+ model = AutoModelForCausalLM.from_pretrained(
12
+ MODEL_NAME, torch_dtype=torch.bfloat16, device_map="auto"
13
+ )
14
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
15
 
16
 
17
  def get_input_ids(inst, history):
18
  conversation = []
19
  for user, assistant in history:
20
+ conversation.extend(
21
+ [
22
+ {"role": "user", "content": user},
23
+ {"role": "assistant", "content": assistant},
24
+ ]
25
+ )
26
  conversation.append({"role": "user", "content": inst})
27
+ return tokenizer.apply_chat_template(conversation, return_tensors="pt").to(
28
+ model.device
29
+ )
30
 
31
 
32
+ @spaces.GPU(duration=120)
33
  def chat(inst, history, temperature, top_p, repetition_penalty):
34
+ streamer = TextIteratorStreamer(
35
+ tokenizer, skip_prompt=True, skip_special_tokens=True
36
+ )
37
  input_ids = get_input_ids(inst, history)
38
+ generation_kwargs = dict(
39
+ input_ids=input_ids,
40
+ streamer=streamer,
41
+ do_sample=True,
42
+ max_new_tokens=MAX_NEW_TOKENS,
43
+ temperature=temperature,
44
+ top_p=top_p,
45
+ repetition_penalty=repetition_penalty,
46
+ )
47
 
48
  Thread(target=model.generate, kwargs=generation_kwargs).start()
49
 
 
80
  step=0.01,
81
  interactive=True,
82
  info="Repetition Penalty: Controls how much repetition is penalized.",
83
+ ),
84
  ]
85
 
86
+ gr.ChatInterface(
87
+ chat,
88
+ chatbot=gr.Chatbot(
89
+ show_label=False, height=500, show_copy_button=True, render_markdown=True
90
+ ),
91
+ textbox=gr.Textbox(placeholder="", container=False, scale=7),
92
+ title="Blossom-V6.1-8B Demo",
93
+ description="Hello, I am Blossom, an open source conversational large language model.🌠"
94
+ '<a href="https://github.com/Azure99/BlossomLM">GitHub</a>',
95
+ theme="soft",
96
+ examples=[
97
+ ["Hello"],
98
+ ["What is MBTI"],
99
+ ["用Python实现二分查找"],
100
+ ["为switch写一篇小红书种草文案,带上emoji"],
101
+ ],
102
+ cache_examples=False,
103
+ additional_inputs=additional_inputs,
104
+ additional_inputs_accordion=gr.Accordion(label="Config", open=True),
105
+ clear_btn="🗑️Clear",
106
+ undo_btn="↩️Undo",
107
+ retry_btn="🔄Retry",
108
+ submit_btn="➡️Submit",
109
+ ).queue().launch()