Spaces:
Runtime error
Runtime error
Peter
commited on
Commit
·
0b3d061
1
Parent(s):
235585a
🐛 fix input len bug
Browse filesSigned-off-by: Peter <[email protected]>
- app.py +8 -5
- converse.py +12 -8
- grammar_improve.py +5 -3
app.py
CHANGED
@@ -101,11 +101,13 @@ def ask_gpt(
|
|
101 |
st = time.perf_counter()
|
102 |
prompt = clean(message) # clean user input
|
103 |
prompt = prompt.strip() # get rid of any extra whitespace
|
104 |
-
in_len = len(prompt)
|
105 |
if in_len > 512:
|
106 |
-
|
107 |
-
|
108 |
-
|
|
|
|
|
109 |
|
110 |
resp = discussion(
|
111 |
prompt_text=prompt,
|
@@ -115,7 +117,8 @@ def ask_gpt(
|
|
115 |
top_p=top_p,
|
116 |
top_k=top_k,
|
117 |
temperature=temperature,
|
118 |
-
max_length=
|
|
|
119 |
)
|
120 |
gpt_et = time.perf_counter()
|
121 |
gpt_rt = round(gpt_et - st, 2)
|
|
|
101 |
st = time.perf_counter()
|
102 |
prompt = clean(message) # clean user input
|
103 |
prompt = prompt.strip() # get rid of any extra whitespace
|
104 |
+
in_len = len(chat_pipe.tokenizer(prompt).input_ids)
|
105 |
if in_len > 512:
|
106 |
+
# truncate to last 512 tokens
|
107 |
+
tokens = chat_pipe.tokenizer(prompt).input_ids
|
108 |
+
trunc_tokens = tokens[-512:]
|
109 |
+
prompt = chat_pipe.tokenizer.decode(trunc_tokens)
|
110 |
+
print(f"truncated prompt to {len(trunc_tokens)} tokens, input length: {in_len}")
|
111 |
|
112 |
resp = discussion(
|
113 |
prompt_text=prompt,
|
|
|
117 |
top_p=top_p,
|
118 |
top_k=top_k,
|
119 |
temperature=temperature,
|
120 |
+
max_length=max_length,
|
121 |
+
min_length=min_length,
|
122 |
)
|
123 |
gpt_et = time.perf_counter()
|
124 |
gpt_rt = round(gpt_et - st, 2)
|
converse.py
CHANGED
@@ -17,7 +17,8 @@ def discussion(
|
|
17 |
responder: str,
|
18 |
pipeline,
|
19 |
timeout=45,
|
20 |
-
|
|
|
21 |
top_p=0.95,
|
22 |
top_k=50,
|
23 |
temperature=0.7,
|
@@ -104,7 +105,8 @@ def gen_response(
|
|
104 |
speaker: str,
|
105 |
responder: str,
|
106 |
timeout=45,
|
107 |
-
|
|
|
108 |
top_p=0.95,
|
109 |
top_k=50,
|
110 |
temperature=0.7,
|
@@ -125,7 +127,8 @@ def gen_response(
|
|
125 |
responder : str, the name of the person who is responding to the prompt
|
126 |
pipeline : transformers.Pipeline, the pipeline to use for generating the response
|
127 |
timeout : int, optional, the number of seconds to wait before timing out, by default 45
|
128 |
-
|
|
|
129 |
top_p : float, optional, the top probability to use for sampling, defaults to 0.95
|
130 |
top_k : int, optional, the top k to use for sampling, defaults to 50
|
131 |
temperature : float, optional, the temperature to use for sampling, defaults to 0.7
|
@@ -139,15 +142,16 @@ def gen_response(
|
|
139 |
str, the generated text
|
140 |
|
141 |
"""
|
142 |
-
|
143 |
-
if max_length > 1024:
|
144 |
-
max_length = 1024
|
145 |
-
print("max_length
|
146 |
st = time.perf_counter()
|
147 |
|
148 |
response = pipeline(
|
149 |
query,
|
150 |
-
|
|
|
151 |
temperature=temperature,
|
152 |
top_k=top_k,
|
153 |
top_p=top_p,
|
|
|
17 |
responder: str,
|
18 |
pipeline,
|
19 |
timeout=45,
|
20 |
+
min_length=4,
|
21 |
+
max_length=64,
|
22 |
top_p=0.95,
|
23 |
top_k=50,
|
24 |
temperature=0.7,
|
|
|
105 |
speaker: str,
|
106 |
responder: str,
|
107 |
timeout=45,
|
108 |
+
min_length=4,
|
109 |
+
max_length=64,
|
110 |
top_p=0.95,
|
111 |
top_k=50,
|
112 |
temperature=0.7,
|
|
|
127 |
responder : str, the name of the person who is responding to the prompt
|
128 |
pipeline : transformers.Pipeline, the pipeline to use for generating the response
|
129 |
timeout : int, optional, the number of seconds to wait before timing out, by default 45
|
130 |
+
min_length : int, optional, the minimum number of tokens to generate, defaults to 4
|
131 |
+
max_length : int, optional, the maximum number of tokens to generate, defaults to 64
|
132 |
top_p : float, optional, the top probability to use for sampling, defaults to 0.95
|
133 |
top_k : int, optional, the top k to use for sampling, defaults to 50
|
134 |
temperature : float, optional, the temperature to use for sampling, defaults to 0.7
|
|
|
142 |
str, the generated text
|
143 |
|
144 |
"""
|
145 |
+
input_len = len(pipeline.tokenizer(query).input_ids)
|
146 |
+
if max_length + input_len > 1024:
|
147 |
+
max_length = max(1024 - input_len, 8)
|
148 |
+
print(f"max_length too large, setting to {max_length}")
|
149 |
st = time.perf_counter()
|
150 |
|
151 |
response = pipeline(
|
152 |
query,
|
153 |
+
min_length=min_length + input_len,
|
154 |
+
max_length=max_length + input_len,
|
155 |
temperature=temperature,
|
156 |
top_k=top_k,
|
157 |
top_p=top_p,
|
grammar_improve.py
CHANGED
@@ -137,10 +137,11 @@ def synthesize_grammar(
|
|
137 |
"""
|
138 |
st = time.perf_counter()
|
139 |
input_text = clean(message, lower=False)
|
|
|
140 |
results = corrector(
|
141 |
input_text,
|
142 |
-
max_length=int(1.1 *
|
143 |
-
min_length=2 if
|
144 |
num_beams=num_beams,
|
145 |
repetition_penalty=repetition_penalty,
|
146 |
length_penalty=length_penalty,
|
@@ -479,7 +480,8 @@ def correct_grammar(
|
|
479 |
"""
|
480 |
st = time.perf_counter()
|
481 |
|
482 |
-
if len(input_text) <
|
|
|
483 |
return input_text
|
484 |
max_length = min(int(math.ceil(len(input_text) * 1.2)), 128)
|
485 |
batch = tokenizer(
|
|
|
137 |
"""
|
138 |
st = time.perf_counter()
|
139 |
input_text = clean(message, lower=False)
|
140 |
+
input_len = len(corrector.tokenizer(input_text).input_ids)
|
141 |
results = corrector(
|
142 |
input_text,
|
143 |
+
max_length=int(1.1 * input_len),
|
144 |
+
min_length=2 if input_len < 64 else int(0.2 * input_len),
|
145 |
num_beams=num_beams,
|
146 |
repetition_penalty=repetition_penalty,
|
147 |
length_penalty=length_penalty,
|
|
|
480 |
"""
|
481 |
st = time.perf_counter()
|
482 |
|
483 |
+
if len(tokenizer(input_text).input_ids) < 4:
|
484 |
+
print(f"input text of {input_text} is too short to be corrected")
|
485 |
return input_text
|
486 |
max_length = min(int(math.ceil(len(input_text) * 1.2)), 128)
|
487 |
batch = tokenizer(
|