Spaces:
Sleeping
Sleeping
Doc Summarizer version 1
Browse files- README.md +29 -9
- aggregate.py +192 -0
- app.py +666 -0
- gitattributes +31 -0
- gitignore +29 -0
- pdf2text.py +346 -0
- requirements.txt +12 -0
- summarize.py +177 -0
- utils.py +450 -0
README.md
CHANGED
@@ -1,14 +1,34 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
-
sdk_version: 5.
|
8 |
app_file: app.py
|
9 |
-
pinned:
|
10 |
-
license:
|
11 |
-
short_description:
|
12 |
---
|
13 |
|
14 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
+
title: Document Summarization
|
3 |
+
emoji: 🌖
|
4 |
+
colorFrom: gray
|
5 |
+
colorTo: indigo
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 5.12.0
|
8 |
app_file: app.py
|
9 |
+
pinned: true
|
10 |
+
license: apache-2.0
|
11 |
+
short_description: text2text models for document summarization
|
12 |
---
|
13 |
|
14 |
+
Check out the configuration reference at <https://huggingface.co/docs/hub/spaces-config-reference>
|
15 |
+
|
16 |
+
# README - Document Summarization
|
17 |
+
|
18 |
+
The original demo/what this repo was built for can be found [here](https://huggingface.co/spaces/pszemraj/document-summarization)
|
19 |
+
|
20 |
+
## Usage
|
21 |
+
|
22 |
+
If you are using this **not** as a gradio demo on hf spaces, you can run it locally with:
|
23 |
+
|
24 |
+
```bash
|
25 |
+
python app.py --share
|
26 |
+
```
|
27 |
+
|
28 |
+
To see all the available arguments, run `python app.py --help`.
|
29 |
+
|
30 |
+
## Installation
|
31 |
+
|
32 |
+
```bash
|
33 |
+
pip install -r requirements.txt
|
34 |
+
```
|
aggregate.py
ADDED
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
aggregate.py - module for 'reducing' multiple 'summary chunks' into one
|
3 |
+
|
4 |
+
an overly complicated class for legacy compatibility reasons, for usage of the
|
5 |
+
2024 map-reduce models see hf.co/pszemraj/bart-large-summary-map-reduce#usage
|
6 |
+
"""
|
7 |
+
|
8 |
+
import logging
|
9 |
+
import pprint as pp
|
10 |
+
import time
|
11 |
+
|
12 |
+
import torch
|
13 |
+
from transformers import GenerationConfig, pipeline
|
14 |
+
|
15 |
+
# Setting up logging
|
16 |
+
logging.basicConfig(
|
17 |
+
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
18 |
+
)
|
19 |
+
|
20 |
+
|
21 |
+
class BatchAggregator:
|
22 |
+
"""
|
23 |
+
BatchAggregator is a class for aggregating text from multiple sources.
|
24 |
+
|
25 |
+
Usage:
|
26 |
+
from aggregate import BatchAggregator
|
27 |
+
aggregator = BatchAggregator()
|
28 |
+
agg = aggregator.infer_aggregate(["This is a test", "This is another test"])
|
29 |
+
print(agg)
|
30 |
+
"""
|
31 |
+
|
32 |
+
GENERIC_CONFIG = GenerationConfig(
|
33 |
+
max_new_tokens=512,
|
34 |
+
num_beams=4,
|
35 |
+
early_stopping=True,
|
36 |
+
do_sample=False,
|
37 |
+
truncation=True,
|
38 |
+
)
|
39 |
+
|
40 |
+
def __init__(
|
41 |
+
self,
|
42 |
+
model_name: str = "pszemraj/bart-large-summary-map-reduce",
|
43 |
+
force_cpu: bool = False,
|
44 |
+
**kwargs,
|
45 |
+
):
|
46 |
+
"""
|
47 |
+
__init__ initializes the BatchAggregator class.
|
48 |
+
|
49 |
+
:param str model_name: model name to use, default: "pszemraj/bart-large-summary-map-reduce"
|
50 |
+
:param bool force_cpu: force the model to run on CPU, default: False
|
51 |
+
"""
|
52 |
+
self.device = None
|
53 |
+
self.is_compiled = False
|
54 |
+
self.model_name = None
|
55 |
+
self.aggregator = None
|
56 |
+
self.force_cpu = force_cpu
|
57 |
+
self.logger = logging.getLogger(__name__)
|
58 |
+
self.init_model(model_name)
|
59 |
+
|
60 |
+
def init_model(self, model_name: str) -> None:
|
61 |
+
"""
|
62 |
+
Initialize the model.
|
63 |
+
|
64 |
+
:param model_name: The name of the model to use.
|
65 |
+
"""
|
66 |
+
# Free up memory
|
67 |
+
if torch.cuda.is_available():
|
68 |
+
torch.cuda.empty_cache()
|
69 |
+
|
70 |
+
self.logger.info(f"Setting model to {model_name}")
|
71 |
+
self.model_name = model_name
|
72 |
+
self.aggregator = self._create_pipeline(model_name)
|
73 |
+
self._configure_model()
|
74 |
+
|
75 |
+
def _create_pipeline(
|
76 |
+
self, model_name: str = "pszemraj/bart-large-summary-map-reduce"
|
77 |
+
) -> pipeline:
|
78 |
+
"""
|
79 |
+
_create_pipeline creates a pipeline for the model.
|
80 |
+
|
81 |
+
:param str model_name: model name to use
|
82 |
+
:return pipeline: the pipeline for the model
|
83 |
+
|
84 |
+
:raises Exception: if the pipeline cannot be created
|
85 |
+
"""
|
86 |
+
device_map = (
|
87 |
+
"auto" if torch.cuda.is_available() and not self.force_cpu else "cpu"
|
88 |
+
)
|
89 |
+
try:
|
90 |
+
self.logger.info(
|
91 |
+
f"Creating pipeline with model {model_name} on device {device_map}"
|
92 |
+
)
|
93 |
+
return pipeline(
|
94 |
+
"text2text-generation",
|
95 |
+
model=model_name,
|
96 |
+
device_map=device_map,
|
97 |
+
torch_dtype=torch.float32,
|
98 |
+
)
|
99 |
+
except Exception as e:
|
100 |
+
self.logger.error(f"Failed to create pipeline: {e}")
|
101 |
+
raise
|
102 |
+
|
103 |
+
def _configure_model(self):
|
104 |
+
"""
|
105 |
+
Configure the model for generation.
|
106 |
+
"""
|
107 |
+
try:
|
108 |
+
self.aggregator.model = torch.compile(self.aggregator.model)
|
109 |
+
self.is_compiled = True
|
110 |
+
except Exception as e:
|
111 |
+
self.logger.warning(f"Could not compile model with Torch 2.0: {e}")
|
112 |
+
|
113 |
+
self._set_default_generation_config()
|
114 |
+
self.logger.info(self.aggregator.model.generation_config.to_json_string())
|
115 |
+
|
116 |
+
def _set_default_generation_config(self):
|
117 |
+
"""
|
118 |
+
Set the default generation configuration for the model.
|
119 |
+
"""
|
120 |
+
self.aggregator.model.generation_config.update(
|
121 |
+
**self.GENERIC_CONFIG.to_diff_dict()
|
122 |
+
)
|
123 |
+
|
124 |
+
def update_generation_config(self, **kwargs):
|
125 |
+
"""
|
126 |
+
Update the generation configuration with the specified parameters.
|
127 |
+
|
128 |
+
Args:
|
129 |
+
**kwargs: The parameters to update in the generation configuration.
|
130 |
+
"""
|
131 |
+
self.logger.info(f"Updating generation config with {pp.pformat(kwargs)}")
|
132 |
+
self.aggregator.model.generation_config.update(**kwargs)
|
133 |
+
|
134 |
+
def get_generation_config(self) -> dict:
|
135 |
+
"""
|
136 |
+
Get the current generation configuration.
|
137 |
+
|
138 |
+
Returns:
|
139 |
+
dict: The current generation configuration.
|
140 |
+
"""
|
141 |
+
return self.aggregator.model.generation_config.to_dict()
|
142 |
+
|
143 |
+
def update_loglevel(self, level: str = "INFO"):
|
144 |
+
"""
|
145 |
+
Update the log level.
|
146 |
+
|
147 |
+
Args:
|
148 |
+
level (str): The log level to set. Defaults to "INFO".
|
149 |
+
"""
|
150 |
+
self.logger.setLevel(level)
|
151 |
+
|
152 |
+
def infer_aggregate(
|
153 |
+
self,
|
154 |
+
text_list: list,
|
155 |
+
instruction: str = None, # Kept for backward compatibility but not used
|
156 |
+
**kwargs,
|
157 |
+
) -> str:
|
158 |
+
"""
|
159 |
+
infer_aggregate - infers a consolidated summary from a list of texts.
|
160 |
+
|
161 |
+
Args:
|
162 |
+
text_list (list): The texts to summarize.
|
163 |
+
instruction (str): Not used by this model, kept for compatibility.
|
164 |
+
**kwargs: Additional parameters to update in the generation configuration.
|
165 |
+
|
166 |
+
Returns:
|
167 |
+
The generated summary.
|
168 |
+
"""
|
169 |
+
joined_text = "\n\n".join(text_list)
|
170 |
+
if kwargs:
|
171 |
+
self.update_generation_config(**kwargs)
|
172 |
+
st = time.perf_counter()
|
173 |
+
self.logger.info(f"inference on {len(text_list)} texts ...")
|
174 |
+
result = self.aggregator(
|
175 |
+
joined_text,
|
176 |
+
generation_config=self.aggregator.model.generation_config,
|
177 |
+
)[0]["generated_text"]
|
178 |
+
self.logger.info(f"Done. runtime:\t{round(time.perf_counter() - st, 2)}s")
|
179 |
+
self.logger.info(
|
180 |
+
f"Input tokens:\t{self.count_tokens(joined_text)}. Output tokens:\t{self.count_tokens(result)}"
|
181 |
+
)
|
182 |
+
self.logger.debug(f"Generated text:\n{result}")
|
183 |
+
|
184 |
+
return result
|
185 |
+
|
186 |
+
def count_tokens(self, text: str) -> int:
|
187 |
+
"""count the number of tokens in a text"""
|
188 |
+
return (
|
189 |
+
len(self.aggregator.tokenizer.encode(text, truncation=False, padding=False))
|
190 |
+
if text
|
191 |
+
else 0
|
192 |
+
)
|
app.py
ADDED
@@ -0,0 +1,666 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
app.py - the main module for the gradio app for summarization
|
3 |
+
|
4 |
+
Usage:
|
5 |
+
app.py [-h] [--share] [-m MODEL] [-nb ADD_BEAM_OPTION] [-batch TOKEN_BATCH_OPTION]
|
6 |
+
[-level {DEBUG,INFO,WARNING,ERROR}]
|
7 |
+
Details:
|
8 |
+
python app.py --help
|
9 |
+
|
10 |
+
Environment Variables:
|
11 |
+
USE_TORCH (str): whether to use torch (1) or not (0)
|
12 |
+
TOKENIZERS_PARALLELISM (str): whether to use parallelism (true) or not (false)
|
13 |
+
Optional Environment Variables:
|
14 |
+
APP_MAX_WORDS (int): the maximum number of words to use for summarization
|
15 |
+
APP_OCR_MAX_PAGES (int): the maximum number of pages to use for OCR
|
16 |
+
"""
|
17 |
+
|
18 |
+
import argparse
|
19 |
+
import contextlib
|
20 |
+
import gc
|
21 |
+
import logging
|
22 |
+
import os
|
23 |
+
import pprint as pp
|
24 |
+
import random
|
25 |
+
import time
|
26 |
+
from pathlib import Path
|
27 |
+
|
28 |
+
os.environ["USE_TORCH"] = "1"
|
29 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
30 |
+
|
31 |
+
logging.basicConfig(
|
32 |
+
level=logging.INFO,
|
33 |
+
format="%(asctime)s [%(levelname)s] %(name)s - %(message)s",
|
34 |
+
datefmt="%Y-%b-%d %H:%M:%S",
|
35 |
+
)
|
36 |
+
|
37 |
+
import gradio as gr
|
38 |
+
import nltk
|
39 |
+
import torch
|
40 |
+
from cleantext import clean
|
41 |
+
from doctr.models import ocr_predictor
|
42 |
+
|
43 |
+
from aggregate import BatchAggregator
|
44 |
+
from pdf2text import convert_PDF_to_Text
|
45 |
+
from summarize import load_model_and_tokenizer, summarize_via_tokenbatches
|
46 |
+
from utils import (
|
47 |
+
contraction_aware_tokenize,
|
48 |
+
extract_batches,
|
49 |
+
load_example_filenames,
|
50 |
+
remove_stagnant_files,
|
51 |
+
remove_stopwords,
|
52 |
+
saves_summary,
|
53 |
+
textlist2html,
|
54 |
+
truncate_word_count,
|
55 |
+
)
|
56 |
+
|
57 |
+
_here = Path(__file__).parent
|
58 |
+
|
59 |
+
nltk.download("punkt", force=True, quiet=True)
|
60 |
+
nltk.download("popular", force=True, quiet=True)
|
61 |
+
|
62 |
+
# Constants & Globals
|
63 |
+
MODEL_OPTIONS = [
|
64 |
+
"BEE-spoke-data/pegasus-x-base-synthsumm_open-16k",
|
65 |
+
"pszemraj/long-t5-tglobal-base-sci-simplify",
|
66 |
+
"pszemraj/long-t5-tglobal-base-16384-book-summary",
|
67 |
+
"pszemraj/long-t5-tglobal-base-summary-souffle-16384-loD",
|
68 |
+
"pszemraj/pegasus-x-large-book_synthsumm",
|
69 |
+
"pszemraj/pegasus-x-large-book-summary",
|
70 |
+
] # models users can choose from
|
71 |
+
BEAM_OPTIONS = [2, 3, 4] # beam sizes users can choose from
|
72 |
+
TOKEN_BATCH_OPTIONS = [
|
73 |
+
1024,
|
74 |
+
1536,
|
75 |
+
2048,
|
76 |
+
2560,
|
77 |
+
3072,
|
78 |
+
] # token batch sizes users can choose from
|
79 |
+
|
80 |
+
SUMMARY_PLACEHOLDER = "<p><em>Output will appear below:</em></p>"
|
81 |
+
AGGREGATE_MODEL = "pszemraj/bart-large-summary-map-reduce" # map-reduce model
|
82 |
+
|
83 |
+
# if duplicating space: uncomment this line to adjust the max words
|
84 |
+
# os.environ["APP_MAX_WORDS"] = str(2048) # set the max words to 2048
|
85 |
+
# os.environ["APP_OCR_MAX_PAGES"] = str(40) # set the max pages to 40
|
86 |
+
# os.environ["APP_AGG_FORCE_CPU"] = str(1) # force cpu for aggregation
|
87 |
+
|
88 |
+
aggregator = BatchAggregator(
|
89 |
+
AGGREGATE_MODEL, force_cpu=os.environ.get("APP_AGG_FORCE_CPU", False)
|
90 |
+
)
|
91 |
+
|
92 |
+
|
93 |
+
def aggregate_text(
|
94 |
+
summary_text: str,
|
95 |
+
text_file: gr.File = None,
|
96 |
+
) -> str:
|
97 |
+
"""
|
98 |
+
Aggregate the text from the batches.
|
99 |
+
|
100 |
+
NOTE: you should probably include the BatchAggregator object as a fn arg if using this code
|
101 |
+
|
102 |
+
:param batches_html: The batches to aggregate, in html format
|
103 |
+
:param text_file: The text file to append the aggregate summary to
|
104 |
+
:return: The aggregate summary in html format
|
105 |
+
"""
|
106 |
+
if summary_text is None or summary_text == SUMMARY_PLACEHOLDER:
|
107 |
+
logging.error("No text provided. Make sure a summary has been generated first.")
|
108 |
+
return "Error: No text provided. Make sure a summary has been generated first."
|
109 |
+
|
110 |
+
try:
|
111 |
+
extracted_batches = extract_batches(summary_text)
|
112 |
+
except Exception as e:
|
113 |
+
logging.info(summary_text)
|
114 |
+
logging.info(f"the batches html is: {type(summary_text)}")
|
115 |
+
return f"Error: unable to extract batches - check input: {e}"
|
116 |
+
if not extracted_batches:
|
117 |
+
logging.error("unable to extract batches - check input")
|
118 |
+
return "Error: unable to extract batches - check input"
|
119 |
+
|
120 |
+
out_path = None
|
121 |
+
if text_file is not None:
|
122 |
+
out_path = text_file.name # assuming name attribute stores the file path
|
123 |
+
|
124 |
+
content_batches = [batch["content"] for batch in extracted_batches]
|
125 |
+
full_summary = aggregator.infer_aggregate(content_batches)
|
126 |
+
|
127 |
+
# if a path that exists is provided, append the summary with markdown formatting
|
128 |
+
if out_path:
|
129 |
+
out_path = Path(out_path)
|
130 |
+
|
131 |
+
try:
|
132 |
+
with open(out_path, "a", encoding="utf-8") as f:
|
133 |
+
f.write("\n\n## Aggregate Summary\n\n")
|
134 |
+
f.write(
|
135 |
+
"- This is an instruction-based LLM aggregation of the previous 'summary batches'.\n"
|
136 |
+
)
|
137 |
+
f.write(f"- Aggregation model: {aggregator.model_name}\n\n")
|
138 |
+
f.write(f"{full_summary}\n\n")
|
139 |
+
logging.info(f"Updated {out_path} with aggregate summary")
|
140 |
+
except Exception as e:
|
141 |
+
logging.error(f"unable to update {out_path} with aggregate summary: {e}")
|
142 |
+
|
143 |
+
full_summary_html = f"""
|
144 |
+
<div style="
|
145 |
+
margin-bottom: 20px;
|
146 |
+
font-size: 18px;
|
147 |
+
line-height: 1.5em;
|
148 |
+
color: #333;
|
149 |
+
">
|
150 |
+
<h2 style="font-size: 22px; color: #555;">Aggregate Summary:</h2>
|
151 |
+
<p style="white-space: pre-line;">{full_summary}</p>
|
152 |
+
</div>
|
153 |
+
"""
|
154 |
+
return full_summary_html
|
155 |
+
|
156 |
+
|
157 |
+
def predict(
|
158 |
+
input_text: str,
|
159 |
+
model_name: str,
|
160 |
+
token_batch_length: int = 1024,
|
161 |
+
empty_cache: bool = True,
|
162 |
+
**settings,
|
163 |
+
) -> list:
|
164 |
+
"""
|
165 |
+
predict - helper fn to support multiple models for summarization at once
|
166 |
+
|
167 |
+
:param str input_text: the input text to summarize
|
168 |
+
:param str model_name: model name to use
|
169 |
+
:param int token_batch_length: the length of the token batches to use
|
170 |
+
:param bool empty_cache: whether to empty the cache before loading a new= model
|
171 |
+
:return: list of dicts with keys "summary" and "score"
|
172 |
+
"""
|
173 |
+
if torch.cuda.is_available() and empty_cache:
|
174 |
+
torch.cuda.empty_cache()
|
175 |
+
|
176 |
+
model, tokenizer = load_model_and_tokenizer(model_name)
|
177 |
+
summaries = summarize_via_tokenbatches(
|
178 |
+
input_text,
|
179 |
+
model,
|
180 |
+
tokenizer,
|
181 |
+
batch_length=token_batch_length,
|
182 |
+
**settings,
|
183 |
+
)
|
184 |
+
|
185 |
+
del model
|
186 |
+
del tokenizer
|
187 |
+
gc.collect()
|
188 |
+
|
189 |
+
return summaries
|
190 |
+
|
191 |
+
|
192 |
+
def proc_submission(
|
193 |
+
input_text: str,
|
194 |
+
model_name: str,
|
195 |
+
num_beams: int,
|
196 |
+
token_batch_length: int,
|
197 |
+
length_penalty: float,
|
198 |
+
repetition_penalty: float,
|
199 |
+
no_repeat_ngram_size: int,
|
200 |
+
predrop_stopwords: bool,
|
201 |
+
max_input_length: int = 6144,
|
202 |
+
):
|
203 |
+
"""
|
204 |
+
proc_submission - a helper function for the gradio module to process submissions
|
205 |
+
|
206 |
+
Args:
|
207 |
+
input_text (str): the input text to summarize
|
208 |
+
model_name (str): the hf model tag of the model to use
|
209 |
+
num_beams (int): the number of beams to use
|
210 |
+
token_batch_length (int): the length of the token batches to use
|
211 |
+
length_penalty (float): the length penalty to use
|
212 |
+
repetition_penalty (float): the repetition penalty to use
|
213 |
+
no_repeat_ngram_size (int): the no repeat ngram size to use
|
214 |
+
predrop_stopwords (bool): whether to pre-drop stopwords before truncating/summarizing
|
215 |
+
max_input_length (int, optional): the maximum input length to use. Defaults to 6144.
|
216 |
+
|
217 |
+
Note:
|
218 |
+
the max_input_length is set to 6144 by default, but can be changed by setting the
|
219 |
+
environment variable APP_MAX_WORDS to a different value.
|
220 |
+
|
221 |
+
Returns:
|
222 |
+
tuple (4): a tuple containing the following:
|
223 |
+
"""
|
224 |
+
|
225 |
+
remove_stagnant_files() # clean up old files
|
226 |
+
settings = {
|
227 |
+
"length_penalty": float(length_penalty),
|
228 |
+
"repetition_penalty": float(repetition_penalty),
|
229 |
+
"no_repeat_ngram_size": int(no_repeat_ngram_size),
|
230 |
+
"encoder_no_repeat_ngram_size": 4,
|
231 |
+
"num_beams": int(num_beams),
|
232 |
+
"min_length": 4,
|
233 |
+
"max_length": int(token_batch_length // 4),
|
234 |
+
"early_stopping": True,
|
235 |
+
"do_sample": False,
|
236 |
+
}
|
237 |
+
max_input_length = int(os.environ.get("APP_MAX_WORDS", max_input_length))
|
238 |
+
logging.info(
|
239 |
+
f"max_input_length set to: {max_input_length}. pre-drop stopwords: {predrop_stopwords}"
|
240 |
+
)
|
241 |
+
|
242 |
+
st = time.perf_counter()
|
243 |
+
history = {}
|
244 |
+
cln_text = clean(input_text, lower=False)
|
245 |
+
parsed_cln_text = remove_stopwords(cln_text) if predrop_stopwords else cln_text
|
246 |
+
logging.info(
|
247 |
+
f"pre-truncation word count: {len(contraction_aware_tokenize(parsed_cln_text))}"
|
248 |
+
)
|
249 |
+
truncation_validated = truncate_word_count(
|
250 |
+
parsed_cln_text, max_words=max_input_length
|
251 |
+
)
|
252 |
+
|
253 |
+
if truncation_validated["was_truncated"]:
|
254 |
+
model_input_text = truncation_validated["processed_text"]
|
255 |
+
# create elaborate HTML warning
|
256 |
+
input_wc = len(contraction_aware_tokenize(parsed_cln_text))
|
257 |
+
msg = f"""
|
258 |
+
<div style="background-color: #FFA500; color: white; padding: 20px;">
|
259 |
+
<h3>Warning</h3>
|
260 |
+
<p>Input text was truncated to {max_input_length} words. That's about {100*max_input_length/input_wc:.2f}% of the original text.</p>
|
261 |
+
<p>Dropping stopwords is set to {predrop_stopwords}. If this is not what you intended, please validate the advanced settings.</p>
|
262 |
+
</div>
|
263 |
+
"""
|
264 |
+
logging.warning(msg)
|
265 |
+
history["WARNING"] = msg
|
266 |
+
else:
|
267 |
+
model_input_text = truncation_validated["processed_text"]
|
268 |
+
msg = None
|
269 |
+
|
270 |
+
if len(input_text) < 50:
|
271 |
+
# this is essentially a different case from the above
|
272 |
+
msg = f"""
|
273 |
+
<div style="background-color: #880808; color: white; padding: 20px;">
|
274 |
+
<br>
|
275 |
+
<img src="https://i.imgflip.com/7kadd9.jpg" alt="no text">
|
276 |
+
<br>
|
277 |
+
<h3>Error</h3>
|
278 |
+
<p>Input text is too short to summarize. Detected {len(input_text)} characters.
|
279 |
+
Please load text by selecting an example from the dropdown menu or by pasting text into the text box.</p>
|
280 |
+
</div>
|
281 |
+
"""
|
282 |
+
logging.warning(msg)
|
283 |
+
logging.warning("RETURNING EMPTY STRING")
|
284 |
+
history["WARNING"] = msg
|
285 |
+
|
286 |
+
return msg, "<strong>No summary generated.</strong>", "", []
|
287 |
+
|
288 |
+
_summaries = predict(
|
289 |
+
input_text=model_input_text,
|
290 |
+
model_name=model_name,
|
291 |
+
token_batch_length=token_batch_length,
|
292 |
+
**settings,
|
293 |
+
)
|
294 |
+
sum_text = [s["summary"][0].strip() + "\n" for s in _summaries]
|
295 |
+
sum_scores = [
|
296 |
+
f" - Batch Summary {i}: {round(s['summary_score'],4)}"
|
297 |
+
for i, s in enumerate(_summaries)
|
298 |
+
]
|
299 |
+
|
300 |
+
full_summary = textlist2html(sum_text)
|
301 |
+
history["Summary Scores"] = "<br><br>"
|
302 |
+
scores_out = "\n".join(sum_scores)
|
303 |
+
rt = round((time.perf_counter() - st) / 60, 2)
|
304 |
+
logging.info(f"Runtime: {rt} minutes")
|
305 |
+
html = ""
|
306 |
+
html += f"<p>Runtime: {rt} minutes with model: {model_name}</p>"
|
307 |
+
if msg is not None:
|
308 |
+
html += msg
|
309 |
+
|
310 |
+
html += ""
|
311 |
+
|
312 |
+
settings["remove_stopwords"] = predrop_stopwords
|
313 |
+
settings["model_name"] = model_name
|
314 |
+
saved_file = saves_summary(summarize_output=_summaries, outpath=None, **settings)
|
315 |
+
return html, full_summary, scores_out, saved_file
|
316 |
+
|
317 |
+
|
318 |
+
def load_single_example_text(
|
319 |
+
example_path: str or Path,
|
320 |
+
max_pages: int = 20,
|
321 |
+
) -> str:
|
322 |
+
"""
|
323 |
+
load_single_example_text - loads a single example text file
|
324 |
+
|
325 |
+
:param strorPath example_path: name of the example to load
|
326 |
+
:param int max_pages: the maximum number of pages to load from a PDF
|
327 |
+
:return str: the text of the example
|
328 |
+
"""
|
329 |
+
global name_to_path, ocr_model
|
330 |
+
full_ex_path = name_to_path[example_path]
|
331 |
+
full_ex_path = Path(full_ex_path)
|
332 |
+
if full_ex_path.suffix in [".txt", ".md"]:
|
333 |
+
with open(full_ex_path, "r", encoding="utf-8", errors="ignore") as f:
|
334 |
+
raw_text = f.read()
|
335 |
+
text = clean(raw_text, lower=False)
|
336 |
+
elif full_ex_path.suffix == ".pdf":
|
337 |
+
logging.info(f"Loading PDF file {full_ex_path}")
|
338 |
+
max_pages = int(os.environ.get("APP_OCR_MAX_PAGES", max_pages))
|
339 |
+
logging.info(f"max_pages set to: {max_pages}")
|
340 |
+
conversion_stats = convert_PDF_to_Text(
|
341 |
+
full_ex_path,
|
342 |
+
ocr_model=ocr_model,
|
343 |
+
max_pages=max_pages,
|
344 |
+
)
|
345 |
+
text = conversion_stats["converted_text"]
|
346 |
+
else:
|
347 |
+
logging.error(f"Unknown file type {full_ex_path.suffix}")
|
348 |
+
text = "ERROR - check example path"
|
349 |
+
|
350 |
+
return text
|
351 |
+
|
352 |
+
|
353 |
+
def load_uploaded_file(file_obj, max_pages: int = 20, lower: bool = False) -> str:
|
354 |
+
"""
|
355 |
+
load_uploaded_file - loads a file uploaded by the user
|
356 |
+
|
357 |
+
:param file_obj (POTENTIALLY list): Gradio file object inside a list
|
358 |
+
:param int max_pages: the maximum number of pages to load from a PDF
|
359 |
+
:param bool lower: whether to lowercase the text
|
360 |
+
:return str: the text of the file
|
361 |
+
"""
|
362 |
+
global ocr_model
|
363 |
+
logger = logging.getLogger(__name__)
|
364 |
+
# check if mysterious file object is a list
|
365 |
+
if isinstance(file_obj, list):
|
366 |
+
file_obj = file_obj[0]
|
367 |
+
file_path = Path(file_obj.name)
|
368 |
+
try:
|
369 |
+
logger.info(f"Loading file:\t{file_path}")
|
370 |
+
if file_path.suffix in [".txt", ".md"]:
|
371 |
+
with open(file_path, "r", encoding="utf-8", errors="ignore") as f:
|
372 |
+
raw_text = f.read()
|
373 |
+
text = clean(raw_text, lower=lower)
|
374 |
+
elif file_path.suffix == ".pdf":
|
375 |
+
logger.info(f"loading a PDF file: {file_path.name}")
|
376 |
+
max_pages = int(os.environ.get("APP_OCR_MAX_PAGES", max_pages))
|
377 |
+
logger.info(f"max_pages is: {max_pages}. Starting conversion...")
|
378 |
+
conversion_stats = convert_PDF_to_Text(
|
379 |
+
file_path,
|
380 |
+
ocr_model=ocr_model,
|
381 |
+
max_pages=max_pages,
|
382 |
+
)
|
383 |
+
text = conversion_stats["converted_text"]
|
384 |
+
else:
|
385 |
+
logger.error(f"Unknown file type:\t{file_path.suffix}")
|
386 |
+
text = "ERROR - check file - unknown file type. PDF, TXT, and MD are supported."
|
387 |
+
|
388 |
+
return text
|
389 |
+
except Exception as e:
|
390 |
+
logger.error(f"Trying to load file:\t{file_path},\nerror:\t{e}")
|
391 |
+
return f"Error: Could not read file {file_path.name}. Make sure it is a PDF, TXT, or MD file."
|
392 |
+
|
393 |
+
|
394 |
+
def parse_args():
|
395 |
+
"""arguments for the command line interface"""
|
396 |
+
parser = argparse.ArgumentParser(
|
397 |
+
description="Document Summarization with Long-Document Transformers - Demo",
|
398 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
399 |
+
epilog="Runs a local-only web UI to summarize documents. pass --share for a public link to share.",
|
400 |
+
)
|
401 |
+
|
402 |
+
parser.add_argument(
|
403 |
+
"--share",
|
404 |
+
dest="share",
|
405 |
+
action="store_true",
|
406 |
+
help="Create a public link to share",
|
407 |
+
)
|
408 |
+
parser.add_argument(
|
409 |
+
"-m",
|
410 |
+
"--model",
|
411 |
+
type=str,
|
412 |
+
default=None,
|
413 |
+
help=f"Add a custom model to the list of models: {pp.pformat(MODEL_OPTIONS, compact=True)}",
|
414 |
+
)
|
415 |
+
parser.add_argument(
|
416 |
+
"-nb",
|
417 |
+
"--add_beam_option",
|
418 |
+
type=int,
|
419 |
+
default=None,
|
420 |
+
help=f"Add a beam search option to the demo UI options, default: {pp.pformat(BEAM_OPTIONS, compact=True)}",
|
421 |
+
)
|
422 |
+
parser.add_argument(
|
423 |
+
"-batch",
|
424 |
+
"--token_batch_option",
|
425 |
+
type=int,
|
426 |
+
default=None,
|
427 |
+
help=f"Add a token batch size to the demo UI options, default: {pp.pformat(TOKEN_BATCH_OPTIONS, compact=True)}",
|
428 |
+
)
|
429 |
+
parser.add_argument(
|
430 |
+
"-max_agg",
|
431 |
+
"-2x",
|
432 |
+
"--aggregator_beam_boost",
|
433 |
+
dest="aggregator_beam_boost",
|
434 |
+
action="store_true",
|
435 |
+
help="Double the number of beams for the aggregator during beam search",
|
436 |
+
)
|
437 |
+
parser.add_argument(
|
438 |
+
"-level",
|
439 |
+
"--log_level",
|
440 |
+
type=str,
|
441 |
+
default="INFO",
|
442 |
+
choices=["DEBUG", "INFO", "WARNING", "ERROR"],
|
443 |
+
help="Set the logging level",
|
444 |
+
)
|
445 |
+
|
446 |
+
return parser.parse_args()
|
447 |
+
|
448 |
+
|
449 |
+
if __name__ == "__main__":
|
450 |
+
"""main - the main function of the app"""
|
451 |
+
logger = logging.getLogger(__name__)
|
452 |
+
args = parse_args()
|
453 |
+
logger.setLevel(args.log_level)
|
454 |
+
logger.info(f"args: {pp.pformat(args.__dict__, compact=True)}")
|
455 |
+
|
456 |
+
# add any custom options
|
457 |
+
if args.model is not None:
|
458 |
+
logger.info(f"Adding model {args.model} to the list of models")
|
459 |
+
MODEL_OPTIONS.append(args.model)
|
460 |
+
if args.add_beam_option is not None:
|
461 |
+
logger.info(f"Adding beam search option {args.add_beam_option} to the list")
|
462 |
+
BEAM_OPTIONS.append(args.add_beam_option)
|
463 |
+
if args.token_batch_option is not None:
|
464 |
+
logger.info(f"Adding token batch option {args.token_batch_option} to the list")
|
465 |
+
TOKEN_BATCH_OPTIONS.append(args.token_batch_option)
|
466 |
+
|
467 |
+
if args.aggregator_beam_boost:
|
468 |
+
logger.info("Doubling aggregator num_beams")
|
469 |
+
_agg_cfg = aggregator.get_generation_config()
|
470 |
+
_agg_cfg["num_beams"] = _agg_cfg["num_beams"] * 2
|
471 |
+
aggregator.update_generation_config(**_agg_cfg)
|
472 |
+
|
473 |
+
logger.info("Loading OCR model")
|
474 |
+
with contextlib.redirect_stdout(None):
|
475 |
+
ocr_model = ocr_predictor(
|
476 |
+
"db_resnet50",
|
477 |
+
"crnn_mobilenet_v3_large",
|
478 |
+
pretrained=True,
|
479 |
+
assume_straight_pages=True,
|
480 |
+
)
|
481 |
+
|
482 |
+
# load the examples
|
483 |
+
name_to_path = load_example_filenames(_here / "examples")
|
484 |
+
logger.info(f"Loaded {len(name_to_path)} examples")
|
485 |
+
|
486 |
+
demo = gr.Blocks(title="Document Summarization")
|
487 |
+
_examples = list(name_to_path.keys())
|
488 |
+
logger.info("Starting app instance")
|
489 |
+
with demo:
|
490 |
+
gr.Markdown(
|
491 |
+
"""# Document Summarization with Long-Document Transformers
|
492 |
+
|
493 |
+
An example use case for fine-tuned long document transformers. Model(s) are trained on [book summaries](https://hf.co/datasets/kmfoda/booksum). Architectures [in this demo](https://hf.co/spaces/pszemraj/document-summarization) are [LongT5-base](https://hf.co/pszemraj/long-t5-tglobal-base-16384-book-summary) and [Pegasus-X-Large](https://hf.co/pszemraj/pegasus-x-large-book-summary).
|
494 |
+
|
495 |
+
**Want more performance?** Run this demo from a free [Google Colab GPU](https://colab.research.google.com/gist/pszemraj/52f67cf7326e780155812a6a1f9bb724/document-summarization-on-gpu.ipynb)
|
496 |
+
"""
|
497 |
+
)
|
498 |
+
with gr.Column():
|
499 |
+
gr.Markdown(
|
500 |
+
"""## Load Inputs & Select Parameters
|
501 |
+
|
502 |
+
Enter/paste text below, or upload a file. Pick a model & adjust params (_optional_), and press **Summarize!**
|
503 |
+
|
504 |
+
See [the guide doc](https://gist.github.com/pszemraj/722a7ba443aa3a671b02d87038375519) for details.
|
505 |
+
"""
|
506 |
+
)
|
507 |
+
with gr.Row():
|
508 |
+
with gr.Column(variant="compact"):
|
509 |
+
model_name = gr.Dropdown(
|
510 |
+
choices=MODEL_OPTIONS,
|
511 |
+
value=MODEL_OPTIONS[0],
|
512 |
+
label="Model Name",
|
513 |
+
)
|
514 |
+
num_beams = gr.Radio(
|
515 |
+
choices=BEAM_OPTIONS,
|
516 |
+
value=BEAM_OPTIONS[len(BEAM_OPTIONS) // 2],
|
517 |
+
label="Beam Search: # of Beams",
|
518 |
+
)
|
519 |
+
load_examples_button = gr.Button(
|
520 |
+
"Load Example in Dropdown",
|
521 |
+
)
|
522 |
+
load_file_button = gr.Button("Upload & Process File")
|
523 |
+
with gr.Column(variant="compact"):
|
524 |
+
example_name = gr.Dropdown(
|
525 |
+
_examples,
|
526 |
+
label="Examples",
|
527 |
+
value=random.choice(_examples),
|
528 |
+
)
|
529 |
+
uploaded_file = gr.File(
|
530 |
+
label="File Upload",
|
531 |
+
file_count="single",
|
532 |
+
file_types=[".txt", ".md", ".pdf"],
|
533 |
+
type="filepath",
|
534 |
+
)
|
535 |
+
with gr.Row():
|
536 |
+
input_text = gr.Textbox(
|
537 |
+
lines=4,
|
538 |
+
max_lines=8,
|
539 |
+
label="Text to Summarize",
|
540 |
+
placeholder="Enter text to summarize, the text will be cleaned and truncated on Spaces. Narrative, academic (both papers and lecture transcription), and article text work well. May take a bit to generate depending on the input text :)",
|
541 |
+
)
|
542 |
+
with gr.Column():
|
543 |
+
gr.Markdown("## Generate Summary")
|
544 |
+
with gr.Row():
|
545 |
+
summarize_button = gr.Button(
|
546 |
+
"Summarize!",
|
547 |
+
variant="primary",
|
548 |
+
)
|
549 |
+
gr.Markdown(
|
550 |
+
"_Summarization should take ~1-2 minutes for most settings, but may extend up to 5-10 minutes in some scenarios._"
|
551 |
+
)
|
552 |
+
output_text = gr.HTML("<em>Output will appear below:</em>")
|
553 |
+
with gr.Column():
|
554 |
+
gr.Markdown("### Results & Scores")
|
555 |
+
with gr.Row():
|
556 |
+
with gr.Column(variant="compact"):
|
557 |
+
gr.Markdown(
|
558 |
+
"Download the summary as a text file, with parameters and scores."
|
559 |
+
)
|
560 |
+
text_file = gr.File(
|
561 |
+
label="Download as Text File",
|
562 |
+
file_count="single",
|
563 |
+
type="filepath",
|
564 |
+
interactive=False,
|
565 |
+
)
|
566 |
+
with gr.Column(variant="compact"):
|
567 |
+
gr.Markdown(
|
568 |
+
"Scores **roughly** represent the summary quality as a measure of the model's 'confidence'. less-negative numbers (closer to 0) are better."
|
569 |
+
)
|
570 |
+
summary_scores = gr.Textbox(
|
571 |
+
label="Summary Scores",
|
572 |
+
placeholder="Summary scores will appear here",
|
573 |
+
)
|
574 |
+
with gr.Column(variant="panel"):
|
575 |
+
gr.Markdown("### **Summary Output**")
|
576 |
+
summary_text = gr.HTML(
|
577 |
+
label="Summary",
|
578 |
+
value="<i>Summary will appear here!</i>",
|
579 |
+
)
|
580 |
+
with gr.Column():
|
581 |
+
gr.Markdown("### **Aggregate Summary Batches**")
|
582 |
+
with gr.Row():
|
583 |
+
aggregate_button = gr.Button(
|
584 |
+
"Aggregate!",
|
585 |
+
variant="primary",
|
586 |
+
)
|
587 |
+
gr.Markdown(
|
588 |
+
f"""Aggregate the above batches into a cohesive summary.
|
589 |
+
- A secondary instruct-tuned LM consolidates info
|
590 |
+
- Current model: [{AGGREGATE_MODEL}](https://hf.co/{AGGREGATE_MODEL})
|
591 |
+
"""
|
592 |
+
)
|
593 |
+
with gr.Column(variant="panel"):
|
594 |
+
aggregated_summary = gr.HTML(
|
595 |
+
label="Aggregate Summary",
|
596 |
+
value="<i>Aggregate summary will appear here!</i>",
|
597 |
+
)
|
598 |
+
|
599 |
+
with gr.Column():
|
600 |
+
gr.Markdown(
|
601 |
+
"""### Advanced Settings
|
602 |
+
|
603 |
+
Refer to [the guide doc](https://gist.github.com/pszemraj/722a7ba443aa3a671b02d87038375519) for what these are, and how they impact _quality_ and _speed_.
|
604 |
+
"""
|
605 |
+
)
|
606 |
+
with gr.Row():
|
607 |
+
length_penalty = gr.Slider(
|
608 |
+
minimum=0.3,
|
609 |
+
maximum=1.1,
|
610 |
+
label="length penalty",
|
611 |
+
value=0.7,
|
612 |
+
step=0.05,
|
613 |
+
)
|
614 |
+
token_batch_length = gr.Radio(
|
615 |
+
choices=TOKEN_BATCH_OPTIONS,
|
616 |
+
label="token batch length",
|
617 |
+
# select median option
|
618 |
+
value=TOKEN_BATCH_OPTIONS[len(TOKEN_BATCH_OPTIONS) // 2],
|
619 |
+
)
|
620 |
+
|
621 |
+
with gr.Row():
|
622 |
+
repetition_penalty = gr.Slider(
|
623 |
+
minimum=1.0,
|
624 |
+
maximum=5.0,
|
625 |
+
label="repetition penalty",
|
626 |
+
value=1.5,
|
627 |
+
step=0.1,
|
628 |
+
)
|
629 |
+
no_repeat_ngram_size = gr.Radio(
|
630 |
+
choices=[2, 3, 4, 5],
|
631 |
+
label="no repeat ngram size",
|
632 |
+
value=3,
|
633 |
+
)
|
634 |
+
predrop_stopwords = gr.Checkbox(
|
635 |
+
label="Drop Stopwords (Pre-Truncation)",
|
636 |
+
value=False,
|
637 |
+
)
|
638 |
+
|
639 |
+
load_examples_button.click(
|
640 |
+
fn=load_single_example_text, inputs=[example_name], outputs=[input_text]
|
641 |
+
)
|
642 |
+
|
643 |
+
load_file_button.click(
|
644 |
+
fn=load_uploaded_file, inputs=uploaded_file, outputs=[input_text]
|
645 |
+
)
|
646 |
+
|
647 |
+
summarize_button.click(
|
648 |
+
fn=proc_submission,
|
649 |
+
inputs=[
|
650 |
+
input_text,
|
651 |
+
model_name,
|
652 |
+
num_beams,
|
653 |
+
token_batch_length,
|
654 |
+
length_penalty,
|
655 |
+
repetition_penalty,
|
656 |
+
no_repeat_ngram_size,
|
657 |
+
predrop_stopwords,
|
658 |
+
],
|
659 |
+
outputs=[output_text, summary_text, summary_scores, text_file],
|
660 |
+
)
|
661 |
+
aggregate_button.click(
|
662 |
+
fn=aggregate_text,
|
663 |
+
inputs=[summary_text, text_file],
|
664 |
+
outputs=[aggregated_summary],
|
665 |
+
)
|
666 |
+
demo.launch(share=args.share, debug=True)
|
gitattributes
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
23 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
26 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
gitignore
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# logs
|
2 |
+
*.log
|
3 |
+
*LOGFILE*
|
4 |
+
|
5 |
+
# output files need to be force-added
|
6 |
+
*.csv
|
7 |
+
*.png
|
8 |
+
*.jpg
|
9 |
+
*.jpeg
|
10 |
+
*.pkl
|
11 |
+
*.xlsx
|
12 |
+
*.txt
|
13 |
+
|
14 |
+
# cache
|
15 |
+
*__pycache__/
|
16 |
+
*.pyc
|
17 |
+
|
18 |
+
# reports folder - need to be force-added
|
19 |
+
*reports/
|
20 |
+
|
21 |
+
# scratch files and folders
|
22 |
+
|
23 |
+
*scratch*
|
24 |
+
*scratch/
|
25 |
+
|
26 |
+
# notebooks
|
27 |
+
|
28 |
+
*notebooks/
|
29 |
+
*.ipynb
|
pdf2text.py
ADDED
@@ -0,0 +1,346 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
"""
|
3 |
+
pdf2text.py - convert pdf files to text files using OCR
|
4 |
+
"""
|
5 |
+
import logging
|
6 |
+
import os
|
7 |
+
import re
|
8 |
+
import shutil
|
9 |
+
import time
|
10 |
+
from datetime import date
|
11 |
+
from os.path import join
|
12 |
+
from pathlib import Path
|
13 |
+
|
14 |
+
logging.basicConfig(
|
15 |
+
level=logging.INFO,
|
16 |
+
format="%(asctime)s %(levelname)s %(message)s",
|
17 |
+
datefmt="%m/%d/%Y %I:%M:%S",
|
18 |
+
)
|
19 |
+
|
20 |
+
|
21 |
+
os.environ["USE_TORCH"] = "1"
|
22 |
+
|
23 |
+
from cleantext import clean
|
24 |
+
from doctr.io import DocumentFile
|
25 |
+
from doctr.models import ocr_predictor
|
26 |
+
from spellchecker import SpellChecker
|
27 |
+
|
28 |
+
|
29 |
+
def simple_rename(filepath, target_ext=".txt"):
|
30 |
+
"""simple_rename - get a new str to rename a file"""
|
31 |
+
_fp = Path(filepath)
|
32 |
+
basename = _fp.stem
|
33 |
+
return f"OCR_{basename}_{target_ext}"
|
34 |
+
|
35 |
+
|
36 |
+
def rm_local_text_files(name_contains="RESULT_"):
|
37 |
+
"""
|
38 |
+
rm_local_text_files - remove local text files
|
39 |
+
"""
|
40 |
+
files = [
|
41 |
+
f
|
42 |
+
for f in Path.cwd().iterdir()
|
43 |
+
if f.is_file() and f.suffix == ".txt" and name_contains in f.name
|
44 |
+
]
|
45 |
+
logging.info(f"removing {len(files)} text files")
|
46 |
+
for f in files:
|
47 |
+
os.remove(f)
|
48 |
+
logging.info("done")
|
49 |
+
|
50 |
+
|
51 |
+
def corr(
|
52 |
+
s: str,
|
53 |
+
add_space_when_numerics=False,
|
54 |
+
exceptions=["e.g.", "i.e.", "etc.", "cf.", "vs.", "p."],
|
55 |
+
) -> str:
|
56 |
+
"""corrects spacing in a string
|
57 |
+
|
58 |
+
Args:
|
59 |
+
s (str): the string to correct
|
60 |
+
add_space_when_numerics (bool, optional): [add a space when a period is between two numbers, example 5.73]. Defaults to False.
|
61 |
+
exceptions (list, optional): [do not change these substrings]. Defaults to ['e.g.', 'i.e.', 'etc.', 'cf.', 'vs.', 'p.'].
|
62 |
+
|
63 |
+
Returns:
|
64 |
+
str: the corrected string
|
65 |
+
"""
|
66 |
+
if add_space_when_numerics:
|
67 |
+
s = re.sub(r"(\d)\.(\d)", r"\1. \2", s)
|
68 |
+
|
69 |
+
s = re.sub(r"\s+", " ", s)
|
70 |
+
s = re.sub(r'\s([?.!"](?:\s|$))', r"\1", s)
|
71 |
+
|
72 |
+
# fix space before apostrophe
|
73 |
+
s = re.sub(r"\s\'", r"'", s)
|
74 |
+
# fix space after apostrophe
|
75 |
+
s = re.sub(r"'\s", r"'", s)
|
76 |
+
# fix space before comma
|
77 |
+
s = re.sub(r"\s,", r",", s)
|
78 |
+
|
79 |
+
for e in exceptions:
|
80 |
+
expected_sub = re.sub(r"\s", "", e)
|
81 |
+
s = s.replace(expected_sub, e)
|
82 |
+
|
83 |
+
return s
|
84 |
+
|
85 |
+
|
86 |
+
def fix_punct_spaces(string: str) -> str:
|
87 |
+
"""
|
88 |
+
fix_punct_spaces - fix spaces around punctuation
|
89 |
+
|
90 |
+
:param str string: input string
|
91 |
+
:return str: string with spaces fixed
|
92 |
+
"""
|
93 |
+
|
94 |
+
fix_spaces = re.compile(r"\s*([?!.,]+(?:\s+[?!.,]+)*)\s*")
|
95 |
+
string = fix_spaces.sub(lambda x: "{} ".format(x.group(1).replace(" ", "")), string)
|
96 |
+
string = string.replace(" ' ", "'")
|
97 |
+
string = string.replace(' " ', '"')
|
98 |
+
return string.strip()
|
99 |
+
|
100 |
+
|
101 |
+
def clean_OCR(ugly_text: str) -> str:
|
102 |
+
"""
|
103 |
+
clean_OCR - clean up the OCR text
|
104 |
+
|
105 |
+
:param str ugly_text: input text to be cleaned
|
106 |
+
:return str: cleaned text
|
107 |
+
"""
|
108 |
+
# Remove all the newlines.
|
109 |
+
cleaned_text = ugly_text.replace("\n", " ")
|
110 |
+
# Remove all the tabs.
|
111 |
+
cleaned_text = cleaned_text.replace("\t", " ")
|
112 |
+
# Remove all the double spaces.
|
113 |
+
cleaned_text = cleaned_text.replace(" ", " ")
|
114 |
+
# Remove all the spaces at the beginning of the text.
|
115 |
+
cleaned_text = cleaned_text.lstrip()
|
116 |
+
# remove all instances of "- " and " - "
|
117 |
+
cleaned_text = cleaned_text.replace("- ", "")
|
118 |
+
cleaned_text = cleaned_text.replace(" -", "")
|
119 |
+
return fix_punct_spaces(cleaned_text)
|
120 |
+
|
121 |
+
|
122 |
+
def move2completed(
|
123 |
+
from_dir, filename, new_folder: str = "completed", verbose: bool = False
|
124 |
+
):
|
125 |
+
"""
|
126 |
+
move2completed - move a file to a new folder
|
127 |
+
"""
|
128 |
+
old_filepath = join(from_dir, filename)
|
129 |
+
|
130 |
+
new_filedirectory = join(from_dir, new_folder)
|
131 |
+
|
132 |
+
if not os.path.isdir(new_filedirectory):
|
133 |
+
os.mkdir(new_filedirectory)
|
134 |
+
if verbose:
|
135 |
+
print("created new directory for files at: \n", new_filedirectory)
|
136 |
+
new_filepath = join(new_filedirectory, filename)
|
137 |
+
|
138 |
+
try:
|
139 |
+
shutil.move(old_filepath, new_filepath)
|
140 |
+
logging.info("successfully moved the file {} to */completed.".format(filename))
|
141 |
+
except:
|
142 |
+
logging.info(
|
143 |
+
"ERROR! unable to move file to \n{}. Please investigate".format(
|
144 |
+
new_filepath
|
145 |
+
)
|
146 |
+
)
|
147 |
+
|
148 |
+
|
149 |
+
custom_replace_list = {
|
150 |
+
"t0": "to",
|
151 |
+
"'$": "'s",
|
152 |
+
",,": ", ",
|
153 |
+
"_ ": " ",
|
154 |
+
" '": "'",
|
155 |
+
}
|
156 |
+
|
157 |
+
replace_corr_exceptions = {
|
158 |
+
"i. e.": "i.e.",
|
159 |
+
"e. g.": "e.g.",
|
160 |
+
"e. g": "e.g.",
|
161 |
+
" ,": ",",
|
162 |
+
}
|
163 |
+
|
164 |
+
|
165 |
+
spell = SpellChecker()
|
166 |
+
|
167 |
+
|
168 |
+
def check_word_spelling(word: str) -> bool:
|
169 |
+
"""
|
170 |
+
check_word_spelling - check the spelling of a word
|
171 |
+
|
172 |
+
Args:
|
173 |
+
word (str): word to check
|
174 |
+
|
175 |
+
Returns:
|
176 |
+
bool: True if word is spelled correctly, False if not
|
177 |
+
"""
|
178 |
+
|
179 |
+
misspelled = spell.unknown([word])
|
180 |
+
|
181 |
+
return len(misspelled) == 0
|
182 |
+
|
183 |
+
|
184 |
+
def eval_and_replace(text: str, match_token: str = "- ") -> str:
|
185 |
+
"""
|
186 |
+
eval_and_replace - conditionally replace all instances of a substring in a string based on whether the eliminated substring results in a valid word
|
187 |
+
|
188 |
+
Args:
|
189 |
+
text (str): text to evaluate
|
190 |
+
match_token (str, optional): token to replace. Defaults to "- ".
|
191 |
+
|
192 |
+
Returns:
|
193 |
+
str: text with replaced tokens
|
194 |
+
"""
|
195 |
+
|
196 |
+
if match_token not in text:
|
197 |
+
return text
|
198 |
+
else:
|
199 |
+
while True:
|
200 |
+
full_before_text = text.split(match_token, maxsplit=1)[0]
|
201 |
+
before_text = [
|
202 |
+
char for char in full_before_text.split()[-1] if char.isalpha()
|
203 |
+
]
|
204 |
+
before_text = "".join(before_text)
|
205 |
+
full_after_text = text.split(match_token, maxsplit=1)[-1]
|
206 |
+
after_text = [char for char in full_after_text.split()[0] if char.isalpha()]
|
207 |
+
after_text = "".join(after_text)
|
208 |
+
full_text = before_text + after_text
|
209 |
+
if check_word_spelling(full_text):
|
210 |
+
text = full_before_text + full_after_text
|
211 |
+
else:
|
212 |
+
text = full_before_text + " " + full_after_text
|
213 |
+
if match_token not in text:
|
214 |
+
break
|
215 |
+
return text
|
216 |
+
|
217 |
+
|
218 |
+
def cleantxt_ocr(ugly_text, lower=False, lang: str = "en") -> str:
|
219 |
+
"""
|
220 |
+
cleantxt_ocr - clean text from OCR
|
221 |
+
|
222 |
+
https://pypi.org/project/clean-text/
|
223 |
+
Args:
|
224 |
+
ugly_text (str): text to clean
|
225 |
+
lower (bool, optional): lowercase text. Defaults to False.
|
226 |
+
lang (str, optional): language of text. Defaults to "en".
|
227 |
+
|
228 |
+
Returns:
|
229 |
+
str: cleaned text
|
230 |
+
"""
|
231 |
+
|
232 |
+
cleaned_text = clean(
|
233 |
+
ugly_text,
|
234 |
+
fix_unicode=True, # fix various unicode errors
|
235 |
+
to_ascii=True, # transliterate to closest ASCII representation
|
236 |
+
lower=lower, # lowercase text
|
237 |
+
no_line_breaks=True, # fully strip line breaks as opposed to only normalizing them
|
238 |
+
no_urls=True, # replace all URLs with a special token
|
239 |
+
no_emails=True, # replace all email addresses with a special token
|
240 |
+
no_phone_numbers=True, # replace all phone numbers with a special token
|
241 |
+
no_numbers=False, # replace all numbers with a special token
|
242 |
+
no_digits=False, # replace all digits with a special token
|
243 |
+
no_currency_symbols=False, # replace all currency symbols with a special token
|
244 |
+
no_punct=False, # remove punctuations
|
245 |
+
replace_with_punct="", # instead of removing punctuations you may replace them
|
246 |
+
replace_with_url="this url",
|
247 |
+
replace_with_email="this email",
|
248 |
+
replace_with_phone_number="this phone number",
|
249 |
+
lang=lang, # set to 'de' for German special handling
|
250 |
+
)
|
251 |
+
|
252 |
+
return cleaned_text
|
253 |
+
|
254 |
+
|
255 |
+
def format_ocr_out(OCR_data):
|
256 |
+
"""format OCR output to text"""
|
257 |
+
if isinstance(OCR_data, list):
|
258 |
+
text = " ".join(OCR_data)
|
259 |
+
else:
|
260 |
+
text = str(OCR_data)
|
261 |
+
_clean = cleantxt_ocr(text)
|
262 |
+
return corr(_clean)
|
263 |
+
|
264 |
+
|
265 |
+
def postprocess(text: str) -> str:
|
266 |
+
"""to be used after recombining the lines"""
|
267 |
+
|
268 |
+
proc = corr(cleantxt_ocr(text))
|
269 |
+
|
270 |
+
for k, v in custom_replace_list.items():
|
271 |
+
proc = proc.replace(str(k), str(v))
|
272 |
+
|
273 |
+
proc = corr(proc)
|
274 |
+
|
275 |
+
for k, v in replace_corr_exceptions.items():
|
276 |
+
proc = proc.replace(str(k), str(v))
|
277 |
+
|
278 |
+
return eval_and_replace(proc)
|
279 |
+
|
280 |
+
|
281 |
+
def result2text(result, as_text=False) -> str or list:
|
282 |
+
"""Convert OCR result to text"""
|
283 |
+
|
284 |
+
full_doc = []
|
285 |
+
for i, page in enumerate(result.pages, start=1):
|
286 |
+
text = ""
|
287 |
+
for block in page.blocks:
|
288 |
+
text += "\n\t"
|
289 |
+
for line in block.lines:
|
290 |
+
for word in line.words:
|
291 |
+
# print(dir(word))
|
292 |
+
text += word.value + " "
|
293 |
+
full_doc.append(text)
|
294 |
+
|
295 |
+
return "\n".join(full_doc) if as_text else full_doc
|
296 |
+
|
297 |
+
|
298 |
+
def convert_PDF_to_Text(
|
299 |
+
PDF_file,
|
300 |
+
ocr_model=None,
|
301 |
+
max_pages: int = 20,
|
302 |
+
) -> str:
|
303 |
+
"""
|
304 |
+
convert_PDF_to_Text - convert a PDF file to text
|
305 |
+
|
306 |
+
:param str PDF_file: path to PDF file
|
307 |
+
:param ocr_model: model to use for OCR, defaults to None (uses the default model)
|
308 |
+
:param int max_pages: maximum number of pages to process, defaults to 20
|
309 |
+
:return str: text from PDF
|
310 |
+
"""
|
311 |
+
st = time.perf_counter()
|
312 |
+
PDF_file = Path(PDF_file)
|
313 |
+
ocr_model = ocr_predictor(pretrained=True) if ocr_model is None else ocr_model
|
314 |
+
logging.info(f"starting OCR on {PDF_file.name}")
|
315 |
+
doc = DocumentFile.from_pdf(PDF_file)
|
316 |
+
truncated = False
|
317 |
+
if len(doc) > max_pages:
|
318 |
+
logging.warning(
|
319 |
+
f"PDF has {len(doc)} pages, which is more than {max_pages}.. truncating"
|
320 |
+
)
|
321 |
+
doc = doc[:max_pages]
|
322 |
+
truncated = True
|
323 |
+
|
324 |
+
# Analyze
|
325 |
+
logging.info(f"running OCR on {len(doc)} pages")
|
326 |
+
result = ocr_model(doc)
|
327 |
+
raw_text = result2text(result)
|
328 |
+
proc_text = [format_ocr_out(r) for r in raw_text]
|
329 |
+
fin_text = [postprocess(t) for t in proc_text]
|
330 |
+
|
331 |
+
ocr_results = "\n\n".join(fin_text)
|
332 |
+
|
333 |
+
fn_rt = time.perf_counter() - st
|
334 |
+
|
335 |
+
logging.info("OCR complete")
|
336 |
+
|
337 |
+
results_dict = {
|
338 |
+
"num_pages": len(doc),
|
339 |
+
"runtime": round(fn_rt, 2),
|
340 |
+
"date": str(date.today()),
|
341 |
+
"converted_text": ocr_results,
|
342 |
+
"truncated": truncated,
|
343 |
+
"length": len(ocr_results),
|
344 |
+
}
|
345 |
+
|
346 |
+
return results_dict
|
requirements.txt
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
accelerate
|
2 |
+
clean-text[gpl]
|
3 |
+
gradio==5.5.0
|
4 |
+
natsort
|
5 |
+
nltk
|
6 |
+
pyspellchecker
|
7 |
+
python-doctr[torch]
|
8 |
+
rapidfuzz==2.13.7
|
9 |
+
sentencepiece
|
10 |
+
torch
|
11 |
+
tqdm
|
12 |
+
transformers==4.46.2
|
summarize.py
ADDED
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
summarize - a module for summarizing text using a model from the Hugging Face model hub
|
3 |
+
"""
|
4 |
+
import logging
|
5 |
+
import os
|
6 |
+
import pprint as pp
|
7 |
+
|
8 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(message)s")
|
9 |
+
|
10 |
+
import torch
|
11 |
+
from tqdm.auto import tqdm
|
12 |
+
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
13 |
+
|
14 |
+
from utils import validate_pytorch2
|
15 |
+
|
16 |
+
|
17 |
+
def load_model_and_tokenizer(model_name: str) -> tuple:
|
18 |
+
"""
|
19 |
+
load_model_and_tokenizer - load a model and tokenizer from a model name/ID on the hub
|
20 |
+
|
21 |
+
:param str model_name: the model name/ID on the hub
|
22 |
+
:return tuple: a tuple containing the model and tokenizer
|
23 |
+
"""
|
24 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
25 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(
|
26 |
+
model_name,
|
27 |
+
use_auth_token=os.environ.get("HF_TOKEN", None),
|
28 |
+
).to(device)
|
29 |
+
model = model.eval()
|
30 |
+
|
31 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
32 |
+
model_name,
|
33 |
+
use_auth_token=os.environ.get("HF_TOKEN", None),
|
34 |
+
)
|
35 |
+
|
36 |
+
logging.info(f"Loaded model {model_name} to {device}")
|
37 |
+
|
38 |
+
if validate_pytorch2():
|
39 |
+
try:
|
40 |
+
logging.info("Compiling model with Torch 2.0")
|
41 |
+
model = torch.compile(model)
|
42 |
+
except Exception as e:
|
43 |
+
logging.warning(f"Could not compile model with Torch 2.0: {e}")
|
44 |
+
else:
|
45 |
+
logging.info("Torch 2.0 not detected, skipping compilation")
|
46 |
+
|
47 |
+
return model, tokenizer
|
48 |
+
|
49 |
+
|
50 |
+
def summarize_and_score(
|
51 |
+
ids, mask, model, tokenizer, is_general_attention_model=True, **kwargs
|
52 |
+
) -> tuple:
|
53 |
+
"""
|
54 |
+
summarize_and_score - given a batch of ids and a mask, return a summary and a score for the summary
|
55 |
+
|
56 |
+
Args:
|
57 |
+
ids (): the batch of ids
|
58 |
+
mask (): the attention mask for the batch
|
59 |
+
model (): the model to use for summarization
|
60 |
+
tokenizer (): the tokenizer to use for summarization
|
61 |
+
is_general_attention_model (bool, optional): whether the model is a general attention model. Defaults to True.
|
62 |
+
**kwargs: any additional arguments to pass to the model
|
63 |
+
Returns:
|
64 |
+
tuple (str, float): the summary, the score for the summary
|
65 |
+
"""
|
66 |
+
|
67 |
+
ids = ids[None, :]
|
68 |
+
mask = mask[None, :]
|
69 |
+
|
70 |
+
input_ids = ids.to("cuda") if torch.cuda.is_available() else ids
|
71 |
+
attention_mask = mask.to("cuda") if torch.cuda.is_available() else mask
|
72 |
+
|
73 |
+
global_attention_mask = torch.zeros_like(attention_mask)
|
74 |
+
# put global attention on <s> token
|
75 |
+
global_attention_mask[:, 0] = 1
|
76 |
+
|
77 |
+
if is_general_attention_model:
|
78 |
+
summary_pred_ids = model.generate(
|
79 |
+
input_ids,
|
80 |
+
attention_mask=attention_mask,
|
81 |
+
output_scores=True,
|
82 |
+
return_dict_in_generate=True,
|
83 |
+
**kwargs,
|
84 |
+
)
|
85 |
+
else:
|
86 |
+
summary_pred_ids = model.generate(
|
87 |
+
input_ids,
|
88 |
+
attention_mask=attention_mask,
|
89 |
+
global_attention_mask=global_attention_mask,
|
90 |
+
output_scores=True,
|
91 |
+
return_dict_in_generate=True,
|
92 |
+
**kwargs,
|
93 |
+
)
|
94 |
+
summary = tokenizer.batch_decode(
|
95 |
+
summary_pred_ids.sequences,
|
96 |
+
skip_special_tokens=True,
|
97 |
+
remove_invalid_values=True,
|
98 |
+
)
|
99 |
+
score = round(summary_pred_ids.sequences_scores.cpu().numpy()[0], 4)
|
100 |
+
|
101 |
+
return summary, score
|
102 |
+
|
103 |
+
|
104 |
+
def summarize_via_tokenbatches(
|
105 |
+
input_text: str,
|
106 |
+
model,
|
107 |
+
tokenizer,
|
108 |
+
batch_length=2048,
|
109 |
+
batch_stride=16,
|
110 |
+
min_batch_length=512,
|
111 |
+
**kwargs,
|
112 |
+
) -> list:
|
113 |
+
"""
|
114 |
+
summarize_via_tokenbatches - summarize a long string via batches of tokens
|
115 |
+
|
116 |
+
Args:
|
117 |
+
input_text (str): the text to summarize
|
118 |
+
model (): the model to use for summarization
|
119 |
+
tokenizer (): the tokenizer to use for summarization
|
120 |
+
batch_length (int, optional): the length of each batch. Defaults to 2048.
|
121 |
+
batch_stride (int, optional): the stride of each batch. Defaults to 16. The stride is the number of tokens that overlap between batches.
|
122 |
+
min_batch_length (int, optional): the minimum length of each batch. Defaults to 512.
|
123 |
+
|
124 |
+
**kwargs: any additional arguments to pass to the model for inference
|
125 |
+
Returns:
|
126 |
+
list: a list of dictionaries containing the input tokens, the summary, and the summary score
|
127 |
+
"""
|
128 |
+
|
129 |
+
logger = logging.getLogger(__name__)
|
130 |
+
# log all input parameters
|
131 |
+
if batch_length < min_batch_length:
|
132 |
+
logger.warning(
|
133 |
+
f"batch_length must be at least {min_batch_length}. Setting batch_length to {min_batch_length}"
|
134 |
+
)
|
135 |
+
batch_length = min_batch_length
|
136 |
+
|
137 |
+
logger.info(f"input parameters:\n{pp.pformat(kwargs)}")
|
138 |
+
logger.info(f"batch_length: {batch_length}, batch_stride: {batch_stride}")
|
139 |
+
|
140 |
+
encoded_input = tokenizer(
|
141 |
+
input_text,
|
142 |
+
padding="max_length",
|
143 |
+
truncation=True,
|
144 |
+
max_length=batch_length,
|
145 |
+
stride=batch_stride,
|
146 |
+
return_overflowing_tokens=True,
|
147 |
+
add_special_tokens=False,
|
148 |
+
return_tensors="pt",
|
149 |
+
)
|
150 |
+
|
151 |
+
in_id_arr, att_arr = encoded_input.input_ids, encoded_input.attention_mask
|
152 |
+
gen_summaries = []
|
153 |
+
|
154 |
+
pbar = tqdm(total=len(in_id_arr))
|
155 |
+
|
156 |
+
for _id, _mask in zip(in_id_arr, att_arr):
|
157 |
+
result, score = summarize_and_score(
|
158 |
+
ids=_id,
|
159 |
+
mask=_mask,
|
160 |
+
model=model,
|
161 |
+
tokenizer=tokenizer,
|
162 |
+
**kwargs,
|
163 |
+
)
|
164 |
+
score = round(float(score), 4)
|
165 |
+
_sum = {
|
166 |
+
"input_tokens": _id,
|
167 |
+
"summary": result,
|
168 |
+
"summary_score": score,
|
169 |
+
}
|
170 |
+
gen_summaries.append(_sum)
|
171 |
+
logger.debug(f"Score for batch: {score}. num chars: {len(repr(result))}")
|
172 |
+
logger.debug(f"Summary:\n\t{result}")
|
173 |
+
pbar.update()
|
174 |
+
|
175 |
+
pbar.close()
|
176 |
+
|
177 |
+
return gen_summaries
|
utils.py
ADDED
@@ -0,0 +1,450 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
utils.py - Utility functions for the project.
|
3 |
+
"""
|
4 |
+
import logging
|
5 |
+
import re
|
6 |
+
import subprocess
|
7 |
+
from collections import defaultdict, deque
|
8 |
+
from datetime import datetime, timedelta
|
9 |
+
from itertools import combinations
|
10 |
+
from pathlib import Path
|
11 |
+
from typing import List
|
12 |
+
|
13 |
+
logging.basicConfig(
|
14 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
15 |
+
level=logging.INFO,
|
16 |
+
)
|
17 |
+
|
18 |
+
import torch
|
19 |
+
from natsort import natsorted
|
20 |
+
from nltk.tokenize import WhitespaceTokenizer, sent_tokenize, word_tokenize
|
21 |
+
from rapidfuzz import fuzz
|
22 |
+
|
23 |
+
STOPWORDS = set(
|
24 |
+
"a about above after again all also am an and any are aren't as at back be because been before being below between both but by can't cannot could couldn't did didn't do does doesn't doing don't down during each few for from further had hadn't has hasn't have haven't having he'd he'll he's hence her here here's hers herself him himself his how how's however i'd i'll i'm i've if in into is isn't it's its itself just let's me more moreover most mustn't my myself new nor now of off on once only or other ought our ours ourselves out over own really same shan't she'd she'll she's should shouldn't so some such than that that's the their theirs them themselves then there there's therefore these they they'd they'll they're they've this those through thus to too under until up use used using very was wasn't we we'd we'll we're we've were weren't what what's when when's where where's which while who who's whom why why's with won't would wouldn't you'd you'll you're you've your yours yourself yourselves".split()
|
25 |
+
)
|
26 |
+
|
27 |
+
|
28 |
+
def contraction_aware_tokenize(text: str) -> List[str]:
|
29 |
+
"""contraction_aware_tokenize - merges words containing apostrophes as one token."""
|
30 |
+
|
31 |
+
# Tokenize the text using the WhitespaceTokenizer
|
32 |
+
tokenizer = WhitespaceTokenizer()
|
33 |
+
tokens = tokenizer.tokenize(text)
|
34 |
+
|
35 |
+
merged_tokens = []
|
36 |
+
merged_token = ""
|
37 |
+
|
38 |
+
for token in tokens:
|
39 |
+
if re.search(r"\w+'\w+", token):
|
40 |
+
# Token contains an apostrophe, merge with previous token
|
41 |
+
merged_token += token
|
42 |
+
else:
|
43 |
+
# no apostrophe, add previous merged token (if any) and current
|
44 |
+
if merged_token:
|
45 |
+
merged_tokens.append(merged_token)
|
46 |
+
merged_token = ""
|
47 |
+
merged_tokens.append(token)
|
48 |
+
|
49 |
+
# Add the last merged token (if any)
|
50 |
+
if merged_token:
|
51 |
+
merged_tokens.append(merged_token)
|
52 |
+
|
53 |
+
return merged_tokens
|
54 |
+
|
55 |
+
|
56 |
+
def remove_stopwords(
|
57 |
+
text: str, stopwords: List[str] = STOPWORDS, contraction_tokenize: bool = True
|
58 |
+
) -> str:
|
59 |
+
"""
|
60 |
+
remove_stopwords - Remove stopwords from text.
|
61 |
+
|
62 |
+
:param str text: input text
|
63 |
+
:param List[str] stopwords: list of stopwords, defaults to STOPWORDS
|
64 |
+
:param bool contraction_tokenize: use custom apostrophe tokenizer, defaults to True
|
65 |
+
:return str: text with stopwords removed
|
66 |
+
"""
|
67 |
+
lines = text.split("\n")
|
68 |
+
filtered_lines = []
|
69 |
+
|
70 |
+
def fix_commas(text: str) -> str:
|
71 |
+
"""fixes commas in text to have a space after them"""
|
72 |
+
spaced_text = text.replace(",", ", ")
|
73 |
+
return spaced_text.replace(" ", " ").strip()
|
74 |
+
|
75 |
+
for line in lines:
|
76 |
+
sentences = sent_tokenize(line)
|
77 |
+
filtered_sentences = []
|
78 |
+
|
79 |
+
for sentence in sentences:
|
80 |
+
# Add space around punctuations for the regex to work correctly, only if they are followed by a letter
|
81 |
+
sentence_with_spaces = re.sub(r"([.,!?])(\w)", r"\1 \2", sentence[:-1])
|
82 |
+
|
83 |
+
words = (
|
84 |
+
contraction_aware_tokenize(sentence_with_spaces)
|
85 |
+
if contraction_tokenize
|
86 |
+
else word_tokenize(sentence_with_spaces)
|
87 |
+
)
|
88 |
+
|
89 |
+
filtered_words = []
|
90 |
+
for word in words:
|
91 |
+
if word.lower() not in stopwords:
|
92 |
+
filtered_words.append(word)
|
93 |
+
|
94 |
+
filtered_sentence = " ".join(filtered_words)
|
95 |
+
# Restore original spaces around punctuation marks
|
96 |
+
filtered_sentence = re.sub(r"([.,!?])\s*", r"\1", filtered_sentence)
|
97 |
+
|
98 |
+
filtered_sentences.append(filtered_sentence + sentence[-1])
|
99 |
+
|
100 |
+
filtered_line = " ".join(filtered_sentences)
|
101 |
+
|
102 |
+
# Replace multiple consecutive whitespaces with a single space
|
103 |
+
filtered_line = re.sub(r"\s+", " ", filtered_line)
|
104 |
+
filtered_line = fix_commas(filtered_line.strip())
|
105 |
+
|
106 |
+
filtered_lines.append(filtered_line)
|
107 |
+
|
108 |
+
filtered_text = "\n".join(filtered_lines)
|
109 |
+
|
110 |
+
return filtered_text
|
111 |
+
|
112 |
+
|
113 |
+
def remove_stagnant_files(
|
114 |
+
freq: str = "hourly",
|
115 |
+
search_path: str = ".",
|
116 |
+
substring="DocSumm",
|
117 |
+
remove_suffix=".txt",
|
118 |
+
):
|
119 |
+
"""
|
120 |
+
remove_stagnant_files - Remove files that have not been modified in a certain amount of time.
|
121 |
+
|
122 |
+
:param str freq: frequency of file removal, defaults to "hourly"
|
123 |
+
:param str search_path: location to search for files, defaults to "."
|
124 |
+
:param str substring: substring to search for in file names, defaults to "DocSumm"
|
125 |
+
:param str remove_suffix: suffix of files to remove, defaults to ".txt"
|
126 |
+
:raises ValueError: if freq is not one of "hourly", "daily", or "weekly"
|
127 |
+
"""
|
128 |
+
current_time = datetime.now()
|
129 |
+
search_path = Path(search_path)
|
130 |
+
|
131 |
+
if freq == "hourly":
|
132 |
+
time_threshold = current_time - timedelta(hours=1)
|
133 |
+
elif freq == "daily":
|
134 |
+
time_threshold = current_time - timedelta(days=1)
|
135 |
+
elif freq == "weekly":
|
136 |
+
time_threshold = current_time - timedelta(weeks=1)
|
137 |
+
else:
|
138 |
+
raise ValueError(
|
139 |
+
"Invalid frequency. Supported values are 'hourly', 'daily', and 'weekly'."
|
140 |
+
)
|
141 |
+
|
142 |
+
files_to_remove = []
|
143 |
+
potential_files = [
|
144 |
+
f for f in search_path.iterdir() if f.is_file() and f.suffix == remove_suffix
|
145 |
+
]
|
146 |
+
logging.info(f"Found {len(potential_files)} files.")
|
147 |
+
for candidate in potential_files:
|
148 |
+
if (
|
149 |
+
candidate.is_file()
|
150 |
+
and substring in candidate.name
|
151 |
+
and candidate.stat().st_mtime < time_threshold.timestamp()
|
152 |
+
):
|
153 |
+
files_to_remove.append(candidate)
|
154 |
+
logging.debug(f"File {candidate} last modified at {candidate.stat().st_mtime}")
|
155 |
+
logging.info(f"Removing {len(files_to_remove)} files.")
|
156 |
+
for file_path in files_to_remove:
|
157 |
+
file_path.unlink()
|
158 |
+
logging.debug(f"Removed files: {files_to_remove}")
|
159 |
+
|
160 |
+
|
161 |
+
def compare_model_size(model_name: str, threshold: int = 500) -> bool:
|
162 |
+
"""
|
163 |
+
compare_model_size - compare string representations of model size to a threshold
|
164 |
+
|
165 |
+
:param str model_name: the model name to compare
|
166 |
+
:param int threshold: the threshold to compare against in millions, defaults to 500
|
167 |
+
:return: True if the model size is greater than the threshold, False or None otherwise
|
168 |
+
"""
|
169 |
+
pattern = r"(\d+)(M|G|k|b)?" # param regex
|
170 |
+
|
171 |
+
matches = re.findall(pattern, model_name)
|
172 |
+
if not matches:
|
173 |
+
return None
|
174 |
+
|
175 |
+
# Extract the parameter count and unit
|
176 |
+
parameter_count, unit = matches[-1]
|
177 |
+
parameter_count = int(parameter_count)
|
178 |
+
|
179 |
+
# Convert to the standard form (M for million, G for billion, k for thousand)
|
180 |
+
if unit == "G" or unit == "b":
|
181 |
+
parameter_count *= 1000
|
182 |
+
elif unit == "M":
|
183 |
+
pass
|
184 |
+
elif unit == "k":
|
185 |
+
parameter_count /= 1000
|
186 |
+
else:
|
187 |
+
return None # Unknown
|
188 |
+
|
189 |
+
return parameter_count > threshold
|
190 |
+
|
191 |
+
|
192 |
+
def validate_pytorch2(torch_version: str = None) -> bool:
|
193 |
+
"""
|
194 |
+
validate_pytorch2 - validate that the PyTorch version is 2.0 or greater
|
195 |
+
|
196 |
+
:param str torch_version: the PyTorch version to validate, defaults to None
|
197 |
+
:return: True if the PyTorch version is 2.0 or greater, False otherwise
|
198 |
+
"""
|
199 |
+
|
200 |
+
torch_version = torch.__version__ if torch_version is None else torch_version
|
201 |
+
|
202 |
+
pattern = r"^2\.\d+(\.\d+)*"
|
203 |
+
|
204 |
+
return True if re.match(pattern, torch_version) else False
|
205 |
+
|
206 |
+
|
207 |
+
def get_timestamp(detailed=False) -> str:
|
208 |
+
"""
|
209 |
+
get_timestamp - get a timestamp for the current time
|
210 |
+
:param bool detailed: whether to include seconds and microseconds, defaults to False
|
211 |
+
:return: str, the timestamp
|
212 |
+
"""
|
213 |
+
return (
|
214 |
+
datetime.now().strftime("%b%d%Y_%H%M%S%f")
|
215 |
+
if detailed
|
216 |
+
else datetime.now().strftime("%b%d%Y_%H")
|
217 |
+
)
|
218 |
+
|
219 |
+
|
220 |
+
def truncate_word_count(text: str, max_words=1024) -> dict:
|
221 |
+
"""
|
222 |
+
truncate_word_count - truncate a text to a maximum number of words
|
223 |
+
:param str text: the text to truncate
|
224 |
+
:param int max_words: the maximum number of words to keep, defaults to 1024
|
225 |
+
:return: dict, the processed text
|
226 |
+
"""
|
227 |
+
words = contraction_aware_tokenize(str(text))
|
228 |
+
processed = {}
|
229 |
+
if len(words) > max_words:
|
230 |
+
processed["was_truncated"] = True
|
231 |
+
processed["processed_text"] = " ".join(words[:max_words])
|
232 |
+
else:
|
233 |
+
processed["was_truncated"] = False
|
234 |
+
processed["processed_text"] = text
|
235 |
+
return processed
|
236 |
+
|
237 |
+
|
238 |
+
def load_examples(src, filetypes=[".txt", ".pdf"]):
|
239 |
+
"""
|
240 |
+
load_examples - a helper function for the gradio module to load examples
|
241 |
+
:param str src: the path to the examples
|
242 |
+
"""
|
243 |
+
src = Path(src)
|
244 |
+
src.mkdir(exist_ok=True)
|
245 |
+
|
246 |
+
pdf_url = (
|
247 |
+
"https://www.dropbox.com/s/y92xy7o5qb88yij/all_you_need_is_attention.pdf?dl=1"
|
248 |
+
)
|
249 |
+
subprocess.run(["wget", pdf_url, "-O", src / "all_you_need_is_attention.pdf"])
|
250 |
+
examples = [f for f in src.iterdir() if f.suffix in filetypes]
|
251 |
+
examples = natsorted(examples)
|
252 |
+
# load the examples into a list
|
253 |
+
text_examples = []
|
254 |
+
for example in examples:
|
255 |
+
with open(example, "r") as f:
|
256 |
+
text = f.read()
|
257 |
+
text_examples.append([text, "base", 2, 1024, 0.7, 3.5, 3])
|
258 |
+
|
259 |
+
return text_examples
|
260 |
+
|
261 |
+
|
262 |
+
def load_example_filenames(example_path: str or Path):
|
263 |
+
"""
|
264 |
+
load_example_filenames - a helper function for the gradio module to load examples
|
265 |
+
Returns:
|
266 |
+
dict, the examples (filename:full path)
|
267 |
+
"""
|
268 |
+
example_path = Path(example_path)
|
269 |
+
# load the examples into a list
|
270 |
+
examples = {f.name: f for f in example_path.glob("*.txt")}
|
271 |
+
return examples
|
272 |
+
|
273 |
+
|
274 |
+
def textlist2html(text_batches: List[str]) -> str:
|
275 |
+
"""textlist2html - convert a list of text summaries into a single HTML string"""
|
276 |
+
# Step 1: Generate each summary batch as a string of HTML
|
277 |
+
formatted_batches = [
|
278 |
+
f"""
|
279 |
+
<div style="
|
280 |
+
margin-bottom: 20px;
|
281 |
+
font-size: 18px;
|
282 |
+
line-height: 1.5em;
|
283 |
+
color: #333;
|
284 |
+
">
|
285 |
+
<h2 style="font-size: 22px; color: #555;">Batch {i}:</h2>
|
286 |
+
<p style="white-space: pre-line;">{s}</p>
|
287 |
+
</div>
|
288 |
+
"""
|
289 |
+
for i, s in enumerate(text_batches, start=1)
|
290 |
+
]
|
291 |
+
|
292 |
+
# Step 2: Join all the summary batches together into one string
|
293 |
+
joined_batches = "".join(formatted_batches)
|
294 |
+
|
295 |
+
# Step 3: Wrap the summary string in a larger div with background color, border, and padding
|
296 |
+
text_html_block = f"""
|
297 |
+
<div style="
|
298 |
+
border: 1px solid #ddd;
|
299 |
+
border-radius: 5px;
|
300 |
+
padding: 20px;
|
301 |
+
">
|
302 |
+
{joined_batches}
|
303 |
+
</div>
|
304 |
+
"""
|
305 |
+
|
306 |
+
return text_html_block
|
307 |
+
|
308 |
+
|
309 |
+
def extract_batches(html_string: str, pattern=None, flags=None) -> list:
|
310 |
+
"""
|
311 |
+
Extract batches of text from an HTML string.
|
312 |
+
|
313 |
+
Args:
|
314 |
+
html_string (str): The HTML string to extract batches from.
|
315 |
+
pattern (str, optional): The regular expression pattern to use. Defaults to a pattern that matches batches in the format provided.
|
316 |
+
flags (int, optional): The flags to use with the regular expression. Defaults to re.DOTALL.
|
317 |
+
|
318 |
+
Returns:
|
319 |
+
list: A list of dictionaries where each dictionary represents a batch and has 'title' and 'content' keys.
|
320 |
+
"""
|
321 |
+
# Set default pattern if none provided
|
322 |
+
if pattern is None:
|
323 |
+
pattern = r'<h2 style="font-size: 22px; color: #555;">(.*?)</h2>\s*<p style="white-space: pre-line;">(.*?)</p>'
|
324 |
+
|
325 |
+
# Set default flags if none provided
|
326 |
+
if flags is None:
|
327 |
+
flags = re.DOTALL
|
328 |
+
|
329 |
+
try:
|
330 |
+
# Find all matches in the string
|
331 |
+
matches = re.findall(pattern, html_string, flags)
|
332 |
+
|
333 |
+
# Convert matches to a list of dictionaries
|
334 |
+
batches = [
|
335 |
+
{"title": title.strip(), "content": content.strip()}
|
336 |
+
for title, content in matches
|
337 |
+
]
|
338 |
+
|
339 |
+
return batches
|
340 |
+
except re.error as e:
|
341 |
+
logging.error(f"An error occurred while trying to extract batches: {e}")
|
342 |
+
return []
|
343 |
+
|
344 |
+
|
345 |
+
def extract_keywords(
|
346 |
+
text: str, num_keywords: int = 3, window_size: int = 5, kw_max_len: int = 20
|
347 |
+
) -> List[str]:
|
348 |
+
"""
|
349 |
+
Extracts keywords from a text using a simplified TextRank algorithm.
|
350 |
+
|
351 |
+
Args:
|
352 |
+
text: The text to extract keywords from.
|
353 |
+
num_keywords: The number of keywords to extract. Default: 3
|
354 |
+
window_size: The number of words considered for co-occurrence. Default: 5
|
355 |
+
kw_max_len: The maximum length of a keyword (truncate longer keywords to max). Default: 20
|
356 |
+
Returns:
|
357 |
+
A list of strings, where each string is a keyword extracted from the input text.
|
358 |
+
"""
|
359 |
+
logger = logging.getLogger(__name__)
|
360 |
+
# Remove stopwords and tokenize the text into words
|
361 |
+
words = [
|
362 |
+
word
|
363 |
+
for word in re.findall(r"\b\w{3,}\b", text.lower())
|
364 |
+
if word not in STOPWORDS
|
365 |
+
]
|
366 |
+
|
367 |
+
# Create a graph of word co-occurrences within a moving window of words
|
368 |
+
cooccur = defaultdict(lambda: defaultdict(int))
|
369 |
+
deque_words = deque(maxlen=window_size)
|
370 |
+
for word in words:
|
371 |
+
for w1, w2 in combinations(deque_words, 2):
|
372 |
+
cooccur[w1][w2] += 1
|
373 |
+
cooccur[w2][w1] += 1
|
374 |
+
deque_words.append(word)
|
375 |
+
|
376 |
+
# Assign scores to words using a simplified TextRank algorithm
|
377 |
+
scores = defaultdict(float)
|
378 |
+
for _ in range(10):
|
379 |
+
new_scores = defaultdict(float)
|
380 |
+
for word, co_words in cooccur.items():
|
381 |
+
new_scores[word] = 0.15 + 0.85 * sum(
|
382 |
+
cooccur[word][other] / sum(cooccur[other].values()) * scores[other]
|
383 |
+
for other in co_words
|
384 |
+
)
|
385 |
+
scores = new_scores
|
386 |
+
|
387 |
+
# Sort the words by score and return the top num_keywords keywords
|
388 |
+
keywords = sorted(scores, key=scores.get, reverse=True)[:num_keywords]
|
389 |
+
logger.debug(f"All keywords: {keywords}")
|
390 |
+
# Use fuzzy matching to remove similar keywords
|
391 |
+
final_keywords = []
|
392 |
+
for keyword in keywords:
|
393 |
+
if not any(fuzz.ratio(keyword, other) > 70 for other in final_keywords):
|
394 |
+
final_keywords.append(keyword[:kw_max_len])
|
395 |
+
logger.debug(f"Keywords (max len. {kw_max_len}):\t{final_keywords}")
|
396 |
+
return final_keywords
|
397 |
+
|
398 |
+
|
399 |
+
def saves_summary(
|
400 |
+
summarize_output, outpath: str or Path = None, add_signature=True, **kwargs
|
401 |
+
) -> Path:
|
402 |
+
"""
|
403 |
+
saves_summary - save the summary generated from summarize_via_tokenbatches() to a text file
|
404 |
+
|
405 |
+
summarize_output: output from summarize_via_tokenbatches()
|
406 |
+
outpath: path to the output file
|
407 |
+
add_signature: whether to add a signature to the output file
|
408 |
+
kwargs: additional keyword arguments to include in the output file
|
409 |
+
"""
|
410 |
+
logger = logging.getLogger(__name__)
|
411 |
+
sum_text = [f"{s['summary'][0]}\n" for s in summarize_output]
|
412 |
+
sum_scores = [f"\n - {round(s['summary_score'],4)}" for s in summarize_output]
|
413 |
+
scores_text = "\n".join(sum_scores)
|
414 |
+
full_summary = "\n".join(sum_text)
|
415 |
+
|
416 |
+
keywords = "_".join(extract_keywords(full_summary, kw_max_len=4))
|
417 |
+
logger.debug(f"kw:\t{keywords}")
|
418 |
+
outpath = (
|
419 |
+
Path.cwd() / f"DocSumm_{keywords}_{get_timestamp()}.txt"
|
420 |
+
if outpath is None
|
421 |
+
else Path(outpath)
|
422 |
+
)
|
423 |
+
logger.info(f"Saving summary to:\t{outpath.name}")
|
424 |
+
with open(
|
425 |
+
outpath,
|
426 |
+
"w",
|
427 |
+
encoding="utf-8",
|
428 |
+
) as fo:
|
429 |
+
fo.writelines(full_summary)
|
430 |
+
fo.write("\n\n")
|
431 |
+
if add_signature:
|
432 |
+
fo.write("\n\n---\n\n")
|
433 |
+
fo.write("Generated with the Document Summarization space :)\n\n")
|
434 |
+
fo.write("https://hf.co/spaces/pszemraj/document-summarization\n\n")
|
435 |
+
with open(
|
436 |
+
outpath,
|
437 |
+
"a",
|
438 |
+
encoding="utf-8",
|
439 |
+
) as fo:
|
440 |
+
fo.write("\n")
|
441 |
+
fo.write("## Section Scores:\n\n")
|
442 |
+
fo.writelines(scores_text)
|
443 |
+
fo.write("\n\n")
|
444 |
+
fo.write(f"Date: {get_timestamp()}\n\n")
|
445 |
+
if kwargs:
|
446 |
+
fo.write("---\n\n")
|
447 |
+
fo.write("## Parameters:\n\n")
|
448 |
+
for key, value in kwargs.items():
|
449 |
+
fo.write(f"{key}: {value}\n")
|
450 |
+
return str(outpath.resolve())
|