Spaces:
Runtime error
Runtime error
Peter
commited on
Commit
·
b4c0306
1
Parent(s):
8d9ed7d
✨ integrate constrained gen
Browse filesSigned-off-by: Peter <[email protected]>
- app.py +31 -26
- constrained_generation.py +3 -5
- converse.py +11 -5
app.py
CHANGED
@@ -5,6 +5,9 @@ app.py - the main file for the app. This creates the flask app and handles the r
|
|
5 |
|
6 |
import argparse
|
7 |
import logging
|
|
|
|
|
|
|
8 |
import os
|
9 |
import sys
|
10 |
import time
|
@@ -16,7 +19,7 @@ import gradio as gr
|
|
16 |
import nltk
|
17 |
import torch
|
18 |
from cleantext import clean
|
19 |
-
from gradio.inputs import Slider, Textbox
|
20 |
from transformers import pipeline
|
21 |
|
22 |
from converse import discussion
|
@@ -40,13 +43,12 @@ warnings.filterwarnings(action="ignore", message=".*gradient_checkpointing*")
|
|
40 |
import transformers
|
41 |
|
42 |
transformers.logging.set_verbosity_error()
|
43 |
-
logging.basicConfig()
|
44 |
cwd = Path.cwd()
|
45 |
my_cwd = str(cwd.resolve()) # string so it can be passed to os.path() objects
|
46 |
|
47 |
|
48 |
def chat(
|
49 |
-
prompt_message, temperature: float = 0.5, top_p: float = 0.95, top_k: int = 20
|
50 |
) -> str:
|
51 |
"""
|
52 |
chat - the main function for the chatbot. This is the function that is called when the user
|
@@ -55,6 +57,7 @@ def chat(
|
|
55 |
:param float temperature: the temperature value for the model, defaults to 0.6
|
56 |
:param float top_p: the top_p value for the model, defaults to 0.95
|
57 |
:param int top_k: the top_k value for the model, defaults to 25
|
|
|
58 |
:return str: the response from the model
|
59 |
"""
|
60 |
history = []
|
@@ -64,6 +67,7 @@ def chat(
|
|
64 |
top_p=top_p,
|
65 |
top_k=top_k,
|
66 |
temperature=temperature,
|
|
|
67 |
)
|
68 |
history = [prompt_message, response]
|
69 |
html = ""
|
@@ -85,7 +89,8 @@ def ask_gpt(
|
|
85 |
top_p=0.95,
|
86 |
top_k=25,
|
87 |
temperature=0.5,
|
88 |
-
constrained_generation=
|
|
|
89 |
) -> str:
|
90 |
"""
|
91 |
ask_gpt - helper function that asks the GPT model a question and returns the response
|
@@ -99,19 +104,20 @@ def ask_gpt(
|
|
99 |
:param float top_p: the top_p value for the model, defaults to 0.95
|
100 |
:param int top_k: the top_k value for the model, defaults to 25
|
101 |
:param float temperature: the temperature value for the model, defaults to 0.6
|
|
|
102 |
:return str: the response from the model
|
103 |
"""
|
104 |
st = time.perf_counter()
|
105 |
prompt = clean(message) # clean user input
|
106 |
prompt = prompt.strip() # get rid of any extra whitespace
|
107 |
in_len = len(chat_pipe.tokenizer(prompt).input_ids)
|
108 |
-
if in_len >
|
109 |
-
# truncate to last
|
110 |
tokens = chat_pipe.tokenizer(prompt).input_ids
|
111 |
-
trunc_tokens = tokens[-
|
112 |
prompt = chat_pipe.tokenizer.decode(trunc_tokens)
|
113 |
print(f"truncated prompt to {len(trunc_tokens)} tokens, input length: {in_len}")
|
114 |
-
|
115 |
resp = discussion(
|
116 |
prompt_text=prompt,
|
117 |
pipeline=chat_pipe,
|
@@ -122,7 +128,7 @@ def ask_gpt(
|
|
122 |
temperature=temperature,
|
123 |
max_length=max_length,
|
124 |
min_length=min_length,
|
125 |
-
|
126 |
)
|
127 |
gpt_et = time.perf_counter()
|
128 |
gpt_rt = round(gpt_et - st, 2)
|
@@ -134,10 +140,9 @@ def ask_gpt(
|
|
134 |
cln_resp = synthesize_grammar(corrector=grammarbot, message=rawtxt)
|
135 |
bot_resp_a = corr(remove_repeated_words(cln_resp))
|
136 |
bot_resp = fix_punct_spacing(bot_resp_a)
|
137 |
-
print(f"the prompt was:\n\t{message}\nand the response was:\n\t{bot_resp}\n")
|
138 |
corr_rt = round(time.perf_counter() - gpt_et, 4)
|
139 |
print(
|
140 |
-
f"
|
141 |
)
|
142 |
return remove_trailing_punctuation(bot_resp)
|
143 |
|
@@ -225,7 +230,7 @@ if __name__ == "__main__":
|
|
225 |
Textbox(
|
226 |
default="Why is everyone here eating chocolate cake?",
|
227 |
label="prompt_message",
|
228 |
-
placeholder="
|
229 |
lines=2,
|
230 |
),
|
231 |
Slider(
|
@@ -233,20 +238,21 @@ if __name__ == "__main__":
|
|
233 |
),
|
234 |
Slider(minimum=0.0, maximum=1.0, step=0.01, default=0.95, label="top_p"),
|
235 |
Slider(minimum=0, maximum=100, step=5, default=20, label="top_k"),
|
|
|
236 |
],
|
237 |
outputs="html",
|
238 |
examples_per_page=8,
|
239 |
examples=[
|
240 |
-
["Point Break or Bad Boys II?", 0.75, 0.95, 50],
|
241 |
-
["So... you're saying this wasn't an accident?", 0.6, 0.95, 40],
|
242 |
-
["Hi, my name is Reginald", 0.6, 0.95, 100],
|
243 |
-
["Happy birthday!", 0.9, 0.95, 50],
|
244 |
-
["I have a question, can you help me?", 0.6, 0.95, 50],
|
245 |
-
["Do you know a joke?", 0.8, 0.85, 50],
|
246 |
-
["Will you marry me?", 0.9, 0.95, 100],
|
247 |
-
["Are you single?", 0.95, 0.95, 100],
|
248 |
-
["Do you like people?", 0.7, 0.95, 25],
|
249 |
-
["You never took a shortcut before?", 0.7, 0.95, 100],
|
250 |
],
|
251 |
title=f"GPT Chatbot Demo: {default_model} Model",
|
252 |
description=f"A Demo of a Chatbot trained for conversation with humans. Size XL= 1.5B parameters.\n\n"
|
@@ -254,20 +260,19 @@ if __name__ == "__main__":
|
|
254 |
"You can find a link to the model card **[here](https://huggingface.co/ethzanalytics/ai-msgbot-gpt2-XL-dialogue)**\n\n"
|
255 |
"1. responses can take up to 60 seconds to respond sometimes, patience is a virtue.\n"
|
256 |
"2. the model was trained on several different datasets. fact-check responses instead of regarding as a true statement.\n"
|
257 |
-
"3. Try adjusting the **[generation parameters](https://huggingface.co/blog/how-to-generate)** to get a better understanding of how they work!\n"
|
|
|
258 |
css="""
|
259 |
.chatbox {display:flex;flex-direction:row}
|
260 |
.user_msg, .resp_msg {padding:4px;margin-bottom:4px;border-radius:4px;width:80%}
|
261 |
.user_msg {background-color:cornflowerblue;color:white;align-self:start}
|
262 |
.resp_msg {background-color:lightgray;align-self:self-end}
|
263 |
""",
|
264 |
-
allow_screenshot=True,
|
265 |
allow_flagging="never",
|
266 |
theme="dark",
|
267 |
)
|
268 |
|
269 |
# launch the gradio interface and start the server
|
270 |
iface.launch(
|
271 |
-
|
272 |
-
enable_queue=True, # also allows for dealing with multiple users simultaneously (per newer gradio version)
|
273 |
)
|
|
|
5 |
|
6 |
import argparse
|
7 |
import logging
|
8 |
+
|
9 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(message)s")
|
10 |
+
|
11 |
import os
|
12 |
import sys
|
13 |
import time
|
|
|
19 |
import nltk
|
20 |
import torch
|
21 |
from cleantext import clean
|
22 |
+
from gradio.inputs import Slider, Textbox, Radio
|
23 |
from transformers import pipeline
|
24 |
|
25 |
from converse import discussion
|
|
|
43 |
import transformers
|
44 |
|
45 |
transformers.logging.set_verbosity_error()
|
|
|
46 |
cwd = Path.cwd()
|
47 |
my_cwd = str(cwd.resolve()) # string so it can be passed to os.path() objects
|
48 |
|
49 |
|
50 |
def chat(
|
51 |
+
prompt_message, temperature: float = 0.5, top_p: float = 0.95, top_k: int = 20, constrained_generation: str = "False"
|
52 |
) -> str:
|
53 |
"""
|
54 |
chat - the main function for the chatbot. This is the function that is called when the user
|
|
|
57 |
:param float temperature: the temperature value for the model, defaults to 0.6
|
58 |
:param float top_p: the top_p value for the model, defaults to 0.95
|
59 |
:param int top_k: the top_k value for the model, defaults to 25
|
60 |
+
:param bool constrained_generation: whether to use constrained generation or not, defaults to False
|
61 |
:return str: the response from the model
|
62 |
"""
|
63 |
history = []
|
|
|
67 |
top_p=top_p,
|
68 |
top_k=top_k,
|
69 |
temperature=temperature,
|
70 |
+
constrained_generation="true" in constrained_generation.lower(),
|
71 |
)
|
72 |
history = [prompt_message, response]
|
73 |
html = ""
|
|
|
89 |
top_p=0.95,
|
90 |
top_k=25,
|
91 |
temperature=0.5,
|
92 |
+
constrained_generation=False,
|
93 |
+
max_input_length=128,
|
94 |
) -> str:
|
95 |
"""
|
96 |
ask_gpt - helper function that asks the GPT model a question and returns the response
|
|
|
104 |
:param float top_p: the top_p value for the model, defaults to 0.95
|
105 |
:param int top_k: the top_k value for the model, defaults to 25
|
106 |
:param float temperature: the temperature value for the model, defaults to 0.6
|
107 |
+
:param bool constrained_generation: whether to use constrained generation or not, defaults to False
|
108 |
:return str: the response from the model
|
109 |
"""
|
110 |
st = time.perf_counter()
|
111 |
prompt = clean(message) # clean user input
|
112 |
prompt = prompt.strip() # get rid of any extra whitespace
|
113 |
in_len = len(chat_pipe.tokenizer(prompt).input_ids)
|
114 |
+
if in_len > max_input_length:
|
115 |
+
# truncate to last max_input_length tokens
|
116 |
tokens = chat_pipe.tokenizer(prompt).input_ids
|
117 |
+
trunc_tokens = tokens[-max_input_length:]
|
118 |
prompt = chat_pipe.tokenizer.decode(trunc_tokens)
|
119 |
print(f"truncated prompt to {len(trunc_tokens)} tokens, input length: {in_len}")
|
120 |
+
logging.info(f"prompt: {prompt}")
|
121 |
resp = discussion(
|
122 |
prompt_text=prompt,
|
123 |
pipeline=chat_pipe,
|
|
|
128 |
temperature=temperature,
|
129 |
max_length=max_length,
|
130 |
min_length=min_length,
|
131 |
+
constrained_beam_search = constrained_generation,
|
132 |
)
|
133 |
gpt_et = time.perf_counter()
|
134 |
gpt_rt = round(gpt_et - st, 2)
|
|
|
140 |
cln_resp = synthesize_grammar(corrector=grammarbot, message=rawtxt)
|
141 |
bot_resp_a = corr(remove_repeated_words(cln_resp))
|
142 |
bot_resp = fix_punct_spacing(bot_resp_a)
|
|
|
143 |
corr_rt = round(time.perf_counter() - gpt_et, 4)
|
144 |
print(
|
145 |
+
f"{gpt_rt + corr_rt} to respond, {gpt_rt} GPT, {corr_rt} for correction\n"
|
146 |
)
|
147 |
return remove_trailing_punctuation(bot_resp)
|
148 |
|
|
|
230 |
Textbox(
|
231 |
default="Why is everyone here eating chocolate cake?",
|
232 |
label="prompt_message",
|
233 |
+
placeholder="Start a conversation with the bot",
|
234 |
lines=2,
|
235 |
),
|
236 |
Slider(
|
|
|
238 |
),
|
239 |
Slider(minimum=0.0, maximum=1.0, step=0.01, default=0.95, label="top_p"),
|
240 |
Slider(minimum=0, maximum=100, step=5, default=20, label="top_k"),
|
241 |
+
Radio(choices=["True", "False"], default="False", label="constrained_generation"),
|
242 |
],
|
243 |
outputs="html",
|
244 |
examples_per_page=8,
|
245 |
examples=[
|
246 |
+
["Point Break or Bad Boys II?", 0.75, 0.95, 50, False],
|
247 |
+
["So... you're saying this wasn't an accident?", 0.6, 0.95, 40, False],
|
248 |
+
["Hi, my name is Reginald", 0.6, 0.95, 100, False],
|
249 |
+
["Happy birthday!", 0.9, 0.95, 50, False],
|
250 |
+
["I have a question, can you help me?", 0.6, 0.95, 50, False],
|
251 |
+
["Do you know a joke?", 0.8, 0.85, 50, False],
|
252 |
+
["Will you marry me?", 0.9, 0.95, 100, False],
|
253 |
+
["Are you single?", 0.95, 0.95, 100, False],
|
254 |
+
["Do you like people?", 0.7, 0.95, 25, False],
|
255 |
+
["You never took a shortcut before?", 0.7, 0.95, 100, False],
|
256 |
],
|
257 |
title=f"GPT Chatbot Demo: {default_model} Model",
|
258 |
description=f"A Demo of a Chatbot trained for conversation with humans. Size XL= 1.5B parameters.\n\n"
|
|
|
260 |
"You can find a link to the model card **[here](https://huggingface.co/ethzanalytics/ai-msgbot-gpt2-XL-dialogue)**\n\n"
|
261 |
"1. responses can take up to 60 seconds to respond sometimes, patience is a virtue.\n"
|
262 |
"2. the model was trained on several different datasets. fact-check responses instead of regarding as a true statement.\n"
|
263 |
+
"3. Try adjusting the **[generation parameters](https://huggingface.co/blog/how-to-generate)** to get a better understanding of how they work!\n"
|
264 |
+
"4. New - try using [constrained beam search](https://huggingface.co/blog/constrained-beam-search) decoding to generate more coherent responses. _(experimental, feedback welcome!)_\n",
|
265 |
css="""
|
266 |
.chatbox {display:flex;flex-direction:row}
|
267 |
.user_msg, .resp_msg {padding:4px;margin-bottom:4px;border-radius:4px;width:80%}
|
268 |
.user_msg {background-color:cornflowerblue;color:white;align-self:start}
|
269 |
.resp_msg {background-color:lightgray;align-self:self-end}
|
270 |
""",
|
|
|
271 |
allow_flagging="never",
|
272 |
theme="dark",
|
273 |
)
|
274 |
|
275 |
# launch the gradio interface and start the server
|
276 |
iface.launch(
|
277 |
+
enable_queue=True,
|
|
|
278 |
)
|
constrained_generation.py
CHANGED
@@ -4,6 +4,7 @@
|
|
4 |
|
5 |
import copy
|
6 |
import logging
|
|
|
7 |
import time
|
8 |
from pathlib import Path
|
9 |
|
@@ -81,7 +82,7 @@ def create_kw_extractor(
|
|
81 |
)
|
82 |
|
83 |
|
84 |
-
def simple_kw(body_text: str, yake_ex=None, max_kw=
|
85 |
"""
|
86 |
simple_kw - extract keywords from a text using yake
|
87 |
|
@@ -96,7 +97,7 @@ def simple_kw(body_text: str, yake_ex=None, max_kw=10, verbose=False):
|
|
96 |
"""
|
97 |
yake_ex = yake_ex or create_kw_extractor(
|
98 |
max_ngram_size=2,
|
99 |
-
ddpt=0.
|
100 |
windowSize=10,
|
101 |
deduplication_algo="seqm",
|
102 |
numOfKeywords=max_kw,
|
@@ -219,7 +220,6 @@ def constrained_generation(
|
|
219 |
if force_flexible is not None
|
220 |
else None
|
221 |
)
|
222 |
-
|
223 |
try:
|
224 |
logging.info("generating text..")
|
225 |
result = pipeline(
|
@@ -236,8 +236,6 @@ def constrained_generation(
|
|
236 |
length_penalty=length_penalty,
|
237 |
repetition_penalty=repetition_penalty,
|
238 |
return_full_text=full_text,
|
239 |
-
remove_invalid_values=True,
|
240 |
-
skip_special_tokens=True,
|
241 |
clean_up_tokenization_spaces=True,
|
242 |
early_stopping=True,
|
243 |
do_sample=False,
|
|
|
4 |
|
5 |
import copy
|
6 |
import logging
|
7 |
+
logging.basicConfig(level=logging.INFO)
|
8 |
import time
|
9 |
from pathlib import Path
|
10 |
|
|
|
82 |
)
|
83 |
|
84 |
|
85 |
+
def simple_kw(body_text: str, yake_ex=None, max_kw=15, verbose=False):
|
86 |
"""
|
87 |
simple_kw - extract keywords from a text using yake
|
88 |
|
|
|
97 |
"""
|
98 |
yake_ex = yake_ex or create_kw_extractor(
|
99 |
max_ngram_size=2,
|
100 |
+
ddpt=0.9,
|
101 |
windowSize=10,
|
102 |
deduplication_algo="seqm",
|
103 |
numOfKeywords=max_kw,
|
|
|
220 |
if force_flexible is not None
|
221 |
else None
|
222 |
)
|
|
|
223 |
try:
|
224 |
logging.info("generating text..")
|
225 |
result = pipeline(
|
|
|
236 |
length_penalty=length_penalty,
|
237 |
repetition_penalty=repetition_penalty,
|
238 |
return_full_text=full_text,
|
|
|
|
|
239 |
clean_up_tokenization_spaces=True,
|
240 |
early_stopping=True,
|
241 |
do_sample=False,
|
converse.py
CHANGED
@@ -4,7 +4,8 @@
|
|
4 |
https://huggingface.co/docs/transformers/v4.15.0/en/main_classes/model#transformers.generation_utils.GenerationMixin.generate.no_repeat_ngram_size
|
5 |
"""
|
6 |
|
7 |
-
|
|
|
8 |
import pprint as pp
|
9 |
import time
|
10 |
|
@@ -29,7 +30,7 @@ def discussion(
|
|
29 |
num_return_sequences=1,
|
30 |
device=-1,
|
31 |
verbose=False,
|
32 |
-
|
33 |
):
|
34 |
"""
|
35 |
discussion - a function that takes in a prompt and generates a response. This function is meant to be used in a conversation loop, and is the main function for the bot.
|
@@ -66,7 +67,8 @@ def discussion(
|
|
66 |
pp.pprint(this_prompt, indent=4)
|
67 |
# call the model
|
68 |
print("\n... generating...")
|
69 |
-
if
|
|
|
70 |
response = constrained_generation(
|
71 |
prompt=this_prompt,
|
72 |
pipeline=pipeline,
|
@@ -75,7 +77,7 @@ def discussion(
|
|
75 |
repetition_penalty=1.0,
|
76 |
num_beams=4,
|
77 |
timeout=timeout,
|
78 |
-
verbose=
|
79 |
full_text=full_text,
|
80 |
speaker_name=speaker,
|
81 |
responder_name=responder,
|
@@ -83,12 +85,15 @@ def discussion(
|
|
83 |
|
84 |
bot_dialogue = consolidate_texts(
|
85 |
name_resp=responder,
|
86 |
-
model_resp=response
|
|
|
|
|
87 |
name_spk=speaker,
|
88 |
verbose=verbose,
|
89 |
print_debug=True,
|
90 |
)
|
91 |
else:
|
|
|
92 |
bot_dialogue = gen_response(
|
93 |
this_prompt,
|
94 |
pipeline,
|
@@ -123,6 +128,7 @@ def discussion(
|
|
123 |
p_list.append("\n")
|
124 |
|
125 |
print("\nfinished!")
|
|
|
126 |
# return the bot response and the full conversation
|
127 |
|
128 |
return {"out_text": bot_resp, "full_conv": p_list}
|
|
|
4 |
https://huggingface.co/docs/transformers/v4.15.0/en/main_classes/model#transformers.generation_utils.GenerationMixin.generate.no_repeat_ngram_size
|
5 |
"""
|
6 |
|
7 |
+
import logging
|
8 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
9 |
import pprint as pp
|
10 |
import time
|
11 |
|
|
|
30 |
num_return_sequences=1,
|
31 |
device=-1,
|
32 |
verbose=False,
|
33 |
+
constrained_beam_search=False,
|
34 |
):
|
35 |
"""
|
36 |
discussion - a function that takes in a prompt and generates a response. This function is meant to be used in a conversation loop, and is the main function for the bot.
|
|
|
67 |
pp.pprint(this_prompt, indent=4)
|
68 |
# call the model
|
69 |
print("\n... generating...")
|
70 |
+
if constrained_beam_search:
|
71 |
+
logging.info("using constrained beam search")
|
72 |
response = constrained_generation(
|
73 |
prompt=this_prompt,
|
74 |
pipeline=pipeline,
|
|
|
77 |
repetition_penalty=1.0,
|
78 |
num_beams=4,
|
79 |
timeout=timeout,
|
80 |
+
verbose=False,
|
81 |
full_text=full_text,
|
82 |
speaker_name=speaker,
|
83 |
responder_name=responder,
|
|
|
85 |
|
86 |
bot_dialogue = consolidate_texts(
|
87 |
name_resp=responder,
|
88 |
+
model_resp=response.split(
|
89 |
+
"\n"
|
90 |
+
),
|
91 |
name_spk=speaker,
|
92 |
verbose=verbose,
|
93 |
print_debug=True,
|
94 |
)
|
95 |
else:
|
96 |
+
logging.info("using sampling")
|
97 |
bot_dialogue = gen_response(
|
98 |
this_prompt,
|
99 |
pipeline,
|
|
|
128 |
p_list.append("\n")
|
129 |
|
130 |
print("\nfinished!")
|
131 |
+
logging.info(f"finished generating response:\n\t{bot_resp}")
|
132 |
# return the bot response and the full conversation
|
133 |
|
134 |
return {"out_text": bot_resp, "full_conv": p_list}
|