Update app.py
Browse files
app.py
CHANGED
@@ -11,6 +11,7 @@ model_path = hf_hub_download(repo_id="BlinkDL/rwkv-5-world", filename=f"{title}.
|
|
11 |
model = RWKV(model=model_path, strategy='cpu bf16')
|
12 |
pipeline = PIPELINE(model, "rwkv_vocab_v20230424")
|
13 |
|
|
|
14 |
def generate_prompt(instruction, input=None, history=None):
|
15 |
# parse the chat history into a string of user and assistant messages
|
16 |
history_str = ""
|
@@ -32,6 +33,7 @@ Response:"""
|
|
32 |
|
33 |
Assistant:"""
|
34 |
|
|
|
35 |
examples = [
|
36 |
["東京で訪れるべき素晴らしい場所とその紹介をいくつか挙げてください。", "", 300, 1.2, 0.5, 0.5, 0.5],
|
37 |
["Écrivez un programme Python pour miner 1 Bitcoin, avec des commentaires.", "", 300, 1.2, 0.5, 0.5, 0.5],
|
@@ -42,7 +44,7 @@ examples = [
|
|
42 |
["You have $100, and your goal is to turn that into as much money as possible with AI and Machine Learning. Please respond with detailed plan.", "", 300, 1.2, 0.5, 0.5, 0.5],
|
43 |
]
|
44 |
|
45 |
-
def
|
46 |
instruction,
|
47 |
input=None,
|
48 |
token_count=333,
|
@@ -59,9 +61,6 @@ def evaluate(
|
|
59 |
token_stop = [0]) # stop generation whenever you see any token here
|
60 |
|
61 |
instruction = re.sub(r'\n{2,}', '\n', instruction).strip().replace('\r\n','\n')
|
62 |
-
no_history = (history is None)
|
63 |
-
if no_history:
|
64 |
-
input = re.sub(r'\n{2,}', '\n', input).strip().replace('\r\n','\n')
|
65 |
ctx = generate_prompt(instruction, input, history)
|
66 |
print(ctx + "\n")
|
67 |
|
@@ -89,8 +88,6 @@ def evaluate(
|
|
89 |
tmp = pipeline.decode(all_tokens[out_last:])
|
90 |
if '\ufffd' not in tmp:
|
91 |
out_str += tmp
|
92 |
-
if no_history:
|
93 |
-
yield out_str.strip()
|
94 |
out_last = i + 1
|
95 |
if '\n\n' in out_str:
|
96 |
break
|
@@ -98,11 +95,61 @@ def evaluate(
|
|
98 |
del out
|
99 |
del state
|
100 |
gc.collect()
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
106 |
|
107 |
def user(message, chatbot):
|
108 |
chatbot = chatbot or []
|
@@ -153,7 +200,7 @@ with gr.Blocks(title=title) as demo:
|
|
153 |
presence_penalty = presence_penalty_chat.value
|
154 |
count_penalty = count_penalty_chat.value
|
155 |
|
156 |
-
response =
|
157 |
|
158 |
history[-1][1] = response
|
159 |
return history
|
@@ -179,7 +226,7 @@ with gr.Blocks(title=title) as demo:
|
|
179 |
clear = gr.Button("Clear", variant="secondary")
|
180 |
output = gr.Textbox(label="Output", lines=5)
|
181 |
data = gr.Dataset(components=[instruction, input_instruct, token_count_instruct, temperature_instruct, top_p_instruct, presence_penalty_instruct, count_penalty_instruct], samples=examples, label="Example Instructions", headers=["Instruction", "Input", "Max Tokens", "Temperature", "Top P", "Presence Penalty", "Count Penalty"])
|
182 |
-
submit.click(
|
183 |
clear.click(lambda: None, [], [output])
|
184 |
data.click(lambda x: x, [data], [instruction, input_instruct, token_count_instruct, temperature_instruct, top_p_instruct, presence_penalty_instruct, count_penalty_instruct])
|
185 |
|
|
|
11 |
model = RWKV(model=model_path, strategy='cpu bf16')
|
12 |
pipeline = PIPELINE(model, "rwkv_vocab_v20230424")
|
13 |
|
14 |
+
|
15 |
def generate_prompt(instruction, input=None, history=None):
|
16 |
# parse the chat history into a string of user and assistant messages
|
17 |
history_str = ""
|
|
|
33 |
|
34 |
Assistant:"""
|
35 |
|
36 |
+
|
37 |
examples = [
|
38 |
["東京で訪れるべき素晴らしい場所とその紹介をいくつか挙げてください。", "", 300, 1.2, 0.5, 0.5, 0.5],
|
39 |
["Écrivez un programme Python pour miner 1 Bitcoin, avec des commentaires.", "", 300, 1.2, 0.5, 0.5, 0.5],
|
|
|
44 |
["You have $100, and your goal is to turn that into as much money as possible with AI and Machine Learning. Please respond with detailed plan.", "", 300, 1.2, 0.5, 0.5, 0.5],
|
45 |
]
|
46 |
|
47 |
+
def respond(
|
48 |
instruction,
|
49 |
input=None,
|
50 |
token_count=333,
|
|
|
61 |
token_stop = [0]) # stop generation whenever you see any token here
|
62 |
|
63 |
instruction = re.sub(r'\n{2,}', '\n', instruction).strip().replace('\r\n','\n')
|
|
|
|
|
|
|
64 |
ctx = generate_prompt(instruction, input, history)
|
65 |
print(ctx + "\n")
|
66 |
|
|
|
88 |
tmp = pipeline.decode(all_tokens[out_last:])
|
89 |
if '\ufffd' not in tmp:
|
90 |
out_str += tmp
|
|
|
|
|
91 |
out_last = i + 1
|
92 |
if '\n\n' in out_str:
|
93 |
break
|
|
|
95 |
del out
|
96 |
del state
|
97 |
gc.collect()
|
98 |
+
return out_str.strip()
|
99 |
+
|
100 |
+
def generator(
|
101 |
+
instruction,
|
102 |
+
input=None,
|
103 |
+
token_count=333,
|
104 |
+
temperature=1.0,
|
105 |
+
top_p=0.5,
|
106 |
+
presencePenalty = 0.5,
|
107 |
+
countPenalty = 0.5
|
108 |
+
):
|
109 |
+
args = PIPELINE_ARGS(temperature = max(0.2, float(temperature)), top_p = float(top_p),
|
110 |
+
alpha_frequency = countPenalty,
|
111 |
+
alpha_presence = presencePenalty,
|
112 |
+
token_ban = [], # ban the generation of some tokens
|
113 |
+
token_stop = [0]) # stop generation whenever you see any token here
|
114 |
+
|
115 |
+
instruction = re.sub(r'\n{2,}', '\n', instruction).strip().replace('\r\n','\n')
|
116 |
+
input = re.sub(r'\n{2,}', '\n', input).strip().replace('\r\n','\n')
|
117 |
+
ctx = generate_prompt(instruction, input, history)
|
118 |
+
print(ctx + "\n")
|
119 |
+
|
120 |
+
all_tokens = []
|
121 |
+
out_last = 0
|
122 |
+
out_str = ''
|
123 |
+
occurrence = {}
|
124 |
+
state = None
|
125 |
+
for i in range(int(token_count)):
|
126 |
+
out, state = model.forward(pipeline.encode(ctx)[-ctx_limit:] if i == 0 else [token], state)
|
127 |
+
for n in occurrence:
|
128 |
+
out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency)
|
129 |
+
|
130 |
+
token = pipeline.sample_logits(out, temperature=args.temperature, top_p=args.top_p)
|
131 |
+
if token in args.token_stop:
|
132 |
+
break
|
133 |
+
all_tokens += [token]
|
134 |
+
for xxx in occurrence:
|
135 |
+
occurrence[xxx] *= 0.996
|
136 |
+
if token not in occurrence:
|
137 |
+
occurrence[token] = 1
|
138 |
+
else:
|
139 |
+
occurrence[token] += 1
|
140 |
+
|
141 |
+
tmp = pipeline.decode(all_tokens[out_last:])
|
142 |
+
if '\ufffd' not in tmp:
|
143 |
+
out_str += tmp
|
144 |
+
yield out_str.strip()
|
145 |
+
out_last = i + 1
|
146 |
+
if '\n\n' in out_str:
|
147 |
+
break
|
148 |
+
|
149 |
+
del out
|
150 |
+
del state
|
151 |
+
gc.collect()
|
152 |
+
yield out_str.strip()
|
153 |
|
154 |
def user(message, chatbot):
|
155 |
chatbot = chatbot or []
|
|
|
200 |
presence_penalty = presence_penalty_chat.value
|
201 |
count_penalty = count_penalty_chat.value
|
202 |
|
203 |
+
response = respond(instruction, None, token_count, temperature, top_p, presence_penalty, count_penalty, history)
|
204 |
|
205 |
history[-1][1] = response
|
206 |
return history
|
|
|
226 |
clear = gr.Button("Clear", variant="secondary")
|
227 |
output = gr.Textbox(label="Output", lines=5)
|
228 |
data = gr.Dataset(components=[instruction, input_instruct, token_count_instruct, temperature_instruct, top_p_instruct, presence_penalty_instruct, count_penalty_instruct], samples=examples, label="Example Instructions", headers=["Instruction", "Input", "Max Tokens", "Temperature", "Top P", "Presence Penalty", "Count Penalty"])
|
229 |
+
submit.click(generator, [instruction, input_instruct, token_count_instruct, temperature_instruct, top_p_instruct, presence_penalty_instruct, count_penalty_instruct], [output])
|
230 |
clear.click(lambda: None, [], [output])
|
231 |
data.click(lambda x: x, [data], [instruction, input_instruct, token_count_instruct, temperature_instruct, top_p_instruct, presence_penalty_instruct, count_penalty_instruct])
|
232 |
|