studyonly commited on
Commit
7e2bc05
·
verified ·
1 Parent(s): 445bb59

Doc Summarizer version 1

Browse files
Files changed (9) hide show
  1. README.md +29 -9
  2. aggregate.py +192 -0
  3. app.py +666 -0
  4. gitattributes +31 -0
  5. gitignore +29 -0
  6. pdf2text.py +346 -0
  7. requirements.txt +12 -0
  8. summarize.py +177 -0
  9. utils.py +450 -0
README.md CHANGED
@@ -1,14 +1,34 @@
1
  ---
2
- title: DOC Summarizer Studyonly
3
- emoji: 🐢
4
- colorFrom: indigo
5
- colorTo: pink
6
  sdk: gradio
7
- sdk_version: 5.36.2
8
  app_file: app.py
9
- pinned: false
10
- license: mit
11
- short_description: DOC Summarizer by studyonly
12
  ---
13
 
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Document Summarization
3
+ emoji: 🌖
4
+ colorFrom: gray
5
+ colorTo: indigo
6
  sdk: gradio
7
+ sdk_version: 5.12.0
8
  app_file: app.py
9
+ pinned: true
10
+ license: apache-2.0
11
+ short_description: text2text models for document summarization
12
  ---
13
 
14
+ Check out the configuration reference at <https://huggingface.co/docs/hub/spaces-config-reference>
15
+
16
+ # README - Document Summarization
17
+
18
+ The original demo/what this repo was built for can be found [here](https://huggingface.co/spaces/pszemraj/document-summarization)
19
+
20
+ ## Usage
21
+
22
+ If you are using this **not** as a gradio demo on hf spaces, you can run it locally with:
23
+
24
+ ```bash
25
+ python app.py --share
26
+ ```
27
+
28
+ To see all the available arguments, run `python app.py --help`.
29
+
30
+ ## Installation
31
+
32
+ ```bash
33
+ pip install -r requirements.txt
34
+ ```
aggregate.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ aggregate.py - module for 'reducing' multiple 'summary chunks' into one
3
+
4
+ an overly complicated class for legacy compatibility reasons, for usage of the
5
+ 2024 map-reduce models see hf.co/pszemraj/bart-large-summary-map-reduce#usage
6
+ """
7
+
8
+ import logging
9
+ import pprint as pp
10
+ import time
11
+
12
+ import torch
13
+ from transformers import GenerationConfig, pipeline
14
+
15
+ # Setting up logging
16
+ logging.basicConfig(
17
+ level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
18
+ )
19
+
20
+
21
+ class BatchAggregator:
22
+ """
23
+ BatchAggregator is a class for aggregating text from multiple sources.
24
+
25
+ Usage:
26
+ from aggregate import BatchAggregator
27
+ aggregator = BatchAggregator()
28
+ agg = aggregator.infer_aggregate(["This is a test", "This is another test"])
29
+ print(agg)
30
+ """
31
+
32
+ GENERIC_CONFIG = GenerationConfig(
33
+ max_new_tokens=512,
34
+ num_beams=4,
35
+ early_stopping=True,
36
+ do_sample=False,
37
+ truncation=True,
38
+ )
39
+
40
+ def __init__(
41
+ self,
42
+ model_name: str = "pszemraj/bart-large-summary-map-reduce",
43
+ force_cpu: bool = False,
44
+ **kwargs,
45
+ ):
46
+ """
47
+ __init__ initializes the BatchAggregator class.
48
+
49
+ :param str model_name: model name to use, default: "pszemraj/bart-large-summary-map-reduce"
50
+ :param bool force_cpu: force the model to run on CPU, default: False
51
+ """
52
+ self.device = None
53
+ self.is_compiled = False
54
+ self.model_name = None
55
+ self.aggregator = None
56
+ self.force_cpu = force_cpu
57
+ self.logger = logging.getLogger(__name__)
58
+ self.init_model(model_name)
59
+
60
+ def init_model(self, model_name: str) -> None:
61
+ """
62
+ Initialize the model.
63
+
64
+ :param model_name: The name of the model to use.
65
+ """
66
+ # Free up memory
67
+ if torch.cuda.is_available():
68
+ torch.cuda.empty_cache()
69
+
70
+ self.logger.info(f"Setting model to {model_name}")
71
+ self.model_name = model_name
72
+ self.aggregator = self._create_pipeline(model_name)
73
+ self._configure_model()
74
+
75
+ def _create_pipeline(
76
+ self, model_name: str = "pszemraj/bart-large-summary-map-reduce"
77
+ ) -> pipeline:
78
+ """
79
+ _create_pipeline creates a pipeline for the model.
80
+
81
+ :param str model_name: model name to use
82
+ :return pipeline: the pipeline for the model
83
+
84
+ :raises Exception: if the pipeline cannot be created
85
+ """
86
+ device_map = (
87
+ "auto" if torch.cuda.is_available() and not self.force_cpu else "cpu"
88
+ )
89
+ try:
90
+ self.logger.info(
91
+ f"Creating pipeline with model {model_name} on device {device_map}"
92
+ )
93
+ return pipeline(
94
+ "text2text-generation",
95
+ model=model_name,
96
+ device_map=device_map,
97
+ torch_dtype=torch.float32,
98
+ )
99
+ except Exception as e:
100
+ self.logger.error(f"Failed to create pipeline: {e}")
101
+ raise
102
+
103
+ def _configure_model(self):
104
+ """
105
+ Configure the model for generation.
106
+ """
107
+ try:
108
+ self.aggregator.model = torch.compile(self.aggregator.model)
109
+ self.is_compiled = True
110
+ except Exception as e:
111
+ self.logger.warning(f"Could not compile model with Torch 2.0: {e}")
112
+
113
+ self._set_default_generation_config()
114
+ self.logger.info(self.aggregator.model.generation_config.to_json_string())
115
+
116
+ def _set_default_generation_config(self):
117
+ """
118
+ Set the default generation configuration for the model.
119
+ """
120
+ self.aggregator.model.generation_config.update(
121
+ **self.GENERIC_CONFIG.to_diff_dict()
122
+ )
123
+
124
+ def update_generation_config(self, **kwargs):
125
+ """
126
+ Update the generation configuration with the specified parameters.
127
+
128
+ Args:
129
+ **kwargs: The parameters to update in the generation configuration.
130
+ """
131
+ self.logger.info(f"Updating generation config with {pp.pformat(kwargs)}")
132
+ self.aggregator.model.generation_config.update(**kwargs)
133
+
134
+ def get_generation_config(self) -> dict:
135
+ """
136
+ Get the current generation configuration.
137
+
138
+ Returns:
139
+ dict: The current generation configuration.
140
+ """
141
+ return self.aggregator.model.generation_config.to_dict()
142
+
143
+ def update_loglevel(self, level: str = "INFO"):
144
+ """
145
+ Update the log level.
146
+
147
+ Args:
148
+ level (str): The log level to set. Defaults to "INFO".
149
+ """
150
+ self.logger.setLevel(level)
151
+
152
+ def infer_aggregate(
153
+ self,
154
+ text_list: list,
155
+ instruction: str = None, # Kept for backward compatibility but not used
156
+ **kwargs,
157
+ ) -> str:
158
+ """
159
+ infer_aggregate - infers a consolidated summary from a list of texts.
160
+
161
+ Args:
162
+ text_list (list): The texts to summarize.
163
+ instruction (str): Not used by this model, kept for compatibility.
164
+ **kwargs: Additional parameters to update in the generation configuration.
165
+
166
+ Returns:
167
+ The generated summary.
168
+ """
169
+ joined_text = "\n\n".join(text_list)
170
+ if kwargs:
171
+ self.update_generation_config(**kwargs)
172
+ st = time.perf_counter()
173
+ self.logger.info(f"inference on {len(text_list)} texts ...")
174
+ result = self.aggregator(
175
+ joined_text,
176
+ generation_config=self.aggregator.model.generation_config,
177
+ )[0]["generated_text"]
178
+ self.logger.info(f"Done. runtime:\t{round(time.perf_counter() - st, 2)}s")
179
+ self.logger.info(
180
+ f"Input tokens:\t{self.count_tokens(joined_text)}. Output tokens:\t{self.count_tokens(result)}"
181
+ )
182
+ self.logger.debug(f"Generated text:\n{result}")
183
+
184
+ return result
185
+
186
+ def count_tokens(self, text: str) -> int:
187
+ """count the number of tokens in a text"""
188
+ return (
189
+ len(self.aggregator.tokenizer.encode(text, truncation=False, padding=False))
190
+ if text
191
+ else 0
192
+ )
app.py ADDED
@@ -0,0 +1,666 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ app.py - the main module for the gradio app for summarization
3
+
4
+ Usage:
5
+ app.py [-h] [--share] [-m MODEL] [-nb ADD_BEAM_OPTION] [-batch TOKEN_BATCH_OPTION]
6
+ [-level {DEBUG,INFO,WARNING,ERROR}]
7
+ Details:
8
+ python app.py --help
9
+
10
+ Environment Variables:
11
+ USE_TORCH (str): whether to use torch (1) or not (0)
12
+ TOKENIZERS_PARALLELISM (str): whether to use parallelism (true) or not (false)
13
+ Optional Environment Variables:
14
+ APP_MAX_WORDS (int): the maximum number of words to use for summarization
15
+ APP_OCR_MAX_PAGES (int): the maximum number of pages to use for OCR
16
+ """
17
+
18
+ import argparse
19
+ import contextlib
20
+ import gc
21
+ import logging
22
+ import os
23
+ import pprint as pp
24
+ import random
25
+ import time
26
+ from pathlib import Path
27
+
28
+ os.environ["USE_TORCH"] = "1"
29
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
30
+
31
+ logging.basicConfig(
32
+ level=logging.INFO,
33
+ format="%(asctime)s [%(levelname)s] %(name)s - %(message)s",
34
+ datefmt="%Y-%b-%d %H:%M:%S",
35
+ )
36
+
37
+ import gradio as gr
38
+ import nltk
39
+ import torch
40
+ from cleantext import clean
41
+ from doctr.models import ocr_predictor
42
+
43
+ from aggregate import BatchAggregator
44
+ from pdf2text import convert_PDF_to_Text
45
+ from summarize import load_model_and_tokenizer, summarize_via_tokenbatches
46
+ from utils import (
47
+ contraction_aware_tokenize,
48
+ extract_batches,
49
+ load_example_filenames,
50
+ remove_stagnant_files,
51
+ remove_stopwords,
52
+ saves_summary,
53
+ textlist2html,
54
+ truncate_word_count,
55
+ )
56
+
57
+ _here = Path(__file__).parent
58
+
59
+ nltk.download("punkt", force=True, quiet=True)
60
+ nltk.download("popular", force=True, quiet=True)
61
+
62
+ # Constants & Globals
63
+ MODEL_OPTIONS = [
64
+ "BEE-spoke-data/pegasus-x-base-synthsumm_open-16k",
65
+ "pszemraj/long-t5-tglobal-base-sci-simplify",
66
+ "pszemraj/long-t5-tglobal-base-16384-book-summary",
67
+ "pszemraj/long-t5-tglobal-base-summary-souffle-16384-loD",
68
+ "pszemraj/pegasus-x-large-book_synthsumm",
69
+ "pszemraj/pegasus-x-large-book-summary",
70
+ ] # models users can choose from
71
+ BEAM_OPTIONS = [2, 3, 4] # beam sizes users can choose from
72
+ TOKEN_BATCH_OPTIONS = [
73
+ 1024,
74
+ 1536,
75
+ 2048,
76
+ 2560,
77
+ 3072,
78
+ ] # token batch sizes users can choose from
79
+
80
+ SUMMARY_PLACEHOLDER = "<p><em>Output will appear below:</em></p>"
81
+ AGGREGATE_MODEL = "pszemraj/bart-large-summary-map-reduce" # map-reduce model
82
+
83
+ # if duplicating space: uncomment this line to adjust the max words
84
+ # os.environ["APP_MAX_WORDS"] = str(2048) # set the max words to 2048
85
+ # os.environ["APP_OCR_MAX_PAGES"] = str(40) # set the max pages to 40
86
+ # os.environ["APP_AGG_FORCE_CPU"] = str(1) # force cpu for aggregation
87
+
88
+ aggregator = BatchAggregator(
89
+ AGGREGATE_MODEL, force_cpu=os.environ.get("APP_AGG_FORCE_CPU", False)
90
+ )
91
+
92
+
93
+ def aggregate_text(
94
+ summary_text: str,
95
+ text_file: gr.File = None,
96
+ ) -> str:
97
+ """
98
+ Aggregate the text from the batches.
99
+
100
+ NOTE: you should probably include the BatchAggregator object as a fn arg if using this code
101
+
102
+ :param batches_html: The batches to aggregate, in html format
103
+ :param text_file: The text file to append the aggregate summary to
104
+ :return: The aggregate summary in html format
105
+ """
106
+ if summary_text is None or summary_text == SUMMARY_PLACEHOLDER:
107
+ logging.error("No text provided. Make sure a summary has been generated first.")
108
+ return "Error: No text provided. Make sure a summary has been generated first."
109
+
110
+ try:
111
+ extracted_batches = extract_batches(summary_text)
112
+ except Exception as e:
113
+ logging.info(summary_text)
114
+ logging.info(f"the batches html is: {type(summary_text)}")
115
+ return f"Error: unable to extract batches - check input: {e}"
116
+ if not extracted_batches:
117
+ logging.error("unable to extract batches - check input")
118
+ return "Error: unable to extract batches - check input"
119
+
120
+ out_path = None
121
+ if text_file is not None:
122
+ out_path = text_file.name # assuming name attribute stores the file path
123
+
124
+ content_batches = [batch["content"] for batch in extracted_batches]
125
+ full_summary = aggregator.infer_aggregate(content_batches)
126
+
127
+ # if a path that exists is provided, append the summary with markdown formatting
128
+ if out_path:
129
+ out_path = Path(out_path)
130
+
131
+ try:
132
+ with open(out_path, "a", encoding="utf-8") as f:
133
+ f.write("\n\n## Aggregate Summary\n\n")
134
+ f.write(
135
+ "- This is an instruction-based LLM aggregation of the previous 'summary batches'.\n"
136
+ )
137
+ f.write(f"- Aggregation model: {aggregator.model_name}\n\n")
138
+ f.write(f"{full_summary}\n\n")
139
+ logging.info(f"Updated {out_path} with aggregate summary")
140
+ except Exception as e:
141
+ logging.error(f"unable to update {out_path} with aggregate summary: {e}")
142
+
143
+ full_summary_html = f"""
144
+ <div style="
145
+ margin-bottom: 20px;
146
+ font-size: 18px;
147
+ line-height: 1.5em;
148
+ color: #333;
149
+ ">
150
+ <h2 style="font-size: 22px; color: #555;">Aggregate Summary:</h2>
151
+ <p style="white-space: pre-line;">{full_summary}</p>
152
+ </div>
153
+ """
154
+ return full_summary_html
155
+
156
+
157
+ def predict(
158
+ input_text: str,
159
+ model_name: str,
160
+ token_batch_length: int = 1024,
161
+ empty_cache: bool = True,
162
+ **settings,
163
+ ) -> list:
164
+ """
165
+ predict - helper fn to support multiple models for summarization at once
166
+
167
+ :param str input_text: the input text to summarize
168
+ :param str model_name: model name to use
169
+ :param int token_batch_length: the length of the token batches to use
170
+ :param bool empty_cache: whether to empty the cache before loading a new= model
171
+ :return: list of dicts with keys "summary" and "score"
172
+ """
173
+ if torch.cuda.is_available() and empty_cache:
174
+ torch.cuda.empty_cache()
175
+
176
+ model, tokenizer = load_model_and_tokenizer(model_name)
177
+ summaries = summarize_via_tokenbatches(
178
+ input_text,
179
+ model,
180
+ tokenizer,
181
+ batch_length=token_batch_length,
182
+ **settings,
183
+ )
184
+
185
+ del model
186
+ del tokenizer
187
+ gc.collect()
188
+
189
+ return summaries
190
+
191
+
192
+ def proc_submission(
193
+ input_text: str,
194
+ model_name: str,
195
+ num_beams: int,
196
+ token_batch_length: int,
197
+ length_penalty: float,
198
+ repetition_penalty: float,
199
+ no_repeat_ngram_size: int,
200
+ predrop_stopwords: bool,
201
+ max_input_length: int = 6144,
202
+ ):
203
+ """
204
+ proc_submission - a helper function for the gradio module to process submissions
205
+
206
+ Args:
207
+ input_text (str): the input text to summarize
208
+ model_name (str): the hf model tag of the model to use
209
+ num_beams (int): the number of beams to use
210
+ token_batch_length (int): the length of the token batches to use
211
+ length_penalty (float): the length penalty to use
212
+ repetition_penalty (float): the repetition penalty to use
213
+ no_repeat_ngram_size (int): the no repeat ngram size to use
214
+ predrop_stopwords (bool): whether to pre-drop stopwords before truncating/summarizing
215
+ max_input_length (int, optional): the maximum input length to use. Defaults to 6144.
216
+
217
+ Note:
218
+ the max_input_length is set to 6144 by default, but can be changed by setting the
219
+ environment variable APP_MAX_WORDS to a different value.
220
+
221
+ Returns:
222
+ tuple (4): a tuple containing the following:
223
+ """
224
+
225
+ remove_stagnant_files() # clean up old files
226
+ settings = {
227
+ "length_penalty": float(length_penalty),
228
+ "repetition_penalty": float(repetition_penalty),
229
+ "no_repeat_ngram_size": int(no_repeat_ngram_size),
230
+ "encoder_no_repeat_ngram_size": 4,
231
+ "num_beams": int(num_beams),
232
+ "min_length": 4,
233
+ "max_length": int(token_batch_length // 4),
234
+ "early_stopping": True,
235
+ "do_sample": False,
236
+ }
237
+ max_input_length = int(os.environ.get("APP_MAX_WORDS", max_input_length))
238
+ logging.info(
239
+ f"max_input_length set to: {max_input_length}. pre-drop stopwords: {predrop_stopwords}"
240
+ )
241
+
242
+ st = time.perf_counter()
243
+ history = {}
244
+ cln_text = clean(input_text, lower=False)
245
+ parsed_cln_text = remove_stopwords(cln_text) if predrop_stopwords else cln_text
246
+ logging.info(
247
+ f"pre-truncation word count: {len(contraction_aware_tokenize(parsed_cln_text))}"
248
+ )
249
+ truncation_validated = truncate_word_count(
250
+ parsed_cln_text, max_words=max_input_length
251
+ )
252
+
253
+ if truncation_validated["was_truncated"]:
254
+ model_input_text = truncation_validated["processed_text"]
255
+ # create elaborate HTML warning
256
+ input_wc = len(contraction_aware_tokenize(parsed_cln_text))
257
+ msg = f"""
258
+ <div style="background-color: #FFA500; color: white; padding: 20px;">
259
+ <h3>Warning</h3>
260
+ <p>Input text was truncated to {max_input_length} words. That's about {100*max_input_length/input_wc:.2f}% of the original text.</p>
261
+ <p>Dropping stopwords is set to {predrop_stopwords}. If this is not what you intended, please validate the advanced settings.</p>
262
+ </div>
263
+ """
264
+ logging.warning(msg)
265
+ history["WARNING"] = msg
266
+ else:
267
+ model_input_text = truncation_validated["processed_text"]
268
+ msg = None
269
+
270
+ if len(input_text) < 50:
271
+ # this is essentially a different case from the above
272
+ msg = f"""
273
+ <div style="background-color: #880808; color: white; padding: 20px;">
274
+ <br>
275
+ <img src="https://i.imgflip.com/7kadd9.jpg" alt="no text">
276
+ <br>
277
+ <h3>Error</h3>
278
+ <p>Input text is too short to summarize. Detected {len(input_text)} characters.
279
+ Please load text by selecting an example from the dropdown menu or by pasting text into the text box.</p>
280
+ </div>
281
+ """
282
+ logging.warning(msg)
283
+ logging.warning("RETURNING EMPTY STRING")
284
+ history["WARNING"] = msg
285
+
286
+ return msg, "<strong>No summary generated.</strong>", "", []
287
+
288
+ _summaries = predict(
289
+ input_text=model_input_text,
290
+ model_name=model_name,
291
+ token_batch_length=token_batch_length,
292
+ **settings,
293
+ )
294
+ sum_text = [s["summary"][0].strip() + "\n" for s in _summaries]
295
+ sum_scores = [
296
+ f" - Batch Summary {i}: {round(s['summary_score'],4)}"
297
+ for i, s in enumerate(_summaries)
298
+ ]
299
+
300
+ full_summary = textlist2html(sum_text)
301
+ history["Summary Scores"] = "<br><br>"
302
+ scores_out = "\n".join(sum_scores)
303
+ rt = round((time.perf_counter() - st) / 60, 2)
304
+ logging.info(f"Runtime: {rt} minutes")
305
+ html = ""
306
+ html += f"<p>Runtime: {rt} minutes with model: {model_name}</p>"
307
+ if msg is not None:
308
+ html += msg
309
+
310
+ html += ""
311
+
312
+ settings["remove_stopwords"] = predrop_stopwords
313
+ settings["model_name"] = model_name
314
+ saved_file = saves_summary(summarize_output=_summaries, outpath=None, **settings)
315
+ return html, full_summary, scores_out, saved_file
316
+
317
+
318
+ def load_single_example_text(
319
+ example_path: str or Path,
320
+ max_pages: int = 20,
321
+ ) -> str:
322
+ """
323
+ load_single_example_text - loads a single example text file
324
+
325
+ :param strorPath example_path: name of the example to load
326
+ :param int max_pages: the maximum number of pages to load from a PDF
327
+ :return str: the text of the example
328
+ """
329
+ global name_to_path, ocr_model
330
+ full_ex_path = name_to_path[example_path]
331
+ full_ex_path = Path(full_ex_path)
332
+ if full_ex_path.suffix in [".txt", ".md"]:
333
+ with open(full_ex_path, "r", encoding="utf-8", errors="ignore") as f:
334
+ raw_text = f.read()
335
+ text = clean(raw_text, lower=False)
336
+ elif full_ex_path.suffix == ".pdf":
337
+ logging.info(f"Loading PDF file {full_ex_path}")
338
+ max_pages = int(os.environ.get("APP_OCR_MAX_PAGES", max_pages))
339
+ logging.info(f"max_pages set to: {max_pages}")
340
+ conversion_stats = convert_PDF_to_Text(
341
+ full_ex_path,
342
+ ocr_model=ocr_model,
343
+ max_pages=max_pages,
344
+ )
345
+ text = conversion_stats["converted_text"]
346
+ else:
347
+ logging.error(f"Unknown file type {full_ex_path.suffix}")
348
+ text = "ERROR - check example path"
349
+
350
+ return text
351
+
352
+
353
+ def load_uploaded_file(file_obj, max_pages: int = 20, lower: bool = False) -> str:
354
+ """
355
+ load_uploaded_file - loads a file uploaded by the user
356
+
357
+ :param file_obj (POTENTIALLY list): Gradio file object inside a list
358
+ :param int max_pages: the maximum number of pages to load from a PDF
359
+ :param bool lower: whether to lowercase the text
360
+ :return str: the text of the file
361
+ """
362
+ global ocr_model
363
+ logger = logging.getLogger(__name__)
364
+ # check if mysterious file object is a list
365
+ if isinstance(file_obj, list):
366
+ file_obj = file_obj[0]
367
+ file_path = Path(file_obj.name)
368
+ try:
369
+ logger.info(f"Loading file:\t{file_path}")
370
+ if file_path.suffix in [".txt", ".md"]:
371
+ with open(file_path, "r", encoding="utf-8", errors="ignore") as f:
372
+ raw_text = f.read()
373
+ text = clean(raw_text, lower=lower)
374
+ elif file_path.suffix == ".pdf":
375
+ logger.info(f"loading a PDF file: {file_path.name}")
376
+ max_pages = int(os.environ.get("APP_OCR_MAX_PAGES", max_pages))
377
+ logger.info(f"max_pages is: {max_pages}. Starting conversion...")
378
+ conversion_stats = convert_PDF_to_Text(
379
+ file_path,
380
+ ocr_model=ocr_model,
381
+ max_pages=max_pages,
382
+ )
383
+ text = conversion_stats["converted_text"]
384
+ else:
385
+ logger.error(f"Unknown file type:\t{file_path.suffix}")
386
+ text = "ERROR - check file - unknown file type. PDF, TXT, and MD are supported."
387
+
388
+ return text
389
+ except Exception as e:
390
+ logger.error(f"Trying to load file:\t{file_path},\nerror:\t{e}")
391
+ return f"Error: Could not read file {file_path.name}. Make sure it is a PDF, TXT, or MD file."
392
+
393
+
394
+ def parse_args():
395
+ """arguments for the command line interface"""
396
+ parser = argparse.ArgumentParser(
397
+ description="Document Summarization with Long-Document Transformers - Demo",
398
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
399
+ epilog="Runs a local-only web UI to summarize documents. pass --share for a public link to share.",
400
+ )
401
+
402
+ parser.add_argument(
403
+ "--share",
404
+ dest="share",
405
+ action="store_true",
406
+ help="Create a public link to share",
407
+ )
408
+ parser.add_argument(
409
+ "-m",
410
+ "--model",
411
+ type=str,
412
+ default=None,
413
+ help=f"Add a custom model to the list of models: {pp.pformat(MODEL_OPTIONS, compact=True)}",
414
+ )
415
+ parser.add_argument(
416
+ "-nb",
417
+ "--add_beam_option",
418
+ type=int,
419
+ default=None,
420
+ help=f"Add a beam search option to the demo UI options, default: {pp.pformat(BEAM_OPTIONS, compact=True)}",
421
+ )
422
+ parser.add_argument(
423
+ "-batch",
424
+ "--token_batch_option",
425
+ type=int,
426
+ default=None,
427
+ help=f"Add a token batch size to the demo UI options, default: {pp.pformat(TOKEN_BATCH_OPTIONS, compact=True)}",
428
+ )
429
+ parser.add_argument(
430
+ "-max_agg",
431
+ "-2x",
432
+ "--aggregator_beam_boost",
433
+ dest="aggregator_beam_boost",
434
+ action="store_true",
435
+ help="Double the number of beams for the aggregator during beam search",
436
+ )
437
+ parser.add_argument(
438
+ "-level",
439
+ "--log_level",
440
+ type=str,
441
+ default="INFO",
442
+ choices=["DEBUG", "INFO", "WARNING", "ERROR"],
443
+ help="Set the logging level",
444
+ )
445
+
446
+ return parser.parse_args()
447
+
448
+
449
+ if __name__ == "__main__":
450
+ """main - the main function of the app"""
451
+ logger = logging.getLogger(__name__)
452
+ args = parse_args()
453
+ logger.setLevel(args.log_level)
454
+ logger.info(f"args: {pp.pformat(args.__dict__, compact=True)}")
455
+
456
+ # add any custom options
457
+ if args.model is not None:
458
+ logger.info(f"Adding model {args.model} to the list of models")
459
+ MODEL_OPTIONS.append(args.model)
460
+ if args.add_beam_option is not None:
461
+ logger.info(f"Adding beam search option {args.add_beam_option} to the list")
462
+ BEAM_OPTIONS.append(args.add_beam_option)
463
+ if args.token_batch_option is not None:
464
+ logger.info(f"Adding token batch option {args.token_batch_option} to the list")
465
+ TOKEN_BATCH_OPTIONS.append(args.token_batch_option)
466
+
467
+ if args.aggregator_beam_boost:
468
+ logger.info("Doubling aggregator num_beams")
469
+ _agg_cfg = aggregator.get_generation_config()
470
+ _agg_cfg["num_beams"] = _agg_cfg["num_beams"] * 2
471
+ aggregator.update_generation_config(**_agg_cfg)
472
+
473
+ logger.info("Loading OCR model")
474
+ with contextlib.redirect_stdout(None):
475
+ ocr_model = ocr_predictor(
476
+ "db_resnet50",
477
+ "crnn_mobilenet_v3_large",
478
+ pretrained=True,
479
+ assume_straight_pages=True,
480
+ )
481
+
482
+ # load the examples
483
+ name_to_path = load_example_filenames(_here / "examples")
484
+ logger.info(f"Loaded {len(name_to_path)} examples")
485
+
486
+ demo = gr.Blocks(title="Document Summarization")
487
+ _examples = list(name_to_path.keys())
488
+ logger.info("Starting app instance")
489
+ with demo:
490
+ gr.Markdown(
491
+ """# Document Summarization with Long-Document Transformers
492
+
493
+ An example use case for fine-tuned long document transformers. Model(s) are trained on [book summaries](https://hf.co/datasets/kmfoda/booksum). Architectures [in this demo](https://hf.co/spaces/pszemraj/document-summarization) are [LongT5-base](https://hf.co/pszemraj/long-t5-tglobal-base-16384-book-summary) and [Pegasus-X-Large](https://hf.co/pszemraj/pegasus-x-large-book-summary).
494
+
495
+ **Want more performance?** Run this demo from a free [Google Colab GPU](https://colab.research.google.com/gist/pszemraj/52f67cf7326e780155812a6a1f9bb724/document-summarization-on-gpu.ipynb)
496
+ """
497
+ )
498
+ with gr.Column():
499
+ gr.Markdown(
500
+ """## Load Inputs & Select Parameters
501
+
502
+ Enter/paste text below, or upload a file. Pick a model & adjust params (_optional_), and press **Summarize!**
503
+
504
+ See [the guide doc](https://gist.github.com/pszemraj/722a7ba443aa3a671b02d87038375519) for details.
505
+ """
506
+ )
507
+ with gr.Row():
508
+ with gr.Column(variant="compact"):
509
+ model_name = gr.Dropdown(
510
+ choices=MODEL_OPTIONS,
511
+ value=MODEL_OPTIONS[0],
512
+ label="Model Name",
513
+ )
514
+ num_beams = gr.Radio(
515
+ choices=BEAM_OPTIONS,
516
+ value=BEAM_OPTIONS[len(BEAM_OPTIONS) // 2],
517
+ label="Beam Search: # of Beams",
518
+ )
519
+ load_examples_button = gr.Button(
520
+ "Load Example in Dropdown",
521
+ )
522
+ load_file_button = gr.Button("Upload & Process File")
523
+ with gr.Column(variant="compact"):
524
+ example_name = gr.Dropdown(
525
+ _examples,
526
+ label="Examples",
527
+ value=random.choice(_examples),
528
+ )
529
+ uploaded_file = gr.File(
530
+ label="File Upload",
531
+ file_count="single",
532
+ file_types=[".txt", ".md", ".pdf"],
533
+ type="filepath",
534
+ )
535
+ with gr.Row():
536
+ input_text = gr.Textbox(
537
+ lines=4,
538
+ max_lines=8,
539
+ label="Text to Summarize",
540
+ placeholder="Enter text to summarize, the text will be cleaned and truncated on Spaces. Narrative, academic (both papers and lecture transcription), and article text work well. May take a bit to generate depending on the input text :)",
541
+ )
542
+ with gr.Column():
543
+ gr.Markdown("## Generate Summary")
544
+ with gr.Row():
545
+ summarize_button = gr.Button(
546
+ "Summarize!",
547
+ variant="primary",
548
+ )
549
+ gr.Markdown(
550
+ "_Summarization should take ~1-2 minutes for most settings, but may extend up to 5-10 minutes in some scenarios._"
551
+ )
552
+ output_text = gr.HTML("<em>Output will appear below:</em>")
553
+ with gr.Column():
554
+ gr.Markdown("### Results & Scores")
555
+ with gr.Row():
556
+ with gr.Column(variant="compact"):
557
+ gr.Markdown(
558
+ "Download the summary as a text file, with parameters and scores."
559
+ )
560
+ text_file = gr.File(
561
+ label="Download as Text File",
562
+ file_count="single",
563
+ type="filepath",
564
+ interactive=False,
565
+ )
566
+ with gr.Column(variant="compact"):
567
+ gr.Markdown(
568
+ "Scores **roughly** represent the summary quality as a measure of the model's 'confidence'. less-negative numbers (closer to 0) are better."
569
+ )
570
+ summary_scores = gr.Textbox(
571
+ label="Summary Scores",
572
+ placeholder="Summary scores will appear here",
573
+ )
574
+ with gr.Column(variant="panel"):
575
+ gr.Markdown("### **Summary Output**")
576
+ summary_text = gr.HTML(
577
+ label="Summary",
578
+ value="<i>Summary will appear here!</i>",
579
+ )
580
+ with gr.Column():
581
+ gr.Markdown("### **Aggregate Summary Batches**")
582
+ with gr.Row():
583
+ aggregate_button = gr.Button(
584
+ "Aggregate!",
585
+ variant="primary",
586
+ )
587
+ gr.Markdown(
588
+ f"""Aggregate the above batches into a cohesive summary.
589
+ - A secondary instruct-tuned LM consolidates info
590
+ - Current model: [{AGGREGATE_MODEL}](https://hf.co/{AGGREGATE_MODEL})
591
+ """
592
+ )
593
+ with gr.Column(variant="panel"):
594
+ aggregated_summary = gr.HTML(
595
+ label="Aggregate Summary",
596
+ value="<i>Aggregate summary will appear here!</i>",
597
+ )
598
+
599
+ with gr.Column():
600
+ gr.Markdown(
601
+ """### Advanced Settings
602
+
603
+ Refer to [the guide doc](https://gist.github.com/pszemraj/722a7ba443aa3a671b02d87038375519) for what these are, and how they impact _quality_ and _speed_.
604
+ """
605
+ )
606
+ with gr.Row():
607
+ length_penalty = gr.Slider(
608
+ minimum=0.3,
609
+ maximum=1.1,
610
+ label="length penalty",
611
+ value=0.7,
612
+ step=0.05,
613
+ )
614
+ token_batch_length = gr.Radio(
615
+ choices=TOKEN_BATCH_OPTIONS,
616
+ label="token batch length",
617
+ # select median option
618
+ value=TOKEN_BATCH_OPTIONS[len(TOKEN_BATCH_OPTIONS) // 2],
619
+ )
620
+
621
+ with gr.Row():
622
+ repetition_penalty = gr.Slider(
623
+ minimum=1.0,
624
+ maximum=5.0,
625
+ label="repetition penalty",
626
+ value=1.5,
627
+ step=0.1,
628
+ )
629
+ no_repeat_ngram_size = gr.Radio(
630
+ choices=[2, 3, 4, 5],
631
+ label="no repeat ngram size",
632
+ value=3,
633
+ )
634
+ predrop_stopwords = gr.Checkbox(
635
+ label="Drop Stopwords (Pre-Truncation)",
636
+ value=False,
637
+ )
638
+
639
+ load_examples_button.click(
640
+ fn=load_single_example_text, inputs=[example_name], outputs=[input_text]
641
+ )
642
+
643
+ load_file_button.click(
644
+ fn=load_uploaded_file, inputs=uploaded_file, outputs=[input_text]
645
+ )
646
+
647
+ summarize_button.click(
648
+ fn=proc_submission,
649
+ inputs=[
650
+ input_text,
651
+ model_name,
652
+ num_beams,
653
+ token_batch_length,
654
+ length_penalty,
655
+ repetition_penalty,
656
+ no_repeat_ngram_size,
657
+ predrop_stopwords,
658
+ ],
659
+ outputs=[output_text, summary_text, summary_scores, text_file],
660
+ )
661
+ aggregate_button.click(
662
+ fn=aggregate_text,
663
+ inputs=[summary_text, text_file],
664
+ outputs=[aggregated_summary],
665
+ )
666
+ demo.launch(share=args.share, debug=True)
gitattributes ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ftz filter=lfs diff=lfs merge=lfs -text
6
+ *.gz filter=lfs diff=lfs merge=lfs -text
7
+ *.h5 filter=lfs diff=lfs merge=lfs -text
8
+ *.joblib filter=lfs diff=lfs merge=lfs -text
9
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
10
+ *.model filter=lfs diff=lfs merge=lfs -text
11
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
12
+ *.npy filter=lfs diff=lfs merge=lfs -text
13
+ *.npz filter=lfs diff=lfs merge=lfs -text
14
+ *.onnx filter=lfs diff=lfs merge=lfs -text
15
+ *.ot filter=lfs diff=lfs merge=lfs -text
16
+ *.parquet filter=lfs diff=lfs merge=lfs -text
17
+ *.pickle filter=lfs diff=lfs merge=lfs -text
18
+ *.pkl filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pt filter=lfs diff=lfs merge=lfs -text
21
+ *.pth filter=lfs diff=lfs merge=lfs -text
22
+ *.rar filter=lfs diff=lfs merge=lfs -text
23
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
24
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
25
+ *.tflite filter=lfs diff=lfs merge=lfs -text
26
+ *.tgz filter=lfs diff=lfs merge=lfs -text
27
+ *.wasm filter=lfs diff=lfs merge=lfs -text
28
+ *.xz filter=lfs diff=lfs merge=lfs -text
29
+ *.zip filter=lfs diff=lfs merge=lfs -text
30
+ *.zst filter=lfs diff=lfs merge=lfs -text
31
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
gitignore ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # logs
2
+ *.log
3
+ *LOGFILE*
4
+
5
+ # output files need to be force-added
6
+ *.csv
7
+ *.png
8
+ *.jpg
9
+ *.jpeg
10
+ *.pkl
11
+ *.xlsx
12
+ *.txt
13
+
14
+ # cache
15
+ *__pycache__/
16
+ *.pyc
17
+
18
+ # reports folder - need to be force-added
19
+ *reports/
20
+
21
+ # scratch files and folders
22
+
23
+ *scratch*
24
+ *scratch/
25
+
26
+ # notebooks
27
+
28
+ *notebooks/
29
+ *.ipynb
pdf2text.py ADDED
@@ -0,0 +1,346 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ pdf2text.py - convert pdf files to text files using OCR
4
+ """
5
+ import logging
6
+ import os
7
+ import re
8
+ import shutil
9
+ import time
10
+ from datetime import date
11
+ from os.path import join
12
+ from pathlib import Path
13
+
14
+ logging.basicConfig(
15
+ level=logging.INFO,
16
+ format="%(asctime)s %(levelname)s %(message)s",
17
+ datefmt="%m/%d/%Y %I:%M:%S",
18
+ )
19
+
20
+
21
+ os.environ["USE_TORCH"] = "1"
22
+
23
+ from cleantext import clean
24
+ from doctr.io import DocumentFile
25
+ from doctr.models import ocr_predictor
26
+ from spellchecker import SpellChecker
27
+
28
+
29
+ def simple_rename(filepath, target_ext=".txt"):
30
+ """simple_rename - get a new str to rename a file"""
31
+ _fp = Path(filepath)
32
+ basename = _fp.stem
33
+ return f"OCR_{basename}_{target_ext}"
34
+
35
+
36
+ def rm_local_text_files(name_contains="RESULT_"):
37
+ """
38
+ rm_local_text_files - remove local text files
39
+ """
40
+ files = [
41
+ f
42
+ for f in Path.cwd().iterdir()
43
+ if f.is_file() and f.suffix == ".txt" and name_contains in f.name
44
+ ]
45
+ logging.info(f"removing {len(files)} text files")
46
+ for f in files:
47
+ os.remove(f)
48
+ logging.info("done")
49
+
50
+
51
+ def corr(
52
+ s: str,
53
+ add_space_when_numerics=False,
54
+ exceptions=["e.g.", "i.e.", "etc.", "cf.", "vs.", "p."],
55
+ ) -> str:
56
+ """corrects spacing in a string
57
+
58
+ Args:
59
+ s (str): the string to correct
60
+ add_space_when_numerics (bool, optional): [add a space when a period is between two numbers, example 5.73]. Defaults to False.
61
+ exceptions (list, optional): [do not change these substrings]. Defaults to ['e.g.', 'i.e.', 'etc.', 'cf.', 'vs.', 'p.'].
62
+
63
+ Returns:
64
+ str: the corrected string
65
+ """
66
+ if add_space_when_numerics:
67
+ s = re.sub(r"(\d)\.(\d)", r"\1. \2", s)
68
+
69
+ s = re.sub(r"\s+", " ", s)
70
+ s = re.sub(r'\s([?.!"](?:\s|$))', r"\1", s)
71
+
72
+ # fix space before apostrophe
73
+ s = re.sub(r"\s\'", r"'", s)
74
+ # fix space after apostrophe
75
+ s = re.sub(r"'\s", r"'", s)
76
+ # fix space before comma
77
+ s = re.sub(r"\s,", r",", s)
78
+
79
+ for e in exceptions:
80
+ expected_sub = re.sub(r"\s", "", e)
81
+ s = s.replace(expected_sub, e)
82
+
83
+ return s
84
+
85
+
86
+ def fix_punct_spaces(string: str) -> str:
87
+ """
88
+ fix_punct_spaces - fix spaces around punctuation
89
+
90
+ :param str string: input string
91
+ :return str: string with spaces fixed
92
+ """
93
+
94
+ fix_spaces = re.compile(r"\s*([?!.,]+(?:\s+[?!.,]+)*)\s*")
95
+ string = fix_spaces.sub(lambda x: "{} ".format(x.group(1).replace(" ", "")), string)
96
+ string = string.replace(" ' ", "'")
97
+ string = string.replace(' " ', '"')
98
+ return string.strip()
99
+
100
+
101
+ def clean_OCR(ugly_text: str) -> str:
102
+ """
103
+ clean_OCR - clean up the OCR text
104
+
105
+ :param str ugly_text: input text to be cleaned
106
+ :return str: cleaned text
107
+ """
108
+ # Remove all the newlines.
109
+ cleaned_text = ugly_text.replace("\n", " ")
110
+ # Remove all the tabs.
111
+ cleaned_text = cleaned_text.replace("\t", " ")
112
+ # Remove all the double spaces.
113
+ cleaned_text = cleaned_text.replace(" ", " ")
114
+ # Remove all the spaces at the beginning of the text.
115
+ cleaned_text = cleaned_text.lstrip()
116
+ # remove all instances of "- " and " - "
117
+ cleaned_text = cleaned_text.replace("- ", "")
118
+ cleaned_text = cleaned_text.replace(" -", "")
119
+ return fix_punct_spaces(cleaned_text)
120
+
121
+
122
+ def move2completed(
123
+ from_dir, filename, new_folder: str = "completed", verbose: bool = False
124
+ ):
125
+ """
126
+ move2completed - move a file to a new folder
127
+ """
128
+ old_filepath = join(from_dir, filename)
129
+
130
+ new_filedirectory = join(from_dir, new_folder)
131
+
132
+ if not os.path.isdir(new_filedirectory):
133
+ os.mkdir(new_filedirectory)
134
+ if verbose:
135
+ print("created new directory for files at: \n", new_filedirectory)
136
+ new_filepath = join(new_filedirectory, filename)
137
+
138
+ try:
139
+ shutil.move(old_filepath, new_filepath)
140
+ logging.info("successfully moved the file {} to */completed.".format(filename))
141
+ except:
142
+ logging.info(
143
+ "ERROR! unable to move file to \n{}. Please investigate".format(
144
+ new_filepath
145
+ )
146
+ )
147
+
148
+
149
+ custom_replace_list = {
150
+ "t0": "to",
151
+ "'$": "'s",
152
+ ",,": ", ",
153
+ "_ ": " ",
154
+ " '": "'",
155
+ }
156
+
157
+ replace_corr_exceptions = {
158
+ "i. e.": "i.e.",
159
+ "e. g.": "e.g.",
160
+ "e. g": "e.g.",
161
+ " ,": ",",
162
+ }
163
+
164
+
165
+ spell = SpellChecker()
166
+
167
+
168
+ def check_word_spelling(word: str) -> bool:
169
+ """
170
+ check_word_spelling - check the spelling of a word
171
+
172
+ Args:
173
+ word (str): word to check
174
+
175
+ Returns:
176
+ bool: True if word is spelled correctly, False if not
177
+ """
178
+
179
+ misspelled = spell.unknown([word])
180
+
181
+ return len(misspelled) == 0
182
+
183
+
184
+ def eval_and_replace(text: str, match_token: str = "- ") -> str:
185
+ """
186
+ eval_and_replace - conditionally replace all instances of a substring in a string based on whether the eliminated substring results in a valid word
187
+
188
+ Args:
189
+ text (str): text to evaluate
190
+ match_token (str, optional): token to replace. Defaults to "- ".
191
+
192
+ Returns:
193
+ str: text with replaced tokens
194
+ """
195
+
196
+ if match_token not in text:
197
+ return text
198
+ else:
199
+ while True:
200
+ full_before_text = text.split(match_token, maxsplit=1)[0]
201
+ before_text = [
202
+ char for char in full_before_text.split()[-1] if char.isalpha()
203
+ ]
204
+ before_text = "".join(before_text)
205
+ full_after_text = text.split(match_token, maxsplit=1)[-1]
206
+ after_text = [char for char in full_after_text.split()[0] if char.isalpha()]
207
+ after_text = "".join(after_text)
208
+ full_text = before_text + after_text
209
+ if check_word_spelling(full_text):
210
+ text = full_before_text + full_after_text
211
+ else:
212
+ text = full_before_text + " " + full_after_text
213
+ if match_token not in text:
214
+ break
215
+ return text
216
+
217
+
218
+ def cleantxt_ocr(ugly_text, lower=False, lang: str = "en") -> str:
219
+ """
220
+ cleantxt_ocr - clean text from OCR
221
+
222
+ https://pypi.org/project/clean-text/
223
+ Args:
224
+ ugly_text (str): text to clean
225
+ lower (bool, optional): lowercase text. Defaults to False.
226
+ lang (str, optional): language of text. Defaults to "en".
227
+
228
+ Returns:
229
+ str: cleaned text
230
+ """
231
+
232
+ cleaned_text = clean(
233
+ ugly_text,
234
+ fix_unicode=True, # fix various unicode errors
235
+ to_ascii=True, # transliterate to closest ASCII representation
236
+ lower=lower, # lowercase text
237
+ no_line_breaks=True, # fully strip line breaks as opposed to only normalizing them
238
+ no_urls=True, # replace all URLs with a special token
239
+ no_emails=True, # replace all email addresses with a special token
240
+ no_phone_numbers=True, # replace all phone numbers with a special token
241
+ no_numbers=False, # replace all numbers with a special token
242
+ no_digits=False, # replace all digits with a special token
243
+ no_currency_symbols=False, # replace all currency symbols with a special token
244
+ no_punct=False, # remove punctuations
245
+ replace_with_punct="", # instead of removing punctuations you may replace them
246
+ replace_with_url="this url",
247
+ replace_with_email="this email",
248
+ replace_with_phone_number="this phone number",
249
+ lang=lang, # set to 'de' for German special handling
250
+ )
251
+
252
+ return cleaned_text
253
+
254
+
255
+ def format_ocr_out(OCR_data):
256
+ """format OCR output to text"""
257
+ if isinstance(OCR_data, list):
258
+ text = " ".join(OCR_data)
259
+ else:
260
+ text = str(OCR_data)
261
+ _clean = cleantxt_ocr(text)
262
+ return corr(_clean)
263
+
264
+
265
+ def postprocess(text: str) -> str:
266
+ """to be used after recombining the lines"""
267
+
268
+ proc = corr(cleantxt_ocr(text))
269
+
270
+ for k, v in custom_replace_list.items():
271
+ proc = proc.replace(str(k), str(v))
272
+
273
+ proc = corr(proc)
274
+
275
+ for k, v in replace_corr_exceptions.items():
276
+ proc = proc.replace(str(k), str(v))
277
+
278
+ return eval_and_replace(proc)
279
+
280
+
281
+ def result2text(result, as_text=False) -> str or list:
282
+ """Convert OCR result to text"""
283
+
284
+ full_doc = []
285
+ for i, page in enumerate(result.pages, start=1):
286
+ text = ""
287
+ for block in page.blocks:
288
+ text += "\n\t"
289
+ for line in block.lines:
290
+ for word in line.words:
291
+ # print(dir(word))
292
+ text += word.value + " "
293
+ full_doc.append(text)
294
+
295
+ return "\n".join(full_doc) if as_text else full_doc
296
+
297
+
298
+ def convert_PDF_to_Text(
299
+ PDF_file,
300
+ ocr_model=None,
301
+ max_pages: int = 20,
302
+ ) -> str:
303
+ """
304
+ convert_PDF_to_Text - convert a PDF file to text
305
+
306
+ :param str PDF_file: path to PDF file
307
+ :param ocr_model: model to use for OCR, defaults to None (uses the default model)
308
+ :param int max_pages: maximum number of pages to process, defaults to 20
309
+ :return str: text from PDF
310
+ """
311
+ st = time.perf_counter()
312
+ PDF_file = Path(PDF_file)
313
+ ocr_model = ocr_predictor(pretrained=True) if ocr_model is None else ocr_model
314
+ logging.info(f"starting OCR on {PDF_file.name}")
315
+ doc = DocumentFile.from_pdf(PDF_file)
316
+ truncated = False
317
+ if len(doc) > max_pages:
318
+ logging.warning(
319
+ f"PDF has {len(doc)} pages, which is more than {max_pages}.. truncating"
320
+ )
321
+ doc = doc[:max_pages]
322
+ truncated = True
323
+
324
+ # Analyze
325
+ logging.info(f"running OCR on {len(doc)} pages")
326
+ result = ocr_model(doc)
327
+ raw_text = result2text(result)
328
+ proc_text = [format_ocr_out(r) for r in raw_text]
329
+ fin_text = [postprocess(t) for t in proc_text]
330
+
331
+ ocr_results = "\n\n".join(fin_text)
332
+
333
+ fn_rt = time.perf_counter() - st
334
+
335
+ logging.info("OCR complete")
336
+
337
+ results_dict = {
338
+ "num_pages": len(doc),
339
+ "runtime": round(fn_rt, 2),
340
+ "date": str(date.today()),
341
+ "converted_text": ocr_results,
342
+ "truncated": truncated,
343
+ "length": len(ocr_results),
344
+ }
345
+
346
+ return results_dict
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate
2
+ clean-text[gpl]
3
+ gradio==5.5.0
4
+ natsort
5
+ nltk
6
+ pyspellchecker
7
+ python-doctr[torch]
8
+ rapidfuzz==2.13.7
9
+ sentencepiece
10
+ torch
11
+ tqdm
12
+ transformers==4.46.2
summarize.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ summarize - a module for summarizing text using a model from the Hugging Face model hub
3
+ """
4
+ import logging
5
+ import os
6
+ import pprint as pp
7
+
8
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(message)s")
9
+
10
+ import torch
11
+ from tqdm.auto import tqdm
12
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
13
+
14
+ from utils import validate_pytorch2
15
+
16
+
17
+ def load_model_and_tokenizer(model_name: str) -> tuple:
18
+ """
19
+ load_model_and_tokenizer - load a model and tokenizer from a model name/ID on the hub
20
+
21
+ :param str model_name: the model name/ID on the hub
22
+ :return tuple: a tuple containing the model and tokenizer
23
+ """
24
+ device = "cuda" if torch.cuda.is_available() else "cpu"
25
+ model = AutoModelForSeq2SeqLM.from_pretrained(
26
+ model_name,
27
+ use_auth_token=os.environ.get("HF_TOKEN", None),
28
+ ).to(device)
29
+ model = model.eval()
30
+
31
+ tokenizer = AutoTokenizer.from_pretrained(
32
+ model_name,
33
+ use_auth_token=os.environ.get("HF_TOKEN", None),
34
+ )
35
+
36
+ logging.info(f"Loaded model {model_name} to {device}")
37
+
38
+ if validate_pytorch2():
39
+ try:
40
+ logging.info("Compiling model with Torch 2.0")
41
+ model = torch.compile(model)
42
+ except Exception as e:
43
+ logging.warning(f"Could not compile model with Torch 2.0: {e}")
44
+ else:
45
+ logging.info("Torch 2.0 not detected, skipping compilation")
46
+
47
+ return model, tokenizer
48
+
49
+
50
+ def summarize_and_score(
51
+ ids, mask, model, tokenizer, is_general_attention_model=True, **kwargs
52
+ ) -> tuple:
53
+ """
54
+ summarize_and_score - given a batch of ids and a mask, return a summary and a score for the summary
55
+
56
+ Args:
57
+ ids (): the batch of ids
58
+ mask (): the attention mask for the batch
59
+ model (): the model to use for summarization
60
+ tokenizer (): the tokenizer to use for summarization
61
+ is_general_attention_model (bool, optional): whether the model is a general attention model. Defaults to True.
62
+ **kwargs: any additional arguments to pass to the model
63
+ Returns:
64
+ tuple (str, float): the summary, the score for the summary
65
+ """
66
+
67
+ ids = ids[None, :]
68
+ mask = mask[None, :]
69
+
70
+ input_ids = ids.to("cuda") if torch.cuda.is_available() else ids
71
+ attention_mask = mask.to("cuda") if torch.cuda.is_available() else mask
72
+
73
+ global_attention_mask = torch.zeros_like(attention_mask)
74
+ # put global attention on <s> token
75
+ global_attention_mask[:, 0] = 1
76
+
77
+ if is_general_attention_model:
78
+ summary_pred_ids = model.generate(
79
+ input_ids,
80
+ attention_mask=attention_mask,
81
+ output_scores=True,
82
+ return_dict_in_generate=True,
83
+ **kwargs,
84
+ )
85
+ else:
86
+ summary_pred_ids = model.generate(
87
+ input_ids,
88
+ attention_mask=attention_mask,
89
+ global_attention_mask=global_attention_mask,
90
+ output_scores=True,
91
+ return_dict_in_generate=True,
92
+ **kwargs,
93
+ )
94
+ summary = tokenizer.batch_decode(
95
+ summary_pred_ids.sequences,
96
+ skip_special_tokens=True,
97
+ remove_invalid_values=True,
98
+ )
99
+ score = round(summary_pred_ids.sequences_scores.cpu().numpy()[0], 4)
100
+
101
+ return summary, score
102
+
103
+
104
+ def summarize_via_tokenbatches(
105
+ input_text: str,
106
+ model,
107
+ tokenizer,
108
+ batch_length=2048,
109
+ batch_stride=16,
110
+ min_batch_length=512,
111
+ **kwargs,
112
+ ) -> list:
113
+ """
114
+ summarize_via_tokenbatches - summarize a long string via batches of tokens
115
+
116
+ Args:
117
+ input_text (str): the text to summarize
118
+ model (): the model to use for summarization
119
+ tokenizer (): the tokenizer to use for summarization
120
+ batch_length (int, optional): the length of each batch. Defaults to 2048.
121
+ batch_stride (int, optional): the stride of each batch. Defaults to 16. The stride is the number of tokens that overlap between batches.
122
+ min_batch_length (int, optional): the minimum length of each batch. Defaults to 512.
123
+
124
+ **kwargs: any additional arguments to pass to the model for inference
125
+ Returns:
126
+ list: a list of dictionaries containing the input tokens, the summary, and the summary score
127
+ """
128
+
129
+ logger = logging.getLogger(__name__)
130
+ # log all input parameters
131
+ if batch_length < min_batch_length:
132
+ logger.warning(
133
+ f"batch_length must be at least {min_batch_length}. Setting batch_length to {min_batch_length}"
134
+ )
135
+ batch_length = min_batch_length
136
+
137
+ logger.info(f"input parameters:\n{pp.pformat(kwargs)}")
138
+ logger.info(f"batch_length: {batch_length}, batch_stride: {batch_stride}")
139
+
140
+ encoded_input = tokenizer(
141
+ input_text,
142
+ padding="max_length",
143
+ truncation=True,
144
+ max_length=batch_length,
145
+ stride=batch_stride,
146
+ return_overflowing_tokens=True,
147
+ add_special_tokens=False,
148
+ return_tensors="pt",
149
+ )
150
+
151
+ in_id_arr, att_arr = encoded_input.input_ids, encoded_input.attention_mask
152
+ gen_summaries = []
153
+
154
+ pbar = tqdm(total=len(in_id_arr))
155
+
156
+ for _id, _mask in zip(in_id_arr, att_arr):
157
+ result, score = summarize_and_score(
158
+ ids=_id,
159
+ mask=_mask,
160
+ model=model,
161
+ tokenizer=tokenizer,
162
+ **kwargs,
163
+ )
164
+ score = round(float(score), 4)
165
+ _sum = {
166
+ "input_tokens": _id,
167
+ "summary": result,
168
+ "summary_score": score,
169
+ }
170
+ gen_summaries.append(_sum)
171
+ logger.debug(f"Score for batch: {score}. num chars: {len(repr(result))}")
172
+ logger.debug(f"Summary:\n\t{result}")
173
+ pbar.update()
174
+
175
+ pbar.close()
176
+
177
+ return gen_summaries
utils.py ADDED
@@ -0,0 +1,450 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ utils.py - Utility functions for the project.
3
+ """
4
+ import logging
5
+ import re
6
+ import subprocess
7
+ from collections import defaultdict, deque
8
+ from datetime import datetime, timedelta
9
+ from itertools import combinations
10
+ from pathlib import Path
11
+ from typing import List
12
+
13
+ logging.basicConfig(
14
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
15
+ level=logging.INFO,
16
+ )
17
+
18
+ import torch
19
+ from natsort import natsorted
20
+ from nltk.tokenize import WhitespaceTokenizer, sent_tokenize, word_tokenize
21
+ from rapidfuzz import fuzz
22
+
23
+ STOPWORDS = set(
24
+ "a about above after again all also am an and any are aren't as at back be because been before being below between both but by can't cannot could couldn't did didn't do does doesn't doing don't down during each few for from further had hadn't has hasn't have haven't having he'd he'll he's hence her here here's hers herself him himself his how how's however i'd i'll i'm i've if in into is isn't it's its itself just let's me more moreover most mustn't my myself new nor now of off on once only or other ought our ours ourselves out over own really same shan't she'd she'll she's should shouldn't so some such than that that's the their theirs them themselves then there there's therefore these they they'd they'll they're they've this those through thus to too under until up use used using very was wasn't we we'd we'll we're we've were weren't what what's when when's where where's which while who who's whom why why's with won't would wouldn't you'd you'll you're you've your yours yourself yourselves".split()
25
+ )
26
+
27
+
28
+ def contraction_aware_tokenize(text: str) -> List[str]:
29
+ """contraction_aware_tokenize - merges words containing apostrophes as one token."""
30
+
31
+ # Tokenize the text using the WhitespaceTokenizer
32
+ tokenizer = WhitespaceTokenizer()
33
+ tokens = tokenizer.tokenize(text)
34
+
35
+ merged_tokens = []
36
+ merged_token = ""
37
+
38
+ for token in tokens:
39
+ if re.search(r"\w+'\w+", token):
40
+ # Token contains an apostrophe, merge with previous token
41
+ merged_token += token
42
+ else:
43
+ # no apostrophe, add previous merged token (if any) and current
44
+ if merged_token:
45
+ merged_tokens.append(merged_token)
46
+ merged_token = ""
47
+ merged_tokens.append(token)
48
+
49
+ # Add the last merged token (if any)
50
+ if merged_token:
51
+ merged_tokens.append(merged_token)
52
+
53
+ return merged_tokens
54
+
55
+
56
+ def remove_stopwords(
57
+ text: str, stopwords: List[str] = STOPWORDS, contraction_tokenize: bool = True
58
+ ) -> str:
59
+ """
60
+ remove_stopwords - Remove stopwords from text.
61
+
62
+ :param str text: input text
63
+ :param List[str] stopwords: list of stopwords, defaults to STOPWORDS
64
+ :param bool contraction_tokenize: use custom apostrophe tokenizer, defaults to True
65
+ :return str: text with stopwords removed
66
+ """
67
+ lines = text.split("\n")
68
+ filtered_lines = []
69
+
70
+ def fix_commas(text: str) -> str:
71
+ """fixes commas in text to have a space after them"""
72
+ spaced_text = text.replace(",", ", ")
73
+ return spaced_text.replace(" ", " ").strip()
74
+
75
+ for line in lines:
76
+ sentences = sent_tokenize(line)
77
+ filtered_sentences = []
78
+
79
+ for sentence in sentences:
80
+ # Add space around punctuations for the regex to work correctly, only if they are followed by a letter
81
+ sentence_with_spaces = re.sub(r"([.,!?])(\w)", r"\1 \2", sentence[:-1])
82
+
83
+ words = (
84
+ contraction_aware_tokenize(sentence_with_spaces)
85
+ if contraction_tokenize
86
+ else word_tokenize(sentence_with_spaces)
87
+ )
88
+
89
+ filtered_words = []
90
+ for word in words:
91
+ if word.lower() not in stopwords:
92
+ filtered_words.append(word)
93
+
94
+ filtered_sentence = " ".join(filtered_words)
95
+ # Restore original spaces around punctuation marks
96
+ filtered_sentence = re.sub(r"([.,!?])\s*", r"\1", filtered_sentence)
97
+
98
+ filtered_sentences.append(filtered_sentence + sentence[-1])
99
+
100
+ filtered_line = " ".join(filtered_sentences)
101
+
102
+ # Replace multiple consecutive whitespaces with a single space
103
+ filtered_line = re.sub(r"\s+", " ", filtered_line)
104
+ filtered_line = fix_commas(filtered_line.strip())
105
+
106
+ filtered_lines.append(filtered_line)
107
+
108
+ filtered_text = "\n".join(filtered_lines)
109
+
110
+ return filtered_text
111
+
112
+
113
+ def remove_stagnant_files(
114
+ freq: str = "hourly",
115
+ search_path: str = ".",
116
+ substring="DocSumm",
117
+ remove_suffix=".txt",
118
+ ):
119
+ """
120
+ remove_stagnant_files - Remove files that have not been modified in a certain amount of time.
121
+
122
+ :param str freq: frequency of file removal, defaults to "hourly"
123
+ :param str search_path: location to search for files, defaults to "."
124
+ :param str substring: substring to search for in file names, defaults to "DocSumm"
125
+ :param str remove_suffix: suffix of files to remove, defaults to ".txt"
126
+ :raises ValueError: if freq is not one of "hourly", "daily", or "weekly"
127
+ """
128
+ current_time = datetime.now()
129
+ search_path = Path(search_path)
130
+
131
+ if freq == "hourly":
132
+ time_threshold = current_time - timedelta(hours=1)
133
+ elif freq == "daily":
134
+ time_threshold = current_time - timedelta(days=1)
135
+ elif freq == "weekly":
136
+ time_threshold = current_time - timedelta(weeks=1)
137
+ else:
138
+ raise ValueError(
139
+ "Invalid frequency. Supported values are 'hourly', 'daily', and 'weekly'."
140
+ )
141
+
142
+ files_to_remove = []
143
+ potential_files = [
144
+ f for f in search_path.iterdir() if f.is_file() and f.suffix == remove_suffix
145
+ ]
146
+ logging.info(f"Found {len(potential_files)} files.")
147
+ for candidate in potential_files:
148
+ if (
149
+ candidate.is_file()
150
+ and substring in candidate.name
151
+ and candidate.stat().st_mtime < time_threshold.timestamp()
152
+ ):
153
+ files_to_remove.append(candidate)
154
+ logging.debug(f"File {candidate} last modified at {candidate.stat().st_mtime}")
155
+ logging.info(f"Removing {len(files_to_remove)} files.")
156
+ for file_path in files_to_remove:
157
+ file_path.unlink()
158
+ logging.debug(f"Removed files: {files_to_remove}")
159
+
160
+
161
+ def compare_model_size(model_name: str, threshold: int = 500) -> bool:
162
+ """
163
+ compare_model_size - compare string representations of model size to a threshold
164
+
165
+ :param str model_name: the model name to compare
166
+ :param int threshold: the threshold to compare against in millions, defaults to 500
167
+ :return: True if the model size is greater than the threshold, False or None otherwise
168
+ """
169
+ pattern = r"(\d+)(M|G|k|b)?" # param regex
170
+
171
+ matches = re.findall(pattern, model_name)
172
+ if not matches:
173
+ return None
174
+
175
+ # Extract the parameter count and unit
176
+ parameter_count, unit = matches[-1]
177
+ parameter_count = int(parameter_count)
178
+
179
+ # Convert to the standard form (M for million, G for billion, k for thousand)
180
+ if unit == "G" or unit == "b":
181
+ parameter_count *= 1000
182
+ elif unit == "M":
183
+ pass
184
+ elif unit == "k":
185
+ parameter_count /= 1000
186
+ else:
187
+ return None # Unknown
188
+
189
+ return parameter_count > threshold
190
+
191
+
192
+ def validate_pytorch2(torch_version: str = None) -> bool:
193
+ """
194
+ validate_pytorch2 - validate that the PyTorch version is 2.0 or greater
195
+
196
+ :param str torch_version: the PyTorch version to validate, defaults to None
197
+ :return: True if the PyTorch version is 2.0 or greater, False otherwise
198
+ """
199
+
200
+ torch_version = torch.__version__ if torch_version is None else torch_version
201
+
202
+ pattern = r"^2\.\d+(\.\d+)*"
203
+
204
+ return True if re.match(pattern, torch_version) else False
205
+
206
+
207
+ def get_timestamp(detailed=False) -> str:
208
+ """
209
+ get_timestamp - get a timestamp for the current time
210
+ :param bool detailed: whether to include seconds and microseconds, defaults to False
211
+ :return: str, the timestamp
212
+ """
213
+ return (
214
+ datetime.now().strftime("%b%d%Y_%H%M%S%f")
215
+ if detailed
216
+ else datetime.now().strftime("%b%d%Y_%H")
217
+ )
218
+
219
+
220
+ def truncate_word_count(text: str, max_words=1024) -> dict:
221
+ """
222
+ truncate_word_count - truncate a text to a maximum number of words
223
+ :param str text: the text to truncate
224
+ :param int max_words: the maximum number of words to keep, defaults to 1024
225
+ :return: dict, the processed text
226
+ """
227
+ words = contraction_aware_tokenize(str(text))
228
+ processed = {}
229
+ if len(words) > max_words:
230
+ processed["was_truncated"] = True
231
+ processed["processed_text"] = " ".join(words[:max_words])
232
+ else:
233
+ processed["was_truncated"] = False
234
+ processed["processed_text"] = text
235
+ return processed
236
+
237
+
238
+ def load_examples(src, filetypes=[".txt", ".pdf"]):
239
+ """
240
+ load_examples - a helper function for the gradio module to load examples
241
+ :param str src: the path to the examples
242
+ """
243
+ src = Path(src)
244
+ src.mkdir(exist_ok=True)
245
+
246
+ pdf_url = (
247
+ "https://www.dropbox.com/s/y92xy7o5qb88yij/all_you_need_is_attention.pdf?dl=1"
248
+ )
249
+ subprocess.run(["wget", pdf_url, "-O", src / "all_you_need_is_attention.pdf"])
250
+ examples = [f for f in src.iterdir() if f.suffix in filetypes]
251
+ examples = natsorted(examples)
252
+ # load the examples into a list
253
+ text_examples = []
254
+ for example in examples:
255
+ with open(example, "r") as f:
256
+ text = f.read()
257
+ text_examples.append([text, "base", 2, 1024, 0.7, 3.5, 3])
258
+
259
+ return text_examples
260
+
261
+
262
+ def load_example_filenames(example_path: str or Path):
263
+ """
264
+ load_example_filenames - a helper function for the gradio module to load examples
265
+ Returns:
266
+ dict, the examples (filename:full path)
267
+ """
268
+ example_path = Path(example_path)
269
+ # load the examples into a list
270
+ examples = {f.name: f for f in example_path.glob("*.txt")}
271
+ return examples
272
+
273
+
274
+ def textlist2html(text_batches: List[str]) -> str:
275
+ """textlist2html - convert a list of text summaries into a single HTML string"""
276
+ # Step 1: Generate each summary batch as a string of HTML
277
+ formatted_batches = [
278
+ f"""
279
+ <div style="
280
+ margin-bottom: 20px;
281
+ font-size: 18px;
282
+ line-height: 1.5em;
283
+ color: #333;
284
+ ">
285
+ <h2 style="font-size: 22px; color: #555;">Batch {i}:</h2>
286
+ <p style="white-space: pre-line;">{s}</p>
287
+ </div>
288
+ """
289
+ for i, s in enumerate(text_batches, start=1)
290
+ ]
291
+
292
+ # Step 2: Join all the summary batches together into one string
293
+ joined_batches = "".join(formatted_batches)
294
+
295
+ # Step 3: Wrap the summary string in a larger div with background color, border, and padding
296
+ text_html_block = f"""
297
+ <div style="
298
+ border: 1px solid #ddd;
299
+ border-radius: 5px;
300
+ padding: 20px;
301
+ ">
302
+ {joined_batches}
303
+ </div>
304
+ """
305
+
306
+ return text_html_block
307
+
308
+
309
+ def extract_batches(html_string: str, pattern=None, flags=None) -> list:
310
+ """
311
+ Extract batches of text from an HTML string.
312
+
313
+ Args:
314
+ html_string (str): The HTML string to extract batches from.
315
+ pattern (str, optional): The regular expression pattern to use. Defaults to a pattern that matches batches in the format provided.
316
+ flags (int, optional): The flags to use with the regular expression. Defaults to re.DOTALL.
317
+
318
+ Returns:
319
+ list: A list of dictionaries where each dictionary represents a batch and has 'title' and 'content' keys.
320
+ """
321
+ # Set default pattern if none provided
322
+ if pattern is None:
323
+ pattern = r'<h2 style="font-size: 22px; color: #555;">(.*?)</h2>\s*<p style="white-space: pre-line;">(.*?)</p>'
324
+
325
+ # Set default flags if none provided
326
+ if flags is None:
327
+ flags = re.DOTALL
328
+
329
+ try:
330
+ # Find all matches in the string
331
+ matches = re.findall(pattern, html_string, flags)
332
+
333
+ # Convert matches to a list of dictionaries
334
+ batches = [
335
+ {"title": title.strip(), "content": content.strip()}
336
+ for title, content in matches
337
+ ]
338
+
339
+ return batches
340
+ except re.error as e:
341
+ logging.error(f"An error occurred while trying to extract batches: {e}")
342
+ return []
343
+
344
+
345
+ def extract_keywords(
346
+ text: str, num_keywords: int = 3, window_size: int = 5, kw_max_len: int = 20
347
+ ) -> List[str]:
348
+ """
349
+ Extracts keywords from a text using a simplified TextRank algorithm.
350
+
351
+ Args:
352
+ text: The text to extract keywords from.
353
+ num_keywords: The number of keywords to extract. Default: 3
354
+ window_size: The number of words considered for co-occurrence. Default: 5
355
+ kw_max_len: The maximum length of a keyword (truncate longer keywords to max). Default: 20
356
+ Returns:
357
+ A list of strings, where each string is a keyword extracted from the input text.
358
+ """
359
+ logger = logging.getLogger(__name__)
360
+ # Remove stopwords and tokenize the text into words
361
+ words = [
362
+ word
363
+ for word in re.findall(r"\b\w{3,}\b", text.lower())
364
+ if word not in STOPWORDS
365
+ ]
366
+
367
+ # Create a graph of word co-occurrences within a moving window of words
368
+ cooccur = defaultdict(lambda: defaultdict(int))
369
+ deque_words = deque(maxlen=window_size)
370
+ for word in words:
371
+ for w1, w2 in combinations(deque_words, 2):
372
+ cooccur[w1][w2] += 1
373
+ cooccur[w2][w1] += 1
374
+ deque_words.append(word)
375
+
376
+ # Assign scores to words using a simplified TextRank algorithm
377
+ scores = defaultdict(float)
378
+ for _ in range(10):
379
+ new_scores = defaultdict(float)
380
+ for word, co_words in cooccur.items():
381
+ new_scores[word] = 0.15 + 0.85 * sum(
382
+ cooccur[word][other] / sum(cooccur[other].values()) * scores[other]
383
+ for other in co_words
384
+ )
385
+ scores = new_scores
386
+
387
+ # Sort the words by score and return the top num_keywords keywords
388
+ keywords = sorted(scores, key=scores.get, reverse=True)[:num_keywords]
389
+ logger.debug(f"All keywords: {keywords}")
390
+ # Use fuzzy matching to remove similar keywords
391
+ final_keywords = []
392
+ for keyword in keywords:
393
+ if not any(fuzz.ratio(keyword, other) > 70 for other in final_keywords):
394
+ final_keywords.append(keyword[:kw_max_len])
395
+ logger.debug(f"Keywords (max len. {kw_max_len}):\t{final_keywords}")
396
+ return final_keywords
397
+
398
+
399
+ def saves_summary(
400
+ summarize_output, outpath: str or Path = None, add_signature=True, **kwargs
401
+ ) -> Path:
402
+ """
403
+ saves_summary - save the summary generated from summarize_via_tokenbatches() to a text file
404
+
405
+ summarize_output: output from summarize_via_tokenbatches()
406
+ outpath: path to the output file
407
+ add_signature: whether to add a signature to the output file
408
+ kwargs: additional keyword arguments to include in the output file
409
+ """
410
+ logger = logging.getLogger(__name__)
411
+ sum_text = [f"{s['summary'][0]}\n" for s in summarize_output]
412
+ sum_scores = [f"\n - {round(s['summary_score'],4)}" for s in summarize_output]
413
+ scores_text = "\n".join(sum_scores)
414
+ full_summary = "\n".join(sum_text)
415
+
416
+ keywords = "_".join(extract_keywords(full_summary, kw_max_len=4))
417
+ logger.debug(f"kw:\t{keywords}")
418
+ outpath = (
419
+ Path.cwd() / f"DocSumm_{keywords}_{get_timestamp()}.txt"
420
+ if outpath is None
421
+ else Path(outpath)
422
+ )
423
+ logger.info(f"Saving summary to:\t{outpath.name}")
424
+ with open(
425
+ outpath,
426
+ "w",
427
+ encoding="utf-8",
428
+ ) as fo:
429
+ fo.writelines(full_summary)
430
+ fo.write("\n\n")
431
+ if add_signature:
432
+ fo.write("\n\n---\n\n")
433
+ fo.write("Generated with the Document Summarization space :)\n\n")
434
+ fo.write("https://hf.co/spaces/pszemraj/document-summarization\n\n")
435
+ with open(
436
+ outpath,
437
+ "a",
438
+ encoding="utf-8",
439
+ ) as fo:
440
+ fo.write("\n")
441
+ fo.write("## Section Scores:\n\n")
442
+ fo.writelines(scores_text)
443
+ fo.write("\n\n")
444
+ fo.write(f"Date: {get_timestamp()}\n\n")
445
+ if kwargs:
446
+ fo.write("---\n\n")
447
+ fo.write("## Parameters:\n\n")
448
+ for key, value in kwargs.items():
449
+ fo.write(f"{key}: {value}\n")
450
+ return str(outpath.resolve())