Blaise-g commited on
Commit
e6b6c5d
Β·
1 Parent(s): 5ca1bcb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -16
app.py CHANGED
@@ -24,8 +24,8 @@ def proc_submission(
24
  summary_type: str,
25
  num_beams,
26
  token_batch_length,
27
- length_penalty,
28
- repetition_penalty,
29
  #no_repeat_ngram_size: int = 3,
30
  max_input_length: int = 768,
31
  ):
@@ -48,7 +48,7 @@ def proc_submission(
48
 
49
  settings = {
50
  "length_penalty": float(length_penalty),
51
- "repetition_penalty": float(repetition_penalty),
52
  "no_repeat_ngram_size": 3,
53
  "encoder_no_repeat_ngram_size": 4,
54
  "num_beams": int(num_beams),
@@ -74,10 +74,10 @@ def proc_submission(
74
 
75
  _summaries = summarize_via_tokenbatches(
76
  tr_in,
77
- model_led_det if (model_type == "LED" & summary_type == "detailed") else model_det,
78
- tokenizer_led_det if (model_type == "LED" & summary_type == "detailed") else tokenizer_det,
79
- model_led_tldr if (model_type == "LED" & summary_type == "tldr") else model_tldr,
80
- tokenizer_led_tldr if (model_type == "LED" & summary_type == "tldr") else tokenizer_tldr,
81
  batch_length=token_batch_length,
82
  **settings,
83
  )
@@ -176,7 +176,7 @@ if __name__ == "__main__":
176
  choices=["LongT5", "LED"], label="Model Architecture", value="LongT5"
177
  )
178
  num_beams = gr.Radio(
179
- choices=[2, 3, 4],
180
  label="Beam Search: # of Beams",
181
  value=2,
182
  )
@@ -197,13 +197,13 @@ if __name__ == "__main__":
197
  value=512,
198
  )
199
 
200
- with gr.Row():
201
- repetition_penalty = gr.inputs.Slider(
202
- minimum=1.0,
203
- maximum=5.0,
204
- label="repetition penalty",
205
- default=3.5,
206
- step=0.1,
207
  )
208
  #no_repeat_ngram_size = gr.Radio(
209
  #choices=[2, 3, 4],
@@ -285,7 +285,6 @@ if __name__ == "__main__":
285
  num_beams,
286
  token_batch_length,
287
  length_penalty,
288
- repetition_penalty,
289
  ],
290
  outputs=[output_text, summary_text, compression_rate],
291
  )
 
24
  summary_type: str,
25
  num_beams,
26
  token_batch_length,
27
+ #length_penalty,
28
+ #repetition_penalty,
29
  #no_repeat_ngram_size: int = 3,
30
  max_input_length: int = 768,
31
  ):
 
48
 
49
  settings = {
50
  "length_penalty": float(length_penalty),
51
+ "repetition_penalty": 3.5,#float(repetition_penalty),
52
  "no_repeat_ngram_size": 3,
53
  "encoder_no_repeat_ngram_size": 4,
54
  "num_beams": int(num_beams),
 
74
 
75
  _summaries = summarize_via_tokenbatches(
76
  tr_in,
77
+ model_led_det if (model_type == "LED" and summary_type == "Detailed") else model_det,
78
+ tokenizer_led_det if (model_type == "LED" and summary_type == "Detailed") else tokenizer_det,
79
+ model_led_tldr if (model_type == "LED" and summary_type == "TLDR") else model_tldr,
80
+ tokenizer_led_tldr if (model_type == "LED" and summary_type == "TLDR") else tokenizer_tldr,
81
  batch_length=token_batch_length,
82
  **settings,
83
  )
 
176
  choices=["LongT5", "LED"], label="Model Architecture", value="LongT5"
177
  )
178
  num_beams = gr.Radio(
179
+ choices=[2, 3, 4, 5, 6],
180
  label="Beam Search: # of Beams",
181
  value=2,
182
  )
 
197
  value=512,
198
  )
199
 
200
+ #with gr.Row():
201
+ #repetition_penalty = gr.inputs.Slider(
202
+ #minimum=1.0,
203
+ #maximum=5.0,
204
+ #label="repetition penalty",
205
+ #default=3.5,
206
+ #step=0.1,
207
  )
208
  #no_repeat_ngram_size = gr.Radio(
209
  #choices=[2, 3, 4],
 
285
  num_beams,
286
  token_batch_length,
287
  length_penalty,
 
288
  ],
289
  outputs=[output_text, summary_text, compression_rate],
290
  )