Jimmy0866 commited on
Commit
3bcb0b8
·
verified ·
1 Parent(s): f2628b8

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +666 -0
app.py ADDED
@@ -0,0 +1,666 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ app.py - the main module for the gradio app for summarization
3
+
4
+ Usage:
5
+ app.py [-h] [--share] [-m MODEL] [-nb ADD_BEAM_OPTION] [-batch TOKEN_BATCH_OPTION]
6
+ [-level {DEBUG,INFO,WARNING,ERROR}]
7
+ Details:
8
+ python app.py --help
9
+
10
+ Environment Variables:
11
+ USE_TORCH (str): whether to use torch (1) or not (0)
12
+ TOKENIZERS_PARALLELISM (str): whether to use parallelism (true) or not (false)
13
+ Optional Environment Variables:
14
+ APP_MAX_WORDS (int): the maximum number of words to use for summarization
15
+ APP_OCR_MAX_PAGES (int): the maximum number of pages to use for OCR
16
+ """
17
+
18
+ import argparse
19
+ import contextlib
20
+ import gc
21
+ import logging
22
+ import os
23
+ import pprint as pp
24
+ import random
25
+ import time
26
+ from pathlib import Path
27
+
28
+ os.environ["USE_TORCH"] = "1"
29
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
30
+
31
+ logging.basicConfig(
32
+ level=logging.INFO,
33
+ format="%(asctime)s [%(levelname)s] %(name)s - %(message)s",
34
+ datefmt="%Y-%b-%d %H:%M:%S",
35
+ )
36
+
37
+ import gradio as gr
38
+ import nltk
39
+ import torch
40
+ from cleantext import clean
41
+ from doctr.models import ocr_predictor
42
+
43
+ from aggregate import BatchAggregator
44
+ from pdf2text import convert_PDF_to_Text
45
+ from summarize import load_model_and_tokenizer, summarize_via_tokenbatches
46
+ from utils import (
47
+ contraction_aware_tokenize,
48
+ extract_batches,
49
+ load_example_filenames,
50
+ remove_stagnant_files,
51
+ remove_stopwords,
52
+ saves_summary,
53
+ textlist2html,
54
+ truncate_word_count,
55
+ )
56
+
57
+ _here = Path(__file__).parent
58
+
59
+ nltk.download("punkt", force=True, quiet=True)
60
+ nltk.download("popular", force=True, quiet=True)
61
+
62
+ # Constants & Globals
63
+ MODEL_OPTIONS = [
64
+ "BEE-spoke-data/pegasus-x-base-synthsumm_open-16k",
65
+ "pszemraj/long-t5-tglobal-base-sci-simplify",
66
+ "pszemraj/long-t5-tglobal-base-16384-book-summary",
67
+ "pszemraj/long-t5-tglobal-base-summary-souffle-16384-loD",
68
+ "pszemraj/pegasus-x-large-book_synthsumm",
69
+ "pszemraj/pegasus-x-large-book-summary",
70
+ ] # models users can choose from
71
+ BEAM_OPTIONS = [2, 3, 4] # beam sizes users can choose from
72
+ TOKEN_BATCH_OPTIONS = [
73
+ 1024,
74
+ 1536,
75
+ 2048,
76
+ 2560,
77
+ 3072,
78
+ ] # token batch sizes users can choose from
79
+
80
+ SUMMARY_PLACEHOLDER = "<p><em>Output will appear below:</em></p>"
81
+ AGGREGATE_MODEL = "pszemraj/bart-large-summary-map-reduce" # map-reduce model
82
+
83
+ # if duplicating space: uncomment this line to adjust the max words
84
+ # os.environ["APP_MAX_WORDS"] = str(2048) # set the max words to 2048
85
+ # os.environ["APP_OCR_MAX_PAGES"] = str(40) # set the max pages to 40
86
+ # os.environ["APP_AGG_FORCE_CPU"] = str(1) # force cpu for aggregation
87
+
88
+ aggregator = BatchAggregator(
89
+ AGGREGATE_MODEL, force_cpu=os.environ.get("APP_AGG_FORCE_CPU", False)
90
+ )
91
+
92
+
93
+ def aggregate_text(
94
+ summary_text: str,
95
+ text_file: gr.File = None,
96
+ ) -> str:
97
+ """
98
+ Aggregate the text from the batches.
99
+
100
+ NOTE: you should probably include the BatchAggregator object as a fn arg if using this code
101
+
102
+ :param batches_html: The batches to aggregate, in html format
103
+ :param text_file: The text file to append the aggregate summary to
104
+ :return: The aggregate summary in html format
105
+ """
106
+ if summary_text is None or summary_text == SUMMARY_PLACEHOLDER:
107
+ logging.error("No text provided. Make sure a summary has been generated first.")
108
+ return "Error: No text provided. Make sure a summary has been generated first."
109
+
110
+ try:
111
+ extracted_batches = extract_batches(summary_text)
112
+ except Exception as e:
113
+ logging.info(summary_text)
114
+ logging.info(f"the batches html is: {type(summary_text)}")
115
+ return f"Error: unable to extract batches - check input: {e}"
116
+ if not extracted_batches:
117
+ logging.error("unable to extract batches - check input")
118
+ return "Error: unable to extract batches - check input"
119
+
120
+ out_path = None
121
+ if text_file is not None:
122
+ out_path = text_file.name # assuming name attribute stores the file path
123
+
124
+ content_batches = [batch["content"] for batch in extracted_batches]
125
+ full_summary = aggregator.infer_aggregate(content_batches)
126
+
127
+ # if a path that exists is provided, append the summary with markdown formatting
128
+ if out_path:
129
+ out_path = Path(out_path)
130
+
131
+ try:
132
+ with open(out_path, "a", encoding="utf-8") as f:
133
+ f.write("\n\n## Aggregate Summary\n\n")
134
+ f.write(
135
+ "- This is an instruction-based LLM aggregation of the previous 'summary batches'.\n"
136
+ )
137
+ f.write(f"- Aggregation model: {aggregator.model_name}\n\n")
138
+ f.write(f"{full_summary}\n\n")
139
+ logging.info(f"Updated {out_path} with aggregate summary")
140
+ except Exception as e:
141
+ logging.error(f"unable to update {out_path} with aggregate summary: {e}")
142
+
143
+ full_summary_html = f"""
144
+ <div style="
145
+ margin-bottom: 20px;
146
+ font-size: 18px;
147
+ line-height: 1.5em;
148
+ color: #333;
149
+ ">
150
+ <h2 style="font-size: 22px; color: #555;">Aggregate Summary:</h2>
151
+ <p style="white-space: pre-line;">{full_summary}</p>
152
+ </div>
153
+ """
154
+ return full_summary_html
155
+
156
+
157
+ def predict(
158
+ input_text: str,
159
+ model_name: str,
160
+ token_batch_length: int = 1024,
161
+ empty_cache: bool = True,
162
+ **settings,
163
+ ) -> list:
164
+ """
165
+ predict - helper fn to support multiple models for summarization at once
166
+
167
+ :param str input_text: the input text to summarize
168
+ :param str model_name: model name to use
169
+ :param int token_batch_length: the length of the token batches to use
170
+ :param bool empty_cache: whether to empty the cache before loading a new= model
171
+ :return: list of dicts with keys "summary" and "score"
172
+ """
173
+ if torch.cuda.is_available() and empty_cache:
174
+ torch.cuda.empty_cache()
175
+
176
+ model, tokenizer = load_model_and_tokenizer(model_name)
177
+ summaries = summarize_via_tokenbatches(
178
+ input_text,
179
+ model,
180
+ tokenizer,
181
+ batch_length=token_batch_length,
182
+ **settings,
183
+ )
184
+
185
+ del model
186
+ del tokenizer
187
+ gc.collect()
188
+
189
+ return summaries
190
+
191
+
192
+ def proc_submission(
193
+ input_text: str,
194
+ model_name: str,
195
+ num_beams: int,
196
+ token_batch_length: int,
197
+ length_penalty: float,
198
+ repetition_penalty: float,
199
+ no_repeat_ngram_size: int,
200
+ predrop_stopwords: bool,
201
+ max_input_length: int = 6144,
202
+ ):
203
+ """
204
+ proc_submission - a helper function for the gradio module to process submissions
205
+
206
+ Args:
207
+ input_text (str): the input text to summarize
208
+ model_name (str): the hf model tag of the model to use
209
+ num_beams (int): the number of beams to use
210
+ token_batch_length (int): the length of the token batches to use
211
+ length_penalty (float): the length penalty to use
212
+ repetition_penalty (float): the repetition penalty to use
213
+ no_repeat_ngram_size (int): the no repeat ngram size to use
214
+ predrop_stopwords (bool): whether to pre-drop stopwords before truncating/summarizing
215
+ max_input_length (int, optional): the maximum input length to use. Defaults to 6144.
216
+
217
+ Note:
218
+ the max_input_length is set to 6144 by default, but can be changed by setting the
219
+ environment variable APP_MAX_WORDS to a different value.
220
+
221
+ Returns:
222
+ tuple (4): a tuple containing the following:
223
+ """
224
+
225
+ remove_stagnant_files() # clean up old files
226
+ settings = {
227
+ "length_penalty": float(length_penalty),
228
+ "repetition_penalty": float(repetition_penalty),
229
+ "no_repeat_ngram_size": int(no_repeat_ngram_size),
230
+ "encoder_no_repeat_ngram_size": 4,
231
+ "num_beams": int(num_beams),
232
+ "min_length": 4,
233
+ "max_length": int(token_batch_length // 4),
234
+ "early_stopping": True,
235
+ "do_sample": False,
236
+ }
237
+ max_input_length = int(os.environ.get("APP_MAX_WORDS", max_input_length))
238
+ logging.info(
239
+ f"max_input_length set to: {max_input_length}. pre-drop stopwords: {predrop_stopwords}"
240
+ )
241
+
242
+ st = time.perf_counter()
243
+ history = {}
244
+ cln_text = clean(input_text, lower=False)
245
+ parsed_cln_text = remove_stopwords(cln_text) if predrop_stopwords else cln_text
246
+ logging.info(
247
+ f"pre-truncation word count: {len(contraction_aware_tokenize(parsed_cln_text))}"
248
+ )
249
+ truncation_validated = truncate_word_count(
250
+ parsed_cln_text, max_words=max_input_length
251
+ )
252
+
253
+ if truncation_validated["was_truncated"]:
254
+ model_input_text = truncation_validated["processed_text"]
255
+ # create elaborate HTML warning
256
+ input_wc = len(contraction_aware_tokenize(parsed_cln_text))
257
+ msg = f"""
258
+ <div style="background-color: #FFA500; color: white; padding: 20px;">
259
+ <h3>Warning</h3>
260
+ <p>Input text was truncated to {max_input_length} words. That's about {100*max_input_length/input_wc:.2f}% of the original text.</p>
261
+ <p>Dropping stopwords is set to {predrop_stopwords}. If this is not what you intended, please validate the advanced settings.</p>
262
+ </div>
263
+ """
264
+ logging.warning(msg)
265
+ history["WARNING"] = msg
266
+ else:
267
+ model_input_text = truncation_validated["processed_text"]
268
+ msg = None
269
+
270
+ if len(input_text) < 50:
271
+ # this is essentially a different case from the above
272
+ msg = f"""
273
+ <div style="background-color: #880808; color: white; padding: 20px;">
274
+ <br>
275
+ <img src="https://i.imgflip.com/7kadd9.jpg" alt="no text">
276
+ <br>
277
+ <h3>Error</h3>
278
+ <p>Input text is too short to summarize. Detected {len(input_text)} characters.
279
+ Please load text by selecting an example from the dropdown menu or by pasting text into the text box.</p>
280
+ </div>
281
+ """
282
+ logging.warning(msg)
283
+ logging.warning("RETURNING EMPTY STRING")
284
+ history["WARNING"] = msg
285
+
286
+ return msg, "<strong>No summary generated.</strong>", "", []
287
+
288
+ _summaries = predict(
289
+ input_text=model_input_text,
290
+ model_name=model_name,
291
+ token_batch_length=token_batch_length,
292
+ **settings,
293
+ )
294
+ sum_text = [s["summary"][0].strip() + "\n" for s in _summaries]
295
+ sum_scores = [
296
+ f" - Batch Summary {i}: {round(s['summary_score'],4)}"
297
+ for i, s in enumerate(_summaries)
298
+ ]
299
+
300
+ full_summary = textlist2html(sum_text)
301
+ history["Summary Scores"] = "<br><br>"
302
+ scores_out = "\n".join(sum_scores)
303
+ rt = round((time.perf_counter() - st) / 60, 2)
304
+ logging.info(f"Runtime: {rt} minutes")
305
+ html = ""
306
+ html += f"<p>Runtime: {rt} minutes with model: {model_name}</p>"
307
+ if msg is not None:
308
+ html += msg
309
+
310
+ html += ""
311
+
312
+ settings["remove_stopwords"] = predrop_stopwords
313
+ settings["model_name"] = model_name
314
+ saved_file = saves_summary(summarize_output=_summaries, outpath=None, **settings)
315
+ return html, full_summary, scores_out, saved_file
316
+
317
+
318
+ def load_single_example_text(
319
+ example_path: str or Path,
320
+ max_pages: int = 20,
321
+ ) -> str:
322
+ """
323
+ load_single_example_text - loads a single example text file
324
+
325
+ :param strorPath example_path: name of the example to load
326
+ :param int max_pages: the maximum number of pages to load from a PDF
327
+ :return str: the text of the example
328
+ """
329
+ global name_to_path, ocr_model
330
+ full_ex_path = name_to_path[example_path]
331
+ full_ex_path = Path(full_ex_path)
332
+ if full_ex_path.suffix in [".txt", ".md"]:
333
+ with open(full_ex_path, "r", encoding="utf-8", errors="ignore") as f:
334
+ raw_text = f.read()
335
+ text = clean(raw_text, lower=False)
336
+ elif full_ex_path.suffix == ".pdf":
337
+ logging.info(f"Loading PDF file {full_ex_path}")
338
+ max_pages = int(os.environ.get("APP_OCR_MAX_PAGES", max_pages))
339
+ logging.info(f"max_pages set to: {max_pages}")
340
+ conversion_stats = convert_PDF_to_Text(
341
+ full_ex_path,
342
+ ocr_model=ocr_model,
343
+ max_pages=max_pages,
344
+ )
345
+ text = conversion_stats["converted_text"]
346
+ else:
347
+ logging.error(f"Unknown file type {full_ex_path.suffix}")
348
+ text = "ERROR - check example path"
349
+
350
+ return text
351
+
352
+
353
+ def load_uploaded_file(file_obj, max_pages: int = 20, lower: bool = False) -> str:
354
+ """
355
+ load_uploaded_file - loads a file uploaded by the user
356
+
357
+ :param file_obj (POTENTIALLY list): Gradio file object inside a list
358
+ :param int max_pages: the maximum number of pages to load from a PDF
359
+ :param bool lower: whether to lowercase the text
360
+ :return str: the text of the file
361
+ """
362
+ global ocr_model
363
+ logger = logging.getLogger(__name__)
364
+ # check if mysterious file object is a list
365
+ if isinstance(file_obj, list):
366
+ file_obj = file_obj[0]
367
+ file_path = Path(file_obj.name)
368
+ try:
369
+ logger.info(f"Loading file:\t{file_path}")
370
+ if file_path.suffix in [".txt", ".md"]:
371
+ with open(file_path, "r", encoding="utf-8", errors="ignore") as f:
372
+ raw_text = f.read()
373
+ text = clean(raw_text, lower=lower)
374
+ elif file_path.suffix == ".pdf":
375
+ logger.info(f"loading a PDF file: {file_path.name}")
376
+ max_pages = int(os.environ.get("APP_OCR_MAX_PAGES", max_pages))
377
+ logger.info(f"max_pages is: {max_pages}. Starting conversion...")
378
+ conversion_stats = convert_PDF_to_Text(
379
+ file_path,
380
+ ocr_model=ocr_model,
381
+ max_pages=max_pages,
382
+ )
383
+ text = conversion_stats["converted_text"]
384
+ else:
385
+ logger.error(f"Unknown file type:\t{file_path.suffix}")
386
+ text = "ERROR - check file - unknown file type. PDF, TXT, and MD are supported."
387
+
388
+ return text
389
+ except Exception as e:
390
+ logger.error(f"Trying to load file:\t{file_path},\nerror:\t{e}")
391
+ return f"Error: Could not read file {file_path.name}. Make sure it is a PDF, TXT, or MD file."
392
+
393
+
394
+ def parse_args():
395
+ """arguments for the command line interface"""
396
+ parser = argparse.ArgumentParser(
397
+ description="Document Summarization with Long-Document Transformers - Demo",
398
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
399
+ epilog="Runs a local-only web UI to summarize documents. pass --share for a public link to share.",
400
+ )
401
+
402
+ parser.add_argument(
403
+ "--share",
404
+ dest="share",
405
+ action="store_true",
406
+ help="Create a public link to share",
407
+ )
408
+ parser.add_argument(
409
+ "-m",
410
+ "--model",
411
+ type=str,
412
+ default=None,
413
+ help=f"Add a custom model to the list of models: {pp.pformat(MODEL_OPTIONS, compact=True)}",
414
+ )
415
+ parser.add_argument(
416
+ "-nb",
417
+ "--add_beam_option",
418
+ type=int,
419
+ default=None,
420
+ help=f"Add a beam search option to the demo UI options, default: {pp.pformat(BEAM_OPTIONS, compact=True)}",
421
+ )
422
+ parser.add_argument(
423
+ "-batch",
424
+ "--token_batch_option",
425
+ type=int,
426
+ default=None,
427
+ help=f"Add a token batch size to the demo UI options, default: {pp.pformat(TOKEN_BATCH_OPTIONS, compact=True)}",
428
+ )
429
+ parser.add_argument(
430
+ "-max_agg",
431
+ "-2x",
432
+ "--aggregator_beam_boost",
433
+ dest="aggregator_beam_boost",
434
+ action="store_true",
435
+ help="Double the number of beams for the aggregator during beam search",
436
+ )
437
+ parser.add_argument(
438
+ "-level",
439
+ "--log_level",
440
+ type=str,
441
+ default="INFO",
442
+ choices=["DEBUG", "INFO", "WARNING", "ERROR"],
443
+ help="Set the logging level",
444
+ )
445
+
446
+ return parser.parse_args()
447
+
448
+
449
+ if __name__ == "__main__":
450
+ """main - the main function of the app"""
451
+ logger = logging.getLogger(__name__)
452
+ args = parse_args()
453
+ logger.setLevel(args.log_level)
454
+ logger.info(f"args: {pp.pformat(args.__dict__, compact=True)}")
455
+
456
+ # add any custom options
457
+ if args.model is not None:
458
+ logger.info(f"Adding model {args.model} to the list of models")
459
+ MODEL_OPTIONS.append(args.model)
460
+ if args.add_beam_option is not None:
461
+ logger.info(f"Adding beam search option {args.add_beam_option} to the list")
462
+ BEAM_OPTIONS.append(args.add_beam_option)
463
+ if args.token_batch_option is not None:
464
+ logger.info(f"Adding token batch option {args.token_batch_option} to the list")
465
+ TOKEN_BATCH_OPTIONS.append(args.token_batch_option)
466
+
467
+ if args.aggregator_beam_boost:
468
+ logger.info("Doubling aggregator num_beams")
469
+ _agg_cfg = aggregator.get_generation_config()
470
+ _agg_cfg["num_beams"] = _agg_cfg["num_beams"] * 2
471
+ aggregator.update_generation_config(**_agg_cfg)
472
+
473
+ logger.info("Loading OCR model")
474
+ with contextlib.redirect_stdout(None):
475
+ ocr_model = ocr_predictor(
476
+ "db_resnet50",
477
+ "crnn_mobilenet_v3_large",
478
+ pretrained=True,
479
+ assume_straight_pages=True,
480
+ )
481
+
482
+ # load the examples
483
+ name_to_path = load_example_filenames(_here / "examples")
484
+ logger.info(f"Loaded {len(name_to_path)} examples")
485
+
486
+ demo = gr.Blocks(title="Document Summarization")
487
+ _examples = list(name_to_path.keys())
488
+ logger.info("Starting app instance")
489
+ with demo:
490
+ gr.Markdown(
491
+ """# Document Summarization with Long-Document Transformers
492
+
493
+ An example use case for fine-tuned long document transformers. Model(s) are trained on [book summaries](https://hf.co/datasets/kmfoda/booksum). Architectures [in this demo](https://hf.co/spaces/pszemraj/document-summarization) are [LongT5-base](https://hf.co/pszemraj/long-t5-tglobal-base-16384-book-summary) and [Pegasus-X-Large](https://hf.co/pszemraj/pegasus-x-large-book-summary).
494
+
495
+ **Want more performance?** Run this demo from a free [Google Colab GPU](https://colab.research.google.com/gist/pszemraj/52f67cf7326e780155812a6a1f9bb724/document-summarization-on-gpu.ipynb)
496
+ """
497
+ )
498
+ with gr.Column():
499
+ gr.Markdown(
500
+ """## Load Inputs & Select Parameters
501
+
502
+ Enter/paste text below, or upload a file. Pick a model & adjust params (_optional_), and press **Summarize!**
503
+
504
+ See [the guide doc](https://gist.github.com/pszemraj/722a7ba443aa3a671b02d87038375519) for details.
505
+ """
506
+ )
507
+ with gr.Row():
508
+ with gr.Column(variant="compact"):
509
+ model_name = gr.Dropdown(
510
+ choices=MODEL_OPTIONS,
511
+ value=MODEL_OPTIONS[0],
512
+ label="Model Name",
513
+ )
514
+ num_beams = gr.Radio(
515
+ choices=BEAM_OPTIONS,
516
+ value=BEAM_OPTIONS[len(BEAM_OPTIONS) // 2],
517
+ label="Beam Search: # of Beams",
518
+ )
519
+ load_examples_button = gr.Button(
520
+ "Load Example in Dropdown",
521
+ )
522
+ load_file_button = gr.Button("Upload & Process File")
523
+ with gr.Column(variant="compact"):
524
+ example_name = gr.Dropdown(
525
+ _examples,
526
+ label="Examples",
527
+ value=random.choice(_examples),
528
+ )
529
+ uploaded_file = gr.File(
530
+ label="File Upload",
531
+ file_count="single",
532
+ file_types=[".txt", ".md", ".pdf"],
533
+ type="filepath",
534
+ )
535
+ with gr.Row():
536
+ input_text = gr.Textbox(
537
+ lines=4,
538
+ max_lines=8,
539
+ label="Text to Summarize",
540
+ placeholder="Enter text to summarize, the text will be cleaned and truncated on Spaces. Narrative, academic (both papers and lecture transcription), and article text work well. May take a bit to generate depending on the input text :)",
541
+ )
542
+ with gr.Column():
543
+ gr.Markdown("## Generate Summary")
544
+ with gr.Row():
545
+ summarize_button = gr.Button(
546
+ "Summarize!",
547
+ variant="primary",
548
+ )
549
+ gr.Markdown(
550
+ "_Summarization should take ~1-2 minutes for most settings, but may extend up to 5-10 minutes in some scenarios._"
551
+ )
552
+ output_text = gr.HTML("<em>Output will appear below:</em>")
553
+ with gr.Column():
554
+ gr.Markdown("### Results & Scores")
555
+ with gr.Row():
556
+ with gr.Column(variant="compact"):
557
+ gr.Markdown(
558
+ "Download the summary as a text file, with parameters and scores."
559
+ )
560
+ text_file = gr.File(
561
+ label="Download as Text File",
562
+ file_count="single",
563
+ type="filepath",
564
+ interactive=False,
565
+ )
566
+ with gr.Column(variant="compact"):
567
+ gr.Markdown(
568
+ "Scores **roughly** represent the summary quality as a measure of the model's 'confidence'. less-negative numbers (closer to 0) are better."
569
+ )
570
+ summary_scores = gr.Textbox(
571
+ label="Summary Scores",
572
+ placeholder="Summary scores will appear here",
573
+ )
574
+ with gr.Column(variant="panel"):
575
+ gr.Markdown("### **Summary Output**")
576
+ summary_text = gr.HTML(
577
+ label="Summary",
578
+ value="<i>Summary will appear here!</i>",
579
+ )
580
+ with gr.Column():
581
+ gr.Markdown("### **Aggregate Summary Batches**")
582
+ with gr.Row():
583
+ aggregate_button = gr.Button(
584
+ "Aggregate!",
585
+ variant="primary",
586
+ )
587
+ gr.Markdown(
588
+ f"""Aggregate the above batches into a cohesive summary.
589
+ - A secondary instruct-tuned LM consolidates info
590
+ - Current model: [{AGGREGATE_MODEL}](https://hf.co/{AGGREGATE_MODEL})
591
+ """
592
+ )
593
+ with gr.Column(variant="panel"):
594
+ aggregated_summary = gr.HTML(
595
+ label="Aggregate Summary",
596
+ value="<i>Aggregate summary will appear here!</i>",
597
+ )
598
+
599
+ with gr.Column():
600
+ gr.Markdown(
601
+ """### Advanced Settings
602
+
603
+ Refer to [the guide doc](https://gist.github.com/pszemraj/722a7ba443aa3a671b02d87038375519) for what these are, and how they impact _quality_ and _speed_.
604
+ """
605
+ )
606
+ with gr.Row():
607
+ length_penalty = gr.Slider(
608
+ minimum=0.3,
609
+ maximum=1.1,
610
+ label="length penalty",
611
+ value=0.7,
612
+ step=0.05,
613
+ )
614
+ token_batch_length = gr.Radio(
615
+ choices=TOKEN_BATCH_OPTIONS,
616
+ label="token batch length",
617
+ # select median option
618
+ value=TOKEN_BATCH_OPTIONS[len(TOKEN_BATCH_OPTIONS) // 2],
619
+ )
620
+
621
+ with gr.Row():
622
+ repetition_penalty = gr.Slider(
623
+ minimum=1.0,
624
+ maximum=5.0,
625
+ label="repetition penalty",
626
+ value=1.5,
627
+ step=0.1,
628
+ )
629
+ no_repeat_ngram_size = gr.Radio(
630
+ choices=[2, 3, 4, 5],
631
+ label="no repeat ngram size",
632
+ value=3,
633
+ )
634
+ predrop_stopwords = gr.Checkbox(
635
+ label="Drop Stopwords (Pre-Truncation)",
636
+ value=False,
637
+ )
638
+
639
+ load_examples_button.click(
640
+ fn=load_single_example_text, inputs=[example_name], outputs=[input_text]
641
+ )
642
+
643
+ load_file_button.click(
644
+ fn=load_uploaded_file, inputs=uploaded_file, outputs=[input_text]
645
+ )
646
+
647
+ summarize_button.click(
648
+ fn=proc_submission,
649
+ inputs=[
650
+ input_text,
651
+ model_name,
652
+ num_beams,
653
+ token_batch_length,
654
+ length_penalty,
655
+ repetition_penalty,
656
+ no_repeat_ngram_size,
657
+ predrop_stopwords,
658
+ ],
659
+ outputs=[output_text, summary_text, summary_scores, text_file],
660
+ )
661
+ aggregate_button.click(
662
+ fn=aggregate_text,
663
+ inputs=[summary_text, text_file],
664
+ outputs=[aggregated_summary],
665
+ )
666
+ demo.launch(share=args.share, debug=True)