Update app.py
Browse files
app.py
CHANGED
@@ -24,8 +24,8 @@ def proc_submission(
|
|
24 |
summary_type: str,
|
25 |
num_beams,
|
26 |
token_batch_length,
|
27 |
-
length_penalty,
|
28 |
-
repetition_penalty,
|
29 |
#no_repeat_ngram_size: int = 3,
|
30 |
max_input_length: int = 768,
|
31 |
):
|
@@ -48,7 +48,7 @@ def proc_submission(
|
|
48 |
|
49 |
settings = {
|
50 |
"length_penalty": float(length_penalty),
|
51 |
-
"repetition_penalty": float(repetition_penalty),
|
52 |
"no_repeat_ngram_size": 3,
|
53 |
"encoder_no_repeat_ngram_size": 4,
|
54 |
"num_beams": int(num_beams),
|
@@ -74,10 +74,10 @@ def proc_submission(
|
|
74 |
|
75 |
_summaries = summarize_via_tokenbatches(
|
76 |
tr_in,
|
77 |
-
model_led_det if (model_type == "LED"
|
78 |
-
tokenizer_led_det if (model_type == "LED"
|
79 |
-
model_led_tldr if (model_type == "LED"
|
80 |
-
tokenizer_led_tldr if (model_type == "LED"
|
81 |
batch_length=token_batch_length,
|
82 |
**settings,
|
83 |
)
|
@@ -176,7 +176,7 @@ if __name__ == "__main__":
|
|
176 |
choices=["LongT5", "LED"], label="Model Architecture", value="LongT5"
|
177 |
)
|
178 |
num_beams = gr.Radio(
|
179 |
-
choices=[2, 3, 4],
|
180 |
label="Beam Search: # of Beams",
|
181 |
value=2,
|
182 |
)
|
@@ -197,13 +197,13 @@ if __name__ == "__main__":
|
|
197 |
value=512,
|
198 |
)
|
199 |
|
200 |
-
with gr.Row():
|
201 |
-
repetition_penalty = gr.inputs.Slider(
|
202 |
-
minimum=1.0,
|
203 |
-
maximum=5.0,
|
204 |
-
label="repetition penalty",
|
205 |
-
default=3.5,
|
206 |
-
step=0.1,
|
207 |
)
|
208 |
#no_repeat_ngram_size = gr.Radio(
|
209 |
#choices=[2, 3, 4],
|
@@ -285,7 +285,6 @@ if __name__ == "__main__":
|
|
285 |
num_beams,
|
286 |
token_batch_length,
|
287 |
length_penalty,
|
288 |
-
repetition_penalty,
|
289 |
],
|
290 |
outputs=[output_text, summary_text, compression_rate],
|
291 |
)
|
|
|
24 |
summary_type: str,
|
25 |
num_beams,
|
26 |
token_batch_length,
|
27 |
+
#length_penalty,
|
28 |
+
#repetition_penalty,
|
29 |
#no_repeat_ngram_size: int = 3,
|
30 |
max_input_length: int = 768,
|
31 |
):
|
|
|
48 |
|
49 |
settings = {
|
50 |
"length_penalty": float(length_penalty),
|
51 |
+
"repetition_penalty": 3.5,#float(repetition_penalty),
|
52 |
"no_repeat_ngram_size": 3,
|
53 |
"encoder_no_repeat_ngram_size": 4,
|
54 |
"num_beams": int(num_beams),
|
|
|
74 |
|
75 |
_summaries = summarize_via_tokenbatches(
|
76 |
tr_in,
|
77 |
+
model_led_det if (model_type == "LED" and summary_type == "Detailed") else model_det,
|
78 |
+
tokenizer_led_det if (model_type == "LED" and summary_type == "Detailed") else tokenizer_det,
|
79 |
+
model_led_tldr if (model_type == "LED" and summary_type == "TLDR") else model_tldr,
|
80 |
+
tokenizer_led_tldr if (model_type == "LED" and summary_type == "TLDR") else tokenizer_tldr,
|
81 |
batch_length=token_batch_length,
|
82 |
**settings,
|
83 |
)
|
|
|
176 |
choices=["LongT5", "LED"], label="Model Architecture", value="LongT5"
|
177 |
)
|
178 |
num_beams = gr.Radio(
|
179 |
+
choices=[2, 3, 4, 5, 6],
|
180 |
label="Beam Search: # of Beams",
|
181 |
value=2,
|
182 |
)
|
|
|
197 |
value=512,
|
198 |
)
|
199 |
|
200 |
+
#with gr.Row():
|
201 |
+
#repetition_penalty = gr.inputs.Slider(
|
202 |
+
#minimum=1.0,
|
203 |
+
#maximum=5.0,
|
204 |
+
#label="repetition penalty",
|
205 |
+
#default=3.5,
|
206 |
+
#step=0.1,
|
207 |
)
|
208 |
#no_repeat_ngram_size = gr.Radio(
|
209 |
#choices=[2, 3, 4],
|
|
|
285 |
num_beams,
|
286 |
token_batch_length,
|
287 |
length_penalty,
|
|
|
288 |
],
|
289 |
outputs=[output_text, summary_text, compression_rate],
|
290 |
)
|