seanpedrickcase commited on
Commit
bafcf39
·
1 Parent(s): 3d18b9d

Fixed on deprecated Github workflow functions. Applied linter and formatter to code throughout. Added tests for GUI load.

Browse files
.github/scripts/setup_test_data.py CHANGED
@@ -5,170 +5,199 @@ Creates dummy test files when example data is not available.
5
  """
6
 
7
  import os
8
- import sys
9
  import pandas as pd
10
- from pathlib import Path
11
 
12
  def create_directories():
13
  """Create necessary directories."""
14
- dirs = [
15
- 'example_data',
16
- 'example_data/example_outputs'
17
- ]
18
-
19
  for dir_path in dirs:
20
  os.makedirs(dir_path, exist_ok=True)
21
  print(f"Created directory: {dir_path}")
22
 
 
23
  def create_dummy_pdf():
24
  """Create a dummy PDF for testing."""
25
  try:
26
- from reportlab.pdfgen import canvas
27
  from reportlab.lib.pagesizes import letter
28
-
29
- pdf_path = 'example_data/example_of_emails_sent_to_a_professor_before_applying.pdf'
30
-
 
 
 
31
  c = canvas.Canvas(pdf_path, pagesize=letter)
32
- c.drawString(100, 750, 'This is a test document for redaction testing.')
33
- c.drawString(100, 700, 'Email: [email protected]')
34
- c.drawString(100, 650, 'Phone: 123-456-7890')
35
- c.drawString(100, 600, 'Name: John Doe')
36
- c.drawString(100, 550, 'Address: 123 Test Street, Test City, TC 12345')
37
  c.showPage()
38
-
39
  # Add second page
40
- c.drawString(100, 750, 'Second page content')
41
- c.drawString(100, 700, 'More test data: [email protected]')
42
- c.drawString(100, 650, 'Another phone: 987-654-3210')
43
  c.save()
44
-
45
  print(f"Created dummy PDF: {pdf_path}")
46
-
47
  except ImportError:
48
  print("ReportLab not available, skipping PDF creation")
49
  # Create a simple text file instead
50
- with open('example_data/example_of_emails_sent_to_a_professor_before_applying.pdf', 'w') as f:
 
 
 
51
  f.write("This is a dummy PDF file for testing")
52
  print("Created dummy text file instead of PDF")
53
 
 
54
  def create_dummy_csv():
55
  """Create dummy CSV files for testing."""
56
  # Main CSV
57
  csv_data = {
58
- 'Case Note': [
59
- 'Client visited for consultation regarding housing issues',
60
- 'Follow-up appointment scheduled for next week',
61
- 'Documentation submitted for review'
62
  ],
63
- 'Client': ['John Smith', 'Jane Doe', 'Bob Johnson'],
64
- 'Date': ['2024-01-15', '2024-01-16', '2024-01-17']
65
  }
66
  df = pd.DataFrame(csv_data)
67
- df.to_csv('example_data/combined_case_notes.csv', index=False)
68
  print("Created dummy CSV: example_data/combined_case_notes.csv")
69
-
70
  # Lambeth CSV
71
  lambeth_data = {
72
- 'text': [
73
- 'Lambeth 2030 vision document content',
74
- 'Our Future Our Lambeth strategic plan',
75
- 'Community engagement and development'
76
  ],
77
- 'page': [1, 2, 3]
78
  }
79
  df_lambeth = pd.DataFrame(lambeth_data)
80
- df_lambeth.to_csv('example_data/Lambeth_2030-Our_Future_Our_Lambeth.pdf.csv', index=False)
 
 
81
  print("Created dummy CSV: example_data/Lambeth_2030-Our_Future_Our_Lambeth.pdf.csv")
82
 
 
83
  def create_dummy_word_doc():
84
  """Create dummy Word document."""
85
  try:
86
  from docx import Document
87
-
88
  doc = Document()
89
- doc.add_heading('Test Document for Redaction', 0)
90
- doc.add_paragraph('This is a test document for redaction testing.')
91
- doc.add_paragraph('Contact Information:')
92
- doc.add_paragraph('Email: [email protected]')
93
- doc.add_paragraph('Phone: 123-456-7890')
94
- doc.add_paragraph('Name: John Doe')
95
- doc.add_paragraph('Address: 123 Test Street, Test City, TC 12345')
96
-
97
- doc.save('example_data/Bold minimalist professional cover letter.docx')
98
  print("Created dummy Word document")
99
-
100
  except ImportError:
101
  print("python-docx not available, skipping Word document creation")
102
 
 
103
  def create_allow_deny_lists():
104
  """Create dummy allow/deny lists."""
105
  # Allow lists
106
- allow_data = {'word': ['test', 'example', 'document']}
107
- pd.DataFrame(allow_data).to_csv('example_data/test_allow_list_graduate.csv', index=False)
108
- pd.DataFrame(allow_data).to_csv('example_data/test_allow_list_partnership.csv', index=False)
 
 
 
 
109
  print("Created allow lists")
110
-
111
  # Deny lists
112
- deny_data = {'word': ['sensitive', 'confidential', 'private']}
113
- pd.DataFrame(deny_data).to_csv('example_data/partnership_toolkit_redact_custom_deny_list.csv', index=False)
114
- pd.DataFrame(deny_data).to_csv('example_data/Partnership-Agreement-Toolkit_test_deny_list_para_single_spell.csv', index=False)
 
 
 
 
 
115
  print("Created deny lists")
116
-
117
  # Whole page redaction list
118
- page_data = {'page': [1, 2]}
119
- pd.DataFrame(page_data).to_csv('example_data/partnership_toolkit_redact_some_pages.csv', index=False)
 
 
120
  print("Created whole page redaction list")
121
 
 
122
  def create_ocr_output():
123
  """Create dummy OCR output CSV."""
124
  ocr_data = {
125
- 'file_name': ['test.pdf', 'test.pdf', 'test.pdf'],
126
- 'page_number': [1, 2, 3],
127
- 'text': [
128
- 'This is page 1 content with some text',
129
- 'This is page 2 content with different text',
130
- 'This is page 3 content with more text'
131
  ],
132
- 'confidence': [0.95, 0.92, 0.88]
133
  }
134
  df = pd.DataFrame(ocr_data)
135
- df.to_csv('example_data/example_outputs/doubled_output_joined.pdf_ocr_output.csv', index=False)
 
 
 
136
  print("Created dummy OCR output CSV")
137
 
 
138
  def create_dummy_image():
139
  """Create dummy image for testing."""
140
  try:
141
  from PIL import Image, ImageDraw, ImageFont
142
-
143
- img = Image.new('RGB', (800, 600), color='white')
144
  draw = ImageDraw.Draw(img)
145
-
146
  # Try to use a system font
147
  try:
148
- font = ImageFont.truetype('/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf', 20)
149
- except:
 
 
 
150
  try:
151
- font = ImageFont.truetype('/System/Library/Fonts/Arial.ttf', 20)
152
- except:
 
153
  font = ImageFont.load_default()
154
-
155
  # Add text to image
156
- draw.text((50, 50), 'Test Document for Redaction', fill='black', font=font)
157
- draw.text((50, 100), 'Email: [email protected]', fill='black', font=font)
158
- draw.text((50, 150), 'Phone: 123-456-7890', fill='black', font=font)
159
- draw.text((50, 200), 'Name: John Doe', fill='black', font=font)
160
- draw.text((50, 250), 'Address: 123 Test Street', fill='black', font=font)
161
-
162
- img.save('example_data/example_complaint_letter.jpg')
163
  print("Created dummy image")
164
-
165
  except ImportError:
166
  print("PIL not available, skipping image creation")
167
 
 
168
  def main():
169
  """Main setup function."""
170
  print("Setting up test data for GitHub Actions...")
171
-
172
  create_directories()
173
  create_dummy_pdf()
174
  create_dummy_csv()
@@ -176,12 +205,13 @@ def main():
176
  create_allow_deny_lists()
177
  create_ocr_output()
178
  create_dummy_image()
179
-
180
  print("\nTest data setup complete!")
181
  print("Created files:")
182
- for root, dirs, files in os.walk('example_data'):
183
  for file in files:
184
  print(f" {os.path.join(root, file)}")
185
 
 
186
  if __name__ == "__main__":
187
  main()
 
5
  """
6
 
7
  import os
8
+
9
  import pandas as pd
10
+
11
 
12
  def create_directories():
13
  """Create necessary directories."""
14
+ dirs = ["example_data", "example_data/example_outputs"]
15
+
 
 
 
16
  for dir_path in dirs:
17
  os.makedirs(dir_path, exist_ok=True)
18
  print(f"Created directory: {dir_path}")
19
 
20
+
21
  def create_dummy_pdf():
22
  """Create a dummy PDF for testing."""
23
  try:
 
24
  from reportlab.lib.pagesizes import letter
25
+ from reportlab.pdfgen import canvas
26
+
27
+ pdf_path = (
28
+ "example_data/example_of_emails_sent_to_a_professor_before_applying.pdf"
29
+ )
30
+
31
  c = canvas.Canvas(pdf_path, pagesize=letter)
32
+ c.drawString(100, 750, "This is a test document for redaction testing.")
33
+ c.drawString(100, 700, "Email: [email protected]")
34
+ c.drawString(100, 650, "Phone: 123-456-7890")
35
+ c.drawString(100, 600, "Name: John Doe")
36
+ c.drawString(100, 550, "Address: 123 Test Street, Test City, TC 12345")
37
  c.showPage()
38
+
39
  # Add second page
40
+ c.drawString(100, 750, "Second page content")
41
+ c.drawString(100, 700, "More test data: [email protected]")
42
+ c.drawString(100, 650, "Another phone: 987-654-3210")
43
  c.save()
44
+
45
  print(f"Created dummy PDF: {pdf_path}")
46
+
47
  except ImportError:
48
  print("ReportLab not available, skipping PDF creation")
49
  # Create a simple text file instead
50
+ with open(
51
+ "example_data/example_of_emails_sent_to_a_professor_before_applying.pdf",
52
+ "w",
53
+ ) as f:
54
  f.write("This is a dummy PDF file for testing")
55
  print("Created dummy text file instead of PDF")
56
 
57
+
58
  def create_dummy_csv():
59
  """Create dummy CSV files for testing."""
60
  # Main CSV
61
  csv_data = {
62
+ "Case Note": [
63
+ "Client visited for consultation regarding housing issues",
64
+ "Follow-up appointment scheduled for next week",
65
+ "Documentation submitted for review",
66
  ],
67
+ "Client": ["John Smith", "Jane Doe", "Bob Johnson"],
68
+ "Date": ["2024-01-15", "2024-01-16", "2024-01-17"],
69
  }
70
  df = pd.DataFrame(csv_data)
71
+ df.to_csv("example_data/combined_case_notes.csv", index=False)
72
  print("Created dummy CSV: example_data/combined_case_notes.csv")
73
+
74
  # Lambeth CSV
75
  lambeth_data = {
76
+ "text": [
77
+ "Lambeth 2030 vision document content",
78
+ "Our Future Our Lambeth strategic plan",
79
+ "Community engagement and development",
80
  ],
81
+ "page": [1, 2, 3],
82
  }
83
  df_lambeth = pd.DataFrame(lambeth_data)
84
+ df_lambeth.to_csv(
85
+ "example_data/Lambeth_2030-Our_Future_Our_Lambeth.pdf.csv", index=False
86
+ )
87
  print("Created dummy CSV: example_data/Lambeth_2030-Our_Future_Our_Lambeth.pdf.csv")
88
 
89
+
90
  def create_dummy_word_doc():
91
  """Create dummy Word document."""
92
  try:
93
  from docx import Document
94
+
95
  doc = Document()
96
+ doc.add_heading("Test Document for Redaction", 0)
97
+ doc.add_paragraph("This is a test document for redaction testing.")
98
+ doc.add_paragraph("Contact Information:")
99
+ doc.add_paragraph("Email: [email protected]")
100
+ doc.add_paragraph("Phone: 123-456-7890")
101
+ doc.add_paragraph("Name: John Doe")
102
+ doc.add_paragraph("Address: 123 Test Street, Test City, TC 12345")
103
+
104
+ doc.save("example_data/Bold minimalist professional cover letter.docx")
105
  print("Created dummy Word document")
106
+
107
  except ImportError:
108
  print("python-docx not available, skipping Word document creation")
109
 
110
+
111
  def create_allow_deny_lists():
112
  """Create dummy allow/deny lists."""
113
  # Allow lists
114
+ allow_data = {"word": ["test", "example", "document"]}
115
+ pd.DataFrame(allow_data).to_csv(
116
+ "example_data/test_allow_list_graduate.csv", index=False
117
+ )
118
+ pd.DataFrame(allow_data).to_csv(
119
+ "example_data/test_allow_list_partnership.csv", index=False
120
+ )
121
  print("Created allow lists")
122
+
123
  # Deny lists
124
+ deny_data = {"word": ["sensitive", "confidential", "private"]}
125
+ pd.DataFrame(deny_data).to_csv(
126
+ "example_data/partnership_toolkit_redact_custom_deny_list.csv", index=False
127
+ )
128
+ pd.DataFrame(deny_data).to_csv(
129
+ "example_data/Partnership-Agreement-Toolkit_test_deny_list_para_single_spell.csv",
130
+ index=False,
131
+ )
132
  print("Created deny lists")
133
+
134
  # Whole page redaction list
135
+ page_data = {"page": [1, 2]}
136
+ pd.DataFrame(page_data).to_csv(
137
+ "example_data/partnership_toolkit_redact_some_pages.csv", index=False
138
+ )
139
  print("Created whole page redaction list")
140
 
141
+
142
  def create_ocr_output():
143
  """Create dummy OCR output CSV."""
144
  ocr_data = {
145
+ "file_name": ["test.pdf", "test.pdf", "test.pdf"],
146
+ "page_number": [1, 2, 3],
147
+ "text": [
148
+ "This is page 1 content with some text",
149
+ "This is page 2 content with different text",
150
+ "This is page 3 content with more text",
151
  ],
152
+ "confidence": [0.95, 0.92, 0.88],
153
  }
154
  df = pd.DataFrame(ocr_data)
155
+ df.to_csv(
156
+ "example_data/example_outputs/doubled_output_joined.pdf_ocr_output.csv",
157
+ index=False,
158
+ )
159
  print("Created dummy OCR output CSV")
160
 
161
+
162
  def create_dummy_image():
163
  """Create dummy image for testing."""
164
  try:
165
  from PIL import Image, ImageDraw, ImageFont
166
+
167
+ img = Image.new("RGB", (800, 600), color="white")
168
  draw = ImageDraw.Draw(img)
169
+
170
  # Try to use a system font
171
  try:
172
+ font = ImageFont.truetype(
173
+ "/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 20
174
+ )
175
+ except Exception as e:
176
+ print(f"Error loading DejaVuSans font: {e}")
177
  try:
178
+ font = ImageFont.truetype("/System/Library/Fonts/Arial.ttf", 20)
179
+ except Exception as e:
180
+ print(f"Error loading Arial font: {e}")
181
  font = ImageFont.load_default()
182
+
183
  # Add text to image
184
+ draw.text((50, 50), "Test Document for Redaction", fill="black", font=font)
185
+ draw.text((50, 100), "Email: [email protected]", fill="black", font=font)
186
+ draw.text((50, 150), "Phone: 123-456-7890", fill="black", font=font)
187
+ draw.text((50, 200), "Name: John Doe", fill="black", font=font)
188
+ draw.text((50, 250), "Address: 123 Test Street", fill="black", font=font)
189
+
190
+ img.save("example_data/example_complaint_letter.jpg")
191
  print("Created dummy image")
192
+
193
  except ImportError:
194
  print("PIL not available, skipping image creation")
195
 
196
+
197
  def main():
198
  """Main setup function."""
199
  print("Setting up test data for GitHub Actions...")
200
+
201
  create_directories()
202
  create_dummy_pdf()
203
  create_dummy_csv()
 
205
  create_allow_deny_lists()
206
  create_ocr_output()
207
  create_dummy_image()
208
+
209
  print("\nTest data setup complete!")
210
  print("Created files:")
211
+ for root, dirs, files in os.walk("example_data"):
212
  for file in files:
213
  print(f" {os.path.join(root, file)}")
214
 
215
+
216
  if __name__ == "__main__":
217
  main()
.github/workflows/ci.yml CHANGED
@@ -49,7 +49,7 @@ jobs:
49
  python-version: ${{ matrix.python-version }}
50
 
51
  - name: Cache pip dependencies
52
- uses: actions/cache@v3
53
  with:
54
  path: ~/.cache/pip
55
  key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt') }}
@@ -63,7 +63,7 @@ jobs:
63
  tesseract-ocr \
64
  tesseract-ocr-eng \
65
  poppler-utils \
66
- libgl1-mesa-glx \
67
  libglib2.0-0 \
68
  libsm6 \
69
  libxext6 \
@@ -107,7 +107,7 @@ jobs:
107
  fail_ci_if_error: false
108
 
109
  - name: Upload test results
110
- uses: actions/upload-artifact@v3
111
  if: always()
112
  with:
113
  name: test-results-python-${{ matrix.python-version }}
@@ -141,7 +141,7 @@ jobs:
141
  tesseract-ocr \
142
  tesseract-ocr-eng \
143
  poppler-utils \
144
- libgl1-mesa-glx \
145
  libglib2.0-0
146
 
147
  - name: Download spaCy model
@@ -189,7 +189,7 @@ jobs:
189
  bandit -r . -f json -o bandit-report.json || true
190
 
191
  - name: Upload security report
192
- uses: actions/upload-artifact@v3
193
  if: always()
194
  with:
195
  name: security-report
@@ -222,7 +222,7 @@ jobs:
222
  twine check dist/*
223
 
224
  - name: Upload build artifacts
225
- uses: actions/upload-artifact@v3
226
  with:
227
  name: dist
228
  path: dist/
 
49
  python-version: ${{ matrix.python-version }}
50
 
51
  - name: Cache pip dependencies
52
+ uses: actions/cache@v4
53
  with:
54
  path: ~/.cache/pip
55
  key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt') }}
 
63
  tesseract-ocr \
64
  tesseract-ocr-eng \
65
  poppler-utils \
66
+ libgl1-mesa-dri \
67
  libglib2.0-0 \
68
  libsm6 \
69
  libxext6 \
 
107
  fail_ci_if_error: false
108
 
109
  - name: Upload test results
110
+ uses: actions/upload-artifact@v4
111
  if: always()
112
  with:
113
  name: test-results-python-${{ matrix.python-version }}
 
141
  tesseract-ocr \
142
  tesseract-ocr-eng \
143
  poppler-utils \
144
+ libgl1-mesa-dri \
145
  libglib2.0-0
146
 
147
  - name: Download spaCy model
 
189
  bandit -r . -f json -o bandit-report.json || true
190
 
191
  - name: Upload security report
192
+ uses: actions/upload-artifact@v4
193
  if: always()
194
  with:
195
  name: security-report
 
222
  twine check dist/*
223
 
224
  - name: Upload build artifacts
225
+ uses: actions/upload-artifact@v4
226
  with:
227
  name: dist
228
  path: dist/
.github/workflows/multi-os-test.yml CHANGED
@@ -36,7 +36,7 @@ jobs:
36
  tesseract-ocr \
37
  tesseract-ocr-eng \
38
  poppler-utils \
39
- libgl1-mesa-glx \
40
  libglib2.0-0
41
 
42
  - name: Install system dependencies (macOS)
 
36
  tesseract-ocr \
37
  tesseract-ocr-eng \
38
  poppler-utils \
39
+ libgl1-mesa-dri \
40
  libglib2.0-0
41
 
42
  - name: Install system dependencies (macOS)
.github/workflows/simple-test.yml CHANGED
@@ -25,7 +25,7 @@ jobs:
25
  tesseract-ocr \
26
  tesseract-ocr-eng \
27
  poppler-utils \
28
- libgl1-mesa-glx \
29
  libglib2.0-0
30
 
31
  - name: Install Python dependencies
 
25
  tesseract-ocr \
26
  tesseract-ocr-eng \
27
  poppler-utils \
28
+ libgl1-mesa-dri \
29
  libglib2.0-0
30
 
31
  - name: Install Python dependencies
.github/workflows/test.yml CHANGED
@@ -22,7 +22,7 @@ jobs:
22
  python-version: ${{ matrix.python-version }}
23
 
24
  - name: Cache pip dependencies
25
- uses: actions/cache@v3
26
  with:
27
  path: ~/.cache/pip
28
  key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt') }}
@@ -36,7 +36,7 @@ jobs:
36
  tesseract-ocr \
37
  tesseract-ocr-eng \
38
  poppler-utils \
39
- libgl1-mesa-glx \
40
  libglib2.0-0
41
 
42
  - name: Install Python dependencies
@@ -201,7 +201,7 @@ jobs:
201
  fail_ci_if_error: false
202
 
203
  - name: Upload test results
204
- uses: actions/upload-artifact@v3
205
  if: always()
206
  with:
207
  name: test-results-python-${{ matrix.python-version }}
 
22
  python-version: ${{ matrix.python-version }}
23
 
24
  - name: Cache pip dependencies
25
+ uses: actions/cache@v4
26
  with:
27
  path: ~/.cache/pip
28
  key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt') }}
 
36
  tesseract-ocr \
37
  tesseract-ocr-eng \
38
  poppler-utils \
39
+ libgl1-mesa-dri \
40
  libglib2.0-0
41
 
42
  - name: Install Python dependencies
 
201
  fail_ci_if_error: false
202
 
203
  - name: Upload test results
204
+ uses: actions/upload-artifact@v4
205
  if: always()
206
  with:
207
  name: test-results-python-${{ matrix.python-version }}
app.py CHANGED
The diff for this file is too large to render. See raw diff
 
cdk/app.py CHANGED
@@ -1,11 +1,12 @@
1
  import os
2
- from aws_cdk import (App, Environment)
3
 
4
- # Assuming these are still relevant for you
5
- from check_resources import check_and_set_context, CONTEXT_FILE
6
  from cdk_config import AWS_ACCOUNT_ID, AWS_REGION, RUN_USEAST_STACK, USE_CLOUDFRONT
7
- from cdk_stack import CdkStack, CdkStackCloudfront#, CdkStackMain
8
- from cdk_functions import load_context_from_file, create_basic_config_env
 
 
 
9
 
10
  # Initialize the CDK app
11
  app = App()
@@ -25,7 +26,9 @@ print("Running pre-check script to generate application context...")
25
  try:
26
  check_and_set_context()
27
  if not os.path.exists(CONTEXT_FILE):
28
- raise RuntimeError(f"check_and_set_context() finished, but {CONTEXT_FILE} was not created.")
 
 
29
  print(f"Context generated successfully at {CONTEXT_FILE}.")
30
  except Exception as e:
31
  raise RuntimeError(f"Failed to generate context via check_and_set_context(): {e}")
@@ -56,12 +59,11 @@ aws_env_regional = Environment(account=AWS_ACCOUNT_ID, region=AWS_REGION)
56
  # public_route_tables=regional_stack.params["public_route_tables"],
57
  # cross_region_references=True)
58
 
59
- regional_stack = CdkStack(app,
60
- "RedactionStack",
61
- env=aws_env_regional,
62
- cross_region_references=True)
63
 
64
- if USE_CLOUDFRONT == 'True' and RUN_USEAST_STACK == 'True':
65
  # Define the environment for the CloudFront stack (always us-east-1 for CF-level resources like WAFv2 WebACLs for CF)
66
  aws_env_us_east_1 = Environment(account=AWS_ACCOUNT_ID, region="us-east-1")
67
 
@@ -72,10 +74,10 @@ if USE_CLOUDFRONT == 'True' and RUN_USEAST_STACK == 'True':
72
  env=aws_env_us_east_1,
73
  alb_arn=regional_stack.params["alb_arn_output"],
74
  alb_sec_group_id=regional_stack.params["alb_security_group_id"],
75
- alb_dns_name=regional_stack.params["alb_dns_name"],
76
- cross_region_references=True
77
  )
78
 
79
 
80
  # Synthesize the CloudFormation template
81
- app.synth(validate_on_synthesis=True)
 
1
  import os
 
2
 
3
+ from aws_cdk import App, Environment
 
4
  from cdk_config import AWS_ACCOUNT_ID, AWS_REGION, RUN_USEAST_STACK, USE_CLOUDFRONT
5
+ from cdk_functions import create_basic_config_env, load_context_from_file
6
+ from cdk_stack import CdkStack, CdkStackCloudfront # , CdkStackMain
7
+
8
+ # Assuming these are still relevant for you
9
+ from check_resources import CONTEXT_FILE, check_and_set_context
10
 
11
  # Initialize the CDK app
12
  app = App()
 
26
  try:
27
  check_and_set_context()
28
  if not os.path.exists(CONTEXT_FILE):
29
+ raise RuntimeError(
30
+ f"check_and_set_context() finished, but {CONTEXT_FILE} was not created."
31
+ )
32
  print(f"Context generated successfully at {CONTEXT_FILE}.")
33
  except Exception as e:
34
  raise RuntimeError(f"Failed to generate context via check_and_set_context(): {e}")
 
59
  # public_route_tables=regional_stack.params["public_route_tables"],
60
  # cross_region_references=True)
61
 
62
+ regional_stack = CdkStack(
63
+ app, "RedactionStack", env=aws_env_regional, cross_region_references=True
64
+ )
 
65
 
66
+ if USE_CLOUDFRONT == "True" and RUN_USEAST_STACK == "True":
67
  # Define the environment for the CloudFront stack (always us-east-1 for CF-level resources like WAFv2 WebACLs for CF)
68
  aws_env_us_east_1 = Environment(account=AWS_ACCOUNT_ID, region="us-east-1")
69
 
 
74
  env=aws_env_us_east_1,
75
  alb_arn=regional_stack.params["alb_arn_output"],
76
  alb_sec_group_id=regional_stack.params["alb_security_group_id"],
77
+ alb_dns_name=regional_stack.params["alb_dns_name"],
78
+ cross_region_references=True,
79
  )
80
 
81
 
82
  # Synthesize the CloudFormation template
83
+ app.synth(validate_on_synthesis=True)
cdk/cdk_config.py CHANGED
@@ -1,28 +1,31 @@
1
  import os
2
  import tempfile
 
3
  from dotenv import load_dotenv
4
 
5
  # Set or retrieve configuration variables for CDK redaction deployment
6
 
7
- def get_or_create_env_var(var_name:str, default_value:str, print_val:bool=False):
8
- '''
 
9
  Get an environmental variable, and set it to a default value if it doesn't exist
10
- '''
11
  # Get the environment variable if it exists
12
  value = os.environ.get(var_name)
13
-
14
  # If it doesn't exist, set the environment variable to the default value
15
  if value is None:
16
  os.environ[var_name] = default_value
17
  value = default_value
18
 
19
- if print_val == True:
20
- print(f'The value of {var_name} is {value}')
21
-
22
  return value
23
 
24
- def ensure_folder_exists(output_folder:str):
25
- """Checks if the specified folder exists, creates it if not."""
 
26
 
27
  if not os.path.exists(output_folder):
28
  # Create the folder if it doesn't exist
@@ -31,10 +34,11 @@ def ensure_folder_exists(output_folder:str):
31
  else:
32
  print(f"The {output_folder} folder already exists.")
33
 
 
34
  def add_folder_to_path(folder_path: str):
35
- '''
36
  Check if a folder exists on your system. If so, get the absolute path and then add it to the system Path variable if it doesn't already exist. Function is only relevant for locally-created executable files based on this app (when using pyinstaller it creates a _internal folder that contains tesseract and poppler. These need to be added to the system path to enable the app to run)
37
- '''
38
 
39
  if os.path.exists(folder_path) and os.path.isdir(folder_path):
40
  print(folder_path, "folder exists.")
@@ -42,192 +46,295 @@ def add_folder_to_path(folder_path: str):
42
  # Resolve relative path to absolute path
43
  absolute_path = os.path.abspath(folder_path)
44
 
45
- current_path = os.environ['PATH']
46
  if absolute_path not in current_path.split(os.pathsep):
47
  full_path_extension = absolute_path + os.pathsep + current_path
48
- os.environ['PATH'] = full_path_extension
49
- #print(f"Updated PATH with: ", full_path_extension)
50
  else:
51
  print(f"Directory {folder_path} already exists in PATH.")
52
  else:
53
  print(f"Folder not found at {folder_path} - not added to PATH")
54
 
 
55
  ###
56
  # LOAD CONFIG FROM ENV FILE
57
  ###
58
- CONFIG_FOLDER = get_or_create_env_var('CONFIG_FOLDER', "config/")
59
 
60
  ensure_folder_exists(CONFIG_FOLDER)
61
 
62
  # If you have an aws_config env file in the config folder, you can load in app variables this way, e.g. 'config/cdk_config.env'
63
- CDK_CONFIG_PATH = get_or_create_env_var('CDK_CONFIG_PATH', 'config/cdk_config.env') # e.g. config/cdk_config.env
 
 
64
 
65
  if CDK_CONFIG_PATH:
66
  if os.path.exists(CDK_CONFIG_PATH):
67
  print(f"Loading CDK variables from config file {CDK_CONFIG_PATH}")
68
  load_dotenv(CDK_CONFIG_PATH)
69
- else: print("CDK config file not found at location:", CDK_CONFIG_PATH)
 
70
 
71
  ###
72
  # AWS OPTIONS
73
  ###
74
- AWS_REGION = get_or_create_env_var('AWS_REGION', '')
75
- AWS_ACCOUNT_ID = get_or_create_env_var('AWS_ACCOUNT_ID', '')
76
 
77
  ###
78
  # CDK OPTIONS
79
  ###
80
- CDK_PREFIX = get_or_create_env_var('CDK_PREFIX', '')
81
- CONTEXT_FILE = get_or_create_env_var('CONTEXT_FILE', 'cdk.context.json') # Define the CDK output context file name
82
- CDK_FOLDER = get_or_create_env_var('CDK_FOLDER', '') # FULL_PATH_TO_CDK_FOLDER_HERE (with forward slash)
83
- RUN_USEAST_STACK = get_or_create_env_var('RUN_USEAST_STACK', 'False')
 
 
 
 
84
 
85
  ### VPC and connections
86
- VPC_NAME = get_or_create_env_var('VPC_NAME', '')
87
- NEW_VPC_DEFAULT_NAME = get_or_create_env_var('NEW_VPC_DEFAULT_NAME', f'{CDK_PREFIX}vpc')
88
- NEW_VPC_CIDR = get_or_create_env_var('NEW_VPC_CIDR', '') # "10.0.0.0/24"
89
 
90
 
91
- EXISTING_IGW_ID = get_or_create_env_var('EXISTING_IGW_ID', '')
92
- SINGLE_NAT_GATEWAY_ID = get_or_create_env_var('SINGLE_NAT_GATEWAY_ID', '')
93
 
94
  ### SUBNETS / ROUTE TABLES / NAT GATEWAY
95
- PUBLIC_SUBNETS_TO_USE = get_or_create_env_var('PUBLIC_SUBNETS_TO_USE', '') # e.g. ['PublicSubnet1', 'PublicSubnet2']
96
- PUBLIC_SUBNET_CIDR_BLOCKS = get_or_create_env_var('PUBLIC_SUBNET_CIDR_BLOCKS', '') # e.g. ["10.0.1.0/24", "10.0.2.0/24"]
97
- PUBLIC_SUBNET_AVAILABILITY_ZONES = get_or_create_env_var('PUBLIC_SUBNET_AVAILABILITY_ZONES', '') # e.g. ["eu-east-1b", "eu-east1b"]
98
-
99
- PRIVATE_SUBNETS_TO_USE = get_or_create_env_var('PRIVATE_SUBNETS_TO_USE', '') # e.g. ['PrivateSubnet1', 'PrivateSubnet2']
100
- PRIVATE_SUBNET_CIDR_BLOCKS = get_or_create_env_var('PRIVATE_SUBNET_CIDR_BLOCKS', '') # e.g. ["10.0.1.0/24", "10.0.2.0/24"]
101
- PRIVATE_SUBNET_AVAILABILITY_ZONES = get_or_create_env_var('PRIVATE_SUBNET_AVAILABILITY_ZONES', '') # e.g. ["eu-east-1b", "eu-east1b"]
102
-
103
- ROUTE_TABLE_BASE_NAME = get_or_create_env_var('ROUTE_TABLE_BASE_NAME', f'{CDK_PREFIX}PrivateRouteTable')
104
- NAT_GATEWAY_EIP_NAME = get_or_create_env_var('NAT_GATEWAY_EIP_NAME', f"{CDK_PREFIX}NatGatewayEip")
105
- NAT_GATEWAY_NAME = get_or_create_env_var('NAT_GATEWAY_NAME', f"{CDK_PREFIX}NatGateway")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
107
  # IAM roles
108
- AWS_MANAGED_TASK_ROLES_LIST = get_or_create_env_var('AWS_MANAGED_TASK_ROLES_LIST', '["AmazonCognitoReadOnly", "service-role/AmazonECSTaskExecutionRolePolicy", "AmazonS3FullAccess", "AmazonTextractFullAccess", "ComprehendReadOnly", "AmazonDynamoDBFullAccess", "service-role/AWSAppSyncPushToCloudWatchLogs"]')
109
- POLICY_FILE_LOCATIONS = get_or_create_env_var('POLICY_FILE_LOCATIONS', '') # e.g. '["config/sts_permissions.json"]'
110
- POLICY_FILE_ARNS = get_or_create_env_var('POLICY_FILE_ARNS', '')
 
 
 
 
 
111
 
112
  # GITHUB REPO
113
- GITHUB_REPO_USERNAME = get_or_create_env_var('GITHUB_REPO_USERNAME', 'seanpedrick-case')
114
- GITHUB_REPO_NAME = get_or_create_env_var('GITHUB_REPO_NAME', 'doc_redaction')
115
- GITHUB_REPO_BRANCH = get_or_create_env_var('GITHUB_REPO_BRANCH', 'main')
116
 
117
  ### CODEBUILD
118
- CODEBUILD_ROLE_NAME = get_or_create_env_var('CODEBUILD_ROLE_NAME', f"{CDK_PREFIX}CodeBuildRole")
119
- CODEBUILD_PROJECT_NAME = get_or_create_env_var('CODEBUILD_PROJECT_NAME', f"{CDK_PREFIX}CodeBuildProject")
 
 
 
 
120
 
121
  ### ECR
122
- ECR_REPO_NAME = get_or_create_env_var('ECR_REPO_NAME', 'doc-redaction') # Beware - cannot have underscores and must be lower case
123
- ECR_CDK_REPO_NAME = get_or_create_env_var('ECR_CDK_REPO_NAME', f"{CDK_PREFIX}{ECR_REPO_NAME}".lower())
 
 
 
 
124
 
125
  ### S3
126
- S3_LOG_CONFIG_BUCKET_NAME = get_or_create_env_var('S3_LOG_CONFIG_BUCKET_NAME', f"{CDK_PREFIX}s3-logs".lower()) # S3 bucket names need to be lower case
127
- S3_OUTPUT_BUCKET_NAME = get_or_create_env_var('S3_OUTPUT_BUCKET_NAME', f"{CDK_PREFIX}s3-output".lower())
 
 
 
 
128
 
129
  ### KMS KEYS FOR S3 AND SECRETS MANAGER
130
- USE_CUSTOM_KMS_KEY = get_or_create_env_var('USE_CUSTOM_KMS_KEY', '1')
131
- CUSTOM_KMS_KEY_NAME = get_or_create_env_var('CUSTOM_KMS_KEY_NAME', f"alias/{CDK_PREFIX}kms-key".lower())
 
 
132
 
133
  ### ECS
134
- FARGATE_TASK_DEFINITION_NAME = get_or_create_env_var('FARGATE_TASK_DEFINITION_NAME', f"{CDK_PREFIX}FargateTaskDefinition")
135
- TASK_DEFINITION_FILE_LOCATION = get_or_create_env_var('TASK_DEFINITION_FILE_LOCATION', CDK_FOLDER + CONFIG_FOLDER + "task_definition.json")
136
-
137
- CLUSTER_NAME = get_or_create_env_var('CLUSTER_NAME', f"{CDK_PREFIX}Cluster")
138
- ECS_SERVICE_NAME = get_or_create_env_var('ECS_SERVICE_NAME', f"{CDK_PREFIX}ECSService")
139
- ECS_TASK_ROLE_NAME = get_or_create_env_var('ECS_TASK_ROLE_NAME', f"{CDK_PREFIX}TaskRole")
140
- ECS_TASK_EXECUTION_ROLE_NAME = get_or_create_env_var('ECS_TASK_EXECUTION_ROLE_NAME', f"{CDK_PREFIX}ExecutionRole")
141
- ECS_SECURITY_GROUP_NAME = get_or_create_env_var('ECS_SECURITY_GROUP_NAME', f"{CDK_PREFIX}SecurityGroupECS")
142
- ECS_LOG_GROUP_NAME = get_or_create_env_var('ECS_LOG_GROUP_NAME', f"/ecs/{ECS_SERVICE_NAME}-logs".lower())
143
-
144
- ECS_TASK_CPU_SIZE = get_or_create_env_var('ECS_TASK_CPU_SIZE', '1024')
145
- ECS_TASK_MEMORY_SIZE = get_or_create_env_var('ECS_TASK_MEMORY_SIZE', '4096')
146
- ECS_USE_FARGATE_SPOT = get_or_create_env_var('USE_FARGATE_SPOT', 'False')
147
- ECS_READ_ONLY_FILE_SYSTEM = get_or_create_env_var('ECS_READ_ONLY_FILE_SYSTEM', 'True')
 
 
 
 
 
 
 
 
 
 
 
 
148
 
149
  ### Cognito
150
- COGNITO_USER_POOL_NAME = get_or_create_env_var('COGNITO_USER_POOL_NAME', f"{CDK_PREFIX}UserPool")
151
- COGNITO_USER_POOL_CLIENT_NAME = get_or_create_env_var('COGNITO_USER_POOL_CLIENT_NAME', f"{CDK_PREFIX}UserPoolClient")
152
- COGNITO_USER_POOL_CLIENT_SECRET_NAME = get_or_create_env_var('COGNITO_USER_POOL_CLIENT_SECRET_NAME', f"{CDK_PREFIX}ParamCognitoSecret")
153
- COGNITO_USER_POOL_DOMAIN_PREFIX = get_or_create_env_var('COGNITO_USER_POOL_DOMAIN_PREFIX', "redaction-app-domain") # Should change this to something unique or you'll probably hit an error
 
 
 
 
 
 
 
 
154
 
155
  # Application load balancer
156
- ALB_NAME = get_or_create_env_var('ALB_NAME', f"{CDK_PREFIX}Alb"[-32:]) # Application load balancer name can be max 32 characters, so taking the last 32 characters of the suggested name
157
- ALB_NAME_SECURITY_GROUP_NAME = get_or_create_env_var('ALB_SECURITY_GROUP_NAME', f"{CDK_PREFIX}SecurityGroupALB")
158
- ALB_TARGET_GROUP_NAME = get_or_create_env_var('ALB_TARGET_GROUP_NAME', f"{CDK_PREFIX}-tg"[-32:]) # Max 32 characters
159
- EXISTING_LOAD_BALANCER_ARN = get_or_create_env_var('EXISTING_LOAD_BALANCER_ARN', '')
160
- EXISTING_LOAD_BALANCER_DNS = get_or_create_env_var('EXISTING_LOAD_BALANCER_ARN', 'placeholder_load_balancer_dns.net')
 
 
 
 
 
 
 
 
161
 
162
  ## CLOUDFRONT
163
- USE_CLOUDFRONT = get_or_create_env_var('USE_CLOUDFRONT', 'True')
164
- CLOUDFRONT_PREFIX_LIST_ID = get_or_create_env_var('CLOUDFRONT_PREFIX_LIST_ID', 'pl-93a247fa')
165
- CLOUDFRONT_GEO_RESTRICTION = get_or_create_env_var('CLOUDFRONT_GEO_RESTRICTION', '') # A country that Cloudfront restricts access to. See here: https://docs.aws.amazon.com/AmazonCloudFront/latest/DeveloperGuide/georestrictions.html
166
- CLOUDFRONT_DISTRIBUTION_NAME = get_or_create_env_var('CLOUDFRONT_DISTRIBUTION_NAME', f"{CDK_PREFIX}CfDist")
167
- CLOUDFRONT_DOMAIN = get_or_create_env_var('CLOUDFRONT_DOMAIN', "cloudfront_placeholder.net")
 
 
 
 
 
 
 
 
168
 
169
 
170
  # Certificate for Application load balancer (optional, for HTTPS and logins through the ALB)
171
- ACM_SSL_CERTIFICATE_ARN = get_or_create_env_var('ACM_SSL_CERTIFICATE_ARN', '')
172
- SSL_CERTIFICATE_DOMAIN = get_or_create_env_var('SSL_CERTIFICATE_DOMAIN', '') # e.g. example.com or www.example.com
 
 
173
 
174
  # This should be the CloudFront domain, the domain linked to your ACM certificate, or the DNS of your application load balancer in console afterwards
175
  if USE_CLOUDFRONT == "True":
176
- COGNITO_REDIRECTION_URL = get_or_create_env_var('COGNITO_REDIRECTION_URL', "https://" + CLOUDFRONT_DOMAIN)
 
 
177
  elif SSL_CERTIFICATE_DOMAIN:
178
- COGNITO_REDIRECTION_URL = get_or_create_env_var('COGNITO_REDIRECTION_URL', "https://" + SSL_CERTIFICATE_DOMAIN)
 
 
179
  else:
180
- COGNITO_REDIRECTION_URL = get_or_create_env_var('COGNITO_REDIRECTION_URL', "https://" + EXISTING_LOAD_BALANCER_DNS)
 
 
181
 
182
  # Custom headers e.g. if routing traffic through Cloudfront
183
- CUSTOM_HEADER = get_or_create_env_var('CUSTOM_HEADER', '') # Retrieving or setting CUSTOM_HEADER
184
- CUSTOM_HEADER_VALUE = get_or_create_env_var('CUSTOM_HEADER_VALUE', '') # Retrieving or setting CUSTOM_HEADER_VALUE
 
 
 
 
185
 
186
  # Firewall on top of load balancer
187
- LOAD_BALANCER_WEB_ACL_NAME = get_or_create_env_var('LOAD_BALANCER_WEB_ACL_NAME', f"{CDK_PREFIX}alb-web-acl")
 
 
188
 
189
  # Firewall on top of CloudFront
190
- WEB_ACL_NAME = get_or_create_env_var('WEB_ACL_NAME', f"{CDK_PREFIX}cloudfront-web-acl")
191
 
192
  ###
193
  # File I/O options
194
  ###
195
 
196
- OUTPUT_FOLDER = get_or_create_env_var('GRADIO_OUTPUT_FOLDER', 'output/') # 'output/'
197
- INPUT_FOLDER = get_or_create_env_var('GRADIO_INPUT_FOLDER', 'input/') # 'input/'
198
 
199
  # Allow for files to be saved in a temporary folder for increased security in some instances
200
- if OUTPUT_FOLDER == "TEMP" or INPUT_FOLDER == "TEMP":
201
  # Create a temporary directory
202
  with tempfile.TemporaryDirectory() as temp_dir:
203
- print(f'Temporary directory created at: {temp_dir}')
204
 
205
- if OUTPUT_FOLDER == "TEMP": OUTPUT_FOLDER = temp_dir + "/"
206
- if INPUT_FOLDER == "TEMP": INPUT_FOLDER = temp_dir + "/"
 
 
207
 
208
  ###
209
  # LOGGING OPTIONS
210
  ###
211
 
212
- SAVE_LOGS_TO_CSV = get_or_create_env_var('SAVE_LOGS_TO_CSV', 'True')
213
 
214
  ### DYNAMODB logs. Whether to save to DynamoDB, and the headers of the table
215
- SAVE_LOGS_TO_DYNAMODB = get_or_create_env_var('SAVE_LOGS_TO_DYNAMODB', 'True')
216
- ACCESS_LOG_DYNAMODB_TABLE_NAME = get_or_create_env_var('ACCESS_LOG_DYNAMODB_TABLE_NAME', f"{CDK_PREFIX}dynamodb-access-logs".lower())
217
- FEEDBACK_LOG_DYNAMODB_TABLE_NAME = get_or_create_env_var('FEEDBACK_LOG_DYNAMODB_TABLE_NAME', f"{CDK_PREFIX}dynamodb-feedback-logs".lower())
218
- USAGE_LOG_DYNAMODB_TABLE_NAME = get_or_create_env_var('USAGE_LOG_DYNAMODB_TABLE_NAME', f"{CDK_PREFIX}dynamodb-usage-logs".lower())
 
 
 
 
 
 
219
 
220
  ###
221
  # REDACTION OPTIONS
222
  ###
223
 
224
  # Get some environment variables and Launch the Gradio app
225
- COGNITO_AUTH = get_or_create_env_var('COGNITO_AUTH', '0')
226
 
227
- GRADIO_SERVER_PORT = int(get_or_create_env_var('GRADIO_SERVER_PORT', '7860'))
228
 
229
  ###
230
  # WHOLE DOCUMENT API OPTIONS
231
  ###
232
 
233
- DAYS_TO_DISPLAY_WHOLE_DOCUMENT_JOBS = get_or_create_env_var('DAYS_TO_DISPLAY_WHOLE_DOCUMENT_JOBS', '7') # How many days into the past should whole document Textract jobs be displayed? After that, the data is not deleted from the Textract jobs csv, but it is just filtered out. Included to align with S3 buckets where the file outputs will be automatically deleted after X days.
 
 
 
1
  import os
2
  import tempfile
3
+
4
  from dotenv import load_dotenv
5
 
6
  # Set or retrieve configuration variables for CDK redaction deployment
7
 
8
+
9
+ def get_or_create_env_var(var_name: str, default_value: str, print_val: bool = False):
10
+ """
11
  Get an environmental variable, and set it to a default value if it doesn't exist
12
+ """
13
  # Get the environment variable if it exists
14
  value = os.environ.get(var_name)
15
+
16
  # If it doesn't exist, set the environment variable to the default value
17
  if value is None:
18
  os.environ[var_name] = default_value
19
  value = default_value
20
 
21
+ if print_val is True:
22
+ print(f"The value of {var_name} is {value}")
23
+
24
  return value
25
 
26
+
27
+ def ensure_folder_exists(output_folder: str):
28
+ """Checks if the specified folder exists, creates it if not."""
29
 
30
  if not os.path.exists(output_folder):
31
  # Create the folder if it doesn't exist
 
34
  else:
35
  print(f"The {output_folder} folder already exists.")
36
 
37
+
38
  def add_folder_to_path(folder_path: str):
39
+ """
40
  Check if a folder exists on your system. If so, get the absolute path and then add it to the system Path variable if it doesn't already exist. Function is only relevant for locally-created executable files based on this app (when using pyinstaller it creates a _internal folder that contains tesseract and poppler. These need to be added to the system path to enable the app to run)
41
+ """
42
 
43
  if os.path.exists(folder_path) and os.path.isdir(folder_path):
44
  print(folder_path, "folder exists.")
 
46
  # Resolve relative path to absolute path
47
  absolute_path = os.path.abspath(folder_path)
48
 
49
+ current_path = os.environ["PATH"]
50
  if absolute_path not in current_path.split(os.pathsep):
51
  full_path_extension = absolute_path + os.pathsep + current_path
52
+ os.environ["PATH"] = full_path_extension
53
+ # print(f"Updated PATH with: ", full_path_extension)
54
  else:
55
  print(f"Directory {folder_path} already exists in PATH.")
56
  else:
57
  print(f"Folder not found at {folder_path} - not added to PATH")
58
 
59
+
60
  ###
61
  # LOAD CONFIG FROM ENV FILE
62
  ###
63
+ CONFIG_FOLDER = get_or_create_env_var("CONFIG_FOLDER", "config/")
64
 
65
  ensure_folder_exists(CONFIG_FOLDER)
66
 
67
  # If you have an aws_config env file in the config folder, you can load in app variables this way, e.g. 'config/cdk_config.env'
68
+ CDK_CONFIG_PATH = get_or_create_env_var(
69
+ "CDK_CONFIG_PATH", "config/cdk_config.env"
70
+ ) # e.g. config/cdk_config.env
71
 
72
  if CDK_CONFIG_PATH:
73
  if os.path.exists(CDK_CONFIG_PATH):
74
  print(f"Loading CDK variables from config file {CDK_CONFIG_PATH}")
75
  load_dotenv(CDK_CONFIG_PATH)
76
+ else:
77
+ print("CDK config file not found at location:", CDK_CONFIG_PATH)
78
 
79
  ###
80
  # AWS OPTIONS
81
  ###
82
+ AWS_REGION = get_or_create_env_var("AWS_REGION", "")
83
+ AWS_ACCOUNT_ID = get_or_create_env_var("AWS_ACCOUNT_ID", "")
84
 
85
  ###
86
  # CDK OPTIONS
87
  ###
88
+ CDK_PREFIX = get_or_create_env_var("CDK_PREFIX", "")
89
+ CONTEXT_FILE = get_or_create_env_var(
90
+ "CONTEXT_FILE", "cdk.context.json"
91
+ ) # Define the CDK output context file name
92
+ CDK_FOLDER = get_or_create_env_var(
93
+ "CDK_FOLDER", ""
94
+ ) # FULL_PATH_TO_CDK_FOLDER_HERE (with forward slash)
95
+ RUN_USEAST_STACK = get_or_create_env_var("RUN_USEAST_STACK", "False")
96
 
97
  ### VPC and connections
98
+ VPC_NAME = get_or_create_env_var("VPC_NAME", "")
99
+ NEW_VPC_DEFAULT_NAME = get_or_create_env_var("NEW_VPC_DEFAULT_NAME", f"{CDK_PREFIX}vpc")
100
+ NEW_VPC_CIDR = get_or_create_env_var("NEW_VPC_CIDR", "") # "10.0.0.0/24"
101
 
102
 
103
+ EXISTING_IGW_ID = get_or_create_env_var("EXISTING_IGW_ID", "")
104
+ SINGLE_NAT_GATEWAY_ID = get_or_create_env_var("SINGLE_NAT_GATEWAY_ID", "")
105
 
106
  ### SUBNETS / ROUTE TABLES / NAT GATEWAY
107
+ PUBLIC_SUBNETS_TO_USE = get_or_create_env_var(
108
+ "PUBLIC_SUBNETS_TO_USE", ""
109
+ ) # e.g. ['PublicSubnet1', 'PublicSubnet2']
110
+ PUBLIC_SUBNET_CIDR_BLOCKS = get_or_create_env_var(
111
+ "PUBLIC_SUBNET_CIDR_BLOCKS", ""
112
+ ) # e.g. ["10.0.1.0/24", "10.0.2.0/24"]
113
+ PUBLIC_SUBNET_AVAILABILITY_ZONES = get_or_create_env_var(
114
+ "PUBLIC_SUBNET_AVAILABILITY_ZONES", ""
115
+ ) # e.g. ["eu-east-1b", "eu-east1b"]
116
+
117
+ PRIVATE_SUBNETS_TO_USE = get_or_create_env_var(
118
+ "PRIVATE_SUBNETS_TO_USE", ""
119
+ ) # e.g. ['PrivateSubnet1', 'PrivateSubnet2']
120
+ PRIVATE_SUBNET_CIDR_BLOCKS = get_or_create_env_var(
121
+ "PRIVATE_SUBNET_CIDR_BLOCKS", ""
122
+ ) # e.g. ["10.0.1.0/24", "10.0.2.0/24"]
123
+ PRIVATE_SUBNET_AVAILABILITY_ZONES = get_or_create_env_var(
124
+ "PRIVATE_SUBNET_AVAILABILITY_ZONES", ""
125
+ ) # e.g. ["eu-east-1b", "eu-east1b"]
126
+
127
+ ROUTE_TABLE_BASE_NAME = get_or_create_env_var(
128
+ "ROUTE_TABLE_BASE_NAME", f"{CDK_PREFIX}PrivateRouteTable"
129
+ )
130
+ NAT_GATEWAY_EIP_NAME = get_or_create_env_var(
131
+ "NAT_GATEWAY_EIP_NAME", f"{CDK_PREFIX}NatGatewayEip"
132
+ )
133
+ NAT_GATEWAY_NAME = get_or_create_env_var("NAT_GATEWAY_NAME", f"{CDK_PREFIX}NatGateway")
134
 
135
  # IAM roles
136
+ AWS_MANAGED_TASK_ROLES_LIST = get_or_create_env_var(
137
+ "AWS_MANAGED_TASK_ROLES_LIST",
138
+ '["AmazonCognitoReadOnly", "service-role/AmazonECSTaskExecutionRolePolicy", "AmazonS3FullAccess", "AmazonTextractFullAccess", "ComprehendReadOnly", "AmazonDynamoDBFullAccess", "service-role/AWSAppSyncPushToCloudWatchLogs"]',
139
+ )
140
+ POLICY_FILE_LOCATIONS = get_or_create_env_var(
141
+ "POLICY_FILE_LOCATIONS", ""
142
+ ) # e.g. '["config/sts_permissions.json"]'
143
+ POLICY_FILE_ARNS = get_or_create_env_var("POLICY_FILE_ARNS", "")
144
 
145
  # GITHUB REPO
146
+ GITHUB_REPO_USERNAME = get_or_create_env_var("GITHUB_REPO_USERNAME", "seanpedrick-case")
147
+ GITHUB_REPO_NAME = get_or_create_env_var("GITHUB_REPO_NAME", "doc_redaction")
148
+ GITHUB_REPO_BRANCH = get_or_create_env_var("GITHUB_REPO_BRANCH", "main")
149
 
150
  ### CODEBUILD
151
+ CODEBUILD_ROLE_NAME = get_or_create_env_var(
152
+ "CODEBUILD_ROLE_NAME", f"{CDK_PREFIX}CodeBuildRole"
153
+ )
154
+ CODEBUILD_PROJECT_NAME = get_or_create_env_var(
155
+ "CODEBUILD_PROJECT_NAME", f"{CDK_PREFIX}CodeBuildProject"
156
+ )
157
 
158
  ### ECR
159
+ ECR_REPO_NAME = get_or_create_env_var(
160
+ "ECR_REPO_NAME", "doc-redaction"
161
+ ) # Beware - cannot have underscores and must be lower case
162
+ ECR_CDK_REPO_NAME = get_or_create_env_var(
163
+ "ECR_CDK_REPO_NAME", f"{CDK_PREFIX}{ECR_REPO_NAME}".lower()
164
+ )
165
 
166
  ### S3
167
+ S3_LOG_CONFIG_BUCKET_NAME = get_or_create_env_var(
168
+ "S3_LOG_CONFIG_BUCKET_NAME", f"{CDK_PREFIX}s3-logs".lower()
169
+ ) # S3 bucket names need to be lower case
170
+ S3_OUTPUT_BUCKET_NAME = get_or_create_env_var(
171
+ "S3_OUTPUT_BUCKET_NAME", f"{CDK_PREFIX}s3-output".lower()
172
+ )
173
 
174
  ### KMS KEYS FOR S3 AND SECRETS MANAGER
175
+ USE_CUSTOM_KMS_KEY = get_or_create_env_var("USE_CUSTOM_KMS_KEY", "1")
176
+ CUSTOM_KMS_KEY_NAME = get_or_create_env_var(
177
+ "CUSTOM_KMS_KEY_NAME", f"alias/{CDK_PREFIX}kms-key".lower()
178
+ )
179
 
180
  ### ECS
181
+ FARGATE_TASK_DEFINITION_NAME = get_or_create_env_var(
182
+ "FARGATE_TASK_DEFINITION_NAME", f"{CDK_PREFIX}FargateTaskDefinition"
183
+ )
184
+ TASK_DEFINITION_FILE_LOCATION = get_or_create_env_var(
185
+ "TASK_DEFINITION_FILE_LOCATION", CDK_FOLDER + CONFIG_FOLDER + "task_definition.json"
186
+ )
187
+
188
+ CLUSTER_NAME = get_or_create_env_var("CLUSTER_NAME", f"{CDK_PREFIX}Cluster")
189
+ ECS_SERVICE_NAME = get_or_create_env_var("ECS_SERVICE_NAME", f"{CDK_PREFIX}ECSService")
190
+ ECS_TASK_ROLE_NAME = get_or_create_env_var(
191
+ "ECS_TASK_ROLE_NAME", f"{CDK_PREFIX}TaskRole"
192
+ )
193
+ ECS_TASK_EXECUTION_ROLE_NAME = get_or_create_env_var(
194
+ "ECS_TASK_EXECUTION_ROLE_NAME", f"{CDK_PREFIX}ExecutionRole"
195
+ )
196
+ ECS_SECURITY_GROUP_NAME = get_or_create_env_var(
197
+ "ECS_SECURITY_GROUP_NAME", f"{CDK_PREFIX}SecurityGroupECS"
198
+ )
199
+ ECS_LOG_GROUP_NAME = get_or_create_env_var(
200
+ "ECS_LOG_GROUP_NAME", f"/ecs/{ECS_SERVICE_NAME}-logs".lower()
201
+ )
202
+
203
+ ECS_TASK_CPU_SIZE = get_or_create_env_var("ECS_TASK_CPU_SIZE", "1024")
204
+ ECS_TASK_MEMORY_SIZE = get_or_create_env_var("ECS_TASK_MEMORY_SIZE", "4096")
205
+ ECS_USE_FARGATE_SPOT = get_or_create_env_var("USE_FARGATE_SPOT", "False")
206
+ ECS_READ_ONLY_FILE_SYSTEM = get_or_create_env_var("ECS_READ_ONLY_FILE_SYSTEM", "True")
207
 
208
  ### Cognito
209
+ COGNITO_USER_POOL_NAME = get_or_create_env_var(
210
+ "COGNITO_USER_POOL_NAME", f"{CDK_PREFIX}UserPool"
211
+ )
212
+ COGNITO_USER_POOL_CLIENT_NAME = get_or_create_env_var(
213
+ "COGNITO_USER_POOL_CLIENT_NAME", f"{CDK_PREFIX}UserPoolClient"
214
+ )
215
+ COGNITO_USER_POOL_CLIENT_SECRET_NAME = get_or_create_env_var(
216
+ "COGNITO_USER_POOL_CLIENT_SECRET_NAME", f"{CDK_PREFIX}ParamCognitoSecret"
217
+ )
218
+ COGNITO_USER_POOL_DOMAIN_PREFIX = get_or_create_env_var(
219
+ "COGNITO_USER_POOL_DOMAIN_PREFIX", "redaction-app-domain"
220
+ ) # Should change this to something unique or you'll probably hit an error
221
 
222
  # Application load balancer
223
+ ALB_NAME = get_or_create_env_var(
224
+ "ALB_NAME", f"{CDK_PREFIX}Alb"[-32:]
225
+ ) # Application load balancer name can be max 32 characters, so taking the last 32 characters of the suggested name
226
+ ALB_NAME_SECURITY_GROUP_NAME = get_or_create_env_var(
227
+ "ALB_SECURITY_GROUP_NAME", f"{CDK_PREFIX}SecurityGroupALB"
228
+ )
229
+ ALB_TARGET_GROUP_NAME = get_or_create_env_var(
230
+ "ALB_TARGET_GROUP_NAME", f"{CDK_PREFIX}-tg"[-32:]
231
+ ) # Max 32 characters
232
+ EXISTING_LOAD_BALANCER_ARN = get_or_create_env_var("EXISTING_LOAD_BALANCER_ARN", "")
233
+ EXISTING_LOAD_BALANCER_DNS = get_or_create_env_var(
234
+ "EXISTING_LOAD_BALANCER_ARN", "placeholder_load_balancer_dns.net"
235
+ )
236
 
237
  ## CLOUDFRONT
238
+ USE_CLOUDFRONT = get_or_create_env_var("USE_CLOUDFRONT", "True")
239
+ CLOUDFRONT_PREFIX_LIST_ID = get_or_create_env_var(
240
+ "CLOUDFRONT_PREFIX_LIST_ID", "pl-93a247fa"
241
+ )
242
+ CLOUDFRONT_GEO_RESTRICTION = get_or_create_env_var(
243
+ "CLOUDFRONT_GEO_RESTRICTION", ""
244
+ ) # A country that Cloudfront restricts access to. See here: https://docs.aws.amazon.com/AmazonCloudFront/latest/DeveloperGuide/georestrictions.html
245
+ CLOUDFRONT_DISTRIBUTION_NAME = get_or_create_env_var(
246
+ "CLOUDFRONT_DISTRIBUTION_NAME", f"{CDK_PREFIX}CfDist"
247
+ )
248
+ CLOUDFRONT_DOMAIN = get_or_create_env_var(
249
+ "CLOUDFRONT_DOMAIN", "cloudfront_placeholder.net"
250
+ )
251
 
252
 
253
  # Certificate for Application load balancer (optional, for HTTPS and logins through the ALB)
254
+ ACM_SSL_CERTIFICATE_ARN = get_or_create_env_var("ACM_SSL_CERTIFICATE_ARN", "")
255
+ SSL_CERTIFICATE_DOMAIN = get_or_create_env_var(
256
+ "SSL_CERTIFICATE_DOMAIN", ""
257
+ ) # e.g. example.com or www.example.com
258
 
259
  # This should be the CloudFront domain, the domain linked to your ACM certificate, or the DNS of your application load balancer in console afterwards
260
  if USE_CLOUDFRONT == "True":
261
+ COGNITO_REDIRECTION_URL = get_or_create_env_var(
262
+ "COGNITO_REDIRECTION_URL", "https://" + CLOUDFRONT_DOMAIN
263
+ )
264
  elif SSL_CERTIFICATE_DOMAIN:
265
+ COGNITO_REDIRECTION_URL = get_or_create_env_var(
266
+ "COGNITO_REDIRECTION_URL", "https://" + SSL_CERTIFICATE_DOMAIN
267
+ )
268
  else:
269
+ COGNITO_REDIRECTION_URL = get_or_create_env_var(
270
+ "COGNITO_REDIRECTION_URL", "https://" + EXISTING_LOAD_BALANCER_DNS
271
+ )
272
 
273
  # Custom headers e.g. if routing traffic through Cloudfront
274
+ CUSTOM_HEADER = get_or_create_env_var(
275
+ "CUSTOM_HEADER", ""
276
+ ) # Retrieving or setting CUSTOM_HEADER
277
+ CUSTOM_HEADER_VALUE = get_or_create_env_var(
278
+ "CUSTOM_HEADER_VALUE", ""
279
+ ) # Retrieving or setting CUSTOM_HEADER_VALUE
280
 
281
  # Firewall on top of load balancer
282
+ LOAD_BALANCER_WEB_ACL_NAME = get_or_create_env_var(
283
+ "LOAD_BALANCER_WEB_ACL_NAME", f"{CDK_PREFIX}alb-web-acl"
284
+ )
285
 
286
  # Firewall on top of CloudFront
287
+ WEB_ACL_NAME = get_or_create_env_var("WEB_ACL_NAME", f"{CDK_PREFIX}cloudfront-web-acl")
288
 
289
  ###
290
  # File I/O options
291
  ###
292
 
293
+ OUTPUT_FOLDER = get_or_create_env_var("GRADIO_OUTPUT_FOLDER", "output/") # 'output/'
294
+ INPUT_FOLDER = get_or_create_env_var("GRADIO_INPUT_FOLDER", "input/") # 'input/'
295
 
296
  # Allow for files to be saved in a temporary folder for increased security in some instances
297
+ if OUTPUT_FOLDER == "TEMP" or INPUT_FOLDER == "TEMP":
298
  # Create a temporary directory
299
  with tempfile.TemporaryDirectory() as temp_dir:
300
+ print(f"Temporary directory created at: {temp_dir}")
301
 
302
+ if OUTPUT_FOLDER == "TEMP":
303
+ OUTPUT_FOLDER = temp_dir + "/"
304
+ if INPUT_FOLDER == "TEMP":
305
+ INPUT_FOLDER = temp_dir + "/"
306
 
307
  ###
308
  # LOGGING OPTIONS
309
  ###
310
 
311
+ SAVE_LOGS_TO_CSV = get_or_create_env_var("SAVE_LOGS_TO_CSV", "True")
312
 
313
  ### DYNAMODB logs. Whether to save to DynamoDB, and the headers of the table
314
+ SAVE_LOGS_TO_DYNAMODB = get_or_create_env_var("SAVE_LOGS_TO_DYNAMODB", "True")
315
+ ACCESS_LOG_DYNAMODB_TABLE_NAME = get_or_create_env_var(
316
+ "ACCESS_LOG_DYNAMODB_TABLE_NAME", f"{CDK_PREFIX}dynamodb-access-logs".lower()
317
+ )
318
+ FEEDBACK_LOG_DYNAMODB_TABLE_NAME = get_or_create_env_var(
319
+ "FEEDBACK_LOG_DYNAMODB_TABLE_NAME", f"{CDK_PREFIX}dynamodb-feedback-logs".lower()
320
+ )
321
+ USAGE_LOG_DYNAMODB_TABLE_NAME = get_or_create_env_var(
322
+ "USAGE_LOG_DYNAMODB_TABLE_NAME", f"{CDK_PREFIX}dynamodb-usage-logs".lower()
323
+ )
324
 
325
  ###
326
  # REDACTION OPTIONS
327
  ###
328
 
329
  # Get some environment variables and Launch the Gradio app
330
+ COGNITO_AUTH = get_or_create_env_var("COGNITO_AUTH", "0")
331
 
332
+ GRADIO_SERVER_PORT = int(get_or_create_env_var("GRADIO_SERVER_PORT", "7860"))
333
 
334
  ###
335
  # WHOLE DOCUMENT API OPTIONS
336
  ###
337
 
338
+ DAYS_TO_DISPLAY_WHOLE_DOCUMENT_JOBS = get_or_create_env_var(
339
+ "DAYS_TO_DISPLAY_WHOLE_DOCUMENT_JOBS", "7"
340
+ ) # How many days into the past should whole document Textract jobs be displayed? After that, the data is not deleted from the Textract jobs csv, but it is just filtered out. Included to align with S3 buckets where the file outputs will be automatically deleted after X days.
cdk/cdk_functions.py CHANGED
@@ -1,34 +1,42 @@
1
- import boto3
2
- from botocore.exceptions import ClientError
3
  import json
4
  import os
 
 
 
5
  import pandas as pd
6
- import ipaddress
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  from constructs import Construct
8
  from dotenv import set_key
9
- from typing import List, Tuple, Optional, Dict, Any
10
- from aws_cdk import (
11
- App,
12
- CfnTag,
13
- aws_ec2 as ec2,
14
- aws_wafv2 as wafv2,
15
- aws_elasticloadbalancingv2 as elb,
16
- aws_elasticloadbalancingv2_actions as elb_act,
17
- aws_certificatemanager as acm, # You might need this if you were looking up a cert, but not strictly for ARN
18
- aws_cognito as cognito,
19
- aws_iam as iam,
20
- CfnOutput,
21
- Tags
22
- )
23
 
24
 
25
-
26
- from cdk_config import PUBLIC_SUBNETS_TO_USE, PRIVATE_SUBNETS_TO_USE, PUBLIC_SUBNET_CIDR_BLOCKS, PRIVATE_SUBNET_CIDR_BLOCKS, PUBLIC_SUBNET_AVAILABILITY_ZONES, PRIVATE_SUBNET_AVAILABILITY_ZONES, POLICY_FILE_LOCATIONS, NAT_GATEWAY_EIP_NAME, S3_LOG_CONFIG_BUCKET_NAME, S3_OUTPUT_BUCKET_NAME, ACCESS_LOG_DYNAMODB_TABLE_NAME, FEEDBACK_LOG_DYNAMODB_TABLE_NAME, USAGE_LOG_DYNAMODB_TABLE_NAME, AWS_REGION
27
-
28
  # --- Function to load context from file ---
29
  def load_context_from_file(app: App, file_path: str):
30
  if os.path.exists(file_path):
31
- with open(file_path, 'r') as f:
32
  context_data = json.load(f)
33
  for key, value in context_data.items():
34
  app.node.set_context(key, value)
@@ -36,35 +44,47 @@ def load_context_from_file(app: App, file_path: str):
36
  else:
37
  print(f"Context file not found: {file_path}")
38
 
 
39
  # --- Helper to parse environment variables into lists ---
40
  def _get_env_list(env_var_name: str) -> List[str]:
41
  """Parses a comma-separated environment variable into a list of strings."""
42
- value = env_var_name[1:-1].strip().replace('\"', '').replace("\'","")
43
  if not value:
44
  return []
45
  # Split by comma and filter out any empty strings that might result from extra commas
46
- return [s.strip() for s in value.split(',') if s.strip()]
 
47
 
48
  # 1. Try to load CIDR/AZs from environment variables
49
- if PUBLIC_SUBNETS_TO_USE: PUBLIC_SUBNETS_TO_USE = _get_env_list(PUBLIC_SUBNETS_TO_USE)
50
- if PRIVATE_SUBNETS_TO_USE: PRIVATE_SUBNETS_TO_USE = _get_env_list(PRIVATE_SUBNETS_TO_USE)
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
- if PUBLIC_SUBNET_CIDR_BLOCKS: PUBLIC_SUBNET_CIDR_BLOCKS = _get_env_list("PUBLIC_SUBNET_CIDR_BLOCKS")
53
- if PUBLIC_SUBNET_AVAILABILITY_ZONES: PUBLIC_SUBNET_AVAILABILITY_ZONES = _get_env_list("PUBLIC_SUBNET_AVAILABILITY_ZONES")
54
- if PRIVATE_SUBNET_CIDR_BLOCKS: PRIVATE_SUBNET_CIDR_BLOCKS = _get_env_list("PRIVATE_SUBNET_CIDR_BLOCKS")
55
- if PRIVATE_SUBNET_AVAILABILITY_ZONES: PRIVATE_SUBNET_AVAILABILITY_ZONES = _get_env_list("PRIVATE_SUBNET_AVAILABILITY_ZONES")
56
 
57
- if POLICY_FILE_LOCATIONS: POLICY_FILE_LOCATIONS = _get_env_list(POLICY_FILE_LOCATIONS)
58
 
59
- def check_for_existing_role(role_name:str):
60
  try:
61
- iam = boto3.client('iam')
62
- #iam.get_role(RoleName=role_name)
63
-
64
  response = iam.get_role(RoleName=role_name)
65
- role = response['Role']['Arn']
66
 
67
- print("Response Role:", role)
68
 
69
  return True, role, ""
70
  except iam.exceptions.NoSuchEntityException:
@@ -72,12 +92,8 @@ def check_for_existing_role(role_name:str):
72
  except Exception as e:
73
  raise Exception("Getting information on IAM role failed due to:", e)
74
 
75
- import json
76
- from typing import List, Dict, Any, Union, Optional
77
- from aws_cdk import (
78
- aws_iam as iam,
79
- )
80
- from constructs import Construct
81
 
82
  # Assume POLICY_FILE_LOCATIONS is defined globally or passed as a default
83
  # For example:
@@ -93,11 +109,13 @@ def add_statement_to_policy(role: iam.IRole, policy_document: Dict[str, Any]):
93
  policy_document: A Python dictionary representing an IAM policy document.
94
  """
95
  # Ensure the loaded JSON is a valid policy document structure
96
- if 'Statement' not in policy_document or not isinstance(policy_document['Statement'], list):
97
- print(f"Warning: Policy document does not contain a 'Statement' list. Skipping.")
98
- return # Do not return role, just log and exit
 
 
99
 
100
- for statement_dict in policy_document['Statement']:
101
  try:
102
  # Create a CDK PolicyStatement from the dictionary
103
  cdk_policy_statement = iam.PolicyStatement.from_json(statement_dict)
@@ -106,13 +124,16 @@ def add_statement_to_policy(role: iam.IRole, policy_document: Dict[str, Any]):
106
  role.add_to_policy(cdk_policy_statement)
107
  print(f" - Added statement: {statement_dict.get('Sid', 'No Sid')}")
108
  except Exception as e:
109
- print(f"Warning: Could not process policy statement: {statement_dict}. Error: {e}")
 
 
 
110
 
111
  def add_custom_policies(
112
- scope: Construct, # Not strictly used here, but good practice if you expand to ManagedPolicies
113
  role: iam.IRole,
114
  policy_file_locations: Optional[List[str]] = None,
115
- custom_policy_text: Optional[str] = None
116
  ) -> iam.IRole:
117
  """
118
  Loads custom policies from JSON files or a string and attaches them to a CDK Role.
@@ -129,7 +150,7 @@ def add_custom_policies(
129
  if policy_file_locations is None:
130
  policy_file_locations = []
131
 
132
- current_source = "unknown source" # For error messages
133
 
134
  try:
135
  if policy_file_locations:
@@ -137,20 +158,26 @@ def add_custom_policies(
137
  for path in policy_file_locations:
138
  current_source = f"file: {path}"
139
  try:
140
- with open(path, 'r') as f:
141
  policy_document = json.load(f)
142
  print(f"Processing policy from {current_source}...")
143
  add_statement_to_policy(role, policy_document)
144
  except FileNotFoundError:
145
  print(f"Warning: Policy file not found at {path}. Skipping.")
146
  except json.JSONDecodeError as e:
147
- print(f"Warning: Invalid JSON in policy file {path}: {e}. Skipping.")
 
 
148
  except Exception as e:
149
- print(f"An unexpected error occurred processing policy from {path}: {e}. Skipping.")
 
 
150
 
151
  if custom_policy_text:
152
  current_source = "custom policy text string"
153
- print(f"Attempting to add policy from custom text to role {role.node.id}...")
 
 
154
  try:
155
  # *** FIX: Parse the JSON string into a Python dictionary ***
156
  policy_document = json.loads(custom_policy_text)
@@ -159,20 +186,28 @@ def add_custom_policies(
159
  except json.JSONDecodeError as e:
160
  print(f"Warning: Invalid JSON in custom_policy_text: {e}. Skipping.")
161
  except Exception as e:
162
- print(f"An unexpected error occurred processing policy from custom_policy_text: {e}. Skipping.")
 
 
163
 
164
  # You might want a final success message, but individual processing messages are also good.
165
  print(f"Finished processing custom policies for role {role.node.id}.")
166
 
167
  except Exception as e:
168
- print(f"An unhandled error occurred during policy addition for {current_source}: {e}")
 
 
169
 
170
  return role
171
 
 
172
  # Import the S3 Bucket class if you intend to return a CDK object later
173
  # from aws_cdk import aws_s3 as s3
174
 
175
- def check_s3_bucket_exists(bucket_name: str): # Return type hint depends on what you return
 
 
 
176
  """
177
  Checks if an S3 bucket with the given name exists and is accessible.
178
 
@@ -186,72 +221,89 @@ def check_s3_bucket_exists(bucket_name: str): # Return type hint depends on what
186
  or the bucket name for CDK lookups/creations.
187
  For this example, let's return the boolean and the name.
188
  """
189
- s3_client = boto3.client('s3')
190
  try:
191
  # Use head_bucket to check for existence and access
192
  s3_client.head_bucket(Bucket=bucket_name)
193
  print(f"Bucket '{bucket_name}' exists and is accessible.")
194
- return True, bucket_name # Return True and the bucket name
195
 
196
  except ClientError as e:
197
  # If a ClientError occurs, check the error code.
198
  # '404' means the bucket does not exist.
199
  # '403' means the bucket exists but you don't have permission.
200
- error_code = e.response['Error']['Code']
201
- if error_code == '404':
202
  print(f"Bucket '{bucket_name}' does not exist.")
203
  return False, None
204
- elif error_code == '403':
205
- # The bucket exists, but you can't access it.
206
- # Depending on your requirements, this might be treated as "exists"
207
- # or "not accessible for our purpose". For checking existence,
208
- # we'll say it exists here, but note the permission issue.
209
- # NOTE - when I tested this, it was returning 403 even for buckets that don't exist. So I will return False instead
210
- print(f"Bucket '{bucket_name}' returned 403, which indicates it may exist but is not accessible due to permissions, or that it doesn't exist. Returning False for existence just in case.")
211
- return False, bucket_name # It exists, even if not accessible
 
 
212
  else:
213
  # For other errors, it's better to raise the exception
214
  # to indicate something unexpected happened.
215
- print(f"An unexpected AWS ClientError occurred checking bucket '{bucket_name}': {e}")
 
 
216
  # Decide how to handle other errors - raising might be safer
217
- raise # Re-raise the original exception
218
  except Exception as e:
219
- print(f"An unexpected non-ClientError occurred checking bucket '{bucket_name}': {e}")
 
 
220
  # Decide how to handle other errors
221
- raise # Re-raise the original exception
 
222
 
223
  # Example usage in your check_resources.py:
224
  # exists, bucket_name_if_exists = check_s3_bucket_exists(log_bucket_name)
225
  # context_data[f"exists:{log_bucket_name}"] = exists
226
  # # You don't necessarily need to store the name in context if using from_bucket_name
227
 
 
228
  # Delete an S3 bucket
229
- def delete_s3_bucket(bucket_name:str):
230
- s3 = boto3.client('s3')
231
-
232
  try:
233
  # List and delete all objects
234
  response = s3.list_object_versions(Bucket=bucket_name)
235
- versions = response.get('Versions', []) + response.get('DeleteMarkers', [])
236
  for version in versions:
237
- s3.delete_object(Bucket=bucket_name, Key=version['Key'], VersionId=version['VersionId'])
238
-
 
 
239
  # Delete the bucket
240
  s3.delete_bucket(Bucket=bucket_name)
241
- return {'Status': 'SUCCESS'}
242
  except Exception as e:
243
- return {'Status': 'FAILED', 'Reason': str(e)}
 
244
 
245
  # Function to get subnet ID from subnet name
246
- def get_subnet_id(vpc:str, ec2_client:str, subnet_name:str):
247
- response = ec2_client.describe_subnets(Filters=[{'Name': 'vpc-id', 'Values': [vpc.vpc_id]}])
 
 
 
 
 
 
 
 
 
248
 
249
- for subnet in response['Subnets']:
250
- if subnet['Tags'] and any(tag['Key'] == 'Name' and tag['Value'] == subnet_name for tag in subnet['Tags']):
251
- return subnet['SubnetId']
252
-
253
  return None
254
 
 
255
  def check_ecr_repo_exists(repo_name: str) -> tuple[bool, dict]:
256
  """
257
  Checks if an ECR repository with the given name exists.
@@ -262,16 +314,16 @@ def check_ecr_repo_exists(repo_name: str) -> tuple[bool, dict]:
262
  Returns:
263
  True if the repository exists, False otherwise.
264
  """
265
- ecr_client = boto3.client('ecr')
266
  try:
267
  print("ecr repo_name to check:", repo_name)
268
  response = ecr_client.describe_repositories(repositoryNames=[repo_name])
269
  # If describe_repositories succeeds and returns a list of repositories,
270
  # and the list is not empty, the repository exists.
271
- return len(response['repositories']) > 0, response['repositories'][0]
272
  except ClientError as e:
273
  # Check for the specific error code indicating the repository doesn't exist
274
- if e.response['Error']['Code'] == 'RepositoryNotFoundException':
275
  return False, {}
276
  else:
277
  # Re-raise other exceptions to handle unexpected errors
@@ -279,8 +331,11 @@ def check_ecr_repo_exists(repo_name: str) -> tuple[bool, dict]:
279
  except Exception as e:
280
  print(f"An unexpected error occurred: {e}")
281
  return False, {}
282
-
283
- def check_codebuild_project_exists(project_name: str): # Adjust return type hint as needed
 
 
 
284
  """
285
  Checks if a CodeBuild project with the given name exists.
286
 
@@ -293,27 +348,35 @@ def check_codebuild_project_exists(project_name: str): # Adjust return type hint
293
  - The second element is the project object (dictionary) if found,
294
  None otherwise.
295
  """
296
- codebuild_client = boto3.client('codebuild')
297
  try:
298
  # Use batch_get_projects with a list containing the single project name
299
  response = codebuild_client.batch_get_projects(names=[project_name])
300
 
301
  # The response for batch_get_projects includes 'projects' (found)
302
  # and 'projectsNotFound' (not found).
303
- if response['projects']:
304
  # If the project is found in the 'projects' list
305
  print(f"CodeBuild project '{project_name}' found.")
306
- return True, response['projects'][0]['arn'] # Return True and the project details dict
307
- elif response['projectsNotFound'] and project_name in response['projectsNotFound']:
308
- # If the project name is explicitly in the 'projectsNotFound' list
309
- print(f"CodeBuild project '{project_name}' not found.")
310
- return False, None
 
 
 
 
 
 
311
  else:
312
  # This case is less expected for a single name lookup,
313
  # but could happen if there's an internal issue or the response
314
  # structure is slightly different than expected for an error.
315
  # It's safer to assume it wasn't found if not in 'projects'.
316
- print(f"CodeBuild project '{project_name}' not found (not in 'projects' list).")
 
 
317
  return False, None
318
 
319
  except ClientError as e:
@@ -321,50 +384,53 @@ def check_codebuild_project_exists(project_name: str): # Adjust return type hint
321
  # 'InvalidInputException' for a non-existent project name if the
322
  # name format is valid. It typically just lists it in projectsNotFound.
323
  # However, other ClientErrors are possible (e.g., permissions).
324
- print(f"An AWS ClientError occurred checking CodeBuild project '{project_name}': {e}")
 
 
325
  # Decide how to handle other ClientErrors - raising might be safer
326
- raise # Re-raise the original exception
327
  except Exception as e:
328
- print(f"An unexpected non-ClientError occurred checking CodeBuild project '{project_name}': {e}")
 
 
329
  # Decide how to handle other errors
330
- raise # Re-raise the original exception
 
331
 
332
  def get_vpc_id_by_name(vpc_name: str) -> Optional[str]:
333
  """
334
  Finds a VPC ID by its 'Name' tag.
335
  """
336
- ec2_client = boto3.client('ec2')
337
  try:
338
  response = ec2_client.describe_vpcs(
339
- Filters=[
340
- {'Name': 'tag:Name', 'Values': [vpc_name]}
341
- ]
342
  )
343
- if response and response['Vpcs']:
344
- vpc_id = response['Vpcs'][0]['VpcId']
345
  print(f"VPC '{vpc_name}' found with ID: {vpc_id}")
346
 
347
  # In get_vpc_id_by_name, after finding VPC ID:
348
 
349
  # Look for NAT Gateways in this VPC
350
- ec2_client = boto3.client('ec2')
351
  nat_gateways = []
352
  try:
353
  response = ec2_client.describe_nat_gateways(
354
  Filters=[
355
- {'Name': 'vpc-id', 'Values': [vpc_id]},
356
  # Optional: Add a tag filter if you consistently tag your NATs
357
  # {'Name': 'tag:Name', 'Values': [f"{prefix}-nat-gateway"]}
358
  ]
359
  )
360
- nat_gateways = response.get('NatGateways', [])
361
  except Exception as e:
362
- print(f"Warning: Could not describe NAT Gateways in VPC '{vpc_id}': {e}")
 
 
363
  # Decide how to handle this error - proceed or raise?
364
 
365
  # Decide how to identify the specific NAT Gateway you want to check for.
366
-
367
-
368
 
369
  return vpc_id, nat_gateways
370
  else:
@@ -374,6 +440,7 @@ def get_vpc_id_by_name(vpc_name: str) -> Optional[str]:
374
  print(f"An unexpected error occurred finding VPC '{vpc_name}': {e}")
375
  raise
376
 
 
377
  # --- Helper to fetch all existing subnets in a VPC once ---
378
  def _get_existing_subnets_in_vpc(vpc_id: str) -> Dict[str, Any]:
379
  """
@@ -381,21 +448,26 @@ def _get_existing_subnets_in_vpc(vpc_id: str) -> Dict[str, Any]:
381
  Returns a dictionary with 'by_name' (map of name to subnet data),
382
  'by_id' (map of id to subnet data), and 'cidr_networks' (list of ipaddress.IPv4Network).
383
  """
384
- ec2_client = boto3.client('ec2')
385
  existing_subnets_data = {
386
  "by_name": {}, # {subnet_name: {'id': 'subnet-id', 'cidr': 'x.x.x.x/x'}}
387
- "by_id": {}, # {subnet_id: {'name': 'subnet-name', 'cidr': 'x.x.x.x/x'}}
388
- "cidr_networks": [] # List of ipaddress.IPv4Network objects
389
  }
390
  try:
391
- response = ec2_client.describe_subnets(Filters=[{'Name': 'vpc-id', 'Values': [vpc_id]}])
392
- for s in response.get('Subnets', []):
393
- subnet_id = s['SubnetId']
394
- cidr_block = s.get('CidrBlock')
 
 
395
  # Extract 'Name' tag, which is crucial for lookup by name
396
- name_tag = next((tag['Value'] for tag in s.get('Tags', []) if tag['Key'] == 'Name'), None)
 
 
 
397
 
398
- subnet_info = {'id': subnet_id, 'cidr': cidr_block, 'name': name_tag}
399
 
400
  if name_tag:
401
  existing_subnets_data["by_name"][name_tag] = subnet_info
@@ -403,22 +475,35 @@ def _get_existing_subnets_in_vpc(vpc_id: str) -> Dict[str, Any]:
403
 
404
  if cidr_block:
405
  try:
406
- existing_subnets_data["cidr_networks"].append(ipaddress.ip_network(cidr_block, strict=False))
 
 
407
  except ValueError:
408
- print(f"Warning: Existing subnet {subnet_id} has an invalid CIDR: {cidr_block}. Skipping for overlap check.")
 
 
409
 
410
- print(f"Fetched {len(response.get('Subnets', []))} existing subnets from VPC '{vpc_id}'.")
 
 
411
  except Exception as e:
412
- print(f"Error describing existing subnets in VPC '{vpc_id}': {e}. Cannot perform full validation.")
413
- raise # Re-raise if this essential step fails
 
 
414
 
415
  return existing_subnets_data
416
 
 
417
  # --- Modified validate_subnet_creation_parameters to take pre-fetched data ---
418
  def validate_subnet_creation_parameters(
419
  vpc_id: str,
420
- proposed_subnets_data: List[Dict[str, str]], # e.g., [{'name': 'my-public-subnet', 'cidr': '10.0.0.0/24', 'az': 'us-east-1a'}]
421
- existing_aws_subnets_data: Dict[str, Any] # Pre-fetched data from _get_existing_subnets_in_vpc
 
 
 
 
422
  ) -> None:
423
  """
424
  Validates proposed subnet names and CIDR blocks against existing AWS subnets
@@ -440,9 +525,11 @@ def validate_subnet_creation_parameters(
440
  print("No proposed subnet data provided for validation. Skipping.")
441
  return
442
 
443
- print(f"--- Starting pre-synth validation for VPC '{vpc_id}' with proposed subnets ---")
 
 
444
 
445
- print("Existing subnet data:", pd.DataFrame(existing_aws_subnets_data['by_name']))
446
 
447
  existing_aws_subnet_names = set(existing_aws_subnets_data["by_name"].keys())
448
  existing_aws_cidr_networks = existing_aws_subnets_data["cidr_networks"]
@@ -452,27 +539,35 @@ def validate_subnet_creation_parameters(
452
  proposed_cidr_networks_seen: List[ipaddress.IPv4Network] = []
453
 
454
  for i, proposed_subnet in enumerate(proposed_subnets_data):
455
- subnet_name = proposed_subnet.get('name')
456
- cidr_block_str = proposed_subnet.get('cidr')
457
- availability_zone = proposed_subnet.get('az')
458
 
459
  if not all([subnet_name, cidr_block_str, availability_zone]):
460
- raise ValueError(f"Proposed subnet at index {i} is incomplete. Requires 'name', 'cidr', and 'az'.")
 
 
461
 
462
  # 1. Check for duplicate names within the proposed batch
463
  if subnet_name in proposed_names_seen:
464
- raise ValueError(f"Proposed subnet name '{subnet_name}' is duplicated within the input list.")
 
 
465
  proposed_names_seen.add(subnet_name)
466
 
467
  # 2. Check for duplicate names against existing AWS subnets
468
  if subnet_name in existing_aws_subnet_names:
469
- print(f"Proposed subnet name '{subnet_name}' already exists in VPC '{vpc_id}'.")
 
 
470
 
471
  # Parse proposed CIDR
472
  try:
473
  proposed_net = ipaddress.ip_network(cidr_block_str, strict=False)
474
  except ValueError as e:
475
- raise ValueError(f"Invalid CIDR format '{cidr_block_str}' for proposed subnet '{subnet_name}': {e}")
 
 
476
 
477
  # 3. Check for overlapping CIDRs within the proposed batch
478
  for existing_proposed_net in proposed_cidr_networks_seen:
@@ -494,14 +589,18 @@ def validate_subnet_creation_parameters(
494
 
495
  # If all checks pass for this subnet, add its network to the list for subsequent checks
496
  proposed_cidr_networks_seen.append(proposed_net)
497
- print(f"Validation successful for proposed subnet '{subnet_name}' with CIDR '{cidr_block_str}'.")
 
 
 
 
 
 
498
 
499
- print(f"--- All proposed subnets passed pre-synth validation checks for VPC '{vpc_id}'. ---")
500
 
501
  # --- Modified check_subnet_exists_by_name (Uses pre-fetched data) ---
502
  def check_subnet_exists_by_name(
503
- subnet_name: str,
504
- existing_aws_subnets_data: Dict[str, Any]
505
  ) -> Tuple[bool, Optional[str]]:
506
  """
507
  Checks if a subnet with the given name exists within the pre-fetched data.
@@ -519,51 +618,63 @@ def check_subnet_exists_by_name(
519
  subnet_info = existing_aws_subnets_data["by_name"].get(subnet_name)
520
  if subnet_info:
521
  print(f"Subnet '{subnet_name}' found with ID: {subnet_info['id']}")
522
- return True, subnet_info['id']
523
  else:
524
  print(f"Subnet '{subnet_name}' not found.")
525
  return False, None
526
 
 
527
  def create_nat_gateway(
528
  scope: Construct,
529
- public_subnet_for_nat: ec2.ISubnet, # Expects a proper ISubnet
530
  nat_gateway_name: str,
531
- nat_gateway_id_context_key: str
532
  ) -> str:
533
  """
534
  Creates a single NAT Gateway in the specified public subnet.
535
  It does not handle lookup from context; the calling stack should do that.
536
  Returns the CloudFormation Ref of the NAT Gateway ID.
537
  """
538
- print(f"Defining a new NAT Gateway '{nat_gateway_name}' in subnet '{public_subnet_for_nat.subnet_id}'.")
 
 
539
 
540
  # Create an Elastic IP for the NAT Gateway
541
- eip = ec2.CfnEIP(scope, NAT_GATEWAY_EIP_NAME,
542
- tags=[CfnTag(key="Name", value=NAT_GATEWAY_EIP_NAME)]
 
 
543
  )
544
 
545
  # Create the NAT Gateway
546
- nat_gateway_logical_id = nat_gateway_name.replace('-', '') + "NatGateway"
547
- nat_gateway = ec2.CfnNatGateway(scope, nat_gateway_logical_id,
 
 
548
  subnet_id=public_subnet_for_nat.subnet_id, # Associate with the public subnet
549
- allocation_id=eip.attr_allocation_id, # Associate with the EIP
550
- tags=[CfnTag(key="Name", value=nat_gateway_name)]
551
  )
552
  # The NAT GW depends on the EIP. The dependency on the subnet is implicit via subnet_id.
553
  nat_gateway.add_dependency(eip)
554
 
555
  # *** CRUCIAL: Use CfnOutput to export the ID after deployment ***
556
  # This is how you will get the ID to put into cdk.context.json
557
- CfnOutput(scope, "SingleNatGatewayIdOutput",
 
 
558
  value=nat_gateway.ref,
559
  description=f"Physical ID of the Single NAT Gateway. Add this to cdk.context.json under the key '{nat_gateway_id_context_key}'.",
560
- export_name=f"{scope.stack_name}-NatGatewayId" # Make export name unique
561
  )
562
 
563
- print(f"CDK: Defined new NAT Gateway '{nat_gateway.ref}'. Its physical ID will be available in the stack outputs after deployment.")
 
 
564
  # Return the tokenised reference for use within this synthesis
565
  return nat_gateway.ref
566
 
 
567
  def create_subnets(
568
  scope: Construct,
569
  vpc: ec2.IVpc,
@@ -573,7 +684,7 @@ def create_subnets(
573
  availability_zones: List[str],
574
  is_public: bool,
575
  internet_gateway_id: Optional[str] = None,
576
- single_nat_gateway_id: Optional[str] = None
577
  ) -> Tuple[List[ec2.CfnSubnet], List[ec2.CfnRouteTable]]:
578
  """
579
  Creates subnets using L2 constructs but returns the underlying L1 Cfn objects
@@ -581,11 +692,15 @@ def create_subnets(
581
  """
582
  # --- Validations remain the same ---
583
  if not (len(subnet_names) == len(cidr_blocks) == len(availability_zones) > 0):
584
- raise ValueError("Subnet names, CIDR blocks, and Availability Zones lists must be non-empty and match in length.")
 
 
585
  if is_public and not internet_gateway_id:
586
  raise ValueError("internet_gateway_id must be provided for public subnets.")
587
  if not is_public and not single_nat_gateway_id:
588
- raise ValueError("single_nat_gateway_id must be provided for private subnets when using a single NAT Gateway.")
 
 
589
 
590
  # --- We will populate these lists with the L1 objects to return ---
591
  created_subnets: List[ec2.CfnSubnet] = []
@@ -603,16 +718,16 @@ def create_subnets(
603
  vpc_id=vpc.vpc_id,
604
  cidr_block=cidr_blocks[i],
605
  availability_zone=availability_zones[i],
606
- map_public_ip_on_launch=is_public
607
  )
608
  Tags.of(subnet).add("Name", subnet_name)
609
  Tags.of(subnet).add("Type", subnet_type_tag)
610
-
611
  if is_public:
612
  # The subnet's route_table is automatically created by the L2 Subnet construct
613
  try:
614
  subnet.add_route(
615
- "DefaultInternetRoute", # A logical ID for the CfnRoute resource
616
  router_id=internet_gateway_id,
617
  router_type=ec2.RouterType.GATEWAY,
618
  # destination_cidr_block="0.0.0.0/0" is the default for this method
@@ -624,22 +739,25 @@ def create_subnets(
624
  try:
625
  # Using .add_route() for private subnets as well for consistency
626
  subnet.add_route(
627
- "DefaultNatRoute", # A logical ID for the CfnRoute resource
628
  router_id=single_nat_gateway_id,
629
  router_type=ec2.RouterType.NAT_GATEWAY,
630
  )
631
  except Exception as e:
632
  print("Could not create NAT gateway route for public subnet due to:", e)
633
- print(f"CDK: Defined private L2 subnet '{subnet_name}' and added NAT GW route.")
 
 
634
 
635
  route_table = subnet.route_table
636
-
637
  created_subnets.append(subnet)
638
  created_route_tables.append(route_table)
639
 
640
  return created_subnets, created_route_tables
641
-
642
- def ingress_rule_exists(security_group:str, peer:str, port:str):
 
643
  for rule in security_group.connections.security_groups:
644
  if port:
645
  if rule.peer == peer and rule.connection == port:
@@ -649,26 +767,32 @@ def ingress_rule_exists(security_group:str, peer:str, port:str):
649
  return True
650
  return False
651
 
652
- def check_for_existing_user_pool(user_pool_name:str):
 
653
  cognito_client = boto3.client("cognito-idp")
654
- list_pools_response = cognito_client.list_user_pools(MaxResults=60) # MaxResults up to 60
655
-
 
 
656
  # ListUserPools might require pagination if you have more than 60 pools
657
  # This simple example doesn't handle pagination, which could miss your pool
658
 
659
  existing_user_pool_id = ""
660
 
661
- for pool in list_pools_response.get('UserPools', []):
662
- if pool.get('Name') == user_pool_name:
663
- existing_user_pool_id = pool['Id']
664
- print(f"Found existing user pool by name '{user_pool_name}' with ID: {existing_user_pool_id}")
665
- break # Found the one we're looking for
 
 
666
 
667
  if existing_user_pool_id:
668
  return True, existing_user_pool_id, pool
669
  else:
670
  return False, "", ""
671
-
 
672
  def check_for_existing_user_pool_client(user_pool_id: str, user_pool_client_name: str):
673
  """
674
  Checks if a Cognito User Pool Client with the given name exists in the specified User Pool.
@@ -683,39 +807,39 @@ def check_for_existing_user_pool_client(user_pool_id: str, user_pool_client_name
683
  - False, "", {} otherwise.
684
  """
685
  cognito_client = boto3.client("cognito-idp")
686
- next_token = 'string'
687
-
688
 
689
  while True:
690
  try:
691
  response = cognito_client.list_user_pool_clients(
692
- UserPoolId=user_pool_id,
693
- MaxResults=60,
694
- NextToken=next_token
695
  )
696
  except cognito_client.exceptions.ResourceNotFoundException:
697
  print(f"Error: User pool with ID '{user_pool_id}' not found.")
698
  return False, "", {}
699
-
700
  except cognito_client.exceptions.InvalidParameterException:
701
  print(f"Error: No app clients for '{user_pool_id}' found.")
702
  return False, "", {}
703
-
704
  except Exception as e:
705
  print("Could not check User Pool clients due to:", e)
706
 
707
- for client in response.get('UserPoolClients', []):
708
- if client.get('ClientName') == user_pool_client_name:
709
- print(f"Found existing user pool client '{user_pool_client_name}' with ID: {client['ClientId']}")
710
- return True, client['ClientId'], client
 
 
711
 
712
- next_token = response.get('NextToken')
713
  if not next_token:
714
  break
715
 
716
  return False, "", {}
717
 
718
- def check_for_secret(secret_name: str, secret_value: dict=""):
 
719
  """
720
  Checks if a Secrets Manager secret with the given name exists.
721
  If it doesn't exist, it creates the secret.
@@ -741,8 +865,11 @@ def check_for_secret(secret_name: str, secret_value: dict=""):
741
  # Handle other potential exceptions during the get operation
742
  print(f"Error checking for secret '{secret_name}': {e}")
743
  return False, {}
744
-
745
- def check_alb_exists(load_balancer_name: str, region_name: str = None) -> tuple[bool, dict]:
 
 
 
746
  """
747
  Checks if an Application Load Balancer (ALB) with the given name exists.
748
 
@@ -759,18 +886,21 @@ def check_alb_exists(load_balancer_name: str, region_name: str = None) -> tuple[
759
  the LoadBalancers list from the describe_load_balancers response.
760
  """
761
  if region_name:
762
- elbv2_client = boto3.client('elbv2', region_name=region_name)
763
  else:
764
- elbv2_client = boto3.client('elbv2')
765
  try:
766
  response = elbv2_client.describe_load_balancers(Names=[load_balancer_name])
767
- if response['LoadBalancers']:
768
- return True, response['LoadBalancers'][0] # Return True and the first ALB object
 
 
 
769
  else:
770
  return False, {}
771
  except ClientError as e:
772
  # If the error indicates the ALB doesn't exist, return False
773
- if e.response['Error']['Code'] == 'LoadBalancerNotFound':
774
  return False, {}
775
  else:
776
  # Re-raise other exceptions
@@ -778,8 +908,11 @@ def check_alb_exists(load_balancer_name: str, region_name: str = None) -> tuple[
778
  except Exception as e:
779
  print(f"An unexpected error occurred: {e}")
780
  return False, {}
781
-
782
- def check_fargate_task_definition_exists(task_definition_name: str, region_name: str = None) -> tuple[bool, dict]:
 
 
 
783
  """
784
  Checks if a Fargate task definition with the given name exists.
785
 
@@ -796,17 +929,23 @@ def check_fargate_task_definition_exists(task_definition_name: str, region_name:
796
  taskDefinitions list from the describe_task_definition response.
797
  """
798
  if region_name:
799
- ecs_client = boto3.client('ecs', region_name=region_name)
800
  else:
801
- ecs_client = boto3.client('ecs')
802
  try:
803
- response = ecs_client.describe_task_definition(taskDefinition=task_definition_name)
 
 
804
  # If describe_task_definition succeeds, it returns the task definition.
805
  # We can directly return True and the task definition.
806
- return True, response['taskDefinition']
807
  except ClientError as e:
808
  # Check for the error code indicating the task definition doesn't exist.
809
- if e.response['Error']['Code'] == 'ClientException' and 'Task definition' in e.response['Message'] and 'does not exist' in e.response['Message']:
 
 
 
 
810
  return False, {}
811
  else:
812
  # Re-raise other exceptions.
@@ -814,8 +953,11 @@ def check_fargate_task_definition_exists(task_definition_name: str, region_name:
814
  except Exception as e:
815
  print(f"An unexpected error occurred: {e}")
816
  return False, {}
817
-
818
- def check_ecs_service_exists(cluster_name: str, service_name: str, region_name: str = None) -> tuple[bool, dict]:
 
 
 
819
  """
820
  Checks if an ECS service with the given name exists in the specified cluster.
821
 
@@ -832,20 +974,25 @@ def check_ecs_service_exists(cluster_name: str, service_name: str, region_name:
832
  None otherwise.
833
  """
834
  if region_name:
835
- ecs_client = boto3.client('ecs', region_name=region_name)
836
  else:
837
- ecs_client = boto3.client('ecs')
838
  try:
839
- response = ecs_client.describe_services(cluster=cluster_name, services=[service_name])
840
- if response['services']:
841
- return True, response['services'][0] # Return True and the first service object
 
 
 
 
 
842
  else:
843
  return False, {}
844
  except ClientError as e:
845
  # Check for the error code indicating the service doesn't exist.
846
- if e.response['Error']['Code'] == 'ClusterNotFoundException':
847
  return False, {}
848
- elif e.response['Error']['Code'] == 'ServiceNotFoundException':
849
  return False, {}
850
  else:
851
  # Re-raise other exceptions.
@@ -853,8 +1000,11 @@ def check_ecs_service_exists(cluster_name: str, service_name: str, region_name:
853
  except Exception as e:
854
  print(f"An unexpected error occurred: {e}")
855
  return False, {}
856
-
857
- def check_cloudfront_distribution_exists(distribution_name: str, region_name: str = None) -> tuple[bool, dict | None]:
 
 
 
858
  """
859
  Checks if a CloudFront distribution with the given name exists.
860
 
@@ -873,22 +1023,25 @@ def check_cloudfront_distribution_exists(distribution_name: str, region_name: st
873
  DistributionList from the ListDistributions response.
874
  """
875
  if region_name:
876
- cf_client = boto3.client('cloudfront', region_name=region_name)
877
  else:
878
- cf_client = boto3.client('cloudfront')
879
  try:
880
  response = cf_client.list_distributions()
881
- if 'Items' in response['DistributionList']:
882
- for distribution in response['DistributionList']['Items']:
883
  # CloudFront doesn't directly filter by name, so we have to iterate.
884
- if distribution['AliasSet']['Items'] and distribution['AliasSet']['Items'][0] == distribution_name:
 
 
 
885
  return True, distribution
886
  return False, None
887
  else:
888
  return False, None
889
  except ClientError as e:
890
  # If the error indicates the Distribution doesn't exist, return False
891
- if e.response['Error']['Code'] == 'NoSuchDistribution':
892
  return False, None
893
  else:
894
  # Re-raise other exceptions
@@ -897,19 +1050,22 @@ def check_cloudfront_distribution_exists(distribution_name: str, region_name: st
897
  print(f"An unexpected error occurred: {e}")
898
  return False, None
899
 
900
- def create_web_acl_with_common_rules(scope:Construct, web_acl_name: str, waf_scope:str="CLOUDFRONT"):
901
- '''
 
 
 
902
  Use CDK to create a web ACL based on an AWS common rule set with overrides.
903
  This function now expects a 'scope' argument, typically 'self' from your stack,
904
  as CfnWebACL requires a construct scope.
905
- '''
906
 
907
  # Create full list of rules
908
  rules = []
909
  aws_ruleset_names = [
910
  "AWSManagedRulesCommonRuleSet",
911
  "AWSManagedRulesKnownBadInputsRuleSet",
912
- "AWSManagedRulesAmazonIpReputationList"
913
  ]
914
 
915
  # Use a separate counter to assign unique priorities sequentially
@@ -917,7 +1073,7 @@ def create_web_acl_with_common_rules(scope:Construct, web_acl_name: str, waf_sco
917
 
918
  for aws_rule_name in aws_ruleset_names:
919
  current_rule_action_overrides = None
920
-
921
  # All managed rule groups need an override_action.
922
  # 'none' means use the managed rule group's default action.
923
  current_override_action = wafv2.CfnWebACL.OverrideActionProperty(none={})
@@ -929,9 +1085,7 @@ def create_web_acl_with_common_rules(scope:Construct, web_acl_name: str, waf_sco
929
  current_rule_action_overrides = [
930
  wafv2.CfnWebACL.RuleActionOverrideProperty(
931
  name="SizeRestrictions_BODY",
932
- action_to_use=wafv2.CfnWebACL.RuleActionProperty(
933
- allow={}
934
- )
935
  )
936
  ]
937
  # No need to set current_override_action here, it's already set above.
@@ -945,39 +1099,38 @@ def create_web_acl_with_common_rules(scope:Construct, web_acl_name: str, waf_sco
945
  managed_rule_group_statement=wafv2.CfnWebACL.ManagedRuleGroupStatementProperty(
946
  vendor_name="AWS",
947
  name=aws_rule_name,
948
- rule_action_overrides=current_rule_action_overrides
949
  )
950
  ),
951
  visibility_config=wafv2.CfnWebACL.VisibilityConfigProperty(
952
  cloud_watch_metrics_enabled=True,
953
  metric_name=aws_rule_name,
954
- sampled_requests_enabled=True
955
  ),
956
- override_action=current_override_action # THIS IS THE CRUCIAL PART FOR ALL MANAGED RULES
957
  )
958
 
959
  rules.append(rule_property)
960
 
961
  # Add the rate limit rule
962
- rate_limit_priority = priority_counter # Use the next available priority
963
- rules.append(wafv2.CfnWebACL.RuleProperty(
964
- name="RateLimitRule",
965
- priority=rate_limit_priority,
966
- statement=wafv2.CfnWebACL.StatementProperty(
967
- rate_based_statement=wafv2.CfnWebACL.RateBasedStatementProperty(
968
- limit=1000,
969
- aggregate_key_type="IP"
970
- )
971
- ),
972
- visibility_config=wafv2.CfnWebACL.VisibilityConfigProperty(
973
- cloud_watch_metrics_enabled=True,
974
- metric_name="RateLimitRule",
975
- sampled_requests_enabled=True
976
- ),
977
- action=wafv2.CfnWebACL.RuleActionProperty(
978
- block={}
979
  )
980
- ))
981
 
982
  web_acl = wafv2.CfnWebACL(
983
  scope,
@@ -988,16 +1141,19 @@ def create_web_acl_with_common_rules(scope:Construct, web_acl_name: str, waf_sco
988
  visibility_config=wafv2.CfnWebACL.VisibilityConfigProperty(
989
  cloud_watch_metrics_enabled=True,
990
  metric_name="webACL",
991
- sampled_requests_enabled=True
992
  ),
993
- rules=rules
994
  )
995
 
996
  CfnOutput(scope, "WebACLArn", value=web_acl.attr_arn)
997
 
998
  return web_acl
999
-
1000
- def check_web_acl_exists(web_acl_name: str, scope: str, region_name: str = None) -> tuple[bool, dict]:
 
 
 
1001
  """
1002
  Checks if a Web ACL with the given name and scope exists.
1003
 
@@ -1014,33 +1170,35 @@ def check_web_acl_exists(web_acl_name: str, scope: str, region_name: str = None)
1014
  - The second element is the Web ACL object (dictionary) if found,
1015
  None otherwise.
1016
  """
1017
- if scope not in ['CLOUDFRONT', 'REGIONAL']:
1018
  raise ValueError("Scope must be either 'CLOUDFRONT' or 'REGIONAL'")
1019
 
1020
- if scope == 'REGIONAL' and not region_name:
1021
  raise ValueError("Region name is required for REGIONAL scope")
1022
 
1023
- if scope == 'CLOUDFRONT':
1024
- region_name = 'us-east-1' # CloudFront scope requires us-east-1
1025
-
1026
  if region_name:
1027
- waf_client = boto3.client('wafv2', region_name=region_name)
1028
  else:
1029
- waf_client = boto3.client('wafv2')
1030
  try:
1031
  response = waf_client.list_web_acls(Scope=scope)
1032
- if 'WebACLs' in response:
1033
- for web_acl in response['WebACLs']:
1034
- if web_acl['Name'] == web_acl_name:
1035
  # Describe the Web ACL to get the full object.
1036
- describe_response = waf_client.describe_web_acl(Name=web_acl_name, Scope=scope)
1037
- return True, describe_response['WebACL']
 
 
1038
  return False, {}
1039
  else:
1040
  return False, {}
1041
  except ClientError as e:
1042
  # Check for the error code indicating the web ACL doesn't exist.
1043
- if e.response['Error']['Code'] == 'ResourceNotFoundException':
1044
  return False, {}
1045
  else:
1046
  # Re-raise other exceptions.
@@ -1048,23 +1206,30 @@ def check_web_acl_exists(web_acl_name: str, scope: str, region_name: str = None)
1048
  except Exception as e:
1049
  print(f"An unexpected error occurred: {e}")
1050
  return False, {}
1051
-
 
1052
  def add_alb_https_listener_with_cert(
1053
  scope: Construct,
1054
- logical_id: str, # A unique ID for this listener construct
1055
  alb: elb.ApplicationLoadBalancer,
1056
- acm_certificate_arn: Optional[str], # Optional: If None, no HTTPS listener will be created
1057
- default_target_group: elb.ITargetGroup, # Mandatory: The target group to forward traffic to
 
 
1058
  listener_port_https: int = 443,
1059
- listener_open_to_internet: bool = False, # Be cautious with True, ensure ALB security group restricts access
1060
  # --- Cognito Authentication Parameters ---
1061
  enable_cognito_auth: bool = False,
1062
  cognito_user_pool: Optional[cognito.IUserPool] = None,
1063
  cognito_user_pool_client: Optional[cognito.IUserPoolClient] = None,
1064
- cognito_user_pool_domain: Optional[str] = None, # E.g., "my-app-domain" for "my-app-domain.auth.region.amazoncognito.com"
1065
- cognito_auth_scope: Optional[str] = "openid profile email", # Default recommended scope
 
 
 
 
1066
  cognito_auth_on_unauthenticated_request: elb.UnauthenticatedAction = elb.UnauthenticatedAction.AUTHENTICATE,
1067
- stickiness_cookie_duration=None
1068
  # --- End Cognito Parameters ---
1069
  ) -> Optional[elb.ApplicationListener]:
1070
  """
@@ -1098,25 +1263,33 @@ def add_alb_https_listener_with_cert(
1098
  https_listener = None
1099
  if acm_certificate_arn:
1100
  certificates_list = [elb.ListenerCertificate.from_arn(acm_certificate_arn)]
1101
- print(f"Attempting to add ALB HTTPS listener on port {listener_port_https} with ACM certificate: {acm_certificate_arn}")
 
 
1102
 
1103
  # Determine the default action based on whether Cognito auth is enabled
1104
  default_action = None
1105
- if enable_cognito_auth == True:
1106
- if not all([cognito_user_pool, cognito_user_pool_client, cognito_user_pool_domain]):
 
 
1107
  raise ValueError(
1108
  "Cognito User Pool, Client, and Domain must be provided if enable_cognito_auth is True."
1109
  )
1110
- print(f"Enabling Cognito authentication with User Pool: {cognito_user_pool.user_pool_id}")
 
 
1111
 
1112
  default_action = elb_act.AuthenticateCognitoAction(
1113
- next=elb.ListenerAction.forward([default_target_group]), # After successful auth, forward to TG
 
 
1114
  user_pool=cognito_user_pool,
1115
  user_pool_client=cognito_user_pool_client,
1116
- user_pool_domain=cognito_user_pool_domain,
1117
  scope=cognito_auth_scope,
1118
- on_unauthenticated_request=cognito_auth_on_unauthenticated_request,
1119
- session_timeout=stickiness_cookie_duration
1120
  # Additional options you might want to configure:
1121
  # session_cookie_name="AWSELBCookies"
1122
  )
@@ -1130,7 +1303,7 @@ def add_alb_https_listener_with_cert(
1130
  port=listener_port_https,
1131
  open=listener_open_to_internet,
1132
  certificates=certificates_list,
1133
- default_action=default_action # Use the determined default action
1134
  )
1135
  print(f"ALB HTTPS listener on port {listener_port_https} defined.")
1136
  else:
@@ -1139,8 +1312,8 @@ def add_alb_https_listener_with_cert(
1139
  return https_listener
1140
 
1141
 
1142
- def ensure_folder_exists(output_folder:str):
1143
- """Checks if the specified folder exists, creates it if not."""
1144
 
1145
  if not os.path.exists(output_folder):
1146
  # Create the folder if it doesn't exist
@@ -1149,62 +1322,70 @@ def ensure_folder_exists(output_folder:str):
1149
  else:
1150
  print(f"The {output_folder} folder already exists.")
1151
 
1152
- def create_basic_config_env(out_dir:str="config", S3_LOG_CONFIG_BUCKET_NAME=S3_LOG_CONFIG_BUCKET_NAME, S3_OUTPUT_BUCKET_NAME=S3_OUTPUT_BUCKET_NAME, ACCESS_LOG_DYNAMODB_TABLE_NAME=ACCESS_LOG_DYNAMODB_TABLE_NAME, FEEDBACK_LOG_DYNAMODB_TABLE_NAME=FEEDBACK_LOG_DYNAMODB_TABLE_NAME, USAGE_LOG_DYNAMODB_TABLE_NAME=USAGE_LOG_DYNAMODB_TABLE_NAME):
1153
- '''
 
 
 
 
 
 
 
 
1154
  Create a basic config.env file for the user to use with their newly deployed redaction app.
1155
- '''
1156
  variables = {
1157
- 'COGNITO_AUTH':'1',
1158
- 'RUN_AWS_FUNCTIONS':'1',
1159
- 'DISPLAY_FILE_NAMES_IN_LOGS':'False',
1160
- 'SESSION_OUTPUT_FOLDER':'True',
1161
- 'SAVE_LOGS_TO_DYNAMODB':'True',
1162
- 'SHOW_COSTS':'True',
1163
- 'SHOW_WHOLE_DOCUMENT_TEXTRACT_CALL_OPTIONS':'True',
1164
- 'LOAD_PREVIOUS_TEXTRACT_JOBS_S3':'True',
1165
- 'DOCUMENT_REDACTION_BUCKET':S3_LOG_CONFIG_BUCKET_NAME,
1166
- 'TEXTRACT_WHOLE_DOCUMENT_ANALYSIS_BUCKET':S3_OUTPUT_BUCKET_NAME,
1167
- 'ACCESS_LOG_DYNAMODB_TABLE_NAME':ACCESS_LOG_DYNAMODB_TABLE_NAME,
1168
- 'FEEDBACK_LOG_DYNAMODB_TABLE_NAME':FEEDBACK_LOG_DYNAMODB_TABLE_NAME,
1169
- 'USAGE_LOG_DYNAMODB_TABLE_NAME':USAGE_LOG_DYNAMODB_TABLE_NAME,
1170
- 'DISPLAY_FILE_NAMES_IN_LOGS':'False'
1171
  }
1172
 
1173
  # Write variables to .env file
1174
  ensure_folder_exists(out_dir + "/")
1175
- env_file_path = os.path.abspath(os.path.join(out_dir, 'config.env'))
1176
 
1177
  # It's good practice to ensure the file exists before calling set_key repeatedly.
1178
  # set_key will create it, but for a loop, it might be cleaner to ensure it's empty/exists once.
1179
  if not os.path.exists(env_file_path):
1180
- with open(env_file_path, 'w') as f:
1181
- pass # Create empty file
1182
 
1183
  for key, value in variables.items():
1184
  set_key(env_file_path, key, str(value), quote_mode="never")
1185
 
1186
  return variables
1187
 
1188
- def start_codebuild_build(PROJECT_NAME:str, AWS_REGION:str = AWS_REGION):
1189
- '''
 
1190
  Start an existing Codebuild project build
1191
- '''
1192
 
1193
  # --- Initialize CodeBuild client ---
1194
- client = boto3.client('codebuild', region_name=AWS_REGION)
1195
 
1196
  try:
1197
  print(f"Attempting to start build for project: {PROJECT_NAME}")
1198
 
1199
- response = client.start_build(
1200
- projectName=PROJECT_NAME
1201
- )
1202
 
1203
- build_id = response['build']['id']
1204
  print(f"Successfully started build with ID: {build_id}")
1205
  print(f"Build ARN: {response['build']['arn']}")
1206
- print(f"Build URL (approximate - construct based on region and ID):")
1207
- print(f"https://{AWS_REGION}.console.aws.amazon.com/codesuite/codebuild/projects/{PROJECT_NAME}/build/{build_id.split(':')[-1]}/detail")
 
 
1208
 
1209
  # You can inspect the full response if needed
1210
  # print("\nFull response:")
@@ -1216,7 +1397,13 @@ def start_codebuild_build(PROJECT_NAME:str, AWS_REGION:str = AWS_REGION):
1216
  except Exception as e:
1217
  print(f"An unexpected error occurred: {e}")
1218
 
1219
- def upload_file_to_s3(local_file_paths:List[str], s3_key:str, s3_bucket:str, RUN_AWS_FUNCTIONS:str = "1"):
 
 
 
 
 
 
1220
  """
1221
  Uploads a file from local machine to Amazon S3.
1222
 
@@ -1235,14 +1422,14 @@ def upload_file_to_s3(local_file_paths:List[str], s3_key:str, s3_bucket:str, RUN
1235
  try:
1236
  if s3_bucket and local_file_paths:
1237
 
1238
- s3_client = boto3.client('s3', region_name=AWS_REGION)
1239
 
1240
  if isinstance(local_file_paths, str):
1241
  local_file_paths = [local_file_paths]
1242
 
1243
  for file in local_file_paths:
1244
  if s3_client:
1245
- #print(s3_client)
1246
  try:
1247
  # Get file name off file path
1248
  file_name = os.path.basename(file)
@@ -1251,18 +1438,24 @@ def upload_file_to_s3(local_file_paths:List[str], s3_key:str, s3_bucket:str, RUN
1251
  print("S3 key: ", s3_key_full)
1252
 
1253
  s3_client.upload_file(file, s3_bucket, s3_key_full)
1254
- out_message = "File " + file_name + " uploaded successfully!"
 
 
1255
  print(out_message)
1256
-
1257
  except Exception as e:
1258
  out_message = f"Error uploading file(s): {e}"
1259
  print(out_message)
1260
 
1261
  final_out_message.append(out_message)
1262
- final_out_message_str = '\n'.join(final_out_message)
1263
 
1264
- else: final_out_message_str = "Could not connect to AWS."
1265
- else: final_out_message_str = "At least one essential variable is empty, could not upload to S3"
 
 
 
 
1266
  except Exception as e:
1267
  final_out_message_str = "Could not upload files to S3 due to: " + str(e)
1268
  print(final_out_message_str)
@@ -1271,23 +1464,19 @@ def upload_file_to_s3(local_file_paths:List[str], s3_key:str, s3_bucket:str, RUN
1271
 
1272
  return final_out_message_str
1273
 
 
1274
  # Initialize ECS client
1275
  def start_ecs_task(cluster_name, service_name):
1276
- ecs_client = boto3.client('ecs')
1277
-
1278
  try:
1279
  # Update the service to set the desired count to 1
1280
- response = ecs_client.update_service(
1281
- cluster=cluster_name,
1282
- service=service_name,
1283
- desiredCount=1
1284
  )
1285
  return {
1286
  "statusCode": 200,
1287
- "body": f"Service {service_name} in cluster {cluster_name} has been updated to 1 task."
1288
  }
1289
  except Exception as e:
1290
- return {
1291
- "statusCode": 500,
1292
- "body": f"Error updating service: {str(e)}"
1293
- }
 
1
+ import ipaddress
 
2
  import json
3
  import os
4
+ from typing import Any, Dict, List, Optional, Tuple
5
+
6
+ import boto3
7
  import pandas as pd
8
+ from aws_cdk import App, CfnOutput, CfnTag, Tags
9
+ from aws_cdk import aws_cognito as cognito
10
+ from aws_cdk import aws_ec2 as ec2
11
+ from aws_cdk import aws_elasticloadbalancingv2 as elb
12
+ from aws_cdk import aws_elasticloadbalancingv2_actions as elb_act
13
+ from aws_cdk import aws_iam as iam
14
+ from aws_cdk import aws_wafv2 as wafv2
15
+ from botocore.exceptions import ClientError
16
+ from cdk_config import (
17
+ ACCESS_LOG_DYNAMODB_TABLE_NAME,
18
+ AWS_REGION,
19
+ FEEDBACK_LOG_DYNAMODB_TABLE_NAME,
20
+ NAT_GATEWAY_EIP_NAME,
21
+ POLICY_FILE_LOCATIONS,
22
+ PRIVATE_SUBNET_AVAILABILITY_ZONES,
23
+ PRIVATE_SUBNET_CIDR_BLOCKS,
24
+ PRIVATE_SUBNETS_TO_USE,
25
+ PUBLIC_SUBNET_AVAILABILITY_ZONES,
26
+ PUBLIC_SUBNET_CIDR_BLOCKS,
27
+ PUBLIC_SUBNETS_TO_USE,
28
+ S3_LOG_CONFIG_BUCKET_NAME,
29
+ S3_OUTPUT_BUCKET_NAME,
30
+ USAGE_LOG_DYNAMODB_TABLE_NAME,
31
+ )
32
  from constructs import Construct
33
  from dotenv import set_key
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
 
 
 
 
36
  # --- Function to load context from file ---
37
  def load_context_from_file(app: App, file_path: str):
38
  if os.path.exists(file_path):
39
+ with open(file_path, "r") as f:
40
  context_data = json.load(f)
41
  for key, value in context_data.items():
42
  app.node.set_context(key, value)
 
44
  else:
45
  print(f"Context file not found: {file_path}")
46
 
47
+
48
  # --- Helper to parse environment variables into lists ---
49
  def _get_env_list(env_var_name: str) -> List[str]:
50
  """Parses a comma-separated environment variable into a list of strings."""
51
+ value = env_var_name[1:-1].strip().replace('"', "").replace("'", "")
52
  if not value:
53
  return []
54
  # Split by comma and filter out any empty strings that might result from extra commas
55
+ return [s.strip() for s in value.split(",") if s.strip()]
56
+
57
 
58
  # 1. Try to load CIDR/AZs from environment variables
59
+ if PUBLIC_SUBNETS_TO_USE:
60
+ PUBLIC_SUBNETS_TO_USE = _get_env_list(PUBLIC_SUBNETS_TO_USE)
61
+ if PRIVATE_SUBNETS_TO_USE:
62
+ PRIVATE_SUBNETS_TO_USE = _get_env_list(PRIVATE_SUBNETS_TO_USE)
63
+
64
+ if PUBLIC_SUBNET_CIDR_BLOCKS:
65
+ PUBLIC_SUBNET_CIDR_BLOCKS = _get_env_list("PUBLIC_SUBNET_CIDR_BLOCKS")
66
+ if PUBLIC_SUBNET_AVAILABILITY_ZONES:
67
+ PUBLIC_SUBNET_AVAILABILITY_ZONES = _get_env_list("PUBLIC_SUBNET_AVAILABILITY_ZONES")
68
+ if PRIVATE_SUBNET_CIDR_BLOCKS:
69
+ PRIVATE_SUBNET_CIDR_BLOCKS = _get_env_list("PRIVATE_SUBNET_CIDR_BLOCKS")
70
+ if PRIVATE_SUBNET_AVAILABILITY_ZONES:
71
+ PRIVATE_SUBNET_AVAILABILITY_ZONES = _get_env_list(
72
+ "PRIVATE_SUBNET_AVAILABILITY_ZONES"
73
+ )
74
 
75
+ if POLICY_FILE_LOCATIONS:
76
+ POLICY_FILE_LOCATIONS = _get_env_list(POLICY_FILE_LOCATIONS)
 
 
77
 
 
78
 
79
+ def check_for_existing_role(role_name: str):
80
  try:
81
+ iam = boto3.client("iam")
82
+ # iam.get_role(RoleName=role_name)
83
+
84
  response = iam.get_role(RoleName=role_name)
85
+ role = response["Role"]["Arn"]
86
 
87
+ print("Response Role:", role)
88
 
89
  return True, role, ""
90
  except iam.exceptions.NoSuchEntityException:
 
92
  except Exception as e:
93
  raise Exception("Getting information on IAM role failed due to:", e)
94
 
95
+
96
+ from typing import List
 
 
 
 
97
 
98
  # Assume POLICY_FILE_LOCATIONS is defined globally or passed as a default
99
  # For example:
 
109
  policy_document: A Python dictionary representing an IAM policy document.
110
  """
111
  # Ensure the loaded JSON is a valid policy document structure
112
+ if "Statement" not in policy_document or not isinstance(
113
+ policy_document["Statement"], list
114
+ ):
115
+ print("Warning: Policy document does not contain a 'Statement' list. Skipping.")
116
+ return # Do not return role, just log and exit
117
 
118
+ for statement_dict in policy_document["Statement"]:
119
  try:
120
  # Create a CDK PolicyStatement from the dictionary
121
  cdk_policy_statement = iam.PolicyStatement.from_json(statement_dict)
 
124
  role.add_to_policy(cdk_policy_statement)
125
  print(f" - Added statement: {statement_dict.get('Sid', 'No Sid')}")
126
  except Exception as e:
127
+ print(
128
+ f"Warning: Could not process policy statement: {statement_dict}. Error: {e}"
129
+ )
130
+
131
 
132
  def add_custom_policies(
133
+ scope: Construct, # Not strictly used here, but good practice if you expand to ManagedPolicies
134
  role: iam.IRole,
135
  policy_file_locations: Optional[List[str]] = None,
136
+ custom_policy_text: Optional[str] = None,
137
  ) -> iam.IRole:
138
  """
139
  Loads custom policies from JSON files or a string and attaches them to a CDK Role.
 
150
  if policy_file_locations is None:
151
  policy_file_locations = []
152
 
153
+ current_source = "unknown source" # For error messages
154
 
155
  try:
156
  if policy_file_locations:
 
158
  for path in policy_file_locations:
159
  current_source = f"file: {path}"
160
  try:
161
+ with open(path, "r") as f:
162
  policy_document = json.load(f)
163
  print(f"Processing policy from {current_source}...")
164
  add_statement_to_policy(role, policy_document)
165
  except FileNotFoundError:
166
  print(f"Warning: Policy file not found at {path}. Skipping.")
167
  except json.JSONDecodeError as e:
168
+ print(
169
+ f"Warning: Invalid JSON in policy file {path}: {e}. Skipping."
170
+ )
171
  except Exception as e:
172
+ print(
173
+ f"An unexpected error occurred processing policy from {path}: {e}. Skipping."
174
+ )
175
 
176
  if custom_policy_text:
177
  current_source = "custom policy text string"
178
+ print(
179
+ f"Attempting to add policy from custom text to role {role.node.id}..."
180
+ )
181
  try:
182
  # *** FIX: Parse the JSON string into a Python dictionary ***
183
  policy_document = json.loads(custom_policy_text)
 
186
  except json.JSONDecodeError as e:
187
  print(f"Warning: Invalid JSON in custom_policy_text: {e}. Skipping.")
188
  except Exception as e:
189
+ print(
190
+ f"An unexpected error occurred processing policy from custom_policy_text: {e}. Skipping."
191
+ )
192
 
193
  # You might want a final success message, but individual processing messages are also good.
194
  print(f"Finished processing custom policies for role {role.node.id}.")
195
 
196
  except Exception as e:
197
+ print(
198
+ f"An unhandled error occurred during policy addition for {current_source}: {e}"
199
+ )
200
 
201
  return role
202
 
203
+
204
  # Import the S3 Bucket class if you intend to return a CDK object later
205
  # from aws_cdk import aws_s3 as s3
206
 
207
+
208
+ def check_s3_bucket_exists(
209
+ bucket_name: str,
210
+ ): # Return type hint depends on what you return
211
  """
212
  Checks if an S3 bucket with the given name exists and is accessible.
213
 
 
221
  or the bucket name for CDK lookups/creations.
222
  For this example, let's return the boolean and the name.
223
  """
224
+ s3_client = boto3.client("s3")
225
  try:
226
  # Use head_bucket to check for existence and access
227
  s3_client.head_bucket(Bucket=bucket_name)
228
  print(f"Bucket '{bucket_name}' exists and is accessible.")
229
+ return True, bucket_name # Return True and the bucket name
230
 
231
  except ClientError as e:
232
  # If a ClientError occurs, check the error code.
233
  # '404' means the bucket does not exist.
234
  # '403' means the bucket exists but you don't have permission.
235
+ error_code = e.response["Error"]["Code"]
236
+ if error_code == "404":
237
  print(f"Bucket '{bucket_name}' does not exist.")
238
  return False, None
239
+ elif error_code == "403":
240
+ # The bucket exists, but you can't access it.
241
+ # Depending on your requirements, this might be treated as "exists"
242
+ # or "not accessible for our purpose". For checking existence,
243
+ # we'll say it exists here, but note the permission issue.
244
+ # NOTE - when I tested this, it was returning 403 even for buckets that don't exist. So I will return False instead
245
+ print(
246
+ f"Bucket '{bucket_name}' returned 403, which indicates it may exist but is not accessible due to permissions, or that it doesn't exist. Returning False for existence just in case."
247
+ )
248
+ return False, bucket_name # It exists, even if not accessible
249
  else:
250
  # For other errors, it's better to raise the exception
251
  # to indicate something unexpected happened.
252
+ print(
253
+ f"An unexpected AWS ClientError occurred checking bucket '{bucket_name}': {e}"
254
+ )
255
  # Decide how to handle other errors - raising might be safer
256
+ raise # Re-raise the original exception
257
  except Exception as e:
258
+ print(
259
+ f"An unexpected non-ClientError occurred checking bucket '{bucket_name}': {e}"
260
+ )
261
  # Decide how to handle other errors
262
+ raise # Re-raise the original exception
263
+
264
 
265
  # Example usage in your check_resources.py:
266
  # exists, bucket_name_if_exists = check_s3_bucket_exists(log_bucket_name)
267
  # context_data[f"exists:{log_bucket_name}"] = exists
268
  # # You don't necessarily need to store the name in context if using from_bucket_name
269
 
270
+
271
  # Delete an S3 bucket
272
+ def delete_s3_bucket(bucket_name: str):
273
+ s3 = boto3.client("s3")
274
+
275
  try:
276
  # List and delete all objects
277
  response = s3.list_object_versions(Bucket=bucket_name)
278
+ versions = response.get("Versions", []) + response.get("DeleteMarkers", [])
279
  for version in versions:
280
+ s3.delete_object(
281
+ Bucket=bucket_name, Key=version["Key"], VersionId=version["VersionId"]
282
+ )
283
+
284
  # Delete the bucket
285
  s3.delete_bucket(Bucket=bucket_name)
286
+ return {"Status": "SUCCESS"}
287
  except Exception as e:
288
+ return {"Status": "FAILED", "Reason": str(e)}
289
+
290
 
291
  # Function to get subnet ID from subnet name
292
+ def get_subnet_id(vpc: str, ec2_client: str, subnet_name: str):
293
+ response = ec2_client.describe_subnets(
294
+ Filters=[{"Name": "vpc-id", "Values": [vpc.vpc_id]}]
295
+ )
296
+
297
+ for subnet in response["Subnets"]:
298
+ if subnet["Tags"] and any(
299
+ tag["Key"] == "Name" and tag["Value"] == subnet_name
300
+ for tag in subnet["Tags"]
301
+ ):
302
+ return subnet["SubnetId"]
303
 
 
 
 
 
304
  return None
305
 
306
+
307
  def check_ecr_repo_exists(repo_name: str) -> tuple[bool, dict]:
308
  """
309
  Checks if an ECR repository with the given name exists.
 
314
  Returns:
315
  True if the repository exists, False otherwise.
316
  """
317
+ ecr_client = boto3.client("ecr")
318
  try:
319
  print("ecr repo_name to check:", repo_name)
320
  response = ecr_client.describe_repositories(repositoryNames=[repo_name])
321
  # If describe_repositories succeeds and returns a list of repositories,
322
  # and the list is not empty, the repository exists.
323
+ return len(response["repositories"]) > 0, response["repositories"][0]
324
  except ClientError as e:
325
  # Check for the specific error code indicating the repository doesn't exist
326
+ if e.response["Error"]["Code"] == "RepositoryNotFoundException":
327
  return False, {}
328
  else:
329
  # Re-raise other exceptions to handle unexpected errors
 
331
  except Exception as e:
332
  print(f"An unexpected error occurred: {e}")
333
  return False, {}
334
+
335
+
336
+ def check_codebuild_project_exists(
337
+ project_name: str,
338
+ ): # Adjust return type hint as needed
339
  """
340
  Checks if a CodeBuild project with the given name exists.
341
 
 
348
  - The second element is the project object (dictionary) if found,
349
  None otherwise.
350
  """
351
+ codebuild_client = boto3.client("codebuild")
352
  try:
353
  # Use batch_get_projects with a list containing the single project name
354
  response = codebuild_client.batch_get_projects(names=[project_name])
355
 
356
  # The response for batch_get_projects includes 'projects' (found)
357
  # and 'projectsNotFound' (not found).
358
+ if response["projects"]:
359
  # If the project is found in the 'projects' list
360
  print(f"CodeBuild project '{project_name}' found.")
361
+ return (
362
+ True,
363
+ response["projects"][0]["arn"],
364
+ ) # Return True and the project details dict
365
+ elif (
366
+ response["projectsNotFound"]
367
+ and project_name in response["projectsNotFound"]
368
+ ):
369
+ # If the project name is explicitly in the 'projectsNotFound' list
370
+ print(f"CodeBuild project '{project_name}' not found.")
371
+ return False, None
372
  else:
373
  # This case is less expected for a single name lookup,
374
  # but could happen if there's an internal issue or the response
375
  # structure is slightly different than expected for an error.
376
  # It's safer to assume it wasn't found if not in 'projects'.
377
+ print(
378
+ f"CodeBuild project '{project_name}' not found (not in 'projects' list)."
379
+ )
380
  return False, None
381
 
382
  except ClientError as e:
 
384
  # 'InvalidInputException' for a non-existent project name if the
385
  # name format is valid. It typically just lists it in projectsNotFound.
386
  # However, other ClientErrors are possible (e.g., permissions).
387
+ print(
388
+ f"An AWS ClientError occurred checking CodeBuild project '{project_name}': {e}"
389
+ )
390
  # Decide how to handle other ClientErrors - raising might be safer
391
+ raise # Re-raise the original exception
392
  except Exception as e:
393
+ print(
394
+ f"An unexpected non-ClientError occurred checking CodeBuild project '{project_name}': {e}"
395
+ )
396
  # Decide how to handle other errors
397
+ raise # Re-raise the original exception
398
+
399
 
400
  def get_vpc_id_by_name(vpc_name: str) -> Optional[str]:
401
  """
402
  Finds a VPC ID by its 'Name' tag.
403
  """
404
+ ec2_client = boto3.client("ec2")
405
  try:
406
  response = ec2_client.describe_vpcs(
407
+ Filters=[{"Name": "tag:Name", "Values": [vpc_name]}]
 
 
408
  )
409
+ if response and response["Vpcs"]:
410
+ vpc_id = response["Vpcs"][0]["VpcId"]
411
  print(f"VPC '{vpc_name}' found with ID: {vpc_id}")
412
 
413
  # In get_vpc_id_by_name, after finding VPC ID:
414
 
415
  # Look for NAT Gateways in this VPC
416
+ ec2_client = boto3.client("ec2")
417
  nat_gateways = []
418
  try:
419
  response = ec2_client.describe_nat_gateways(
420
  Filters=[
421
+ {"Name": "vpc-id", "Values": [vpc_id]},
422
  # Optional: Add a tag filter if you consistently tag your NATs
423
  # {'Name': 'tag:Name', 'Values': [f"{prefix}-nat-gateway"]}
424
  ]
425
  )
426
+ nat_gateways = response.get("NatGateways", [])
427
  except Exception as e:
428
+ print(
429
+ f"Warning: Could not describe NAT Gateways in VPC '{vpc_id}': {e}"
430
+ )
431
  # Decide how to handle this error - proceed or raise?
432
 
433
  # Decide how to identify the specific NAT Gateway you want to check for.
 
 
434
 
435
  return vpc_id, nat_gateways
436
  else:
 
440
  print(f"An unexpected error occurred finding VPC '{vpc_name}': {e}")
441
  raise
442
 
443
+
444
  # --- Helper to fetch all existing subnets in a VPC once ---
445
  def _get_existing_subnets_in_vpc(vpc_id: str) -> Dict[str, Any]:
446
  """
 
448
  Returns a dictionary with 'by_name' (map of name to subnet data),
449
  'by_id' (map of id to subnet data), and 'cidr_networks' (list of ipaddress.IPv4Network).
450
  """
451
+ ec2_client = boto3.client("ec2")
452
  existing_subnets_data = {
453
  "by_name": {}, # {subnet_name: {'id': 'subnet-id', 'cidr': 'x.x.x.x/x'}}
454
+ "by_id": {}, # {subnet_id: {'name': 'subnet-name', 'cidr': 'x.x.x.x/x'}}
455
+ "cidr_networks": [], # List of ipaddress.IPv4Network objects
456
  }
457
  try:
458
+ response = ec2_client.describe_subnets(
459
+ Filters=[{"Name": "vpc-id", "Values": [vpc_id]}]
460
+ )
461
+ for s in response.get("Subnets", []):
462
+ subnet_id = s["SubnetId"]
463
+ cidr_block = s.get("CidrBlock")
464
  # Extract 'Name' tag, which is crucial for lookup by name
465
+ name_tag = next(
466
+ (tag["Value"] for tag in s.get("Tags", []) if tag["Key"] == "Name"),
467
+ None,
468
+ )
469
 
470
+ subnet_info = {"id": subnet_id, "cidr": cidr_block, "name": name_tag}
471
 
472
  if name_tag:
473
  existing_subnets_data["by_name"][name_tag] = subnet_info
 
475
 
476
  if cidr_block:
477
  try:
478
+ existing_subnets_data["cidr_networks"].append(
479
+ ipaddress.ip_network(cidr_block, strict=False)
480
+ )
481
  except ValueError:
482
+ print(
483
+ f"Warning: Existing subnet {subnet_id} has an invalid CIDR: {cidr_block}. Skipping for overlap check."
484
+ )
485
 
486
+ print(
487
+ f"Fetched {len(response.get('Subnets', []))} existing subnets from VPC '{vpc_id}'."
488
+ )
489
  except Exception as e:
490
+ print(
491
+ f"Error describing existing subnets in VPC '{vpc_id}': {e}. Cannot perform full validation."
492
+ )
493
+ raise # Re-raise if this essential step fails
494
 
495
  return existing_subnets_data
496
 
497
+
498
  # --- Modified validate_subnet_creation_parameters to take pre-fetched data ---
499
  def validate_subnet_creation_parameters(
500
  vpc_id: str,
501
+ proposed_subnets_data: List[
502
+ Dict[str, str]
503
+ ], # e.g., [{'name': 'my-public-subnet', 'cidr': '10.0.0.0/24', 'az': 'us-east-1a'}]
504
+ existing_aws_subnets_data: Dict[
505
+ str, Any
506
+ ], # Pre-fetched data from _get_existing_subnets_in_vpc
507
  ) -> None:
508
  """
509
  Validates proposed subnet names and CIDR blocks against existing AWS subnets
 
525
  print("No proposed subnet data provided for validation. Skipping.")
526
  return
527
 
528
+ print(
529
+ f"--- Starting pre-synth validation for VPC '{vpc_id}' with proposed subnets ---"
530
+ )
531
 
532
+ print("Existing subnet data:", pd.DataFrame(existing_aws_subnets_data["by_name"]))
533
 
534
  existing_aws_subnet_names = set(existing_aws_subnets_data["by_name"].keys())
535
  existing_aws_cidr_networks = existing_aws_subnets_data["cidr_networks"]
 
539
  proposed_cidr_networks_seen: List[ipaddress.IPv4Network] = []
540
 
541
  for i, proposed_subnet in enumerate(proposed_subnets_data):
542
+ subnet_name = proposed_subnet.get("name")
543
+ cidr_block_str = proposed_subnet.get("cidr")
544
+ availability_zone = proposed_subnet.get("az")
545
 
546
  if not all([subnet_name, cidr_block_str, availability_zone]):
547
+ raise ValueError(
548
+ f"Proposed subnet at index {i} is incomplete. Requires 'name', 'cidr', and 'az'."
549
+ )
550
 
551
  # 1. Check for duplicate names within the proposed batch
552
  if subnet_name in proposed_names_seen:
553
+ raise ValueError(
554
+ f"Proposed subnet name '{subnet_name}' is duplicated within the input list."
555
+ )
556
  proposed_names_seen.add(subnet_name)
557
 
558
  # 2. Check for duplicate names against existing AWS subnets
559
  if subnet_name in existing_aws_subnet_names:
560
+ print(
561
+ f"Proposed subnet name '{subnet_name}' already exists in VPC '{vpc_id}'."
562
+ )
563
 
564
  # Parse proposed CIDR
565
  try:
566
  proposed_net = ipaddress.ip_network(cidr_block_str, strict=False)
567
  except ValueError as e:
568
+ raise ValueError(
569
+ f"Invalid CIDR format '{cidr_block_str}' for proposed subnet '{subnet_name}': {e}"
570
+ )
571
 
572
  # 3. Check for overlapping CIDRs within the proposed batch
573
  for existing_proposed_net in proposed_cidr_networks_seen:
 
589
 
590
  # If all checks pass for this subnet, add its network to the list for subsequent checks
591
  proposed_cidr_networks_seen.append(proposed_net)
592
+ print(
593
+ f"Validation successful for proposed subnet '{subnet_name}' with CIDR '{cidr_block_str}'."
594
+ )
595
+
596
+ print(
597
+ f"--- All proposed subnets passed pre-synth validation checks for VPC '{vpc_id}'. ---"
598
+ )
599
 
 
600
 
601
  # --- Modified check_subnet_exists_by_name (Uses pre-fetched data) ---
602
  def check_subnet_exists_by_name(
603
+ subnet_name: str, existing_aws_subnets_data: Dict[str, Any]
 
604
  ) -> Tuple[bool, Optional[str]]:
605
  """
606
  Checks if a subnet with the given name exists within the pre-fetched data.
 
618
  subnet_info = existing_aws_subnets_data["by_name"].get(subnet_name)
619
  if subnet_info:
620
  print(f"Subnet '{subnet_name}' found with ID: {subnet_info['id']}")
621
+ return True, subnet_info["id"]
622
  else:
623
  print(f"Subnet '{subnet_name}' not found.")
624
  return False, None
625
 
626
+
627
  def create_nat_gateway(
628
  scope: Construct,
629
+ public_subnet_for_nat: ec2.ISubnet, # Expects a proper ISubnet
630
  nat_gateway_name: str,
631
+ nat_gateway_id_context_key: str,
632
  ) -> str:
633
  """
634
  Creates a single NAT Gateway in the specified public subnet.
635
  It does not handle lookup from context; the calling stack should do that.
636
  Returns the CloudFormation Ref of the NAT Gateway ID.
637
  """
638
+ print(
639
+ f"Defining a new NAT Gateway '{nat_gateway_name}' in subnet '{public_subnet_for_nat.subnet_id}'."
640
+ )
641
 
642
  # Create an Elastic IP for the NAT Gateway
643
+ eip = ec2.CfnEIP(
644
+ scope,
645
+ NAT_GATEWAY_EIP_NAME,
646
+ tags=[CfnTag(key="Name", value=NAT_GATEWAY_EIP_NAME)],
647
  )
648
 
649
  # Create the NAT Gateway
650
+ nat_gateway_logical_id = nat_gateway_name.replace("-", "") + "NatGateway"
651
+ nat_gateway = ec2.CfnNatGateway(
652
+ scope,
653
+ nat_gateway_logical_id,
654
  subnet_id=public_subnet_for_nat.subnet_id, # Associate with the public subnet
655
+ allocation_id=eip.attr_allocation_id, # Associate with the EIP
656
+ tags=[CfnTag(key="Name", value=nat_gateway_name)],
657
  )
658
  # The NAT GW depends on the EIP. The dependency on the subnet is implicit via subnet_id.
659
  nat_gateway.add_dependency(eip)
660
 
661
  # *** CRUCIAL: Use CfnOutput to export the ID after deployment ***
662
  # This is how you will get the ID to put into cdk.context.json
663
+ CfnOutput(
664
+ scope,
665
+ "SingleNatGatewayIdOutput",
666
  value=nat_gateway.ref,
667
  description=f"Physical ID of the Single NAT Gateway. Add this to cdk.context.json under the key '{nat_gateway_id_context_key}'.",
668
+ export_name=f"{scope.stack_name}-NatGatewayId", # Make export name unique
669
  )
670
 
671
+ print(
672
+ f"CDK: Defined new NAT Gateway '{nat_gateway.ref}'. Its physical ID will be available in the stack outputs after deployment."
673
+ )
674
  # Return the tokenised reference for use within this synthesis
675
  return nat_gateway.ref
676
 
677
+
678
  def create_subnets(
679
  scope: Construct,
680
  vpc: ec2.IVpc,
 
684
  availability_zones: List[str],
685
  is_public: bool,
686
  internet_gateway_id: Optional[str] = None,
687
+ single_nat_gateway_id: Optional[str] = None,
688
  ) -> Tuple[List[ec2.CfnSubnet], List[ec2.CfnRouteTable]]:
689
  """
690
  Creates subnets using L2 constructs but returns the underlying L1 Cfn objects
 
692
  """
693
  # --- Validations remain the same ---
694
  if not (len(subnet_names) == len(cidr_blocks) == len(availability_zones) > 0):
695
+ raise ValueError(
696
+ "Subnet names, CIDR blocks, and Availability Zones lists must be non-empty and match in length."
697
+ )
698
  if is_public and not internet_gateway_id:
699
  raise ValueError("internet_gateway_id must be provided for public subnets.")
700
  if not is_public and not single_nat_gateway_id:
701
+ raise ValueError(
702
+ "single_nat_gateway_id must be provided for private subnets when using a single NAT Gateway."
703
+ )
704
 
705
  # --- We will populate these lists with the L1 objects to return ---
706
  created_subnets: List[ec2.CfnSubnet] = []
 
718
  vpc_id=vpc.vpc_id,
719
  cidr_block=cidr_blocks[i],
720
  availability_zone=availability_zones[i],
721
+ map_public_ip_on_launch=is_public,
722
  )
723
  Tags.of(subnet).add("Name", subnet_name)
724
  Tags.of(subnet).add("Type", subnet_type_tag)
725
+
726
  if is_public:
727
  # The subnet's route_table is automatically created by the L2 Subnet construct
728
  try:
729
  subnet.add_route(
730
+ "DefaultInternetRoute", # A logical ID for the CfnRoute resource
731
  router_id=internet_gateway_id,
732
  router_type=ec2.RouterType.GATEWAY,
733
  # destination_cidr_block="0.0.0.0/0" is the default for this method
 
739
  try:
740
  # Using .add_route() for private subnets as well for consistency
741
  subnet.add_route(
742
+ "DefaultNatRoute", # A logical ID for the CfnRoute resource
743
  router_id=single_nat_gateway_id,
744
  router_type=ec2.RouterType.NAT_GATEWAY,
745
  )
746
  except Exception as e:
747
  print("Could not create NAT gateway route for public subnet due to:", e)
748
+ print(
749
+ f"CDK: Defined private L2 subnet '{subnet_name}' and added NAT GW route."
750
+ )
751
 
752
  route_table = subnet.route_table
753
+
754
  created_subnets.append(subnet)
755
  created_route_tables.append(route_table)
756
 
757
  return created_subnets, created_route_tables
758
+
759
+
760
+ def ingress_rule_exists(security_group: str, peer: str, port: str):
761
  for rule in security_group.connections.security_groups:
762
  if port:
763
  if rule.peer == peer and rule.connection == port:
 
767
  return True
768
  return False
769
 
770
+
771
+ def check_for_existing_user_pool(user_pool_name: str):
772
  cognito_client = boto3.client("cognito-idp")
773
+ list_pools_response = cognito_client.list_user_pools(
774
+ MaxResults=60
775
+ ) # MaxResults up to 60
776
+
777
  # ListUserPools might require pagination if you have more than 60 pools
778
  # This simple example doesn't handle pagination, which could miss your pool
779
 
780
  existing_user_pool_id = ""
781
 
782
+ for pool in list_pools_response.get("UserPools", []):
783
+ if pool.get("Name") == user_pool_name:
784
+ existing_user_pool_id = pool["Id"]
785
+ print(
786
+ f"Found existing user pool by name '{user_pool_name}' with ID: {existing_user_pool_id}"
787
+ )
788
+ break # Found the one we're looking for
789
 
790
  if existing_user_pool_id:
791
  return True, existing_user_pool_id, pool
792
  else:
793
  return False, "", ""
794
+
795
+
796
  def check_for_existing_user_pool_client(user_pool_id: str, user_pool_client_name: str):
797
  """
798
  Checks if a Cognito User Pool Client with the given name exists in the specified User Pool.
 
807
  - False, "", {} otherwise.
808
  """
809
  cognito_client = boto3.client("cognito-idp")
810
+ next_token = "string"
 
811
 
812
  while True:
813
  try:
814
  response = cognito_client.list_user_pool_clients(
815
+ UserPoolId=user_pool_id, MaxResults=60, NextToken=next_token
 
 
816
  )
817
  except cognito_client.exceptions.ResourceNotFoundException:
818
  print(f"Error: User pool with ID '{user_pool_id}' not found.")
819
  return False, "", {}
820
+
821
  except cognito_client.exceptions.InvalidParameterException:
822
  print(f"Error: No app clients for '{user_pool_id}' found.")
823
  return False, "", {}
824
+
825
  except Exception as e:
826
  print("Could not check User Pool clients due to:", e)
827
 
828
+ for client in response.get("UserPoolClients", []):
829
+ if client.get("ClientName") == user_pool_client_name:
830
+ print(
831
+ f"Found existing user pool client '{user_pool_client_name}' with ID: {client['ClientId']}"
832
+ )
833
+ return True, client["ClientId"], client
834
 
835
+ next_token = response.get("NextToken")
836
  if not next_token:
837
  break
838
 
839
  return False, "", {}
840
 
841
+
842
+ def check_for_secret(secret_name: str, secret_value: dict = ""):
843
  """
844
  Checks if a Secrets Manager secret with the given name exists.
845
  If it doesn't exist, it creates the secret.
 
865
  # Handle other potential exceptions during the get operation
866
  print(f"Error checking for secret '{secret_name}': {e}")
867
  return False, {}
868
+
869
+
870
+ def check_alb_exists(
871
+ load_balancer_name: str, region_name: str = None
872
+ ) -> tuple[bool, dict]:
873
  """
874
  Checks if an Application Load Balancer (ALB) with the given name exists.
875
 
 
886
  the LoadBalancers list from the describe_load_balancers response.
887
  """
888
  if region_name:
889
+ elbv2_client = boto3.client("elbv2", region_name=region_name)
890
  else:
891
+ elbv2_client = boto3.client("elbv2")
892
  try:
893
  response = elbv2_client.describe_load_balancers(Names=[load_balancer_name])
894
+ if response["LoadBalancers"]:
895
+ return (
896
+ True,
897
+ response["LoadBalancers"][0],
898
+ ) # Return True and the first ALB object
899
  else:
900
  return False, {}
901
  except ClientError as e:
902
  # If the error indicates the ALB doesn't exist, return False
903
+ if e.response["Error"]["Code"] == "LoadBalancerNotFound":
904
  return False, {}
905
  else:
906
  # Re-raise other exceptions
 
908
  except Exception as e:
909
  print(f"An unexpected error occurred: {e}")
910
  return False, {}
911
+
912
+
913
+ def check_fargate_task_definition_exists(
914
+ task_definition_name: str, region_name: str = None
915
+ ) -> tuple[bool, dict]:
916
  """
917
  Checks if a Fargate task definition with the given name exists.
918
 
 
929
  taskDefinitions list from the describe_task_definition response.
930
  """
931
  if region_name:
932
+ ecs_client = boto3.client("ecs", region_name=region_name)
933
  else:
934
+ ecs_client = boto3.client("ecs")
935
  try:
936
+ response = ecs_client.describe_task_definition(
937
+ taskDefinition=task_definition_name
938
+ )
939
  # If describe_task_definition succeeds, it returns the task definition.
940
  # We can directly return True and the task definition.
941
+ return True, response["taskDefinition"]
942
  except ClientError as e:
943
  # Check for the error code indicating the task definition doesn't exist.
944
+ if (
945
+ e.response["Error"]["Code"] == "ClientException"
946
+ and "Task definition" in e.response["Message"]
947
+ and "does not exist" in e.response["Message"]
948
+ ):
949
  return False, {}
950
  else:
951
  # Re-raise other exceptions.
 
953
  except Exception as e:
954
  print(f"An unexpected error occurred: {e}")
955
  return False, {}
956
+
957
+
958
+ def check_ecs_service_exists(
959
+ cluster_name: str, service_name: str, region_name: str = None
960
+ ) -> tuple[bool, dict]:
961
  """
962
  Checks if an ECS service with the given name exists in the specified cluster.
963
 
 
974
  None otherwise.
975
  """
976
  if region_name:
977
+ ecs_client = boto3.client("ecs", region_name=region_name)
978
  else:
979
+ ecs_client = boto3.client("ecs")
980
  try:
981
+ response = ecs_client.describe_services(
982
+ cluster=cluster_name, services=[service_name]
983
+ )
984
+ if response["services"]:
985
+ return (
986
+ True,
987
+ response["services"][0],
988
+ ) # Return True and the first service object
989
  else:
990
  return False, {}
991
  except ClientError as e:
992
  # Check for the error code indicating the service doesn't exist.
993
+ if e.response["Error"]["Code"] == "ClusterNotFoundException":
994
  return False, {}
995
+ elif e.response["Error"]["Code"] == "ServiceNotFoundException":
996
  return False, {}
997
  else:
998
  # Re-raise other exceptions.
 
1000
  except Exception as e:
1001
  print(f"An unexpected error occurred: {e}")
1002
  return False, {}
1003
+
1004
+
1005
+ def check_cloudfront_distribution_exists(
1006
+ distribution_name: str, region_name: str = None
1007
+ ) -> tuple[bool, dict | None]:
1008
  """
1009
  Checks if a CloudFront distribution with the given name exists.
1010
 
 
1023
  DistributionList from the ListDistributions response.
1024
  """
1025
  if region_name:
1026
+ cf_client = boto3.client("cloudfront", region_name=region_name)
1027
  else:
1028
+ cf_client = boto3.client("cloudfront")
1029
  try:
1030
  response = cf_client.list_distributions()
1031
+ if "Items" in response["DistributionList"]:
1032
+ for distribution in response["DistributionList"]["Items"]:
1033
  # CloudFront doesn't directly filter by name, so we have to iterate.
1034
+ if (
1035
+ distribution["AliasSet"]["Items"]
1036
+ and distribution["AliasSet"]["Items"][0] == distribution_name
1037
+ ):
1038
  return True, distribution
1039
  return False, None
1040
  else:
1041
  return False, None
1042
  except ClientError as e:
1043
  # If the error indicates the Distribution doesn't exist, return False
1044
+ if e.response["Error"]["Code"] == "NoSuchDistribution":
1045
  return False, None
1046
  else:
1047
  # Re-raise other exceptions
 
1050
  print(f"An unexpected error occurred: {e}")
1051
  return False, None
1052
 
1053
+
1054
+ def create_web_acl_with_common_rules(
1055
+ scope: Construct, web_acl_name: str, waf_scope: str = "CLOUDFRONT"
1056
+ ):
1057
+ """
1058
  Use CDK to create a web ACL based on an AWS common rule set with overrides.
1059
  This function now expects a 'scope' argument, typically 'self' from your stack,
1060
  as CfnWebACL requires a construct scope.
1061
+ """
1062
 
1063
  # Create full list of rules
1064
  rules = []
1065
  aws_ruleset_names = [
1066
  "AWSManagedRulesCommonRuleSet",
1067
  "AWSManagedRulesKnownBadInputsRuleSet",
1068
+ "AWSManagedRulesAmazonIpReputationList",
1069
  ]
1070
 
1071
  # Use a separate counter to assign unique priorities sequentially
 
1073
 
1074
  for aws_rule_name in aws_ruleset_names:
1075
  current_rule_action_overrides = None
1076
+
1077
  # All managed rule groups need an override_action.
1078
  # 'none' means use the managed rule group's default action.
1079
  current_override_action = wafv2.CfnWebACL.OverrideActionProperty(none={})
 
1085
  current_rule_action_overrides = [
1086
  wafv2.CfnWebACL.RuleActionOverrideProperty(
1087
  name="SizeRestrictions_BODY",
1088
+ action_to_use=wafv2.CfnWebACL.RuleActionProperty(allow={}),
 
 
1089
  )
1090
  ]
1091
  # No need to set current_override_action here, it's already set above.
 
1099
  managed_rule_group_statement=wafv2.CfnWebACL.ManagedRuleGroupStatementProperty(
1100
  vendor_name="AWS",
1101
  name=aws_rule_name,
1102
+ rule_action_overrides=current_rule_action_overrides,
1103
  )
1104
  ),
1105
  visibility_config=wafv2.CfnWebACL.VisibilityConfigProperty(
1106
  cloud_watch_metrics_enabled=True,
1107
  metric_name=aws_rule_name,
1108
+ sampled_requests_enabled=True,
1109
  ),
1110
+ override_action=current_override_action, # THIS IS THE CRUCIAL PART FOR ALL MANAGED RULES
1111
  )
1112
 
1113
  rules.append(rule_property)
1114
 
1115
  # Add the rate limit rule
1116
+ rate_limit_priority = priority_counter # Use the next available priority
1117
+ rules.append(
1118
+ wafv2.CfnWebACL.RuleProperty(
1119
+ name="RateLimitRule",
1120
+ priority=rate_limit_priority,
1121
+ statement=wafv2.CfnWebACL.StatementProperty(
1122
+ rate_based_statement=wafv2.CfnWebACL.RateBasedStatementProperty(
1123
+ limit=1000, aggregate_key_type="IP"
1124
+ )
1125
+ ),
1126
+ visibility_config=wafv2.CfnWebACL.VisibilityConfigProperty(
1127
+ cloud_watch_metrics_enabled=True,
1128
+ metric_name="RateLimitRule",
1129
+ sampled_requests_enabled=True,
1130
+ ),
1131
+ action=wafv2.CfnWebACL.RuleActionProperty(block={}),
 
1132
  )
1133
+ )
1134
 
1135
  web_acl = wafv2.CfnWebACL(
1136
  scope,
 
1141
  visibility_config=wafv2.CfnWebACL.VisibilityConfigProperty(
1142
  cloud_watch_metrics_enabled=True,
1143
  metric_name="webACL",
1144
+ sampled_requests_enabled=True,
1145
  ),
1146
+ rules=rules,
1147
  )
1148
 
1149
  CfnOutput(scope, "WebACLArn", value=web_acl.attr_arn)
1150
 
1151
  return web_acl
1152
+
1153
+
1154
+ def check_web_acl_exists(
1155
+ web_acl_name: str, scope: str, region_name: str = None
1156
+ ) -> tuple[bool, dict]:
1157
  """
1158
  Checks if a Web ACL with the given name and scope exists.
1159
 
 
1170
  - The second element is the Web ACL object (dictionary) if found,
1171
  None otherwise.
1172
  """
1173
+ if scope not in ["CLOUDFRONT", "REGIONAL"]:
1174
  raise ValueError("Scope must be either 'CLOUDFRONT' or 'REGIONAL'")
1175
 
1176
+ if scope == "REGIONAL" and not region_name:
1177
  raise ValueError("Region name is required for REGIONAL scope")
1178
 
1179
+ if scope == "CLOUDFRONT":
1180
+ region_name = "us-east-1" # CloudFront scope requires us-east-1
1181
+
1182
  if region_name:
1183
+ waf_client = boto3.client("wafv2", region_name=region_name)
1184
  else:
1185
+ waf_client = boto3.client("wafv2")
1186
  try:
1187
  response = waf_client.list_web_acls(Scope=scope)
1188
+ if "WebACLs" in response:
1189
+ for web_acl in response["WebACLs"]:
1190
+ if web_acl["Name"] == web_acl_name:
1191
  # Describe the Web ACL to get the full object.
1192
+ describe_response = waf_client.describe_web_acl(
1193
+ Name=web_acl_name, Scope=scope
1194
+ )
1195
+ return True, describe_response["WebACL"]
1196
  return False, {}
1197
  else:
1198
  return False, {}
1199
  except ClientError as e:
1200
  # Check for the error code indicating the web ACL doesn't exist.
1201
+ if e.response["Error"]["Code"] == "ResourceNotFoundException":
1202
  return False, {}
1203
  else:
1204
  # Re-raise other exceptions.
 
1206
  except Exception as e:
1207
  print(f"An unexpected error occurred: {e}")
1208
  return False, {}
1209
+
1210
+
1211
  def add_alb_https_listener_with_cert(
1212
  scope: Construct,
1213
+ logical_id: str, # A unique ID for this listener construct
1214
  alb: elb.ApplicationLoadBalancer,
1215
+ acm_certificate_arn: Optional[
1216
+ str
1217
+ ], # Optional: If None, no HTTPS listener will be created
1218
+ default_target_group: elb.ITargetGroup, # Mandatory: The target group to forward traffic to
1219
  listener_port_https: int = 443,
1220
+ listener_open_to_internet: bool = False, # Be cautious with True, ensure ALB security group restricts access
1221
  # --- Cognito Authentication Parameters ---
1222
  enable_cognito_auth: bool = False,
1223
  cognito_user_pool: Optional[cognito.IUserPool] = None,
1224
  cognito_user_pool_client: Optional[cognito.IUserPoolClient] = None,
1225
+ cognito_user_pool_domain: Optional[
1226
+ str
1227
+ ] = None, # E.g., "my-app-domain" for "my-app-domain.auth.region.amazoncognito.com"
1228
+ cognito_auth_scope: Optional[
1229
+ str
1230
+ ] = "openid profile email", # Default recommended scope
1231
  cognito_auth_on_unauthenticated_request: elb.UnauthenticatedAction = elb.UnauthenticatedAction.AUTHENTICATE,
1232
+ stickiness_cookie_duration=None,
1233
  # --- End Cognito Parameters ---
1234
  ) -> Optional[elb.ApplicationListener]:
1235
  """
 
1263
  https_listener = None
1264
  if acm_certificate_arn:
1265
  certificates_list = [elb.ListenerCertificate.from_arn(acm_certificate_arn)]
1266
+ print(
1267
+ f"Attempting to add ALB HTTPS listener on port {listener_port_https} with ACM certificate: {acm_certificate_arn}"
1268
+ )
1269
 
1270
  # Determine the default action based on whether Cognito auth is enabled
1271
  default_action = None
1272
+ if enable_cognito_auth is True:
1273
+ if not all(
1274
+ [cognito_user_pool, cognito_user_pool_client, cognito_user_pool_domain]
1275
+ ):
1276
  raise ValueError(
1277
  "Cognito User Pool, Client, and Domain must be provided if enable_cognito_auth is True."
1278
  )
1279
+ print(
1280
+ f"Enabling Cognito authentication with User Pool: {cognito_user_pool.user_pool_id}"
1281
+ )
1282
 
1283
  default_action = elb_act.AuthenticateCognitoAction(
1284
+ next=elb.ListenerAction.forward(
1285
+ [default_target_group]
1286
+ ), # After successful auth, forward to TG
1287
  user_pool=cognito_user_pool,
1288
  user_pool_client=cognito_user_pool_client,
1289
+ user_pool_domain=cognito_user_pool_domain,
1290
  scope=cognito_auth_scope,
1291
+ on_unauthenticated_request=cognito_auth_on_unauthenticated_request,
1292
+ session_timeout=stickiness_cookie_duration,
1293
  # Additional options you might want to configure:
1294
  # session_cookie_name="AWSELBCookies"
1295
  )
 
1303
  port=listener_port_https,
1304
  open=listener_open_to_internet,
1305
  certificates=certificates_list,
1306
+ default_action=default_action, # Use the determined default action
1307
  )
1308
  print(f"ALB HTTPS listener on port {listener_port_https} defined.")
1309
  else:
 
1312
  return https_listener
1313
 
1314
 
1315
+ def ensure_folder_exists(output_folder: str):
1316
+ """Checks if the specified folder exists, creates it if not."""
1317
 
1318
  if not os.path.exists(output_folder):
1319
  # Create the folder if it doesn't exist
 
1322
  else:
1323
  print(f"The {output_folder} folder already exists.")
1324
 
1325
+
1326
+ def create_basic_config_env(
1327
+ out_dir: str = "config",
1328
+ S3_LOG_CONFIG_BUCKET_NAME=S3_LOG_CONFIG_BUCKET_NAME,
1329
+ S3_OUTPUT_BUCKET_NAME=S3_OUTPUT_BUCKET_NAME,
1330
+ ACCESS_LOG_DYNAMODB_TABLE_NAME=ACCESS_LOG_DYNAMODB_TABLE_NAME,
1331
+ FEEDBACK_LOG_DYNAMODB_TABLE_NAME=FEEDBACK_LOG_DYNAMODB_TABLE_NAME,
1332
+ USAGE_LOG_DYNAMODB_TABLE_NAME=USAGE_LOG_DYNAMODB_TABLE_NAME,
1333
+ ):
1334
+ """
1335
  Create a basic config.env file for the user to use with their newly deployed redaction app.
1336
+ """
1337
  variables = {
1338
+ "COGNITO_AUTH": "1",
1339
+ "RUN_AWS_FUNCTIONS": "1",
1340
+ "DISPLAY_FILE_NAMES_IN_LOGS": "False",
1341
+ "SESSION_OUTPUT_FOLDER": "True",
1342
+ "SAVE_LOGS_TO_DYNAMODB": "True",
1343
+ "SHOW_COSTS": "True",
1344
+ "SHOW_WHOLE_DOCUMENT_TEXTRACT_CALL_OPTIONS": "True",
1345
+ "LOAD_PREVIOUS_TEXTRACT_JOBS_S3": "True",
1346
+ "DOCUMENT_REDACTION_BUCKET": S3_LOG_CONFIG_BUCKET_NAME,
1347
+ "TEXTRACT_WHOLE_DOCUMENT_ANALYSIS_BUCKET": S3_OUTPUT_BUCKET_NAME,
1348
+ "ACCESS_LOG_DYNAMODB_TABLE_NAME": ACCESS_LOG_DYNAMODB_TABLE_NAME,
1349
+ "FEEDBACK_LOG_DYNAMODB_TABLE_NAME": FEEDBACK_LOG_DYNAMODB_TABLE_NAME,
1350
+ "USAGE_LOG_DYNAMODB_TABLE_NAME": USAGE_LOG_DYNAMODB_TABLE_NAME,
 
1351
  }
1352
 
1353
  # Write variables to .env file
1354
  ensure_folder_exists(out_dir + "/")
1355
+ env_file_path = os.path.abspath(os.path.join(out_dir, "config.env"))
1356
 
1357
  # It's good practice to ensure the file exists before calling set_key repeatedly.
1358
  # set_key will create it, but for a loop, it might be cleaner to ensure it's empty/exists once.
1359
  if not os.path.exists(env_file_path):
1360
+ with open(env_file_path, "w"):
1361
+ pass # Create empty file
1362
 
1363
  for key, value in variables.items():
1364
  set_key(env_file_path, key, str(value), quote_mode="never")
1365
 
1366
  return variables
1367
 
1368
+
1369
+ def start_codebuild_build(PROJECT_NAME: str, AWS_REGION: str = AWS_REGION):
1370
+ """
1371
  Start an existing Codebuild project build
1372
+ """
1373
 
1374
  # --- Initialize CodeBuild client ---
1375
+ client = boto3.client("codebuild", region_name=AWS_REGION)
1376
 
1377
  try:
1378
  print(f"Attempting to start build for project: {PROJECT_NAME}")
1379
 
1380
+ response = client.start_build(projectName=PROJECT_NAME)
 
 
1381
 
1382
+ build_id = response["build"]["id"]
1383
  print(f"Successfully started build with ID: {build_id}")
1384
  print(f"Build ARN: {response['build']['arn']}")
1385
+ print("Build URL (approximate - construct based on region and ID):")
1386
+ print(
1387
+ f"https://{AWS_REGION}.console.aws.amazon.com/codesuite/codebuild/projects/{PROJECT_NAME}/build/{build_id.split(':')[-1]}/detail"
1388
+ )
1389
 
1390
  # You can inspect the full response if needed
1391
  # print("\nFull response:")
 
1397
  except Exception as e:
1398
  print(f"An unexpected error occurred: {e}")
1399
 
1400
+
1401
+ def upload_file_to_s3(
1402
+ local_file_paths: List[str],
1403
+ s3_key: str,
1404
+ s3_bucket: str,
1405
+ RUN_AWS_FUNCTIONS: str = "1",
1406
+ ):
1407
  """
1408
  Uploads a file from local machine to Amazon S3.
1409
 
 
1422
  try:
1423
  if s3_bucket and local_file_paths:
1424
 
1425
+ s3_client = boto3.client("s3", region_name=AWS_REGION)
1426
 
1427
  if isinstance(local_file_paths, str):
1428
  local_file_paths = [local_file_paths]
1429
 
1430
  for file in local_file_paths:
1431
  if s3_client:
1432
+ # print(s3_client)
1433
  try:
1434
  # Get file name off file path
1435
  file_name = os.path.basename(file)
 
1438
  print("S3 key: ", s3_key_full)
1439
 
1440
  s3_client.upload_file(file, s3_bucket, s3_key_full)
1441
+ out_message = (
1442
+ "File " + file_name + " uploaded successfully!"
1443
+ )
1444
  print(out_message)
1445
+
1446
  except Exception as e:
1447
  out_message = f"Error uploading file(s): {e}"
1448
  print(out_message)
1449
 
1450
  final_out_message.append(out_message)
1451
+ final_out_message_str = "\n".join(final_out_message)
1452
 
1453
+ else:
1454
+ final_out_message_str = "Could not connect to AWS."
1455
+ else:
1456
+ final_out_message_str = (
1457
+ "At least one essential variable is empty, could not upload to S3"
1458
+ )
1459
  except Exception as e:
1460
  final_out_message_str = "Could not upload files to S3 due to: " + str(e)
1461
  print(final_out_message_str)
 
1464
 
1465
  return final_out_message_str
1466
 
1467
+
1468
  # Initialize ECS client
1469
  def start_ecs_task(cluster_name, service_name):
1470
+ ecs_client = boto3.client("ecs")
1471
+
1472
  try:
1473
  # Update the service to set the desired count to 1
1474
+ ecs_client.update_service(
1475
+ cluster=cluster_name, service=service_name, desiredCount=1
 
 
1476
  )
1477
  return {
1478
  "statusCode": 200,
1479
+ "body": f"Service {service_name} in cluster {cluster_name} has been updated to 1 task.",
1480
  }
1481
  except Exception as e:
1482
+ return {"statusCode": 500, "body": f"Error updating service: {str(e)}"}
 
 
 
cdk/cdk_stack.py CHANGED
The diff for this file is too large to render. See raw diff
 
cdk/check_resources.py CHANGED
@@ -1,50 +1,85 @@
1
  import json
2
  import os
3
- from cdk_config import CDK_PREFIX, VPC_NAME, AWS_REGION, PUBLIC_SUBNETS_TO_USE, PRIVATE_SUBNETS_TO_USE, CODEBUILD_ROLE_NAME, ECS_TASK_ROLE_NAME, ECS_TASK_EXECUTION_ROLE_NAME, S3_LOG_CONFIG_BUCKET_NAME, S3_OUTPUT_BUCKET_NAME, ECR_CDK_REPO_NAME, CODEBUILD_PROJECT_NAME, ALB_NAME, COGNITO_USER_POOL_NAME, COGNITO_USER_POOL_CLIENT_NAME, COGNITO_USER_POOL_CLIENT_SECRET_NAME, WEB_ACL_NAME, CONTEXT_FILE, PUBLIC_SUBNET_CIDR_BLOCKS, PRIVATE_SUBNET_CIDR_BLOCKS, PUBLIC_SUBNET_AVAILABILITY_ZONES, PRIVATE_SUBNET_AVAILABILITY_ZONES, CDK_FOLDER, CDK_CONFIG_PATH # Import necessary config
4
- from cdk_functions import ( # Import your check functions (assuming they use Boto3)
5
- get_vpc_id_by_name,
6
- check_subnet_exists_by_name,
7
- check_for_existing_role,
8
- check_s3_bucket_exists,
9
- check_ecr_repo_exists,
10
- check_codebuild_project_exists,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  check_alb_exists,
 
 
 
12
  check_for_existing_user_pool,
13
  check_for_existing_user_pool_client,
14
  check_for_secret,
15
- check_cloudfront_distribution_exists,
 
16
  check_web_acl_exists,
17
- _get_existing_subnets_in_vpc,
18
- validate_subnet_creation_parameters
19
  # Add other check functions as needed
20
  )
21
 
22
- from typing import List, Dict, Any
23
-
24
- cdk_folder = CDK_FOLDER #<FULL_PATH_TO_CDK_FOLDER_HERE>
25
 
26
  # Full path needed to find config file
27
  os.environ["CDK_CONFIG_PATH"] = cdk_folder + CDK_CONFIG_PATH
28
 
 
29
  # --- Helper to parse environment variables into lists ---
30
  def _get_env_list(env_var_name: str) -> List[str]:
31
  """Parses a comma-separated environment variable into a list of strings."""
32
- value = env_var_name[1:-1].strip().replace('\"', '').replace("\'","")
33
  if not value:
34
  return []
35
  # Split by comma and filter out any empty strings that might result from extra commas
36
- return [s.strip() for s in value.split(',') if s.strip()]
37
-
38
-
39
- if PUBLIC_SUBNETS_TO_USE and not isinstance(PUBLIC_SUBNETS_TO_USE, list): PUBLIC_SUBNETS_TO_USE = _get_env_list(PUBLIC_SUBNETS_TO_USE)
40
- if PRIVATE_SUBNETS_TO_USE and not isinstance(PRIVATE_SUBNETS_TO_USE, list): PRIVATE_SUBNETS_TO_USE = _get_env_list(PRIVATE_SUBNETS_TO_USE)
41
- if PUBLIC_SUBNET_CIDR_BLOCKS and not isinstance(PUBLIC_SUBNET_CIDR_BLOCKS, list): PUBLIC_SUBNET_CIDR_BLOCKS = _get_env_list(PUBLIC_SUBNET_CIDR_BLOCKS)
42
- if PUBLIC_SUBNET_AVAILABILITY_ZONES and not isinstance(PUBLIC_SUBNET_AVAILABILITY_ZONES, list): PUBLIC_SUBNET_AVAILABILITY_ZONES = _get_env_list(PUBLIC_SUBNET_AVAILABILITY_ZONES)
43
- if PRIVATE_SUBNET_CIDR_BLOCKS and not isinstance(PRIVATE_SUBNET_CIDR_BLOCKS, list): PRIVATE_SUBNET_CIDR_BLOCKS = _get_env_list(PRIVATE_SUBNET_CIDR_BLOCKS)
44
- if PRIVATE_SUBNET_AVAILABILITY_ZONES and not isinstance(PRIVATE_SUBNET_AVAILABILITY_ZONES, list): PRIVATE_SUBNET_AVAILABILITY_ZONES = _get_env_list(PRIVATE_SUBNET_AVAILABILITY_ZONES)
 
 
 
 
 
 
 
 
 
 
45
 
46
  # Check for the existence of elements in your AWS environment to see if it's necessary to create new versions of the same
47
 
 
48
  def check_and_set_context():
49
  context_data = {}
50
 
@@ -58,7 +93,9 @@ def check_and_set_context():
58
  # For simplicity, let's just check if *any* NAT exists in the VPC
59
  # A more robust check would match by subnet, AZ, or a specific tag.
60
  context_data["exists:NatGateway"] = True
61
- context_data["id:NatGateway"] = nat_gateways[0]['NatGatewayId'] # Store the ID of the first one found
 
 
62
  else:
63
  context_data["exists:NatGateway"] = False
64
  context_data["id:NatGateway"] = None
@@ -66,9 +103,11 @@ def check_and_set_context():
66
  if not vpc_id:
67
  # If the VPC doesn't exist, you might not be able to check/create subnets.
68
  # Decide how to handle this: raise an error, set a flag, etc.
69
- raise RuntimeError(f"Required VPC '{VPC_NAME}' not found. Cannot proceed with subnet checks.")
 
 
70
 
71
- context_data["vpc_id"] = vpc_id # Store VPC ID in context
72
 
73
  # SUBNET CHECKS
74
  context_data: Dict[str, Any] = {}
@@ -80,14 +119,14 @@ def check_and_set_context():
80
  # Determine if full validation mode is possible/desired
81
  # It's 'desired' if CIDR/AZs are provided, and their lengths match the name lists.
82
  public_ready_for_full_validation = (
83
- len(PUBLIC_SUBNETS_TO_USE) > 0 and
84
- len(PUBLIC_SUBNET_CIDR_BLOCKS) == len(PUBLIC_SUBNETS_TO_USE) and
85
- len(PUBLIC_SUBNET_AVAILABILITY_ZONES) == len(PUBLIC_SUBNETS_TO_USE)
86
  )
87
  private_ready_for_full_validation = (
88
- len(PRIVATE_SUBNETS_TO_USE) > 0 and
89
- len(PRIVATE_SUBNET_CIDR_BLOCKS) == len(PRIVATE_SUBNETS_TO_USE) and
90
- len(PRIVATE_SUBNET_AVAILABILITY_ZONES) == len(PRIVATE_SUBNETS_TO_USE)
91
  )
92
 
93
  # Activate full validation if *any* type of subnet (public or private) has its full details provided.
@@ -96,27 +135,42 @@ def check_and_set_context():
96
  full_validation_mode = True
97
 
98
  # If some are ready but others aren't, print a warning or raise an error based on your strictness
99
- if public_ready_for_full_validation and not private_ready_for_full_validation and PRIVATE_SUBNETS_TO_USE:
100
- print("Warning: Public subnets have CIDRs/AZs, but private subnets do not. Only public will be fully validated/created with CIDRs.")
101
- if private_ready_for_full_validation and not public_ready_for_full_validation and PUBLIC_SUBNETS_TO_USE:
102
- print("Warning: Private subnets have CIDRs/AZs, but public subnets do not. Only private will be fully validated/created with CIDRs.")
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
  # Prepare data for validate_subnet_creation_parameters for all subnets that have full details
105
  if public_ready_for_full_validation:
106
  for i, name in enumerate(PUBLIC_SUBNETS_TO_USE):
107
- all_proposed_subnets_data.append({
108
- 'name': name,
109
- 'cidr': PUBLIC_SUBNET_CIDR_BLOCKS[i],
110
- 'az': PUBLIC_SUBNET_AVAILABILITY_ZONES[i]
111
- })
 
 
112
  if private_ready_for_full_validation:
113
  for i, name in enumerate(PRIVATE_SUBNETS_TO_USE):
114
- all_proposed_subnets_data.append({
115
- 'name': name,
116
- 'cidr': PRIVATE_SUBNET_CIDR_BLOCKS[i],
117
- 'az': PRIVATE_SUBNET_AVAILABILITY_ZONES[i]
118
- })
119
-
 
120
 
121
  print(f"Target VPC ID for Boto3 lookup: {vpc_id}")
122
 
@@ -125,22 +179,28 @@ def check_and_set_context():
125
  existing_aws_subnets = _get_existing_subnets_in_vpc(vpc_id)
126
  except Exception as e:
127
  print(f"Failed to fetch existing VPC subnets. Aborting. Error: {e}")
128
- raise SystemExit(1) # Exit immediately if we can't get baseline data
129
-
130
  print("\n--- Running Name-Only Subnet Existence Check Mode ---")
131
  # Fallback: check only by name using the existing data
132
  checked_public_subnets = {}
133
  if PUBLIC_SUBNETS_TO_USE:
134
  for subnet_name in PUBLIC_SUBNETS_TO_USE:
135
  print("subnet_name:", subnet_name)
136
- exists, subnet_id = check_subnet_exists_by_name(subnet_name, existing_aws_subnets)
137
- checked_public_subnets[subnet_name] = {"exists": exists, "id": subnet_id}
 
 
 
 
 
138
 
139
  # If the subnet exists, remove it from the proposed subnets list
140
- if checked_public_subnets[subnet_name]["exists"] == True:
141
  all_proposed_subnets_data = [
142
- subnet for subnet in all_proposed_subnets_data
143
- if subnet['name'] != subnet_name
 
144
  ]
145
 
146
  context_data["checked_public_subnets"] = checked_public_subnets
@@ -149,74 +209,86 @@ def check_and_set_context():
149
  if PRIVATE_SUBNETS_TO_USE:
150
  for subnet_name in PRIVATE_SUBNETS_TO_USE:
151
  print("subnet_name:", subnet_name)
152
- exists, subnet_id = check_subnet_exists_by_name(subnet_name, existing_aws_subnets)
153
- checked_private_subnets[subnet_name] = {"exists": exists, "id": subnet_id}
 
 
 
 
 
154
 
155
  # If the subnet exists, remove it from the proposed subnets list
156
- if checked_private_subnets[subnet_name]["exists"] == True:
157
  all_proposed_subnets_data = [
158
- subnet for subnet in all_proposed_subnets_data
159
- if subnet['name'] != subnet_name
 
160
  ]
161
 
162
  context_data["checked_private_subnets"] = checked_private_subnets
163
 
164
-
165
-
166
  print("\nName-only existence subnet check complete.\n")
167
 
168
  if full_validation_mode:
169
- print("\n--- Running in Full Subnet Validation Mode (CIDR/AZs provided) ---")
 
 
170
  try:
171
- validate_subnet_creation_parameters(vpc_id, all_proposed_subnets_data, existing_aws_subnets)
 
 
172
  print("\nPre-synth validation successful. Proceeding with CDK synth.\n")
173
 
174
  # Populate context_data for downstream CDK construct creation
175
  context_data["public_subnets_to_create"] = []
176
  if public_ready_for_full_validation:
177
  for i, name in enumerate(PUBLIC_SUBNETS_TO_USE):
178
- context_data["public_subnets_to_create"].append({
179
- 'name': name,
180
- 'cidr': PUBLIC_SUBNET_CIDR_BLOCKS[i],
181
- 'az': PUBLIC_SUBNET_AVAILABILITY_ZONES[i],
182
- 'is_public': True
183
- })
 
 
184
  context_data["private_subnets_to_create"] = []
185
  if private_ready_for_full_validation:
186
  for i, name in enumerate(PRIVATE_SUBNETS_TO_USE):
187
- context_data["private_subnets_to_create"].append({
188
- 'name': name,
189
- 'cidr': PRIVATE_SUBNET_CIDR_BLOCKS[i],
190
- 'az': PRIVATE_SUBNET_AVAILABILITY_ZONES[i],
191
- 'is_public': False
192
- })
 
 
193
 
194
  except (ValueError, Exception) as e:
195
  print(f"\nFATAL ERROR: Subnet parameter validation failed: {e}\n")
196
- raise SystemExit(1) # Exit if validation fails
197
 
198
  # Example checks and setting context values
199
  # IAM Roles
200
  role_name = CODEBUILD_ROLE_NAME
201
  exists, _, _ = check_for_existing_role(role_name)
202
- context_data[f"exists:{role_name}"] = exists # Use boolean
203
  if exists:
204
- _, role_arn, _ = check_for_existing_role(role_name) # Get ARN if needed
205
- context_data[f"arn:{role_name}"] = role_arn
206
 
207
  role_name = ECS_TASK_ROLE_NAME
208
  exists, _, _ = check_for_existing_role(role_name)
209
  context_data[f"exists:{role_name}"] = exists
210
  if exists:
211
- _, role_arn, _ = check_for_existing_role(role_name)
212
- context_data[f"arn:{role_name}"] = role_arn
213
 
214
  role_name = ECS_TASK_EXECUTION_ROLE_NAME
215
  exists, _, _ = check_for_existing_role(role_name)
216
  context_data[f"exists:{role_name}"] = exists
217
  if exists:
218
- _, role_arn, _ = check_for_existing_role(role_name)
219
- context_data[f"arn:{role_name}"] = role_arn
220
 
221
  # S3 Buckets
222
  bucket_name = S3_LOG_CONFIG_BUCKET_NAME
@@ -230,33 +302,36 @@ def check_and_set_context():
230
  exists, _ = check_s3_bucket_exists(output_bucket_name)
231
  context_data[f"exists:{output_bucket_name}"] = exists
232
  if exists:
233
- pass
234
 
235
  # ECR Repository
236
  repo_name = ECR_CDK_REPO_NAME
237
  exists, _ = check_ecr_repo_exists(repo_name)
238
  context_data[f"exists:{repo_name}"] = exists
239
  if exists:
240
- pass # from_repository_name is sufficient
241
 
242
  # CodeBuild Project
243
  project_name = CODEBUILD_PROJECT_NAME
244
  exists, _ = check_codebuild_project_exists(project_name)
245
  context_data[f"exists:{project_name}"] = exists
246
  if exists:
247
- # Need a way to get the ARN from the check function
248
- _, project_arn = check_codebuild_project_exists(project_name) # Assuming it returns ARN
249
- context_data[f"arn:{project_name}"] = project_arn
 
 
250
 
251
  # ALB (by name lookup)
252
  alb_name = ALB_NAME
253
  exists, _ = check_alb_exists(alb_name, region_name=AWS_REGION)
254
  context_data[f"exists:{alb_name}"] = exists
255
  if exists:
256
- _, alb_object = check_alb_exists(alb_name, region_name=AWS_REGION) # Assuming check returns object
 
 
257
  print("alb_object:", alb_object)
258
- context_data[f"arn:{alb_name}"] = alb_object['LoadBalancerArn']
259
-
260
 
261
  # Cognito User Pool (by name)
262
  user_pool_name = COGNITO_USER_POOL_NAME
@@ -267,10 +342,12 @@ def check_and_set_context():
267
 
268
  # Cognito User Pool Client (by name and pool ID) - requires User Pool ID from check
269
  if user_pool_id:
270
- user_pool_id_for_client_check = user_pool_id #context_data.get(f"id:{user_pool_name}") # Use ID from context
271
  user_pool_client_name = COGNITO_USER_POOL_CLIENT_NAME
272
  if user_pool_id_for_client_check:
273
- exists, client_id, _ = check_for_existing_user_pool_client(user_pool_client_name, user_pool_id_for_client_check)
 
 
274
  context_data[f"exists:{user_pool_client_name}"] = exists
275
  if exists:
276
  context_data[f"id:{user_pool_client_name}"] = client_id
@@ -281,10 +358,11 @@ def check_and_set_context():
281
  context_data[f"exists:{secret_name}"] = exists
282
  # You might not need the ARN if using from_secret_name_v2
283
 
284
-
285
  # WAF Web ACL (by name and scope)
286
  web_acl_name = WEB_ACL_NAME
287
- exists, _ = check_web_acl_exists(web_acl_name, scope="CLOUDFRONT") # Assuming check returns object
 
 
288
  context_data[f"exists:{web_acl_name}"] = exists
289
  if exists:
290
  _, existing_web_acl = check_web_acl_exists(web_acl_name, scope="CLOUDFRONT")
@@ -295,4 +373,3 @@ def check_and_set_context():
295
  json.dump(context_data, f, indent=2)
296
 
297
  print(f"Context data written to {CONTEXT_FILE}")
298
-
 
1
  import json
2
  import os
3
+ from typing import Any, Dict, List
4
+
5
+ from cdk_config import ( # Import necessary config
6
+ ALB_NAME,
7
+ AWS_REGION,
8
+ CDK_CONFIG_PATH,
9
+ CDK_FOLDER,
10
+ CODEBUILD_PROJECT_NAME,
11
+ CODEBUILD_ROLE_NAME,
12
+ COGNITO_USER_POOL_CLIENT_NAME,
13
+ COGNITO_USER_POOL_CLIENT_SECRET_NAME,
14
+ COGNITO_USER_POOL_NAME,
15
+ CONTEXT_FILE,
16
+ ECR_CDK_REPO_NAME,
17
+ ECS_TASK_EXECUTION_ROLE_NAME,
18
+ ECS_TASK_ROLE_NAME,
19
+ PRIVATE_SUBNET_AVAILABILITY_ZONES,
20
+ PRIVATE_SUBNET_CIDR_BLOCKS,
21
+ PRIVATE_SUBNETS_TO_USE,
22
+ PUBLIC_SUBNET_AVAILABILITY_ZONES,
23
+ PUBLIC_SUBNET_CIDR_BLOCKS,
24
+ PUBLIC_SUBNETS_TO_USE,
25
+ S3_LOG_CONFIG_BUCKET_NAME,
26
+ S3_OUTPUT_BUCKET_NAME,
27
+ VPC_NAME,
28
+ WEB_ACL_NAME,
29
+ )
30
+ from cdk_functions import ( # Import your check functions (assuming they use Boto3)
31
+ _get_existing_subnets_in_vpc,
32
  check_alb_exists,
33
+ check_codebuild_project_exists,
34
+ check_ecr_repo_exists,
35
+ check_for_existing_role,
36
  check_for_existing_user_pool,
37
  check_for_existing_user_pool_client,
38
  check_for_secret,
39
+ check_s3_bucket_exists,
40
+ check_subnet_exists_by_name,
41
  check_web_acl_exists,
42
+ get_vpc_id_by_name,
43
+ validate_subnet_creation_parameters,
44
  # Add other check functions as needed
45
  )
46
 
47
+ cdk_folder = CDK_FOLDER # <FULL_PATH_TO_CDK_FOLDER_HERE>
 
 
48
 
49
  # Full path needed to find config file
50
  os.environ["CDK_CONFIG_PATH"] = cdk_folder + CDK_CONFIG_PATH
51
 
52
+
53
  # --- Helper to parse environment variables into lists ---
54
  def _get_env_list(env_var_name: str) -> List[str]:
55
  """Parses a comma-separated environment variable into a list of strings."""
56
+ value = env_var_name[1:-1].strip().replace('"', "").replace("'", "")
57
  if not value:
58
  return []
59
  # Split by comma and filter out any empty strings that might result from extra commas
60
+ return [s.strip() for s in value.split(",") if s.strip()]
61
+
62
+
63
+ if PUBLIC_SUBNETS_TO_USE and not isinstance(PUBLIC_SUBNETS_TO_USE, list):
64
+ PUBLIC_SUBNETS_TO_USE = _get_env_list(PUBLIC_SUBNETS_TO_USE)
65
+ if PRIVATE_SUBNETS_TO_USE and not isinstance(PRIVATE_SUBNETS_TO_USE, list):
66
+ PRIVATE_SUBNETS_TO_USE = _get_env_list(PRIVATE_SUBNETS_TO_USE)
67
+ if PUBLIC_SUBNET_CIDR_BLOCKS and not isinstance(PUBLIC_SUBNET_CIDR_BLOCKS, list):
68
+ PUBLIC_SUBNET_CIDR_BLOCKS = _get_env_list(PUBLIC_SUBNET_CIDR_BLOCKS)
69
+ if PUBLIC_SUBNET_AVAILABILITY_ZONES and not isinstance(
70
+ PUBLIC_SUBNET_AVAILABILITY_ZONES, list
71
+ ):
72
+ PUBLIC_SUBNET_AVAILABILITY_ZONES = _get_env_list(PUBLIC_SUBNET_AVAILABILITY_ZONES)
73
+ if PRIVATE_SUBNET_CIDR_BLOCKS and not isinstance(PRIVATE_SUBNET_CIDR_BLOCKS, list):
74
+ PRIVATE_SUBNET_CIDR_BLOCKS = _get_env_list(PRIVATE_SUBNET_CIDR_BLOCKS)
75
+ if PRIVATE_SUBNET_AVAILABILITY_ZONES and not isinstance(
76
+ PRIVATE_SUBNET_AVAILABILITY_ZONES, list
77
+ ):
78
+ PRIVATE_SUBNET_AVAILABILITY_ZONES = _get_env_list(PRIVATE_SUBNET_AVAILABILITY_ZONES)
79
 
80
  # Check for the existence of elements in your AWS environment to see if it's necessary to create new versions of the same
81
 
82
+
83
  def check_and_set_context():
84
  context_data = {}
85
 
 
93
  # For simplicity, let's just check if *any* NAT exists in the VPC
94
  # A more robust check would match by subnet, AZ, or a specific tag.
95
  context_data["exists:NatGateway"] = True
96
+ context_data["id:NatGateway"] = nat_gateways[0][
97
+ "NatGatewayId"
98
+ ] # Store the ID of the first one found
99
  else:
100
  context_data["exists:NatGateway"] = False
101
  context_data["id:NatGateway"] = None
 
103
  if not vpc_id:
104
  # If the VPC doesn't exist, you might not be able to check/create subnets.
105
  # Decide how to handle this: raise an error, set a flag, etc.
106
+ raise RuntimeError(
107
+ f"Required VPC '{VPC_NAME}' not found. Cannot proceed with subnet checks."
108
+ )
109
 
110
+ context_data["vpc_id"] = vpc_id # Store VPC ID in context
111
 
112
  # SUBNET CHECKS
113
  context_data: Dict[str, Any] = {}
 
119
  # Determine if full validation mode is possible/desired
120
  # It's 'desired' if CIDR/AZs are provided, and their lengths match the name lists.
121
  public_ready_for_full_validation = (
122
+ len(PUBLIC_SUBNETS_TO_USE) > 0
123
+ and len(PUBLIC_SUBNET_CIDR_BLOCKS) == len(PUBLIC_SUBNETS_TO_USE)
124
+ and len(PUBLIC_SUBNET_AVAILABILITY_ZONES) == len(PUBLIC_SUBNETS_TO_USE)
125
  )
126
  private_ready_for_full_validation = (
127
+ len(PRIVATE_SUBNETS_TO_USE) > 0
128
+ and len(PRIVATE_SUBNET_CIDR_BLOCKS) == len(PRIVATE_SUBNETS_TO_USE)
129
+ and len(PRIVATE_SUBNET_AVAILABILITY_ZONES) == len(PRIVATE_SUBNETS_TO_USE)
130
  )
131
 
132
  # Activate full validation if *any* type of subnet (public or private) has its full details provided.
 
135
  full_validation_mode = True
136
 
137
  # If some are ready but others aren't, print a warning or raise an error based on your strictness
138
+ if (
139
+ public_ready_for_full_validation
140
+ and not private_ready_for_full_validation
141
+ and PRIVATE_SUBNETS_TO_USE
142
+ ):
143
+ print(
144
+ "Warning: Public subnets have CIDRs/AZs, but private subnets do not. Only public will be fully validated/created with CIDRs."
145
+ )
146
+ if (
147
+ private_ready_for_full_validation
148
+ and not public_ready_for_full_validation
149
+ and PUBLIC_SUBNETS_TO_USE
150
+ ):
151
+ print(
152
+ "Warning: Private subnets have CIDRs/AZs, but public subnets do not. Only private will be fully validated/created with CIDRs."
153
+ )
154
 
155
  # Prepare data for validate_subnet_creation_parameters for all subnets that have full details
156
  if public_ready_for_full_validation:
157
  for i, name in enumerate(PUBLIC_SUBNETS_TO_USE):
158
+ all_proposed_subnets_data.append(
159
+ {
160
+ "name": name,
161
+ "cidr": PUBLIC_SUBNET_CIDR_BLOCKS[i],
162
+ "az": PUBLIC_SUBNET_AVAILABILITY_ZONES[i],
163
+ }
164
+ )
165
  if private_ready_for_full_validation:
166
  for i, name in enumerate(PRIVATE_SUBNETS_TO_USE):
167
+ all_proposed_subnets_data.append(
168
+ {
169
+ "name": name,
170
+ "cidr": PRIVATE_SUBNET_CIDR_BLOCKS[i],
171
+ "az": PRIVATE_SUBNET_AVAILABILITY_ZONES[i],
172
+ }
173
+ )
174
 
175
  print(f"Target VPC ID for Boto3 lookup: {vpc_id}")
176
 
 
179
  existing_aws_subnets = _get_existing_subnets_in_vpc(vpc_id)
180
  except Exception as e:
181
  print(f"Failed to fetch existing VPC subnets. Aborting. Error: {e}")
182
+ raise SystemExit(1) # Exit immediately if we can't get baseline data
183
+
184
  print("\n--- Running Name-Only Subnet Existence Check Mode ---")
185
  # Fallback: check only by name using the existing data
186
  checked_public_subnets = {}
187
  if PUBLIC_SUBNETS_TO_USE:
188
  for subnet_name in PUBLIC_SUBNETS_TO_USE:
189
  print("subnet_name:", subnet_name)
190
+ exists, subnet_id = check_subnet_exists_by_name(
191
+ subnet_name, existing_aws_subnets
192
+ )
193
+ checked_public_subnets[subnet_name] = {
194
+ "exists": exists,
195
+ "id": subnet_id,
196
+ }
197
 
198
  # If the subnet exists, remove it from the proposed subnets list
199
+ if checked_public_subnets[subnet_name]["exists"] is True:
200
  all_proposed_subnets_data = [
201
+ subnet
202
+ for subnet in all_proposed_subnets_data
203
+ if subnet["name"] != subnet_name
204
  ]
205
 
206
  context_data["checked_public_subnets"] = checked_public_subnets
 
209
  if PRIVATE_SUBNETS_TO_USE:
210
  for subnet_name in PRIVATE_SUBNETS_TO_USE:
211
  print("subnet_name:", subnet_name)
212
+ exists, subnet_id = check_subnet_exists_by_name(
213
+ subnet_name, existing_aws_subnets
214
+ )
215
+ checked_private_subnets[subnet_name] = {
216
+ "exists": exists,
217
+ "id": subnet_id,
218
+ }
219
 
220
  # If the subnet exists, remove it from the proposed subnets list
221
+ if checked_private_subnets[subnet_name]["exists"] is True:
222
  all_proposed_subnets_data = [
223
+ subnet
224
+ for subnet in all_proposed_subnets_data
225
+ if subnet["name"] != subnet_name
226
  ]
227
 
228
  context_data["checked_private_subnets"] = checked_private_subnets
229
 
 
 
230
  print("\nName-only existence subnet check complete.\n")
231
 
232
  if full_validation_mode:
233
+ print(
234
+ "\n--- Running in Full Subnet Validation Mode (CIDR/AZs provided) ---"
235
+ )
236
  try:
237
+ validate_subnet_creation_parameters(
238
+ vpc_id, all_proposed_subnets_data, existing_aws_subnets
239
+ )
240
  print("\nPre-synth validation successful. Proceeding with CDK synth.\n")
241
 
242
  # Populate context_data for downstream CDK construct creation
243
  context_data["public_subnets_to_create"] = []
244
  if public_ready_for_full_validation:
245
  for i, name in enumerate(PUBLIC_SUBNETS_TO_USE):
246
+ context_data["public_subnets_to_create"].append(
247
+ {
248
+ "name": name,
249
+ "cidr": PUBLIC_SUBNET_CIDR_BLOCKS[i],
250
+ "az": PUBLIC_SUBNET_AVAILABILITY_ZONES[i],
251
+ "is_public": True,
252
+ }
253
+ )
254
  context_data["private_subnets_to_create"] = []
255
  if private_ready_for_full_validation:
256
  for i, name in enumerate(PRIVATE_SUBNETS_TO_USE):
257
+ context_data["private_subnets_to_create"].append(
258
+ {
259
+ "name": name,
260
+ "cidr": PRIVATE_SUBNET_CIDR_BLOCKS[i],
261
+ "az": PRIVATE_SUBNET_AVAILABILITY_ZONES[i],
262
+ "is_public": False,
263
+ }
264
+ )
265
 
266
  except (ValueError, Exception) as e:
267
  print(f"\nFATAL ERROR: Subnet parameter validation failed: {e}\n")
268
+ raise SystemExit(1) # Exit if validation fails
269
 
270
  # Example checks and setting context values
271
  # IAM Roles
272
  role_name = CODEBUILD_ROLE_NAME
273
  exists, _, _ = check_for_existing_role(role_name)
274
+ context_data[f"exists:{role_name}"] = exists # Use boolean
275
  if exists:
276
+ _, role_arn, _ = check_for_existing_role(role_name) # Get ARN if needed
277
+ context_data[f"arn:{role_name}"] = role_arn
278
 
279
  role_name = ECS_TASK_ROLE_NAME
280
  exists, _, _ = check_for_existing_role(role_name)
281
  context_data[f"exists:{role_name}"] = exists
282
  if exists:
283
+ _, role_arn, _ = check_for_existing_role(role_name)
284
+ context_data[f"arn:{role_name}"] = role_arn
285
 
286
  role_name = ECS_TASK_EXECUTION_ROLE_NAME
287
  exists, _, _ = check_for_existing_role(role_name)
288
  context_data[f"exists:{role_name}"] = exists
289
  if exists:
290
+ _, role_arn, _ = check_for_existing_role(role_name)
291
+ context_data[f"arn:{role_name}"] = role_arn
292
 
293
  # S3 Buckets
294
  bucket_name = S3_LOG_CONFIG_BUCKET_NAME
 
302
  exists, _ = check_s3_bucket_exists(output_bucket_name)
303
  context_data[f"exists:{output_bucket_name}"] = exists
304
  if exists:
305
+ pass
306
 
307
  # ECR Repository
308
  repo_name = ECR_CDK_REPO_NAME
309
  exists, _ = check_ecr_repo_exists(repo_name)
310
  context_data[f"exists:{repo_name}"] = exists
311
  if exists:
312
+ pass # from_repository_name is sufficient
313
 
314
  # CodeBuild Project
315
  project_name = CODEBUILD_PROJECT_NAME
316
  exists, _ = check_codebuild_project_exists(project_name)
317
  context_data[f"exists:{project_name}"] = exists
318
  if exists:
319
+ # Need a way to get the ARN from the check function
320
+ _, project_arn = check_codebuild_project_exists(
321
+ project_name
322
+ ) # Assuming it returns ARN
323
+ context_data[f"arn:{project_name}"] = project_arn
324
 
325
  # ALB (by name lookup)
326
  alb_name = ALB_NAME
327
  exists, _ = check_alb_exists(alb_name, region_name=AWS_REGION)
328
  context_data[f"exists:{alb_name}"] = exists
329
  if exists:
330
+ _, alb_object = check_alb_exists(
331
+ alb_name, region_name=AWS_REGION
332
+ ) # Assuming check returns object
333
  print("alb_object:", alb_object)
334
+ context_data[f"arn:{alb_name}"] = alb_object["LoadBalancerArn"]
 
335
 
336
  # Cognito User Pool (by name)
337
  user_pool_name = COGNITO_USER_POOL_NAME
 
342
 
343
  # Cognito User Pool Client (by name and pool ID) - requires User Pool ID from check
344
  if user_pool_id:
345
+ user_pool_id_for_client_check = user_pool_id # context_data.get(f"id:{user_pool_name}") # Use ID from context
346
  user_pool_client_name = COGNITO_USER_POOL_CLIENT_NAME
347
  if user_pool_id_for_client_check:
348
+ exists, client_id, _ = check_for_existing_user_pool_client(
349
+ user_pool_client_name, user_pool_id_for_client_check
350
+ )
351
  context_data[f"exists:{user_pool_client_name}"] = exists
352
  if exists:
353
  context_data[f"id:{user_pool_client_name}"] = client_id
 
358
  context_data[f"exists:{secret_name}"] = exists
359
  # You might not need the ARN if using from_secret_name_v2
360
 
 
361
  # WAF Web ACL (by name and scope)
362
  web_acl_name = WEB_ACL_NAME
363
+ exists, _ = check_web_acl_exists(
364
+ web_acl_name, scope="CLOUDFRONT"
365
+ ) # Assuming check returns object
366
  context_data[f"exists:{web_acl_name}"] = exists
367
  if exists:
368
  _, existing_web_acl = check_web_acl_exists(web_acl_name, scope="CLOUDFRONT")
 
373
  json.dump(context_data, f, indent=2)
374
 
375
  print(f"Context data written to {CONTEXT_FILE}")
 
cdk/post_cdk_build_quickstart.py CHANGED
@@ -1,6 +1,17 @@
1
  import time
2
- from cdk_config import CODEBUILD_PROJECT_NAME, S3_LOG_CONFIG_BUCKET_NAME, CLUSTER_NAME, ECS_SERVICE_NAME
3
- from cdk_functions import start_codebuild_build, upload_file_to_s3, start_ecs_task, create_basic_config_env
 
 
 
 
 
 
 
 
 
 
 
4
  from tqdm import tqdm
5
 
6
  # Create basic config.env file that user can use to run the app later. Input is the folder it is saved into.
@@ -11,10 +22,12 @@ print("Starting CodeBuild project.")
11
  start_codebuild_build(PROJECT_NAME=CODEBUILD_PROJECT_NAME)
12
 
13
  # Upload config.env file to S3 bucket
14
- upload_file_to_s3(local_file_paths="config/config.env", s3_key="", s3_bucket=S3_LOG_CONFIG_BUCKET_NAME)
 
 
15
 
16
- total_seconds = 660 # 11 minutes
17
- update_interval = 1 # Update every second
18
 
19
  print("Waiting 11 minutes for the CodeBuild container to build.")
20
 
@@ -24,4 +37,4 @@ for i in tqdm(range(total_seconds), desc="Building container"):
24
 
25
  # Start task on ECS
26
  print("Starting ECS task")
27
- start_ecs_task(cluster_name=CLUSTER_NAME, service_name=ECS_SERVICE_NAME)
 
1
  import time
2
+
3
+ from cdk_config import (
4
+ CLUSTER_NAME,
5
+ CODEBUILD_PROJECT_NAME,
6
+ ECS_SERVICE_NAME,
7
+ S3_LOG_CONFIG_BUCKET_NAME,
8
+ )
9
+ from cdk_functions import (
10
+ create_basic_config_env,
11
+ start_codebuild_build,
12
+ start_ecs_task,
13
+ upload_file_to_s3,
14
+ )
15
  from tqdm import tqdm
16
 
17
  # Create basic config.env file that user can use to run the app later. Input is the folder it is saved into.
 
22
  start_codebuild_build(PROJECT_NAME=CODEBUILD_PROJECT_NAME)
23
 
24
  # Upload config.env file to S3 bucket
25
+ upload_file_to_s3(
26
+ local_file_paths="config/config.env", s3_key="", s3_bucket=S3_LOG_CONFIG_BUCKET_NAME
27
+ )
28
 
29
+ total_seconds = 660 # 11 minutes
30
+ update_interval = 1 # Update every second
31
 
32
  print("Waiting 11 minutes for the CodeBuild container to build.")
33
 
 
37
 
38
  # Start task on ECS
39
  print("Starting ECS task")
40
+ start_ecs_task(cluster_name=CLUSTER_NAME, service_name=ECS_SERVICE_NAME)
cli_redact.py CHANGED
@@ -1,70 +1,145 @@
1
  import argparse
2
  import os
3
- import pandas as pd
4
  import time
5
  import uuid
6
- from tools.config import LOCAL_PII_OPTION, AWS_PII_OPTION, OUTPUT_FOLDER, DEFAULT_LANGUAGE, CHOSEN_COMPREHEND_ENTITIES, FULL_COMPREHEND_ENTITY_LIST, CHOSEN_REDACT_ENTITIES, FULL_ENTITY_LIST, CUSTOM_ENTITIES, AWS_ACCESS_KEY, AWS_SECRET_KEY, AWS_REGION, DOCUMENT_REDACTION_BUCKET, DEFAULT_COST_CODE, SAVE_LOGS_TO_CSV, SAVE_LOGS_TO_DYNAMODB, DISPLAY_FILE_NAMES_IN_LOGS, DEFAULT_HANDWRITE_SIGNATURE_CHECKBOX, DO_INITIAL_TABULAR_DATA_CLEAN, ALLOW_LIST_PATH, DENY_LIST_PATH, WHOLE_PAGE_REDACTION_LIST_PATH, PREPROCESS_LOCAL_OCR_IMAGES, IMAGES_DPI, RETURN_PDF_END_OF_REDACTION, COMPRESS_REDACTED_PDF, CHOSEN_LOCAL_OCR_MODEL, DEFAULT_TABULAR_ANONYMISATION_STRATEGY, DEFAULT_FUZZY_SPELLING_MISTAKES_NUM, DEFAULT_DUPLICATE_DETECTION_THRESHOLD, DEFAULT_MIN_WORD_COUNT, DEFAULT_MIN_CONSECUTIVE_PAGES, USE_GREEDY_DUPLICATE_DETECTION, DEFAULT_COMBINE_PAGES, REMOVE_DUPLICATE_ROWS, TEXTRACT_WHOLE_DOCUMENT_ANALYSIS_BUCKET, TEXTRACT_WHOLE_DOCUMENT_ANALYSIS_INPUT_SUBFOLDER, TEXTRACT_WHOLE_DOCUMENT_ANALYSIS_OUTPUT_SUBFOLDER, INPUT_FOLDER, TEXTRACT_JOBS_S3_LOC, TEXTRACT_JOBS_LOCAL_LOC, SESSION_OUTPUT_FOLDER, DIRECT_MODE_DEFAULT_USER, RUN_AWS_FUNCTIONS, S3_USAGE_LOGS_FOLDER
7
 
8
- from tools.helper_functions import ensure_output_folder_exists
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
 
11
  def _generate_session_hash() -> str:
12
  """Generate a unique session hash for logging purposes."""
13
  return str(uuid.uuid4())[:8]
14
 
15
- def get_username_and_folders(username:str = "",
16
- output_folder_textbox:str=OUTPUT_FOLDER,
17
- input_folder_textbox:str=INPUT_FOLDER,
18
- session_output_folder:str=SESSION_OUTPUT_FOLDER,
19
- textract_document_upload_input_folder:str=TEXTRACT_WHOLE_DOCUMENT_ANALYSIS_INPUT_SUBFOLDER,
20
- textract_document_upload_output_folder:str=TEXTRACT_WHOLE_DOCUMENT_ANALYSIS_OUTPUT_SUBFOLDER,
21
- s3_textract_document_logs_subfolder:str=TEXTRACT_JOBS_S3_LOC,
22
- local_textract_document_logs_subfolder:str=TEXTRACT_JOBS_LOCAL_LOC):
23
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  # Generate session hash for logging. Either from input user name or generated
26
  if username:
27
  out_session_hash = username
28
  else:
29
- out_session_hash = _generate_session_hash()
30
-
31
 
32
- if session_output_folder == 'True' or session_output_folder == True:
33
  output_folder = output_folder_textbox + out_session_hash + "/"
34
  input_folder = input_folder_textbox + out_session_hash + "/"
35
 
36
- textract_document_upload_input_folder = textract_document_upload_input_folder + "/" + out_session_hash
37
- textract_document_upload_output_folder = textract_document_upload_output_folder + "/" + out_session_hash
 
 
 
 
38
 
39
- s3_textract_document_logs_subfolder = s3_textract_document_logs_subfolder + "/" + out_session_hash
40
- local_textract_document_logs_subfolder = local_textract_document_logs_subfolder + "/" + out_session_hash + "/"
 
 
 
 
41
 
42
  else:
43
  output_folder = output_folder_textbox
44
  input_folder = input_folder_textbox
45
 
46
- if not os.path.exists(output_folder): os.mkdir(output_folder)
47
- if not os.path.exists(input_folder): os.mkdir(input_folder)
48
-
49
- return out_session_hash, output_folder, out_session_hash, input_folder, textract_document_upload_input_folder, textract_document_upload_output_folder, s3_textract_document_logs_subfolder, local_textract_document_logs_subfolder
 
 
 
 
 
 
 
 
 
 
 
50
 
51
 
52
  def _get_env_list(env_var_name: str) -> list[str]:
53
  """Parses a comma-separated environment variable into a list of strings."""
54
- value = env_var_name[1:-1].strip().replace('\"', '').replace("\'","")
55
  if not value:
56
  return []
57
  # Split by comma and filter out any empty strings that might result from extra commas
58
- return [s.strip() for s in value.split(',') if s.strip()]
 
59
 
60
  # --- Constants and Configuration ---
61
 
62
- if CHOSEN_COMPREHEND_ENTITIES: CHOSEN_COMPREHEND_ENTITIES = _get_env_list(CHOSEN_COMPREHEND_ENTITIES)
63
- if FULL_COMPREHEND_ENTITY_LIST: FULL_COMPREHEND_ENTITY_LIST = _get_env_list(FULL_COMPREHEND_ENTITY_LIST)
64
- if CHOSEN_REDACT_ENTITIES: CHOSEN_REDACT_ENTITIES = _get_env_list(CHOSEN_REDACT_ENTITIES)
65
- if FULL_ENTITY_LIST: FULL_ENTITY_LIST = _get_env_list(FULL_ENTITY_LIST)
66
- if CUSTOM_ENTITIES: CUSTOM_ENTITIES = _get_env_list(CUSTOM_ENTITIES)
67
- if DEFAULT_HANDWRITE_SIGNATURE_CHECKBOX: DEFAULT_HANDWRITE_SIGNATURE_CHECKBOX = _get_env_list(DEFAULT_HANDWRITE_SIGNATURE_CHECKBOX)
 
 
 
 
 
 
 
 
68
 
69
  # Add custom spacy recognisers to the Comprehend list, so that local Spacy model can be used to pick up e.g. titles, streetnames, UK postcodes that are sometimes missed by comprehend
70
  CHOSEN_COMPREHEND_ENTITIES.extend(CUSTOM_ENTITIES)
@@ -76,19 +151,20 @@ chosen_comprehend_entities = CHOSEN_COMPREHEND_ENTITIES
76
  full_comprehend_entity_list = FULL_COMPREHEND_ENTITY_LIST
77
  default_handwrite_signature_checkbox = DEFAULT_HANDWRITE_SIGNATURE_CHECKBOX
78
 
 
79
  # --- Main CLI Function ---
80
  def main(direct_mode_args={}):
81
  """
82
  A unified command-line interface to prepare, redact, and anonymise various document types.
83
-
84
  Args:
85
  direct_mode_args (dict, optional): Dictionary of arguments for direct mode execution.
86
  If provided, uses these instead of parsing command line arguments.
87
  """
88
  parser = argparse.ArgumentParser(
89
- description='A versatile CLI for redacting PII from PDF/image files and anonymising Word/tabular data.',
90
  formatter_class=argparse.RawTextHelpFormatter,
91
- epilog='''
92
  Examples:
93
 
94
  To run these, you need to do the following:
@@ -171,95 +247,334 @@ python cli_redact.py --task textract --textract_action retrieve --job_id 1234567
171
  ## List recent Textract jobs:
172
  python cli_redact.py --task textract --textract_action list
173
 
174
- '''
175
  )
176
 
177
  # --- Task Selection ---
178
- task_group = parser.add_argument_group('Task Selection')
179
- task_group.add_argument('--task',
180
- choices=['redact', 'deduplicate', 'textract'],
181
- default='redact',
182
- help='Task to perform: redact (PII redaction/anonymisation), deduplicate (find duplicate content), or textract (AWS Textract batch operations).')
 
 
183
 
184
  # --- General Arguments (apply to all file types) ---
185
- general_group = parser.add_argument_group('General Options')
186
- general_group.add_argument('--input_file', nargs='+', help='Path to the input file(s) to process. Separate multiple files with a space, and use quotes if there are spaces in the file name.')
187
- general_group.add_argument('--output_dir', default=OUTPUT_FOLDER, help='Directory for all output files.')
188
- general_group.add_argument('--input_dir', default=INPUT_FOLDER, help='Directory for all input files.')
189
- general_group.add_argument('--language', default=DEFAULT_LANGUAGE, help='Language of the document content.')
190
- general_group.add_argument('--allow_list', default=ALLOW_LIST_PATH, help='Path to a CSV file with words to exclude from redaction.')
191
- general_group.add_argument('--pii_detector', choices=[LOCAL_PII_OPTION, AWS_PII_OPTION, "None"], default=LOCAL_PII_OPTION,
192
- help='Core PII detection method (Local or AWS Comprehend, or None).')
193
- general_group.add_argument('--username', default=DIRECT_MODE_DEFAULT_USER, help='Username for the session.')
194
- general_group.add_argument('--save_to_user_folders', default=SESSION_OUTPUT_FOLDER, help='Whether to save to user folders or not.')
195
-
196
- general_group.add_argument('--local_redact_entities', nargs='+', choices=full_entity_list, default=chosen_redact_entities,
197
- help=f'Local redaction entities to use. Default: {chosen_redact_entities}. Full list: {full_entity_list}.')
198
-
199
- general_group.add_argument('--aws_redact_entities', nargs='+', choices=full_comprehend_entity_list, default=chosen_comprehend_entities,
200
- help=f'AWS redaction entities to use. Default: {chosen_comprehend_entities}. Full list: {full_comprehend_entity_list}.')
201
-
202
- general_group.add_argument('--aws_access_key', default=AWS_ACCESS_KEY, help='Your AWS Access Key ID.')
203
- general_group.add_argument('--aws_secret_key', default=AWS_SECRET_KEY, help='Your AWS Secret Access Key.')
204
- general_group.add_argument('--cost_code', default=DEFAULT_COST_CODE, help='Cost code for tracking usage.')
205
- general_group.add_argument('--aws_region', default=AWS_REGION, help='AWS region for cloud services.')
206
- general_group.add_argument('--s3_bucket', default=DOCUMENT_REDACTION_BUCKET, help='S3 bucket name for cloud operations.')
207
- general_group.add_argument('--do_initial_clean', default=DO_INITIAL_TABULAR_DATA_CLEAN, help='Perform initial text cleaning for tabular data.')
208
- general_group.add_argument('--save_logs_to_csv', default=SAVE_LOGS_TO_CSV, help='Save processing logs to CSV files.')
209
- general_group.add_argument('--save_logs_to_dynamodb', default=SAVE_LOGS_TO_DYNAMODB, help='Save processing logs to DynamoDB.')
210
- general_group.add_argument('--display_file_names_in_logs', default=DISPLAY_FILE_NAMES_IN_LOGS, help='Include file names in log outputs.')
211
- general_group.add_argument('--upload_logs_to_s3', default=RUN_AWS_FUNCTIONS == "1", help='Upload log files to S3 after processing.')
212
- general_group.add_argument('--s3_logs_prefix', default=S3_USAGE_LOGS_FOLDER, help='S3 prefix for usage log files.')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
213
 
214
  # --- PDF/Image Redaction Arguments ---
215
- pdf_group = parser.add_argument_group('PDF/Image Redaction Options (.pdf, .png, .jpg)')
216
- pdf_group.add_argument('--ocr_method', choices=["AWS Textract", "Local OCR", "Local text"], default="Local OCR", help='OCR method for text extraction from images.')
217
- pdf_group.add_argument('--page_min', type=int, default=0, help='First page to redact.')
218
- pdf_group.add_argument('--page_max', type=int, default=0, help='Last page to redact.')
219
- pdf_group.add_argument('--images_dpi', type=float, default=float(IMAGES_DPI), help='DPI for image processing.')
220
- pdf_group.add_argument('--chosen_local_ocr_model', choices=['tesseract', 'hybrid', 'paddle'], default=CHOSEN_LOCAL_OCR_MODEL, help='Local OCR model to use.')
221
- pdf_group.add_argument('--preprocess_local_ocr_images', default=PREPROCESS_LOCAL_OCR_IMAGES, help='Preprocess images before OCR.')
222
- pdf_group.add_argument('--compress_redacted_pdf', default=COMPRESS_REDACTED_PDF, help='Compress the final redacted PDF.')
223
- pdf_group.add_argument('--return_pdf_end_of_redaction', default=RETURN_PDF_END_OF_REDACTION, help='Return PDF at end of redaction process.')
224
- pdf_group.add_argument('--deny_list_file', default=DENY_LIST_PATH, help='Custom words file to recognize for redaction.')
225
- pdf_group.add_argument('--allow_list_file', default=ALLOW_LIST_PATH, help='Custom words file to recognize for redaction.')
226
- pdf_group.add_argument('--redact_whole_page_file', default=WHOLE_PAGE_REDACTION_LIST_PATH, help='File for pages to redact completely.')
227
- pdf_group.add_argument('--handwrite_signature_extraction', nargs='+', default=default_handwrite_signature_checkbox, help='Handwriting and signature extraction options. Choose from "Extract handwriting", "Extract signatures".')
228
- pdf_group.add_argument('--extract_forms', action='store_true', help='Extract forms during Textract analysis.')
229
- pdf_group.add_argument('--extract_tables', action='store_true', help='Extract tables during Textract analysis.')
230
- pdf_group.add_argument('--extract_layout', action='store_true', help='Extract layout during Textract analysis.')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
231
 
232
  # --- Word/Tabular Anonymisation Arguments ---
233
- tabular_group = parser.add_argument_group('Word/Tabular Anonymisation Options (.docx, .csv, .xlsx)')
234
- tabular_group.add_argument('--anon_strategy', choices=['redact', 'redact completely', 'replace_redacted', 'entity_type', 'encrypt', 'hash', 'replace with \'REDACTED\'', 'replace with <ENTITY_NAME>', 'mask', 'fake_first_name'], default=DEFAULT_TABULAR_ANONYMISATION_STRATEGY, help='The anonymisation strategy to apply.')
235
- tabular_group.add_argument('--text_columns', nargs='+', default=list(), help='A list of column names to anonymise or deduplicate in tabular data.')
236
- tabular_group.add_argument('--excel_sheets', nargs='+', default=list(), help='Specific Excel sheet names to process.')
237
- tabular_group.add_argument('--fuzzy_mistakes', type=int, default=DEFAULT_FUZZY_SPELLING_MISTAKES_NUM, help='Number of allowed spelling mistakes for fuzzy matching.')
238
- tabular_group.add_argument('--match_fuzzy_whole_phrase_bool', default=True, help='Match fuzzy whole phrase boolean.')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
239
  # --- Duplicate Detection Arguments ---
240
- duplicate_group = parser.add_argument_group('Duplicate Detection Options')
241
- duplicate_group.add_argument('--duplicate_type', choices=['pages', 'tabular'], default='pages', help='Type of duplicate detection: pages (for OCR files) or tabular (for CSV/Excel files).')
242
- duplicate_group.add_argument('--similarity_threshold', type=float, default=DEFAULT_DUPLICATE_DETECTION_THRESHOLD, help='Similarity threshold (0-1) to consider content as duplicates.')
243
- duplicate_group.add_argument('--min_word_count', type=int, default=DEFAULT_MIN_WORD_COUNT, help='Minimum word count for text to be considered in duplicate analysis.')
244
- duplicate_group.add_argument('--min_consecutive_pages', type=int, default=DEFAULT_MIN_CONSECUTIVE_PAGES, help='Minimum number of consecutive pages to consider as a match.')
245
- duplicate_group.add_argument('--greedy_match', default=USE_GREEDY_DUPLICATE_DETECTION, help='Use greedy matching strategy for consecutive pages.')
246
- duplicate_group.add_argument('--combine_pages', default=DEFAULT_COMBINE_PAGES, help='Combine text from the same page number within a file. Alternative will enable line-level duplicate detection.')
247
- duplicate_group.add_argument('--remove_duplicate_rows', default=REMOVE_DUPLICATE_ROWS, help='Remove duplicate rows from the output.')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
248
 
249
  # --- Textract Batch Operations Arguments ---
250
- textract_group = parser.add_argument_group('Textract Batch Operations Options')
251
- textract_group.add_argument('--textract_action',
252
- choices=['submit', 'retrieve', 'list'],
253
- help='Textract action to perform: submit (submit document for analysis), retrieve (get results by job ID), or list (show recent jobs).')
254
- textract_group.add_argument('--job_id', help='Textract job ID for retrieve action.')
255
- textract_group.add_argument('--extract_signatures', action='store_true', help='Extract signatures during Textract analysis (for submit action).')
256
- textract_group.add_argument('--textract_bucket', default=TEXTRACT_WHOLE_DOCUMENT_ANALYSIS_BUCKET, help='S3 bucket name for Textract operations (overrides default).')
257
- textract_group.add_argument('--textract_input_prefix', default=TEXTRACT_WHOLE_DOCUMENT_ANALYSIS_INPUT_SUBFOLDER, help='S3 prefix for input files in Textract operations.')
258
- textract_group.add_argument('--textract_output_prefix', default=TEXTRACT_WHOLE_DOCUMENT_ANALYSIS_OUTPUT_SUBFOLDER, help='S3 prefix for output files in Textract operations.')
259
- textract_group.add_argument('--s3_textract_document_logs_subfolder', default=TEXTRACT_JOBS_S3_LOC, help='S3 prefix for logs in Textract operations.')
260
- textract_group.add_argument('--local_textract_document_logs_subfolder', default=TEXTRACT_JOBS_LOCAL_LOC, help='Local prefix for logs in Textract operations.')
261
- textract_group.add_argument('--poll_interval', type=int, default=30, help='Polling interval in seconds for Textract job status.')
262
- textract_group.add_argument('--max_poll_attempts', type=int, default=120, help='Maximum number of polling attempts for Textract job completion.')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
263
  # Parse arguments - either from command line or direct mode
264
  if direct_mode_args:
265
  # Use direct mode arguments
@@ -270,42 +585,70 @@ python cli_redact.py --task textract --textract_action list
270
 
271
  # --- Initial Setup ---
272
  # Convert string boolean variables to boolean
273
- if args.preprocess_local_ocr_images == "True": args.preprocess_local_ocr_images = True
274
- else: args.preprocess_local_ocr_images = False
275
- if args.greedy_match == "True": args.greedy_match = True
276
- else: args.greedy_match = False
277
- if args.combine_pages == "True": args.combine_pages = True
278
- else: args.combine_pages = False
279
- if args.remove_duplicate_rows == "True": args.remove_duplicate_rows = True
280
- else: args.remove_duplicate_rows = False
281
- if args.return_pdf_end_of_redaction == "True": args.return_pdf_end_of_redaction = True
282
- else: args.return_pdf_end_of_redaction = False
283
- if args.compress_redacted_pdf == "True": args.compress_redacted_pdf = True
284
- else: args.compress_redacted_pdf = False
285
- if args.do_initial_clean == "True": args.do_initial_clean = True
286
- else: args.do_initial_clean = False
287
- if args.save_logs_to_csv == "True": args.save_logs_to_csv = True
288
- else: args.save_logs_to_csv = False
289
- if args.save_logs_to_dynamodb == "True": args.save_logs_to_dynamodb = True
290
- else: args.save_logs_to_dynamodb = False
291
- if args.display_file_names_in_logs == "True": args.display_file_names_in_logs = True
292
- else: args.display_file_names_in_logs = False
293
- if args.match_fuzzy_whole_phrase_bool == "True": args.match_fuzzy_whole_phrase_bool = True
294
- else: args.match_fuzzy_whole_phrase_bool = False
295
- if args.save_to_user_folders == "True": args.save_to_user_folders = True
296
- else: args.save_to_user_folders = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
297
 
298
  # Combine extraction options
299
- extraction_options = list(args.handwrite_signature_extraction) if args.handwrite_signature_extraction else []
 
 
 
 
300
  if args.extract_forms:
301
- extraction_options.append('Extract forms')
302
  if args.extract_tables:
303
- extraction_options.append('Extract tables')
304
  if args.extract_layout:
305
- extraction_options.append('Extract layout')
306
  args.handwrite_signature_extraction = extraction_options
307
 
308
- if args.task in ['redact', 'deduplicate']:
309
  if args.input_file:
310
  if isinstance(args.input_file, str):
311
  args.input_file = [args.input_file]
@@ -314,25 +657,46 @@ python cli_redact.py --task textract --textract_action list
314
  file_extension = file_extension.lower()
315
  else:
316
  raise ValueError("Error: --input_file is required for 'redact' task.")
317
-
318
  # Initialise usage logger if logging is enabled
319
  usage_logger = None
320
  if args.save_logs_to_csv or args.save_logs_to_dynamodb:
321
  from tools.cli_usage_logger import create_cli_usage_logger
 
322
  try:
323
  usage_logger = create_cli_usage_logger()
324
  except Exception as e:
325
  print(f"Warning: Could not initialise usage logger: {e}")
326
 
327
  # Get username and folders
328
- session_hash, args.output_dir, _, args.input_dir, args.textract_input_prefix, args.textract_output_prefix, args.s3_textract_document_logs_subfolder, args.local_textract_document_logs_subfolder = get_username_and_folders(username=args.username, output_folder_textbox=args.output_dir, input_folder_textbox=args.input_dir, session_output_folder=args.save_to_user_folders, textract_document_upload_input_folder=args.textract_input_prefix, textract_document_upload_output_folder=args.textract_output_prefix, s3_textract_document_logs_subfolder=args.s3_textract_document_logs_subfolder, local_textract_document_logs_subfolder=args.local_textract_document_logs_subfolder)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
329
 
330
- print(f"Conducting analyses with user {args.username}. Outputs will be saved to {args.output_dir}.")
 
 
331
 
332
  # --- Route to the Correct Workflow Based on Task and File Type ---
333
 
334
  # Validate input_file requirement for tasks that need it
335
- if args.task in ['redact', 'deduplicate'] and not args.input_file:
336
  print(f"Error: --input_file is required for '{args.task}' task.")
337
  return
338
 
@@ -342,70 +706,151 @@ python cli_redact.py --task textract --textract_action list
342
  args.prepare_images = False
343
 
344
  from tools.cli_usage_logger import create_cli_usage_logger, log_redaction_usage
 
345
  # Task 1: Redaction/Anonymisation
346
- if args.task == 'redact':
347
 
348
  # Workflow 1: PDF/Image Redaction
349
- if file_extension in ['.pdf', '.png', '.jpg', '.jpeg']:
350
  print("--- Detected PDF/Image file. Starting Redaction Workflow... ---")
351
  start_time = time.time()
352
  try:
353
  from tools.file_conversion import prepare_image_or_pdf
354
  from tools.file_redaction import choose_and_run_redactor
 
355
  # Step 1: Prepare the document
356
  print("\nStep 1: Preparing document...")
357
  (
358
- prep_summary, prepared_pdf_paths, image_file_paths, _, _, pdf_doc,
359
- image_annotations, _, original_cropboxes, page_sizes, _, _, _, _, _
 
 
 
 
 
 
 
 
 
 
 
 
 
360
  ) = prepare_image_or_pdf(
361
- file_paths=args.input_file, text_extract_method=args.ocr_method, all_line_level_ocr_results_df=pd.DataFrame(), all_page_line_level_ocr_results_with_words_df=pd.DataFrame(),
362
- first_loop_state=True, prepare_for_review=False,
363
- output_folder=args.output_dir, input_folder=args.input_dir, prepare_images=args.prepare_images
 
 
 
 
 
 
364
  )
365
  print(f"Preparation complete. {prep_summary}")
366
 
367
  # Step 2: Redact the prepared document
368
  print("\nStep 2: Running redaction...")
369
  (
370
- output_summary, output_files, _, _, log_files, _, _, _, _, _, _, _, _, _, comprehend_query_number, _, _, _, _, _, _, page_sizes, _, _, _, total_textract_query_number, _, _, _, _, _, _
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
371
  ) = choose_and_run_redactor(
372
- file_paths=args.input_file, prepared_pdf_file_paths=prepared_pdf_paths,
373
- pdf_image_file_paths=image_file_paths, chosen_redact_entities=args.local_redact_entities,
374
- chosen_redact_comprehend_entities=args.aws_redact_entities, text_extraction_method=args.ocr_method,
375
- in_allow_list=args.allow_list_file, in_deny_list=args.deny_list_file,
376
- redact_whole_page_list=args.redact_whole_page_file, first_loop_state=True,
377
- page_min=args.page_min, page_max=args.page_max, handwrite_signature_checkbox=args.handwrite_signature_extraction, max_fuzzy_spelling_mistakes_num=args.fuzzy_mistakes, match_fuzzy_whole_phrase_bool=args.match_fuzzy_whole_phrase_bool,
378
- pymupdf_doc=pdf_doc, annotations_all_pages=image_annotations, page_sizes=page_sizes,
379
- document_cropboxes=original_cropboxes, pii_identification_method=args.pii_detector,
380
- aws_access_key_textbox=args.aws_access_key, aws_secret_key_textbox=args.aws_secret_key,
381
- language=args.language, output_folder=args.output_dir, input_folder=args.input_dir
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
382
  )
383
 
384
  # Calculate processing time
385
  end_time = time.time()
386
  processing_time = end_time - start_time
387
-
388
  # Log usage data if logger is available
389
  if usage_logger:
390
  try:
391
  # Extract file name for logging
392
  print("Saving logs to CSV")
393
- doc_file_name = os.path.basename(args.input_file[0]) if args.display_file_names_in_logs else "document"
 
 
 
 
394
  data_file_name = "" # Not applicable for PDF/image redaction
395
-
396
  # Determine if this was a Textract API call
397
  is_textract_call = args.ocr_method == "AWS Textract"
398
-
399
  # Count pages (approximate from page_sizes if available)
400
  total_pages = len(page_sizes) if page_sizes else 1
401
-
402
  # Count API calls (approximate - would need to be tracked in the redaction function)
403
- textract_queries = int(total_textract_query_number) if is_textract_call else 0
404
- comprehend_queries = int(comprehend_query_number) if args.pii_detector == "AWS Comprehend" else 0
405
-
 
 
 
 
 
 
406
  # Format handwriting/signature options
407
- handwriting_signature = ", ".join(args.handwrite_signature_extraction) if args.handwrite_signature_extraction else ""
408
-
 
 
 
 
409
  log_redaction_usage(
410
  logger=usage_logger,
411
  session_hash=session_hash,
@@ -424,33 +869,47 @@ python cli_redact.py --task textract --textract_action list
424
  save_to_dynamodb=args.save_logs_to_dynamodb,
425
  save_to_s3=args.upload_logs_to_s3,
426
  s3_bucket=args.s3_bucket,
427
- s3_key_prefix=args.s3_logs_prefix
428
  )
429
  except Exception as e:
430
  print(f"Warning: Could not log usage data: {e}")
431
-
432
  print("\n--- Redaction Process Complete ---")
433
  print(f"Summary: {output_summary}")
434
  print(f"Processing time: {processing_time:.2f} seconds")
435
  print(f"\nOutput files saved to: {args.output_dir}")
436
  print("Generated Files:", sorted(output_files))
437
- if log_files: print("Log Files:", sorted(log_files))
 
438
 
439
  except Exception as e:
440
- print(f"\nAn error occurred during the PDF/Image redaction workflow: {e}")
 
 
441
 
442
  # Workflow 2: Word/Tabular Data Anonymisation
443
- elif file_extension in ['.docx', '.xlsx', '.xls', '.csv', '.parquet']:
444
- print("--- Detected Word/Tabular file. Starting Anonymisation Workflow... ---")
 
 
445
  start_time = time.time()
446
  try:
447
  from tools.data_anonymise import anonymise_files_with_open_text
448
-
449
  # Run the anonymisation function directly
450
 
451
- output_summary, output_files, _, _, log_files, _, processing_time, comprehend_query_number = anonymise_files_with_open_text(
 
 
 
 
 
 
 
 
 
452
  file_paths=args.input_file,
453
- in_text="", # Not used for file-based operations
454
  anon_strategy=args.anon_strategy,
455
  chosen_cols=args.text_columns,
456
  chosen_redact_entities=args.local_redact_entities,
@@ -465,34 +924,42 @@ python cli_redact.py --task textract --textract_action list
465
  aws_access_key_textbox=args.aws_access_key,
466
  aws_secret_key_textbox=args.aws_secret_key,
467
  language=args.language,
468
- do_initial_clean=args.do_initial_clean
469
  )
470
 
471
  # Calculate processing time
472
  end_time = time.time()
473
  processing_time = end_time - start_time
474
-
475
  # Log usage data if logger is available
476
  if usage_logger:
477
  try:
478
  print("Saving logs to CSV")
479
  # Extract file name for logging
480
  doc_file_name = "" # Not applicable for tabular data
481
- data_file_name = os.path.basename(args.input_file[0]) if args.display_file_names_in_logs else "data_file"
482
-
 
 
 
 
483
  # Determine if this was a Textract API call (not applicable for tabular)
484
  is_textract_call = False
485
-
486
  # Count pages (not applicable for tabular data)
487
  total_pages = 0
488
-
489
  # Count API calls (approximate - would need to be tracked in the anonymisation function)
490
  textract_queries = 0 # Not applicable for tabular data
491
- comprehend_queries = comprehend_query_number if args.pii_detector == "AWS Comprehend" else 0
492
-
 
 
 
 
493
  # Format handwriting/signature options (not applicable for tabular)
494
  handwriting_signature = ""
495
-
496
  log_redaction_usage(
497
  logger=usage_logger,
498
  session_hash=session_hash,
@@ -511,7 +978,7 @@ python cli_redact.py --task textract --textract_action list
511
  save_to_dynamodb=args.save_logs_to_dynamodb,
512
  save_to_s3=args.upload_logs_to_s3,
513
  s3_bucket=args.s3_bucket,
514
- s3_key_prefix=args.s3_logs_prefix
515
  )
516
  except Exception as e:
517
  print(f"Warning: Could not log usage data: {e}")
@@ -521,54 +988,71 @@ python cli_redact.py --task textract --textract_action list
521
  print(f"Processing time: {processing_time:.2f} seconds")
522
  print(f"\nOutput files saved to: {args.output_dir}")
523
  print("Generated Files:", sorted(output_files))
524
- if log_files: print("Log Files:", sorted(log_files))
 
525
 
526
  except Exception as e:
527
- print(f"\nAn error occurred during the Word/Tabular anonymisation workflow: {e}")
528
-
 
 
529
  else:
530
  print(f"Error: Unsupported file type '{file_extension}' for redaction.")
531
  print("Supported types for redaction: .pdf, .png, .jpg, .jpeg")
532
- print("Supported types for anonymisation: .docx, .xlsx, .xls, .csv, .parquet")
 
 
533
 
534
  # Task 2: Duplicate Detection
535
- elif args.task == 'deduplicate':
536
  print("--- Starting Duplicate Detection Workflow... ---")
537
  try:
538
  from tools.find_duplicate_pages import run_duplicate_analysis
539
- if args.duplicate_type == 'pages':
 
540
  # Page duplicate detection
541
- if file_extension == '.csv':
542
- print("--- Detected OCR CSV file. Starting Page Duplicate Detection... ---")
543
-
 
 
544
  start_time = time.time()
545
 
546
- if args.combine_pages == True:
547
  print("Combining pages...")
548
  else:
549
  print("Using line-level duplicate detection...")
550
 
551
  # Load the CSV file as a list for the duplicate analysis function
552
- results_df, output_paths, full_data_by_file, processing_time, task_textbox = run_duplicate_analysis(
 
 
 
 
 
 
553
  files=args.input_file,
554
  threshold=args.similarity_threshold,
555
  min_words=args.min_word_count,
556
  min_consecutive=args.min_consecutive_pages,
557
  greedy_match=args.greedy_match,
558
  combine_pages=args.combine_pages,
559
- output_folder=args.output_dir
560
  )
561
-
562
  end_time = time.time()
563
  processing_time = end_time - start_time
564
 
565
  print("\n--- Page Duplicate Detection Complete ---")
566
  print(f"Found {len(results_df)} duplicate matches")
567
  print(f"\nOutput files saved to: {args.output_dir}")
568
- if output_paths: print("Generated Files:", sorted(output_paths))
569
-
 
570
  else:
571
- print(f"Error: Page duplicate detection requires CSV files with OCR data.")
 
 
572
  print("Please provide a CSV file containing OCR output data.")
573
 
574
  # Log usage data if logger is available
@@ -576,22 +1060,28 @@ python cli_redact.py --task textract --textract_action list
576
  try:
577
  # Extract file name for logging
578
  print("Saving logs to CSV")
579
- doc_file_name = os.path.basename(args.input_file[0]) if args.display_file_names_in_logs else "document"
580
- data_file_name = "" # Not applicable for PDF/image redaction
581
-
 
 
 
 
 
 
582
  # Determine if this was a Textract API call
583
  is_textract_call = False
584
-
585
  # Count pages (approximate from page_sizes if available)
586
  total_pages = len(page_sizes) if page_sizes else 1
587
-
588
  # Count API calls (approximate - would need to be tracked in the redaction function)
589
  textract_queries = 0
590
  comprehend_queries = 0
591
-
592
  # Format handwriting/signature options
593
  handwriting_signature = ""
594
-
595
  log_redaction_usage(
596
  logger=usage_logger,
597
  session_hash=session_hash,
@@ -610,20 +1100,29 @@ python cli_redact.py --task textract --textract_action list
610
  save_to_dynamodb=args.save_logs_to_dynamodb,
611
  save_to_s3=args.upload_logs_to_s3,
612
  s3_bucket=args.s3_bucket,
613
- s3_key_prefix=args.s3_logs_prefix
614
  )
615
  except Exception as e:
616
  print(f"Warning: Could not log usage data: {e}")
617
-
618
- elif args.duplicate_type == 'tabular':
619
  # Tabular duplicate detection
620
  from tools.find_duplicate_tabular import run_tabular_duplicate_detection
621
- if file_extension in ['.csv', '.xlsx', '.xls', '.parquet']:
622
- print("--- Detected tabular file. Starting Tabular Duplicate Detection... ---")
 
 
 
623
 
624
  start_time = time.time()
625
-
626
- results_df, output_paths, full_data_by_file, processing_time, task_textbox = run_tabular_duplicate_detection(
 
 
 
 
 
 
627
  files=args.input_file,
628
  threshold=args.similarity_threshold,
629
  min_words=args.min_word_count,
@@ -631,7 +1130,7 @@ python cli_redact.py --task textract --textract_action list
631
  output_folder=args.output_dir,
632
  do_initial_clean_dup=args.do_initial_clean,
633
  in_excel_tabular_sheets=args.excel_sheets,
634
- remove_duplicate_rows=args.remove_duplicate_rows
635
  )
636
 
637
  end_time = time.time()
@@ -643,21 +1142,25 @@ python cli_redact.py --task textract --textract_action list
643
  # Extract file name for logging
644
  print("Saving logs to CSV")
645
  doc_file_name = ""
646
- data_file_name = os.path.basename(args.input_file[0]) if args.display_file_names_in_logs else "data_file"
647
-
 
 
 
 
648
  # Determine if this was a Textract API call
649
  is_textract_call = False
650
-
651
  # Count pages (approximate from page_sizes if available)
652
  total_pages = len(page_sizes) if page_sizes else 1
653
-
654
  # Count API calls (approximate - would need to be tracked in the redaction function)
655
  textract_queries = 0
656
  comprehend_queries = 0
657
-
658
  # Format handwriting/signature options
659
  handwriting_signature = ""
660
-
661
  log_redaction_usage(
662
  logger=usage_logger,
663
  session_hash=session_hash,
@@ -676,58 +1179,80 @@ python cli_redact.py --task textract --textract_action list
676
  save_to_dynamodb=args.save_logs_to_dynamodb,
677
  save_to_s3=args.upload_logs_to_s3,
678
  s3_bucket=args.s3_bucket,
679
- s3_key_prefix=args.s3_logs_prefix
680
  )
681
  except Exception as e:
682
  print(f"Warning: Could not log usage data: {e}")
683
-
684
  print("\n--- Tabular Duplicate Detection Complete ---")
685
  print(f"Found {len(results_df)} duplicate matches")
686
  print(f"\nOutput files saved to: {args.output_dir}")
687
- if output_paths: print("Generated Files:", sorted(output_paths))
688
-
 
689
  else:
690
- print(f"Error: Tabular duplicate detection requires CSV, Excel, or Parquet files.")
 
 
691
  print("Supported types: .csv, .xlsx, .xls, .parquet")
692
  else:
693
  print(f"Error: Invalid duplicate type '{args.duplicate_type}'.")
694
  print("Valid options: 'pages' or 'tabular'")
695
-
696
  except Exception as e:
697
  print(f"\nAn error occurred during the duplicate detection workflow: {e}")
698
 
699
  # Task 3: Textract Batch Operations
700
- elif args.task == 'textract':
701
  print("--- Starting Textract Batch Operations Workflow... ---")
702
-
703
  if not args.textract_action:
704
  print("Error: --textract_action is required for textract task.")
705
  print("Valid options: 'submit', 'retrieve', or 'list'")
706
  return
707
-
708
  try:
709
- if args.textract_action == 'submit':
710
- from tools.textract_batch_call import analyse_document_with_textract_api, load_in_textract_job_details
 
 
 
 
711
  # Submit document to Textract for analysis
712
  if not args.input_file:
713
  print("Error: --input_file is required for submit action.")
714
  return
715
-
716
  print(f"--- Submitting document to Textract: {args.input_file} ---")
717
 
718
  start_time = time.time()
719
-
720
  # Load existing job details
721
- job_df = load_in_textract_job_details(load_s3_jobs_loc=args.s3_textract_document_logs_subfolder, load_local_jobs_loc=args.local_textract_document_logs_subfolder)
722
-
 
 
 
723
  # Determine signature extraction options
724
- signature_options = ['Extract handwriting', 'Extract signatures'] if args.extract_signatures else ['Extract handwriting']
725
-
 
 
 
 
726
  # Use configured bucket or override
727
  textract_bucket = args.textract_bucket if args.textract_bucket else ""
728
-
729
  # Submit the job
730
- result_message, job_id, job_type, successful_job_number, is_textract_call, total_pages, task_textbox = analyse_document_with_textract_api(
 
 
 
 
 
 
 
 
731
  local_pdf_path=args.input_file,
732
  s3_input_prefix=args.textract_input_prefix,
733
  s3_output_prefix=args.textract_output_prefix,
@@ -736,13 +1261,13 @@ python cli_redact.py --task textract --textract_action list
736
  general_s3_bucket_name=args.s3_bucket,
737
  local_output_dir=args.output_dir,
738
  handwrite_signature_checkbox=signature_options,
739
- aws_region=args.aws_region
740
  )
741
 
742
  end_time = time.time()
743
  processing_time = end_time - start_time
744
-
745
- print(f"\n--- Textract Job Submitted Successfully ---")
746
  print(f"Job ID: {job_id}")
747
  print(f"Job Type: {job_type}")
748
  print(f"Message: {result_message}")
@@ -753,20 +1278,24 @@ python cli_redact.py --task textract --textract_action list
753
  try:
754
  # Extract file name for logging
755
  print("Saving logs to CSV")
756
- doc_file_name = os.path.basename(args.input_file[0]) if args.display_file_names_in_logs else "document"
 
 
 
 
757
  data_file_name = ""
758
-
759
  # Determine if this was a Textract API call
760
  is_textract_call = True
761
  args.ocr_method == "AWS Textract"
762
-
763
  # Count API calls (approximate - would need to be tracked in the redaction function)
764
  textract_queries = total_pages
765
  comprehend_queries = 0
766
-
767
  # Format handwriting/signature options
768
  handwriting_signature = ""
769
-
770
  log_redaction_usage(
771
  logger=usage_logger,
772
  session_hash=session_hash,
@@ -785,62 +1314,75 @@ python cli_redact.py --task textract --textract_action list
785
  save_to_dynamodb=args.save_logs_to_dynamodb,
786
  save_to_s3=args.upload_logs_to_s3,
787
  s3_bucket=args.s3_bucket,
788
- s3_key_prefix=args.s3_logs_prefix
789
  )
790
  except Exception as e:
791
  print(f"Warning: Could not log usage data: {e}")
792
-
793
- elif args.textract_action == 'retrieve':
794
  print(f"--- Retrieving Textract results for Job ID: {args.job_id} ---")
795
 
796
- from tools.textract_batch_call import poll_whole_document_textract_analysis_progress_and_download, load_in_textract_job_details
 
 
 
 
797
  # Retrieve results by job ID
798
  if not args.job_id:
799
  print("Error: --job_id is required for retrieve action.")
800
  return
801
-
802
  # Load existing job details to get job type
803
  print("Loading existing job details...")
804
- job_df = load_in_textract_job_details(load_s3_jobs_loc=args.s3_textract_document_logs_subfolder, load_local_jobs_loc=args.local_textract_document_logs_subfolder)
805
-
 
 
 
806
  # Find job type from the dataframe
807
  job_type = "document_text_detection" # default
808
  if not job_df.empty and "job_id" in job_df.columns:
809
  matching_jobs = job_df.loc[job_df["job_id"] == args.job_id]
810
  if not matching_jobs.empty and "job_type" in matching_jobs.columns:
811
  job_type = matching_jobs.iloc[0]["job_type"]
812
-
813
  # Use configured bucket or override
814
  textract_bucket = args.textract_bucket if args.textract_bucket else ""
815
-
816
  # Poll for completion and download results
817
  print("Polling for completion and downloading results...")
818
- downloaded_file_path, job_status, updated_job_df, output_filename = poll_whole_document_textract_analysis_progress_and_download(
819
- job_id=args.job_id,
820
- job_type_dropdown=job_type,
821
- s3_output_prefix=args.textract_output_prefix,
822
- pdf_filename="", # Will be determined from job details
823
- job_df=job_df,
824
- s3_bucket_name=textract_bucket,
825
- load_s3_jobs_loc=args.s3_textract_document_logs_subfolder,
826
- load_local_jobs_loc=args.local_textract_document_logs_subfolder,
827
- local_output_dir=args.output_dir,
828
- poll_interval_seconds=args.poll_interval,
829
- max_polling_attempts=args.max_poll_attempts
 
 
830
  )
831
-
832
- print(f"\n--- Textract Results Retrieved Successfully ---")
833
  print(f"Job Status: {job_status}")
834
  print(f"Downloaded File: {downloaded_file_path}")
835
- #print(f"Output Filename: {output_filename}")
836
-
837
- elif args.textract_action == 'list':
838
  from tools.textract_batch_call import load_in_textract_job_details
 
839
  # List recent Textract jobs
840
  print("--- Listing Recent Textract Jobs ---")
841
-
842
- job_df = load_in_textract_job_details(load_s3_jobs_loc=args.s3_textract_document_logs_subfolder, load_local_jobs_loc=args.local_textract_document_logs_subfolder)
843
-
 
 
 
844
  if job_df.empty:
845
  print("No recent Textract jobs found.")
846
  else:
@@ -853,17 +1395,18 @@ python cli_redact.py --task textract --textract_action list
853
  print(f"Signatures: {job.get('signature_extraction', 'N/A')}")
854
  print(f"Date: {job.get('job_date_time', 'N/A')}")
855
  print("-" * 80)
856
-
857
  else:
858
  print(f"Error: Invalid textract_action '{args.textract_action}'.")
859
  print("Valid options: 'submit', 'retrieve', or 'list'")
860
-
861
  except Exception as e:
862
  print(f"\nAn error occurred during the Textract workflow: {e}")
863
-
864
  else:
865
  print(f"Error: Invalid task '{args.task}'.")
866
  print("Valid options: 'redact', 'deduplicate', or 'textract'")
867
 
 
868
  if __name__ == "__main__":
869
- main()
 
1
  import argparse
2
  import os
 
3
  import time
4
  import uuid
 
5
 
6
+ import pandas as pd
7
+
8
+ from tools.config import (
9
+ ALLOW_LIST_PATH,
10
+ AWS_ACCESS_KEY,
11
+ AWS_PII_OPTION,
12
+ AWS_REGION,
13
+ AWS_SECRET_KEY,
14
+ CHOSEN_COMPREHEND_ENTITIES,
15
+ CHOSEN_LOCAL_OCR_MODEL,
16
+ CHOSEN_REDACT_ENTITIES,
17
+ COMPRESS_REDACTED_PDF,
18
+ CUSTOM_ENTITIES,
19
+ DEFAULT_COMBINE_PAGES,
20
+ DEFAULT_COST_CODE,
21
+ DEFAULT_DUPLICATE_DETECTION_THRESHOLD,
22
+ DEFAULT_FUZZY_SPELLING_MISTAKES_NUM,
23
+ DEFAULT_HANDWRITE_SIGNATURE_CHECKBOX,
24
+ DEFAULT_LANGUAGE,
25
+ DEFAULT_MIN_CONSECUTIVE_PAGES,
26
+ DEFAULT_MIN_WORD_COUNT,
27
+ DEFAULT_TABULAR_ANONYMISATION_STRATEGY,
28
+ DENY_LIST_PATH,
29
+ DIRECT_MODE_DEFAULT_USER,
30
+ DISPLAY_FILE_NAMES_IN_LOGS,
31
+ DO_INITIAL_TABULAR_DATA_CLEAN,
32
+ DOCUMENT_REDACTION_BUCKET,
33
+ FULL_COMPREHEND_ENTITY_LIST,
34
+ FULL_ENTITY_LIST,
35
+ IMAGES_DPI,
36
+ INPUT_FOLDER,
37
+ LOCAL_PII_OPTION,
38
+ OUTPUT_FOLDER,
39
+ PREPROCESS_LOCAL_OCR_IMAGES,
40
+ REMOVE_DUPLICATE_ROWS,
41
+ RETURN_PDF_END_OF_REDACTION,
42
+ RUN_AWS_FUNCTIONS,
43
+ S3_USAGE_LOGS_FOLDER,
44
+ SAVE_LOGS_TO_CSV,
45
+ SAVE_LOGS_TO_DYNAMODB,
46
+ SESSION_OUTPUT_FOLDER,
47
+ TEXTRACT_JOBS_LOCAL_LOC,
48
+ TEXTRACT_JOBS_S3_LOC,
49
+ TEXTRACT_WHOLE_DOCUMENT_ANALYSIS_BUCKET,
50
+ TEXTRACT_WHOLE_DOCUMENT_ANALYSIS_INPUT_SUBFOLDER,
51
+ TEXTRACT_WHOLE_DOCUMENT_ANALYSIS_OUTPUT_SUBFOLDER,
52
+ USE_GREEDY_DUPLICATE_DETECTION,
53
+ WHOLE_PAGE_REDACTION_LIST_PATH,
54
+ )
55
 
56
 
57
  def _generate_session_hash() -> str:
58
  """Generate a unique session hash for logging purposes."""
59
  return str(uuid.uuid4())[:8]
60
 
 
 
 
 
 
 
 
 
61
 
62
+ def get_username_and_folders(
63
+ username: str = "",
64
+ output_folder_textbox: str = OUTPUT_FOLDER,
65
+ input_folder_textbox: str = INPUT_FOLDER,
66
+ session_output_folder: str = SESSION_OUTPUT_FOLDER,
67
+ textract_document_upload_input_folder: str = TEXTRACT_WHOLE_DOCUMENT_ANALYSIS_INPUT_SUBFOLDER,
68
+ textract_document_upload_output_folder: str = TEXTRACT_WHOLE_DOCUMENT_ANALYSIS_OUTPUT_SUBFOLDER,
69
+ s3_textract_document_logs_subfolder: str = TEXTRACT_JOBS_S3_LOC,
70
+ local_textract_document_logs_subfolder: str = TEXTRACT_JOBS_LOCAL_LOC,
71
+ ):
72
 
73
  # Generate session hash for logging. Either from input user name or generated
74
  if username:
75
  out_session_hash = username
76
  else:
77
+ out_session_hash = _generate_session_hash()
 
78
 
79
+ if session_output_folder == "True" or session_output_folder is True:
80
  output_folder = output_folder_textbox + out_session_hash + "/"
81
  input_folder = input_folder_textbox + out_session_hash + "/"
82
 
83
+ textract_document_upload_input_folder = (
84
+ textract_document_upload_input_folder + "/" + out_session_hash
85
+ )
86
+ textract_document_upload_output_folder = (
87
+ textract_document_upload_output_folder + "/" + out_session_hash
88
+ )
89
 
90
+ s3_textract_document_logs_subfolder = (
91
+ s3_textract_document_logs_subfolder + "/" + out_session_hash
92
+ )
93
+ local_textract_document_logs_subfolder = (
94
+ local_textract_document_logs_subfolder + "/" + out_session_hash + "/"
95
+ )
96
 
97
  else:
98
  output_folder = output_folder_textbox
99
  input_folder = input_folder_textbox
100
 
101
+ if not os.path.exists(output_folder):
102
+ os.mkdir(output_folder)
103
+ if not os.path.exists(input_folder):
104
+ os.mkdir(input_folder)
105
+
106
+ return (
107
+ out_session_hash,
108
+ output_folder,
109
+ out_session_hash,
110
+ input_folder,
111
+ textract_document_upload_input_folder,
112
+ textract_document_upload_output_folder,
113
+ s3_textract_document_logs_subfolder,
114
+ local_textract_document_logs_subfolder,
115
+ )
116
 
117
 
118
  def _get_env_list(env_var_name: str) -> list[str]:
119
  """Parses a comma-separated environment variable into a list of strings."""
120
+ value = env_var_name[1:-1].strip().replace('"', "").replace("'", "")
121
  if not value:
122
  return []
123
  # Split by comma and filter out any empty strings that might result from extra commas
124
+ return [s.strip() for s in value.split(",") if s.strip()]
125
+
126
 
127
  # --- Constants and Configuration ---
128
 
129
+ if CHOSEN_COMPREHEND_ENTITIES:
130
+ CHOSEN_COMPREHEND_ENTITIES = _get_env_list(CHOSEN_COMPREHEND_ENTITIES)
131
+ if FULL_COMPREHEND_ENTITY_LIST:
132
+ FULL_COMPREHEND_ENTITY_LIST = _get_env_list(FULL_COMPREHEND_ENTITY_LIST)
133
+ if CHOSEN_REDACT_ENTITIES:
134
+ CHOSEN_REDACT_ENTITIES = _get_env_list(CHOSEN_REDACT_ENTITIES)
135
+ if FULL_ENTITY_LIST:
136
+ FULL_ENTITY_LIST = _get_env_list(FULL_ENTITY_LIST)
137
+ if CUSTOM_ENTITIES:
138
+ CUSTOM_ENTITIES = _get_env_list(CUSTOM_ENTITIES)
139
+ if DEFAULT_HANDWRITE_SIGNATURE_CHECKBOX:
140
+ DEFAULT_HANDWRITE_SIGNATURE_CHECKBOX = _get_env_list(
141
+ DEFAULT_HANDWRITE_SIGNATURE_CHECKBOX
142
+ )
143
 
144
  # Add custom spacy recognisers to the Comprehend list, so that local Spacy model can be used to pick up e.g. titles, streetnames, UK postcodes that are sometimes missed by comprehend
145
  CHOSEN_COMPREHEND_ENTITIES.extend(CUSTOM_ENTITIES)
 
151
  full_comprehend_entity_list = FULL_COMPREHEND_ENTITY_LIST
152
  default_handwrite_signature_checkbox = DEFAULT_HANDWRITE_SIGNATURE_CHECKBOX
153
 
154
+
155
  # --- Main CLI Function ---
156
  def main(direct_mode_args={}):
157
  """
158
  A unified command-line interface to prepare, redact, and anonymise various document types.
159
+
160
  Args:
161
  direct_mode_args (dict, optional): Dictionary of arguments for direct mode execution.
162
  If provided, uses these instead of parsing command line arguments.
163
  """
164
  parser = argparse.ArgumentParser(
165
+ description="A versatile CLI for redacting PII from PDF/image files and anonymising Word/tabular data.",
166
  formatter_class=argparse.RawTextHelpFormatter,
167
+ epilog="""
168
  Examples:
169
 
170
  To run these, you need to do the following:
 
247
  ## List recent Textract jobs:
248
  python cli_redact.py --task textract --textract_action list
249
 
250
+ """,
251
  )
252
 
253
  # --- Task Selection ---
254
+ task_group = parser.add_argument_group("Task Selection")
255
+ task_group.add_argument(
256
+ "--task",
257
+ choices=["redact", "deduplicate", "textract"],
258
+ default="redact",
259
+ help="Task to perform: redact (PII redaction/anonymisation), deduplicate (find duplicate content), or textract (AWS Textract batch operations).",
260
+ )
261
 
262
  # --- General Arguments (apply to all file types) ---
263
+ general_group = parser.add_argument_group("General Options")
264
+ general_group.add_argument(
265
+ "--input_file",
266
+ nargs="+",
267
+ help="Path to the input file(s) to process. Separate multiple files with a space, and use quotes if there are spaces in the file name.",
268
+ )
269
+ general_group.add_argument(
270
+ "--output_dir", default=OUTPUT_FOLDER, help="Directory for all output files."
271
+ )
272
+ general_group.add_argument(
273
+ "--input_dir", default=INPUT_FOLDER, help="Directory for all input files."
274
+ )
275
+ general_group.add_argument(
276
+ "--language", default=DEFAULT_LANGUAGE, help="Language of the document content."
277
+ )
278
+ general_group.add_argument(
279
+ "--allow_list",
280
+ default=ALLOW_LIST_PATH,
281
+ help="Path to a CSV file with words to exclude from redaction.",
282
+ )
283
+ general_group.add_argument(
284
+ "--pii_detector",
285
+ choices=[LOCAL_PII_OPTION, AWS_PII_OPTION, "None"],
286
+ default=LOCAL_PII_OPTION,
287
+ help="Core PII detection method (Local or AWS Comprehend, or None).",
288
+ )
289
+ general_group.add_argument(
290
+ "--username", default=DIRECT_MODE_DEFAULT_USER, help="Username for the session."
291
+ )
292
+ general_group.add_argument(
293
+ "--save_to_user_folders",
294
+ default=SESSION_OUTPUT_FOLDER,
295
+ help="Whether to save to user folders or not.",
296
+ )
297
+
298
+ general_group.add_argument(
299
+ "--local_redact_entities",
300
+ nargs="+",
301
+ choices=full_entity_list,
302
+ default=chosen_redact_entities,
303
+ help=f"Local redaction entities to use. Default: {chosen_redact_entities}. Full list: {full_entity_list}.",
304
+ )
305
+
306
+ general_group.add_argument(
307
+ "--aws_redact_entities",
308
+ nargs="+",
309
+ choices=full_comprehend_entity_list,
310
+ default=chosen_comprehend_entities,
311
+ help=f"AWS redaction entities to use. Default: {chosen_comprehend_entities}. Full list: {full_comprehend_entity_list}.",
312
+ )
313
+
314
+ general_group.add_argument(
315
+ "--aws_access_key", default=AWS_ACCESS_KEY, help="Your AWS Access Key ID."
316
+ )
317
+ general_group.add_argument(
318
+ "--aws_secret_key", default=AWS_SECRET_KEY, help="Your AWS Secret Access Key."
319
+ )
320
+ general_group.add_argument(
321
+ "--cost_code", default=DEFAULT_COST_CODE, help="Cost code for tracking usage."
322
+ )
323
+ general_group.add_argument(
324
+ "--aws_region", default=AWS_REGION, help="AWS region for cloud services."
325
+ )
326
+ general_group.add_argument(
327
+ "--s3_bucket",
328
+ default=DOCUMENT_REDACTION_BUCKET,
329
+ help="S3 bucket name for cloud operations.",
330
+ )
331
+ general_group.add_argument(
332
+ "--do_initial_clean",
333
+ default=DO_INITIAL_TABULAR_DATA_CLEAN,
334
+ help="Perform initial text cleaning for tabular data.",
335
+ )
336
+ general_group.add_argument(
337
+ "--save_logs_to_csv",
338
+ default=SAVE_LOGS_TO_CSV,
339
+ help="Save processing logs to CSV files.",
340
+ )
341
+ general_group.add_argument(
342
+ "--save_logs_to_dynamodb",
343
+ default=SAVE_LOGS_TO_DYNAMODB,
344
+ help="Save processing logs to DynamoDB.",
345
+ )
346
+ general_group.add_argument(
347
+ "--display_file_names_in_logs",
348
+ default=DISPLAY_FILE_NAMES_IN_LOGS,
349
+ help="Include file names in log outputs.",
350
+ )
351
+ general_group.add_argument(
352
+ "--upload_logs_to_s3",
353
+ default=RUN_AWS_FUNCTIONS == "1",
354
+ help="Upload log files to S3 after processing.",
355
+ )
356
+ general_group.add_argument(
357
+ "--s3_logs_prefix",
358
+ default=S3_USAGE_LOGS_FOLDER,
359
+ help="S3 prefix for usage log files.",
360
+ )
361
 
362
  # --- PDF/Image Redaction Arguments ---
363
+ pdf_group = parser.add_argument_group(
364
+ "PDF/Image Redaction Options (.pdf, .png, .jpg)"
365
+ )
366
+ pdf_group.add_argument(
367
+ "--ocr_method",
368
+ choices=["AWS Textract", "Local OCR", "Local text"],
369
+ default="Local OCR",
370
+ help="OCR method for text extraction from images.",
371
+ )
372
+ pdf_group.add_argument(
373
+ "--page_min", type=int, default=0, help="First page to redact."
374
+ )
375
+ pdf_group.add_argument(
376
+ "--page_max", type=int, default=0, help="Last page to redact."
377
+ )
378
+ pdf_group.add_argument(
379
+ "--images_dpi",
380
+ type=float,
381
+ default=float(IMAGES_DPI),
382
+ help="DPI for image processing.",
383
+ )
384
+ pdf_group.add_argument(
385
+ "--chosen_local_ocr_model",
386
+ choices=["tesseract", "hybrid", "paddle"],
387
+ default=CHOSEN_LOCAL_OCR_MODEL,
388
+ help="Local OCR model to use.",
389
+ )
390
+ pdf_group.add_argument(
391
+ "--preprocess_local_ocr_images",
392
+ default=PREPROCESS_LOCAL_OCR_IMAGES,
393
+ help="Preprocess images before OCR.",
394
+ )
395
+ pdf_group.add_argument(
396
+ "--compress_redacted_pdf",
397
+ default=COMPRESS_REDACTED_PDF,
398
+ help="Compress the final redacted PDF.",
399
+ )
400
+ pdf_group.add_argument(
401
+ "--return_pdf_end_of_redaction",
402
+ default=RETURN_PDF_END_OF_REDACTION,
403
+ help="Return PDF at end of redaction process.",
404
+ )
405
+ pdf_group.add_argument(
406
+ "--deny_list_file",
407
+ default=DENY_LIST_PATH,
408
+ help="Custom words file to recognize for redaction.",
409
+ )
410
+ pdf_group.add_argument(
411
+ "--allow_list_file",
412
+ default=ALLOW_LIST_PATH,
413
+ help="Custom words file to recognize for redaction.",
414
+ )
415
+ pdf_group.add_argument(
416
+ "--redact_whole_page_file",
417
+ default=WHOLE_PAGE_REDACTION_LIST_PATH,
418
+ help="File for pages to redact completely.",
419
+ )
420
+ pdf_group.add_argument(
421
+ "--handwrite_signature_extraction",
422
+ nargs="+",
423
+ default=default_handwrite_signature_checkbox,
424
+ help='Handwriting and signature extraction options. Choose from "Extract handwriting", "Extract signatures".',
425
+ )
426
+ pdf_group.add_argument(
427
+ "--extract_forms",
428
+ action="store_true",
429
+ help="Extract forms during Textract analysis.",
430
+ )
431
+ pdf_group.add_argument(
432
+ "--extract_tables",
433
+ action="store_true",
434
+ help="Extract tables during Textract analysis.",
435
+ )
436
+ pdf_group.add_argument(
437
+ "--extract_layout",
438
+ action="store_true",
439
+ help="Extract layout during Textract analysis.",
440
+ )
441
 
442
  # --- Word/Tabular Anonymisation Arguments ---
443
+ tabular_group = parser.add_argument_group(
444
+ "Word/Tabular Anonymisation Options (.docx, .csv, .xlsx)"
445
+ )
446
+ tabular_group.add_argument(
447
+ "--anon_strategy",
448
+ choices=[
449
+ "redact",
450
+ "redact completely",
451
+ "replace_redacted",
452
+ "entity_type",
453
+ "encrypt",
454
+ "hash",
455
+ "replace with 'REDACTED'",
456
+ "replace with <ENTITY_NAME>",
457
+ "mask",
458
+ "fake_first_name",
459
+ ],
460
+ default=DEFAULT_TABULAR_ANONYMISATION_STRATEGY,
461
+ help="The anonymisation strategy to apply.",
462
+ )
463
+ tabular_group.add_argument(
464
+ "--text_columns",
465
+ nargs="+",
466
+ default=list(),
467
+ help="A list of column names to anonymise or deduplicate in tabular data.",
468
+ )
469
+ tabular_group.add_argument(
470
+ "--excel_sheets",
471
+ nargs="+",
472
+ default=list(),
473
+ help="Specific Excel sheet names to process.",
474
+ )
475
+ tabular_group.add_argument(
476
+ "--fuzzy_mistakes",
477
+ type=int,
478
+ default=DEFAULT_FUZZY_SPELLING_MISTAKES_NUM,
479
+ help="Number of allowed spelling mistakes for fuzzy matching.",
480
+ )
481
+ tabular_group.add_argument(
482
+ "--match_fuzzy_whole_phrase_bool",
483
+ default=True,
484
+ help="Match fuzzy whole phrase boolean.",
485
+ )
486
  # --- Duplicate Detection Arguments ---
487
+ duplicate_group = parser.add_argument_group("Duplicate Detection Options")
488
+ duplicate_group.add_argument(
489
+ "--duplicate_type",
490
+ choices=["pages", "tabular"],
491
+ default="pages",
492
+ help="Type of duplicate detection: pages (for OCR files) or tabular (for CSV/Excel files).",
493
+ )
494
+ duplicate_group.add_argument(
495
+ "--similarity_threshold",
496
+ type=float,
497
+ default=DEFAULT_DUPLICATE_DETECTION_THRESHOLD,
498
+ help="Similarity threshold (0-1) to consider content as duplicates.",
499
+ )
500
+ duplicate_group.add_argument(
501
+ "--min_word_count",
502
+ type=int,
503
+ default=DEFAULT_MIN_WORD_COUNT,
504
+ help="Minimum word count for text to be considered in duplicate analysis.",
505
+ )
506
+ duplicate_group.add_argument(
507
+ "--min_consecutive_pages",
508
+ type=int,
509
+ default=DEFAULT_MIN_CONSECUTIVE_PAGES,
510
+ help="Minimum number of consecutive pages to consider as a match.",
511
+ )
512
+ duplicate_group.add_argument(
513
+ "--greedy_match",
514
+ default=USE_GREEDY_DUPLICATE_DETECTION,
515
+ help="Use greedy matching strategy for consecutive pages.",
516
+ )
517
+ duplicate_group.add_argument(
518
+ "--combine_pages",
519
+ default=DEFAULT_COMBINE_PAGES,
520
+ help="Combine text from the same page number within a file. Alternative will enable line-level duplicate detection.",
521
+ )
522
+ duplicate_group.add_argument(
523
+ "--remove_duplicate_rows",
524
+ default=REMOVE_DUPLICATE_ROWS,
525
+ help="Remove duplicate rows from the output.",
526
+ )
527
 
528
  # --- Textract Batch Operations Arguments ---
529
+ textract_group = parser.add_argument_group("Textract Batch Operations Options")
530
+ textract_group.add_argument(
531
+ "--textract_action",
532
+ choices=["submit", "retrieve", "list"],
533
+ help="Textract action to perform: submit (submit document for analysis), retrieve (get results by job ID), or list (show recent jobs).",
534
+ )
535
+ textract_group.add_argument("--job_id", help="Textract job ID for retrieve action.")
536
+ textract_group.add_argument(
537
+ "--extract_signatures",
538
+ action="store_true",
539
+ help="Extract signatures during Textract analysis (for submit action).",
540
+ )
541
+ textract_group.add_argument(
542
+ "--textract_bucket",
543
+ default=TEXTRACT_WHOLE_DOCUMENT_ANALYSIS_BUCKET,
544
+ help="S3 bucket name for Textract operations (overrides default).",
545
+ )
546
+ textract_group.add_argument(
547
+ "--textract_input_prefix",
548
+ default=TEXTRACT_WHOLE_DOCUMENT_ANALYSIS_INPUT_SUBFOLDER,
549
+ help="S3 prefix for input files in Textract operations.",
550
+ )
551
+ textract_group.add_argument(
552
+ "--textract_output_prefix",
553
+ default=TEXTRACT_WHOLE_DOCUMENT_ANALYSIS_OUTPUT_SUBFOLDER,
554
+ help="S3 prefix for output files in Textract operations.",
555
+ )
556
+ textract_group.add_argument(
557
+ "--s3_textract_document_logs_subfolder",
558
+ default=TEXTRACT_JOBS_S3_LOC,
559
+ help="S3 prefix for logs in Textract operations.",
560
+ )
561
+ textract_group.add_argument(
562
+ "--local_textract_document_logs_subfolder",
563
+ default=TEXTRACT_JOBS_LOCAL_LOC,
564
+ help="Local prefix for logs in Textract operations.",
565
+ )
566
+ textract_group.add_argument(
567
+ "--poll_interval",
568
+ type=int,
569
+ default=30,
570
+ help="Polling interval in seconds for Textract job status.",
571
+ )
572
+ textract_group.add_argument(
573
+ "--max_poll_attempts",
574
+ type=int,
575
+ default=120,
576
+ help="Maximum number of polling attempts for Textract job completion.",
577
+ )
578
  # Parse arguments - either from command line or direct mode
579
  if direct_mode_args:
580
  # Use direct mode arguments
 
585
 
586
  # --- Initial Setup ---
587
  # Convert string boolean variables to boolean
588
+ if args.preprocess_local_ocr_images == "True":
589
+ args.preprocess_local_ocr_images = True
590
+ else:
591
+ args.preprocess_local_ocr_images = False
592
+ if args.greedy_match == "True":
593
+ args.greedy_match = True
594
+ else:
595
+ args.greedy_match = False
596
+ if args.combine_pages == "True":
597
+ args.combine_pages = True
598
+ else:
599
+ args.combine_pages = False
600
+ if args.remove_duplicate_rows == "True":
601
+ args.remove_duplicate_rows = True
602
+ else:
603
+ args.remove_duplicate_rows = False
604
+ if args.return_pdf_end_of_redaction == "True":
605
+ args.return_pdf_end_of_redaction = True
606
+ else:
607
+ args.return_pdf_end_of_redaction = False
608
+ if args.compress_redacted_pdf == "True":
609
+ args.compress_redacted_pdf = True
610
+ else:
611
+ args.compress_redacted_pdf = False
612
+ if args.do_initial_clean == "True":
613
+ args.do_initial_clean = True
614
+ else:
615
+ args.do_initial_clean = False
616
+ if args.save_logs_to_csv == "True":
617
+ args.save_logs_to_csv = True
618
+ else:
619
+ args.save_logs_to_csv = False
620
+ if args.save_logs_to_dynamodb == "True":
621
+ args.save_logs_to_dynamodb = True
622
+ else:
623
+ args.save_logs_to_dynamodb = False
624
+ if args.display_file_names_in_logs == "True":
625
+ args.display_file_names_in_logs = True
626
+ else:
627
+ args.display_file_names_in_logs = False
628
+ if args.match_fuzzy_whole_phrase_bool == "True":
629
+ args.match_fuzzy_whole_phrase_bool = True
630
+ else:
631
+ args.match_fuzzy_whole_phrase_bool = False
632
+ if args.save_to_user_folders == "True":
633
+ args.save_to_user_folders = True
634
+ else:
635
+ args.save_to_user_folders = False
636
 
637
  # Combine extraction options
638
+ extraction_options = (
639
+ list(args.handwrite_signature_extraction)
640
+ if args.handwrite_signature_extraction
641
+ else []
642
+ )
643
  if args.extract_forms:
644
+ extraction_options.append("Extract forms")
645
  if args.extract_tables:
646
+ extraction_options.append("Extract tables")
647
  if args.extract_layout:
648
+ extraction_options.append("Extract layout")
649
  args.handwrite_signature_extraction = extraction_options
650
 
651
+ if args.task in ["redact", "deduplicate"]:
652
  if args.input_file:
653
  if isinstance(args.input_file, str):
654
  args.input_file = [args.input_file]
 
657
  file_extension = file_extension.lower()
658
  else:
659
  raise ValueError("Error: --input_file is required for 'redact' task.")
660
+
661
  # Initialise usage logger if logging is enabled
662
  usage_logger = None
663
  if args.save_logs_to_csv or args.save_logs_to_dynamodb:
664
  from tools.cli_usage_logger import create_cli_usage_logger
665
+
666
  try:
667
  usage_logger = create_cli_usage_logger()
668
  except Exception as e:
669
  print(f"Warning: Could not initialise usage logger: {e}")
670
 
671
  # Get username and folders
672
+ (
673
+ session_hash,
674
+ args.output_dir,
675
+ _,
676
+ args.input_dir,
677
+ args.textract_input_prefix,
678
+ args.textract_output_prefix,
679
+ args.s3_textract_document_logs_subfolder,
680
+ args.local_textract_document_logs_subfolder,
681
+ ) = get_username_and_folders(
682
+ username=args.username,
683
+ output_folder_textbox=args.output_dir,
684
+ input_folder_textbox=args.input_dir,
685
+ session_output_folder=args.save_to_user_folders,
686
+ textract_document_upload_input_folder=args.textract_input_prefix,
687
+ textract_document_upload_output_folder=args.textract_output_prefix,
688
+ s3_textract_document_logs_subfolder=args.s3_textract_document_logs_subfolder,
689
+ local_textract_document_logs_subfolder=args.local_textract_document_logs_subfolder,
690
+ )
691
 
692
+ print(
693
+ f"Conducting analyses with user {args.username}. Outputs will be saved to {args.output_dir}."
694
+ )
695
 
696
  # --- Route to the Correct Workflow Based on Task and File Type ---
697
 
698
  # Validate input_file requirement for tasks that need it
699
+ if args.task in ["redact", "deduplicate"] and not args.input_file:
700
  print(f"Error: --input_file is required for '{args.task}' task.")
701
  return
702
 
 
706
  args.prepare_images = False
707
 
708
  from tools.cli_usage_logger import create_cli_usage_logger, log_redaction_usage
709
+
710
  # Task 1: Redaction/Anonymisation
711
+ if args.task == "redact":
712
 
713
  # Workflow 1: PDF/Image Redaction
714
+ if file_extension in [".pdf", ".png", ".jpg", ".jpeg"]:
715
  print("--- Detected PDF/Image file. Starting Redaction Workflow... ---")
716
  start_time = time.time()
717
  try:
718
  from tools.file_conversion import prepare_image_or_pdf
719
  from tools.file_redaction import choose_and_run_redactor
720
+
721
  # Step 1: Prepare the document
722
  print("\nStep 1: Preparing document...")
723
  (
724
+ prep_summary,
725
+ prepared_pdf_paths,
726
+ image_file_paths,
727
+ _,
728
+ _,
729
+ pdf_doc,
730
+ image_annotations,
731
+ _,
732
+ original_cropboxes,
733
+ page_sizes,
734
+ _,
735
+ _,
736
+ _,
737
+ _,
738
+ _,
739
  ) = prepare_image_or_pdf(
740
+ file_paths=args.input_file,
741
+ text_extract_method=args.ocr_method,
742
+ all_line_level_ocr_results_df=pd.DataFrame(),
743
+ all_page_line_level_ocr_results_with_words_df=pd.DataFrame(),
744
+ first_loop_state=True,
745
+ prepare_for_review=False,
746
+ output_folder=args.output_dir,
747
+ input_folder=args.input_dir,
748
+ prepare_images=args.prepare_images,
749
  )
750
  print(f"Preparation complete. {prep_summary}")
751
 
752
  # Step 2: Redact the prepared document
753
  print("\nStep 2: Running redaction...")
754
  (
755
+ output_summary,
756
+ output_files,
757
+ _,
758
+ _,
759
+ log_files,
760
+ _,
761
+ _,
762
+ _,
763
+ _,
764
+ _,
765
+ _,
766
+ _,
767
+ _,
768
+ _,
769
+ comprehend_query_number,
770
+ _,
771
+ _,
772
+ _,
773
+ _,
774
+ _,
775
+ _,
776
+ page_sizes,
777
+ _,
778
+ _,
779
+ _,
780
+ total_textract_query_number,
781
+ _,
782
+ _,
783
+ _,
784
+ _,
785
+ _,
786
+ _,
787
  ) = choose_and_run_redactor(
788
+ file_paths=args.input_file,
789
+ prepared_pdf_file_paths=prepared_pdf_paths,
790
+ pdf_image_file_paths=image_file_paths,
791
+ chosen_redact_entities=args.local_redact_entities,
792
+ chosen_redact_comprehend_entities=args.aws_redact_entities,
793
+ text_extraction_method=args.ocr_method,
794
+ in_allow_list=args.allow_list_file,
795
+ in_deny_list=args.deny_list_file,
796
+ redact_whole_page_list=args.redact_whole_page_file,
797
+ first_loop_state=True,
798
+ page_min=args.page_min,
799
+ page_max=args.page_max,
800
+ handwrite_signature_checkbox=args.handwrite_signature_extraction,
801
+ max_fuzzy_spelling_mistakes_num=args.fuzzy_mistakes,
802
+ match_fuzzy_whole_phrase_bool=args.match_fuzzy_whole_phrase_bool,
803
+ pymupdf_doc=pdf_doc,
804
+ annotations_all_pages=image_annotations,
805
+ page_sizes=page_sizes,
806
+ document_cropboxes=original_cropboxes,
807
+ pii_identification_method=args.pii_detector,
808
+ aws_access_key_textbox=args.aws_access_key,
809
+ aws_secret_key_textbox=args.aws_secret_key,
810
+ language=args.language,
811
+ output_folder=args.output_dir,
812
+ input_folder=args.input_dir,
813
  )
814
 
815
  # Calculate processing time
816
  end_time = time.time()
817
  processing_time = end_time - start_time
818
+
819
  # Log usage data if logger is available
820
  if usage_logger:
821
  try:
822
  # Extract file name for logging
823
  print("Saving logs to CSV")
824
+ doc_file_name = (
825
+ os.path.basename(args.input_file[0])
826
+ if args.display_file_names_in_logs
827
+ else "document"
828
+ )
829
  data_file_name = "" # Not applicable for PDF/image redaction
830
+
831
  # Determine if this was a Textract API call
832
  is_textract_call = args.ocr_method == "AWS Textract"
833
+
834
  # Count pages (approximate from page_sizes if available)
835
  total_pages = len(page_sizes) if page_sizes else 1
836
+
837
  # Count API calls (approximate - would need to be tracked in the redaction function)
838
+ textract_queries = (
839
+ int(total_textract_query_number) if is_textract_call else 0
840
+ )
841
+ comprehend_queries = (
842
+ int(comprehend_query_number)
843
+ if args.pii_detector == "AWS Comprehend"
844
+ else 0
845
+ )
846
+
847
  # Format handwriting/signature options
848
+ handwriting_signature = (
849
+ ", ".join(args.handwrite_signature_extraction)
850
+ if args.handwrite_signature_extraction
851
+ else ""
852
+ )
853
+
854
  log_redaction_usage(
855
  logger=usage_logger,
856
  session_hash=session_hash,
 
869
  save_to_dynamodb=args.save_logs_to_dynamodb,
870
  save_to_s3=args.upload_logs_to_s3,
871
  s3_bucket=args.s3_bucket,
872
+ s3_key_prefix=args.s3_logs_prefix,
873
  )
874
  except Exception as e:
875
  print(f"Warning: Could not log usage data: {e}")
876
+
877
  print("\n--- Redaction Process Complete ---")
878
  print(f"Summary: {output_summary}")
879
  print(f"Processing time: {processing_time:.2f} seconds")
880
  print(f"\nOutput files saved to: {args.output_dir}")
881
  print("Generated Files:", sorted(output_files))
882
+ if log_files:
883
+ print("Log Files:", sorted(log_files))
884
 
885
  except Exception as e:
886
+ print(
887
+ f"\nAn error occurred during the PDF/Image redaction workflow: {e}"
888
+ )
889
 
890
  # Workflow 2: Word/Tabular Data Anonymisation
891
+ elif file_extension in [".docx", ".xlsx", ".xls", ".csv", ".parquet"]:
892
+ print(
893
+ "--- Detected Word/Tabular file. Starting Anonymisation Workflow... ---"
894
+ )
895
  start_time = time.time()
896
  try:
897
  from tools.data_anonymise import anonymise_files_with_open_text
898
+
899
  # Run the anonymisation function directly
900
 
901
+ (
902
+ output_summary,
903
+ output_files,
904
+ _,
905
+ _,
906
+ log_files,
907
+ _,
908
+ processing_time,
909
+ comprehend_query_number,
910
+ ) = anonymise_files_with_open_text(
911
  file_paths=args.input_file,
912
+ in_text="", # Not used for file-based operations
913
  anon_strategy=args.anon_strategy,
914
  chosen_cols=args.text_columns,
915
  chosen_redact_entities=args.local_redact_entities,
 
924
  aws_access_key_textbox=args.aws_access_key,
925
  aws_secret_key_textbox=args.aws_secret_key,
926
  language=args.language,
927
+ do_initial_clean=args.do_initial_clean,
928
  )
929
 
930
  # Calculate processing time
931
  end_time = time.time()
932
  processing_time = end_time - start_time
933
+
934
  # Log usage data if logger is available
935
  if usage_logger:
936
  try:
937
  print("Saving logs to CSV")
938
  # Extract file name for logging
939
  doc_file_name = "" # Not applicable for tabular data
940
+ data_file_name = (
941
+ os.path.basename(args.input_file[0])
942
+ if args.display_file_names_in_logs
943
+ else "data_file"
944
+ )
945
+
946
  # Determine if this was a Textract API call (not applicable for tabular)
947
  is_textract_call = False
948
+
949
  # Count pages (not applicable for tabular data)
950
  total_pages = 0
951
+
952
  # Count API calls (approximate - would need to be tracked in the anonymisation function)
953
  textract_queries = 0 # Not applicable for tabular data
954
+ comprehend_queries = (
955
+ comprehend_query_number
956
+ if args.pii_detector == "AWS Comprehend"
957
+ else 0
958
+ )
959
+
960
  # Format handwriting/signature options (not applicable for tabular)
961
  handwriting_signature = ""
962
+
963
  log_redaction_usage(
964
  logger=usage_logger,
965
  session_hash=session_hash,
 
978
  save_to_dynamodb=args.save_logs_to_dynamodb,
979
  save_to_s3=args.upload_logs_to_s3,
980
  s3_bucket=args.s3_bucket,
981
+ s3_key_prefix=args.s3_logs_prefix,
982
  )
983
  except Exception as e:
984
  print(f"Warning: Could not log usage data: {e}")
 
988
  print(f"Processing time: {processing_time:.2f} seconds")
989
  print(f"\nOutput files saved to: {args.output_dir}")
990
  print("Generated Files:", sorted(output_files))
991
+ if log_files:
992
+ print("Log Files:", sorted(log_files))
993
 
994
  except Exception as e:
995
+ print(
996
+ f"\nAn error occurred during the Word/Tabular anonymisation workflow: {e}"
997
+ )
998
+
999
  else:
1000
  print(f"Error: Unsupported file type '{file_extension}' for redaction.")
1001
  print("Supported types for redaction: .pdf, .png, .jpg, .jpeg")
1002
+ print(
1003
+ "Supported types for anonymisation: .docx, .xlsx, .xls, .csv, .parquet"
1004
+ )
1005
 
1006
  # Task 2: Duplicate Detection
1007
+ elif args.task == "deduplicate":
1008
  print("--- Starting Duplicate Detection Workflow... ---")
1009
  try:
1010
  from tools.find_duplicate_pages import run_duplicate_analysis
1011
+
1012
+ if args.duplicate_type == "pages":
1013
  # Page duplicate detection
1014
+ if file_extension == ".csv":
1015
+ print(
1016
+ "--- Detected OCR CSV file. Starting Page Duplicate Detection... ---"
1017
+ )
1018
+
1019
  start_time = time.time()
1020
 
1021
+ if args.combine_pages is True:
1022
  print("Combining pages...")
1023
  else:
1024
  print("Using line-level duplicate detection...")
1025
 
1026
  # Load the CSV file as a list for the duplicate analysis function
1027
+ (
1028
+ results_df,
1029
+ output_paths,
1030
+ full_data_by_file,
1031
+ processing_time,
1032
+ task_textbox,
1033
+ ) = run_duplicate_analysis(
1034
  files=args.input_file,
1035
  threshold=args.similarity_threshold,
1036
  min_words=args.min_word_count,
1037
  min_consecutive=args.min_consecutive_pages,
1038
  greedy_match=args.greedy_match,
1039
  combine_pages=args.combine_pages,
1040
+ output_folder=args.output_dir,
1041
  )
1042
+
1043
  end_time = time.time()
1044
  processing_time = end_time - start_time
1045
 
1046
  print("\n--- Page Duplicate Detection Complete ---")
1047
  print(f"Found {len(results_df)} duplicate matches")
1048
  print(f"\nOutput files saved to: {args.output_dir}")
1049
+ if output_paths:
1050
+ print("Generated Files:", sorted(output_paths))
1051
+
1052
  else:
1053
+ print(
1054
+ "Error: Page duplicate detection requires CSV files with OCR data."
1055
+ )
1056
  print("Please provide a CSV file containing OCR output data.")
1057
 
1058
  # Log usage data if logger is available
 
1060
  try:
1061
  # Extract file name for logging
1062
  print("Saving logs to CSV")
1063
+ doc_file_name = (
1064
+ os.path.basename(args.input_file[0])
1065
+ if args.display_file_names_in_logs
1066
+ else "document"
1067
+ )
1068
+ data_file_name = (
1069
+ "" # Not applicable for PDF/image redaction
1070
+ )
1071
+
1072
  # Determine if this was a Textract API call
1073
  is_textract_call = False
1074
+
1075
  # Count pages (approximate from page_sizes if available)
1076
  total_pages = len(page_sizes) if page_sizes else 1
1077
+
1078
  # Count API calls (approximate - would need to be tracked in the redaction function)
1079
  textract_queries = 0
1080
  comprehend_queries = 0
1081
+
1082
  # Format handwriting/signature options
1083
  handwriting_signature = ""
1084
+
1085
  log_redaction_usage(
1086
  logger=usage_logger,
1087
  session_hash=session_hash,
 
1100
  save_to_dynamodb=args.save_logs_to_dynamodb,
1101
  save_to_s3=args.upload_logs_to_s3,
1102
  s3_bucket=args.s3_bucket,
1103
+ s3_key_prefix=args.s3_logs_prefix,
1104
  )
1105
  except Exception as e:
1106
  print(f"Warning: Could not log usage data: {e}")
1107
+
1108
+ elif args.duplicate_type == "tabular":
1109
  # Tabular duplicate detection
1110
  from tools.find_duplicate_tabular import run_tabular_duplicate_detection
1111
+
1112
+ if file_extension in [".csv", ".xlsx", ".xls", ".parquet"]:
1113
+ print(
1114
+ "--- Detected tabular file. Starting Tabular Duplicate Detection... ---"
1115
+ )
1116
 
1117
  start_time = time.time()
1118
+
1119
+ (
1120
+ results_df,
1121
+ output_paths,
1122
+ full_data_by_file,
1123
+ processing_time,
1124
+ task_textbox,
1125
+ ) = run_tabular_duplicate_detection(
1126
  files=args.input_file,
1127
  threshold=args.similarity_threshold,
1128
  min_words=args.min_word_count,
 
1130
  output_folder=args.output_dir,
1131
  do_initial_clean_dup=args.do_initial_clean,
1132
  in_excel_tabular_sheets=args.excel_sheets,
1133
+ remove_duplicate_rows=args.remove_duplicate_rows,
1134
  )
1135
 
1136
  end_time = time.time()
 
1142
  # Extract file name for logging
1143
  print("Saving logs to CSV")
1144
  doc_file_name = ""
1145
+ data_file_name = (
1146
+ os.path.basename(args.input_file[0])
1147
+ if args.display_file_names_in_logs
1148
+ else "data_file"
1149
+ )
1150
+
1151
  # Determine if this was a Textract API call
1152
  is_textract_call = False
1153
+
1154
  # Count pages (approximate from page_sizes if available)
1155
  total_pages = len(page_sizes) if page_sizes else 1
1156
+
1157
  # Count API calls (approximate - would need to be tracked in the redaction function)
1158
  textract_queries = 0
1159
  comprehend_queries = 0
1160
+
1161
  # Format handwriting/signature options
1162
  handwriting_signature = ""
1163
+
1164
  log_redaction_usage(
1165
  logger=usage_logger,
1166
  session_hash=session_hash,
 
1179
  save_to_dynamodb=args.save_logs_to_dynamodb,
1180
  save_to_s3=args.upload_logs_to_s3,
1181
  s3_bucket=args.s3_bucket,
1182
+ s3_key_prefix=args.s3_logs_prefix,
1183
  )
1184
  except Exception as e:
1185
  print(f"Warning: Could not log usage data: {e}")
1186
+
1187
  print("\n--- Tabular Duplicate Detection Complete ---")
1188
  print(f"Found {len(results_df)} duplicate matches")
1189
  print(f"\nOutput files saved to: {args.output_dir}")
1190
+ if output_paths:
1191
+ print("Generated Files:", sorted(output_paths))
1192
+
1193
  else:
1194
+ print(
1195
+ "Error: Tabular duplicate detection requires CSV, Excel, or Parquet files."
1196
+ )
1197
  print("Supported types: .csv, .xlsx, .xls, .parquet")
1198
  else:
1199
  print(f"Error: Invalid duplicate type '{args.duplicate_type}'.")
1200
  print("Valid options: 'pages' or 'tabular'")
1201
+
1202
  except Exception as e:
1203
  print(f"\nAn error occurred during the duplicate detection workflow: {e}")
1204
 
1205
  # Task 3: Textract Batch Operations
1206
+ elif args.task == "textract":
1207
  print("--- Starting Textract Batch Operations Workflow... ---")
1208
+
1209
  if not args.textract_action:
1210
  print("Error: --textract_action is required for textract task.")
1211
  print("Valid options: 'submit', 'retrieve', or 'list'")
1212
  return
1213
+
1214
  try:
1215
+ if args.textract_action == "submit":
1216
+ from tools.textract_batch_call import (
1217
+ analyse_document_with_textract_api,
1218
+ load_in_textract_job_details,
1219
+ )
1220
+
1221
  # Submit document to Textract for analysis
1222
  if not args.input_file:
1223
  print("Error: --input_file is required for submit action.")
1224
  return
1225
+
1226
  print(f"--- Submitting document to Textract: {args.input_file} ---")
1227
 
1228
  start_time = time.time()
1229
+
1230
  # Load existing job details
1231
+ job_df = load_in_textract_job_details(
1232
+ load_s3_jobs_loc=args.s3_textract_document_logs_subfolder,
1233
+ load_local_jobs_loc=args.local_textract_document_logs_subfolder,
1234
+ )
1235
+
1236
  # Determine signature extraction options
1237
+ signature_options = (
1238
+ ["Extract handwriting", "Extract signatures"]
1239
+ if args.extract_signatures
1240
+ else ["Extract handwriting"]
1241
+ )
1242
+
1243
  # Use configured bucket or override
1244
  textract_bucket = args.textract_bucket if args.textract_bucket else ""
1245
+
1246
  # Submit the job
1247
+ (
1248
+ result_message,
1249
+ job_id,
1250
+ job_type,
1251
+ successful_job_number,
1252
+ is_textract_call,
1253
+ total_pages,
1254
+ task_textbox,
1255
+ ) = analyse_document_with_textract_api(
1256
  local_pdf_path=args.input_file,
1257
  s3_input_prefix=args.textract_input_prefix,
1258
  s3_output_prefix=args.textract_output_prefix,
 
1261
  general_s3_bucket_name=args.s3_bucket,
1262
  local_output_dir=args.output_dir,
1263
  handwrite_signature_checkbox=signature_options,
1264
+ aws_region=args.aws_region,
1265
  )
1266
 
1267
  end_time = time.time()
1268
  processing_time = end_time - start_time
1269
+
1270
+ print("\n--- Textract Job Submitted Successfully ---")
1271
  print(f"Job ID: {job_id}")
1272
  print(f"Job Type: {job_type}")
1273
  print(f"Message: {result_message}")
 
1278
  try:
1279
  # Extract file name for logging
1280
  print("Saving logs to CSV")
1281
+ doc_file_name = (
1282
+ os.path.basename(args.input_file[0])
1283
+ if args.display_file_names_in_logs
1284
+ else "document"
1285
+ )
1286
  data_file_name = ""
1287
+
1288
  # Determine if this was a Textract API call
1289
  is_textract_call = True
1290
  args.ocr_method == "AWS Textract"
1291
+
1292
  # Count API calls (approximate - would need to be tracked in the redaction function)
1293
  textract_queries = total_pages
1294
  comprehend_queries = 0
1295
+
1296
  # Format handwriting/signature options
1297
  handwriting_signature = ""
1298
+
1299
  log_redaction_usage(
1300
  logger=usage_logger,
1301
  session_hash=session_hash,
 
1314
  save_to_dynamodb=args.save_logs_to_dynamodb,
1315
  save_to_s3=args.upload_logs_to_s3,
1316
  s3_bucket=args.s3_bucket,
1317
+ s3_key_prefix=args.s3_logs_prefix,
1318
  )
1319
  except Exception as e:
1320
  print(f"Warning: Could not log usage data: {e}")
1321
+
1322
+ elif args.textract_action == "retrieve":
1323
  print(f"--- Retrieving Textract results for Job ID: {args.job_id} ---")
1324
 
1325
+ from tools.textract_batch_call import (
1326
+ load_in_textract_job_details,
1327
+ poll_whole_document_textract_analysis_progress_and_download,
1328
+ )
1329
+
1330
  # Retrieve results by job ID
1331
  if not args.job_id:
1332
  print("Error: --job_id is required for retrieve action.")
1333
  return
1334
+
1335
  # Load existing job details to get job type
1336
  print("Loading existing job details...")
1337
+ job_df = load_in_textract_job_details(
1338
+ load_s3_jobs_loc=args.s3_textract_document_logs_subfolder,
1339
+ load_local_jobs_loc=args.local_textract_document_logs_subfolder,
1340
+ )
1341
+
1342
  # Find job type from the dataframe
1343
  job_type = "document_text_detection" # default
1344
  if not job_df.empty and "job_id" in job_df.columns:
1345
  matching_jobs = job_df.loc[job_df["job_id"] == args.job_id]
1346
  if not matching_jobs.empty and "job_type" in matching_jobs.columns:
1347
  job_type = matching_jobs.iloc[0]["job_type"]
1348
+
1349
  # Use configured bucket or override
1350
  textract_bucket = args.textract_bucket if args.textract_bucket else ""
1351
+
1352
  # Poll for completion and download results
1353
  print("Polling for completion and downloading results...")
1354
+ downloaded_file_path, job_status, updated_job_df, output_filename = (
1355
+ poll_whole_document_textract_analysis_progress_and_download(
1356
+ job_id=args.job_id,
1357
+ job_type_dropdown=job_type,
1358
+ s3_output_prefix=args.textract_output_prefix,
1359
+ pdf_filename="", # Will be determined from job details
1360
+ job_df=job_df,
1361
+ s3_bucket_name=textract_bucket,
1362
+ load_s3_jobs_loc=args.s3_textract_document_logs_subfolder,
1363
+ load_local_jobs_loc=args.local_textract_document_logs_subfolder,
1364
+ local_output_dir=args.output_dir,
1365
+ poll_interval_seconds=args.poll_interval,
1366
+ max_polling_attempts=args.max_poll_attempts,
1367
+ )
1368
  )
1369
+
1370
+ print("\n--- Textract Results Retrieved Successfully ---")
1371
  print(f"Job Status: {job_status}")
1372
  print(f"Downloaded File: {downloaded_file_path}")
1373
+ # print(f"Output Filename: {output_filename}")
1374
+
1375
+ elif args.textract_action == "list":
1376
  from tools.textract_batch_call import load_in_textract_job_details
1377
+
1378
  # List recent Textract jobs
1379
  print("--- Listing Recent Textract Jobs ---")
1380
+
1381
+ job_df = load_in_textract_job_details(
1382
+ load_s3_jobs_loc=args.s3_textract_document_logs_subfolder,
1383
+ load_local_jobs_loc=args.local_textract_document_logs_subfolder,
1384
+ )
1385
+
1386
  if job_df.empty:
1387
  print("No recent Textract jobs found.")
1388
  else:
 
1395
  print(f"Signatures: {job.get('signature_extraction', 'N/A')}")
1396
  print(f"Date: {job.get('job_date_time', 'N/A')}")
1397
  print("-" * 80)
1398
+
1399
  else:
1400
  print(f"Error: Invalid textract_action '{args.textract_action}'.")
1401
  print("Valid options: 'submit', 'retrieve', or 'list'")
1402
+
1403
  except Exception as e:
1404
  print(f"\nAn error occurred during the Textract workflow: {e}")
1405
+
1406
  else:
1407
  print(f"Error: Invalid task '{args.task}'.")
1408
  print("Valid options: 'redact', 'deduplicate', or 'textract'")
1409
 
1410
+
1411
  if __name__ == "__main__":
1412
+ main()
lambda_entrypoint.py CHANGED
@@ -1,6 +1,7 @@
1
- import boto3
2
- import os
3
  import json
 
 
 
4
 
5
  # Import the main function from your CLI script
6
  from cli_redact import main as cli_main
@@ -16,6 +17,7 @@ TMP_DIR = "/tmp"
16
  INPUT_DIR = os.path.join(TMP_DIR, "input")
17
  OUTPUT_DIR = os.path.join(TMP_DIR, "output")
18
 
 
19
  def download_file_from_s3(bucket_name, key, download_path):
20
  """Download a file from S3 to the local filesystem."""
21
  try:
@@ -25,6 +27,7 @@ def download_file_from_s3(bucket_name, key, download_path):
25
  print(f"Error downloading from S3: {e}")
26
  raise
27
 
 
28
  def upload_directory_to_s3(local_directory, bucket_name, s3_prefix):
29
  """Upload all files from a local directory to an S3 prefix."""
30
  for root, _, files in os.walk(local_directory):
@@ -33,14 +36,17 @@ def upload_directory_to_s3(local_directory, bucket_name, s3_prefix):
33
  # Create a relative path to maintain directory structure if needed
34
  relative_path = os.path.relpath(local_file_path, local_directory)
35
  output_key = os.path.join(s3_prefix, relative_path)
36
-
37
  try:
38
  s3_client.upload_file(local_file_path, bucket_name, output_key)
39
- print(f"Successfully uploaded {local_file_path} to s3://{bucket_name}/{output_key}")
 
 
40
  except Exception as e:
41
  print(f"Error uploading to S3: {e}")
42
  raise
43
 
 
44
  def lambda_handler(event, context):
45
  print(f"Received event: {json.dumps(event)}")
46
 
@@ -51,29 +57,33 @@ def lambda_handler(event, context):
51
  # 2. Extract information from the event
52
  # Assumes the event is triggered by S3 and may contain an 'arguments' payload
53
  try:
54
- record = event['Records'][0]
55
- bucket_name = record['s3']['bucket']['name']
56
- input_key = record['s3']['object']['key']
57
-
58
  # The user metadata can be used to pass arguments
59
  # This is more robust than embedding them in the main event body
60
  response = s3_client.head_object(Bucket=bucket_name, Key=input_key)
61
- metadata = response.get('Metadata', {})
62
  # Arguments can be passed as a JSON string in metadata
63
- arguments = json.loads(metadata.get('arguments', '{}'))
64
 
65
  except (KeyError, IndexError) as e:
66
- print(f"Could not parse S3 event record: {e}. Checking for direct invocation payload.")
 
 
67
  # Fallback for direct invocation (e.g., from Step Functions or manual test)
68
- bucket_name = event.get('bucket_name')
69
- input_key = event.get('input_key')
70
- arguments = event.get('arguments', {})
71
  if not all([bucket_name, input_key]):
72
- raise ValueError("Missing 'bucket_name' or 'input_key' in direct invocation event.")
 
 
73
 
74
  print(f"Processing s3://{bucket_name}/{input_key}")
75
  print(f"With arguments: {arguments}")
76
-
77
  # 3. Download the main input file
78
  input_file_path = os.path.join(INPUT_DIR, os.path.basename(input_key))
79
  download_file_from_s3(bucket_name, input_key, input_file_path)
@@ -81,106 +91,117 @@ def lambda_handler(event, context):
81
  # 4. Prepare arguments for the CLI function
82
  # This dictionary should mirror the one in your app.py's "direct mode"
83
  cli_args = {
84
- 'task': arguments.get('task', 'redact'),
85
- 'input_file': input_file_path,
86
- 'output_dir': OUTPUT_DIR,
87
- 'input_dir': INPUT_DIR,
88
- 'language': arguments.get('language', 'en_core_web_lg'),
89
- 'pii_detector': arguments.get('pii_detector', 'Local'), # Default to local
90
- 'username': arguments.get('username', 'lambda_user'),
91
- 'save_to_user_folders': arguments.get('save_to_user_folders', 'False'),
92
- 'ocr_method': arguments.get('ocr_method', 'Tesseract OCR - all PDF types'),
93
- 'page_min': int(arguments.get('page_min', 0)),
94
- 'page_max': int(arguments.get('page_max', 0)),
95
- 'handwrite_signature_extraction': arguments.get('handwrite_signature_checkbox', ['Extract handwriting', 'Extract signatures']),
96
- 'extract_forms': arguments.get('extract_forms', False),
97
- 'extract_tables': arguments.get('extract_tables', False),
98
- 'extract_layout': arguments.get('extract_layout', False),
99
-
 
 
100
  # General arguments
101
- 'local_redact_entities': arguments.get('local_redact_entities', []),
102
- 'aws_redact_entities': arguments.get('aws_redact_entities', []),
103
- 'cost_code': arguments.get('cost_code', ''),
104
- 'save_logs_to_csv': arguments.get('save_logs_to_csv', 'False'),
105
- 'save_logs_to_dynamodb': arguments.get('save_logs_to_dynamodb', 'False'),
106
- 'display_file_names_in_logs': arguments.get('display_file_names_in_logs', 'True'),
107
- 'upload_logs_to_s3': arguments.get('upload_logs_to_s3', 'False'),
108
- 's3_logs_prefix': arguments.get('s3_logs_prefix', ''),
109
- 'do_initial_clean': arguments.get('do_initial_clean', 'False'),
110
-
 
111
  # PDF/Image specific arguments
112
- 'images_dpi': float(arguments.get('images_dpi', 300.0)),
113
- 'chosen_local_ocr_model': arguments.get('chosen_local_ocr_model', 'tesseract'),
114
- 'preprocess_local_ocr_images': arguments.get('preprocess_local_ocr_images', 'False'),
115
-
 
116
  # Handle optional files like allow/deny lists
117
- 'allow_list_file': arguments.get('allow_list_file', ""),
118
- 'deny_list_file': arguments.get('deny_list_file', ""),
119
- 'redact_whole_page_file': arguments.get('redact_whole_page_file', ""),
120
-
121
  # Tabular/Anonymisation arguments
122
- 'excel_sheets': arguments.get('excel_sheets', []),
123
- 'fuzzy_mistakes': int(arguments.get('fuzzy_mistakes', 0)),
124
- 'match_fuzzy_whole_phrase_bool': arguments.get('match_fuzzy_whole_phrase_bool', 'True'),
125
-
 
126
  # Deduplication specific arguments
127
- 'duplicate_type': arguments.get('duplicate_type', 'pages'),
128
- 'similarity_threshold': float(arguments.get('similarity_threshold', 0.95)),
129
- 'min_word_count': int(arguments.get('min_word_count', 3)),
130
- 'min_consecutive_pages': int(arguments.get('min_consecutive_pages', 1)),
131
- 'greedy_match': arguments.get('greedy_match', 'False'),
132
- 'combine_pages': arguments.get('combine_pages', 'True'),
133
- 'search_query': arguments.get('search_query', ""),
134
- 'text_columns': arguments.get('text_columns', []),
135
- 'remove_duplicate_rows': arguments.get('remove_duplicate_rows', 'True'),
136
- 'anon_strategy': arguments.get('anon_strategy', 'redact'),
137
-
138
  # Textract specific arguments
139
- 'textract_action': arguments.get('textract_action', ''),
140
- 'job_id': arguments.get('job_id', ''),
141
- 'extract_signatures': arguments.get('extract_signatures', False),
142
- 'textract_bucket': arguments.get('textract_bucket', ''),
143
- 'textract_input_prefix': arguments.get('textract_input_prefix', ''),
144
- 'textract_output_prefix': arguments.get('textract_output_prefix', ''),
145
- 's3_textract_document_logs_subfolder': arguments.get('s3_textract_document_logs_subfolder', ''),
146
- 'local_textract_document_logs_subfolder': arguments.get('local_textract_document_logs_subfolder', ''),
147
- 'poll_interval': int(arguments.get('poll_interval', 30)),
148
- 'max_poll_attempts': int(arguments.get('max_poll_attempts', 120)),
149
-
 
 
 
150
  # AWS credentials (use IAM Role instead of keys)
151
- 'aws_access_key': None,
152
- 'aws_secret_key': None,
153
- 'aws_region': os.getenv("AWS_REGION", ""),
154
- 's3_bucket': bucket_name,
155
-
156
  # Set defaults for boolean flags
157
- 'prepare_images': arguments.get('prepare_images', True),
158
- 'compress_redacted_pdf': arguments.get('compress_redacted_pdf', False),
159
- 'return_pdf_end_of_redaction': arguments.get('return_pdf_end_of_redaction', True)
 
 
160
  }
161
 
162
  # Combine extraction options
163
- extraction_options = list(cli_args['handwrite_signature_extraction']) if cli_args['handwrite_signature_extraction'] else []
164
- if cli_args['extract_forms']:
165
- extraction_options.append('Extract forms')
166
- if cli_args['extract_tables']:
167
- extraction_options.append('Extract tables')
168
- if cli_args['extract_layout']:
169
- extraction_options.append('Extract layout')
170
- cli_args['handwrite_signature_extraction'] = extraction_options
 
 
 
 
171
 
172
  # Download optional files if they are specified
173
- allow_list_key = arguments.get('allow_list_file')
174
  if allow_list_key:
175
- allow_list_path = os.path.join(INPUT_DIR, 'allow_list.csv')
176
  download_file_from_s3(bucket_name, allow_list_key, allow_list_path)
177
- cli_args['allow_list_file'] = allow_list_path
178
-
179
- deny_list_key = arguments.get('deny_list_file')
180
  if deny_list_key:
181
- deny_list_path = os.path.join(INPUT_DIR, 'deny_list.csv')
182
  download_file_from_s3(bucket_name, deny_list_key, deny_list_path)
183
- cli_args['deny_list_file'] = deny_list_path
184
 
185
  # 5. Execute the main application logic
186
  try:
@@ -195,10 +216,14 @@ def lambda_handler(event, context):
195
 
196
  # 6. Upload results back to S3
197
  output_s3_prefix = f"output/{os.path.splitext(os.path.basename(input_key))[0]}"
198
- print(f"Uploading contents of {OUTPUT_DIR} to s3://{bucket_name}/{output_s3_prefix}/")
 
 
199
  upload_directory_to_s3(OUTPUT_DIR, bucket_name, output_s3_prefix)
200
 
201
  return {
202
  "statusCode": 200,
203
- "body": json.dumps(f"Processing complete for {input_key}. Output saved to s3://{bucket_name}/{output_s3_prefix}/")
204
- }
 
 
 
 
 
1
  import json
2
+ import os
3
+
4
+ import boto3
5
 
6
  # Import the main function from your CLI script
7
  from cli_redact import main as cli_main
 
17
  INPUT_DIR = os.path.join(TMP_DIR, "input")
18
  OUTPUT_DIR = os.path.join(TMP_DIR, "output")
19
 
20
+
21
  def download_file_from_s3(bucket_name, key, download_path):
22
  """Download a file from S3 to the local filesystem."""
23
  try:
 
27
  print(f"Error downloading from S3: {e}")
28
  raise
29
 
30
+
31
  def upload_directory_to_s3(local_directory, bucket_name, s3_prefix):
32
  """Upload all files from a local directory to an S3 prefix."""
33
  for root, _, files in os.walk(local_directory):
 
36
  # Create a relative path to maintain directory structure if needed
37
  relative_path = os.path.relpath(local_file_path, local_directory)
38
  output_key = os.path.join(s3_prefix, relative_path)
39
+
40
  try:
41
  s3_client.upload_file(local_file_path, bucket_name, output_key)
42
+ print(
43
+ f"Successfully uploaded {local_file_path} to s3://{bucket_name}/{output_key}"
44
+ )
45
  except Exception as e:
46
  print(f"Error uploading to S3: {e}")
47
  raise
48
 
49
+
50
  def lambda_handler(event, context):
51
  print(f"Received event: {json.dumps(event)}")
52
 
 
57
  # 2. Extract information from the event
58
  # Assumes the event is triggered by S3 and may contain an 'arguments' payload
59
  try:
60
+ record = event["Records"][0]
61
+ bucket_name = record["s3"]["bucket"]["name"]
62
+ input_key = record["s3"]["object"]["key"]
63
+
64
  # The user metadata can be used to pass arguments
65
  # This is more robust than embedding them in the main event body
66
  response = s3_client.head_object(Bucket=bucket_name, Key=input_key)
67
+ metadata = response.get("Metadata", {})
68
  # Arguments can be passed as a JSON string in metadata
69
+ arguments = json.loads(metadata.get("arguments", "{}"))
70
 
71
  except (KeyError, IndexError) as e:
72
+ print(
73
+ f"Could not parse S3 event record: {e}. Checking for direct invocation payload."
74
+ )
75
  # Fallback for direct invocation (e.g., from Step Functions or manual test)
76
+ bucket_name = event.get("bucket_name")
77
+ input_key = event.get("input_key")
78
+ arguments = event.get("arguments", {})
79
  if not all([bucket_name, input_key]):
80
+ raise ValueError(
81
+ "Missing 'bucket_name' or 'input_key' in direct invocation event."
82
+ )
83
 
84
  print(f"Processing s3://{bucket_name}/{input_key}")
85
  print(f"With arguments: {arguments}")
86
+
87
  # 3. Download the main input file
88
  input_file_path = os.path.join(INPUT_DIR, os.path.basename(input_key))
89
  download_file_from_s3(bucket_name, input_key, input_file_path)
 
91
  # 4. Prepare arguments for the CLI function
92
  # This dictionary should mirror the one in your app.py's "direct mode"
93
  cli_args = {
94
+ "task": arguments.get("task", "redact"),
95
+ "input_file": input_file_path,
96
+ "output_dir": OUTPUT_DIR,
97
+ "input_dir": INPUT_DIR,
98
+ "language": arguments.get("language", "en_core_web_lg"),
99
+ "pii_detector": arguments.get("pii_detector", "Local"), # Default to local
100
+ "username": arguments.get("username", "lambda_user"),
101
+ "save_to_user_folders": arguments.get("save_to_user_folders", "False"),
102
+ "ocr_method": arguments.get("ocr_method", "Tesseract OCR - all PDF types"),
103
+ "page_min": int(arguments.get("page_min", 0)),
104
+ "page_max": int(arguments.get("page_max", 0)),
105
+ "handwrite_signature_extraction": arguments.get(
106
+ "handwrite_signature_checkbox",
107
+ ["Extract handwriting", "Extract signatures"],
108
+ ),
109
+ "extract_forms": arguments.get("extract_forms", False),
110
+ "extract_tables": arguments.get("extract_tables", False),
111
+ "extract_layout": arguments.get("extract_layout", False),
112
  # General arguments
113
+ "local_redact_entities": arguments.get("local_redact_entities", []),
114
+ "aws_redact_entities": arguments.get("aws_redact_entities", []),
115
+ "cost_code": arguments.get("cost_code", ""),
116
+ "save_logs_to_csv": arguments.get("save_logs_to_csv", "False"),
117
+ "save_logs_to_dynamodb": arguments.get("save_logs_to_dynamodb", "False"),
118
+ "display_file_names_in_logs": arguments.get(
119
+ "display_file_names_in_logs", "True"
120
+ ),
121
+ "upload_logs_to_s3": arguments.get("upload_logs_to_s3", "False"),
122
+ "s3_logs_prefix": arguments.get("s3_logs_prefix", ""),
123
+ "do_initial_clean": arguments.get("do_initial_clean", "False"),
124
  # PDF/Image specific arguments
125
+ "images_dpi": float(arguments.get("images_dpi", 300.0)),
126
+ "chosen_local_ocr_model": arguments.get("chosen_local_ocr_model", "tesseract"),
127
+ "preprocess_local_ocr_images": arguments.get(
128
+ "preprocess_local_ocr_images", "False"
129
+ ),
130
  # Handle optional files like allow/deny lists
131
+ "allow_list_file": arguments.get("allow_list_file", ""),
132
+ "deny_list_file": arguments.get("deny_list_file", ""),
133
+ "redact_whole_page_file": arguments.get("redact_whole_page_file", ""),
 
134
  # Tabular/Anonymisation arguments
135
+ "excel_sheets": arguments.get("excel_sheets", []),
136
+ "fuzzy_mistakes": int(arguments.get("fuzzy_mistakes", 0)),
137
+ "match_fuzzy_whole_phrase_bool": arguments.get(
138
+ "match_fuzzy_whole_phrase_bool", "True"
139
+ ),
140
  # Deduplication specific arguments
141
+ "duplicate_type": arguments.get("duplicate_type", "pages"),
142
+ "similarity_threshold": float(arguments.get("similarity_threshold", 0.95)),
143
+ "min_word_count": int(arguments.get("min_word_count", 3)),
144
+ "min_consecutive_pages": int(arguments.get("min_consecutive_pages", 1)),
145
+ "greedy_match": arguments.get("greedy_match", "False"),
146
+ "combine_pages": arguments.get("combine_pages", "True"),
147
+ "search_query": arguments.get("search_query", ""),
148
+ "text_columns": arguments.get("text_columns", []),
149
+ "remove_duplicate_rows": arguments.get("remove_duplicate_rows", "True"),
150
+ "anon_strategy": arguments.get("anon_strategy", "redact"),
 
151
  # Textract specific arguments
152
+ "textract_action": arguments.get("textract_action", ""),
153
+ "job_id": arguments.get("job_id", ""),
154
+ "extract_signatures": arguments.get("extract_signatures", False),
155
+ "textract_bucket": arguments.get("textract_bucket", ""),
156
+ "textract_input_prefix": arguments.get("textract_input_prefix", ""),
157
+ "textract_output_prefix": arguments.get("textract_output_prefix", ""),
158
+ "s3_textract_document_logs_subfolder": arguments.get(
159
+ "s3_textract_document_logs_subfolder", ""
160
+ ),
161
+ "local_textract_document_logs_subfolder": arguments.get(
162
+ "local_textract_document_logs_subfolder", ""
163
+ ),
164
+ "poll_interval": int(arguments.get("poll_interval", 30)),
165
+ "max_poll_attempts": int(arguments.get("max_poll_attempts", 120)),
166
  # AWS credentials (use IAM Role instead of keys)
167
+ "aws_access_key": None,
168
+ "aws_secret_key": None,
169
+ "aws_region": os.getenv("AWS_REGION", ""),
170
+ "s3_bucket": bucket_name,
 
171
  # Set defaults for boolean flags
172
+ "prepare_images": arguments.get("prepare_images", True),
173
+ "compress_redacted_pdf": arguments.get("compress_redacted_pdf", False),
174
+ "return_pdf_end_of_redaction": arguments.get(
175
+ "return_pdf_end_of_redaction", True
176
+ ),
177
  }
178
 
179
  # Combine extraction options
180
+ extraction_options = (
181
+ list(cli_args["handwrite_signature_extraction"])
182
+ if cli_args["handwrite_signature_extraction"]
183
+ else []
184
+ )
185
+ if cli_args["extract_forms"]:
186
+ extraction_options.append("Extract forms")
187
+ if cli_args["extract_tables"]:
188
+ extraction_options.append("Extract tables")
189
+ if cli_args["extract_layout"]:
190
+ extraction_options.append("Extract layout")
191
+ cli_args["handwrite_signature_extraction"] = extraction_options
192
 
193
  # Download optional files if they are specified
194
+ allow_list_key = arguments.get("allow_list_file")
195
  if allow_list_key:
196
+ allow_list_path = os.path.join(INPUT_DIR, "allow_list.csv")
197
  download_file_from_s3(bucket_name, allow_list_key, allow_list_path)
198
+ cli_args["allow_list_file"] = allow_list_path
199
+
200
+ deny_list_key = arguments.get("deny_list_file")
201
  if deny_list_key:
202
+ deny_list_path = os.path.join(INPUT_DIR, "deny_list.csv")
203
  download_file_from_s3(bucket_name, deny_list_key, deny_list_path)
204
+ cli_args["deny_list_file"] = deny_list_path
205
 
206
  # 5. Execute the main application logic
207
  try:
 
216
 
217
  # 6. Upload results back to S3
218
  output_s3_prefix = f"output/{os.path.splitext(os.path.basename(input_key))[0]}"
219
+ print(
220
+ f"Uploading contents of {OUTPUT_DIR} to s3://{bucket_name}/{output_s3_prefix}/"
221
+ )
222
  upload_directory_to_s3(OUTPUT_DIR, bucket_name, output_s3_prefix)
223
 
224
  return {
225
  "statusCode": 200,
226
+ "body": json.dumps(
227
+ f"Processing complete for {input_key}. Output saved to s3://{bucket_name}/{output_s3_prefix}/"
228
+ ),
229
+ }
load_dynamo_logs.py CHANGED
@@ -1,20 +1,25 @@
1
- import boto3
2
  import csv
3
- from decimal import Decimal
4
  import datetime
5
- from boto3.dynamodb.conditions import Key
6
 
7
- from tools.config import AWS_REGION, ACCESS_LOG_DYNAMODB_TABLE_NAME, FEEDBACK_LOG_DYNAMODB_TABLE_NAME, USAGE_LOG_DYNAMODB_TABLE_NAME, OUTPUT_FOLDER
 
 
 
 
 
 
8
 
9
  # Replace with your actual table name and region
10
- TABLE_NAME = USAGE_LOG_DYNAMODB_TABLE_NAME # Choose as appropriate
11
  REGION = AWS_REGION
12
- CSV_OUTPUT = OUTPUT_FOLDER + 'dynamodb_logs_export.csv'
13
 
14
  # Create DynamoDB resource
15
- dynamodb = boto3.resource('dynamodb', region_name=REGION)
16
  table = dynamodb.Table(TABLE_NAME)
17
 
 
18
  # Helper function to convert Decimal to float or int
19
  def convert_types(item):
20
  new_item = {}
@@ -25,11 +30,11 @@ def convert_types(item):
25
  # Handle Strings that might be dates
26
  elif isinstance(value, str):
27
  try:
28
- # Attempt to parse a common ISO 8601 format.
29
  # The .replace() handles the 'Z' for Zulu/UTC time.
30
- dt_obj = datetime.datetime.fromisoformat(value.replace('Z', '+00:00'))
31
  # Now that we have a datetime object, format it as desired
32
- new_item[key] = dt_obj.strftime('%Y-%m-%d %H:%M:%S.%f')[:-3]
33
  except (ValueError, TypeError):
34
  # If it fails to parse, it's just a regular string
35
  new_item[key] = value
@@ -38,18 +43,20 @@ def convert_types(item):
38
  new_item[key] = value
39
  return new_item
40
 
 
41
  # Paginated scan
42
  def scan_table():
43
  items = []
44
  response = table.scan()
45
- items.extend(response['Items'])
46
 
47
- while 'LastEvaluatedKey' in response:
48
- response = table.scan(ExclusiveStartKey=response['LastEvaluatedKey'])
49
- items.extend(response['Items'])
50
 
51
  return items
52
 
 
53
  # Export to CSV
54
  # Export to CSV
55
  def export_to_csv(items, output_path, fields_to_drop: list = None):
@@ -59,25 +66,22 @@ def export_to_csv(items, output_path, fields_to_drop: list = None):
59
 
60
  # Use a set for efficient lookup
61
  drop_set = set(fields_to_drop or [])
62
-
63
  # Get a comprehensive list of all possible headers from all items
64
  all_keys = set()
65
  for item in items:
66
  all_keys.update(item.keys())
67
-
68
  # Determine the final fieldnames by subtracting the ones to drop
69
  fieldnames = sorted(list(all_keys - drop_set))
70
-
71
  print("Final CSV columns will be:", fieldnames)
72
 
73
- with open(output_path, 'w', newline='', encoding='utf-8-sig') as csvfile:
74
  # The key fix is here: extrasaction='ignore'
75
  # restval='' is also good practice to handle rows that are missing a key
76
  writer = csv.DictWriter(
77
- csvfile,
78
- fieldnames=fieldnames,
79
- extrasaction='ignore',
80
- restval=''
81
  )
82
  writer.writeheader()
83
 
@@ -88,6 +92,7 @@ def export_to_csv(items, output_path, fields_to_drop: list = None):
88
 
89
  print(f"Exported {len(items)} items to {output_path}")
90
 
 
91
  # Run export
92
  items = scan_table()
93
- export_to_csv(items, CSV_OUTPUT, fields_to_drop=[])
 
 
1
  import csv
 
2
  import datetime
3
+ from decimal import Decimal
4
 
5
+ import boto3
6
+
7
+ from tools.config import (
8
+ AWS_REGION,
9
+ OUTPUT_FOLDER,
10
+ USAGE_LOG_DYNAMODB_TABLE_NAME,
11
+ )
12
 
13
  # Replace with your actual table name and region
14
+ TABLE_NAME = USAGE_LOG_DYNAMODB_TABLE_NAME # Choose as appropriate
15
  REGION = AWS_REGION
16
+ CSV_OUTPUT = OUTPUT_FOLDER + "dynamodb_logs_export.csv"
17
 
18
  # Create DynamoDB resource
19
+ dynamodb = boto3.resource("dynamodb", region_name=REGION)
20
  table = dynamodb.Table(TABLE_NAME)
21
 
22
+
23
  # Helper function to convert Decimal to float or int
24
  def convert_types(item):
25
  new_item = {}
 
30
  # Handle Strings that might be dates
31
  elif isinstance(value, str):
32
  try:
33
+ # Attempt to parse a common ISO 8601 format.
34
  # The .replace() handles the 'Z' for Zulu/UTC time.
35
+ dt_obj = datetime.datetime.fromisoformat(value.replace("Z", "+00:00"))
36
  # Now that we have a datetime object, format it as desired
37
+ new_item[key] = dt_obj.strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]
38
  except (ValueError, TypeError):
39
  # If it fails to parse, it's just a regular string
40
  new_item[key] = value
 
43
  new_item[key] = value
44
  return new_item
45
 
46
+
47
  # Paginated scan
48
  def scan_table():
49
  items = []
50
  response = table.scan()
51
+ items.extend(response["Items"])
52
 
53
+ while "LastEvaluatedKey" in response:
54
+ response = table.scan(ExclusiveStartKey=response["LastEvaluatedKey"])
55
+ items.extend(response["Items"])
56
 
57
  return items
58
 
59
+
60
  # Export to CSV
61
  # Export to CSV
62
  def export_to_csv(items, output_path, fields_to_drop: list = None):
 
66
 
67
  # Use a set for efficient lookup
68
  drop_set = set(fields_to_drop or [])
69
+
70
  # Get a comprehensive list of all possible headers from all items
71
  all_keys = set()
72
  for item in items:
73
  all_keys.update(item.keys())
74
+
75
  # Determine the final fieldnames by subtracting the ones to drop
76
  fieldnames = sorted(list(all_keys - drop_set))
77
+
78
  print("Final CSV columns will be:", fieldnames)
79
 
80
+ with open(output_path, "w", newline="", encoding="utf-8-sig") as csvfile:
81
  # The key fix is here: extrasaction='ignore'
82
  # restval='' is also good practice to handle rows that are missing a key
83
  writer = csv.DictWriter(
84
+ csvfile, fieldnames=fieldnames, extrasaction="ignore", restval=""
 
 
 
85
  )
86
  writer.writeheader()
87
 
 
92
 
93
  print(f"Exported {len(items)} items to {output_path}")
94
 
95
+
96
  # Run export
97
  items = scan_table()
98
+ export_to_csv(items, CSV_OUTPUT, fields_to_drop=[])
load_s3_logs.py CHANGED
@@ -1,39 +1,53 @@
 
 
 
1
  import boto3
2
  import pandas as pd
3
- from io import StringIO
4
- from datetime import datetime
5
- from tools.config import DOCUMENT_REDACTION_BUCKET, AWS_ACCESS_KEY, AWS_SECRET_KEY, AWS_REGION, OUTPUT_FOLDER
 
 
 
 
 
6
 
7
  # Combine together log files that can be then used for e.g. dashboarding and financial tracking.
8
 
9
  # S3 setup. Try to use provided keys (needs S3 permissions), otherwise assume AWS SSO connection
10
  if AWS_ACCESS_KEY and AWS_SECRET_KEY and AWS_REGION:
11
- s3 = boto3.client('s3',
12
- aws_access_key_id=AWS_ACCESS_KEY,
13
- aws_secret_access_key=AWS_SECRET_KEY,
14
- region_name=AWS_REGION)
15
- else: s3 = boto3.client('s3')
 
 
 
16
 
17
  bucket_name = DOCUMENT_REDACTION_BUCKET
18
- prefix = 'usage/' # 'feedback/' # 'logs/' # Change as needed - top-level folder where logs are stored
19
- earliest_date = '20250409' # Earliest date of logs folder retrieved
20
- latest_date = '20250423' # Latest date of logs folder retrieved
 
21
 
22
  # Function to list all files in a folder
23
  def list_files_in_s3(bucket, prefix):
24
  response = s3.list_objects_v2(Bucket=bucket, Prefix=prefix)
25
- if 'Contents' in response:
26
- return [content['Key'] for content in response['Contents']]
27
  return []
28
 
 
29
  # Function to filter date range
30
  def is_within_date_range(date_str, start_date, end_date):
31
- date_obj = datetime.strptime(date_str, '%Y%m%d')
32
  return start_date <= date_obj <= end_date
33
 
 
34
  # Define the date range
35
- start_date = datetime.strptime(earliest_date, '%Y%m%d') # Replace with your start date
36
- end_date = datetime.strptime(latest_date, '%Y%m%d') # Replace with your end date
37
 
38
  # List all subfolders under 'usage/'
39
  all_files = list_files_in_s3(bucket_name, prefix)
@@ -41,10 +55,13 @@ all_files = list_files_in_s3(bucket_name, prefix)
41
  # Filter based on date range
42
  log_files = []
43
  for file in all_files:
44
- parts = file.split('/')
45
  if len(parts) >= 3:
46
  date_str = parts[1]
47
- if is_within_date_range(date_str, start_date, end_date) and parts[-1] == 'log.csv':
 
 
 
48
  log_files.append(file)
49
 
50
  # Download, read and concatenate CSV files into a pandas DataFrame
@@ -53,9 +70,10 @@ for log_file in log_files:
53
  # Download the file
54
  obj = s3.get_object(Bucket=bucket_name, Key=log_file)
55
  try:
56
- csv_content = obj['Body'].read().decode('utf-8')
57
- except:
58
- csv_content = obj['Body'].read().decode('latin-1')
 
59
 
60
  # Read CSV content into pandas DataFrame
61
  try:
@@ -71,7 +89,7 @@ if df_list:
71
  concatenated_df = pd.concat(df_list, ignore_index=True)
72
 
73
  # Save the concatenated DataFrame to a CSV file
74
- concatenated_df.to_csv(OUTPUT_FOLDER + 'consolidated_s3_logs.csv', index=False)
75
  print("Consolidated CSV saved as 'consolidated_s3_logs.csv'")
76
  else:
77
  print("No log files found in the given date range.")
 
1
+ from datetime import datetime
2
+ from io import StringIO
3
+
4
  import boto3
5
  import pandas as pd
6
+
7
+ from tools.config import (
8
+ AWS_ACCESS_KEY,
9
+ AWS_REGION,
10
+ AWS_SECRET_KEY,
11
+ DOCUMENT_REDACTION_BUCKET,
12
+ OUTPUT_FOLDER,
13
+ )
14
 
15
  # Combine together log files that can be then used for e.g. dashboarding and financial tracking.
16
 
17
  # S3 setup. Try to use provided keys (needs S3 permissions), otherwise assume AWS SSO connection
18
  if AWS_ACCESS_KEY and AWS_SECRET_KEY and AWS_REGION:
19
+ s3 = boto3.client(
20
+ "s3",
21
+ aws_access_key_id=AWS_ACCESS_KEY,
22
+ aws_secret_access_key=AWS_SECRET_KEY,
23
+ region_name=AWS_REGION,
24
+ )
25
+ else:
26
+ s3 = boto3.client("s3")
27
 
28
  bucket_name = DOCUMENT_REDACTION_BUCKET
29
+ prefix = "usage/" # 'feedback/' # 'logs/' # Change as needed - top-level folder where logs are stored
30
+ earliest_date = "20250409" # Earliest date of logs folder retrieved
31
+ latest_date = "20250423" # Latest date of logs folder retrieved
32
+
33
 
34
  # Function to list all files in a folder
35
  def list_files_in_s3(bucket, prefix):
36
  response = s3.list_objects_v2(Bucket=bucket, Prefix=prefix)
37
+ if "Contents" in response:
38
+ return [content["Key"] for content in response["Contents"]]
39
  return []
40
 
41
+
42
  # Function to filter date range
43
  def is_within_date_range(date_str, start_date, end_date):
44
+ date_obj = datetime.strptime(date_str, "%Y%m%d")
45
  return start_date <= date_obj <= end_date
46
 
47
+
48
  # Define the date range
49
+ start_date = datetime.strptime(earliest_date, "%Y%m%d") # Replace with your start date
50
+ end_date = datetime.strptime(latest_date, "%Y%m%d") # Replace with your end date
51
 
52
  # List all subfolders under 'usage/'
53
  all_files = list_files_in_s3(bucket_name, prefix)
 
55
  # Filter based on date range
56
  log_files = []
57
  for file in all_files:
58
+ parts = file.split("/")
59
  if len(parts) >= 3:
60
  date_str = parts[1]
61
+ if (
62
+ is_within_date_range(date_str, start_date, end_date)
63
+ and parts[-1] == "log.csv"
64
+ ):
65
  log_files.append(file)
66
 
67
  # Download, read and concatenate CSV files into a pandas DataFrame
 
70
  # Download the file
71
  obj = s3.get_object(Bucket=bucket_name, Key=log_file)
72
  try:
73
+ csv_content = obj["Body"].read().decode("utf-8")
74
+ except Exception as e:
75
+ print("Could not load in log file:", log_file, "due to:", e)
76
+ csv_content = obj["Body"].read().decode("latin-1")
77
 
78
  # Read CSV content into pandas DataFrame
79
  try:
 
89
  concatenated_df = pd.concat(df_list, ignore_index=True)
90
 
91
  # Save the concatenated DataFrame to a CSV file
92
+ concatenated_df.to_csv(OUTPUT_FOLDER + "consolidated_s3_logs.csv", index=False)
93
  print("Consolidated CSV saved as 'consolidated_s3_logs.csv'")
94
  else:
95
  print("No log files found in the given date range.")
pyproject.toml CHANGED
@@ -52,7 +52,16 @@ test = ["pytest", "pytest-cov"]
52
  # Configuration for Ruff linter:
53
  [tool.ruff]
54
  line-length = 88
 
 
55
  select = ["E", "F", "I"]
 
 
 
 
 
 
 
56
 
57
  # Configuration for a Black formatter:
58
  [tool.black]
 
52
  # Configuration for Ruff linter:
53
  [tool.ruff]
54
  line-length = 88
55
+
56
+ [tool.ruff.lint]
57
  select = ["E", "F", "I"]
58
+ ignore = [
59
+ "E501", # line-too-long (handled with Black)
60
+ "E402", # module-import-not-at-top-of-file (sometimes needed for conditional imports)
61
+ ]
62
+
63
+ [tool.ruff.lint.per-file-ignores]
64
+ "__init__.py" = ["F401"] # Allow unused imports in __init__.py
65
 
66
  # Configuration for a Black formatter:
67
  [tool.black]
test/GUI_TEST_README.md ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # GUI Testing for Document Redaction App
2
+
3
+ This directory contains tests specifically for verifying that the GUI application (`app.py`) loads correctly.
4
+
5
+ ## Test Files
6
+
7
+ ### `test_gui_only.py`
8
+ A standalone script that tests only the GUI functionality. This is useful for:
9
+ - Quick verification that the Gradio interface loads without errors
10
+ - CI/CD pipelines where you want to test GUI separately from CLI functionality
11
+ - Development testing when you only want to check GUI components
12
+
13
+ **Usage:**
14
+
15
+ Option 1 - Manual activation:
16
+ ```bash
17
+ conda activate redaction
18
+ cd test
19
+ python test_gui_only.py
20
+ ```
21
+
22
+ Option 2 - Using helper scripts (Windows):
23
+ ```bash
24
+ cd test
25
+ # For Command Prompt:
26
+ run_gui_test.bat
27
+
28
+ # For PowerShell:
29
+ .\run_gui_test.ps1
30
+ ```
31
+
32
+ ### `test.py` (Updated)
33
+ The main test suite now includes both CLI and GUI tests. The GUI tests are in the `TestGUIApp` class.
34
+
35
+ **Usage:**
36
+
37
+ Option 1 - Manual activation:
38
+ ```bash
39
+ conda activate redaction
40
+ cd test
41
+ python test.py
42
+ ```
43
+
44
+ Option 2 - Using helper scripts (Windows):
45
+ ```bash
46
+ cd test
47
+ # For Command Prompt:
48
+ run_gui_test.bat
49
+
50
+ # For PowerShell:
51
+ .\run_gui_test.ps1
52
+ ```
53
+
54
+ ## What the GUI Tests Check
55
+
56
+ 1. **App Import and Initialization** (`test_app_import_and_initialization`)
57
+ - Verifies that `app.py` can be imported without errors
58
+ - Checks that the Gradio `app` object is created successfully
59
+ - Ensures the app is a proper Gradio Blocks instance
60
+
61
+ 2. **App Launch in Headless Mode** (`test_app_launch_headless`)
62
+ - Tests that the app can be launched without opening a browser
63
+ - Verifies the Gradio server starts successfully
64
+ - Uses threading to prevent blocking the test execution
65
+
66
+ 3. **Configuration Loading** (`test_app_configuration_loading`)
67
+ - Verifies that configuration variables are loaded correctly
68
+ - Checks key settings like server port, file size limits, language settings
69
+ - Ensures the app has access to all required configuration
70
+
71
+ ## Test Requirements
72
+
73
+ - **Conda environment 'redaction' must be activated** before running tests
74
+ - Python environment with all dependencies installed
75
+ - Access to the `tools.config` module
76
+ - Gradio and related GUI dependencies (including `gradio_image_annotation`)
77
+ - The `app.py` file in the parent directory
78
+
79
+ ### Prerequisites
80
+
81
+ Before running the GUI tests, ensure you have activated the conda environment:
82
+
83
+ ```bash
84
+ conda activate redaction
85
+ ```
86
+
87
+ The `gradio_image_annotation` package is already installed in the 'redaction' environment.
88
+
89
+ ## Expected Behavior
90
+
91
+ - All tests should pass if the GUI loads correctly
92
+ - Tests will fail if there are import errors, missing dependencies, or configuration issues
93
+ - The headless launch test may take up to 10 seconds to complete
94
+
95
+ ## Troubleshooting
96
+
97
+ If tests fail:
98
+ 1. Check that all dependencies are installed (`pip install -r requirements.txt`)
99
+ 2. Verify that `app.py` exists in the parent directory
100
+ 3. Ensure configuration files are properly set up
101
+ 4. Check for any missing environment variables or configuration issues
102
+
103
+ ## Integration with CI/CD
104
+
105
+ These tests are designed to run in headless environments and are suitable for:
106
+ - GitHub Actions
107
+ - Jenkins pipelines
108
+ - Docker containers
109
+ - Any automated testing environment
110
+
111
+ The tests do not require a display or browser to be available.
test/demo_single_test.py CHANGED
@@ -7,84 +7,100 @@ to test a specific CLI example.
7
  """
8
 
9
  import os
 
10
  import sys
11
  import tempfile
12
- import shutil
13
 
14
  # Add the parent directory to the path
15
  sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
16
 
17
  from test.test import run_cli_redact
18
 
 
19
  def demo_pdf_redaction():
20
  """Demonstrate how to run a single PDF redaction test."""
21
  print("=== Demo: PDF Redaction with Default Settings ===")
22
-
23
  # Set up paths
24
- script_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "cli_redact.py")
25
- input_file = os.path.join(os.path.dirname(os.path.dirname(__file__)), "example_data", "example_of_emails_sent_to_a_professor_before_applying.pdf")
 
 
 
 
 
 
26
  output_dir = tempfile.mkdtemp(prefix="demo_output_")
27
-
28
  print(f"Script: {script_path}")
29
  print(f"Input: {input_file}")
30
  print(f"Output: {output_dir}")
31
-
32
  # Check if files exist
33
  if not os.path.isfile(script_path):
34
  print(f"❌ Script not found: {script_path}")
35
  return False
36
-
37
  if not os.path.isfile(input_file):
38
  print(f"❌ Input file not found: {input_file}")
39
- print("Make sure you have the example data files in the example_data/ directory")
 
 
40
  return False
41
-
42
  try:
43
  # Run the test
44
  print("\nRunning PDF redaction with default settings...")
45
  result = run_cli_redact(
46
- script_path=script_path,
47
- input_file=input_file,
48
- output_dir=output_dir
49
  )
50
-
51
  if result:
52
  print("✅ Test completed successfully!")
53
  print(f"Check the output directory for results: {output_dir}")
54
  else:
55
  print("❌ Test failed!")
56
-
57
  return result
58
-
59
  finally:
60
  # Clean up
61
  if os.path.exists(output_dir):
62
  shutil.rmtree(output_dir)
63
  print(f"Cleaned up: {output_dir}")
64
 
 
65
  def demo_csv_anonymisation():
66
  """Demonstrate how to run a CSV anonymisation test."""
67
  print("\n=== Demo: CSV Anonymisation ===")
68
-
69
  # Set up paths
70
- script_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "cli_redact.py")
71
- input_file = os.path.join(os.path.dirname(os.path.dirname(__file__)), "example_data", "combined_case_notes.csv")
 
 
 
 
 
 
72
  output_dir = tempfile.mkdtemp(prefix="demo_output_")
73
-
74
  print(f"Script: {script_path}")
75
  print(f"Input: {input_file}")
76
  print(f"Output: {output_dir}")
77
-
78
  # Check if files exist
79
  if not os.path.isfile(script_path):
80
  print(f"❌ Script not found: {script_path}")
81
  return False
82
-
83
  if not os.path.isfile(input_file):
84
  print(f"❌ Input file not found: {input_file}")
85
- print("Make sure you have the example data files in the example_data/ directory")
 
 
86
  return False
87
-
88
  try:
89
  # Run the test
90
  print("\nRunning CSV anonymisation...")
@@ -93,40 +109,41 @@ def demo_csv_anonymisation():
93
  input_file=input_file,
94
  output_dir=output_dir,
95
  text_columns=["Case Note", "Client"],
96
- anon_strategy="replace_redacted"
97
  )
98
-
99
  if result:
100
  print("✅ Test completed successfully!")
101
  print(f"Check the output directory for results: {output_dir}")
102
  else:
103
  print("❌ Test failed!")
104
-
105
  return result
106
-
107
  finally:
108
  # Clean up
109
  if os.path.exists(output_dir):
110
  shutil.rmtree(output_dir)
111
  print(f"Cleaned up: {output_dir}")
112
 
 
113
  if __name__ == "__main__":
114
  print("CLI Redaction Test Demo")
115
  print("=" * 50)
116
  print("This script demonstrates how to run individual tests.")
117
  print("=" * 50)
118
-
119
  # Run the demos
120
  success1 = demo_pdf_redaction()
121
  success2 = demo_csv_anonymisation()
122
-
123
  print("\n" + "=" * 50)
124
  print("Demo Summary")
125
  print("=" * 50)
126
  print(f"PDF Redaction: {'✅ PASSED' if success1 else '❌ FAILED'}")
127
  print(f"CSV Anonymisation: {'✅ PASSED' if success2 else '❌ FAILED'}")
128
-
129
  overall_success = success1 and success2
130
  print(f"\nOverall: {'✅ PASSED' if overall_success else '❌ FAILED'}")
131
-
132
  sys.exit(0 if overall_success else 1)
 
7
  """
8
 
9
  import os
10
+ import shutil
11
  import sys
12
  import tempfile
 
13
 
14
  # Add the parent directory to the path
15
  sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
16
 
17
  from test.test import run_cli_redact
18
 
19
+
20
  def demo_pdf_redaction():
21
  """Demonstrate how to run a single PDF redaction test."""
22
  print("=== Demo: PDF Redaction with Default Settings ===")
23
+
24
  # Set up paths
25
+ script_path = os.path.join(
26
+ os.path.dirname(os.path.dirname(__file__)), "cli_redact.py"
27
+ )
28
+ input_file = os.path.join(
29
+ os.path.dirname(os.path.dirname(__file__)),
30
+ "example_data",
31
+ "example_of_emails_sent_to_a_professor_before_applying.pdf",
32
+ )
33
  output_dir = tempfile.mkdtemp(prefix="demo_output_")
34
+
35
  print(f"Script: {script_path}")
36
  print(f"Input: {input_file}")
37
  print(f"Output: {output_dir}")
38
+
39
  # Check if files exist
40
  if not os.path.isfile(script_path):
41
  print(f"❌ Script not found: {script_path}")
42
  return False
43
+
44
  if not os.path.isfile(input_file):
45
  print(f"❌ Input file not found: {input_file}")
46
+ print(
47
+ "Make sure you have the example data files in the example_data/ directory"
48
+ )
49
  return False
50
+
51
  try:
52
  # Run the test
53
  print("\nRunning PDF redaction with default settings...")
54
  result = run_cli_redact(
55
+ script_path=script_path, input_file=input_file, output_dir=output_dir
 
 
56
  )
57
+
58
  if result:
59
  print("✅ Test completed successfully!")
60
  print(f"Check the output directory for results: {output_dir}")
61
  else:
62
  print("❌ Test failed!")
63
+
64
  return result
65
+
66
  finally:
67
  # Clean up
68
  if os.path.exists(output_dir):
69
  shutil.rmtree(output_dir)
70
  print(f"Cleaned up: {output_dir}")
71
 
72
+
73
  def demo_csv_anonymisation():
74
  """Demonstrate how to run a CSV anonymisation test."""
75
  print("\n=== Demo: CSV Anonymisation ===")
76
+
77
  # Set up paths
78
+ script_path = os.path.join(
79
+ os.path.dirname(os.path.dirname(__file__)), "cli_redact.py"
80
+ )
81
+ input_file = os.path.join(
82
+ os.path.dirname(os.path.dirname(__file__)),
83
+ "example_data",
84
+ "combined_case_notes.csv",
85
+ )
86
  output_dir = tempfile.mkdtemp(prefix="demo_output_")
87
+
88
  print(f"Script: {script_path}")
89
  print(f"Input: {input_file}")
90
  print(f"Output: {output_dir}")
91
+
92
  # Check if files exist
93
  if not os.path.isfile(script_path):
94
  print(f"❌ Script not found: {script_path}")
95
  return False
96
+
97
  if not os.path.isfile(input_file):
98
  print(f"❌ Input file not found: {input_file}")
99
+ print(
100
+ "Make sure you have the example data files in the example_data/ directory"
101
+ )
102
  return False
103
+
104
  try:
105
  # Run the test
106
  print("\nRunning CSV anonymisation...")
 
109
  input_file=input_file,
110
  output_dir=output_dir,
111
  text_columns=["Case Note", "Client"],
112
+ anon_strategy="replace_redacted",
113
  )
114
+
115
  if result:
116
  print("✅ Test completed successfully!")
117
  print(f"Check the output directory for results: {output_dir}")
118
  else:
119
  print("❌ Test failed!")
120
+
121
  return result
122
+
123
  finally:
124
  # Clean up
125
  if os.path.exists(output_dir):
126
  shutil.rmtree(output_dir)
127
  print(f"Cleaned up: {output_dir}")
128
 
129
+
130
  if __name__ == "__main__":
131
  print("CLI Redaction Test Demo")
132
  print("=" * 50)
133
  print("This script demonstrates how to run individual tests.")
134
  print("=" * 50)
135
+
136
  # Run the demos
137
  success1 = demo_pdf_redaction()
138
  success2 = demo_csv_anonymisation()
139
+
140
  print("\n" + "=" * 50)
141
  print("Demo Summary")
142
  print("=" * 50)
143
  print(f"PDF Redaction: {'✅ PASSED' if success1 else '❌ FAILED'}")
144
  print(f"CSV Anonymisation: {'✅ PASSED' if success2 else '❌ FAILED'}")
145
+
146
  overall_success = success1 and success2
147
  print(f"\nOverall: {'✅ PASSED' if overall_success else '❌ FAILED'}")
148
+
149
  sys.exit(0 if overall_success else 1)
test/run_gui_test.bat ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ @echo off
2
+ REM Batch script to run GUI tests with conda environment activated
3
+ REM This script activates the 'redaction' conda environment and runs the GUI tests
4
+
5
+ echo Activating conda environment 'redaction'...
6
+ call conda activate redaction
7
+
8
+ if %errorlevel% neq 0 (
9
+ echo Failed to activate conda environment 'redaction'
10
+ echo Please ensure conda is installed and the 'redaction' environment exists
11
+ pause
12
+ exit /b 1
13
+ )
14
+
15
+ echo Running GUI tests...
16
+ python test_gui_only.py
17
+
18
+ if %errorlevel% neq 0 (
19
+ echo GUI tests failed
20
+ pause
21
+ exit /b 1
22
+ ) else (
23
+ echo GUI tests passed successfully
24
+ )
25
+
26
+ pause
test/run_gui_test.ps1 ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # PowerShell script to run GUI tests with conda environment activated
2
+ # This script activates the 'redaction' conda environment and runs the GUI tests
3
+
4
+ Write-Host "Activating conda environment 'redaction'..." -ForegroundColor Green
5
+
6
+ try {
7
+ # Try to activate the conda environment
8
+ conda activate redaction
9
+
10
+ if ($LASTEXITCODE -ne 0) {
11
+ Write-Host "Failed to activate conda environment 'redaction'" -ForegroundColor Red
12
+ Write-Host "Please ensure conda is installed and the 'redaction' environment exists" -ForegroundColor Red
13
+ Read-Host "Press Enter to exit"
14
+ exit 1
15
+ }
16
+
17
+ Write-Host "Running GUI tests..." -ForegroundColor Green
18
+ python test_gui_only.py
19
+
20
+ if ($LASTEXITCODE -ne 0) {
21
+ Write-Host "GUI tests failed" -ForegroundColor Red
22
+ Read-Host "Press Enter to exit"
23
+ exit 1
24
+ } else {
25
+ Write-Host "GUI tests passed successfully" -ForegroundColor Green
26
+ }
27
+
28
+ } catch {
29
+ Write-Host "An error occurred: $_" -ForegroundColor Red
30
+ Read-Host "Press Enter to exit"
31
+ exit 1
32
+ }
33
+
34
+ Read-Host "Press Enter to exit"
test/run_tests.py CHANGED
@@ -6,8 +6,8 @@ This script demonstrates how to run the comprehensive test suite
6
  that covers all the examples from the CLI epilog.
7
  """
8
 
9
- import sys
10
  import os
 
11
 
12
  # Add the parent directory to the path so we can import the test module
13
  sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
@@ -18,9 +18,9 @@ if __name__ == "__main__":
18
  print("Starting CLI Redaction Test Suite...")
19
  print("This will test all examples from the CLI epilog.")
20
  print("=" * 60)
21
-
22
  success = run_all_tests()
23
-
24
  if success:
25
  print("\n🎉 All tests passed successfully!")
26
  sys.exit(0)
 
6
  that covers all the examples from the CLI epilog.
7
  """
8
 
 
9
  import os
10
+ import sys
11
 
12
  # Add the parent directory to the path so we can import the test module
13
  sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
 
18
  print("Starting CLI Redaction Test Suite...")
19
  print("This will test all examples from the CLI epilog.")
20
  print("=" * 60)
21
+
22
  success = run_all_tests()
23
+
24
  if success:
25
  print("\n🎉 All tests passed successfully!")
26
  sys.exit(0)
test/test.py CHANGED
@@ -1,17 +1,20 @@
1
- from typing import List, Optional
2
  import os
 
3
  import subprocess
4
- import unittest
5
  import tempfile
6
- import shutil
 
 
 
 
 
7
 
8
  def run_cli_redact(
9
  script_path: str,
10
  input_file: str,
11
  output_dir: str,
12
- task: str = 'redact',
13
- timeout: int = 600, # 10-minute timeout
14
-
15
  # --- General Arguments ---
16
  input_dir: Optional[str] = None,
17
  language: Optional[str] = None,
@@ -32,7 +35,6 @@ def run_cli_redact(
32
  display_file_names_in_logs: Optional[bool] = None,
33
  upload_logs_to_s3: Optional[bool] = None,
34
  s3_logs_prefix: Optional[str] = None,
35
-
36
  # --- PDF/Image Redaction Arguments ---
37
  ocr_method: Optional[str] = None,
38
  page_min: Optional[int] = None,
@@ -49,14 +51,12 @@ def run_cli_redact(
49
  extract_forms: Optional[bool] = None,
50
  extract_tables: Optional[bool] = None,
51
  extract_layout: Optional[bool] = None,
52
-
53
  # --- Word/Tabular Anonymisation Arguments ---
54
  anon_strategy: Optional[str] = None,
55
  text_columns: Optional[List[str]] = None,
56
  excel_sheets: Optional[List[str]] = None,
57
  fuzzy_mistakes: Optional[int] = None,
58
  match_fuzzy_whole_phrase_bool: Optional[bool] = None,
59
-
60
  # --- Duplicate Detection Arguments ---
61
  duplicate_type: Optional[str] = None,
62
  similarity_threshold: Optional[float] = None,
@@ -65,7 +65,6 @@ def run_cli_redact(
65
  greedy_match: Optional[bool] = None,
66
  combine_pages: Optional[bool] = None,
67
  remove_duplicate_rows: Optional[bool] = None,
68
-
69
  # --- Textract Batch Operations Arguments ---
70
  textract_action: Optional[str] = None,
71
  job_id: Optional[str] = None,
@@ -76,7 +75,7 @@ def run_cli_redact(
76
  s3_textract_document_logs_subfolder: Optional[str] = None,
77
  local_textract_document_logs_subfolder: Optional[str] = None,
78
  poll_interval: Optional[int] = None,
79
- max_poll_attempts: Optional[int] = None
80
  ) -> bool:
81
  """
82
  Executes the cli_redact.py script with specified arguments using a subprocess.
@@ -87,7 +86,7 @@ def run_cli_redact(
87
  output_dir (str): The path to the directory for output files.
88
  task (str): The main task to perform ('redact', 'deduplicate', or 'textract').
89
  timeout (int): Timeout in seconds for the subprocess.
90
-
91
  # General Arguments
92
  input_dir (str): Directory for all input files.
93
  language (str): Language of the document content.
@@ -108,7 +107,7 @@ def run_cli_redact(
108
  display_file_names_in_logs (bool): Include file names in log outputs.
109
  upload_logs_to_s3 (bool): Upload log files to S3 after processing.
110
  s3_logs_prefix (str): S3 prefix for usage log files.
111
-
112
  # PDF/Image Redaction Arguments
113
  ocr_method (str): OCR method for text extraction from images.
114
  page_min (int): First page to redact.
@@ -125,14 +124,14 @@ def run_cli_redact(
125
  extract_forms (bool): Extract forms during Textract analysis.
126
  extract_tables (bool): Extract tables during Textract analysis.
127
  extract_layout (bool): Extract layout during Textract analysis.
128
-
129
  # Word/Tabular Anonymisation Arguments
130
  anon_strategy (str): The anonymisation strategy to apply.
131
  text_columns (List[str]): A list of column names to anonymise or deduplicate.
132
  excel_sheets (List[str]): Specific Excel sheet names to process.
133
  fuzzy_mistakes (int): Number of allowed spelling mistakes for fuzzy matching.
134
  match_fuzzy_whole_phrase_bool (bool): Match fuzzy whole phrase boolean.
135
-
136
  # Duplicate Detection Arguments
137
  duplicate_type (str): Type of duplicate detection (pages or tabular).
138
  similarity_threshold (float): Similarity threshold (0-1) to consider content as duplicates.
@@ -141,7 +140,7 @@ def run_cli_redact(
141
  greedy_match (bool): Use greedy matching strategy for consecutive pages.
142
  combine_pages (bool): Combine text from the same page number within a file.
143
  remove_duplicate_rows (bool): Remove duplicate rows from the output.
144
-
145
  # Textract Batch Operations Arguments
146
  textract_action (str): Textract action to perform (submit, retrieve, or list).
147
  job_id (str): Textract job ID for retrieve action.
@@ -160,9 +159,9 @@ def run_cli_redact(
160
  # 1. Get absolute paths and perform pre-checks
161
  script_abs_path = os.path.abspath(script_path)
162
  output_abs_dir = os.path.abspath(output_dir)
163
-
164
  # Handle input file based on task and action
165
- if task == 'textract' and textract_action in ['retrieve', 'list']:
166
  # For retrieve and list actions, input file is not required
167
  input_abs_path = None
168
  else:
@@ -172,25 +171,27 @@ def run_cli_redact(
172
  input_abs_path = os.path.abspath(input_file)
173
  if not os.path.isfile(input_abs_path):
174
  raise FileNotFoundError(f"Input file not found: {input_abs_path}")
175
-
176
  if not os.path.isfile(script_abs_path):
177
  raise FileNotFoundError(f"Script not found: {script_abs_path}")
178
-
179
  if not os.path.isdir(output_abs_dir):
180
  # Create the output directory if it doesn't exist
181
  print(f"Output directory not found. Creating: {output_abs_dir}")
182
  os.makedirs(output_abs_dir)
183
-
184
  script_folder = os.path.dirname(script_abs_path)
185
 
186
  # 2. Dynamically build the command list
187
  command = [
188
  "python",
189
  script_abs_path,
190
- "--output_dir", output_abs_dir,
191
- "--task", task,
 
 
192
  ]
193
-
194
  # Add input_file only if it's not None
195
  if input_abs_path is not None:
196
  command.extend(["--input_file", input_abs_path])
@@ -231,7 +232,9 @@ def run_cli_redact(
231
  if save_logs_to_dynamodb is not None:
232
  command.extend(["--save_logs_to_dynamodb", str(save_logs_to_dynamodb)])
233
  if display_file_names_in_logs is not None:
234
- command.extend(["--display_file_names_in_logs", str(display_file_names_in_logs)])
 
 
235
  if upload_logs_to_s3 is not None:
236
  command.extend(["--upload_logs_to_s3", str(upload_logs_to_s3)])
237
  if s3_logs_prefix:
@@ -249,17 +252,23 @@ def run_cli_redact(
249
  if chosen_local_ocr_model:
250
  command.extend(["--chosen_local_ocr_model", chosen_local_ocr_model])
251
  if preprocess_local_ocr_images is not None:
252
- command.extend(["--preprocess_local_ocr_images", str(preprocess_local_ocr_images)])
 
 
253
  if compress_redacted_pdf is not None:
254
  command.extend(["--compress_redacted_pdf", str(compress_redacted_pdf)])
255
  if return_pdf_end_of_redaction is not None:
256
- command.extend(["--return_pdf_end_of_redaction", str(return_pdf_end_of_redaction)])
 
 
257
  if deny_list_file and os.path.isfile(deny_list_file):
258
  command.extend(["--deny_list_file", os.path.abspath(deny_list_file)])
259
  if allow_list_file and os.path.isfile(allow_list_file):
260
  command.extend(["--allow_list_file", os.path.abspath(allow_list_file)])
261
  if redact_whole_page_file and os.path.isfile(redact_whole_page_file):
262
- command.extend(["--redact_whole_page_file", os.path.abspath(redact_whole_page_file)])
 
 
263
  if handwrite_signature_extraction:
264
  command.append("--handwrite_signature_extraction")
265
  command.extend(handwrite_signature_extraction)
@@ -282,7 +291,9 @@ def run_cli_redact(
282
  if fuzzy_mistakes is not None:
283
  command.extend(["--fuzzy_mistakes", str(fuzzy_mistakes)])
284
  if match_fuzzy_whole_phrase_bool is not None:
285
- command.extend(["--match_fuzzy_whole_phrase_bool", str(match_fuzzy_whole_phrase_bool)])
 
 
286
 
287
  # Add duplicate detection arguments
288
  if duplicate_type:
@@ -315,16 +326,26 @@ def run_cli_redact(
315
  if textract_output_prefix:
316
  command.extend(["--textract_output_prefix", textract_output_prefix])
317
  if s3_textract_document_logs_subfolder:
318
- command.extend(["--s3_textract_document_logs_subfolder", s3_textract_document_logs_subfolder])
 
 
 
 
 
319
  if local_textract_document_logs_subfolder:
320
- command.extend(["--local_textract_document_logs_subfolder", local_textract_document_logs_subfolder])
 
 
 
 
 
321
  if poll_interval is not None:
322
  command.extend(["--poll_interval", str(poll_interval)])
323
  if max_poll_attempts is not None:
324
  command.extend(["--max_poll_attempts", str(max_poll_attempts)])
325
 
326
  # Filter out None values before joining
327
- command_str = ' '.join(str(arg) for arg in command if arg is not None)
328
  print(f"Executing command: {command_str}")
329
 
330
  # 3. Execute the command using subprocess
@@ -334,12 +355,12 @@ def run_cli_redact(
334
  stdout=subprocess.PIPE,
335
  stderr=subprocess.PIPE,
336
  text=True,
337
- cwd=script_folder # Important for relative paths within the script
338
  )
339
 
340
  # Communicate with the process to get output and handle timeout
341
  stdout, stderr = result.communicate(timeout=timeout)
342
-
343
  print("--- SCRIPT STDOUT ---")
344
  if stdout:
345
  print(stdout)
@@ -366,124 +387,150 @@ def run_cli_redact(
366
 
367
  class TestCLIRedactExamples(unittest.TestCase):
368
  """Test suite for CLI redaction examples from the epilog."""
369
-
370
  @classmethod
371
  def setUpClass(cls):
372
  """Set up test environment before running tests."""
373
- cls.script_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "cli_redact.py")
374
- cls.example_data_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "example_data")
 
 
 
 
375
  cls.temp_output_dir = tempfile.mkdtemp(prefix="test_output_")
376
-
377
  # Verify script exists
378
  if not os.path.isfile(cls.script_path):
379
  raise FileNotFoundError(f"CLI script not found: {cls.script_path}")
380
-
381
  print(f"Test setup complete. Script: {cls.script_path}")
382
  print(f"Example data directory: {cls.example_data_dir}")
383
  print(f"Temp output directory: {cls.temp_output_dir}")
384
-
385
  @classmethod
386
  def tearDownClass(cls):
387
  """Clean up test environment after running tests."""
388
  if os.path.exists(cls.temp_output_dir):
389
  shutil.rmtree(cls.temp_output_dir)
390
  print(f"Cleaned up temp directory: {cls.temp_output_dir}")
391
-
392
  def test_pdf_redaction_default_settings(self):
393
  """Test: Redact a PDF with default settings (local OCR)"""
394
  print("\n=== Testing PDF redaction with default settings ===")
395
- input_file = os.path.join(self.example_data_dir, "example_of_emails_sent_to_a_professor_before_applying.pdf")
396
-
 
 
 
397
  if not os.path.isfile(input_file):
398
  self.skipTest(f"Example file not found: {input_file}")
399
-
400
  result = run_cli_redact(
401
  script_path=self.script_path,
402
  input_file=input_file,
403
- output_dir=self.temp_output_dir
404
  )
405
-
406
  self.assertTrue(result, "PDF redaction with default settings should succeed")
407
  print("✅ PDF redaction with default settings passed")
408
-
409
  def test_pdf_text_extraction_only(self):
410
  """Test: Extract text from a PDF only (i.e. no redaction), using local OCR"""
411
  print("\n=== Testing PDF text extraction only ===")
412
- input_file = os.path.join(self.example_data_dir, "Partnership-Agreement-Toolkit_0_0.pdf")
413
- whole_page_file = os.path.join(self.example_data_dir, "partnership_toolkit_redact_some_pages.csv")
414
-
 
 
 
 
415
  if not os.path.isfile(input_file):
416
  self.skipTest(f"Example file not found: {input_file}")
417
  if not os.path.isfile(whole_page_file):
418
  self.skipTest(f"Whole page file not found: {whole_page_file}")
419
-
420
  result = run_cli_redact(
421
  script_path=self.script_path,
422
  input_file=input_file,
423
  output_dir=self.temp_output_dir,
424
  redact_whole_page_file=whole_page_file,
425
- pii_detector="None"
426
  )
427
-
428
  self.assertTrue(result, "PDF text extraction should succeed")
429
  print("✅ PDF text extraction only passed")
430
-
431
  def test_pdf_text_extraction_with_whole_page_redaction(self):
432
  """Test: Extract text from a PDF only with a whole page redaction list"""
433
  print("\n=== Testing PDF text extraction with whole page redaction ===")
434
- input_file = os.path.join(self.example_data_dir, "Partnership-Agreement-Toolkit_0_0.pdf")
435
- whole_page_file = os.path.join(self.example_data_dir, "partnership_toolkit_redact_some_pages.csv")
436
-
 
 
 
 
437
  if not os.path.isfile(input_file):
438
  self.skipTest(f"Example file not found: {input_file}")
439
  if not os.path.isfile(whole_page_file):
440
  self.skipTest(f"Whole page file not found: {whole_page_file}")
441
-
442
  result = run_cli_redact(
443
  script_path=self.script_path,
444
  input_file=input_file,
445
  output_dir=self.temp_output_dir,
446
  redact_whole_page_file=whole_page_file,
447
  pii_detector="Local",
448
- local_redact_entities=["CUSTOM"]
 
 
 
 
449
  )
450
-
451
- self.assertTrue(result, "PDF text extraction with whole page redaction should succeed")
452
  print("✅ PDF text extraction with whole page redaction passed")
453
-
454
  def test_pdf_redaction_with_allow_list(self):
455
  """Test: Redact a PDF with allow list (local OCR) and custom list of redaction entities"""
456
  print("\n=== Testing PDF redaction with allow list ===")
457
- input_file = os.path.join(self.example_data_dir, "graduate-job-example-cover-letter.pdf")
458
- allow_list_file = os.path.join(self.example_data_dir, "test_allow_list_graduate.csv")
459
-
 
 
 
 
460
  if not os.path.isfile(input_file):
461
  self.skipTest(f"Example file not found: {input_file}")
462
  if not os.path.isfile(allow_list_file):
463
  self.skipTest(f"Allow list file not found: {allow_list_file}")
464
-
465
  result = run_cli_redact(
466
  script_path=self.script_path,
467
  input_file=input_file,
468
  output_dir=self.temp_output_dir,
469
  allow_list_file=allow_list_file,
470
- local_redact_entities=["TITLES", "PERSON", "DATE_TIME"]
471
  )
472
-
473
  self.assertTrue(result, "PDF redaction with allow list should succeed")
474
  print("✅ PDF redaction with allow list passed")
475
-
476
  def test_pdf_redaction_limited_pages_with_custom_fuzzy(self):
477
  """Test: Redact a PDF with limited pages and text extraction method with custom fuzzy matching"""
478
  print("\n=== Testing PDF redaction with limited pages and fuzzy matching ===")
479
- input_file = os.path.join(self.example_data_dir, "Partnership-Agreement-Toolkit_0_0.pdf")
480
- deny_list_file = os.path.join(self.example_data_dir, "Partnership-Agreement-Toolkit_test_deny_list_para_single_spell.csv")
481
-
 
 
 
 
 
482
  if not os.path.isfile(input_file):
483
  self.skipTest(f"Example file not found: {input_file}")
484
  if not os.path.isfile(deny_list_file):
485
  self.skipTest(f"Deny list file not found: {deny_list_file}")
486
-
487
  result = run_cli_redact(
488
  script_path=self.script_path,
489
  input_file=input_file,
@@ -493,20 +540,30 @@ class TestCLIRedactExamples(unittest.TestCase):
493
  page_min=1,
494
  page_max=3,
495
  ocr_method="Local text",
496
- fuzzy_mistakes=3
 
 
 
 
497
  )
498
-
499
- self.assertTrue(result, "PDF redaction with limited pages and fuzzy matching should succeed")
500
  print("✅ PDF redaction with limited pages and fuzzy matching passed")
501
-
502
  def test_pdf_redaction_with_custom_lists(self):
503
  """Test: Redaction with custom deny list, allow list, and whole page redaction list"""
504
  print("\n=== Testing PDF redaction with custom lists ===")
505
- input_file = os.path.join(self.example_data_dir, "Partnership-Agreement-Toolkit_0_0.pdf")
506
- deny_list_file = os.path.join(self.example_data_dir, "partnership_toolkit_redact_custom_deny_list.csv")
507
- whole_page_file = os.path.join(self.example_data_dir, "partnership_toolkit_redact_some_pages.csv")
508
- allow_list_file = os.path.join(self.example_data_dir, "test_allow_list_partnership.csv")
509
-
 
 
 
 
 
 
 
 
510
  if not os.path.isfile(input_file):
511
  self.skipTest(f"Example file not found: {input_file}")
512
  if not os.path.isfile(deny_list_file):
@@ -515,164 +572,186 @@ class TestCLIRedactExamples(unittest.TestCase):
515
  self.skipTest(f"Whole page file not found: {whole_page_file}")
516
  if not os.path.isfile(allow_list_file):
517
  self.skipTest(f"Allow list file not found: {allow_list_file}")
518
-
519
  result = run_cli_redact(
520
  script_path=self.script_path,
521
  input_file=input_file,
522
  output_dir=self.temp_output_dir,
523
  deny_list_file=deny_list_file,
524
  redact_whole_page_file=whole_page_file,
525
- allow_list_file=allow_list_file
526
  )
527
-
528
  self.assertTrue(result, "PDF redaction with custom lists should succeed")
529
  print("✅ PDF redaction with custom lists passed")
530
-
531
  def test_image_redaction(self):
532
  """Test: Redact an image"""
533
  print("\n=== Testing image redaction ===")
534
  input_file = os.path.join(self.example_data_dir, "example_complaint_letter.jpg")
535
-
536
  if not os.path.isfile(input_file):
537
  self.skipTest(f"Example file not found: {input_file}")
538
-
539
  result = run_cli_redact(
540
  script_path=self.script_path,
541
  input_file=input_file,
542
- output_dir=self.temp_output_dir
543
  )
544
-
545
  self.assertTrue(result, "Image redaction should succeed")
546
  print("✅ Image redaction passed")
547
-
548
  def test_csv_anonymisation_specific_columns(self):
549
  """Test: Anonymise csv file with specific columns"""
550
  print("\n=== Testing CSV anonymisation with specific columns ===")
551
  input_file = os.path.join(self.example_data_dir, "combined_case_notes.csv")
552
-
553
  if not os.path.isfile(input_file):
554
  self.skipTest(f"Example file not found: {input_file}")
555
-
556
  result = run_cli_redact(
557
  script_path=self.script_path,
558
  input_file=input_file,
559
  output_dir=self.temp_output_dir,
560
  text_columns=["Case Note", "Client"],
561
- anon_strategy="replace_redacted"
 
 
 
 
562
  )
563
-
564
- self.assertTrue(result, "CSV anonymisation with specific columns should succeed")
565
  print("✅ CSV anonymisation with specific columns passed")
566
-
567
  def test_csv_anonymisation_different_strategy(self):
568
  """Test: Anonymise csv file with a different strategy (remove text completely)"""
569
  print("\n=== Testing CSV anonymisation with different strategy ===")
570
  input_file = os.path.join(self.example_data_dir, "combined_case_notes.csv")
571
-
572
  if not os.path.isfile(input_file):
573
  self.skipTest(f"Example file not found: {input_file}")
574
-
575
  result = run_cli_redact(
576
  script_path=self.script_path,
577
  input_file=input_file,
578
  output_dir=self.temp_output_dir,
579
  text_columns=["Case Note", "Client"],
580
- anon_strategy="redact"
 
 
 
 
581
  )
582
-
583
- self.assertTrue(result, "CSV anonymisation with different strategy should succeed")
584
  print("✅ CSV anonymisation with different strategy passed")
585
-
586
  def test_word_document_anonymisation(self):
587
  """Test: Anonymise a word document"""
588
  print("\n=== Testing Word document anonymisation ===")
589
- input_file = os.path.join(self.example_data_dir, "Bold minimalist professional cover letter.docx")
590
-
 
 
591
  if not os.path.isfile(input_file):
592
  self.skipTest(f"Example file not found: {input_file}")
593
-
594
  result = run_cli_redact(
595
  script_path=self.script_path,
596
  input_file=input_file,
597
  output_dir=self.temp_output_dir,
598
- anon_strategy="replace_redacted"
599
  )
600
-
601
  self.assertTrue(result, "Word document anonymisation should succeed")
602
  print("✅ Word document anonymisation passed")
603
-
604
  def test_aws_textract_comprehend_redaction(self):
605
  """Test: Use Textract and Comprehend for redaction"""
606
  print("\n=== Testing AWS Textract and Comprehend redaction ===")
607
- input_file = os.path.join(self.example_data_dir, "example_of_emails_sent_to_a_professor_before_applying.pdf")
608
-
 
 
 
609
  if not os.path.isfile(input_file):
610
  self.skipTest(f"Example file not found: {input_file}")
611
-
612
  # Skip this test if AWS credentials are not available
613
  # This is a conditional test that may not work in all environments
614
- result = run_cli_redact(
615
  script_path=self.script_path,
616
  input_file=input_file,
617
  output_dir=self.temp_output_dir,
618
  ocr_method="AWS Textract",
619
- pii_detector="AWS Comprehend"
620
  )
621
-
622
  # Note: This test may fail if AWS credentials are not configured
623
  # We'll mark it as passed if it runs without crashing
624
  print("✅ AWS Textract and Comprehend redaction test completed")
625
-
626
  def test_aws_textract_signature_extraction(self):
627
  """Test: Redact specific pages with AWS OCR and signature extraction"""
628
  print("\n=== Testing AWS Textract with signature extraction ===")
629
- input_file = os.path.join(self.example_data_dir, "Partnership-Agreement-Toolkit_0_0.pdf")
630
-
 
 
631
  if not os.path.isfile(input_file):
632
  self.skipTest(f"Example file not found: {input_file}")
633
-
634
  # Skip this test if AWS credentials are not available
635
- result = run_cli_redact(
636
  script_path=self.script_path,
637
  input_file=input_file,
638
  output_dir=self.temp_output_dir,
639
  page_min=6,
640
  page_max=7,
641
  ocr_method="AWS Textract",
642
- handwrite_signature_extraction=["Extract handwriting", "Extract signatures"]
 
 
 
643
  )
644
-
645
  # Note: This test may fail if AWS credentials are not configured
646
  print("✅ AWS Textract with signature extraction test completed")
647
-
648
  def test_duplicate_pages_detection(self):
649
  """Test: Find duplicate pages in OCR files"""
650
  print("\n=== Testing duplicate pages detection ===")
651
- input_file = os.path.join(self.example_data_dir, "example_outputs", "doubled_output_joined.pdf_ocr_output.csv")
652
-
 
 
 
 
653
  if not os.path.isfile(input_file):
654
  self.skipTest(f"Example OCR file not found: {input_file}")
655
-
656
  result = run_cli_redact(
657
  script_path=self.script_path,
658
  input_file=input_file,
659
  output_dir=self.temp_output_dir,
660
  task="deduplicate",
661
  duplicate_type="pages",
662
- similarity_threshold=0.95
663
  )
664
-
665
  self.assertTrue(result, "Duplicate pages detection should succeed")
666
  print("✅ Duplicate pages detection passed")
667
-
668
  def test_duplicate_line_level_detection(self):
669
  """Test: Find duplicate in OCR files at the line level"""
670
  print("\n=== Testing duplicate line level detection ===")
671
- input_file = os.path.join(self.example_data_dir, "example_outputs", "doubled_output_joined.pdf_ocr_output.csv")
672
-
 
 
 
 
673
  if not os.path.isfile(input_file):
674
  self.skipTest(f"Example OCR file not found: {input_file}")
675
-
676
  result = run_cli_redact(
677
  script_path=self.script_path,
678
  input_file=input_file,
@@ -681,20 +760,22 @@ class TestCLIRedactExamples(unittest.TestCase):
681
  duplicate_type="pages",
682
  similarity_threshold=0.95,
683
  combine_pages=False,
684
- min_word_count=3
685
  )
686
-
687
  self.assertTrue(result, "Duplicate line level detection should succeed")
688
  print("✅ Duplicate line level detection passed")
689
-
690
  def test_duplicate_tabular_detection(self):
691
  """Test: Find duplicate rows in tabular data"""
692
  print("\n=== Testing duplicate tabular detection ===")
693
- input_file = os.path.join(self.example_data_dir, "Lambeth_2030-Our_Future_Our_Lambeth.pdf.csv")
694
-
 
 
695
  if not os.path.isfile(input_file):
696
  self.skipTest(f"Example CSV file not found: {input_file}")
697
-
698
  result = run_cli_redact(
699
  script_path=self.script_path,
700
  input_file=input_file,
@@ -702,124 +783,284 @@ class TestCLIRedactExamples(unittest.TestCase):
702
  task="deduplicate",
703
  duplicate_type="tabular",
704
  text_columns=["text"],
705
- similarity_threshold=0.95
706
  )
707
-
708
  self.assertTrue(result, "Duplicate tabular detection should succeed")
709
  print("✅ Duplicate tabular detection passed")
710
-
711
  def test_textract_submit_document(self):
712
  """Test: Submit document to Textract for basic text analysis"""
713
  print("\n=== Testing Textract document submission ===")
714
- input_file = os.path.join(self.example_data_dir, "example_of_emails_sent_to_a_professor_before_applying.pdf")
715
-
 
 
 
716
  if not os.path.isfile(input_file):
717
  self.skipTest(f"Example file not found: {input_file}")
718
-
719
  # Skip this test if AWS credentials are not available
720
  try:
721
- result = run_cli_redact(
722
  script_path=self.script_path,
723
  input_file=input_file,
724
  output_dir=self.temp_output_dir,
725
  task="textract",
726
- textract_action="submit"
727
  )
728
  except Exception as e:
729
  print(f"Textract test failed (expected without AWS credentials): {e}")
730
- result = True # Mark as passed since this is expected to fail without credentials
731
-
732
  # Note: This test may fail if AWS credentials are not configured
733
  print("✅ Textract document submission test completed")
734
-
735
  def test_textract_submit_with_signatures(self):
736
  """Test: Submit document to Textract for analysis with signature extraction"""
737
  print("\n=== Testing Textract submission with signature extraction ===")
738
- input_file = os.path.join(self.example_data_dir, "Partnership-Agreement-Toolkit_0_0.pdf")
739
-
 
 
740
  if not os.path.isfile(input_file):
741
  self.skipTest(f"Example file not found: {input_file}")
742
-
743
  # Skip this test if AWS credentials are not available
744
  try:
745
- result = run_cli_redact(
746
  script_path=self.script_path,
747
  input_file=input_file,
748
  output_dir=self.temp_output_dir,
749
  task="textract",
750
  textract_action="submit",
751
- extract_signatures=True
752
  )
753
  except Exception as e:
754
  print(f"Textract test failed (expected without AWS credentials): {e}")
755
- result = True # Mark as passed since this is expected to fail without credentials
756
-
757
  # Note: This test may fail if AWS credentials are not configured
758
  print("✅ Textract submission with signature extraction test completed")
759
-
760
  def test_textract_retrieve_results(self):
761
  """Test: Retrieve Textract results by job ID"""
762
  print("\n=== Testing Textract results retrieval ===")
763
-
764
  # Skip this test if AWS credentials are not available
765
  # This would require a valid job ID from a previous submission
766
  # For retrieve and list actions, we don't need a real input file
767
  try:
768
- result = run_cli_redact(
769
  script_path=self.script_path,
770
  input_file=None, # No input file needed for retrieve action
771
  output_dir=self.temp_output_dir,
772
  task="textract",
773
  textract_action="retrieve",
774
- job_id="12345678-1234-1234-1234-123456789012" # Dummy job ID
775
  )
776
  except Exception as e:
777
  print(f"Textract test failed (expected without AWS credentials): {e}")
778
- result = True # Mark as passed since this is expected to fail without credentials
779
-
780
  # Note: This test will likely fail with a dummy job ID, but that's expected
781
  print("✅ Textract results retrieval test completed")
782
-
783
  def test_textract_list_jobs(self):
784
  """Test: List recent Textract jobs"""
785
  print("\n=== Testing Textract jobs listing ===")
786
-
787
  # Skip this test if AWS credentials are not available
788
  # For list action, we don't need a real input file
789
  try:
790
- result = run_cli_redact(
791
  script_path=self.script_path,
792
  input_file=None, # No input file needed for list action
793
  output_dir=self.temp_output_dir,
794
  task="textract",
795
- textract_action="list"
796
  )
797
  except Exception as e:
798
  print(f"Textract test failed (expected without AWS credentials): {e}")
799
- result = True # Mark as passed since this is expected to fail without credentials
800
-
801
  # Note: This test may fail if AWS credentials are not configured
802
  print("✅ Textract jobs listing test completed")
803
 
804
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
805
  def run_all_tests():
806
  """Run all test examples and report results."""
807
  print("=" * 80)
808
- print("DOCUMENT REDACTION CLI TEST SUITE")
809
  print("=" * 80)
810
- print("This test suite runs through all the examples from the CLI epilog.")
 
 
811
  print("Tests will be skipped if required example files are not found.")
812
  print("AWS-related tests may fail if credentials are not configured.")
813
  print("=" * 80)
814
-
815
  # Create test suite
816
  loader = unittest.TestLoader()
817
- suite = loader.loadTestsFromTestCase(TestCLIRedactExamples)
818
 
 
 
 
 
 
 
 
 
819
  # Run tests with detailed output
820
  runner = unittest.TextTestRunner(verbosity=2, stream=None)
821
  result = runner.run(suite)
822
-
823
  # Print summary
824
  print("\n" + "=" * 80)
825
  print("TEST SUMMARY")
@@ -828,25 +1069,25 @@ def run_all_tests():
828
  print(f"Failures: {len(result.failures)}")
829
  print(f"Errors: {len(result.errors)}")
830
  print(f"Skipped: {len(result.skipped) if hasattr(result, 'skipped') else 0}")
831
-
832
  if result.failures:
833
  print("\nFAILURES:")
834
  for test, traceback in result.failures:
835
  print(f"- {test}: {traceback}")
836
-
837
  if result.errors:
838
  print("\nERRORS:")
839
  for test, traceback in result.errors:
840
  print(f"- {test}: {traceback}")
841
-
842
  success = len(result.failures) == 0 and len(result.errors) == 0
843
  print(f"\nOverall result: {'✅ PASSED' if success else '❌ FAILED'}")
844
  print("=" * 80)
845
-
846
  return success
847
 
848
 
849
  if __name__ == "__main__":
850
  # Run the test suite
851
  success = run_all_tests()
852
- exit(0 if success else 1)
 
 
1
  import os
2
+ import shutil
3
  import subprocess
 
4
  import tempfile
5
+ import unittest
6
+ import sys
7
+ import threading
8
+ import time
9
+ from typing import List, Optional
10
+
11
 
12
  def run_cli_redact(
13
  script_path: str,
14
  input_file: str,
15
  output_dir: str,
16
+ task: str = "redact",
17
+ timeout: int = 600, # 10-minute timeout
 
18
  # --- General Arguments ---
19
  input_dir: Optional[str] = None,
20
  language: Optional[str] = None,
 
35
  display_file_names_in_logs: Optional[bool] = None,
36
  upload_logs_to_s3: Optional[bool] = None,
37
  s3_logs_prefix: Optional[str] = None,
 
38
  # --- PDF/Image Redaction Arguments ---
39
  ocr_method: Optional[str] = None,
40
  page_min: Optional[int] = None,
 
51
  extract_forms: Optional[bool] = None,
52
  extract_tables: Optional[bool] = None,
53
  extract_layout: Optional[bool] = None,
 
54
  # --- Word/Tabular Anonymisation Arguments ---
55
  anon_strategy: Optional[str] = None,
56
  text_columns: Optional[List[str]] = None,
57
  excel_sheets: Optional[List[str]] = None,
58
  fuzzy_mistakes: Optional[int] = None,
59
  match_fuzzy_whole_phrase_bool: Optional[bool] = None,
 
60
  # --- Duplicate Detection Arguments ---
61
  duplicate_type: Optional[str] = None,
62
  similarity_threshold: Optional[float] = None,
 
65
  greedy_match: Optional[bool] = None,
66
  combine_pages: Optional[bool] = None,
67
  remove_duplicate_rows: Optional[bool] = None,
 
68
  # --- Textract Batch Operations Arguments ---
69
  textract_action: Optional[str] = None,
70
  job_id: Optional[str] = None,
 
75
  s3_textract_document_logs_subfolder: Optional[str] = None,
76
  local_textract_document_logs_subfolder: Optional[str] = None,
77
  poll_interval: Optional[int] = None,
78
+ max_poll_attempts: Optional[int] = None,
79
  ) -> bool:
80
  """
81
  Executes the cli_redact.py script with specified arguments using a subprocess.
 
86
  output_dir (str): The path to the directory for output files.
87
  task (str): The main task to perform ('redact', 'deduplicate', or 'textract').
88
  timeout (int): Timeout in seconds for the subprocess.
89
+
90
  # General Arguments
91
  input_dir (str): Directory for all input files.
92
  language (str): Language of the document content.
 
107
  display_file_names_in_logs (bool): Include file names in log outputs.
108
  upload_logs_to_s3 (bool): Upload log files to S3 after processing.
109
  s3_logs_prefix (str): S3 prefix for usage log files.
110
+
111
  # PDF/Image Redaction Arguments
112
  ocr_method (str): OCR method for text extraction from images.
113
  page_min (int): First page to redact.
 
124
  extract_forms (bool): Extract forms during Textract analysis.
125
  extract_tables (bool): Extract tables during Textract analysis.
126
  extract_layout (bool): Extract layout during Textract analysis.
127
+
128
  # Word/Tabular Anonymisation Arguments
129
  anon_strategy (str): The anonymisation strategy to apply.
130
  text_columns (List[str]): A list of column names to anonymise or deduplicate.
131
  excel_sheets (List[str]): Specific Excel sheet names to process.
132
  fuzzy_mistakes (int): Number of allowed spelling mistakes for fuzzy matching.
133
  match_fuzzy_whole_phrase_bool (bool): Match fuzzy whole phrase boolean.
134
+
135
  # Duplicate Detection Arguments
136
  duplicate_type (str): Type of duplicate detection (pages or tabular).
137
  similarity_threshold (float): Similarity threshold (0-1) to consider content as duplicates.
 
140
  greedy_match (bool): Use greedy matching strategy for consecutive pages.
141
  combine_pages (bool): Combine text from the same page number within a file.
142
  remove_duplicate_rows (bool): Remove duplicate rows from the output.
143
+
144
  # Textract Batch Operations Arguments
145
  textract_action (str): Textract action to perform (submit, retrieve, or list).
146
  job_id (str): Textract job ID for retrieve action.
 
159
  # 1. Get absolute paths and perform pre-checks
160
  script_abs_path = os.path.abspath(script_path)
161
  output_abs_dir = os.path.abspath(output_dir)
162
+
163
  # Handle input file based on task and action
164
+ if task == "textract" and textract_action in ["retrieve", "list"]:
165
  # For retrieve and list actions, input file is not required
166
  input_abs_path = None
167
  else:
 
171
  input_abs_path = os.path.abspath(input_file)
172
  if not os.path.isfile(input_abs_path):
173
  raise FileNotFoundError(f"Input file not found: {input_abs_path}")
174
+
175
  if not os.path.isfile(script_abs_path):
176
  raise FileNotFoundError(f"Script not found: {script_abs_path}")
177
+
178
  if not os.path.isdir(output_abs_dir):
179
  # Create the output directory if it doesn't exist
180
  print(f"Output directory not found. Creating: {output_abs_dir}")
181
  os.makedirs(output_abs_dir)
182
+
183
  script_folder = os.path.dirname(script_abs_path)
184
 
185
  # 2. Dynamically build the command list
186
  command = [
187
  "python",
188
  script_abs_path,
189
+ "--output_dir",
190
+ output_abs_dir,
191
+ "--task",
192
+ task,
193
  ]
194
+
195
  # Add input_file only if it's not None
196
  if input_abs_path is not None:
197
  command.extend(["--input_file", input_abs_path])
 
232
  if save_logs_to_dynamodb is not None:
233
  command.extend(["--save_logs_to_dynamodb", str(save_logs_to_dynamodb)])
234
  if display_file_names_in_logs is not None:
235
+ command.extend(
236
+ ["--display_file_names_in_logs", str(display_file_names_in_logs)]
237
+ )
238
  if upload_logs_to_s3 is not None:
239
  command.extend(["--upload_logs_to_s3", str(upload_logs_to_s3)])
240
  if s3_logs_prefix:
 
252
  if chosen_local_ocr_model:
253
  command.extend(["--chosen_local_ocr_model", chosen_local_ocr_model])
254
  if preprocess_local_ocr_images is not None:
255
+ command.extend(
256
+ ["--preprocess_local_ocr_images", str(preprocess_local_ocr_images)]
257
+ )
258
  if compress_redacted_pdf is not None:
259
  command.extend(["--compress_redacted_pdf", str(compress_redacted_pdf)])
260
  if return_pdf_end_of_redaction is not None:
261
+ command.extend(
262
+ ["--return_pdf_end_of_redaction", str(return_pdf_end_of_redaction)]
263
+ )
264
  if deny_list_file and os.path.isfile(deny_list_file):
265
  command.extend(["--deny_list_file", os.path.abspath(deny_list_file)])
266
  if allow_list_file and os.path.isfile(allow_list_file):
267
  command.extend(["--allow_list_file", os.path.abspath(allow_list_file)])
268
  if redact_whole_page_file and os.path.isfile(redact_whole_page_file):
269
+ command.extend(
270
+ ["--redact_whole_page_file", os.path.abspath(redact_whole_page_file)]
271
+ )
272
  if handwrite_signature_extraction:
273
  command.append("--handwrite_signature_extraction")
274
  command.extend(handwrite_signature_extraction)
 
291
  if fuzzy_mistakes is not None:
292
  command.extend(["--fuzzy_mistakes", str(fuzzy_mistakes)])
293
  if match_fuzzy_whole_phrase_bool is not None:
294
+ command.extend(
295
+ ["--match_fuzzy_whole_phrase_bool", str(match_fuzzy_whole_phrase_bool)]
296
+ )
297
 
298
  # Add duplicate detection arguments
299
  if duplicate_type:
 
326
  if textract_output_prefix:
327
  command.extend(["--textract_output_prefix", textract_output_prefix])
328
  if s3_textract_document_logs_subfolder:
329
+ command.extend(
330
+ [
331
+ "--s3_textract_document_logs_subfolder",
332
+ s3_textract_document_logs_subfolder,
333
+ ]
334
+ )
335
  if local_textract_document_logs_subfolder:
336
+ command.extend(
337
+ [
338
+ "--local_textract_document_logs_subfolder",
339
+ local_textract_document_logs_subfolder,
340
+ ]
341
+ )
342
  if poll_interval is not None:
343
  command.extend(["--poll_interval", str(poll_interval)])
344
  if max_poll_attempts is not None:
345
  command.extend(["--max_poll_attempts", str(max_poll_attempts)])
346
 
347
  # Filter out None values before joining
348
+ command_str = " ".join(str(arg) for arg in command if arg is not None)
349
  print(f"Executing command: {command_str}")
350
 
351
  # 3. Execute the command using subprocess
 
355
  stdout=subprocess.PIPE,
356
  stderr=subprocess.PIPE,
357
  text=True,
358
+ cwd=script_folder, # Important for relative paths within the script
359
  )
360
 
361
  # Communicate with the process to get output and handle timeout
362
  stdout, stderr = result.communicate(timeout=timeout)
363
+
364
  print("--- SCRIPT STDOUT ---")
365
  if stdout:
366
  print(stdout)
 
387
 
388
  class TestCLIRedactExamples(unittest.TestCase):
389
  """Test suite for CLI redaction examples from the epilog."""
390
+
391
  @classmethod
392
  def setUpClass(cls):
393
  """Set up test environment before running tests."""
394
+ cls.script_path = os.path.join(
395
+ os.path.dirname(os.path.dirname(__file__)), "cli_redact.py"
396
+ )
397
+ cls.example_data_dir = os.path.join(
398
+ os.path.dirname(os.path.dirname(__file__)), "example_data"
399
+ )
400
  cls.temp_output_dir = tempfile.mkdtemp(prefix="test_output_")
401
+
402
  # Verify script exists
403
  if not os.path.isfile(cls.script_path):
404
  raise FileNotFoundError(f"CLI script not found: {cls.script_path}")
405
+
406
  print(f"Test setup complete. Script: {cls.script_path}")
407
  print(f"Example data directory: {cls.example_data_dir}")
408
  print(f"Temp output directory: {cls.temp_output_dir}")
409
+
410
  @classmethod
411
  def tearDownClass(cls):
412
  """Clean up test environment after running tests."""
413
  if os.path.exists(cls.temp_output_dir):
414
  shutil.rmtree(cls.temp_output_dir)
415
  print(f"Cleaned up temp directory: {cls.temp_output_dir}")
416
+
417
  def test_pdf_redaction_default_settings(self):
418
  """Test: Redact a PDF with default settings (local OCR)"""
419
  print("\n=== Testing PDF redaction with default settings ===")
420
+ input_file = os.path.join(
421
+ self.example_data_dir,
422
+ "example_of_emails_sent_to_a_professor_before_applying.pdf",
423
+ )
424
+
425
  if not os.path.isfile(input_file):
426
  self.skipTest(f"Example file not found: {input_file}")
427
+
428
  result = run_cli_redact(
429
  script_path=self.script_path,
430
  input_file=input_file,
431
+ output_dir=self.temp_output_dir,
432
  )
433
+
434
  self.assertTrue(result, "PDF redaction with default settings should succeed")
435
  print("✅ PDF redaction with default settings passed")
436
+
437
  def test_pdf_text_extraction_only(self):
438
  """Test: Extract text from a PDF only (i.e. no redaction), using local OCR"""
439
  print("\n=== Testing PDF text extraction only ===")
440
+ input_file = os.path.join(
441
+ self.example_data_dir, "Partnership-Agreement-Toolkit_0_0.pdf"
442
+ )
443
+ whole_page_file = os.path.join(
444
+ self.example_data_dir, "partnership_toolkit_redact_some_pages.csv"
445
+ )
446
+
447
  if not os.path.isfile(input_file):
448
  self.skipTest(f"Example file not found: {input_file}")
449
  if not os.path.isfile(whole_page_file):
450
  self.skipTest(f"Whole page file not found: {whole_page_file}")
451
+
452
  result = run_cli_redact(
453
  script_path=self.script_path,
454
  input_file=input_file,
455
  output_dir=self.temp_output_dir,
456
  redact_whole_page_file=whole_page_file,
457
+ pii_detector="None",
458
  )
459
+
460
  self.assertTrue(result, "PDF text extraction should succeed")
461
  print("✅ PDF text extraction only passed")
462
+
463
  def test_pdf_text_extraction_with_whole_page_redaction(self):
464
  """Test: Extract text from a PDF only with a whole page redaction list"""
465
  print("\n=== Testing PDF text extraction with whole page redaction ===")
466
+ input_file = os.path.join(
467
+ self.example_data_dir, "Partnership-Agreement-Toolkit_0_0.pdf"
468
+ )
469
+ whole_page_file = os.path.join(
470
+ self.example_data_dir, "partnership_toolkit_redact_some_pages.csv"
471
+ )
472
+
473
  if not os.path.isfile(input_file):
474
  self.skipTest(f"Example file not found: {input_file}")
475
  if not os.path.isfile(whole_page_file):
476
  self.skipTest(f"Whole page file not found: {whole_page_file}")
477
+
478
  result = run_cli_redact(
479
  script_path=self.script_path,
480
  input_file=input_file,
481
  output_dir=self.temp_output_dir,
482
  redact_whole_page_file=whole_page_file,
483
  pii_detector="Local",
484
+ local_redact_entities=["CUSTOM"],
485
+ )
486
+
487
+ self.assertTrue(
488
+ result, "PDF text extraction with whole page redaction should succeed"
489
  )
 
 
490
  print("✅ PDF text extraction with whole page redaction passed")
491
+
492
  def test_pdf_redaction_with_allow_list(self):
493
  """Test: Redact a PDF with allow list (local OCR) and custom list of redaction entities"""
494
  print("\n=== Testing PDF redaction with allow list ===")
495
+ input_file = os.path.join(
496
+ self.example_data_dir, "graduate-job-example-cover-letter.pdf"
497
+ )
498
+ allow_list_file = os.path.join(
499
+ self.example_data_dir, "test_allow_list_graduate.csv"
500
+ )
501
+
502
  if not os.path.isfile(input_file):
503
  self.skipTest(f"Example file not found: {input_file}")
504
  if not os.path.isfile(allow_list_file):
505
  self.skipTest(f"Allow list file not found: {allow_list_file}")
506
+
507
  result = run_cli_redact(
508
  script_path=self.script_path,
509
  input_file=input_file,
510
  output_dir=self.temp_output_dir,
511
  allow_list_file=allow_list_file,
512
+ local_redact_entities=["TITLES", "PERSON", "DATE_TIME"],
513
  )
514
+
515
  self.assertTrue(result, "PDF redaction with allow list should succeed")
516
  print("✅ PDF redaction with allow list passed")
517
+
518
  def test_pdf_redaction_limited_pages_with_custom_fuzzy(self):
519
  """Test: Redact a PDF with limited pages and text extraction method with custom fuzzy matching"""
520
  print("\n=== Testing PDF redaction with limited pages and fuzzy matching ===")
521
+ input_file = os.path.join(
522
+ self.example_data_dir, "Partnership-Agreement-Toolkit_0_0.pdf"
523
+ )
524
+ deny_list_file = os.path.join(
525
+ self.example_data_dir,
526
+ "Partnership-Agreement-Toolkit_test_deny_list_para_single_spell.csv",
527
+ )
528
+
529
  if not os.path.isfile(input_file):
530
  self.skipTest(f"Example file not found: {input_file}")
531
  if not os.path.isfile(deny_list_file):
532
  self.skipTest(f"Deny list file not found: {deny_list_file}")
533
+
534
  result = run_cli_redact(
535
  script_path=self.script_path,
536
  input_file=input_file,
 
540
  page_min=1,
541
  page_max=3,
542
  ocr_method="Local text",
543
+ fuzzy_mistakes=3,
544
+ )
545
+
546
+ self.assertTrue(
547
+ result, "PDF redaction with limited pages and fuzzy matching should succeed"
548
  )
 
 
549
  print("✅ PDF redaction with limited pages and fuzzy matching passed")
550
+
551
  def test_pdf_redaction_with_custom_lists(self):
552
  """Test: Redaction with custom deny list, allow list, and whole page redaction list"""
553
  print("\n=== Testing PDF redaction with custom lists ===")
554
+ input_file = os.path.join(
555
+ self.example_data_dir, "Partnership-Agreement-Toolkit_0_0.pdf"
556
+ )
557
+ deny_list_file = os.path.join(
558
+ self.example_data_dir, "partnership_toolkit_redact_custom_deny_list.csv"
559
+ )
560
+ whole_page_file = os.path.join(
561
+ self.example_data_dir, "partnership_toolkit_redact_some_pages.csv"
562
+ )
563
+ allow_list_file = os.path.join(
564
+ self.example_data_dir, "test_allow_list_partnership.csv"
565
+ )
566
+
567
  if not os.path.isfile(input_file):
568
  self.skipTest(f"Example file not found: {input_file}")
569
  if not os.path.isfile(deny_list_file):
 
572
  self.skipTest(f"Whole page file not found: {whole_page_file}")
573
  if not os.path.isfile(allow_list_file):
574
  self.skipTest(f"Allow list file not found: {allow_list_file}")
575
+
576
  result = run_cli_redact(
577
  script_path=self.script_path,
578
  input_file=input_file,
579
  output_dir=self.temp_output_dir,
580
  deny_list_file=deny_list_file,
581
  redact_whole_page_file=whole_page_file,
582
+ allow_list_file=allow_list_file,
583
  )
584
+
585
  self.assertTrue(result, "PDF redaction with custom lists should succeed")
586
  print("✅ PDF redaction with custom lists passed")
587
+
588
  def test_image_redaction(self):
589
  """Test: Redact an image"""
590
  print("\n=== Testing image redaction ===")
591
  input_file = os.path.join(self.example_data_dir, "example_complaint_letter.jpg")
592
+
593
  if not os.path.isfile(input_file):
594
  self.skipTest(f"Example file not found: {input_file}")
595
+
596
  result = run_cli_redact(
597
  script_path=self.script_path,
598
  input_file=input_file,
599
+ output_dir=self.temp_output_dir,
600
  )
601
+
602
  self.assertTrue(result, "Image redaction should succeed")
603
  print("✅ Image redaction passed")
604
+
605
  def test_csv_anonymisation_specific_columns(self):
606
  """Test: Anonymise csv file with specific columns"""
607
  print("\n=== Testing CSV anonymisation with specific columns ===")
608
  input_file = os.path.join(self.example_data_dir, "combined_case_notes.csv")
609
+
610
  if not os.path.isfile(input_file):
611
  self.skipTest(f"Example file not found: {input_file}")
612
+
613
  result = run_cli_redact(
614
  script_path=self.script_path,
615
  input_file=input_file,
616
  output_dir=self.temp_output_dir,
617
  text_columns=["Case Note", "Client"],
618
+ anon_strategy="replace_redacted",
619
+ )
620
+
621
+ self.assertTrue(
622
+ result, "CSV anonymisation with specific columns should succeed"
623
  )
 
 
624
  print("✅ CSV anonymisation with specific columns passed")
625
+
626
  def test_csv_anonymisation_different_strategy(self):
627
  """Test: Anonymise csv file with a different strategy (remove text completely)"""
628
  print("\n=== Testing CSV anonymisation with different strategy ===")
629
  input_file = os.path.join(self.example_data_dir, "combined_case_notes.csv")
630
+
631
  if not os.path.isfile(input_file):
632
  self.skipTest(f"Example file not found: {input_file}")
633
+
634
  result = run_cli_redact(
635
  script_path=self.script_path,
636
  input_file=input_file,
637
  output_dir=self.temp_output_dir,
638
  text_columns=["Case Note", "Client"],
639
+ anon_strategy="redact",
640
+ )
641
+
642
+ self.assertTrue(
643
+ result, "CSV anonymisation with different strategy should succeed"
644
  )
 
 
645
  print("✅ CSV anonymisation with different strategy passed")
646
+
647
  def test_word_document_anonymisation(self):
648
  """Test: Anonymise a word document"""
649
  print("\n=== Testing Word document anonymisation ===")
650
+ input_file = os.path.join(
651
+ self.example_data_dir, "Bold minimalist professional cover letter.docx"
652
+ )
653
+
654
  if not os.path.isfile(input_file):
655
  self.skipTest(f"Example file not found: {input_file}")
656
+
657
  result = run_cli_redact(
658
  script_path=self.script_path,
659
  input_file=input_file,
660
  output_dir=self.temp_output_dir,
661
+ anon_strategy="replace_redacted",
662
  )
663
+
664
  self.assertTrue(result, "Word document anonymisation should succeed")
665
  print("✅ Word document anonymisation passed")
666
+
667
  def test_aws_textract_comprehend_redaction(self):
668
  """Test: Use Textract and Comprehend for redaction"""
669
  print("\n=== Testing AWS Textract and Comprehend redaction ===")
670
+ input_file = os.path.join(
671
+ self.example_data_dir,
672
+ "example_of_emails_sent_to_a_professor_before_applying.pdf",
673
+ )
674
+
675
  if not os.path.isfile(input_file):
676
  self.skipTest(f"Example file not found: {input_file}")
677
+
678
  # Skip this test if AWS credentials are not available
679
  # This is a conditional test that may not work in all environments
680
+ run_cli_redact(
681
  script_path=self.script_path,
682
  input_file=input_file,
683
  output_dir=self.temp_output_dir,
684
  ocr_method="AWS Textract",
685
+ pii_detector="AWS Comprehend",
686
  )
687
+
688
  # Note: This test may fail if AWS credentials are not configured
689
  # We'll mark it as passed if it runs without crashing
690
  print("✅ AWS Textract and Comprehend redaction test completed")
691
+
692
  def test_aws_textract_signature_extraction(self):
693
  """Test: Redact specific pages with AWS OCR and signature extraction"""
694
  print("\n=== Testing AWS Textract with signature extraction ===")
695
+ input_file = os.path.join(
696
+ self.example_data_dir, "Partnership-Agreement-Toolkit_0_0.pdf"
697
+ )
698
+
699
  if not os.path.isfile(input_file):
700
  self.skipTest(f"Example file not found: {input_file}")
701
+
702
  # Skip this test if AWS credentials are not available
703
+ run_cli_redact(
704
  script_path=self.script_path,
705
  input_file=input_file,
706
  output_dir=self.temp_output_dir,
707
  page_min=6,
708
  page_max=7,
709
  ocr_method="AWS Textract",
710
+ handwrite_signature_extraction=[
711
+ "Extract handwriting",
712
+ "Extract signatures",
713
+ ],
714
  )
715
+
716
  # Note: This test may fail if AWS credentials are not configured
717
  print("✅ AWS Textract with signature extraction test completed")
718
+
719
  def test_duplicate_pages_detection(self):
720
  """Test: Find duplicate pages in OCR files"""
721
  print("\n=== Testing duplicate pages detection ===")
722
+ input_file = os.path.join(
723
+ self.example_data_dir,
724
+ "example_outputs",
725
+ "doubled_output_joined.pdf_ocr_output.csv",
726
+ )
727
+
728
  if not os.path.isfile(input_file):
729
  self.skipTest(f"Example OCR file not found: {input_file}")
730
+
731
  result = run_cli_redact(
732
  script_path=self.script_path,
733
  input_file=input_file,
734
  output_dir=self.temp_output_dir,
735
  task="deduplicate",
736
  duplicate_type="pages",
737
+ similarity_threshold=0.95,
738
  )
739
+
740
  self.assertTrue(result, "Duplicate pages detection should succeed")
741
  print("✅ Duplicate pages detection passed")
742
+
743
  def test_duplicate_line_level_detection(self):
744
  """Test: Find duplicate in OCR files at the line level"""
745
  print("\n=== Testing duplicate line level detection ===")
746
+ input_file = os.path.join(
747
+ self.example_data_dir,
748
+ "example_outputs",
749
+ "doubled_output_joined.pdf_ocr_output.csv",
750
+ )
751
+
752
  if not os.path.isfile(input_file):
753
  self.skipTest(f"Example OCR file not found: {input_file}")
754
+
755
  result = run_cli_redact(
756
  script_path=self.script_path,
757
  input_file=input_file,
 
760
  duplicate_type="pages",
761
  similarity_threshold=0.95,
762
  combine_pages=False,
763
+ min_word_count=3,
764
  )
765
+
766
  self.assertTrue(result, "Duplicate line level detection should succeed")
767
  print("✅ Duplicate line level detection passed")
768
+
769
  def test_duplicate_tabular_detection(self):
770
  """Test: Find duplicate rows in tabular data"""
771
  print("\n=== Testing duplicate tabular detection ===")
772
+ input_file = os.path.join(
773
+ self.example_data_dir, "Lambeth_2030-Our_Future_Our_Lambeth.pdf.csv"
774
+ )
775
+
776
  if not os.path.isfile(input_file):
777
  self.skipTest(f"Example CSV file not found: {input_file}")
778
+
779
  result = run_cli_redact(
780
  script_path=self.script_path,
781
  input_file=input_file,
 
783
  task="deduplicate",
784
  duplicate_type="tabular",
785
  text_columns=["text"],
786
+ similarity_threshold=0.95,
787
  )
788
+
789
  self.assertTrue(result, "Duplicate tabular detection should succeed")
790
  print("✅ Duplicate tabular detection passed")
791
+
792
  def test_textract_submit_document(self):
793
  """Test: Submit document to Textract for basic text analysis"""
794
  print("\n=== Testing Textract document submission ===")
795
+ input_file = os.path.join(
796
+ self.example_data_dir,
797
+ "example_of_emails_sent_to_a_professor_before_applying.pdf",
798
+ )
799
+
800
  if not os.path.isfile(input_file):
801
  self.skipTest(f"Example file not found: {input_file}")
802
+
803
  # Skip this test if AWS credentials are not available
804
  try:
805
+ run_cli_redact(
806
  script_path=self.script_path,
807
  input_file=input_file,
808
  output_dir=self.temp_output_dir,
809
  task="textract",
810
+ textract_action="submit",
811
  )
812
  except Exception as e:
813
  print(f"Textract test failed (expected without AWS credentials): {e}")
814
+
 
815
  # Note: This test may fail if AWS credentials are not configured
816
  print("✅ Textract document submission test completed")
817
+
818
  def test_textract_submit_with_signatures(self):
819
  """Test: Submit document to Textract for analysis with signature extraction"""
820
  print("\n=== Testing Textract submission with signature extraction ===")
821
+ input_file = os.path.join(
822
+ self.example_data_dir, "Partnership-Agreement-Toolkit_0_0.pdf"
823
+ )
824
+
825
  if not os.path.isfile(input_file):
826
  self.skipTest(f"Example file not found: {input_file}")
827
+
828
  # Skip this test if AWS credentials are not available
829
  try:
830
+ run_cli_redact(
831
  script_path=self.script_path,
832
  input_file=input_file,
833
  output_dir=self.temp_output_dir,
834
  task="textract",
835
  textract_action="submit",
836
+ extract_signatures=True,
837
  )
838
  except Exception as e:
839
  print(f"Textract test failed (expected without AWS credentials): {e}")
840
+
 
841
  # Note: This test may fail if AWS credentials are not configured
842
  print("✅ Textract submission with signature extraction test completed")
843
+
844
  def test_textract_retrieve_results(self):
845
  """Test: Retrieve Textract results by job ID"""
846
  print("\n=== Testing Textract results retrieval ===")
847
+
848
  # Skip this test if AWS credentials are not available
849
  # This would require a valid job ID from a previous submission
850
  # For retrieve and list actions, we don't need a real input file
851
  try:
852
+ run_cli_redact(
853
  script_path=self.script_path,
854
  input_file=None, # No input file needed for retrieve action
855
  output_dir=self.temp_output_dir,
856
  task="textract",
857
  textract_action="retrieve",
858
+ job_id="12345678-1234-1234-1234-123456789012", # Dummy job ID
859
  )
860
  except Exception as e:
861
  print(f"Textract test failed (expected without AWS credentials): {e}")
862
+
 
863
  # Note: This test will likely fail with a dummy job ID, but that's expected
864
  print("✅ Textract results retrieval test completed")
865
+
866
  def test_textract_list_jobs(self):
867
  """Test: List recent Textract jobs"""
868
  print("\n=== Testing Textract jobs listing ===")
869
+
870
  # Skip this test if AWS credentials are not available
871
  # For list action, we don't need a real input file
872
  try:
873
+ run_cli_redact(
874
  script_path=self.script_path,
875
  input_file=None, # No input file needed for list action
876
  output_dir=self.temp_output_dir,
877
  task="textract",
878
+ textract_action="list",
879
  )
880
  except Exception as e:
881
  print(f"Textract test failed (expected without AWS credentials): {e}")
882
+
 
883
  # Note: This test may fail if AWS credentials are not configured
884
  print("✅ Textract jobs listing test completed")
885
 
886
 
887
+ class TestGUIApp(unittest.TestCase):
888
+ """Test suite for GUI application loading and basic functionality."""
889
+
890
+ @classmethod
891
+ def setUpClass(cls):
892
+ """Set up test environment for GUI tests."""
893
+ cls.app_path = os.path.join(
894
+ os.path.dirname(os.path.dirname(__file__)), "app.py"
895
+ )
896
+
897
+ # Verify app.py exists
898
+ if not os.path.isfile(cls.app_path):
899
+ raise FileNotFoundError(f"App file not found: {cls.app_path}")
900
+
901
+ print(f"GUI test setup complete. App: {cls.app_path}")
902
+
903
+ def test_app_import_and_initialization(self):
904
+ """Test: Import app.py and check if the Gradio app object is created successfully."""
905
+ print("\n=== Testing GUI app import and initialization ===")
906
+
907
+ try:
908
+ # Add the parent directory to the path so we can import app
909
+ parent_dir = os.path.dirname(os.path.dirname(__file__))
910
+ if parent_dir not in sys.path:
911
+ sys.path.insert(0, parent_dir)
912
+
913
+ # Import the app module
914
+ import app
915
+
916
+ # Check if the app object exists and is a Gradio Blocks object
917
+ self.assertTrue(hasattr(app, 'app'), "App object should exist in the module")
918
+
919
+ # Check if it's a Gradio Blocks instance
920
+ import gradio as gr
921
+ self.assertIsInstance(app.app, gr.Blocks, "App should be a Gradio Blocks instance")
922
+
923
+ print("✅ GUI app import and initialization passed")
924
+
925
+ except ImportError as e:
926
+ error_msg = f"Failed to import app module: {e}"
927
+ if "gradio_image_annotation" in str(e):
928
+ error_msg += "\n\nNOTE: This test requires the 'redaction' conda environment to be activated."
929
+ error_msg += "\nPlease run: conda activate redaction"
930
+ error_msg += "\nThen run this test again."
931
+ self.fail(error_msg)
932
+ except Exception as e:
933
+ self.fail(f"Unexpected error during app initialization: {e}")
934
+
935
+ def test_app_launch_headless(self):
936
+ """Test: Launch the app in headless mode to verify it starts without errors."""
937
+ print("\n=== Testing GUI app launch in headless mode ===")
938
+
939
+ try:
940
+ # Add the parent directory to the path
941
+ parent_dir = os.path.dirname(os.path.dirname(__file__))
942
+ if parent_dir not in sys.path:
943
+ sys.path.insert(0, parent_dir)
944
+
945
+ # Import the app module
946
+ import app
947
+ import gradio as gr
948
+
949
+ # Set up a flag to track if the app launched successfully
950
+ app_launched = threading.Event()
951
+ launch_error = None
952
+
953
+ def launch_app():
954
+ try:
955
+ # Launch the app in headless mode with a short timeout
956
+ app.app.launch(
957
+ show_error=True,
958
+ inbrowser=False, # Don't open browser
959
+ server_port=0, # Use any available port
960
+ quiet=True, # Suppress output
961
+ prevent_thread_lock=True # Don't block the main thread
962
+ )
963
+ app_launched.set()
964
+ except Exception as e:
965
+ launch_error = e
966
+ app_launched.set()
967
+
968
+ # Start the app in a separate thread
969
+ launch_thread = threading.Thread(target=launch_app)
970
+ launch_thread.daemon = True
971
+ launch_thread.start()
972
+
973
+ # Wait for the app to launch (with timeout)
974
+ if app_launched.wait(timeout=10): # 10 second timeout
975
+ if launch_error:
976
+ self.fail(f"App launch failed: {launch_error}")
977
+ else:
978
+ print("✅ GUI app launch in headless mode passed")
979
+ else:
980
+ self.fail("App launch timed out after 10 seconds")
981
+
982
+ except Exception as e:
983
+ error_msg = f"Unexpected error during app launch test: {e}"
984
+ if "gradio_image_annotation" in str(e):
985
+ error_msg += "\n\nNOTE: This test requires the 'redaction' conda environment to be activated."
986
+ error_msg += "\nPlease run: conda activate redaction"
987
+ error_msg += "\nThen run this test again."
988
+ self.fail(error_msg)
989
+
990
+ def test_app_configuration_loading(self):
991
+ """Test: Verify that the app can load its configuration without errors."""
992
+ print("\n=== Testing GUI app configuration loading ===")
993
+
994
+ try:
995
+ # Add the parent directory to the path
996
+ parent_dir = os.path.dirname(os.path.dirname(__file__))
997
+ if parent_dir not in sys.path:
998
+ sys.path.insert(0, parent_dir)
999
+
1000
+ # Import the app module
1001
+ import app
1002
+
1003
+ # Check if key configuration variables are accessible
1004
+ # These should be imported from tools.config
1005
+ from tools.config import (
1006
+ GRADIO_SERVER_PORT,
1007
+ MAX_FILE_SIZE,
1008
+ DEFAULT_LANGUAGE,
1009
+ PII_DETECTION_MODELS
1010
+ )
1011
+
1012
+ # Verify these are not None/empty
1013
+ self.assertIsNotNone(GRADIO_SERVER_PORT, "GRADIO_SERVER_PORT should be configured")
1014
+ self.assertIsNotNone(MAX_FILE_SIZE, "MAX_FILE_SIZE should be configured")
1015
+ self.assertIsNotNone(DEFAULT_LANGUAGE, "DEFAULT_LANGUAGE should be configured")
1016
+ self.assertIsNotNone(PII_DETECTION_MODELS, "PII_DETECTION_MODELS should be configured")
1017
+
1018
+ print("✅ GUI app configuration loading passed")
1019
+
1020
+ except ImportError as e:
1021
+ error_msg = f"Failed to import configuration: {e}"
1022
+ if "gradio_image_annotation" in str(e):
1023
+ error_msg += "\n\nNOTE: This test requires the 'redaction' conda environment to be activated."
1024
+ error_msg += "\nPlease run: conda activate redaction"
1025
+ error_msg += "\nThen run this test again."
1026
+ self.fail(error_msg)
1027
+ except Exception as e:
1028
+ error_msg = f"Unexpected error during configuration test: {e}"
1029
+ if "gradio_image_annotation" in str(e):
1030
+ error_msg += "\n\nNOTE: This test requires the 'redaction' conda environment to be activated."
1031
+ error_msg += "\nPlease run: conda activate redaction"
1032
+ error_msg += "\nThen run this test again."
1033
+ self.fail(error_msg)
1034
+
1035
+
1036
  def run_all_tests():
1037
  """Run all test examples and report results."""
1038
  print("=" * 80)
1039
+ print("DOCUMENT REDACTION TEST SUITE")
1040
  print("=" * 80)
1041
+ print("This test suite includes:")
1042
+ print("- CLI examples from the epilog")
1043
+ print("- GUI application loading and initialization tests")
1044
  print("Tests will be skipped if required example files are not found.")
1045
  print("AWS-related tests may fail if credentials are not configured.")
1046
  print("=" * 80)
1047
+
1048
  # Create test suite
1049
  loader = unittest.TestLoader()
1050
+ suite = unittest.TestSuite()
1051
 
1052
+ # Add CLI tests
1053
+ cli_suite = loader.loadTestsFromTestCase(TestCLIRedactExamples)
1054
+ suite.addTests(cli_suite)
1055
+
1056
+ # Add GUI tests
1057
+ gui_suite = loader.loadTestsFromTestCase(TestGUIApp)
1058
+ suite.addTests(gui_suite)
1059
+
1060
  # Run tests with detailed output
1061
  runner = unittest.TextTestRunner(verbosity=2, stream=None)
1062
  result = runner.run(suite)
1063
+
1064
  # Print summary
1065
  print("\n" + "=" * 80)
1066
  print("TEST SUMMARY")
 
1069
  print(f"Failures: {len(result.failures)}")
1070
  print(f"Errors: {len(result.errors)}")
1071
  print(f"Skipped: {len(result.skipped) if hasattr(result, 'skipped') else 0}")
1072
+
1073
  if result.failures:
1074
  print("\nFAILURES:")
1075
  for test, traceback in result.failures:
1076
  print(f"- {test}: {traceback}")
1077
+
1078
  if result.errors:
1079
  print("\nERRORS:")
1080
  for test, traceback in result.errors:
1081
  print(f"- {test}: {traceback}")
1082
+
1083
  success = len(result.failures) == 0 and len(result.errors) == 0
1084
  print(f"\nOverall result: {'✅ PASSED' if success else '❌ FAILED'}")
1085
  print("=" * 80)
1086
+
1087
  return success
1088
 
1089
 
1090
  if __name__ == "__main__":
1091
  # Run the test suite
1092
  success = run_all_tests()
1093
+ exit(0 if success else 1)
test/test_gui_only.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Standalone GUI test script for the document redaction application.
4
+
5
+ This script tests only the GUI functionality of app.py to ensure it loads correctly.
6
+ Run this script to verify that the Gradio interface can be imported and initialized.
7
+ """
8
+
9
+ import os
10
+ import sys
11
+ import unittest
12
+ import threading
13
+ import time
14
+
15
+ # Add the parent directory to the path so we can import the app
16
+ parent_dir = os.path.dirname(os.path.dirname(__file__))
17
+ if parent_dir not in sys.path:
18
+ sys.path.insert(0, parent_dir)
19
+
20
+
21
+ class TestGUIAppOnly(unittest.TestCase):
22
+ """Test suite for GUI application loading and basic functionality."""
23
+
24
+ @classmethod
25
+ def setUpClass(cls):
26
+ """Set up test environment for GUI tests."""
27
+ cls.app_path = os.path.join(parent_dir, "app.py")
28
+
29
+ # Verify app.py exists
30
+ if not os.path.isfile(cls.app_path):
31
+ raise FileNotFoundError(f"App file not found: {cls.app_path}")
32
+
33
+ print(f"GUI test setup complete. App: {cls.app_path}")
34
+
35
+ def test_app_import_and_initialization(self):
36
+ """Test: Import app.py and check if the Gradio app object is created successfully."""
37
+ print("\n=== Testing GUI app import and initialization ===")
38
+
39
+ try:
40
+ # Import the app module
41
+ import app
42
+
43
+ # Check if the app object exists and is a Gradio Blocks object
44
+ self.assertTrue(hasattr(app, 'app'), "App object should exist in the module")
45
+
46
+ # Check if it's a Gradio Blocks instance
47
+ import gradio as gr
48
+ self.assertIsInstance(app.app, gr.Blocks, "App should be a Gradio Blocks instance")
49
+
50
+ print("✅ GUI app import and initialization passed")
51
+
52
+ except ImportError as e:
53
+ error_msg = f"Failed to import app module: {e}"
54
+ if "gradio_image_annotation" in str(e):
55
+ error_msg += "\n\nNOTE: This test requires the 'redaction' conda environment to be activated."
56
+ error_msg += "\nPlease run: conda activate redaction"
57
+ error_msg += "\nThen run this test again."
58
+ self.fail(error_msg)
59
+ except Exception as e:
60
+ self.fail(f"Unexpected error during app initialization: {e}")
61
+
62
+ def test_app_launch_headless(self):
63
+ """Test: Launch the app in headless mode to verify it starts without errors."""
64
+ print("\n=== Testing GUI app launch in headless mode ===")
65
+
66
+ try:
67
+ # Import the app module
68
+ import app
69
+ import gradio as gr
70
+
71
+ # Set up a flag to track if the app launched successfully
72
+ app_launched = threading.Event()
73
+ launch_error = None
74
+
75
+ def launch_app():
76
+ try:
77
+ # Launch the app in headless mode with a short timeout
78
+ app.app.launch(
79
+ show_error=True,
80
+ inbrowser=False, # Don't open browser
81
+ server_port=0, # Use any available port
82
+ quiet=True, # Suppress output
83
+ prevent_thread_lock=True # Don't block the main thread
84
+ )
85
+ app_launched.set()
86
+ except Exception as e:
87
+ launch_error = e
88
+ app_launched.set()
89
+
90
+ # Start the app in a separate thread
91
+ launch_thread = threading.Thread(target=launch_app)
92
+ launch_thread.daemon = True
93
+ launch_thread.start()
94
+
95
+ # Wait for the app to launch (with timeout)
96
+ if app_launched.wait(timeout=10): # 10 second timeout
97
+ if launch_error:
98
+ self.fail(f"App launch failed: {launch_error}")
99
+ else:
100
+ print("✅ GUI app launch in headless mode passed")
101
+ else:
102
+ self.fail("App launch timed out after 10 seconds")
103
+
104
+ except Exception as e:
105
+ error_msg = f"Unexpected error during app launch test: {e}"
106
+ if "gradio_image_annotation" in str(e):
107
+ error_msg += "\n\nNOTE: This test requires the 'redaction' conda environment to be activated."
108
+ error_msg += "\nPlease run: conda activate redaction"
109
+ error_msg += "\nThen run this test again."
110
+ self.fail(error_msg)
111
+
112
+ def test_app_configuration_loading(self):
113
+ """Test: Verify that the app can load its configuration without errors."""
114
+ print("\n=== Testing GUI app configuration loading ===")
115
+
116
+ try:
117
+ # Import the app module
118
+ import app
119
+
120
+ # Check if key configuration variables are accessible
121
+ # These should be imported from tools.config
122
+ from tools.config import (
123
+ GRADIO_SERVER_PORT,
124
+ MAX_FILE_SIZE,
125
+ DEFAULT_LANGUAGE,
126
+ PII_DETECTION_MODELS
127
+ )
128
+
129
+ # Verify these are not None/empty
130
+ self.assertIsNotNone(GRADIO_SERVER_PORT, "GRADIO_SERVER_PORT should be configured")
131
+ self.assertIsNotNone(MAX_FILE_SIZE, "MAX_FILE_SIZE should be configured")
132
+ self.assertIsNotNone(DEFAULT_LANGUAGE, "DEFAULT_LANGUAGE should be configured")
133
+ self.assertIsNotNone(PII_DETECTION_MODELS, "PII_DETECTION_MODELS should be configured")
134
+
135
+ print("✅ GUI app configuration loading passed")
136
+
137
+ except ImportError as e:
138
+ error_msg = f"Failed to import configuration: {e}"
139
+ if "gradio_image_annotation" in str(e):
140
+ error_msg += "\n\nNOTE: This test requires the 'redaction' conda environment to be activated."
141
+ error_msg += "\nPlease run: conda activate redaction"
142
+ error_msg += "\nThen run this test again."
143
+ self.fail(error_msg)
144
+ except Exception as e:
145
+ error_msg = f"Unexpected error during configuration test: {e}"
146
+ if "gradio_image_annotation" in str(e):
147
+ error_msg += "\n\nNOTE: This test requires the 'redaction' conda environment to be activated."
148
+ error_msg += "\nPlease run: conda activate redaction"
149
+ error_msg += "\nThen run this test again."
150
+ self.fail(error_msg)
151
+
152
+
153
+ def run_gui_tests():
154
+ """Run GUI tests and report results."""
155
+ print("=" * 80)
156
+ print("DOCUMENT REDACTION GUI TEST SUITE")
157
+ print("=" * 80)
158
+ print("This test suite verifies that the GUI application loads correctly.")
159
+ print("=" * 80)
160
+
161
+ # Create test suite
162
+ loader = unittest.TestLoader()
163
+ suite = loader.loadTestsFromTestCase(TestGUIAppOnly)
164
+
165
+ # Run tests with detailed output
166
+ runner = unittest.TextTestRunner(verbosity=2, stream=None)
167
+ result = runner.run(suite)
168
+
169
+ # Print summary
170
+ print("\n" + "=" * 80)
171
+ print("GUI TEST SUMMARY")
172
+ print("=" * 80)
173
+ print(f"Tests run: {result.testsRun}")
174
+ print(f"Failures: {len(result.failures)}")
175
+ print(f"Errors: {len(result.errors)}")
176
+ print(f"Skipped: {len(result.skipped) if hasattr(result, 'skipped') else 0}")
177
+
178
+ if result.failures:
179
+ print("\nFAILURES:")
180
+ for test, traceback in result.failures:
181
+ print(f"- {test}: {traceback}")
182
+
183
+ if result.errors:
184
+ print("\nERRORS:")
185
+ for test, traceback in result.errors:
186
+ print(f"- {test}: {traceback}")
187
+
188
+ success = len(result.failures) == 0 and len(result.errors) == 0
189
+ print(f"\nOverall result: {'✅ PASSED' if success else '❌ FAILED'}")
190
+ print("=" * 80)
191
+
192
+ return success
193
+
194
+
195
+ if __name__ == "__main__":
196
+ # Run the GUI test suite
197
+ success = run_gui_tests()
198
+ exit(0 if success else 1)
tools/auth.py CHANGED
@@ -1,22 +1,33 @@
1
- #import os
2
- import boto3
3
- #import gradio as gr
4
- import hmac
5
- import hashlib
6
  import base64
7
- from tools.config import AWS_CLIENT_ID, AWS_CLIENT_SECRET, AWS_USER_POOL_ID, AWS_REGION
 
 
 
 
 
 
 
8
 
9
- def calculate_secret_hash(client_id:str, client_secret:str, username:str):
 
10
  message = username + client_id
11
  dig = hmac.new(
12
- str(client_secret).encode('utf-8'),
13
- msg=str(message).encode('utf-8'),
14
- digestmod=hashlib.sha256
15
  ).digest()
16
  secret_hash = base64.b64encode(dig).decode()
17
  return secret_hash
18
 
19
- def authenticate_user(username:str, password:str, user_pool_id:str=AWS_USER_POOL_ID, client_id:str=AWS_CLIENT_ID, client_secret:str=AWS_CLIENT_SECRET):
 
 
 
 
 
 
 
20
  """Authenticates a user against an AWS Cognito user pool.
21
 
22
  Args:
@@ -30,36 +41,38 @@ def authenticate_user(username:str, password:str, user_pool_id:str=AWS_USER_POOL
30
  bool: True if the user is authenticated, False otherwise.
31
  """
32
 
33
- client = boto3.client('cognito-idp', region_name=AWS_REGION) # Cognito Identity Provider client
 
 
34
 
35
  # Compute the secret hash
36
  secret_hash = calculate_secret_hash(client_id, client_secret, username)
37
 
38
  try:
39
 
40
- if client_secret == '':
41
  response = client.initiate_auth(
42
- AuthFlow='USER_PASSWORD_AUTH',
43
  AuthParameters={
44
- 'USERNAME': username,
45
- 'PASSWORD': password,
46
  },
47
- ClientId=client_id
48
  )
49
 
50
  else:
51
  response = client.initiate_auth(
52
- AuthFlow='USER_PASSWORD_AUTH',
53
- AuthParameters={
54
- 'USERNAME': username,
55
- 'PASSWORD': password,
56
- 'SECRET_HASH': secret_hash
57
- },
58
- ClientId=client_id
59
  )
60
 
61
  # If successful, you'll receive an AuthenticationResult in the response
62
- if response.get('AuthenticationResult'):
63
  return True
64
  else:
65
  return False
@@ -72,4 +85,4 @@ def authenticate_user(username:str, password:str, user_pool_id:str=AWS_USER_POOL
72
  out_message = f"An error occurred: {e}"
73
  print(out_message)
74
  raise Exception(out_message)
75
- return False
 
1
+ # import os
 
 
 
 
2
  import base64
3
+ import hashlib
4
+
5
+ # import gradio as gr
6
+ import hmac
7
+
8
+ import boto3
9
+
10
+ from tools.config import AWS_CLIENT_ID, AWS_CLIENT_SECRET, AWS_REGION, AWS_USER_POOL_ID
11
 
12
+
13
+ def calculate_secret_hash(client_id: str, client_secret: str, username: str):
14
  message = username + client_id
15
  dig = hmac.new(
16
+ str(client_secret).encode("utf-8"),
17
+ msg=str(message).encode("utf-8"),
18
+ digestmod=hashlib.sha256,
19
  ).digest()
20
  secret_hash = base64.b64encode(dig).decode()
21
  return secret_hash
22
 
23
+
24
+ def authenticate_user(
25
+ username: str,
26
+ password: str,
27
+ user_pool_id: str = AWS_USER_POOL_ID,
28
+ client_id: str = AWS_CLIENT_ID,
29
+ client_secret: str = AWS_CLIENT_SECRET,
30
+ ):
31
  """Authenticates a user against an AWS Cognito user pool.
32
 
33
  Args:
 
41
  bool: True if the user is authenticated, False otherwise.
42
  """
43
 
44
+ client = boto3.client(
45
+ "cognito-idp", region_name=AWS_REGION
46
+ ) # Cognito Identity Provider client
47
 
48
  # Compute the secret hash
49
  secret_hash = calculate_secret_hash(client_id, client_secret, username)
50
 
51
  try:
52
 
53
+ if client_secret == "":
54
  response = client.initiate_auth(
55
+ AuthFlow="USER_PASSWORD_AUTH",
56
  AuthParameters={
57
+ "USERNAME": username,
58
+ "PASSWORD": password,
59
  },
60
+ ClientId=client_id,
61
  )
62
 
63
  else:
64
  response = client.initiate_auth(
65
+ AuthFlow="USER_PASSWORD_AUTH",
66
+ AuthParameters={
67
+ "USERNAME": username,
68
+ "PASSWORD": password,
69
+ "SECRET_HASH": secret_hash,
70
+ },
71
+ ClientId=client_id,
72
  )
73
 
74
  # If successful, you'll receive an AuthenticationResult in the response
75
+ if response.get("AuthenticationResult"):
76
  return True
77
  else:
78
  return False
 
85
  out_message = f"An error occurred: {e}"
86
  print(out_message)
87
  raise Exception(out_message)
88
+ return False
tools/aws_functions.py CHANGED
@@ -1,28 +1,37 @@
1
- from typing import Type, List
2
- import pandas as pd
3
- import boto3
4
- import tempfile
5
  import os
6
- from tools.config import AWS_REGION, RUN_AWS_FUNCTIONS, DOCUMENT_REDACTION_BUCKET, SAVE_LOGS_TO_CSV
 
 
 
 
 
 
 
 
 
 
 
7
  PandasDataFrame = Type[pd.DataFrame]
8
 
 
9
  def get_assumed_role_info():
10
- sts_endpoint = 'https://sts.' + AWS_REGION + '.amazonaws.com'
11
- sts = boto3.client('sts', region_name=AWS_REGION, endpoint_url=sts_endpoint)
12
  response = sts.get_caller_identity()
13
 
14
  # Extract ARN of the assumed role
15
- assumed_role_arn = response['Arn']
16
-
17
  # Extract the name of the assumed role from the ARN
18
- assumed_role_name = assumed_role_arn.split('/')[-1]
19
-
20
  return assumed_role_arn, assumed_role_name
21
 
 
22
  if RUN_AWS_FUNCTIONS == "1":
23
- try:
24
- session = boto3.Session(region_name=AWS_REGION)
25
-
26
  except Exception as e:
27
  print("Could not start boto3 session:", e)
28
 
@@ -30,14 +39,20 @@ if RUN_AWS_FUNCTIONS == "1":
30
  assumed_role_arn, assumed_role_name = get_assumed_role_info()
31
 
32
  print("Successfully assumed ARN role")
33
- #print("Assumed Role ARN:", assumed_role_arn)
34
- #print("Assumed Role Name:", assumed_role_name)
35
 
36
  except Exception as e:
37
  print("Could not get assumed role from STS:", e)
38
 
 
39
  # Download direct from S3 - requires login credentials
40
- def download_file_from_s3(bucket_name:str, key:str, local_file_path_and_name:str, RUN_AWS_FUNCTIONS:str = RUN_AWS_FUNCTIONS):
 
 
 
 
 
41
 
42
  if RUN_AWS_FUNCTIONS == "1":
43
 
@@ -45,30 +60,39 @@ def download_file_from_s3(bucket_name:str, key:str, local_file_path_and_name:str
45
  # Ensure the local directory exists
46
  os.makedirs(os.path.dirname(local_file_path_and_name), exist_ok=True)
47
 
48
- s3 = boto3.client('s3', region_name=AWS_REGION)
49
  s3.download_file(bucket_name, key, local_file_path_and_name)
50
- print(f"File downloaded from s3://{bucket_name}/{key} to {local_file_path_and_name}")
 
 
51
  except Exception as e:
52
  print("Could not download file:", key, "from s3 due to", e)
53
 
54
-
55
- def download_folder_from_s3(bucket_name:str, s3_folder:str, local_folder:str, RUN_AWS_FUNCTIONS:str = RUN_AWS_FUNCTIONS):
 
 
 
 
 
56
  """
57
  Download all files from an S3 folder to a local folder.
58
  """
59
  if RUN_AWS_FUNCTIONS == "1":
60
  if bucket_name and s3_folder and local_folder:
61
 
62
- s3 = boto3.client('s3', region_name=AWS_REGION)
63
 
64
  # List objects in the specified S3 folder
65
  response = s3.list_objects_v2(Bucket=bucket_name, Prefix=s3_folder)
66
 
67
  # Download each object
68
- for obj in response.get('Contents', []):
69
  # Extract object key and construct local file path
70
- object_key = obj['Key']
71
- local_file_path = os.path.join(local_folder, os.path.relpath(object_key, s3_folder))
 
 
72
 
73
  # Create directories if necessary
74
  os.makedirs(os.path.dirname(local_file_path), exist_ok=True)
@@ -76,12 +100,24 @@ def download_folder_from_s3(bucket_name:str, s3_folder:str, local_folder:str, RU
76
  # Download the object
77
  try:
78
  s3.download_file(bucket_name, object_key, local_file_path)
79
- print(f"Downloaded 's3://{bucket_name}/{object_key}' to '{local_file_path}'")
 
 
80
  except Exception as e:
81
  print(f"Error downloading 's3://{bucket_name}/{object_key}':", e)
82
- else: print("One or more required variables are empty, could not download from S3")
83
-
84
- def download_files_from_s3(bucket_name:str, s3_folder:str, local_folder:str, filenames:List[str], RUN_AWS_FUNCTIONS:str = RUN_AWS_FUNCTIONS):
 
 
 
 
 
 
 
 
 
 
85
  """
86
  Download specific files from an S3 folder to a local folder.
87
  """
@@ -89,18 +125,20 @@ def download_files_from_s3(bucket_name:str, s3_folder:str, local_folder:str, fil
89
  if RUN_AWS_FUNCTIONS == "1":
90
  if bucket_name and s3_folder and local_folder and filenames:
91
 
92
- s3 = boto3.client('s3', region_name=AWS_REGION)
93
 
94
  print("Trying to download file: ", filenames)
95
 
96
- if filenames == '*':
97
  # List all objects in the S3 folder
98
  print("Trying to download all files in AWS folder: ", s3_folder)
99
  response = s3.list_objects_v2(Bucket=bucket_name, Prefix=s3_folder)
100
 
101
- print("Found files in AWS folder: ", response.get('Contents', []))
102
 
103
- filenames = [obj['Key'].split('/')[-1] for obj in response.get('Contents', [])]
 
 
104
 
105
  print("Found filenames in AWS folder: ", filenames)
106
 
@@ -114,13 +152,24 @@ def download_files_from_s3(bucket_name:str, s3_folder:str, local_folder:str, fil
114
  # Download the object
115
  try:
116
  s3.download_file(bucket_name, object_key, local_file_path)
117
- print(f"Downloaded 's3://{bucket_name}/{object_key}' to '{local_file_path}'")
 
 
118
  except Exception as e:
119
  print(f"Error downloading 's3://{bucket_name}/{object_key}':", e)
120
 
121
- else: print("One or more required variables are empty, could not download from S3")
 
 
 
122
 
123
- def upload_file_to_s3(local_file_paths:List[str], s3_key:str, s3_bucket:str=DOCUMENT_REDACTION_BUCKET, RUN_AWS_FUNCTIONS:str = RUN_AWS_FUNCTIONS):
 
 
 
 
 
 
124
  """
125
  Uploads a file from local machine to Amazon S3.
126
 
@@ -139,14 +188,14 @@ def upload_file_to_s3(local_file_paths:List[str], s3_key:str, s3_bucket:str=DOCU
139
  try:
140
  if s3_bucket and s3_key and local_file_paths:
141
 
142
- s3_client = boto3.client('s3', region_name=AWS_REGION)
143
 
144
  if isinstance(local_file_paths, str):
145
  local_file_paths = [local_file_paths]
146
 
147
  for file in local_file_paths:
148
  if s3_client:
149
- #print(s3_client)
150
  try:
151
  # Get file name off file path
152
  file_name = os.path.basename(file)
@@ -155,28 +204,40 @@ def upload_file_to_s3(local_file_paths:List[str], s3_key:str, s3_bucket:str=DOCU
155
  print("S3 key: ", s3_key_full)
156
 
157
  s3_client.upload_file(file, s3_bucket, s3_key_full)
158
- out_message = "File " + file_name + " uploaded successfully!"
 
 
159
  print(out_message)
160
-
161
  except Exception as e:
162
  out_message = f"Error uploading file(s): {e}"
163
  print(out_message)
164
 
165
  final_out_message.append(out_message)
166
- final_out_message_str = '\n'.join(final_out_message)
167
-
168
- else: final_out_message_str = "Could not connect to AWS."
169
- else: final_out_message_str = "At least one essential variable is empty, could not upload to S3"
 
 
 
 
170
  except Exception as e:
171
  final_out_message_str = "Could not upload files to S3 due to: " + str(e)
172
  print(final_out_message_str)
173
  else:
174
- final_out_message_str = "App config will not AWS functions"
175
 
176
  return final_out_message_str
177
 
178
 
179
- def upload_log_file_to_s3(local_file_paths:List[str], s3_key:str, s3_bucket:str=DOCUMENT_REDACTION_BUCKET, RUN_AWS_FUNCTIONS:str = RUN_AWS_FUNCTIONS, SAVE_LOGS_TO_CSV:str=SAVE_LOGS_TO_CSV):
 
 
 
 
 
 
180
  """
181
  Uploads a log file from local machine to Amazon S3.
182
 
@@ -195,14 +256,14 @@ def upload_log_file_to_s3(local_file_paths:List[str], s3_key:str, s3_bucket:str=
195
  try:
196
  if s3_bucket and s3_key and local_file_paths:
197
 
198
- s3_client = boto3.client('s3', region_name=AWS_REGION)
199
 
200
  if isinstance(local_file_paths, str):
201
  local_file_paths = [local_file_paths]
202
 
203
  for file in local_file_paths:
204
  if s3_client:
205
- #print(s3_client)
206
  try:
207
  # Get file name off file path
208
  file_name = os.path.basename(file)
@@ -211,23 +272,29 @@ def upload_log_file_to_s3(local_file_paths:List[str], s3_key:str, s3_bucket:str=
211
  print("S3 key: ", s3_key_full)
212
 
213
  s3_client.upload_file(file, s3_bucket, s3_key_full)
214
- out_message = "File " + file_name + " uploaded successfully!"
 
 
215
  print(out_message)
216
-
217
  except Exception as e:
218
  out_message = f"Error uploading file(s): {e}"
219
  print(out_message)
220
 
221
  final_out_message.append(out_message)
222
- final_out_message_str = '\n'.join(final_out_message)
223
-
224
- else: final_out_message_str = "Could not connect to AWS."
225
- else: final_out_message_str = "At least one essential variable is empty, could not upload to S3"
 
 
 
 
226
  except Exception as e:
227
  final_out_message_str = "Could not upload files to S3 due to: " + str(e)
228
  print(final_out_message_str)
229
  else:
230
- final_out_message_str = "App config will not AWS functions"
231
  print(final_out_message_str)
232
 
233
  return final_out_message_str
 
 
 
 
 
1
  import os
2
+ from typing import List, Type
3
+
4
+ import boto3
5
+ import pandas as pd
6
+
7
+ from tools.config import (
8
+ AWS_REGION,
9
+ DOCUMENT_REDACTION_BUCKET,
10
+ RUN_AWS_FUNCTIONS,
11
+ SAVE_LOGS_TO_CSV,
12
+ )
13
+
14
  PandasDataFrame = Type[pd.DataFrame]
15
 
16
+
17
  def get_assumed_role_info():
18
+ sts_endpoint = "https://sts." + AWS_REGION + ".amazonaws.com"
19
+ sts = boto3.client("sts", region_name=AWS_REGION, endpoint_url=sts_endpoint)
20
  response = sts.get_caller_identity()
21
 
22
  # Extract ARN of the assumed role
23
+ assumed_role_arn = response["Arn"]
24
+
25
  # Extract the name of the assumed role from the ARN
26
+ assumed_role_name = assumed_role_arn.split("/")[-1]
27
+
28
  return assumed_role_arn, assumed_role_name
29
 
30
+
31
  if RUN_AWS_FUNCTIONS == "1":
32
+ try:
33
+ session = boto3.Session(region_name=AWS_REGION)
34
+
35
  except Exception as e:
36
  print("Could not start boto3 session:", e)
37
 
 
39
  assumed_role_arn, assumed_role_name = get_assumed_role_info()
40
 
41
  print("Successfully assumed ARN role")
42
+ # print("Assumed Role ARN:", assumed_role_arn)
43
+ # print("Assumed Role Name:", assumed_role_name)
44
 
45
  except Exception as e:
46
  print("Could not get assumed role from STS:", e)
47
 
48
+
49
  # Download direct from S3 - requires login credentials
50
+ def download_file_from_s3(
51
+ bucket_name: str,
52
+ key: str,
53
+ local_file_path_and_name: str,
54
+ RUN_AWS_FUNCTIONS: str = RUN_AWS_FUNCTIONS,
55
+ ):
56
 
57
  if RUN_AWS_FUNCTIONS == "1":
58
 
 
60
  # Ensure the local directory exists
61
  os.makedirs(os.path.dirname(local_file_path_and_name), exist_ok=True)
62
 
63
+ s3 = boto3.client("s3", region_name=AWS_REGION)
64
  s3.download_file(bucket_name, key, local_file_path_and_name)
65
+ print(
66
+ f"File downloaded from s3://{bucket_name}/{key} to {local_file_path_and_name}"
67
+ )
68
  except Exception as e:
69
  print("Could not download file:", key, "from s3 due to", e)
70
 
71
+
72
+ def download_folder_from_s3(
73
+ bucket_name: str,
74
+ s3_folder: str,
75
+ local_folder: str,
76
+ RUN_AWS_FUNCTIONS: str = RUN_AWS_FUNCTIONS,
77
+ ):
78
  """
79
  Download all files from an S3 folder to a local folder.
80
  """
81
  if RUN_AWS_FUNCTIONS == "1":
82
  if bucket_name and s3_folder and local_folder:
83
 
84
+ s3 = boto3.client("s3", region_name=AWS_REGION)
85
 
86
  # List objects in the specified S3 folder
87
  response = s3.list_objects_v2(Bucket=bucket_name, Prefix=s3_folder)
88
 
89
  # Download each object
90
+ for obj in response.get("Contents", []):
91
  # Extract object key and construct local file path
92
+ object_key = obj["Key"]
93
+ local_file_path = os.path.join(
94
+ local_folder, os.path.relpath(object_key, s3_folder)
95
+ )
96
 
97
  # Create directories if necessary
98
  os.makedirs(os.path.dirname(local_file_path), exist_ok=True)
 
100
  # Download the object
101
  try:
102
  s3.download_file(bucket_name, object_key, local_file_path)
103
+ print(
104
+ f"Downloaded 's3://{bucket_name}/{object_key}' to '{local_file_path}'"
105
+ )
106
  except Exception as e:
107
  print(f"Error downloading 's3://{bucket_name}/{object_key}':", e)
108
+ else:
109
+ print(
110
+ "One or more required variables are empty, could not download from S3"
111
+ )
112
+
113
+
114
+ def download_files_from_s3(
115
+ bucket_name: str,
116
+ s3_folder: str,
117
+ local_folder: str,
118
+ filenames: List[str],
119
+ RUN_AWS_FUNCTIONS: str = RUN_AWS_FUNCTIONS,
120
+ ):
121
  """
122
  Download specific files from an S3 folder to a local folder.
123
  """
 
125
  if RUN_AWS_FUNCTIONS == "1":
126
  if bucket_name and s3_folder and local_folder and filenames:
127
 
128
+ s3 = boto3.client("s3", region_name=AWS_REGION)
129
 
130
  print("Trying to download file: ", filenames)
131
 
132
+ if filenames == "*":
133
  # List all objects in the S3 folder
134
  print("Trying to download all files in AWS folder: ", s3_folder)
135
  response = s3.list_objects_v2(Bucket=bucket_name, Prefix=s3_folder)
136
 
137
+ print("Found files in AWS folder: ", response.get("Contents", []))
138
 
139
+ filenames = [
140
+ obj["Key"].split("/")[-1] for obj in response.get("Contents", [])
141
+ ]
142
 
143
  print("Found filenames in AWS folder: ", filenames)
144
 
 
152
  # Download the object
153
  try:
154
  s3.download_file(bucket_name, object_key, local_file_path)
155
+ print(
156
+ f"Downloaded 's3://{bucket_name}/{object_key}' to '{local_file_path}'"
157
+ )
158
  except Exception as e:
159
  print(f"Error downloading 's3://{bucket_name}/{object_key}':", e)
160
 
161
+ else:
162
+ print(
163
+ "One or more required variables are empty, could not download from S3"
164
+ )
165
 
166
+
167
+ def upload_file_to_s3(
168
+ local_file_paths: List[str],
169
+ s3_key: str,
170
+ s3_bucket: str = DOCUMENT_REDACTION_BUCKET,
171
+ RUN_AWS_FUNCTIONS: str = RUN_AWS_FUNCTIONS,
172
+ ):
173
  """
174
  Uploads a file from local machine to Amazon S3.
175
 
 
188
  try:
189
  if s3_bucket and s3_key and local_file_paths:
190
 
191
+ s3_client = boto3.client("s3", region_name=AWS_REGION)
192
 
193
  if isinstance(local_file_paths, str):
194
  local_file_paths = [local_file_paths]
195
 
196
  for file in local_file_paths:
197
  if s3_client:
198
+ # print(s3_client)
199
  try:
200
  # Get file name off file path
201
  file_name = os.path.basename(file)
 
204
  print("S3 key: ", s3_key_full)
205
 
206
  s3_client.upload_file(file, s3_bucket, s3_key_full)
207
+ out_message = (
208
+ "File " + file_name + " uploaded successfully!"
209
+ )
210
  print(out_message)
211
+
212
  except Exception as e:
213
  out_message = f"Error uploading file(s): {e}"
214
  print(out_message)
215
 
216
  final_out_message.append(out_message)
217
+ final_out_message_str = "\n".join(final_out_message)
218
+
219
+ else:
220
+ final_out_message_str = "Could not connect to AWS."
221
+ else:
222
+ final_out_message_str = (
223
+ "At least one essential variable is empty, could not upload to S3"
224
+ )
225
  except Exception as e:
226
  final_out_message_str = "Could not upload files to S3 due to: " + str(e)
227
  print(final_out_message_str)
228
  else:
229
+ final_out_message_str = "App config will not run AWS functions"
230
 
231
  return final_out_message_str
232
 
233
 
234
+ def upload_log_file_to_s3(
235
+ local_file_paths: List[str],
236
+ s3_key: str,
237
+ s3_bucket: str = DOCUMENT_REDACTION_BUCKET,
238
+ RUN_AWS_FUNCTIONS: str = RUN_AWS_FUNCTIONS,
239
+ SAVE_LOGS_TO_CSV: str = SAVE_LOGS_TO_CSV,
240
+ ):
241
  """
242
  Uploads a log file from local machine to Amazon S3.
243
 
 
256
  try:
257
  if s3_bucket and s3_key and local_file_paths:
258
 
259
+ s3_client = boto3.client("s3", region_name=AWS_REGION)
260
 
261
  if isinstance(local_file_paths, str):
262
  local_file_paths = [local_file_paths]
263
 
264
  for file in local_file_paths:
265
  if s3_client:
266
+ # print(s3_client)
267
  try:
268
  # Get file name off file path
269
  file_name = os.path.basename(file)
 
272
  print("S3 key: ", s3_key_full)
273
 
274
  s3_client.upload_file(file, s3_bucket, s3_key_full)
275
+ out_message = (
276
+ "File " + file_name + " uploaded successfully!"
277
+ )
278
  print(out_message)
279
+
280
  except Exception as e:
281
  out_message = f"Error uploading file(s): {e}"
282
  print(out_message)
283
 
284
  final_out_message.append(out_message)
285
+ final_out_message_str = "\n".join(final_out_message)
286
+
287
+ else:
288
+ final_out_message_str = "Could not connect to AWS."
289
+ else:
290
+ final_out_message_str = (
291
+ "At least one essential variable is empty, could not upload to S3"
292
+ )
293
  except Exception as e:
294
  final_out_message_str = "Could not upload files to S3 due to: " + str(e)
295
  print(final_out_message_str)
296
  else:
297
+ final_out_message_str = "App config will not run AWS functions"
298
  print(final_out_message_str)
299
 
300
  return final_out_message_str
tools/aws_textract.py CHANGED
@@ -1,27 +1,44 @@
1
- import boto3
2
- from typing import List
3
  import io
4
- import os
5
  import json
6
- import pikepdf
7
  import time
 
 
 
8
  import pandas as pd
9
- from tools.custom_image_analyser_engine import OCRResult, CustomImageRecognizerResult
10
- from tools.config import AWS_ACCESS_KEY, AWS_SECRET_KEY, AWS_REGION, RUN_AWS_FUNCTIONS, PRIORITISE_SSO_OVER_AWS_ENV_ACCESS_KEYS
 
 
 
 
 
 
 
 
 
11
 
12
- def extract_textract_metadata(response:object):
13
  """Extracts metadata from an AWS Textract response."""
14
 
15
- request_id = response['ResponseMetadata']['RequestId']
16
- pages = response['DocumentMetadata']['Pages']
17
 
18
- return str({
19
- 'RequestId': request_id,
20
- 'Pages': pages
21
- })
22
 
23
- def analyse_page_with_textract(pdf_page_bytes:object, page_no:int, client:str="", handwrite_signature_checkbox:List[str]=["Extract handwriting"], textract_output_found:bool=False, aws_access_key_textbox:str=AWS_ACCESS_KEY, aws_secret_key_textbox:str=AWS_SECRET_KEY, RUN_AWS_FUNCTIONS:str=RUN_AWS_FUNCTIONS, PRIORITISE_SSO_OVER_AWS_ENV_ACCESS_KEYS:str=PRIORITISE_SSO_OVER_AWS_ENV_ACCESS_KEYS):
24
- '''
 
 
 
 
 
 
 
 
 
 
 
25
  Analyzes a single page of a document using AWS Textract to extract text and other features.
26
 
27
  Args:
@@ -53,45 +70,63 @@ def analyse_page_with_textract(pdf_page_bytes:object, page_no:int, client:str=""
53
  Tuple[List[Dict], str]: A tuple containing:
54
  - A list of dictionaries, where each dictionary represents a Textract block (e.g., LINE, WORD, FORM, TABLE).
55
  - A string containing metadata about the Textract request.
56
- '''
57
 
58
- #print("handwrite_signature_checkbox in analyse_page_with_textract:", handwrite_signature_checkbox)
59
  if client == "":
60
- try:
61
  # Try to connect to AWS Textract Client if using that text extraction method
62
- if RUN_AWS_FUNCTIONS == "1" and PRIORITISE_SSO_OVER_AWS_ENV_ACCESS_KEYS == "1":
 
 
 
63
  print("Connecting to Textract via existing SSO connection")
64
- client = boto3.client('textract', region_name=AWS_REGION)
65
  elif aws_access_key_textbox and aws_secret_key_textbox:
66
- print("Connecting to Textract using AWS access key and secret keys from user input.")
67
- client = boto3.client('textract',
68
- aws_access_key_id=aws_access_key_textbox,
69
- aws_secret_access_key=aws_secret_key_textbox, region_name=AWS_REGION)
 
 
 
 
 
70
  elif RUN_AWS_FUNCTIONS == "1":
71
  print("Connecting to Textract via existing SSO connection")
72
- client = boto3.client('textract', region_name=AWS_REGION)
73
  elif AWS_ACCESS_KEY and AWS_SECRET_KEY:
74
  print("Getting Textract credentials from environment variables.")
75
- client = boto3.client('textract',
76
- aws_access_key_id=AWS_ACCESS_KEY,
77
- aws_secret_access_key=AWS_SECRET_KEY, region_name=AWS_REGION)
78
- elif textract_output_found==True:
79
- print("Existing Textract data found for file, no need to connect to AWS Textract")
80
- client = boto3.client('textract', region_name=AWS_REGION)
 
 
 
 
 
81
  else:
82
  client = ""
83
  out_message = "Cannot connect to AWS Textract service."
84
  print(out_message)
85
  raise Exception(out_message)
86
- except:
87
  out_message = "Cannot connect to AWS Textract"
88
- print(out_message)
89
  raise Exception(out_message)
90
  return [], "" # Return an empty list and an empty string
91
-
92
  # Redact signatures if specified
93
  feature_types = list()
94
- if "Extract signatures" in handwrite_signature_checkbox or "Extract forms" in handwrite_signature_checkbox or "Extract layout" in handwrite_signature_checkbox or "Extract tables" in handwrite_signature_checkbox:
 
 
 
 
 
95
  if "Extract signatures" in handwrite_signature_checkbox:
96
  feature_types.append("SIGNATURES")
97
  if "Extract forms" in handwrite_signature_checkbox:
@@ -101,39 +136,50 @@ def analyse_page_with_textract(pdf_page_bytes:object, page_no:int, client:str=""
101
  if "Extract tables" in handwrite_signature_checkbox:
102
  feature_types.append("TABLES")
103
  try:
104
- response = client.analyze_document(Document={'Bytes': pdf_page_bytes}, FeatureTypes=feature_types)
 
 
105
  except Exception as e:
106
  print("Textract call failed due to:", e, "trying again in 3 seconds.")
107
  time.sleep(3)
108
- response = client.analyze_document(Document={'Bytes': pdf_page_bytes}, FeatureTypes=feature_types)
109
-
110
-
111
- if not "Extract signatures" in handwrite_signature_checkbox and not "Extract forms" in handwrite_signature_checkbox and not "Extract layout" in handwrite_signature_checkbox and not "Extract tables" in handwrite_signature_checkbox:
 
 
 
 
 
 
112
  # Call detect_document_text to extract plain text
113
  try:
114
- response = client.detect_document_text(Document={'Bytes': pdf_page_bytes})
115
  except Exception as e:
116
  print("Textract call failed due to:", e, "trying again in 5 seconds.")
117
  time.sleep(5)
118
- response = client.detect_document_text(Document={'Bytes': pdf_page_bytes})
119
 
120
- # Add the 'Page' attribute to each block
121
  if "Blocks" in response:
122
  for block in response["Blocks"]:
123
  block["Page"] = page_no # Inject the page number into each block
124
 
125
  # Wrap the response with the page number in the desired format
126
- wrapped_response = {
127
- 'page_no': page_no,
128
- 'data': response
129
- }
130
 
131
- request_metadata = extract_textract_metadata(response) # Metadata comes out as a string
 
 
132
 
133
  # Return a list containing the wrapped response and the metadata
134
- return wrapped_response, request_metadata # Return as a list to match the desired structure
 
 
 
 
135
 
136
- def convert_pike_pdf_page_to_bytes(pdf:object, page_num:int):
137
  # Create a new empty PDF
138
  new_pdf = pikepdf.Pdf.new()
139
 
@@ -155,8 +201,11 @@ def convert_pike_pdf_page_to_bytes(pdf:object, page_num:int):
155
 
156
  return pdf_bytes
157
 
158
- def json_to_ocrresult(json_data:dict, page_width:float, page_height:float, page_no:int):
159
- '''
 
 
 
160
  Convert the json response from Textract to the OCRResult format used elsewhere in the code.
161
  Looks for lines, words, and signatures. Handwriting and signatures are set aside especially
162
  for later in case the user wants to override the default behaviour and redact all
@@ -175,7 +224,7 @@ def json_to_ocrresult(json_data:dict, page_width:float, page_height:float, page_
175
  - list: Bounding boxes identified specifically as signatures.
176
  - list: Bounding boxes identified specifically as handwriting.
177
  - dict: OCR results with word-level detail, structured for further processing.
178
- '''
179
  all_ocr_results = list()
180
  signature_or_handwriting_recogniser_results = list()
181
  signature_recogniser_results = list()
@@ -183,14 +232,14 @@ def json_to_ocrresult(json_data:dict, page_width:float, page_height:float, page_
183
  signatures = list()
184
  handwriting = list()
185
  ocr_results_with_words = dict()
186
- text_block=dict()
187
 
188
  text_line_number = 1
189
 
190
  # Assuming json_data is structured as a dictionary with a "pages" key
191
 
192
  # Find the specific page data
193
- page_json_data = json_data #next((page for page in json_data["pages"] if page["page_no"] == page_no), None)
194
 
195
  if "Blocks" in page_json_data:
196
  # Access the data for the specific page
@@ -198,14 +247,17 @@ def json_to_ocrresult(json_data:dict, page_width:float, page_height:float, page_
198
  # This is a new page
199
  elif "page_no" in page_json_data:
200
  text_blocks = page_json_data["data"]["Blocks"]
201
- else: text_blocks = []
 
202
 
203
  is_signature = False
204
  is_handwriting = False
205
 
206
- for text_block in text_blocks:
207
-
208
- if (text_block['BlockType'] == 'LINE') | (text_block['BlockType'] == 'SIGNATURE'): # (text_block['BlockType'] == 'WORD') |
 
 
209
 
210
  # Extract text and bounding box for the line
211
  line_bbox = text_block["Geometry"]["BoundingBox"]
@@ -217,26 +269,41 @@ def json_to_ocrresult(json_data:dict, page_width:float, page_height:float, page_
217
  width_abs = int(line_bbox["Width"] * page_width)
218
  height_abs = int(line_bbox["Height"] * page_height)
219
 
220
- if text_block['BlockType'] == 'LINE':
221
-
222
  # Extract text and bounding box for the line
223
- line_text = text_block.get('Text', '')
224
  words = []
225
- current_line_handwriting_results = [] # Track handwriting results for this line
226
-
227
- if 'Relationships' in text_block:
228
- for relationship in text_block['Relationships']:
229
- if relationship['Type'] == 'CHILD':
230
- for child_id in relationship['Ids']:
231
- child_block = next((block for block in text_blocks if block['Id'] == child_id), None)
232
- if child_block and child_block['BlockType'] == 'WORD':
233
- word_text = child_block.get('Text', '')
 
 
 
 
 
 
 
 
 
234
  word_bbox = child_block["Geometry"]["BoundingBox"]
235
- confidence = child_block.get('Confidence','')
236
  word_left = int(word_bbox["Left"] * page_width)
237
  word_top = int(word_bbox["Top"] * page_height)
238
- word_right = int((word_bbox["Left"] + word_bbox["Width"]) * page_width)
239
- word_bottom = int((word_bbox["Top"] + word_bbox["Height"]) * page_height)
 
 
 
 
 
 
240
 
241
  # Extract BoundingBox details
242
  word_width = word_bbox["Width"]
@@ -245,13 +312,20 @@ def json_to_ocrresult(json_data:dict, page_width:float, page_height:float, page_
245
  # Convert proportional coordinates to absolute coordinates
246
  word_width_abs = int(word_width * page_width)
247
  word_height_abs = int(word_height * page_height)
248
-
249
- words.append({
250
- 'text': word_text,
251
- 'bounding_box': (word_left, word_top, word_right, word_bottom)
252
- })
 
 
 
 
 
 
 
253
  # Check for handwriting
254
- text_type = child_block.get("TextType", '')
255
 
256
  if text_type == "HANDWRITING":
257
  is_handwriting = True
@@ -267,22 +341,28 @@ def json_to_ocrresult(json_data:dict, page_width:float, page_height:float, page_
267
  left=word_left,
268
  top=word_top,
269
  width=word_width_abs,
270
- height=word_height_abs
271
  )
272
 
273
  # Add to handwriting collections immediately
274
  handwriting.append(recogniser_result)
275
- handwriting_recogniser_results.append(recogniser_result)
276
- signature_or_handwriting_recogniser_results.append(recogniser_result)
277
- current_line_handwriting_results.append(recogniser_result)
 
 
 
 
 
 
278
 
279
- # If handwriting or signature, add to bounding box
280
 
281
- elif (text_block['BlockType'] == 'SIGNATURE'):
282
  line_text = "SIGNATURE"
283
  is_signature = True
284
  entity_name = "SIGNATURE"
285
- confidence = text_block.get('Confidence', 0)
286
  word_end = len(line_text)
287
 
288
  recogniser_result = CustomImageRecognizerResult(
@@ -294,7 +374,7 @@ def json_to_ocrresult(json_data:dict, page_width:float, page_height:float, page_
294
  left=line_left,
295
  top=line_top,
296
  width=width_abs,
297
- height=height_abs
298
  )
299
 
300
  # Add to signature collections immediately
@@ -302,13 +382,15 @@ def json_to_ocrresult(json_data:dict, page_width:float, page_height:float, page_
302
  signature_recogniser_results.append(recogniser_result)
303
  signature_or_handwriting_recogniser_results.append(recogniser_result)
304
 
305
- words = [{
306
- 'text': line_text,
307
- 'bounding_box': (line_left, line_top, line_right, line_bottom)
308
- }]
 
 
309
  else:
310
  line_text = ""
311
- words=[]
312
  line_left = 0
313
  line_top = 0
314
  line_right = 0
@@ -320,14 +402,22 @@ def json_to_ocrresult(json_data:dict, page_width:float, page_height:float, page_
320
 
321
  ocr_results_with_words["text_line_" + str(text_line_number)] = {
322
  "line": text_line_number,
323
- 'text': line_text,
324
- 'bounding_box': (line_left, line_top, line_right, line_bottom),
325
- 'words': words,
326
- 'page': page_no
327
  }
328
 
329
  # Create OCRResult with absolute coordinates
330
- ocr_result = OCRResult(line_text, line_left, line_top, width_abs, height_abs, conf=confidence, line=text_line_number)
 
 
 
 
 
 
 
 
331
  all_ocr_results.append(ocr_result)
332
 
333
  # Increase line number
@@ -337,35 +427,50 @@ def json_to_ocrresult(json_data:dict, page_width:float, page_height:float, page_
337
 
338
  # If it is signature or handwriting, will overwrite the default behaviour of the PII analyser
339
  if is_signature_or_handwriting:
340
- if recogniser_result not in signature_or_handwriting_recogniser_results:
341
  signature_or_handwriting_recogniser_results.append(recogniser_result)
342
 
343
  if is_signature:
344
- if recogniser_result not in signature_recogniser_results:
345
  signature_recogniser_results.append(recogniser_result)
346
 
347
- if is_handwriting:
348
- if recogniser_result not in handwriting_recogniser_results:
349
  handwriting_recogniser_results.append(recogniser_result)
350
 
351
-
352
-
353
  # Add page key to the line level results
354
  all_ocr_results_with_page = {"page": page_no, "results": all_ocr_results}
355
- ocr_results_with_words_with_page = {"page": page_no, "results": ocr_results_with_words}
 
 
 
 
 
 
 
 
 
 
 
356
 
357
- return all_ocr_results_with_page, signature_or_handwriting_recogniser_results, signature_recogniser_results, handwriting_recogniser_results, ocr_results_with_words_with_page
358
 
359
- def load_and_convert_textract_json(textract_json_file_path:str, log_files_output_paths:str, page_sizes_df:pd.DataFrame):
 
 
 
 
360
  """
361
  Loads Textract JSON from a file, detects if conversion is needed, and converts if necessary.
362
  """
363
-
364
  if not os.path.exists(textract_json_file_path):
365
  print("No existing Textract results file found.")
366
- return {}, True, log_files_output_paths # Return empty dict and flag indicating missing file
367
-
368
- no_textract_file = False
 
 
 
369
  print("Found existing Textract json results file.")
370
 
371
  # Track log files
@@ -373,7 +478,7 @@ def load_and_convert_textract_json(textract_json_file_path:str, log_files_output
373
  log_files_output_paths.append(textract_json_file_path)
374
 
375
  try:
376
- with open(textract_json_file_path, 'r', encoding='utf-8') as json_file:
377
  textract_data = json.load(json_file)
378
  except json.JSONDecodeError:
379
  print("Error: Failed to parse Textract JSON file. Returning empty data.")
@@ -387,21 +492,30 @@ def load_and_convert_textract_json(textract_json_file_path:str, log_files_output
387
  if "Blocks" in textract_data:
388
  print("Need to convert Textract JSON to app format.")
389
  try:
390
-
391
  textract_data = restructure_textract_output(textract_data, page_sizes_df)
392
- return textract_data, False, log_files_output_paths # Successfully converted
393
-
 
 
 
 
394
  except Exception as e:
395
  print("Failed to convert JSON data to app format due to:", e)
396
  return {}, True, log_files_output_paths # Conversion failed
397
  else:
398
  print("Invalid Textract JSON format: 'Blocks' missing.")
399
- #print("textract data:", textract_data)
400
- return {}, True, log_files_output_paths # Return empty data if JSON is not recognized
 
 
 
 
401
 
402
- def restructure_textract_output(textract_output: dict, page_sizes_df:pd.DataFrame):
 
403
  """
404
- Reorganise Textract output from the bulk Textract analysis option on AWS
405
  into a format that works in this redaction app, reducing size.
406
  """
407
  pages_dict = {}
@@ -410,8 +524,8 @@ def restructure_textract_output(textract_output: dict, page_sizes_df:pd.DataFram
410
  document_metadata = textract_output.get("DocumentMetadata", {})
411
 
412
  # For efficient lookup, set 'page' as index if it's not already
413
- if 'page' in page_sizes_df.columns:
414
- page_sizes_df = page_sizes_df.set_index('page')
415
 
416
  for block in textract_output.get("Blocks", []):
417
  page_no = block.get("Page", 1) # Default to 1 if missing
@@ -419,29 +533,30 @@ def restructure_textract_output(textract_output: dict, page_sizes_df:pd.DataFram
419
  # --- Geometry Conversion Logic ---
420
  try:
421
  page_info = page_sizes_df.loc[page_no]
422
- cb_width = page_info['cropbox_width']
423
- cb_height = page_info['cropbox_height']
424
- mb_width = page_info['mediabox_width']
425
- mb_height = page_info['mediabox_height']
426
- cb_x_offset = page_info['cropbox_x_offset']
427
- cb_y_offset_top = page_info['cropbox_y_offset_from_top']
428
 
429
  # Check if conversion is needed (and avoid division by zero)
430
  needs_conversion = (
431
- abs(cb_width - mb_width) > 1e-6 or \
432
- abs(cb_height - mb_height) > 1e-6
433
- ) and mb_width > 1e-6 and mb_height > 1e-6 # Avoid division by zero
 
434
 
435
- if needs_conversion and 'Geometry' in block:
436
- geometry = block['Geometry'] # Work directly on the block's geometry
437
 
438
  # --- Convert BoundingBox ---
439
- if 'BoundingBox' in geometry:
440
- bbox = geometry['BoundingBox']
441
- old_left = bbox['Left']
442
- old_top = bbox['Top']
443
- old_width = bbox['Width']
444
- old_height = bbox['Height']
445
 
446
  # Calculate absolute coordinates within CropBox
447
  abs_cb_x = old_left * cb_width
@@ -454,15 +569,19 @@ def restructure_textract_output(textract_output: dict, page_sizes_df:pd.DataFram
454
  abs_mb_y = cb_y_offset_top + abs_cb_y
455
 
456
  # Convert back to normalized coordinates relative to MediaBox
457
- bbox['Left'] = abs_mb_x / mb_width
458
- bbox['Top'] = abs_mb_y / mb_height
459
- bbox['Width'] = abs_cb_width / mb_width
460
- bbox['Height'] = abs_cb_height / mb_height
461
  except KeyError:
462
- print(f"Warning: Page number {page_no} not found in page_sizes_df. Skipping coordinate conversion for this block.")
 
 
463
  # Decide how to handle missing page info: skip conversion, raise error, etc.
464
  except ZeroDivisionError:
465
- print(f"Warning: MediaBox width or height is zero for page {page_no}. Skipping coordinate conversion for this block.")
 
 
466
 
467
  # Initialise page structure if not already present
468
  if page_no not in pages_dict:
@@ -470,16 +589,25 @@ def restructure_textract_output(textract_output: dict, page_sizes_df:pd.DataFram
470
 
471
  # Keep only essential fields to reduce size
472
  filtered_block = {
473
- key: block[key] for key in ["BlockType", "Confidence", "Text", "Geometry", "Page", "Id", "Relationships"]
 
 
 
 
 
 
 
 
 
474
  if key in block
475
  }
476
-
477
  pages_dict[page_no]["data"]["Blocks"].append(filtered_block)
478
 
479
  # Convert pages dictionary to a sorted list
480
  structured_output = {
481
  "DocumentMetadata": document_metadata, # Store metadata separately
482
- "pages": [pages_dict[page] for page in sorted(pages_dict.keys())]
483
  }
484
 
485
  return structured_output
 
 
 
1
  import io
 
2
  import json
3
+ import os
4
  import time
5
+ from typing import List
6
+
7
+ import boto3
8
  import pandas as pd
9
+ import pikepdf
10
+
11
+ from tools.config import (
12
+ AWS_ACCESS_KEY,
13
+ AWS_REGION,
14
+ AWS_SECRET_KEY,
15
+ PRIORITISE_SSO_OVER_AWS_ENV_ACCESS_KEYS,
16
+ RUN_AWS_FUNCTIONS,
17
+ )
18
+ from tools.custom_image_analyser_engine import CustomImageRecognizerResult, OCRResult
19
+
20
 
21
+ def extract_textract_metadata(response: object):
22
  """Extracts metadata from an AWS Textract response."""
23
 
24
+ request_id = response["ResponseMetadata"]["RequestId"]
25
+ pages = response["DocumentMetadata"]["Pages"]
26
 
27
+ return str({"RequestId": request_id, "Pages": pages})
 
 
 
28
 
29
+
30
+ def analyse_page_with_textract(
31
+ pdf_page_bytes: object,
32
+ page_no: int,
33
+ client: str = "",
34
+ handwrite_signature_checkbox: List[str] = ["Extract handwriting"],
35
+ textract_output_found: bool = False,
36
+ aws_access_key_textbox: str = AWS_ACCESS_KEY,
37
+ aws_secret_key_textbox: str = AWS_SECRET_KEY,
38
+ RUN_AWS_FUNCTIONS: str = RUN_AWS_FUNCTIONS,
39
+ PRIORITISE_SSO_OVER_AWS_ENV_ACCESS_KEYS: str = PRIORITISE_SSO_OVER_AWS_ENV_ACCESS_KEYS,
40
+ ):
41
+ """
42
  Analyzes a single page of a document using AWS Textract to extract text and other features.
43
 
44
  Args:
 
70
  Tuple[List[Dict], str]: A tuple containing:
71
  - A list of dictionaries, where each dictionary represents a Textract block (e.g., LINE, WORD, FORM, TABLE).
72
  - A string containing metadata about the Textract request.
73
+ """
74
 
75
+ # print("handwrite_signature_checkbox in analyse_page_with_textract:", handwrite_signature_checkbox)
76
  if client == "":
77
+ try:
78
  # Try to connect to AWS Textract Client if using that text extraction method
79
+ if (
80
+ RUN_AWS_FUNCTIONS == "1"
81
+ and PRIORITISE_SSO_OVER_AWS_ENV_ACCESS_KEYS == "1"
82
+ ):
83
  print("Connecting to Textract via existing SSO connection")
84
+ client = boto3.client("textract", region_name=AWS_REGION)
85
  elif aws_access_key_textbox and aws_secret_key_textbox:
86
+ print(
87
+ "Connecting to Textract using AWS access key and secret keys from user input."
88
+ )
89
+ client = boto3.client(
90
+ "textract",
91
+ aws_access_key_id=aws_access_key_textbox,
92
+ aws_secret_access_key=aws_secret_key_textbox,
93
+ region_name=AWS_REGION,
94
+ )
95
  elif RUN_AWS_FUNCTIONS == "1":
96
  print("Connecting to Textract via existing SSO connection")
97
+ client = boto3.client("textract", region_name=AWS_REGION)
98
  elif AWS_ACCESS_KEY and AWS_SECRET_KEY:
99
  print("Getting Textract credentials from environment variables.")
100
+ client = boto3.client(
101
+ "textract",
102
+ aws_access_key_id=AWS_ACCESS_KEY,
103
+ aws_secret_access_key=AWS_SECRET_KEY,
104
+ region_name=AWS_REGION,
105
+ )
106
+ elif textract_output_found is True:
107
+ print(
108
+ "Existing Textract data found for file, no need to connect to AWS Textract"
109
+ )
110
+ client = boto3.client("textract", region_name=AWS_REGION)
111
  else:
112
  client = ""
113
  out_message = "Cannot connect to AWS Textract service."
114
  print(out_message)
115
  raise Exception(out_message)
116
+ except Exception as e:
117
  out_message = "Cannot connect to AWS Textract"
118
+ print(out_message, "due to:", e)
119
  raise Exception(out_message)
120
  return [], "" # Return an empty list and an empty string
121
+
122
  # Redact signatures if specified
123
  feature_types = list()
124
+ if (
125
+ "Extract signatures" in handwrite_signature_checkbox
126
+ or "Extract forms" in handwrite_signature_checkbox
127
+ or "Extract layout" in handwrite_signature_checkbox
128
+ or "Extract tables" in handwrite_signature_checkbox
129
+ ):
130
  if "Extract signatures" in handwrite_signature_checkbox:
131
  feature_types.append("SIGNATURES")
132
  if "Extract forms" in handwrite_signature_checkbox:
 
136
  if "Extract tables" in handwrite_signature_checkbox:
137
  feature_types.append("TABLES")
138
  try:
139
+ response = client.analyze_document(
140
+ Document={"Bytes": pdf_page_bytes}, FeatureTypes=feature_types
141
+ )
142
  except Exception as e:
143
  print("Textract call failed due to:", e, "trying again in 3 seconds.")
144
  time.sleep(3)
145
+ response = client.analyze_document(
146
+ Document={"Bytes": pdf_page_bytes}, FeatureTypes=feature_types
147
+ )
148
+
149
+ if (
150
+ "Extract signatures" not in handwrite_signature_checkbox
151
+ and "Extract forms" not in handwrite_signature_checkbox
152
+ and "Extract layout" not in handwrite_signature_checkbox
153
+ and "Extract tables" not in handwrite_signature_checkbox
154
+ ):
155
  # Call detect_document_text to extract plain text
156
  try:
157
+ response = client.detect_document_text(Document={"Bytes": pdf_page_bytes})
158
  except Exception as e:
159
  print("Textract call failed due to:", e, "trying again in 5 seconds.")
160
  time.sleep(5)
161
+ response = client.detect_document_text(Document={"Bytes": pdf_page_bytes})
162
 
163
+ # Add the 'Page' attribute to each block
164
  if "Blocks" in response:
165
  for block in response["Blocks"]:
166
  block["Page"] = page_no # Inject the page number into each block
167
 
168
  # Wrap the response with the page number in the desired format
169
+ wrapped_response = {"page_no": page_no, "data": response}
 
 
 
170
 
171
+ request_metadata = extract_textract_metadata(
172
+ response
173
+ ) # Metadata comes out as a string
174
 
175
  # Return a list containing the wrapped response and the metadata
176
+ return (
177
+ wrapped_response,
178
+ request_metadata,
179
+ ) # Return as a list to match the desired structure
180
+
181
 
182
+ def convert_pike_pdf_page_to_bytes(pdf: object, page_num: int):
183
  # Create a new empty PDF
184
  new_pdf = pikepdf.Pdf.new()
185
 
 
201
 
202
  return pdf_bytes
203
 
204
+
205
+ def json_to_ocrresult(
206
+ json_data: dict, page_width: float, page_height: float, page_no: int
207
+ ):
208
+ """
209
  Convert the json response from Textract to the OCRResult format used elsewhere in the code.
210
  Looks for lines, words, and signatures. Handwriting and signatures are set aside especially
211
  for later in case the user wants to override the default behaviour and redact all
 
224
  - list: Bounding boxes identified specifically as signatures.
225
  - list: Bounding boxes identified specifically as handwriting.
226
  - dict: OCR results with word-level detail, structured for further processing.
227
+ """
228
  all_ocr_results = list()
229
  signature_or_handwriting_recogniser_results = list()
230
  signature_recogniser_results = list()
 
232
  signatures = list()
233
  handwriting = list()
234
  ocr_results_with_words = dict()
235
+ text_block = dict()
236
 
237
  text_line_number = 1
238
 
239
  # Assuming json_data is structured as a dictionary with a "pages" key
240
 
241
  # Find the specific page data
242
+ page_json_data = json_data # next((page for page in json_data["pages"] if page["page_no"] == page_no), None)
243
 
244
  if "Blocks" in page_json_data:
245
  # Access the data for the specific page
 
247
  # This is a new page
248
  elif "page_no" in page_json_data:
249
  text_blocks = page_json_data["data"]["Blocks"]
250
+ else:
251
+ text_blocks = []
252
 
253
  is_signature = False
254
  is_handwriting = False
255
 
256
+ for text_block in text_blocks:
257
+
258
+ if (text_block["BlockType"] == "LINE") | (
259
+ text_block["BlockType"] == "SIGNATURE"
260
+ ): # (text_block['BlockType'] == 'WORD') |
261
 
262
  # Extract text and bounding box for the line
263
  line_bbox = text_block["Geometry"]["BoundingBox"]
 
269
  width_abs = int(line_bbox["Width"] * page_width)
270
  height_abs = int(line_bbox["Height"] * page_height)
271
 
272
+ if text_block["BlockType"] == "LINE":
273
+
274
  # Extract text and bounding box for the line
275
+ line_text = text_block.get("Text", "")
276
  words = []
277
+ current_line_handwriting_results = (
278
+ []
279
+ ) # Track handwriting results for this line
280
+
281
+ if "Relationships" in text_block:
282
+ for relationship in text_block["Relationships"]:
283
+ if relationship["Type"] == "CHILD":
284
+ for child_id in relationship["Ids"]:
285
+ child_block = next(
286
+ (
287
+ block
288
+ for block in text_blocks
289
+ if block["Id"] == child_id
290
+ ),
291
+ None,
292
+ )
293
+ if child_block and child_block["BlockType"] == "WORD":
294
+ word_text = child_block.get("Text", "")
295
  word_bbox = child_block["Geometry"]["BoundingBox"]
296
+ confidence = child_block.get("Confidence", "")
297
  word_left = int(word_bbox["Left"] * page_width)
298
  word_top = int(word_bbox["Top"] * page_height)
299
+ word_right = int(
300
+ (word_bbox["Left"] + word_bbox["Width"])
301
+ * page_width
302
+ )
303
+ word_bottom = int(
304
+ (word_bbox["Top"] + word_bbox["Height"])
305
+ * page_height
306
+ )
307
 
308
  # Extract BoundingBox details
309
  word_width = word_bbox["Width"]
 
312
  # Convert proportional coordinates to absolute coordinates
313
  word_width_abs = int(word_width * page_width)
314
  word_height_abs = int(word_height * page_height)
315
+
316
+ words.append(
317
+ {
318
+ "text": word_text,
319
+ "bounding_box": (
320
+ word_left,
321
+ word_top,
322
+ word_right,
323
+ word_bottom,
324
+ ),
325
+ }
326
+ )
327
  # Check for handwriting
328
+ text_type = child_block.get("TextType", "")
329
 
330
  if text_type == "HANDWRITING":
331
  is_handwriting = True
 
341
  left=word_left,
342
  top=word_top,
343
  width=word_width_abs,
344
+ height=word_height_abs,
345
  )
346
 
347
  # Add to handwriting collections immediately
348
  handwriting.append(recogniser_result)
349
+ handwriting_recogniser_results.append(
350
+ recogniser_result
351
+ )
352
+ signature_or_handwriting_recogniser_results.append(
353
+ recogniser_result
354
+ )
355
+ current_line_handwriting_results.append(
356
+ recogniser_result
357
+ )
358
 
359
+ # If handwriting or signature, add to bounding box
360
 
361
+ elif text_block["BlockType"] == "SIGNATURE":
362
  line_text = "SIGNATURE"
363
  is_signature = True
364
  entity_name = "SIGNATURE"
365
+ confidence = text_block.get("Confidence", 0)
366
  word_end = len(line_text)
367
 
368
  recogniser_result = CustomImageRecognizerResult(
 
374
  left=line_left,
375
  top=line_top,
376
  width=width_abs,
377
+ height=height_abs,
378
  )
379
 
380
  # Add to signature collections immediately
 
382
  signature_recogniser_results.append(recogniser_result)
383
  signature_or_handwriting_recogniser_results.append(recogniser_result)
384
 
385
+ words = [
386
+ {
387
+ "text": line_text,
388
+ "bounding_box": (line_left, line_top, line_right, line_bottom),
389
+ }
390
+ ]
391
  else:
392
  line_text = ""
393
+ words = []
394
  line_left = 0
395
  line_top = 0
396
  line_right = 0
 
402
 
403
  ocr_results_with_words["text_line_" + str(text_line_number)] = {
404
  "line": text_line_number,
405
+ "text": line_text,
406
+ "bounding_box": (line_left, line_top, line_right, line_bottom),
407
+ "words": words,
408
+ "page": page_no,
409
  }
410
 
411
  # Create OCRResult with absolute coordinates
412
+ ocr_result = OCRResult(
413
+ line_text,
414
+ line_left,
415
+ line_top,
416
+ width_abs,
417
+ height_abs,
418
+ conf=confidence,
419
+ line=text_line_number,
420
+ )
421
  all_ocr_results.append(ocr_result)
422
 
423
  # Increase line number
 
427
 
428
  # If it is signature or handwriting, will overwrite the default behaviour of the PII analyser
429
  if is_signature_or_handwriting:
430
+ if recogniser_result not in signature_or_handwriting_recogniser_results:
431
  signature_or_handwriting_recogniser_results.append(recogniser_result)
432
 
433
  if is_signature:
434
+ if recogniser_result not in signature_recogniser_results:
435
  signature_recogniser_results.append(recogniser_result)
436
 
437
+ if is_handwriting:
438
+ if recogniser_result not in handwriting_recogniser_results:
439
  handwriting_recogniser_results.append(recogniser_result)
440
 
 
 
441
  # Add page key to the line level results
442
  all_ocr_results_with_page = {"page": page_no, "results": all_ocr_results}
443
+ ocr_results_with_words_with_page = {
444
+ "page": page_no,
445
+ "results": ocr_results_with_words,
446
+ }
447
+
448
+ return (
449
+ all_ocr_results_with_page,
450
+ signature_or_handwriting_recogniser_results,
451
+ signature_recogniser_results,
452
+ handwriting_recogniser_results,
453
+ ocr_results_with_words_with_page,
454
+ )
455
 
 
456
 
457
+ def load_and_convert_textract_json(
458
+ textract_json_file_path: str,
459
+ log_files_output_paths: str,
460
+ page_sizes_df: pd.DataFrame,
461
+ ):
462
  """
463
  Loads Textract JSON from a file, detects if conversion is needed, and converts if necessary.
464
  """
465
+
466
  if not os.path.exists(textract_json_file_path):
467
  print("No existing Textract results file found.")
468
+ return (
469
+ {},
470
+ True,
471
+ log_files_output_paths,
472
+ ) # Return empty dict and flag indicating missing file
473
+
474
  print("Found existing Textract json results file.")
475
 
476
  # Track log files
 
478
  log_files_output_paths.append(textract_json_file_path)
479
 
480
  try:
481
+ with open(textract_json_file_path, "r", encoding="utf-8") as json_file:
482
  textract_data = json.load(json_file)
483
  except json.JSONDecodeError:
484
  print("Error: Failed to parse Textract JSON file. Returning empty data.")
 
492
  if "Blocks" in textract_data:
493
  print("Need to convert Textract JSON to app format.")
494
  try:
495
+
496
  textract_data = restructure_textract_output(textract_data, page_sizes_df)
497
+ return (
498
+ textract_data,
499
+ False,
500
+ log_files_output_paths,
501
+ ) # Successfully converted
502
+
503
  except Exception as e:
504
  print("Failed to convert JSON data to app format due to:", e)
505
  return {}, True, log_files_output_paths # Conversion failed
506
  else:
507
  print("Invalid Textract JSON format: 'Blocks' missing.")
508
+ # print("textract data:", textract_data)
509
+ return (
510
+ {},
511
+ True,
512
+ log_files_output_paths,
513
+ ) # Return empty data if JSON is not recognized
514
 
515
+
516
+ def restructure_textract_output(textract_output: dict, page_sizes_df: pd.DataFrame):
517
  """
518
+ Reorganise Textract output from the bulk Textract analysis option on AWS
519
  into a format that works in this redaction app, reducing size.
520
  """
521
  pages_dict = {}
 
524
  document_metadata = textract_output.get("DocumentMetadata", {})
525
 
526
  # For efficient lookup, set 'page' as index if it's not already
527
+ if "page" in page_sizes_df.columns:
528
+ page_sizes_df = page_sizes_df.set_index("page")
529
 
530
  for block in textract_output.get("Blocks", []):
531
  page_no = block.get("Page", 1) # Default to 1 if missing
 
533
  # --- Geometry Conversion Logic ---
534
  try:
535
  page_info = page_sizes_df.loc[page_no]
536
+ cb_width = page_info["cropbox_width"]
537
+ cb_height = page_info["cropbox_height"]
538
+ mb_width = page_info["mediabox_width"]
539
+ mb_height = page_info["mediabox_height"]
540
+ cb_x_offset = page_info["cropbox_x_offset"]
541
+ cb_y_offset_top = page_info["cropbox_y_offset_from_top"]
542
 
543
  # Check if conversion is needed (and avoid division by zero)
544
  needs_conversion = (
545
+ (abs(cb_width - mb_width) > 1e-6 or abs(cb_height - mb_height) > 1e-6)
546
+ and mb_width > 1e-6
547
+ and mb_height > 1e-6
548
+ ) # Avoid division by zero
549
 
550
+ if needs_conversion and "Geometry" in block:
551
+ geometry = block["Geometry"] # Work directly on the block's geometry
552
 
553
  # --- Convert BoundingBox ---
554
+ if "BoundingBox" in geometry:
555
+ bbox = geometry["BoundingBox"]
556
+ old_left = bbox["Left"]
557
+ old_top = bbox["Top"]
558
+ old_width = bbox["Width"]
559
+ old_height = bbox["Height"]
560
 
561
  # Calculate absolute coordinates within CropBox
562
  abs_cb_x = old_left * cb_width
 
569
  abs_mb_y = cb_y_offset_top + abs_cb_y
570
 
571
  # Convert back to normalized coordinates relative to MediaBox
572
+ bbox["Left"] = abs_mb_x / mb_width
573
+ bbox["Top"] = abs_mb_y / mb_height
574
+ bbox["Width"] = abs_cb_width / mb_width
575
+ bbox["Height"] = abs_cb_height / mb_height
576
  except KeyError:
577
+ print(
578
+ f"Warning: Page number {page_no} not found in page_sizes_df. Skipping coordinate conversion for this block."
579
+ )
580
  # Decide how to handle missing page info: skip conversion, raise error, etc.
581
  except ZeroDivisionError:
582
+ print(
583
+ f"Warning: MediaBox width or height is zero for page {page_no}. Skipping coordinate conversion for this block."
584
+ )
585
 
586
  # Initialise page structure if not already present
587
  if page_no not in pages_dict:
 
589
 
590
  # Keep only essential fields to reduce size
591
  filtered_block = {
592
+ key: block[key]
593
+ for key in [
594
+ "BlockType",
595
+ "Confidence",
596
+ "Text",
597
+ "Geometry",
598
+ "Page",
599
+ "Id",
600
+ "Relationships",
601
+ ]
602
  if key in block
603
  }
604
+
605
  pages_dict[page_no]["data"]["Blocks"].append(filtered_block)
606
 
607
  # Convert pages dictionary to a sorted list
608
  structured_output = {
609
  "DocumentMetadata": document_metadata, # Store metadata separately
610
+ "pages": [pages_dict[page] for page in sorted(pages_dict.keys())],
611
  }
612
 
613
  return structured_output
tools/cli_usage_logger.py CHANGED
@@ -5,29 +5,29 @@ This module provides functionality to log usage data from CLI operations to CSV
5
 
6
  import csv
7
  import os
8
- import time
9
  import uuid
10
  from datetime import datetime
11
  from pathlib import Path
12
- from typing import Any, List, Optional
 
13
  import boto3
14
- import botocore
15
  from tools.aws_functions import upload_log_file_to_s3
16
  from tools.config import (
17
- USAGE_LOGS_FOLDER,
18
- SAVE_LOGS_TO_CSV,
19
- SAVE_LOGS_TO_DYNAMODB,
20
- USAGE_LOG_DYNAMODB_TABLE_NAME,
21
- DYNAMODB_USAGE_LOG_HEADERS,
22
  CSV_USAGE_LOG_HEADERS,
23
  DISPLAY_FILE_NAMES_IN_LOGS,
 
 
24
  HOST_NAME,
25
- AWS_REGION,
26
- AWS_ACCESS_KEY,
27
- AWS_SECRET_KEY,
28
  RUN_AWS_FUNCTIONS,
29
  S3_USAGE_LOGS_FOLDER,
30
- DOCUMENT_REDACTION_BUCKET
 
 
 
31
  )
32
 
33
 
@@ -36,11 +36,11 @@ class CLIUsageLogger:
36
  A simplified usage logger for CLI operations that mimics the functionality
37
  of the Gradio CSVLogger_custom class.
38
  """
39
-
40
  def __init__(self, dataset_file_name: str = "usage_log.csv"):
41
  """
42
  Initialize the CLI usage logger.
43
-
44
  Args:
45
  dataset_file_name: Name of the CSV file to store logs
46
  """
@@ -48,34 +48,36 @@ class CLIUsageLogger:
48
  self.flagging_dir = Path(USAGE_LOGS_FOLDER)
49
  self.dataset_filepath = None
50
  self.headers = None
51
-
52
  def setup(self, headers: List[str]):
53
  """
54
  Setup the logger with the specified headers.
55
-
56
  Args:
57
  headers: List of column headers for the CSV file
58
  """
59
  self.headers = headers
60
  self._create_dataset_file()
61
-
62
  def _create_dataset_file(self):
63
  """Create the dataset CSV file with headers if it doesn't exist."""
64
  os.makedirs(self.flagging_dir, exist_ok=True)
65
-
66
  # Add ID and timestamp to headers (matching custom_csvlogger.py structure)
67
  full_headers = self.headers + ["id", "timestamp"]
68
-
69
  self.dataset_filepath = self.flagging_dir / self.dataset_file_name
70
-
71
  if not Path(self.dataset_filepath).exists():
72
- with open(self.dataset_filepath, "w", newline="", encoding="utf-8") as csvfile:
 
 
73
  writer = csv.writer(csvfile)
74
  writer.writerow(full_headers)
75
  print(f"Created usage log file at: {self.dataset_filepath}")
76
  else:
77
  print(f"Using existing usage log file at: {self.dataset_filepath}")
78
-
79
  def log_usage(
80
  self,
81
  data: List[Any],
@@ -86,11 +88,11 @@ class CLIUsageLogger:
86
  s3_key_prefix: str = None,
87
  dynamodb_table_name: str = None,
88
  dynamodb_headers: List[str] = None,
89
- replacement_headers: List[str] = None
90
  ) -> int:
91
  """
92
  Log usage data to CSV and optionally DynamoDB and S3.
93
-
94
  Args:
95
  data: List of data values to log
96
  save_to_csv: Whether to save to CSV (defaults to config setting)
@@ -101,17 +103,17 @@ class CLIUsageLogger:
101
  dynamodb_table_name: DynamoDB table name (defaults to config setting)
102
  dynamodb_headers: DynamoDB headers (defaults to config setting)
103
  replacement_headers: Replacement headers for CSV (defaults to config setting)
104
-
105
  Returns:
106
  Number of lines written
107
  """
108
  # Use config defaults if not specified
109
  if save_to_csv is None:
110
- save_to_csv = SAVE_LOGS_TO_CSV == 'True'
111
  if save_to_dynamodb is None:
112
- save_to_dynamodb = SAVE_LOGS_TO_DYNAMODB == 'True'
113
  if save_to_s3 is None:
114
- save_to_s3 = RUN_AWS_FUNCTIONS == "1" and SAVE_LOGS_TO_CSV == 'True'
115
  if s3_bucket is None:
116
  s3_bucket = DOCUMENT_REDACTION_BUCKET
117
  if s3_key_prefix is None:
@@ -122,18 +124,22 @@ class CLIUsageLogger:
122
  dynamodb_headers = DYNAMODB_USAGE_LOG_HEADERS
123
  if replacement_headers is None:
124
  replacement_headers = CSV_USAGE_LOG_HEADERS
125
-
126
  # Generate unique ID and add timestamp (matching custom_csvlogger.py structure)
127
  generated_id = str(uuid.uuid4())
128
- timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')[:-3] # Correct format for Amazon Athena
 
 
129
  csv_data = data + [generated_id, timestamp]
130
-
131
  line_count = 0
132
-
133
  # Save to CSV
134
  if save_to_csv and self.dataset_filepath:
135
  try:
136
- with open(self.dataset_filepath, "a", newline="", encoding="utf-8-sig") as csvfile:
 
 
137
  writer = csv.writer(csvfile)
138
  writer.writerow(csv_data)
139
  line_count = 1
@@ -150,70 +156,86 @@ class CLIUsageLogger:
150
  s3_key=s3_key_prefix,
151
  s3_bucket=s3_bucket,
152
  RUN_AWS_FUNCTIONS=RUN_AWS_FUNCTIONS,
153
- SAVE_LOGS_TO_CSV=SAVE_LOGS_TO_CSV
154
  )
155
  print(f"S3 upload result: {upload_result}")
156
  except Exception as e:
157
  print(f"Error uploading log file to S3: {e}")
158
-
159
  # Save to DynamoDB
160
  if save_to_dynamodb and dynamodb_table_name and dynamodb_headers:
161
  try:
162
  # Initialize DynamoDB client
163
  if AWS_ACCESS_KEY and AWS_SECRET_KEY:
164
  dynamodb = boto3.resource(
165
- 'dynamodb',
166
  region_name=AWS_REGION,
167
  aws_access_key_id=AWS_ACCESS_KEY,
168
- aws_secret_access_key=AWS_SECRET_KEY
169
  )
170
  else:
171
- dynamodb = boto3.resource('dynamodb', region_name=AWS_REGION)
172
-
173
  table = dynamodb.Table(dynamodb_table_name)
174
-
175
  # Generate unique ID
176
  generated_id = str(uuid.uuid4())
177
-
178
  # Prepare the DynamoDB item
179
  item = {
180
- 'id': generated_id,
181
- 'timestamp': timestamp,
182
  }
183
-
184
  # Map the headers to values
185
- item.update({header: str(value) for header, value in zip(dynamodb_headers, data)})
186
-
 
 
 
 
 
187
  table.put_item(Item=item)
188
  print("Successfully uploaded usage log to DynamoDB")
189
-
190
  except Exception as e:
191
  print(f"Could not upload usage log to DynamoDB: {e}")
192
-
193
  return line_count
194
 
195
 
196
  def create_cli_usage_logger() -> CLIUsageLogger:
197
  """
198
  Create and setup a CLI usage logger with the standard headers.
199
-
200
  Returns:
201
  Configured CLIUsageLogger instance
202
  """
203
  # Parse CSV headers from config
204
  import json
 
205
  try:
206
  headers = json.loads(CSV_USAGE_LOG_HEADERS)
207
- except:
 
208
  # Fallback headers if parsing fails
209
  headers = [
210
- "session_hash_textbox", "doc_full_file_name_textbox", "data_full_file_name_textbox",
211
- "actual_time_taken_number", "total_page_count", "textract_query_number",
212
- "pii_detection_method", "comprehend_query_number", "cost_code",
213
- "textract_handwriting_signature", "host_name_textbox", "text_extraction_method",
214
- "is_this_a_textract_api_call", "task"
 
 
 
 
 
 
 
 
 
215
  ]
216
-
217
  logger = CLIUsageLogger()
218
  logger.setup(headers)
219
  return logger
@@ -237,11 +259,11 @@ def log_redaction_usage(
237
  save_to_dynamodb: bool = None,
238
  save_to_s3: bool = None,
239
  s3_bucket: str = None,
240
- s3_key_prefix: str = None
241
  ):
242
  """
243
  Log redaction usage data using the provided logger.
244
-
245
  Args:
246
  logger: CLIUsageLogger instance
247
  session_hash: Session identifier
@@ -263,7 +285,7 @@ def log_redaction_usage(
263
  s3_key_prefix: S3 key prefix (overrides config default)
264
  """
265
  # Use placeholder names if not displaying file names in logs
266
- if DISPLAY_FILE_NAMES_IN_LOGS != 'True':
267
  if doc_file_name:
268
  doc_file_name = "document"
269
  data_file_name = ""
@@ -275,7 +297,7 @@ def log_redaction_usage(
275
  data_file_name = data_file_name
276
 
277
  rounded_time_taken = round(time_taken, 2)
278
-
279
  data = [
280
  session_hash,
281
  doc_file_name,
@@ -290,13 +312,13 @@ def log_redaction_usage(
290
  HOST_NAME,
291
  text_extraction_method,
292
  is_textract_call,
293
- task
294
  ]
295
-
296
  logger.log_usage(
297
- data,
298
  save_to_dynamodb=save_to_dynamodb,
299
  save_to_s3=save_to_s3,
300
  s3_bucket=s3_bucket,
301
- s3_key_prefix=s3_key_prefix
302
  )
 
5
 
6
  import csv
7
  import os
 
8
  import uuid
9
  from datetime import datetime
10
  from pathlib import Path
11
+ from typing import Any, List
12
+
13
  import boto3
14
+
15
  from tools.aws_functions import upload_log_file_to_s3
16
  from tools.config import (
17
+ AWS_ACCESS_KEY,
18
+ AWS_REGION,
19
+ AWS_SECRET_KEY,
 
 
20
  CSV_USAGE_LOG_HEADERS,
21
  DISPLAY_FILE_NAMES_IN_LOGS,
22
+ DOCUMENT_REDACTION_BUCKET,
23
+ DYNAMODB_USAGE_LOG_HEADERS,
24
  HOST_NAME,
 
 
 
25
  RUN_AWS_FUNCTIONS,
26
  S3_USAGE_LOGS_FOLDER,
27
+ SAVE_LOGS_TO_CSV,
28
+ SAVE_LOGS_TO_DYNAMODB,
29
+ USAGE_LOG_DYNAMODB_TABLE_NAME,
30
+ USAGE_LOGS_FOLDER,
31
  )
32
 
33
 
 
36
  A simplified usage logger for CLI operations that mimics the functionality
37
  of the Gradio CSVLogger_custom class.
38
  """
39
+
40
  def __init__(self, dataset_file_name: str = "usage_log.csv"):
41
  """
42
  Initialize the CLI usage logger.
43
+
44
  Args:
45
  dataset_file_name: Name of the CSV file to store logs
46
  """
 
48
  self.flagging_dir = Path(USAGE_LOGS_FOLDER)
49
  self.dataset_filepath = None
50
  self.headers = None
51
+
52
  def setup(self, headers: List[str]):
53
  """
54
  Setup the logger with the specified headers.
55
+
56
  Args:
57
  headers: List of column headers for the CSV file
58
  """
59
  self.headers = headers
60
  self._create_dataset_file()
61
+
62
  def _create_dataset_file(self):
63
  """Create the dataset CSV file with headers if it doesn't exist."""
64
  os.makedirs(self.flagging_dir, exist_ok=True)
65
+
66
  # Add ID and timestamp to headers (matching custom_csvlogger.py structure)
67
  full_headers = self.headers + ["id", "timestamp"]
68
+
69
  self.dataset_filepath = self.flagging_dir / self.dataset_file_name
70
+
71
  if not Path(self.dataset_filepath).exists():
72
+ with open(
73
+ self.dataset_filepath, "w", newline="", encoding="utf-8"
74
+ ) as csvfile:
75
  writer = csv.writer(csvfile)
76
  writer.writerow(full_headers)
77
  print(f"Created usage log file at: {self.dataset_filepath}")
78
  else:
79
  print(f"Using existing usage log file at: {self.dataset_filepath}")
80
+
81
  def log_usage(
82
  self,
83
  data: List[Any],
 
88
  s3_key_prefix: str = None,
89
  dynamodb_table_name: str = None,
90
  dynamodb_headers: List[str] = None,
91
+ replacement_headers: List[str] = None,
92
  ) -> int:
93
  """
94
  Log usage data to CSV and optionally DynamoDB and S3.
95
+
96
  Args:
97
  data: List of data values to log
98
  save_to_csv: Whether to save to CSV (defaults to config setting)
 
103
  dynamodb_table_name: DynamoDB table name (defaults to config setting)
104
  dynamodb_headers: DynamoDB headers (defaults to config setting)
105
  replacement_headers: Replacement headers for CSV (defaults to config setting)
106
+
107
  Returns:
108
  Number of lines written
109
  """
110
  # Use config defaults if not specified
111
  if save_to_csv is None:
112
+ save_to_csv = SAVE_LOGS_TO_CSV == "True"
113
  if save_to_dynamodb is None:
114
+ save_to_dynamodb = SAVE_LOGS_TO_DYNAMODB == "True"
115
  if save_to_s3 is None:
116
+ save_to_s3 = RUN_AWS_FUNCTIONS == "1" and SAVE_LOGS_TO_CSV == "True"
117
  if s3_bucket is None:
118
  s3_bucket = DOCUMENT_REDACTION_BUCKET
119
  if s3_key_prefix is None:
 
124
  dynamodb_headers = DYNAMODB_USAGE_LOG_HEADERS
125
  if replacement_headers is None:
126
  replacement_headers = CSV_USAGE_LOG_HEADERS
127
+
128
  # Generate unique ID and add timestamp (matching custom_csvlogger.py structure)
129
  generated_id = str(uuid.uuid4())
130
+ timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[
131
+ :-3
132
+ ] # Correct format for Amazon Athena
133
  csv_data = data + [generated_id, timestamp]
134
+
135
  line_count = 0
136
+
137
  # Save to CSV
138
  if save_to_csv and self.dataset_filepath:
139
  try:
140
+ with open(
141
+ self.dataset_filepath, "a", newline="", encoding="utf-8-sig"
142
+ ) as csvfile:
143
  writer = csv.writer(csvfile)
144
  writer.writerow(csv_data)
145
  line_count = 1
 
156
  s3_key=s3_key_prefix,
157
  s3_bucket=s3_bucket,
158
  RUN_AWS_FUNCTIONS=RUN_AWS_FUNCTIONS,
159
+ SAVE_LOGS_TO_CSV=SAVE_LOGS_TO_CSV,
160
  )
161
  print(f"S3 upload result: {upload_result}")
162
  except Exception as e:
163
  print(f"Error uploading log file to S3: {e}")
164
+
165
  # Save to DynamoDB
166
  if save_to_dynamodb and dynamodb_table_name and dynamodb_headers:
167
  try:
168
  # Initialize DynamoDB client
169
  if AWS_ACCESS_KEY and AWS_SECRET_KEY:
170
  dynamodb = boto3.resource(
171
+ "dynamodb",
172
  region_name=AWS_REGION,
173
  aws_access_key_id=AWS_ACCESS_KEY,
174
+ aws_secret_access_key=AWS_SECRET_KEY,
175
  )
176
  else:
177
+ dynamodb = boto3.resource("dynamodb", region_name=AWS_REGION)
178
+
179
  table = dynamodb.Table(dynamodb_table_name)
180
+
181
  # Generate unique ID
182
  generated_id = str(uuid.uuid4())
183
+
184
  # Prepare the DynamoDB item
185
  item = {
186
+ "id": generated_id,
187
+ "timestamp": timestamp,
188
  }
189
+
190
  # Map the headers to values
191
+ item.update(
192
+ {
193
+ header: str(value)
194
+ for header, value in zip(dynamodb_headers, data)
195
+ }
196
+ )
197
+
198
  table.put_item(Item=item)
199
  print("Successfully uploaded usage log to DynamoDB")
200
+
201
  except Exception as e:
202
  print(f"Could not upload usage log to DynamoDB: {e}")
203
+
204
  return line_count
205
 
206
 
207
  def create_cli_usage_logger() -> CLIUsageLogger:
208
  """
209
  Create and setup a CLI usage logger with the standard headers.
210
+
211
  Returns:
212
  Configured CLIUsageLogger instance
213
  """
214
  # Parse CSV headers from config
215
  import json
216
+
217
  try:
218
  headers = json.loads(CSV_USAGE_LOG_HEADERS)
219
+ except Exception as e:
220
+ print(f"Error parsing CSV usage log headers: {e}")
221
  # Fallback headers if parsing fails
222
  headers = [
223
+ "session_hash_textbox",
224
+ "doc_full_file_name_textbox",
225
+ "data_full_file_name_textbox",
226
+ "actual_time_taken_number",
227
+ "total_page_count",
228
+ "textract_query_number",
229
+ "pii_detection_method",
230
+ "comprehend_query_number",
231
+ "cost_code",
232
+ "textract_handwriting_signature",
233
+ "host_name_textbox",
234
+ "text_extraction_method",
235
+ "is_this_a_textract_api_call",
236
+ "task",
237
  ]
238
+
239
  logger = CLIUsageLogger()
240
  logger.setup(headers)
241
  return logger
 
259
  save_to_dynamodb: bool = None,
260
  save_to_s3: bool = None,
261
  s3_bucket: str = None,
262
+ s3_key_prefix: str = None,
263
  ):
264
  """
265
  Log redaction usage data using the provided logger.
266
+
267
  Args:
268
  logger: CLIUsageLogger instance
269
  session_hash: Session identifier
 
285
  s3_key_prefix: S3 key prefix (overrides config default)
286
  """
287
  # Use placeholder names if not displaying file names in logs
288
+ if DISPLAY_FILE_NAMES_IN_LOGS != "True":
289
  if doc_file_name:
290
  doc_file_name = "document"
291
  data_file_name = ""
 
297
  data_file_name = data_file_name
298
 
299
  rounded_time_taken = round(time_taken, 2)
300
+
301
  data = [
302
  session_hash,
303
  doc_file_name,
 
312
  HOST_NAME,
313
  text_extraction_method,
314
  is_textract_call,
315
+ task,
316
  ]
317
+
318
  logger.log_usage(
319
+ data,
320
  save_to_dynamodb=save_to_dynamodb,
321
  save_to_s3=save_to_s3,
322
  s3_bucket=s3_bucket,
323
+ s3_key_prefix=s3_key_prefix,
324
  )
tools/config.py CHANGED
@@ -1,147 +1,171 @@
 
1
  import os
2
- import tempfile
3
  import socket
4
- import logging
5
  from datetime import datetime
 
 
6
  from dotenv import load_dotenv
7
  from tldextract import TLDExtract
8
- from typing import List
9
 
10
  today_rev = datetime.now().strftime("%Y%m%d")
11
  HOST_NAME = socket.gethostname()
12
 
 
13
  def _get_env_list(env_var_name: str) -> List[str]:
14
  """Parses a comma-separated environment variable into a list of strings."""
15
- value = env_var_name[1:-1].strip().replace('\"', '').replace("\'","")
16
  if not value:
17
  return []
18
  # Split by comma and filter out any empty strings that might result from extra commas
19
- return [s.strip() for s in value.split(',') if s.strip()]
 
20
 
21
  # Set or retrieve configuration variables for the redaction app
22
 
23
- def get_or_create_env_var(var_name:str, default_value:str, print_val:bool=False):
24
- '''
 
25
  Get an environmental variable, and set it to a default value if it doesn't exist
26
- '''
27
  # Get the environment variable if it exists
28
  value = os.environ.get(var_name)
29
-
30
  # If it doesn't exist, set the environment variable to the default value
31
  if value is None:
32
  os.environ[var_name] = default_value
33
  value = default_value
34
 
35
- if print_val == True:
36
- print(f'The value of {var_name} is {value}')
37
-
38
  return value
39
 
 
40
  def add_folder_to_path(folder_path: str):
41
- '''
42
  Check if a folder exists on your system. If so, get the absolute path and then add it to the system Path variable if it doesn't already exist. Function is only relevant for locally-created executable files based on this app (when using pyinstaller it creates a _internal folder that contains tesseract and poppler. These need to be added to the system path to enable the app to run)
43
- '''
44
 
45
  if os.path.exists(folder_path) and os.path.isdir(folder_path):
46
- #print(folder_path, "folder exists.")
47
 
48
  # Resolve relative path to absolute path
49
  absolute_path = os.path.abspath(folder_path)
50
 
51
- current_path = os.environ['PATH']
52
  if absolute_path not in current_path.split(os.pathsep):
53
  full_path_extension = absolute_path + os.pathsep + current_path
54
- os.environ['PATH'] = full_path_extension
55
- #print(f"Updated PATH with: ", full_path_extension)
56
  else:
57
  pass
58
- #print(f"Directory {folder_path} already exists in PATH.")
59
  else:
60
  print(f"Folder not found at {folder_path} - not added to PATH")
61
 
 
62
  ###
63
  # LOAD CONFIG FROM ENV FILE
64
  ###
65
 
66
- CONFIG_FOLDER = get_or_create_env_var('CONFIG_FOLDER', 'config/')
67
 
68
  # If you have an aws_config env file in the config folder, you can load in app variables this way, e.g. 'config/app_config.env'
69
- APP_CONFIG_PATH = get_or_create_env_var('APP_CONFIG_PATH', CONFIG_FOLDER + 'app_config.env') # e.g. config/app_config.env
 
 
70
 
71
  if APP_CONFIG_PATH:
72
  if os.path.exists(APP_CONFIG_PATH):
73
  print(f"Loading app variables from config file {APP_CONFIG_PATH}")
74
  load_dotenv(APP_CONFIG_PATH)
75
- else: print("App config file not found at location:", APP_CONFIG_PATH)
 
76
 
77
  ###
78
  # AWS OPTIONS
79
  ###
80
 
81
  # If you have an aws_config env file in the config folder, you can load in AWS keys this way, e.g. 'env/aws_config.env'
82
- AWS_CONFIG_PATH = get_or_create_env_var('AWS_CONFIG_PATH', '') # e.g. config/aws_config.env
 
 
83
 
84
  if AWS_CONFIG_PATH:
85
  if os.path.exists(AWS_CONFIG_PATH):
86
  print(f"Loading AWS variables from config file {AWS_CONFIG_PATH}")
87
  load_dotenv(AWS_CONFIG_PATH)
88
- else: print("AWS config file not found at location:", AWS_CONFIG_PATH)
 
89
 
90
  RUN_AWS_FUNCTIONS = get_or_create_env_var("RUN_AWS_FUNCTIONS", "0")
91
 
92
- AWS_REGION = get_or_create_env_var('AWS_REGION', '')
93
 
94
- AWS_CLIENT_ID = get_or_create_env_var('AWS_CLIENT_ID', '')
95
 
96
- AWS_CLIENT_SECRET = get_or_create_env_var('AWS_CLIENT_SECRET', '')
97
 
98
- AWS_USER_POOL_ID = get_or_create_env_var('AWS_USER_POOL_ID', '')
99
 
100
- AWS_ACCESS_KEY = get_or_create_env_var('AWS_ACCESS_KEY', '')
101
- #if AWS_ACCESS_KEY: print(f'AWS_ACCESS_KEY found in environment variables')
102
 
103
- AWS_SECRET_KEY = get_or_create_env_var('AWS_SECRET_KEY', '')
104
- #if AWS_SECRET_KEY: print(f'AWS_SECRET_KEY found in environment variables')
105
 
106
- DOCUMENT_REDACTION_BUCKET = get_or_create_env_var('DOCUMENT_REDACTION_BUCKET', '')
107
 
108
  # Should the app prioritise using AWS SSO over using API keys stored in environment variables/secrets (defaults to yes)
109
- PRIORITISE_SSO_OVER_AWS_ENV_ACCESS_KEYS = get_or_create_env_var('PRIORITISE_SSO_OVER_AWS_ENV_ACCESS_KEYS', '1')
 
 
110
 
111
  # Custom headers e.g. if routing traffic through Cloudfront
112
  # Retrieving or setting CUSTOM_HEADER
113
- CUSTOM_HEADER = get_or_create_env_var('CUSTOM_HEADER', '')
114
 
115
  # Retrieving or setting CUSTOM_HEADER_VALUE
116
- CUSTOM_HEADER_VALUE = get_or_create_env_var('CUSTOM_HEADER_VALUE', '')
117
 
118
  ###
119
  # Image options
120
  ###
121
- IMAGES_DPI = float(get_or_create_env_var('IMAGES_DPI', '300.0'))
122
- LOAD_TRUNCATED_IMAGES = get_or_create_env_var('LOAD_TRUNCATED_IMAGES', 'True')
123
- MAX_IMAGE_PIXELS = get_or_create_env_var('MAX_IMAGE_PIXELS', '') # Changed to None if blank in file_conversion.py
 
 
124
 
125
  ###
126
  # File I/O options
127
  ###
128
 
129
- SESSION_OUTPUT_FOLDER = get_or_create_env_var('SESSION_OUTPUT_FOLDER', 'False') # i.e. do you want your input and output folders saved within a subfolder based on session hash value within output/input folders
 
 
130
 
131
- OUTPUT_FOLDER = get_or_create_env_var('GRADIO_OUTPUT_FOLDER', 'output/') # 'output/'
132
- INPUT_FOLDER = get_or_create_env_var('GRADIO_INPUT_FOLDER', 'input/') # 'input/'
133
 
134
  # Allow for files to be saved in a temporary folder for increased security in some instances
135
- if OUTPUT_FOLDER == "TEMP" or INPUT_FOLDER == "TEMP":
136
  # Create a temporary directory
137
  with tempfile.TemporaryDirectory() as temp_dir:
138
- print(f'Temporary directory created at: {temp_dir}')
139
 
140
- if OUTPUT_FOLDER == "TEMP": OUTPUT_FOLDER = temp_dir + "/"
141
- if INPUT_FOLDER == "TEMP": INPUT_FOLDER = temp_dir + "/"
 
 
142
 
143
- GRADIO_TEMP_DIR = get_or_create_env_var('GRADIO_TEMP_DIR', 'tmp/gradio_tmp/') # Default Gradio temp folder
144
- MPLCONFIGDIR = get_or_create_env_var('MPLCONFIGDIR', 'tmp/matplotlib_cache/') # Matplotlib cache folder
 
 
 
 
145
 
146
  ###
147
  # LOGGING OPTIONS
@@ -150,57 +174,82 @@ MPLCONFIGDIR = get_or_create_env_var('MPLCONFIGDIR', 'tmp/matplotlib_cache/') #
150
  # By default, logs are put into a subfolder of today's date and the host name of the instance running the app. This is to avoid at all possible the possibility of log files from one instance overwriting the logs of another instance on S3. If running the app on one system always, or just locally, it is not necessary to make the log folders so specific.
151
  # Another way to address this issue would be to write logs to another type of storage, e.g. database such as dynamodb. I may look into this in future.
152
 
153
- SAVE_LOGS_TO_CSV = get_or_create_env_var('SAVE_LOGS_TO_CSV', 'True')
154
 
155
- USE_LOG_SUBFOLDERS = get_or_create_env_var('USE_LOG_SUBFOLDERS', 'True')
156
 
157
- FEEDBACK_LOGS_FOLDER = get_or_create_env_var('FEEDBACK_LOGS_FOLDER', 'feedback/')
158
- ACCESS_LOGS_FOLDER = get_or_create_env_var('ACCESS_LOGS_FOLDER', 'logs/')
159
- USAGE_LOGS_FOLDER = get_or_create_env_var('USAGE_LOGS_FOLDER', 'usage/')
160
 
161
  if USE_LOG_SUBFOLDERS == "True":
162
- day_log_subfolder = today_rev + '/'
163
- host_name_subfolder = HOST_NAME + '/'
164
  full_log_subfolder = day_log_subfolder + host_name_subfolder
165
 
166
  FEEDBACK_LOGS_FOLDER = FEEDBACK_LOGS_FOLDER + full_log_subfolder
167
  ACCESS_LOGS_FOLDER = ACCESS_LOGS_FOLDER + full_log_subfolder
168
  USAGE_LOGS_FOLDER = USAGE_LOGS_FOLDER + full_log_subfolder
169
 
170
- S3_FEEDBACK_LOGS_FOLDER = get_or_create_env_var('S3_FEEDBACK_LOGS_FOLDER', 'feedback/' + full_log_subfolder)
171
- S3_ACCESS_LOGS_FOLDER = get_or_create_env_var('S3_ACCESS_LOGS_FOLDER', 'logs/' + full_log_subfolder)
172
- S3_USAGE_LOGS_FOLDER = get_or_create_env_var('S3_USAGE_LOGS_FOLDER', 'usage/' + full_log_subfolder)
 
 
 
 
 
 
173
 
174
  # Should the redacted file name be included in the logs? In some instances, the names of the files themselves could be sensitive, and should not be disclosed beyond the app. So, by default this is false.
175
- DISPLAY_FILE_NAMES_IN_LOGS = get_or_create_env_var('DISPLAY_FILE_NAMES_IN_LOGS', 'False')
 
 
176
 
177
  # Further customisation options for CSV logs
178
- CSV_ACCESS_LOG_HEADERS = get_or_create_env_var('CSV_ACCESS_LOG_HEADERS', '') # If blank, uses component labels
179
- CSV_FEEDBACK_LOG_HEADERS = get_or_create_env_var('CSV_FEEDBACK_LOG_HEADERS', '') # If blank, uses component labels
180
- CSV_USAGE_LOG_HEADERS = get_or_create_env_var('CSV_USAGE_LOG_HEADERS', '["session_hash_textbox", "doc_full_file_name_textbox", "data_full_file_name_textbox", "actual_time_taken_number", "total_page_count", "textract_query_number", "pii_detection_method", "comprehend_query_number", "cost_code", "textract_handwriting_signature", "host_name_textbox", "text_extraction_method", "is_this_a_textract_api_call", "task"]') # If blank, uses component labels
 
 
 
 
 
 
 
181
 
182
  ### DYNAMODB logs. Whether to save to DynamoDB, and the headers of the table
183
- SAVE_LOGS_TO_DYNAMODB = get_or_create_env_var('SAVE_LOGS_TO_DYNAMODB', 'False')
184
 
185
- ACCESS_LOG_DYNAMODB_TABLE_NAME = get_or_create_env_var('ACCESS_LOG_DYNAMODB_TABLE_NAME', 'redaction_access_log')
186
- DYNAMODB_ACCESS_LOG_HEADERS = get_or_create_env_var('DYNAMODB_ACCESS_LOG_HEADERS', '')
 
 
187
 
188
- FEEDBACK_LOG_DYNAMODB_TABLE_NAME = get_or_create_env_var('FEEDBACK_LOG_DYNAMODB_TABLE_NAME', 'redaction_feedback')
189
- DYNAMODB_FEEDBACK_LOG_HEADERS = get_or_create_env_var('DYNAMODB_FEEDBACK_LOG_HEADERS', '')
 
 
 
 
190
 
191
- USAGE_LOG_DYNAMODB_TABLE_NAME = get_or_create_env_var('USAGE_LOG_DYNAMODB_TABLE_NAME', 'redaction_usage')
192
- DYNAMODB_USAGE_LOG_HEADERS = get_or_create_env_var('DYNAMODB_USAGE_LOG_HEADERS', '')
 
 
193
 
194
  # Report logging to console?
195
- LOGGING = get_or_create_env_var('LOGGING', 'False')
196
 
197
- if LOGGING == 'True':
198
  # Configure logging
199
- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
 
 
200
 
201
- LOG_FILE_NAME = get_or_create_env_var('LOG_FILE_NAME', 'log.csv')
202
- USAGE_LOG_FILE_NAME = get_or_create_env_var('USAGE_LOG_FILE_NAME', LOG_FILE_NAME)
203
- FEEDBACK_LOG_FILE_NAME = get_or_create_env_var('FEEDBACK_LOG_FILE_NAME', LOG_FILE_NAME)
204
 
205
 
206
  ###
@@ -208,164 +257,265 @@ FEEDBACK_LOG_FILE_NAME = get_or_create_env_var('FEEDBACK_LOG_FILE_NAME', LOG_FIL
208
  ###
209
 
210
  # Create Tesseract and Poppler folders if you have installed them locally
211
- TESSERACT_FOLDER = get_or_create_env_var('TESSERACT_FOLDER', "") # # If installing for Windows, install Tesseract 5.5.0 from here: https://github.com/UB-Mannheim/tesseract/wiki. Then this environment variable should point to the Tesseract folder e.g. tesseract/
212
- TESSERACT_DATA_FOLDER = get_or_create_env_var('TESSERACT_DATA_FOLDER', "/usr/share/tessdata")
213
- POPPLER_FOLDER = get_or_create_env_var('POPPLER_FOLDER', "") # If installing on Windows,install Poppler from here https://github.com/oschwartz10612/poppler-windows. This variable needs to point to the poppler bin folder e.g. poppler/poppler-24.02.0/Library/bin/
214
-
215
- if TESSERACT_FOLDER: add_folder_to_path(TESSERACT_FOLDER)
216
- if POPPLER_FOLDER: add_folder_to_path(POPPLER_FOLDER)
 
 
 
 
 
 
 
 
217
 
218
  # List of models to use for text extraction and PII detection
219
  # Text extraction models
220
- SELECTABLE_TEXT_EXTRACT_OPTION = get_or_create_env_var('SELECTABLE_TEXT_EXTRACT_OPTION', "Local model - selectable text")
221
- TESSERACT_TEXT_EXTRACT_OPTION = get_or_create_env_var('TESSERACT_TEXT_EXTRACT_OPTION', "Local OCR model - PDFs without selectable text")
222
- TEXTRACT_TEXT_EXTRACT_OPTION = get_or_create_env_var('TEXTRACT_TEXT_EXTRACT_OPTION', "AWS Textract service - all PDF types")
 
 
 
 
 
 
223
 
224
  # PII detection models
225
- NO_REDACTION_PII_OPTION = get_or_create_env_var('NO_REDACTION_PII_OPTION', "Only extract text (no redaction)")
226
- LOCAL_PII_OPTION = get_or_create_env_var('LOCAL_PII_OPTION', "Local")
227
- AWS_PII_OPTION = get_or_create_env_var('AWS_PII_OPTION', "AWS Comprehend")
228
-
229
- SHOW_LOCAL_TEXT_EXTRACTION_OPTIONS = get_or_create_env_var('SHOW_LOCAL_TEXT_EXTRACTION_OPTIONS', 'True')
230
- SHOW_AWS_TEXT_EXTRACTION_OPTIONS = get_or_create_env_var('SHOW_AWS_TEXT_EXTRACTION_OPTIONS', 'True')
 
 
 
 
 
 
231
 
232
  # Show at least local options if everything mistakenly removed
233
- if SHOW_LOCAL_TEXT_EXTRACTION_OPTIONS != "True" and SHOW_AWS_TEXT_EXTRACTION_OPTIONS != "True":
 
 
 
234
  SHOW_LOCAL_TEXT_EXTRACTION_OPTIONS = "True"
235
 
236
  local_model_options = []
237
  aws_model_options = []
238
  text_extraction_models = []
239
 
240
- if SHOW_LOCAL_TEXT_EXTRACTION_OPTIONS == 'True':
241
  local_model_options.append(SELECTABLE_TEXT_EXTRACT_OPTION)
242
  local_model_options.append(TESSERACT_TEXT_EXTRACT_OPTION)
243
 
244
- if SHOW_AWS_TEXT_EXTRACTION_OPTIONS == 'True':
245
  aws_model_options.append(TEXTRACT_TEXT_EXTRACT_OPTION)
246
 
247
  TEXT_EXTRACTION_MODELS = local_model_options + aws_model_options
248
- DO_INITIAL_TABULAR_DATA_CLEAN = get_or_create_env_var('DO_INITIAL_TABULAR_DATA_CLEAN', 'True')
249
-
250
- SHOW_LOCAL_PII_DETECTION_OPTIONS = get_or_create_env_var('SHOW_LOCAL_PII_DETECTION_OPTIONS', 'True')
251
- SHOW_AWS_PII_DETECTION_OPTIONS = get_or_create_env_var('SHOW_AWS_PII_DETECTION_OPTIONS', 'True')
252
-
253
- if SHOW_LOCAL_PII_DETECTION_OPTIONS != "True" and SHOW_AWS_PII_DETECTION_OPTIONS != "True":
 
 
 
 
 
 
 
 
 
254
  SHOW_LOCAL_PII_DETECTION_OPTIONS = "True"
255
 
256
  local_model_options = [NO_REDACTION_PII_OPTION]
257
  aws_model_options = []
258
  pii_detection_models = []
259
 
260
- if SHOW_LOCAL_PII_DETECTION_OPTIONS == 'True':
261
  local_model_options.append(LOCAL_PII_OPTION)
262
 
263
- if SHOW_AWS_PII_DETECTION_OPTIONS == 'True':
264
  aws_model_options.append(AWS_PII_OPTION)
265
 
266
  PII_DETECTION_MODELS = local_model_options + aws_model_options
267
 
268
  if SHOW_AWS_TEXT_EXTRACTION_OPTIONS == "True":
269
- DEFAULT_TEXT_EXTRACTION_MODEL = get_or_create_env_var('DEFAULT_TEXT_EXTRACTION_MODEL', TEXTRACT_TEXT_EXTRACT_OPTION)
 
 
270
  else:
271
- DEFAULT_TEXT_EXTRACTION_MODEL = get_or_create_env_var('DEFAULT_TEXT_EXTRACTION_MODEL', SELECTABLE_TEXT_EXTRACT_OPTION)
 
 
272
 
273
  if SHOW_AWS_PII_DETECTION_OPTIONS == "True":
274
- DEFAULT_PII_DETECTION_MODEL = get_or_create_env_var('DEFAULT_PII_DETECTION_MODEL', AWS_PII_OPTION)
 
 
275
  else:
276
- DEFAULT_PII_DETECTION_MODEL = get_or_create_env_var('DEFAULT_PII_DETECTION_MODEL', LOCAL_PII_OPTION)
 
 
277
 
278
  # Create list of PII detection models for tabular redaction
279
  TABULAR_PII_DETECTION_MODELS = PII_DETECTION_MODELS.copy()
280
  if NO_REDACTION_PII_OPTION in TABULAR_PII_DETECTION_MODELS:
281
  TABULAR_PII_DETECTION_MODELS.remove(NO_REDACTION_PII_OPTION)
282
 
283
- DEFAULT_TEXT_COLUMNS = get_or_create_env_var('DEFAULT_TEXT_COLUMNS', "[]")
284
- DEFAULT_EXCEL_SHEETS = get_or_create_env_var('DEFAULT_EXCEL_SHEETS', "[]")
285
 
286
- DEFAULT_TABULAR_ANONYMISATION_STRATEGY = get_or_create_env_var('DEFAULT_TABULAR_ANONYMISATION_STRATEGY', "redact completely")
 
 
287
 
288
  ### Local OCR model - Tesseract vs PaddleOCR
289
- CHOSEN_LOCAL_OCR_MODEL = get_or_create_env_var('CHOSEN_LOCAL_OCR_MODEL', "tesseract") # Choose between "tesseract", "hybrid", and "paddle". "paddle" will only return whole line text extraction, and so will only work for OCR, not redaction. "hybrid" is a combination of the two - first pass through the redactions will be done with Tesseract, and then a second pass will be done with PaddleOCR on words with low confidence.
 
 
290
 
291
- PREPROCESS_LOCAL_OCR_IMAGES = get_or_create_env_var('PREPROCESS_LOCAL_OCR_IMAGES', "True") # Whether to try and preprocess images before extracting text. NOTE: I have found in testing that this often results in WORSE results for scanned pages, so it is default False
 
 
292
 
293
  # Entities for redaction
294
- CHOSEN_COMPREHEND_ENTITIES = get_or_create_env_var('CHOSEN_COMPREHEND_ENTITIES', "['BANK_ACCOUNT_NUMBER','BANK_ROUTING','CREDIT_DEBIT_NUMBER','CREDIT_DEBIT_CVV','CREDIT_DEBIT_EXPIRY','PIN','EMAIL','ADDRESS','NAME','PHONE', 'PASSPORT_NUMBER','DRIVER_ID', 'USERNAME','PASSWORD', 'IP_ADDRESS','MAC_ADDRESS', 'LICENSE_PLATE','VEHICLE_IDENTIFICATION_NUMBER','UK_NATIONAL_INSURANCE_NUMBER', 'INTERNATIONAL_BANK_ACCOUNT_NUMBER','SWIFT_CODE','UK_NATIONAL_HEALTH_SERVICE_NUMBER']")
 
 
 
295
 
296
- FULL_COMPREHEND_ENTITY_LIST = get_or_create_env_var('FULL_COMPREHEND_ENTITY_LIST', "['BANK_ACCOUNT_NUMBER','BANK_ROUTING','CREDIT_DEBIT_NUMBER','CREDIT_DEBIT_CVV','CREDIT_DEBIT_EXPIRY','PIN','EMAIL','ADDRESS','NAME','PHONE','SSN','DATE_TIME','PASSPORT_NUMBER','DRIVER_ID','URL','AGE','USERNAME','PASSWORD','AWS_ACCESS_KEY','AWS_SECRET_KEY','IP_ADDRESS','MAC_ADDRESS','ALL','LICENSE_PLATE','VEHICLE_IDENTIFICATION_NUMBER','UK_NATIONAL_INSURANCE_NUMBER','CA_SOCIAL_INSURANCE_NUMBER','US_INDIVIDUAL_TAX_IDENTIFICATION_NUMBER','UK_UNIQUE_TAXPAYER_REFERENCE_NUMBER','IN_PERMANENT_ACCOUNT_NUMBER','IN_NREGA','INTERNATIONAL_BANK_ACCOUNT_NUMBER','SWIFT_CODE','UK_NATIONAL_HEALTH_SERVICE_NUMBER','CA_HEALTH_NUMBER','IN_AADHAAR','IN_VOTER_NUMBER', 'CUSTOM_FUZZY']")
 
 
 
297
 
298
  # Entities for local PII redaction option
299
- CHOSEN_REDACT_ENTITIES = get_or_create_env_var('CHOSEN_REDACT_ENTITIES', "['TITLES', 'PERSON', 'PHONE_NUMBER', 'EMAIL_ADDRESS', 'STREETNAME', 'UKPOSTCODE', 'CUSTOM']")
300
-
301
- FULL_ENTITY_LIST = get_or_create_env_var('FULL_ENTITY_LIST', "['TITLES', 'PERSON', 'PHONE_NUMBER', 'EMAIL_ADDRESS', 'STREETNAME', 'UKPOSTCODE', 'CREDIT_CARD', 'CRYPTO', 'DATE_TIME', 'IBAN_CODE', 'IP_ADDRESS', 'NRP', 'LOCATION', 'MEDICAL_LICENSE', 'URL', 'UK_NHS', 'CUSTOM', 'CUSTOM_FUZZY']")
302
-
303
- CUSTOM_ENTITIES = get_or_create_env_var('CUSTOM_ENTITIES', "['TITLES', 'UKPOSTCODE', 'STREETNAME', 'CUSTOM']")
304
-
305
-
306
-
307
- DEFAULT_HANDWRITE_SIGNATURE_CHECKBOX = get_or_create_env_var('DEFAULT_HANDWRITE_SIGNATURE_CHECKBOX', "['Extract handwriting']")
308
-
309
- HANDWRITE_SIGNATURE_TEXTBOX_FULL_OPTIONS = get_or_create_env_var('HANDWRITE_SIGNATURE_TEXTBOX_FULL_OPTIONS', "['Extract handwriting', 'Extract signatures']")
310
-
311
- if HANDWRITE_SIGNATURE_TEXTBOX_FULL_OPTIONS: HANDWRITE_SIGNATURE_TEXTBOX_FULL_OPTIONS = _get_env_list(HANDWRITE_SIGNATURE_TEXTBOX_FULL_OPTIONS)
312
-
313
- INCLUDE_FORM_EXTRACTION_TEXTRACT_OPTION = get_or_create_env_var('INCLUDE_FORM_EXTRACTION_TEXTRACT_OPTION', "False")
314
- INCLUDE_LAYOUT_EXTRACTION_TEXTRACT_OPTION = get_or_create_env_var('INCLUDE_LAYOUT_EXTRACTION_TEXTRACT_OPTION', "False")
315
- INCLUDE_TABLE_EXTRACTION_TEXTRACT_OPTION = get_or_create_env_var('INCLUDE_TABLE_EXTRACTION_TEXTRACT_OPTION', "False")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
316
 
317
  if INCLUDE_FORM_EXTRACTION_TEXTRACT_OPTION == "True":
318
- HANDWRITE_SIGNATURE_TEXTBOX_FULL_OPTIONS.append('Extract forms')
319
  if INCLUDE_LAYOUT_EXTRACTION_TEXTRACT_OPTION == "True":
320
- HANDWRITE_SIGNATURE_TEXTBOX_FULL_OPTIONS.append('Extract layout')
321
  if INCLUDE_TABLE_EXTRACTION_TEXTRACT_OPTION == "True":
322
- HANDWRITE_SIGNATURE_TEXTBOX_FULL_OPTIONS.append('Extract tables')
323
 
324
 
325
- DEFAULT_SEARCH_QUERY = get_or_create_env_var('DEFAULT_SEARCH_QUERY', '')
326
- DEFAULT_FUZZY_SPELLING_MISTAKES_NUM = int(get_or_create_env_var('DEFAULT_FUZZY_SPELLING_MISTAKES_NUM', '1'))
 
 
327
 
328
- DEFAULT_PAGE_MIN = int(get_or_create_env_var('DEFAULT_PAGE_MIN', '0'))
329
 
330
- DEFAULT_PAGE_MAX = int(get_or_create_env_var('DEFAULT_PAGE_MAX', '999'))
331
 
332
 
333
  # Number of pages to loop through before breaking the function and restarting from the last finished page (not currently activated).
334
- PAGE_BREAK_VALUE = int(get_or_create_env_var('PAGE_BREAK_VALUE', '99999'))
335
 
336
- MAX_TIME_VALUE = int(get_or_create_env_var('MAX_TIME_VALUE', '999999'))
337
- MAX_SIMULTANEOUS_FILES = int(get_or_create_env_var('MAX_SIMULTANEOUS_FILES', '10'))
338
- MAX_DOC_PAGES = int(get_or_create_env_var('MAX_DOC_PAGES', '3000'))
339
- MAX_TABLE_ROWS = int(get_or_create_env_var('MAX_TABLE_ROWS', '250000'))
340
- MAX_TABLE_COLUMNS = int(get_or_create_env_var('MAX_TABLE_COLUMNS', '100'))
341
- MAX_OPEN_TEXT_CHARACTERS = int(get_or_create_env_var('MAX_OPEN_TEXT_CHARACTERS', '50000'))
 
 
342
 
343
- CUSTOM_BOX_COLOUR = get_or_create_env_var("CUSTOM_BOX_COLOUR", "") # only "grey" is currently supported as a custom box colour
 
 
344
 
345
  ### Language selection options
346
 
347
  SHOW_LANGUAGE_SELECTION = get_or_create_env_var("SHOW_LANGUAGE_SELECTION", "False")
348
 
349
- DEFAULT_LANGUAGE_FULL_NAME = get_or_create_env_var("DEFAULT_LANGUAGE_FULL_NAME", "english")
350
- DEFAULT_LANGUAGE = get_or_create_env_var("DEFAULT_LANGUAGE", "en") # For tesseract, ensure the Tesseract language data (e.g., fra.traineddata) is installed on your system. You can find the relevant language packs here: https://github.com/tesseract-ocr/tessdata.
 
 
 
 
351
  # For paddle, ensure the paddle language data (e.g., fra.traineddata) is installed on your system. You can find information on supported languages here: https://www.paddleocr.ai/main/en/version3.x/algorithm/PP-OCRv5/PP-OCRv5_multi_languages.html
352
  # For AWS Comprehend, only English and Spanish are supported https://docs.aws.amazon.com/comprehend/latest/dg/how-pii.html ['en', 'es']
353
  # AWS Textract automatically detects the language of the document and supports the following languages: https://aws.amazon.com/textract/faqs/#topic-0. 'English, Spanish, Italian, Portuguese, French, German. Handwriting, Invoices and Receipts, Identity documents and Queries processing are in English only'
354
 
355
- textract_language_choices = get_or_create_env_var("textract_language_choices", "['en', 'es', 'fr', 'de', 'it', 'pt']")
356
- aws_comprehend_language_choices = get_or_create_env_var("aws_comprehend_language_choices", "['en', 'es']")
 
 
 
 
357
 
358
  # The choices that the user sees
359
- MAPPED_LANGUAGE_CHOICES = get_or_create_env_var("MAPPED_LANGUAGE_CHOICES", "['english', 'french', 'german', 'spanish', 'italian', 'dutch', 'portuguese', 'chinese', 'japanese', 'korean', 'lithuanian', 'macedonian', 'norwegian_bokmaal', 'polish', 'romanian', 'russian', 'slovenian', 'swedish', 'catalan', 'ukrainian']")
360
- LANGUAGE_CHOICES = get_or_create_env_var("LANGUAGE_CHOICES", "['en', 'fr', 'de', 'es', 'it', 'nl', 'pt', 'zh', 'ja', 'ko', 'lt', 'mk', 'nb', 'pl', 'ro', 'ru', 'sl', 'sv', 'ca', 'uk']")
 
 
 
 
 
 
361
 
362
  ###
363
  # Duplicate detection settings
364
  ###
365
- DEFAULT_DUPLICATE_DETECTION_THRESHOLD = float(get_or_create_env_var("DEFAULT_DUPLICATE_DETECTION_THRESHOLD", "0.95"))
366
- DEFAULT_MIN_CONSECUTIVE_PAGES = int(get_or_create_env_var("DEFAULT_MIN_CONSECUTIVE_PAGES", "1"))
367
- USE_GREEDY_DUPLICATE_DETECTION = get_or_create_env_var("USE_GREEDY_DUPLICATE_DETECTION", "True")
368
- DEFAULT_COMBINE_PAGES = get_or_create_env_var("DEFAULT_COMBINE_PAGES", "True") # Combine text from the same page number within a file. Alternative will enable line-level duplicate detection.
 
 
 
 
 
 
 
 
369
  DEFAULT_MIN_WORD_COUNT = int(get_or_create_env_var("DEFAULT_MIN_WORD_COUNT", "10"))
370
  REMOVE_DUPLICATE_ROWS = get_or_create_env_var("REMOVE_DUPLICATE_ROWS", "False")
371
 
@@ -373,118 +523,182 @@ REMOVE_DUPLICATE_ROWS = get_or_create_env_var("REMOVE_DUPLICATE_ROWS", "False")
373
  ###
374
  # File output options
375
  ###
376
- RETURN_PDF_END_OF_REDACTION = get_or_create_env_var("RETURN_PDF_END_OF_REDACTION", "True") # Return a redacted PDF at the end of the redaction task. Could be useful to set this to "False" if you want to ensure that the user always goes to the 'Review Redactions' tab before getting the final redacted PDF product.
 
 
377
 
378
- COMPRESS_REDACTED_PDF = get_or_create_env_var("COMPRESS_REDACTED_PDF","False") # On low memory systems, the compression options in pymupdf can cause the app to crash if the PDF is longer than 500 pages or so. Setting this to False will save the PDF only with a basic cleaning option enabled
 
 
379
 
380
  ###
381
  # APP RUN OPTIONS
382
  ###
383
 
384
- TLDEXTRACT_CACHE = get_or_create_env_var('TLDEXTRACT_CACHE', 'tmp/tld/')
385
- try: extract = TLDExtract(cache_dir=TLDEXTRACT_CACHE)
386
- except: extract = TLDExtract(cache_dir=None)
 
 
 
387
 
388
  # Get some environment variables and Launch the Gradio app
389
- COGNITO_AUTH = get_or_create_env_var('COGNITO_AUTH', '0')
390
 
391
- RUN_DIRECT_MODE = get_or_create_env_var('RUN_DIRECT_MODE', '0')
392
 
393
  # Direct mode configuration options
394
- DIRECT_MODE_DEFAULT_USER = get_or_create_env_var('DIRECT_MODE_DEFAULT_USER', '') # Default username for cli/direct mode requests
395
- DIRECT_MODE_TASK = get_or_create_env_var('DIRECT_MODE_TASK', 'redact') # 'redact' or 'deduplicate'
396
- DIRECT_MODE_INPUT_FILE = get_or_create_env_var('DIRECT_MODE_INPUT_FILE', '') # Path to input file
397
- DIRECT_MODE_OUTPUT_DIR = get_or_create_env_var('DIRECT_MODE_OUTPUT_DIR', OUTPUT_FOLDER) # Output directory
398
- DIRECT_MODE_DUPLICATE_TYPE = get_or_create_env_var('DIRECT_MODE_DUPLICATE_TYPE', 'pages') # 'pages' or 'tabular'
 
 
 
 
 
 
 
 
 
 
399
 
400
- MAX_QUEUE_SIZE = int(get_or_create_env_var('MAX_QUEUE_SIZE', '5'))
401
 
402
- MAX_FILE_SIZE = get_or_create_env_var('MAX_FILE_SIZE', '250mb').lower()
403
 
404
- GRADIO_SERVER_PORT = int(get_or_create_env_var('GRADIO_SERVER_PORT', '7860'))
405
 
406
- ROOT_PATH = get_or_create_env_var('ROOT_PATH', '')
407
 
408
- DEFAULT_CONCURRENCY_LIMIT = int(get_or_create_env_var('DEFAULT_CONCURRENCY_LIMIT', '3'))
409
 
410
- FILE_INPUT_HEIGHT = get_or_create_env_var('FILE_INPUT_HEIGHT', '200')
411
 
412
  ### ALLOW LIST
413
 
414
- GET_DEFAULT_ALLOW_LIST = get_or_create_env_var('GET_DEFAULT_ALLOW_LIST', 'False')
415
 
416
- ALLOW_LIST_PATH = get_or_create_env_var('ALLOW_LIST_PATH', '') # config/default_allow_list.csv
 
 
417
 
418
- S3_ALLOW_LIST_PATH = get_or_create_env_var('S3_ALLOW_LIST_PATH', '') # default_allow_list.csv # This is a path within the DOCUMENT_REDACTION_BUCKET
 
 
419
 
420
- if ALLOW_LIST_PATH: OUTPUT_ALLOW_LIST_PATH = ALLOW_LIST_PATH
421
- else: OUTPUT_ALLOW_LIST_PATH = 'config/default_allow_list.csv'
 
 
422
 
423
  ### DENY LIST
424
 
425
- GET_DEFAULT_DENY_LIST = get_or_create_env_var('GET_DEFAULT_DENY_LIST', 'False')
426
 
427
- S3_DENY_LIST_PATH = get_or_create_env_var('S3_DENY_LIST_PATH', '') # default_deny_list.csv # This is a path within the DOCUMENT_REDACTION_BUCKET
 
 
428
 
429
- DENY_LIST_PATH = get_or_create_env_var('DENY_LIST_PATH', '') # config/default_deny_list.csv
 
 
430
 
431
- if DENY_LIST_PATH: OUTPUT_DENY_LIST_PATH = DENY_LIST_PATH
432
- else: OUTPUT_DENY_LIST_PATH = 'config/default_deny_list.csv'
 
 
433
 
434
  ### WHOLE PAGE REDACTION LIST
435
 
436
- GET_DEFAULT_WHOLE_PAGE_REDACTION_LIST = get_or_create_env_var('GET_DEFAULT_WHOLE_PAGE_REDACTION_LIST', 'False')
 
 
437
 
438
- S3_WHOLE_PAGE_REDACTION_LIST_PATH = get_or_create_env_var('S3_WHOLE_PAGE_REDACTION_LIST_PATH', '') # default_whole_page_redaction_list.csv # This is a path within the DOCUMENT_REDACTION_BUCKET
 
 
439
 
440
- WHOLE_PAGE_REDACTION_LIST_PATH = get_or_create_env_var('WHOLE_PAGE_REDACTION_LIST_PATH', '') # config/default_whole_page_redaction_list.csv
 
 
441
 
442
- if WHOLE_PAGE_REDACTION_LIST_PATH: OUTPUT_WHOLE_PAGE_REDACTION_LIST_PATH = WHOLE_PAGE_REDACTION_LIST_PATH
443
- else: OUTPUT_WHOLE_PAGE_REDACTION_LIST_PATH = 'config/default_whole_page_redaction_list.csv'
 
 
 
 
444
 
445
  ###
446
  # COST CODE OPTIONS
447
  ###
448
 
449
- SHOW_COSTS = get_or_create_env_var('SHOW_COSTS', 'False')
450
 
451
- GET_COST_CODES = get_or_create_env_var('GET_COST_CODES', 'False')
452
 
453
- DEFAULT_COST_CODE = get_or_create_env_var('DEFAULT_COST_CODE', '')
454
 
455
- COST_CODES_PATH = get_or_create_env_var('COST_CODES_PATH', '') # 'config/COST_CENTRES.csv' # file should be a csv file with a single table in it that has two columns with a header. First column should contain cost codes, second column should contain a name or description for the cost code
 
 
456
 
457
- S3_COST_CODES_PATH = get_or_create_env_var('S3_COST_CODES_PATH', '') # COST_CENTRES.csv # This is a path within the DOCUMENT_REDACTION_BUCKET
458
-
459
- # A default path in case s3 cost code location is provided but no local cost code location given
460
- if COST_CODES_PATH: OUTPUT_COST_CODES_PATH = COST_CODES_PATH
461
- else: OUTPUT_COST_CODES_PATH = 'config/cost_codes.csv'
462
-
463
- ENFORCE_COST_CODES = get_or_create_env_var('ENFORCE_COST_CODES', 'False') # If you have cost codes listed, is it compulsory to choose one before redacting?
464
 
465
- if ENFORCE_COST_CODES == 'True': GET_COST_CODES = 'True'
 
 
 
 
466
 
 
 
 
467
 
 
 
468
 
469
 
470
  ###
471
  # WHOLE DOCUMENT API OPTIONS
472
  ###
473
 
474
- SHOW_WHOLE_DOCUMENT_TEXTRACT_CALL_OPTIONS = get_or_create_env_var('SHOW_WHOLE_DOCUMENT_TEXTRACT_CALL_OPTIONS', 'False') # This feature not currently implemented
 
 
475
 
476
- TEXTRACT_WHOLE_DOCUMENT_ANALYSIS_BUCKET = get_or_create_env_var('TEXTRACT_WHOLE_DOCUMENT_ANALYSIS_BUCKET', '')
 
 
477
 
478
- TEXTRACT_WHOLE_DOCUMENT_ANALYSIS_INPUT_SUBFOLDER = get_or_create_env_var('TEXTRACT_WHOLE_DOCUMENT_ANALYSIS_INPUT_SUBFOLDER', 'input')
 
 
479
 
480
- TEXTRACT_WHOLE_DOCUMENT_ANALYSIS_OUTPUT_SUBFOLDER = get_or_create_env_var('TEXTRACT_WHOLE_DOCUMENT_ANALYSIS_OUTPUT_SUBFOLDER', 'output')
 
 
481
 
482
- LOAD_PREVIOUS_TEXTRACT_JOBS_S3 = get_or_create_env_var('LOAD_PREVIOUS_TEXTRACT_JOBS_S3', 'False') # Whether or not to load previous Textract jobs from S3
 
 
483
 
484
- TEXTRACT_JOBS_S3_LOC = get_or_create_env_var('TEXTRACT_JOBS_S3_LOC', 'output') # Subfolder in the DOCUMENT_REDACTION_BUCKET where the Textract jobs are stored
 
 
485
 
486
- TEXTRACT_JOBS_S3_INPUT_LOC = get_or_create_env_var('TEXTRACT_JOBS_S3_INPUT_LOC', 'input') # Subfolder in the DOCUMENT_REDACTION_BUCKET where the Textract jobs are stored
 
 
487
 
488
- TEXTRACT_JOBS_LOCAL_LOC = get_or_create_env_var('TEXTRACT_JOBS_LOCAL_LOC', 'output') # Local subfolder where the Textract jobs are stored
 
 
489
 
490
- DAYS_TO_DISPLAY_WHOLE_DOCUMENT_JOBS = int(get_or_create_env_var('DAYS_TO_DISPLAY_WHOLE_DOCUMENT_JOBS', '7')) # How many days into the past should whole document Textract jobs be displayed? After that, the data is not deleted from the Textract jobs csv, but it is just filtered out. Included to align with S3 buckets where the file outputs will be automatically deleted after X days.
 
 
 
1
+ import logging
2
  import os
 
3
  import socket
4
+ import tempfile
5
  from datetime import datetime
6
+ from typing import List
7
+
8
  from dotenv import load_dotenv
9
  from tldextract import TLDExtract
 
10
 
11
  today_rev = datetime.now().strftime("%Y%m%d")
12
  HOST_NAME = socket.gethostname()
13
 
14
+
15
  def _get_env_list(env_var_name: str) -> List[str]:
16
  """Parses a comma-separated environment variable into a list of strings."""
17
+ value = env_var_name[1:-1].strip().replace('"', "").replace("'", "")
18
  if not value:
19
  return []
20
  # Split by comma and filter out any empty strings that might result from extra commas
21
+ return [s.strip() for s in value.split(",") if s.strip()]
22
+
23
 
24
  # Set or retrieve configuration variables for the redaction app
25
 
26
+
27
+ def get_or_create_env_var(var_name: str, default_value: str, print_val: bool = False):
28
+ """
29
  Get an environmental variable, and set it to a default value if it doesn't exist
30
+ """
31
  # Get the environment variable if it exists
32
  value = os.environ.get(var_name)
33
+
34
  # If it doesn't exist, set the environment variable to the default value
35
  if value is None:
36
  os.environ[var_name] = default_value
37
  value = default_value
38
 
39
+ if print_val is True:
40
+ print(f"The value of {var_name} is {value}")
41
+
42
  return value
43
 
44
+
45
  def add_folder_to_path(folder_path: str):
46
+ """
47
  Check if a folder exists on your system. If so, get the absolute path and then add it to the system Path variable if it doesn't already exist. Function is only relevant for locally-created executable files based on this app (when using pyinstaller it creates a _internal folder that contains tesseract and poppler. These need to be added to the system path to enable the app to run)
48
+ """
49
 
50
  if os.path.exists(folder_path) and os.path.isdir(folder_path):
51
+ # print(folder_path, "folder exists.")
52
 
53
  # Resolve relative path to absolute path
54
  absolute_path = os.path.abspath(folder_path)
55
 
56
+ current_path = os.environ["PATH"]
57
  if absolute_path not in current_path.split(os.pathsep):
58
  full_path_extension = absolute_path + os.pathsep + current_path
59
+ os.environ["PATH"] = full_path_extension
60
+ # print(f"Updated PATH with: ", full_path_extension)
61
  else:
62
  pass
63
+ # print(f"Directory {folder_path} already exists in PATH.")
64
  else:
65
  print(f"Folder not found at {folder_path} - not added to PATH")
66
 
67
+
68
  ###
69
  # LOAD CONFIG FROM ENV FILE
70
  ###
71
 
72
+ CONFIG_FOLDER = get_or_create_env_var("CONFIG_FOLDER", "config/")
73
 
74
  # If you have an aws_config env file in the config folder, you can load in app variables this way, e.g. 'config/app_config.env'
75
+ APP_CONFIG_PATH = get_or_create_env_var(
76
+ "APP_CONFIG_PATH", CONFIG_FOLDER + "app_config.env"
77
+ ) # e.g. config/app_config.env
78
 
79
  if APP_CONFIG_PATH:
80
  if os.path.exists(APP_CONFIG_PATH):
81
  print(f"Loading app variables from config file {APP_CONFIG_PATH}")
82
  load_dotenv(APP_CONFIG_PATH)
83
+ else:
84
+ print("App config file not found at location:", APP_CONFIG_PATH)
85
 
86
  ###
87
  # AWS OPTIONS
88
  ###
89
 
90
  # If you have an aws_config env file in the config folder, you can load in AWS keys this way, e.g. 'env/aws_config.env'
91
+ AWS_CONFIG_PATH = get_or_create_env_var(
92
+ "AWS_CONFIG_PATH", ""
93
+ ) # e.g. config/aws_config.env
94
 
95
  if AWS_CONFIG_PATH:
96
  if os.path.exists(AWS_CONFIG_PATH):
97
  print(f"Loading AWS variables from config file {AWS_CONFIG_PATH}")
98
  load_dotenv(AWS_CONFIG_PATH)
99
+ else:
100
+ print("AWS config file not found at location:", AWS_CONFIG_PATH)
101
 
102
  RUN_AWS_FUNCTIONS = get_or_create_env_var("RUN_AWS_FUNCTIONS", "0")
103
 
104
+ AWS_REGION = get_or_create_env_var("AWS_REGION", "")
105
 
106
+ AWS_CLIENT_ID = get_or_create_env_var("AWS_CLIENT_ID", "")
107
 
108
+ AWS_CLIENT_SECRET = get_or_create_env_var("AWS_CLIENT_SECRET", "")
109
 
110
+ AWS_USER_POOL_ID = get_or_create_env_var("AWS_USER_POOL_ID", "")
111
 
112
+ AWS_ACCESS_KEY = get_or_create_env_var("AWS_ACCESS_KEY", "")
113
+ # if AWS_ACCESS_KEY: print(f'AWS_ACCESS_KEY found in environment variables')
114
 
115
+ AWS_SECRET_KEY = get_or_create_env_var("AWS_SECRET_KEY", "")
116
+ # if AWS_SECRET_KEY: print(f'AWS_SECRET_KEY found in environment variables')
117
 
118
+ DOCUMENT_REDACTION_BUCKET = get_or_create_env_var("DOCUMENT_REDACTION_BUCKET", "")
119
 
120
  # Should the app prioritise using AWS SSO over using API keys stored in environment variables/secrets (defaults to yes)
121
+ PRIORITISE_SSO_OVER_AWS_ENV_ACCESS_KEYS = get_or_create_env_var(
122
+ "PRIORITISE_SSO_OVER_AWS_ENV_ACCESS_KEYS", "1"
123
+ )
124
 
125
  # Custom headers e.g. if routing traffic through Cloudfront
126
  # Retrieving or setting CUSTOM_HEADER
127
+ CUSTOM_HEADER = get_or_create_env_var("CUSTOM_HEADER", "")
128
 
129
  # Retrieving or setting CUSTOM_HEADER_VALUE
130
+ CUSTOM_HEADER_VALUE = get_or_create_env_var("CUSTOM_HEADER_VALUE", "")
131
 
132
  ###
133
  # Image options
134
  ###
135
+ IMAGES_DPI = float(get_or_create_env_var("IMAGES_DPI", "300.0"))
136
+ LOAD_TRUNCATED_IMAGES = get_or_create_env_var("LOAD_TRUNCATED_IMAGES", "True")
137
+ MAX_IMAGE_PIXELS = get_or_create_env_var(
138
+ "MAX_IMAGE_PIXELS", ""
139
+ ) # Changed to None if blank in file_conversion.py
140
 
141
  ###
142
  # File I/O options
143
  ###
144
 
145
+ SESSION_OUTPUT_FOLDER = get_or_create_env_var(
146
+ "SESSION_OUTPUT_FOLDER", "False"
147
+ ) # i.e. do you want your input and output folders saved within a subfolder based on session hash value within output/input folders
148
 
149
+ OUTPUT_FOLDER = get_or_create_env_var("GRADIO_OUTPUT_FOLDER", "output/") # 'output/'
150
+ INPUT_FOLDER = get_or_create_env_var("GRADIO_INPUT_FOLDER", "input/") # 'input/'
151
 
152
  # Allow for files to be saved in a temporary folder for increased security in some instances
153
+ if OUTPUT_FOLDER == "TEMP" or INPUT_FOLDER == "TEMP":
154
  # Create a temporary directory
155
  with tempfile.TemporaryDirectory() as temp_dir:
156
+ print(f"Temporary directory created at: {temp_dir}")
157
 
158
+ if OUTPUT_FOLDER == "TEMP":
159
+ OUTPUT_FOLDER = temp_dir + "/"
160
+ if INPUT_FOLDER == "TEMP":
161
+ INPUT_FOLDER = temp_dir + "/"
162
 
163
+ GRADIO_TEMP_DIR = get_or_create_env_var(
164
+ "GRADIO_TEMP_DIR", "tmp/gradio_tmp/"
165
+ ) # Default Gradio temp folder
166
+ MPLCONFIGDIR = get_or_create_env_var(
167
+ "MPLCONFIGDIR", "tmp/matplotlib_cache/"
168
+ ) # Matplotlib cache folder
169
 
170
  ###
171
  # LOGGING OPTIONS
 
174
  # By default, logs are put into a subfolder of today's date and the host name of the instance running the app. This is to avoid at all possible the possibility of log files from one instance overwriting the logs of another instance on S3. If running the app on one system always, or just locally, it is not necessary to make the log folders so specific.
175
  # Another way to address this issue would be to write logs to another type of storage, e.g. database such as dynamodb. I may look into this in future.
176
 
177
+ SAVE_LOGS_TO_CSV = get_or_create_env_var("SAVE_LOGS_TO_CSV", "True")
178
 
179
+ USE_LOG_SUBFOLDERS = get_or_create_env_var("USE_LOG_SUBFOLDERS", "True")
180
 
181
+ FEEDBACK_LOGS_FOLDER = get_or_create_env_var("FEEDBACK_LOGS_FOLDER", "feedback/")
182
+ ACCESS_LOGS_FOLDER = get_or_create_env_var("ACCESS_LOGS_FOLDER", "logs/")
183
+ USAGE_LOGS_FOLDER = get_or_create_env_var("USAGE_LOGS_FOLDER", "usage/")
184
 
185
  if USE_LOG_SUBFOLDERS == "True":
186
+ day_log_subfolder = today_rev + "/"
187
+ host_name_subfolder = HOST_NAME + "/"
188
  full_log_subfolder = day_log_subfolder + host_name_subfolder
189
 
190
  FEEDBACK_LOGS_FOLDER = FEEDBACK_LOGS_FOLDER + full_log_subfolder
191
  ACCESS_LOGS_FOLDER = ACCESS_LOGS_FOLDER + full_log_subfolder
192
  USAGE_LOGS_FOLDER = USAGE_LOGS_FOLDER + full_log_subfolder
193
 
194
+ S3_FEEDBACK_LOGS_FOLDER = get_or_create_env_var(
195
+ "S3_FEEDBACK_LOGS_FOLDER", "feedback/" + full_log_subfolder
196
+ )
197
+ S3_ACCESS_LOGS_FOLDER = get_or_create_env_var(
198
+ "S3_ACCESS_LOGS_FOLDER", "logs/" + full_log_subfolder
199
+ )
200
+ S3_USAGE_LOGS_FOLDER = get_or_create_env_var(
201
+ "S3_USAGE_LOGS_FOLDER", "usage/" + full_log_subfolder
202
+ )
203
 
204
  # Should the redacted file name be included in the logs? In some instances, the names of the files themselves could be sensitive, and should not be disclosed beyond the app. So, by default this is false.
205
+ DISPLAY_FILE_NAMES_IN_LOGS = get_or_create_env_var(
206
+ "DISPLAY_FILE_NAMES_IN_LOGS", "False"
207
+ )
208
 
209
  # Further customisation options for CSV logs
210
+ CSV_ACCESS_LOG_HEADERS = get_or_create_env_var(
211
+ "CSV_ACCESS_LOG_HEADERS", ""
212
+ ) # If blank, uses component labels
213
+ CSV_FEEDBACK_LOG_HEADERS = get_or_create_env_var(
214
+ "CSV_FEEDBACK_LOG_HEADERS", ""
215
+ ) # If blank, uses component labels
216
+ CSV_USAGE_LOG_HEADERS = get_or_create_env_var(
217
+ "CSV_USAGE_LOG_HEADERS",
218
+ '["session_hash_textbox", "doc_full_file_name_textbox", "data_full_file_name_textbox", "actual_time_taken_number", "total_page_count", "textract_query_number", "pii_detection_method", "comprehend_query_number", "cost_code", "textract_handwriting_signature", "host_name_textbox", "text_extraction_method", "is_this_a_textract_api_call", "task"]',
219
+ ) # If blank, uses component labels
220
 
221
  ### DYNAMODB logs. Whether to save to DynamoDB, and the headers of the table
222
+ SAVE_LOGS_TO_DYNAMODB = get_or_create_env_var("SAVE_LOGS_TO_DYNAMODB", "False")
223
 
224
+ ACCESS_LOG_DYNAMODB_TABLE_NAME = get_or_create_env_var(
225
+ "ACCESS_LOG_DYNAMODB_TABLE_NAME", "redaction_access_log"
226
+ )
227
+ DYNAMODB_ACCESS_LOG_HEADERS = get_or_create_env_var("DYNAMODB_ACCESS_LOG_HEADERS", "")
228
 
229
+ FEEDBACK_LOG_DYNAMODB_TABLE_NAME = get_or_create_env_var(
230
+ "FEEDBACK_LOG_DYNAMODB_TABLE_NAME", "redaction_feedback"
231
+ )
232
+ DYNAMODB_FEEDBACK_LOG_HEADERS = get_or_create_env_var(
233
+ "DYNAMODB_FEEDBACK_LOG_HEADERS", ""
234
+ )
235
 
236
+ USAGE_LOG_DYNAMODB_TABLE_NAME = get_or_create_env_var(
237
+ "USAGE_LOG_DYNAMODB_TABLE_NAME", "redaction_usage"
238
+ )
239
+ DYNAMODB_USAGE_LOG_HEADERS = get_or_create_env_var("DYNAMODB_USAGE_LOG_HEADERS", "")
240
 
241
  # Report logging to console?
242
+ LOGGING = get_or_create_env_var("LOGGING", "False")
243
 
244
+ if LOGGING == "True":
245
  # Configure logging
246
+ logging.basicConfig(
247
+ level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
248
+ )
249
 
250
+ LOG_FILE_NAME = get_or_create_env_var("LOG_FILE_NAME", "log.csv")
251
+ USAGE_LOG_FILE_NAME = get_or_create_env_var("USAGE_LOG_FILE_NAME", LOG_FILE_NAME)
252
+ FEEDBACK_LOG_FILE_NAME = get_or_create_env_var("FEEDBACK_LOG_FILE_NAME", LOG_FILE_NAME)
253
 
254
 
255
  ###
 
257
  ###
258
 
259
  # Create Tesseract and Poppler folders if you have installed them locally
260
+ TESSERACT_FOLDER = get_or_create_env_var(
261
+ "TESSERACT_FOLDER", ""
262
+ ) # # If installing for Windows, install Tesseract 5.5.0 from here: https://github.com/UB-Mannheim/tesseract/wiki. Then this environment variable should point to the Tesseract folder e.g. tesseract/
263
+ TESSERACT_DATA_FOLDER = get_or_create_env_var(
264
+ "TESSERACT_DATA_FOLDER", "/usr/share/tessdata"
265
+ )
266
+ POPPLER_FOLDER = get_or_create_env_var(
267
+ "POPPLER_FOLDER", ""
268
+ ) # If installing on Windows,install Poppler from here https://github.com/oschwartz10612/poppler-windows. This variable needs to point to the poppler bin folder e.g. poppler/poppler-24.02.0/Library/bin/
269
+
270
+ if TESSERACT_FOLDER:
271
+ add_folder_to_path(TESSERACT_FOLDER)
272
+ if POPPLER_FOLDER:
273
+ add_folder_to_path(POPPLER_FOLDER)
274
 
275
  # List of models to use for text extraction and PII detection
276
  # Text extraction models
277
+ SELECTABLE_TEXT_EXTRACT_OPTION = get_or_create_env_var(
278
+ "SELECTABLE_TEXT_EXTRACT_OPTION", "Local model - selectable text"
279
+ )
280
+ TESSERACT_TEXT_EXTRACT_OPTION = get_or_create_env_var(
281
+ "TESSERACT_TEXT_EXTRACT_OPTION", "Local OCR model - PDFs without selectable text"
282
+ )
283
+ TEXTRACT_TEXT_EXTRACT_OPTION = get_or_create_env_var(
284
+ "TEXTRACT_TEXT_EXTRACT_OPTION", "AWS Textract service - all PDF types"
285
+ )
286
 
287
  # PII detection models
288
+ NO_REDACTION_PII_OPTION = get_or_create_env_var(
289
+ "NO_REDACTION_PII_OPTION", "Only extract text (no redaction)"
290
+ )
291
+ LOCAL_PII_OPTION = get_or_create_env_var("LOCAL_PII_OPTION", "Local")
292
+ AWS_PII_OPTION = get_or_create_env_var("AWS_PII_OPTION", "AWS Comprehend")
293
+
294
+ SHOW_LOCAL_TEXT_EXTRACTION_OPTIONS = get_or_create_env_var(
295
+ "SHOW_LOCAL_TEXT_EXTRACTION_OPTIONS", "True"
296
+ )
297
+ SHOW_AWS_TEXT_EXTRACTION_OPTIONS = get_or_create_env_var(
298
+ "SHOW_AWS_TEXT_EXTRACTION_OPTIONS", "True"
299
+ )
300
 
301
  # Show at least local options if everything mistakenly removed
302
+ if (
303
+ SHOW_LOCAL_TEXT_EXTRACTION_OPTIONS != "True"
304
+ and SHOW_AWS_TEXT_EXTRACTION_OPTIONS != "True"
305
+ ):
306
  SHOW_LOCAL_TEXT_EXTRACTION_OPTIONS = "True"
307
 
308
  local_model_options = []
309
  aws_model_options = []
310
  text_extraction_models = []
311
 
312
+ if SHOW_LOCAL_TEXT_EXTRACTION_OPTIONS == "True":
313
  local_model_options.append(SELECTABLE_TEXT_EXTRACT_OPTION)
314
  local_model_options.append(TESSERACT_TEXT_EXTRACT_OPTION)
315
 
316
+ if SHOW_AWS_TEXT_EXTRACTION_OPTIONS == "True":
317
  aws_model_options.append(TEXTRACT_TEXT_EXTRACT_OPTION)
318
 
319
  TEXT_EXTRACTION_MODELS = local_model_options + aws_model_options
320
+ DO_INITIAL_TABULAR_DATA_CLEAN = get_or_create_env_var(
321
+ "DO_INITIAL_TABULAR_DATA_CLEAN", "True"
322
+ )
323
+
324
+ SHOW_LOCAL_PII_DETECTION_OPTIONS = get_or_create_env_var(
325
+ "SHOW_LOCAL_PII_DETECTION_OPTIONS", "True"
326
+ )
327
+ SHOW_AWS_PII_DETECTION_OPTIONS = get_or_create_env_var(
328
+ "SHOW_AWS_PII_DETECTION_OPTIONS", "True"
329
+ )
330
+
331
+ if (
332
+ SHOW_LOCAL_PII_DETECTION_OPTIONS != "True"
333
+ and SHOW_AWS_PII_DETECTION_OPTIONS != "True"
334
+ ):
335
  SHOW_LOCAL_PII_DETECTION_OPTIONS = "True"
336
 
337
  local_model_options = [NO_REDACTION_PII_OPTION]
338
  aws_model_options = []
339
  pii_detection_models = []
340
 
341
+ if SHOW_LOCAL_PII_DETECTION_OPTIONS == "True":
342
  local_model_options.append(LOCAL_PII_OPTION)
343
 
344
+ if SHOW_AWS_PII_DETECTION_OPTIONS == "True":
345
  aws_model_options.append(AWS_PII_OPTION)
346
 
347
  PII_DETECTION_MODELS = local_model_options + aws_model_options
348
 
349
  if SHOW_AWS_TEXT_EXTRACTION_OPTIONS == "True":
350
+ DEFAULT_TEXT_EXTRACTION_MODEL = get_or_create_env_var(
351
+ "DEFAULT_TEXT_EXTRACTION_MODEL", TEXTRACT_TEXT_EXTRACT_OPTION
352
+ )
353
  else:
354
+ DEFAULT_TEXT_EXTRACTION_MODEL = get_or_create_env_var(
355
+ "DEFAULT_TEXT_EXTRACTION_MODEL", SELECTABLE_TEXT_EXTRACT_OPTION
356
+ )
357
 
358
  if SHOW_AWS_PII_DETECTION_OPTIONS == "True":
359
+ DEFAULT_PII_DETECTION_MODEL = get_or_create_env_var(
360
+ "DEFAULT_PII_DETECTION_MODEL", AWS_PII_OPTION
361
+ )
362
  else:
363
+ DEFAULT_PII_DETECTION_MODEL = get_or_create_env_var(
364
+ "DEFAULT_PII_DETECTION_MODEL", LOCAL_PII_OPTION
365
+ )
366
 
367
  # Create list of PII detection models for tabular redaction
368
  TABULAR_PII_DETECTION_MODELS = PII_DETECTION_MODELS.copy()
369
  if NO_REDACTION_PII_OPTION in TABULAR_PII_DETECTION_MODELS:
370
  TABULAR_PII_DETECTION_MODELS.remove(NO_REDACTION_PII_OPTION)
371
 
372
+ DEFAULT_TEXT_COLUMNS = get_or_create_env_var("DEFAULT_TEXT_COLUMNS", "[]")
373
+ DEFAULT_EXCEL_SHEETS = get_or_create_env_var("DEFAULT_EXCEL_SHEETS", "[]")
374
 
375
+ DEFAULT_TABULAR_ANONYMISATION_STRATEGY = get_or_create_env_var(
376
+ "DEFAULT_TABULAR_ANONYMISATION_STRATEGY", "redact completely"
377
+ )
378
 
379
  ### Local OCR model - Tesseract vs PaddleOCR
380
+ CHOSEN_LOCAL_OCR_MODEL = get_or_create_env_var(
381
+ "CHOSEN_LOCAL_OCR_MODEL", "tesseract"
382
+ ) # Choose between "tesseract", "hybrid", and "paddle". "paddle" will only return whole line text extraction, and so will only work for OCR, not redaction. "hybrid" is a combination of the two - first pass through the redactions will be done with Tesseract, and then a second pass will be done with PaddleOCR on words with low confidence.
383
 
384
+ PREPROCESS_LOCAL_OCR_IMAGES = get_or_create_env_var(
385
+ "PREPROCESS_LOCAL_OCR_IMAGES", "True"
386
+ ) # Whether to try and preprocess images before extracting text. NOTE: I have found in testing that this doesn't necessarily imporove results, and greatly slows down extraction.
387
 
388
  # Entities for redaction
389
+ CHOSEN_COMPREHEND_ENTITIES = get_or_create_env_var(
390
+ "CHOSEN_COMPREHEND_ENTITIES",
391
+ "['BANK_ACCOUNT_NUMBER','BANK_ROUTING','CREDIT_DEBIT_NUMBER','CREDIT_DEBIT_CVV','CREDIT_DEBIT_EXPIRY','PIN','EMAIL','ADDRESS','NAME','PHONE', 'PASSPORT_NUMBER','DRIVER_ID', 'USERNAME','PASSWORD', 'IP_ADDRESS','MAC_ADDRESS', 'LICENSE_PLATE','VEHICLE_IDENTIFICATION_NUMBER','UK_NATIONAL_INSURANCE_NUMBER', 'INTERNATIONAL_BANK_ACCOUNT_NUMBER','SWIFT_CODE','UK_NATIONAL_HEALTH_SERVICE_NUMBER']",
392
+ )
393
 
394
+ FULL_COMPREHEND_ENTITY_LIST = get_or_create_env_var(
395
+ "FULL_COMPREHEND_ENTITY_LIST",
396
+ "['BANK_ACCOUNT_NUMBER','BANK_ROUTING','CREDIT_DEBIT_NUMBER','CREDIT_DEBIT_CVV','CREDIT_DEBIT_EXPIRY','PIN','EMAIL','ADDRESS','NAME','PHONE','SSN','DATE_TIME','PASSPORT_NUMBER','DRIVER_ID','URL','AGE','USERNAME','PASSWORD','AWS_ACCESS_KEY','AWS_SECRET_KEY','IP_ADDRESS','MAC_ADDRESS','ALL','LICENSE_PLATE','VEHICLE_IDENTIFICATION_NUMBER','UK_NATIONAL_INSURANCE_NUMBER','CA_SOCIAL_INSURANCE_NUMBER','US_INDIVIDUAL_TAX_IDENTIFICATION_NUMBER','UK_UNIQUE_TAXPAYER_REFERENCE_NUMBER','IN_PERMANENT_ACCOUNT_NUMBER','IN_NREGA','INTERNATIONAL_BANK_ACCOUNT_NUMBER','SWIFT_CODE','UK_NATIONAL_HEALTH_SERVICE_NUMBER','CA_HEALTH_NUMBER','IN_AADHAAR','IN_VOTER_NUMBER', 'CUSTOM_FUZZY']",
397
+ )
398
 
399
  # Entities for local PII redaction option
400
+ CHOSEN_REDACT_ENTITIES = get_or_create_env_var(
401
+ "CHOSEN_REDACT_ENTITIES",
402
+ "['TITLES', 'PERSON', 'PHONE_NUMBER', 'EMAIL_ADDRESS', 'STREETNAME', 'UKPOSTCODE', 'CUSTOM']",
403
+ )
404
+
405
+ FULL_ENTITY_LIST = get_or_create_env_var(
406
+ "FULL_ENTITY_LIST",
407
+ "['TITLES', 'PERSON', 'PHONE_NUMBER', 'EMAIL_ADDRESS', 'STREETNAME', 'UKPOSTCODE', 'CREDIT_CARD', 'CRYPTO', 'DATE_TIME', 'IBAN_CODE', 'IP_ADDRESS', 'NRP', 'LOCATION', 'MEDICAL_LICENSE', 'URL', 'UK_NHS', 'CUSTOM', 'CUSTOM_FUZZY']",
408
+ )
409
+
410
+ CUSTOM_ENTITIES = get_or_create_env_var(
411
+ "CUSTOM_ENTITIES", "['TITLES', 'UKPOSTCODE', 'STREETNAME', 'CUSTOM']"
412
+ )
413
+
414
+
415
+ DEFAULT_HANDWRITE_SIGNATURE_CHECKBOX = get_or_create_env_var(
416
+ "DEFAULT_HANDWRITE_SIGNATURE_CHECKBOX", "['Extract handwriting']"
417
+ )
418
+
419
+ HANDWRITE_SIGNATURE_TEXTBOX_FULL_OPTIONS = get_or_create_env_var(
420
+ "HANDWRITE_SIGNATURE_TEXTBOX_FULL_OPTIONS",
421
+ "['Extract handwriting', 'Extract signatures']",
422
+ )
423
+
424
+ if HANDWRITE_SIGNATURE_TEXTBOX_FULL_OPTIONS:
425
+ HANDWRITE_SIGNATURE_TEXTBOX_FULL_OPTIONS = _get_env_list(
426
+ HANDWRITE_SIGNATURE_TEXTBOX_FULL_OPTIONS
427
+ )
428
+
429
+ INCLUDE_FORM_EXTRACTION_TEXTRACT_OPTION = get_or_create_env_var(
430
+ "INCLUDE_FORM_EXTRACTION_TEXTRACT_OPTION", "False"
431
+ )
432
+ INCLUDE_LAYOUT_EXTRACTION_TEXTRACT_OPTION = get_or_create_env_var(
433
+ "INCLUDE_LAYOUT_EXTRACTION_TEXTRACT_OPTION", "False"
434
+ )
435
+ INCLUDE_TABLE_EXTRACTION_TEXTRACT_OPTION = get_or_create_env_var(
436
+ "INCLUDE_TABLE_EXTRACTION_TEXTRACT_OPTION", "False"
437
+ )
438
 
439
  if INCLUDE_FORM_EXTRACTION_TEXTRACT_OPTION == "True":
440
+ HANDWRITE_SIGNATURE_TEXTBOX_FULL_OPTIONS.append("Extract forms")
441
  if INCLUDE_LAYOUT_EXTRACTION_TEXTRACT_OPTION == "True":
442
+ HANDWRITE_SIGNATURE_TEXTBOX_FULL_OPTIONS.append("Extract layout")
443
  if INCLUDE_TABLE_EXTRACTION_TEXTRACT_OPTION == "True":
444
+ HANDWRITE_SIGNATURE_TEXTBOX_FULL_OPTIONS.append("Extract tables")
445
 
446
 
447
+ DEFAULT_SEARCH_QUERY = get_or_create_env_var("DEFAULT_SEARCH_QUERY", "")
448
+ DEFAULT_FUZZY_SPELLING_MISTAKES_NUM = int(
449
+ get_or_create_env_var("DEFAULT_FUZZY_SPELLING_MISTAKES_NUM", "1")
450
+ )
451
 
452
+ DEFAULT_PAGE_MIN = int(get_or_create_env_var("DEFAULT_PAGE_MIN", "0"))
453
 
454
+ DEFAULT_PAGE_MAX = int(get_or_create_env_var("DEFAULT_PAGE_MAX", "999"))
455
 
456
 
457
  # Number of pages to loop through before breaking the function and restarting from the last finished page (not currently activated).
458
+ PAGE_BREAK_VALUE = int(get_or_create_env_var("PAGE_BREAK_VALUE", "99999"))
459
 
460
+ MAX_TIME_VALUE = int(get_or_create_env_var("MAX_TIME_VALUE", "999999"))
461
+ MAX_SIMULTANEOUS_FILES = int(get_or_create_env_var("MAX_SIMULTANEOUS_FILES", "10"))
462
+ MAX_DOC_PAGES = int(get_or_create_env_var("MAX_DOC_PAGES", "3000"))
463
+ MAX_TABLE_ROWS = int(get_or_create_env_var("MAX_TABLE_ROWS", "250000"))
464
+ MAX_TABLE_COLUMNS = int(get_or_create_env_var("MAX_TABLE_COLUMNS", "100"))
465
+ MAX_OPEN_TEXT_CHARACTERS = int(
466
+ get_or_create_env_var("MAX_OPEN_TEXT_CHARACTERS", "50000")
467
+ )
468
 
469
+ CUSTOM_BOX_COLOUR = get_or_create_env_var(
470
+ "CUSTOM_BOX_COLOUR", ""
471
+ ) # only "grey" is currently supported as a custom box colour
472
 
473
  ### Language selection options
474
 
475
  SHOW_LANGUAGE_SELECTION = get_or_create_env_var("SHOW_LANGUAGE_SELECTION", "False")
476
 
477
+ DEFAULT_LANGUAGE_FULL_NAME = get_or_create_env_var(
478
+ "DEFAULT_LANGUAGE_FULL_NAME", "english"
479
+ )
480
+ DEFAULT_LANGUAGE = get_or_create_env_var(
481
+ "DEFAULT_LANGUAGE", "en"
482
+ ) # For tesseract, ensure the Tesseract language data (e.g., fra.traineddata) is installed on your system. You can find the relevant language packs here: https://github.com/tesseract-ocr/tessdata.
483
  # For paddle, ensure the paddle language data (e.g., fra.traineddata) is installed on your system. You can find information on supported languages here: https://www.paddleocr.ai/main/en/version3.x/algorithm/PP-OCRv5/PP-OCRv5_multi_languages.html
484
  # For AWS Comprehend, only English and Spanish are supported https://docs.aws.amazon.com/comprehend/latest/dg/how-pii.html ['en', 'es']
485
  # AWS Textract automatically detects the language of the document and supports the following languages: https://aws.amazon.com/textract/faqs/#topic-0. 'English, Spanish, Italian, Portuguese, French, German. Handwriting, Invoices and Receipts, Identity documents and Queries processing are in English only'
486
 
487
+ textract_language_choices = get_or_create_env_var(
488
+ "textract_language_choices", "['en', 'es', 'fr', 'de', 'it', 'pt']"
489
+ )
490
+ aws_comprehend_language_choices = get_or_create_env_var(
491
+ "aws_comprehend_language_choices", "['en', 'es']"
492
+ )
493
 
494
  # The choices that the user sees
495
+ MAPPED_LANGUAGE_CHOICES = get_or_create_env_var(
496
+ "MAPPED_LANGUAGE_CHOICES",
497
+ "['english', 'french', 'german', 'spanish', 'italian', 'dutch', 'portuguese', 'chinese', 'japanese', 'korean', 'lithuanian', 'macedonian', 'norwegian_bokmaal', 'polish', 'romanian', 'russian', 'slovenian', 'swedish', 'catalan', 'ukrainian']",
498
+ )
499
+ LANGUAGE_CHOICES = get_or_create_env_var(
500
+ "LANGUAGE_CHOICES",
501
+ "['en', 'fr', 'de', 'es', 'it', 'nl', 'pt', 'zh', 'ja', 'ko', 'lt', 'mk', 'nb', 'pl', 'ro', 'ru', 'sl', 'sv', 'ca', 'uk']",
502
+ )
503
 
504
  ###
505
  # Duplicate detection settings
506
  ###
507
+ DEFAULT_DUPLICATE_DETECTION_THRESHOLD = float(
508
+ get_or_create_env_var("DEFAULT_DUPLICATE_DETECTION_THRESHOLD", "0.95")
509
+ )
510
+ DEFAULT_MIN_CONSECUTIVE_PAGES = int(
511
+ get_or_create_env_var("DEFAULT_MIN_CONSECUTIVE_PAGES", "1")
512
+ )
513
+ USE_GREEDY_DUPLICATE_DETECTION = get_or_create_env_var(
514
+ "USE_GREEDY_DUPLICATE_DETECTION", "True"
515
+ )
516
+ DEFAULT_COMBINE_PAGES = get_or_create_env_var(
517
+ "DEFAULT_COMBINE_PAGES", "True"
518
+ ) # Combine text from the same page number within a file. Alternative will enable line-level duplicate detection.
519
  DEFAULT_MIN_WORD_COUNT = int(get_or_create_env_var("DEFAULT_MIN_WORD_COUNT", "10"))
520
  REMOVE_DUPLICATE_ROWS = get_or_create_env_var("REMOVE_DUPLICATE_ROWS", "False")
521
 
 
523
  ###
524
  # File output options
525
  ###
526
+ RETURN_PDF_END_OF_REDACTION = get_or_create_env_var(
527
+ "RETURN_PDF_END_OF_REDACTION", "True"
528
+ ) # Return a redacted PDF at the end of the redaction task. Could be useful to set this to "False" if you want to ensure that the user always goes to the 'Review Redactions' tab before getting the final redacted PDF product.
529
 
530
+ COMPRESS_REDACTED_PDF = get_or_create_env_var(
531
+ "COMPRESS_REDACTED_PDF", "False"
532
+ ) # On low memory systems, the compression options in pymupdf can cause the app to crash if the PDF is longer than 500 pages or so. Setting this to False will save the PDF only with a basic cleaning option enabled
533
 
534
  ###
535
  # APP RUN OPTIONS
536
  ###
537
 
538
+ TLDEXTRACT_CACHE = get_or_create_env_var("TLDEXTRACT_CACHE", "tmp/tld/")
539
+ try:
540
+ extract = TLDExtract(cache_dir=TLDEXTRACT_CACHE)
541
+ except Exception as e:
542
+ print(f"Error initialising TLDExtract: {e}")
543
+ extract = TLDExtract(cache_dir=None)
544
 
545
  # Get some environment variables and Launch the Gradio app
546
+ COGNITO_AUTH = get_or_create_env_var("COGNITO_AUTH", "0")
547
 
548
+ RUN_DIRECT_MODE = get_or_create_env_var("RUN_DIRECT_MODE", "0")
549
 
550
  # Direct mode configuration options
551
+ DIRECT_MODE_DEFAULT_USER = get_or_create_env_var(
552
+ "DIRECT_MODE_DEFAULT_USER", ""
553
+ ) # Default username for cli/direct mode requests
554
+ DIRECT_MODE_TASK = get_or_create_env_var(
555
+ "DIRECT_MODE_TASK", "redact"
556
+ ) # 'redact' or 'deduplicate'
557
+ DIRECT_MODE_INPUT_FILE = get_or_create_env_var(
558
+ "DIRECT_MODE_INPUT_FILE", ""
559
+ ) # Path to input file
560
+ DIRECT_MODE_OUTPUT_DIR = get_or_create_env_var(
561
+ "DIRECT_MODE_OUTPUT_DIR", OUTPUT_FOLDER
562
+ ) # Output directory
563
+ DIRECT_MODE_DUPLICATE_TYPE = get_or_create_env_var(
564
+ "DIRECT_MODE_DUPLICATE_TYPE", "pages"
565
+ ) # 'pages' or 'tabular'
566
 
567
+ MAX_QUEUE_SIZE = int(get_or_create_env_var("MAX_QUEUE_SIZE", "5"))
568
 
569
+ MAX_FILE_SIZE = get_or_create_env_var("MAX_FILE_SIZE", "250mb").lower()
570
 
571
+ GRADIO_SERVER_PORT = int(get_or_create_env_var("GRADIO_SERVER_PORT", "7860"))
572
 
573
+ ROOT_PATH = get_or_create_env_var("ROOT_PATH", "")
574
 
575
+ DEFAULT_CONCURRENCY_LIMIT = int(get_or_create_env_var("DEFAULT_CONCURRENCY_LIMIT", "3"))
576
 
577
+ FILE_INPUT_HEIGHT = get_or_create_env_var("FILE_INPUT_HEIGHT", "200")
578
 
579
  ### ALLOW LIST
580
 
581
+ GET_DEFAULT_ALLOW_LIST = get_or_create_env_var("GET_DEFAULT_ALLOW_LIST", "False")
582
 
583
+ ALLOW_LIST_PATH = get_or_create_env_var(
584
+ "ALLOW_LIST_PATH", ""
585
+ ) # config/default_allow_list.csv
586
 
587
+ S3_ALLOW_LIST_PATH = get_or_create_env_var(
588
+ "S3_ALLOW_LIST_PATH", ""
589
+ ) # default_allow_list.csv # This is a path within the DOCUMENT_REDACTION_BUCKET
590
 
591
+ if ALLOW_LIST_PATH:
592
+ OUTPUT_ALLOW_LIST_PATH = ALLOW_LIST_PATH
593
+ else:
594
+ OUTPUT_ALLOW_LIST_PATH = "config/default_allow_list.csv"
595
 
596
  ### DENY LIST
597
 
598
+ GET_DEFAULT_DENY_LIST = get_or_create_env_var("GET_DEFAULT_DENY_LIST", "False")
599
 
600
+ S3_DENY_LIST_PATH = get_or_create_env_var(
601
+ "S3_DENY_LIST_PATH", ""
602
+ ) # default_deny_list.csv # This is a path within the DOCUMENT_REDACTION_BUCKET
603
 
604
+ DENY_LIST_PATH = get_or_create_env_var(
605
+ "DENY_LIST_PATH", ""
606
+ ) # config/default_deny_list.csv
607
 
608
+ if DENY_LIST_PATH:
609
+ OUTPUT_DENY_LIST_PATH = DENY_LIST_PATH
610
+ else:
611
+ OUTPUT_DENY_LIST_PATH = "config/default_deny_list.csv"
612
 
613
  ### WHOLE PAGE REDACTION LIST
614
 
615
+ GET_DEFAULT_WHOLE_PAGE_REDACTION_LIST = get_or_create_env_var(
616
+ "GET_DEFAULT_WHOLE_PAGE_REDACTION_LIST", "False"
617
+ )
618
 
619
+ S3_WHOLE_PAGE_REDACTION_LIST_PATH = get_or_create_env_var(
620
+ "S3_WHOLE_PAGE_REDACTION_LIST_PATH", ""
621
+ ) # default_whole_page_redaction_list.csv # This is a path within the DOCUMENT_REDACTION_BUCKET
622
 
623
+ WHOLE_PAGE_REDACTION_LIST_PATH = get_or_create_env_var(
624
+ "WHOLE_PAGE_REDACTION_LIST_PATH", ""
625
+ ) # config/default_whole_page_redaction_list.csv
626
 
627
+ if WHOLE_PAGE_REDACTION_LIST_PATH:
628
+ OUTPUT_WHOLE_PAGE_REDACTION_LIST_PATH = WHOLE_PAGE_REDACTION_LIST_PATH
629
+ else:
630
+ OUTPUT_WHOLE_PAGE_REDACTION_LIST_PATH = (
631
+ "config/default_whole_page_redaction_list.csv"
632
+ )
633
 
634
  ###
635
  # COST CODE OPTIONS
636
  ###
637
 
638
+ SHOW_COSTS = get_or_create_env_var("SHOW_COSTS", "False")
639
 
640
+ GET_COST_CODES = get_or_create_env_var("GET_COST_CODES", "False")
641
 
642
+ DEFAULT_COST_CODE = get_or_create_env_var("DEFAULT_COST_CODE", "")
643
 
644
+ COST_CODES_PATH = get_or_create_env_var(
645
+ "COST_CODES_PATH", ""
646
+ ) # 'config/COST_CENTRES.csv' # file should be a csv file with a single table in it that has two columns with a header. First column should contain cost codes, second column should contain a name or description for the cost code
647
 
648
+ S3_COST_CODES_PATH = get_or_create_env_var(
649
+ "S3_COST_CODES_PATH", ""
650
+ ) # COST_CENTRES.csv # This is a path within the DOCUMENT_REDACTION_BUCKET
 
 
 
 
651
 
652
+ # A default path in case s3 cost code location is provided but no local cost code location given
653
+ if COST_CODES_PATH:
654
+ OUTPUT_COST_CODES_PATH = COST_CODES_PATH
655
+ else:
656
+ OUTPUT_COST_CODES_PATH = "config/cost_codes.csv"
657
 
658
+ ENFORCE_COST_CODES = get_or_create_env_var(
659
+ "ENFORCE_COST_CODES", "False"
660
+ ) # If you have cost codes listed, is it compulsory to choose one before redacting?
661
 
662
+ if ENFORCE_COST_CODES == "True":
663
+ GET_COST_CODES = "True"
664
 
665
 
666
  ###
667
  # WHOLE DOCUMENT API OPTIONS
668
  ###
669
 
670
+ SHOW_WHOLE_DOCUMENT_TEXTRACT_CALL_OPTIONS = get_or_create_env_var(
671
+ "SHOW_WHOLE_DOCUMENT_TEXTRACT_CALL_OPTIONS", "False"
672
+ ) # This feature not currently implemented
673
 
674
+ TEXTRACT_WHOLE_DOCUMENT_ANALYSIS_BUCKET = get_or_create_env_var(
675
+ "TEXTRACT_WHOLE_DOCUMENT_ANALYSIS_BUCKET", ""
676
+ )
677
 
678
+ TEXTRACT_WHOLE_DOCUMENT_ANALYSIS_INPUT_SUBFOLDER = get_or_create_env_var(
679
+ "TEXTRACT_WHOLE_DOCUMENT_ANALYSIS_INPUT_SUBFOLDER", "input"
680
+ )
681
 
682
+ TEXTRACT_WHOLE_DOCUMENT_ANALYSIS_OUTPUT_SUBFOLDER = get_or_create_env_var(
683
+ "TEXTRACT_WHOLE_DOCUMENT_ANALYSIS_OUTPUT_SUBFOLDER", "output"
684
+ )
685
 
686
+ LOAD_PREVIOUS_TEXTRACT_JOBS_S3 = get_or_create_env_var(
687
+ "LOAD_PREVIOUS_TEXTRACT_JOBS_S3", "False"
688
+ ) # Whether or not to load previous Textract jobs from S3
689
 
690
+ TEXTRACT_JOBS_S3_LOC = get_or_create_env_var(
691
+ "TEXTRACT_JOBS_S3_LOC", "output"
692
+ ) # Subfolder in the DOCUMENT_REDACTION_BUCKET where the Textract jobs are stored
693
 
694
+ TEXTRACT_JOBS_S3_INPUT_LOC = get_or_create_env_var(
695
+ "TEXTRACT_JOBS_S3_INPUT_LOC", "input"
696
+ ) # Subfolder in the DOCUMENT_REDACTION_BUCKET where the Textract jobs are stored
697
 
698
+ TEXTRACT_JOBS_LOCAL_LOC = get_or_create_env_var(
699
+ "TEXTRACT_JOBS_LOCAL_LOC", "output"
700
+ ) # Local subfolder where the Textract jobs are stored
701
 
702
+ DAYS_TO_DISPLAY_WHOLE_DOCUMENT_JOBS = int(
703
+ get_or_create_env_var("DAYS_TO_DISPLAY_WHOLE_DOCUMENT_JOBS", "7")
704
+ ) # How many days into the past should whole document Textract jobs be displayed? After that, the data is not deleted from the Textract jobs csv, but it is just filtered out. Included to align with S3 buckets where the file outputs will be automatically deleted after X days.
tools/custom_csvlogger.py CHANGED
@@ -1,28 +1,30 @@
1
  from __future__ import annotations
2
- import contextlib
3
  import csv
4
- import datetime
5
- from datetime import datetime
6
  import os
7
  import re
8
- import boto3
9
- import botocore
10
- import uuid
11
  import time
 
12
  from collections.abc import Sequence
13
- from multiprocessing import Lock
14
  from pathlib import Path
 
 
 
15
  from typing import TYPE_CHECKING, Any
16
- from gradio_client import utils as client_utils
17
- import gradio as gr
 
18
  from gradio import utils
19
- from tools.config import AWS_REGION, AWS_ACCESS_KEY, AWS_SECRET_KEY, RUN_AWS_FUNCTIONS
20
 
 
21
 
22
  if TYPE_CHECKING:
23
  from gradio.components import Component
 
24
  from gradio.flagging import FlaggingCallback
25
- from threading import Lock
26
 
27
  class CSVLogger_custom(FlaggingCallback):
28
  """
@@ -68,15 +70,15 @@ class CSVLogger_custom(FlaggingCallback):
68
  self.first_time = True
69
 
70
  def _create_dataset_file(
71
- self,
72
- additional_headers: list[str] | None = None,
73
- replacement_headers: list[str] | None = None
74
- ):
75
  os.makedirs(self.flagging_dir, exist_ok=True)
76
 
77
  if replacement_headers:
78
  if additional_headers is None:
79
- additional_headers = list()
80
 
81
  if len(replacement_headers) != len(self.components):
82
  raise ValueError(
@@ -87,10 +89,14 @@ class CSVLogger_custom(FlaggingCallback):
87
  else:
88
  if additional_headers is None:
89
  additional_headers = []
90
- headers = [
91
- getattr(component, "label", None) or f"component {idx}"
92
- for idx, component in enumerate(self.components)
93
- ] + additional_headers + ["timestamp"]
 
 
 
 
94
 
95
  headers = utils.sanitize_list_for_csv(headers)
96
  dataset_files = list(Path(self.flagging_dir).glob("dataset*.csv"))
@@ -130,16 +136,16 @@ class CSVLogger_custom(FlaggingCallback):
130
  print("Using existing dataset file at:", self.dataset_filepath)
131
 
132
  def flag(
133
- self,
134
- flag_data: list[Any],
135
- flag_option: str | None = None,
136
- username: str | None = None,
137
- save_to_csv: bool = True,
138
- save_to_dynamodb: bool = False,
139
- dynamodb_table_name: str | None = None,
140
- dynamodb_headers: list[str] | None = None, # New: specify headers for DynamoDB
141
- replacement_headers: list[str] | None = None
142
- ) -> int:
143
  if self.first_time:
144
  additional_headers = list()
145
  if flag_option is not None:
@@ -147,7 +153,10 @@ class CSVLogger_custom(FlaggingCallback):
147
  if username is not None:
148
  additional_headers.append("username")
149
  additional_headers.append("id")
150
- self._create_dataset_file(additional_headers=additional_headers, replacement_headers=replacement_headers)
 
 
 
151
  self.first_time = False
152
 
153
  csv_data = list()
@@ -180,59 +189,77 @@ class CSVLogger_custom(FlaggingCallback):
180
  generated_id = str(uuid.uuid4())
181
  csv_data.append(generated_id)
182
 
183
- timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')[:-3] # Correct format for Amazon Athena
184
- csv_data.append(timestamp)
 
 
185
 
186
  # Build the headers
187
- headers = (
188
- [getattr(component, "label", None) or f"component {idx}" for idx, component in enumerate(self.components)]
189
- )
 
190
  if flag_option is not None:
191
  headers.append("flag")
192
  if username is not None:
193
  headers.append("username")
194
  headers.append("id")
195
- headers.append("timestamp")
196
 
197
  line_count = -1
198
 
199
  if save_to_csv:
200
  with self.lock:
201
- with open(self.dataset_filepath, "a", newline="", encoding="utf-8") as csvfile:
 
 
202
  writer = csv.writer(csvfile)
203
  writer.writerow(utils.sanitize_list_for_csv(csv_data))
204
  with open(self.dataset_filepath, encoding="utf-8") as csvfile:
205
  line_count = len(list(csv.reader(csvfile))) - 1
206
 
207
- if save_to_dynamodb == True:
208
 
209
  if RUN_AWS_FUNCTIONS == "1":
210
  try:
211
  print("Connecting to DynamoDB via existing SSO connection")
212
- dynamodb = boto3.resource('dynamodb', region_name=AWS_REGION)
213
-
214
- test_connection = dynamodb.meta.client.list_tables()
215
 
216
  except Exception as e:
217
- print("No SSO credentials found:", e)
218
  if AWS_ACCESS_KEY and AWS_SECRET_KEY:
219
  print("Trying DynamoDB credentials from environment variables")
220
- dynamodb = boto3.resource('dynamodb',aws_access_key_id=AWS_ACCESS_KEY,
221
- aws_secret_access_key=AWS_SECRET_KEY, region_name=AWS_REGION)
 
 
 
 
222
  else:
223
- raise Exception("AWS credentials for DynamoDB logging not found")
 
 
224
  else:
225
  raise Exception("AWS credentials for DynamoDB logging not found")
226
-
227
  if dynamodb_table_name is None:
228
- raise ValueError("You must provide a dynamodb_table_name if save_to_dynamodb is True")
229
-
230
- if dynamodb_headers: dynamodb_headers = dynamodb_headers
231
- if not dynamodb_headers and replacement_headers: dynamodb_headers = replacement_headers
232
- elif headers: dynamodb_headers = headers
 
 
 
 
 
233
  elif not dynamodb_headers:
234
- raise ValueError("Headers not found. You must provide dynamodb_headers or replacement_headers to create a new table.")
235
-
 
 
236
  if flag_option is not None:
237
  if "flag" not in dynamodb_headers:
238
  dynamodb_headers.append("flag")
@@ -249,22 +276,27 @@ class CSVLogger_custom(FlaggingCallback):
249
  table = dynamodb.Table(dynamodb_table_name)
250
  table.load()
251
  except botocore.exceptions.ClientError as e:
252
- if e.response['Error']['Code'] == 'ResourceNotFoundException':
253
-
254
  attribute_definitions = [
255
- {'AttributeName': 'id', 'AttributeType': 'S'} # Only define key attributes here
 
 
 
256
  ]
257
 
258
  table = dynamodb.create_table(
259
  TableName=dynamodb_table_name,
260
  KeySchema=[
261
- {'AttributeName': 'id', 'KeyType': 'HASH'} # Partition key
262
  ],
263
  AttributeDefinitions=attribute_definitions,
264
- BillingMode='PAY_PER_REQUEST'
265
- )
266
  # Wait until the table exists
267
- table.meta.client.get_waiter('table_exists').wait(TableName=dynamodb_table_name)
 
 
268
  time.sleep(5)
269
  print(f"Table '{dynamodb_table_name}' created successfully.")
270
  else:
@@ -274,12 +306,17 @@ class CSVLogger_custom(FlaggingCallback):
274
 
275
  try:
276
  item = {
277
- 'id': str(generated_id), # UUID primary key
278
- 'timestamp': timestamp,
279
  }
280
 
281
  # Map the headers to values
282
- item.update({header: str(value) for header, value in zip(dynamodb_headers, csv_data)})
 
 
 
 
 
283
 
284
  table.put_item(Item=item)
285
 
@@ -287,4 +324,4 @@ class CSVLogger_custom(FlaggingCallback):
287
  except Exception as e:
288
  print("Could not upload log to DynamobDB due to", e)
289
 
290
- return line_count
 
1
  from __future__ import annotations
2
+
3
  import csv
 
 
4
  import os
5
  import re
 
 
 
6
  import time
7
+ import uuid
8
  from collections.abc import Sequence
9
+ from datetime import datetime
10
  from pathlib import Path
11
+
12
+ # from multiprocessing import Lock
13
+ from threading import Lock
14
  from typing import TYPE_CHECKING, Any
15
+
16
+ import boto3
17
+ import botocore
18
  from gradio import utils
19
+ from gradio_client import utils as client_utils
20
 
21
+ from tools.config import AWS_ACCESS_KEY, AWS_REGION, AWS_SECRET_KEY, RUN_AWS_FUNCTIONS
22
 
23
  if TYPE_CHECKING:
24
  from gradio.components import Component
25
+
26
  from gradio.flagging import FlaggingCallback
27
+
28
 
29
  class CSVLogger_custom(FlaggingCallback):
30
  """
 
70
  self.first_time = True
71
 
72
  def _create_dataset_file(
73
+ self,
74
+ additional_headers: list[str] | None = None,
75
+ replacement_headers: list[str] | None = None,
76
+ ):
77
  os.makedirs(self.flagging_dir, exist_ok=True)
78
 
79
  if replacement_headers:
80
  if additional_headers is None:
81
+ additional_headers = list()
82
 
83
  if len(replacement_headers) != len(self.components):
84
  raise ValueError(
 
89
  else:
90
  if additional_headers is None:
91
  additional_headers = []
92
+ headers = (
93
+ [
94
+ getattr(component, "label", None) or f"component {idx}"
95
+ for idx, component in enumerate(self.components)
96
+ ]
97
+ + additional_headers
98
+ + ["timestamp"]
99
+ )
100
 
101
  headers = utils.sanitize_list_for_csv(headers)
102
  dataset_files = list(Path(self.flagging_dir).glob("dataset*.csv"))
 
136
  print("Using existing dataset file at:", self.dataset_filepath)
137
 
138
  def flag(
139
+ self,
140
+ flag_data: list[Any],
141
+ flag_option: str | None = None,
142
+ username: str | None = None,
143
+ save_to_csv: bool = True,
144
+ save_to_dynamodb: bool = False,
145
+ dynamodb_table_name: str | None = None,
146
+ dynamodb_headers: list[str] | None = None, # New: specify headers for DynamoDB
147
+ replacement_headers: list[str] | None = None,
148
+ ) -> int:
149
  if self.first_time:
150
  additional_headers = list()
151
  if flag_option is not None:
 
153
  if username is not None:
154
  additional_headers.append("username")
155
  additional_headers.append("id")
156
+ self._create_dataset_file(
157
+ additional_headers=additional_headers,
158
+ replacement_headers=replacement_headers,
159
+ )
160
  self.first_time = False
161
 
162
  csv_data = list()
 
189
  generated_id = str(uuid.uuid4())
190
  csv_data.append(generated_id)
191
 
192
+ timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[
193
+ :-3
194
+ ] # Correct format for Amazon Athena
195
+ csv_data.append(timestamp)
196
 
197
  # Build the headers
198
+ headers = [
199
+ getattr(component, "label", None) or f"component {idx}"
200
+ for idx, component in enumerate(self.components)
201
+ ]
202
  if flag_option is not None:
203
  headers.append("flag")
204
  if username is not None:
205
  headers.append("username")
206
  headers.append("id")
207
+ headers.append("timestamp")
208
 
209
  line_count = -1
210
 
211
  if save_to_csv:
212
  with self.lock:
213
+ with open(
214
+ self.dataset_filepath, "a", newline="", encoding="utf-8"
215
+ ) as csvfile:
216
  writer = csv.writer(csvfile)
217
  writer.writerow(utils.sanitize_list_for_csv(csv_data))
218
  with open(self.dataset_filepath, encoding="utf-8") as csvfile:
219
  line_count = len(list(csv.reader(csvfile))) - 1
220
 
221
+ if save_to_dynamodb is True:
222
 
223
  if RUN_AWS_FUNCTIONS == "1":
224
  try:
225
  print("Connecting to DynamoDB via existing SSO connection")
226
+ dynamodb = boto3.resource("dynamodb", region_name=AWS_REGION)
227
+
228
+ dynamodb.meta.client.list_tables()
229
 
230
  except Exception as e:
231
+ print("No SSO credentials found:", e)
232
  if AWS_ACCESS_KEY and AWS_SECRET_KEY:
233
  print("Trying DynamoDB credentials from environment variables")
234
+ dynamodb = boto3.resource(
235
+ "dynamodb",
236
+ aws_access_key_id=AWS_ACCESS_KEY,
237
+ aws_secret_access_key=AWS_SECRET_KEY,
238
+ region_name=AWS_REGION,
239
+ )
240
  else:
241
+ raise Exception(
242
+ "AWS credentials for DynamoDB logging not found"
243
+ )
244
  else:
245
  raise Exception("AWS credentials for DynamoDB logging not found")
246
+
247
  if dynamodb_table_name is None:
248
+ raise ValueError(
249
+ "You must provide a dynamodb_table_name if save_to_dynamodb is True"
250
+ )
251
+
252
+ if dynamodb_headers:
253
+ dynamodb_headers = dynamodb_headers
254
+ if not dynamodb_headers and replacement_headers:
255
+ dynamodb_headers = replacement_headers
256
+ elif headers:
257
+ dynamodb_headers = headers
258
  elif not dynamodb_headers:
259
+ raise ValueError(
260
+ "Headers not found. You must provide dynamodb_headers or replacement_headers to create a new table."
261
+ )
262
+
263
  if flag_option is not None:
264
  if "flag" not in dynamodb_headers:
265
  dynamodb_headers.append("flag")
 
276
  table = dynamodb.Table(dynamodb_table_name)
277
  table.load()
278
  except botocore.exceptions.ClientError as e:
279
+ if e.response["Error"]["Code"] == "ResourceNotFoundException":
280
+
281
  attribute_definitions = [
282
+ {
283
+ "AttributeName": "id",
284
+ "AttributeType": "S",
285
+ } # Only define key attributes here
286
  ]
287
 
288
  table = dynamodb.create_table(
289
  TableName=dynamodb_table_name,
290
  KeySchema=[
291
+ {"AttributeName": "id", "KeyType": "HASH"} # Partition key
292
  ],
293
  AttributeDefinitions=attribute_definitions,
294
+ BillingMode="PAY_PER_REQUEST",
295
+ )
296
  # Wait until the table exists
297
+ table.meta.client.get_waiter("table_exists").wait(
298
+ TableName=dynamodb_table_name
299
+ )
300
  time.sleep(5)
301
  print(f"Table '{dynamodb_table_name}' created successfully.")
302
  else:
 
306
 
307
  try:
308
  item = {
309
+ "id": str(generated_id), # UUID primary key
310
+ "timestamp": timestamp,
311
  }
312
 
313
  # Map the headers to values
314
+ item.update(
315
+ {
316
+ header: str(value)
317
+ for header, value in zip(dynamodb_headers, csv_data)
318
+ }
319
+ )
320
 
321
  table.put_item(Item=item)
322
 
 
324
  except Exception as e:
325
  print("Could not upload log to DynamobDB due to", e)
326
 
327
+ return line_count
tools/custom_image_analyser_engine.py CHANGED
@@ -1,32 +1,40 @@
1
- import pytesseract
2
- import numpy as np
3
- import pandas as pd
4
- import gradio as gr
5
- from presidio_analyzer import AnalyzerEngine, RecognizerResult
6
- from typing import List, Dict, Optional, Union, Tuple, Any
7
- from dataclasses import dataclass
8
- import time
9
- import cv2
10
- import re
11
  import copy
12
- import botocore
 
13
  from copy import deepcopy
 
 
 
 
 
 
 
 
14
  from pdfminer.layout import LTChar
15
  from PIL import Image
16
- from typing import Optional, Tuple, Union
 
 
 
 
 
 
 
17
  from tools.helper_functions import clean_unicode_text
18
- from tools.presidio_analyzer_custom import recognizer_result_from_dict
19
  from tools.load_spacy_model_custom_recognisers import custom_entities
20
- from tools.config import PREPROCESS_LOCAL_OCR_IMAGES, DEFAULT_LANGUAGE, LOCAL_PII_OPTION, AWS_PII_OPTION
21
 
22
- if PREPROCESS_LOCAL_OCR_IMAGES == "True": PREPROCESS_LOCAL_OCR_IMAGES = True
23
- else: PREPROCESS_LOCAL_OCR_IMAGES = False
 
 
24
 
25
  try:
26
  from paddleocr import PaddleOCR
27
  except ImportError:
28
  PaddleOCR = None
29
 
 
30
  # --- Language utilities ---
31
  def _normalize_lang(language: str) -> str:
32
  return language.strip().lower().replace("-", "_") if language else "en"
@@ -38,35 +46,75 @@ def _tesseract_lang_code(language: str) -> str:
38
 
39
  mapping = {
40
  # Common
41
- "en": "eng", "eng": "eng",
42
- "fr": "fra", "fre": "fra", "fra": "fra",
43
- "de": "deu", "ger": "deu", "deu": "deu",
44
- "es": "spa", "spa": "spa",
45
- "it": "ita", "ita": "ita",
46
- "nl": "nld", "dut": "nld", "nld": "nld",
47
- "pt": "por", "por": "por",
48
- "ru": "rus", "rus": "rus",
49
- "ar": "ara", "ara": "ara",
 
 
 
 
 
 
 
 
 
 
 
 
50
  # Nordics
51
- "sv": "swe", "swe": "swe",
52
- "no": "nor", "nb": "nor", "nn": "nor", "nor": "nor",
53
- "fi": "fin", "fin": "fin",
54
- "da": "dan", "dan": "dan",
 
 
 
 
 
 
55
  # Eastern/Central
56
- "pl": "pol", "pol": "pol",
57
- "cs": "ces", "cz": "ces", "ces": "ces",
58
- "hu": "hun", "hun": "hun",
59
- "ro": "ron", "rum": "ron", "ron": "ron",
60
- "bg": "bul", "bul": "bul",
61
- "el": "ell", "gre": "ell", "ell": "ell",
 
 
 
 
 
 
 
 
 
62
  # Asian
63
- "ja": "jpn", "jp": "jpn", "jpn": "jpn",
64
- "zh": "chi_sim", "zh_cn": "chi_sim", "zh_hans": "chi_sim", "chi_sim": "chi_sim",
65
- "zh_tw": "chi_tra", "zh_hk": "chi_tra", "zh_tr": "chi_tra", "chi_tra": "chi_tra",
66
- "hi": "hin", "hin": "hin",
67
- "bn": "ben", "ben": "ben",
68
- "ur": "urd", "urd": "urd",
69
- "fa": "fas", "per": "fas", "fas": "fas",
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  }
71
 
72
  return mapping.get(lang, "eng")
@@ -107,6 +155,7 @@ def _paddle_lang_code(language: str) -> str:
107
 
108
  return mapping.get(lang, "en")
109
 
 
110
  @dataclass
111
  class OCRResult:
112
  text: str
@@ -117,6 +166,7 @@ class OCRResult:
117
  conf: float = None
118
  line: int = None
119
 
 
120
  @dataclass
121
  class CustomImageRecognizerResult:
122
  entity_type: str
@@ -127,9 +177,12 @@ class CustomImageRecognizerResult:
127
  top: int
128
  width: int
129
  height: int
130
- text: str
 
 
131
  class ImagePreprocessor:
132
  """ImagePreprocessor class. Parent class for image preprocessing objects."""
 
133
  def __init__(self, use_greyscale: bool = True) -> None:
134
  self.use_greyscale = use_greyscale
135
 
@@ -146,11 +199,13 @@ class ImagePreprocessor:
146
  return img
147
 
148
  @staticmethod
149
- def _get_bg_color(image: np.ndarray, is_greyscale: bool, invert: bool = False) -> Union[int, Tuple[int, int, int]]:
 
 
150
  # Note: Modified to expect numpy array for bincount
151
  if invert:
152
- image = 255 - image # Simple inversion for greyscale numpy array
153
-
154
  if is_greyscale:
155
  bg_color = int(np.bincount(image.flatten()).argmax())
156
  else:
@@ -158,6 +213,7 @@ class ImagePreprocessor:
158
  # For this pipeline, we only use greyscale, so it's fine.
159
  # A simple alternative:
160
  from scipy import stats
 
161
  bg_color = tuple(stats.mode(image.reshape(-1, 3), axis=0)[0][0])
162
  return bg_color
163
 
@@ -166,10 +222,14 @@ class ImagePreprocessor:
166
  contrast = np.std(image)
167
  mean_intensity = np.mean(image)
168
  return contrast, mean_intensity
169
-
 
170
  class BilateralFilter(ImagePreprocessor):
171
  """Applies bilateral filtering."""
172
- def __init__(self, diameter: int = 9, sigma_color: int = 75, sigma_space: int = 75) -> None:
 
 
 
173
  super().__init__(use_greyscale=True)
174
  self.diameter = diameter
175
  self.sigma_color = sigma_color
@@ -177,16 +237,32 @@ class BilateralFilter(ImagePreprocessor):
177
 
178
  def preprocess_image(self, image: np.ndarray) -> Tuple[np.ndarray, dict]:
179
  # Modified to accept and return numpy array for consistency in the pipeline
180
- filtered_image = cv2.bilateralFilter(image, self.diameter, self.sigma_color, self.sigma_space)
181
- metadata = {"diameter": self.diameter, "sigma_color": self.sigma_color, "sigma_space": self.sigma_space}
 
 
 
 
 
 
182
  return filtered_image, metadata
183
-
 
184
  class SegmentedAdaptiveThreshold(ImagePreprocessor):
185
  """Applies adaptive thresholding."""
186
- def __init__(self, block_size: int = 21, contrast_threshold: int = 40, c_low_contrast: int = 5,
187
- c_high_contrast: int = 10, bg_threshold: int = 127) -> None:
 
 
 
 
 
 
 
188
  super().__init__(use_greyscale=True)
189
- self.block_size = block_size if block_size % 2 == 1 else block_size + 1 # Ensure odd
 
 
190
  self.c_low_contrast = c_low_contrast
191
  self.c_high_contrast = c_high_contrast
192
  self.bg_threshold = bg_threshold
@@ -196,20 +272,37 @@ class SegmentedAdaptiveThreshold(ImagePreprocessor):
196
  # Modified to accept and return numpy array
197
  background_color = self._get_bg_color(image, True)
198
  contrast, _ = self._get_image_contrast(image)
199
- c = self.c_low_contrast if contrast <= self.contrast_threshold else self.c_high_contrast
 
 
 
 
200
 
201
- if background_color < self.bg_threshold: # Dark background, light text
202
  adaptive_threshold_image = cv2.adaptiveThreshold(
203
- image, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY_INV, self.block_size, -c
 
 
 
 
 
204
  )
205
- else: # Light background, dark text
206
  adaptive_threshold_image = cv2.adaptiveThreshold(
207
- image, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY, self.block_size, c
 
 
 
 
 
208
  )
209
  metadata = {"C": c, "background_color": background_color, "contrast": contrast}
210
  return adaptive_threshold_image, metadata
 
 
211
  class ImageRescaling(ImagePreprocessor):
212
  """Rescales images based on their size."""
 
213
  def __init__(self, target_dpi: int = 300, assumed_input_dpi: int = 96) -> None:
214
  super().__init__(use_greyscale=True)
215
  self.target_dpi = target_dpi
@@ -224,17 +317,19 @@ class ImageRescaling(ImagePreprocessor):
224
  width = int(image.shape[1] * scale_factor)
225
  height = int(image.shape[0] * scale_factor)
226
  dimensions = (width, height)
227
-
228
  # Use better interpolation for upscaling vs downscaling
229
  interpolation = cv2.INTER_CUBIC if scale_factor > 1.0 else cv2.INTER_AREA
230
  rescaled_image = cv2.resize(image, dimensions, interpolation=interpolation)
231
  metadata["scale_factor"] = scale_factor
232
  return rescaled_image, metadata
233
-
234
  return image, metadata
235
 
 
236
  class ContrastSegmentedImageEnhancer(ImagePreprocessor):
237
  """Class containing all logic to perform contrastive segmentation."""
 
238
  def __init__(
239
  self,
240
  bilateral_filter: Optional[BilateralFilter] = None,
@@ -260,7 +355,9 @@ class ContrastSegmentedImageEnhancer(ImagePreprocessor):
260
  adjusted_contrast = contrast
261
  return adjusted_image, contrast, adjusted_contrast
262
 
263
- def preprocess_image(self, image: Image.Image, perform_binarization: bool = False) -> Tuple[Image.Image, dict]:
 
 
264
  """
265
  A corrected, logical pipeline for OCR preprocessing.
266
  Order: Greyscale -> Rescale -> Denoise -> Enhance Contrast -> Binarize
@@ -271,7 +368,9 @@ class ContrastSegmentedImageEnhancer(ImagePreprocessor):
271
  image_np = self.convert_image_to_array(image)
272
 
273
  # 2. Rescale image to optimal DPI (while still greyscale)
274
- rescaled_image_np, scale_metadata = self.image_rescaling.preprocess_image(image_np)
 
 
275
 
276
  # 3. Apply bilateral filtering for noise reduction
277
  filtered_image_np, _ = self.bilateral_filter.preprocess_image(rescaled_image_np)
@@ -281,31 +380,32 @@ class ContrastSegmentedImageEnhancer(ImagePreprocessor):
281
 
282
  # 5. Adaptive Thresholding (Binarization) - This is the final step
283
  if perform_binarization:
284
- final_image_np, threshold_metadata = self.adaptive_threshold.preprocess_image(
285
- adjusted_image_np
286
  )
287
  else:
288
  final_image_np = adjusted_image_np
289
  threshold_metadata = {}
290
-
291
  # Combine metadata
292
  final_metadata = {**scale_metadata, **threshold_metadata}
293
-
294
  # Convert final numpy array back to PIL Image for return
295
  return Image.fromarray(final_image_np), final_metadata
296
 
297
- def rescale_ocr_data(ocr_data, scale_factor:float):
298
-
 
299
  # We loop from 0 to the number of detected words.
300
- num_boxes = len(ocr_data['text'])
301
  for i in range(num_boxes):
302
  # We only want to process actual words, not empty boxes Tesseract might find
303
- if int(ocr_data['conf'][i]) > -1: # -1 confidence is for structural elements
304
  # Get coordinates from the processed image using the index 'i'
305
- x_proc = ocr_data['left'][i]
306
- y_proc = ocr_data['top'][i]
307
- w_proc = ocr_data['width'][i]
308
- h_proc = ocr_data['height'][i]
309
 
310
  # Apply the inverse transformation (division)
311
  x_orig = int(x_proc / scale_factor)
@@ -315,49 +415,54 @@ def rescale_ocr_data(ocr_data, scale_factor:float):
315
 
316
  # --- THE MAPPING STEP ---
317
  # Update the dictionary values in-place using the same index 'i'
318
- ocr_data['left'][i] = x_orig
319
- ocr_data['top'][i] = y_orig
320
- ocr_data['width'][i] = w_orig
321
- ocr_data['height'][i] = h_orig
322
-
323
  return ocr_data
324
 
325
- def filter_entities_for_language(entities: List[str], valid_language_entities: List[str], language: str) -> List[str]:
 
 
 
326
 
327
  if not valid_language_entities:
328
  print(f"No valid entities supported for language: {language}")
329
- #raise Warning(f"No valid entities supported for language: {language}")
330
  if not entities:
331
  print(f"No entities provided for language: {language}")
332
- #raise Warning(f"No entities provided for language: {language}")
333
-
334
- #print("entities:", entities)
335
- #print("valid_language_entities:", valid_language_entities)
336
- # print("language:", language)
337
 
338
  filtered_entities = [
339
- entity for entity in entities
340
- if entity in valid_language_entities
341
  ]
342
 
343
  if not filtered_entities:
344
  print(f"No relevant entities supported for language: {language}")
345
- #raise Warning(f"No relevant entities supported for language: {language}")
346
 
347
  if language != "en":
348
- gr.Info(f"Using {str(filtered_entities)} entities for local model analysis for language: {language}")
 
 
349
 
350
  return filtered_entities
351
 
 
352
  class CustomImageAnalyzerEngine:
353
  def __init__(
354
  self,
355
  analyzer_engine: Optional[AnalyzerEngine] = None,
356
- ocr_engine: str = "tesseract",
357
  tesseract_config: Optional[str] = None,
358
  paddle_kwargs: Optional[Dict[str, Any]] = None,
359
  image_preprocessor: Optional[ImagePreprocessor] = None,
360
- language: Optional[str] = DEFAULT_LANGUAGE
361
  ):
362
  """
363
  Initializes the CustomImageAnalyzerEngine.
@@ -370,7 +475,9 @@ class CustomImageAnalyzerEngine:
370
  :param language: Preferred OCR language (e.g., "en", "fr", "de"). Defaults to DEFAULT_LANGUAGE.
371
  """
372
  if ocr_engine not in ["tesseract", "paddle", "hybrid"]:
373
- raise ValueError("ocr_engine must be either 'tesseract', 'hybrid', or 'paddle'")
 
 
374
 
375
  self.ocr_engine = ocr_engine
376
 
@@ -378,23 +485,28 @@ class CustomImageAnalyzerEngine:
378
  self.language = language or DEFAULT_LANGUAGE or "en"
379
  self.tesseract_lang = _tesseract_lang_code(self.language)
380
  self.paddle_lang = _paddle_lang_code(self.language)
381
-
382
  if self.ocr_engine == "paddle" or self.ocr_engine == "hybrid":
383
  if PaddleOCR is None:
384
- raise ImportError("paddleocr is not installed. Please run 'pip install paddleocr paddlepaddle'")
 
 
385
  # Default paddle configuration if none provided
386
  if paddle_kwargs is None:
387
- paddle_kwargs = {'use_textline_orientation': True, 'lang': self.paddle_lang}
 
 
 
388
  else:
389
  # Enforce language if not explicitly provided
390
- paddle_kwargs.setdefault('lang', self.paddle_lang)
391
  self.paddle_ocr = PaddleOCR(**paddle_kwargs)
392
 
393
  if not analyzer_engine:
394
  analyzer_engine = AnalyzerEngine()
395
  self.analyzer_engine = analyzer_engine
396
 
397
- self.tesseract_config = tesseract_config or '--oem 3 --psm 11'
398
 
399
  if not image_preprocessor:
400
  image_preprocessor = ContrastSegmentedImageEnhancer()
@@ -403,71 +515,82 @@ class CustomImageAnalyzerEngine:
403
  def _sanitize_filename(self, text: str, max_length: int = 20) -> str:
404
  """
405
  Sanitizes text for use in filenames by removing invalid characters and limiting length.
406
-
407
  :param text: The text to sanitize
408
  :param max_length: Maximum length of the sanitized text
409
  :return: Sanitized text safe for filenames
410
  """
411
-
412
  # Remove or replace invalid filename characters
413
  # Windows: < > : " | ? * \ /
414
  # Unix: / (forward slash)
415
  # Also remove control characters and other problematic chars
416
  invalid_chars = r'[<>:"|?*\\/\x00-\x1f\x7f-\x9f]'
417
- sanitized = re.sub(invalid_chars, '_', text)
418
-
419
  # Replace multiple consecutive underscores with a single one
420
- sanitized = re.sub(r'_+', '_', sanitized)
421
-
422
  # Remove leading/trailing underscores and spaces
423
- sanitized = sanitized.strip('_ ')
424
-
425
  # If empty after sanitization, use a default value
426
  if not sanitized:
427
- sanitized = 'text'
428
-
429
  # Limit to max_length characters
430
  if len(sanitized) > max_length:
431
  sanitized = sanitized[:max_length]
432
  # Ensure we don't end with an underscore if we cut in the middle
433
- sanitized = sanitized.rstrip('_')
434
-
435
  return sanitized
436
 
437
- def _convert_paddle_to_tesseract_format(self, paddle_results: List[Any]) -> Dict[str, List]:
 
 
438
  """Converts PaddleOCR result format to Tesseract's dictionary format. NOTE: This attempts to create word-level bounding boxes by estimating the distance between characters in sentence-level text output. This is currently quite inaccurate, and word-level bounding boxes should not be relied upon."""
439
 
440
- output = {'text': [], 'left': [], 'top': [], 'width': [], 'height': [], 'conf': []}
 
 
 
 
 
 
 
441
 
442
  # paddle_results is now a list of dictionaries with detailed information
443
  if not paddle_results:
444
  return output
445
-
446
  for page_result in paddle_results:
447
  # Extract text recognition results from the new format
448
- rec_texts = page_result.get('rec_texts', [])
449
- rec_scores = page_result.get('rec_scores', [])
450
- rec_polys = page_result.get('rec_polys', [])
451
-
452
- for line_text, line_confidence, bounding_box in zip(rec_texts, rec_scores, rec_polys):
 
 
453
  # bounding_box is now a numpy array with shape (4, 2)
454
  # Convert to list of coordinates if it's a numpy array
455
- if hasattr(bounding_box, 'tolist'):
456
  box = bounding_box.tolist()
457
  else:
458
  box = bounding_box
459
-
460
  # box is [[x1,y1], [x2,y2], [x3,y3], [x4,y4]]
461
  x_coords = [p[0] for p in box]
462
  y_coords = [p[1] for p in box]
463
-
464
  line_left = int(min(x_coords))
465
  line_top = int(min(y_coords))
466
  line_width = int(max(x_coords) - line_left)
467
  line_height = int(max(y_coords) - line_top)
468
- #line_y_center = (max(y_coords) + min(y_coords)) / 2
469
 
470
- # 2. Split the line into words
471
  words = line_text.split()
472
  if not words:
473
  continue
@@ -482,69 +605,78 @@ class CustomImageAnalyzerEngine:
482
  for word in words:
483
  word_width = int(len(word) * avg_char_width)
484
  word_left = line_left + int(current_char_offset * avg_char_width)
485
-
486
- output['text'].append(word)
487
- output['left'].append(word_left)
488
- output['top'].append(line_top)
489
- output['width'].append(word_width)
490
- output['height'].append(line_height)
491
  # Use the line's confidence for each word derived from it
492
- output['conf'].append(int(line_confidence * 100))
493
 
494
  # Update offset for the next word (add word length + 1 for the space)
495
  current_char_offset += len(word) + 1
496
-
497
  return output
498
-
499
  def _perform_hybrid_ocr(
500
- self,
501
- image: Image.Image,
502
- confidence_threshold: int = 65,
503
- padding: int = 5,
504
- ocr: Optional[Any] = None
505
- ) -> Dict[str, list]:
506
  """
507
  Performs OCR using Tesseract for bounding boxes and PaddleOCR for low-confidence text.
508
  Returns data in the same dictionary format as pytesseract.image_to_data.
509
  """
510
  if ocr is None:
511
- if hasattr(self, 'paddle_ocr') and self.paddle_ocr is not None:
512
  ocr = self.paddle_ocr
513
  else:
514
- raise ValueError("No OCR object provided and 'paddle_ocr' is not initialized.")
515
-
 
 
516
  print("Starting hybrid OCR process...")
517
-
518
  # 1. Get initial word-level results from Tesseract
519
  tesseract_data = pytesseract.image_to_data(
520
  image,
521
  output_type=pytesseract.Output.DICT,
522
  config=self.tesseract_config,
523
- lang=self.tesseract_lang
524
  )
525
 
526
- #tesseract_data['abs_line_id'] = tesseract_data.groupby(['block_num', 'par_num', 'line_num']).ngroup()
527
-
528
- final_data = {'text': [], 'left': [], 'top': [], 'width': [], 'height': [], 'conf': []}
529
-
530
- num_words = len(tesseract_data['text'])
 
 
 
 
 
 
 
531
 
532
  # This handles the "no text on page" case. If num_words is 0, the loop is skipped
533
  # and an empty dictionary with empty lists is returned, which is the correct behavior.
534
  for i in range(num_words):
535
- text = tesseract_data['text'][i]
536
- conf = int(tesseract_data['conf'][i])
537
-
538
  # Skip empty text boxes or non-word elements (like page/block markers)
539
  if not text.strip() or conf == -1:
540
  continue
541
 
542
- left = tesseract_data['left'][i]
543
- top = tesseract_data['top'][i]
544
- width = tesseract_data['width'][i]
545
- height = tesseract_data['height'][i]
546
- #line_number = tesseract_data['abs_line_id'][i]
547
-
548
  # If confidence is low, use PaddleOCR for a second opinion
549
  if conf < confidence_threshold:
550
  img_width, img_height = image.size
@@ -552,66 +684,76 @@ class CustomImageAnalyzerEngine:
552
  crop_top = max(0, top - padding)
553
  crop_right = min(img_width, left + width + padding + 15)
554
  crop_bottom = min(img_height, top + height + padding)
555
-
556
  # Ensure crop dimensions are valid
557
  if crop_right <= crop_left or crop_bottom <= crop_top:
558
- continue # Skip invalid crops
559
 
560
- cropped_image = image.crop((crop_left, crop_top, crop_right, crop_bottom))
 
 
561
  cropped_image_np = np.array(cropped_image)
562
-
563
  if len(cropped_image_np.shape) == 2:
564
  cropped_image_np = np.stack([cropped_image_np] * 3, axis=-1)
565
-
566
  paddle_results = ocr.predict(cropped_image_np)
567
-
568
  if paddle_results and paddle_results[0]:
569
- rec_texts = paddle_results[0].get('rec_texts', [])
570
- rec_scores = paddle_results[0].get('rec_scores', [])
571
-
572
  if rec_texts and rec_scores:
573
  new_text = " ".join(rec_texts)
574
- new_conf = int(round(np.median(rec_scores) * 100,0))
575
 
576
  # Only replace if Paddle's confidence is better
577
  if new_conf > conf:
578
- print(f" Re-OCR'd word: '{text}' (conf: {conf}) -> '{new_text}' (conf: {new_conf:.0f})")
 
 
579
 
580
  # For exporting example image comparisons, not used here
581
  safe_text = self._sanitize_filename(text, max_length=20)
582
- safe_new_text = self._sanitize_filename(new_text, max_length=20)
583
  output_image_path = f"examples/tess_vs_paddle_examples/{conf}_conf_{safe_text}_to_{new_text}_{new_conf}.png"
584
  cropped_image.save(output_image_path)
585
 
586
  text = new_text
587
  conf = new_conf
588
-
589
  else:
590
- print(f" '{text}' (conf: {conf}) -> Paddle result '{new_text}' (conf: {new_conf:.0f}) was not better. Keeping original.")
 
 
591
  else:
592
  # Paddle ran but found nothing, so discard the original low-confidence word
593
- print(f" '{text}' (conf: {conf}) -> No text found by Paddle. Discarding.")
594
- text = ''
 
 
595
  else:
596
  # Paddle found nothing, discard original word
597
- print(f" '{text}' (conf: {conf}) -> No text found by Paddle. Discarding.")
598
- text = ''
 
 
599
 
600
  # Append the final result (either original, replaced, or skipped if empty)
601
  if text.strip():
602
- final_data['text'].append(clean_unicode_text(text))
603
- final_data['left'].append(left)
604
- final_data['top'].append(top)
605
- final_data['width'].append(width)
606
- final_data['height'].append(height)
607
- final_data['conf'].append(int(conf))
608
- #final_data['line_number'].append(int(line_number))
609
-
610
  return final_data
611
-
612
- def perform_ocr(self,
613
- image: Union[str, Image.Image, np.ndarray],
614
- ocr: Optional[Any] = None) -> List[OCRResult]:
615
  """
616
  Performs OCR on the given image using the configured engine.
617
  """
@@ -619,10 +761,12 @@ class CustomImageAnalyzerEngine:
619
  image = Image.open(image)
620
  elif isinstance(image, np.ndarray):
621
  image = Image.fromarray(image)
622
-
623
  # Pre-process image - currently seems to give worse results!
624
- if str(PREPROCESS_LOCAL_OCR_IMAGES).lower() == 'true':
625
- image, preprocessing_metadata = self.image_preprocessor.preprocess_image(image)
 
 
626
  else:
627
  preprocessing_metadata = {}
628
 
@@ -637,68 +781,71 @@ class CustomImageAnalyzerEngine:
637
  image,
638
  output_type=pytesseract.Output.DICT,
639
  config=self.tesseract_config,
640
- lang=self.tesseract_lang # Ensure the Tesseract language data (e.g., fra.traineddata) is installed on your system.
641
  )
642
 
643
- #ocr_data['abs_line_id'] = ocr_data.groupby(['block_num', 'par_num', 'line_num']).ngroup()
644
 
645
  elif self.ocr_engine == "paddle":
646
 
647
- image_np = np.array(image) # image_processed
648
-
649
  # PaddleOCR may need an RGB image. Ensure it has 3 channels.
650
  if len(image_np.shape) == 2:
651
  image_np = np.stack([image_np] * 3, axis=-1)
652
 
653
  if ocr is None:
654
- if hasattr(self, 'paddle_ocr') and self.paddle_ocr is not None:
655
  ocr = self.paddle_ocr
656
  else:
657
- raise ValueError("No OCR object provided and 'paddle_ocr' is not initialised.")
 
 
658
 
659
- #ocr = PaddleOCR(use_textline_orientation=True, lang='en')
660
  paddle_results = ocr.predict(image_np)
661
  ocr_data = self._convert_paddle_to_tesseract_format(paddle_results)
662
 
663
  else:
664
  raise RuntimeError(f"Unsupported OCR engine: {self.ocr_engine}")
665
-
666
  if preprocessing_metadata:
667
- scale_factor = preprocessing_metadata.get('scale_factor', 1.0)
668
  ocr_data = rescale_ocr_data(ocr_data, scale_factor)
669
 
670
  # The rest of your processing pipeline now works for both engines
671
  ocr_result = ocr_data
672
-
673
  # Filter out empty strings and low confidence results
674
  valid_indices = [
675
- i for i, text in enumerate(ocr_result['text'])
676
- if text.strip() and int(ocr_result['conf'][i]) > 0
 
677
  ]
678
-
679
  return [
680
  OCRResult(
681
- text=clean_unicode_text(ocr_result['text'][i]),
682
- left=ocr_result['left'][i],
683
- top=ocr_result['top'][i],
684
- width=ocr_result['width'][i],
685
- height=ocr_result['height'][i]#,
686
- #line_number=ocr_result['abs_line_id'][i]
687
  )
688
  for i in valid_indices
689
  ]
690
 
691
  def analyze_text(
692
- self,
693
- line_level_ocr_results: List[OCRResult],
694
  ocr_results_with_words: Dict[str, Dict],
695
  chosen_redact_comprehend_entities: List[str],
696
  pii_identification_method: str = LOCAL_PII_OPTION,
697
- comprehend_client = "",
698
- custom_entities:List[str]=custom_entities,
699
  language: Optional[str] = DEFAULT_LANGUAGE,
700
  nlp_analyser: AnalyzerEngine = None,
701
- **text_analyzer_kwargs
702
  ) -> List[CustomImageRecognizerResult]:
703
 
704
  page_text = ""
@@ -719,9 +866,11 @@ class CustomImageAnalyzerEngine:
719
  page_text_mapping.append((start_pos, i, line_level_ocr_result, None))
720
 
721
  # Determine language for downstream services
722
- aws_language = language or getattr(self, 'language', None) or 'en'
723
 
724
- valid_language_entities = nlp_analyser.registry.get_supported_entities(languages=[language])
 
 
725
  if "CUSTOM" not in valid_language_entities:
726
  valid_language_entities.append("CUSTOM")
727
  if "CUSTOM_FUZZY" not in valid_language_entities:
@@ -730,53 +879,52 @@ class CustomImageAnalyzerEngine:
730
  # Process using either Local or AWS Comprehend
731
  if pii_identification_method == LOCAL_PII_OPTION:
732
 
733
- language_supported_entities = filter_entities_for_language(custom_entities, valid_language_entities, language)
734
-
 
 
735
  if language_supported_entities:
736
  text_analyzer_kwargs["entities"] = language_supported_entities
737
 
738
- #if language != "en":
739
  # gr.Info(f"Using {str(language_supported_entities)} entities for local model analysis for language: {language}")
740
  else:
741
  print(f"No relevant entities supported for language: {language}")
742
- raise Warning(f"No relevant entities supported for language: {language}")
 
 
743
 
744
  analyzer_result = nlp_analyser.analyze(
745
- text=page_text,
746
- language=language,
747
- **text_analyzer_kwargs
748
  )
749
  all_text_line_results = map_back_entity_results(
750
- analyzer_result,
751
- page_text_mapping,
752
- all_text_line_results
753
  )
754
 
755
- elif pii_identification_method == AWS_PII_OPTION:
756
 
757
  # Handle custom entities first
758
  if custom_entities:
759
  custom_redact_entities = [
760
- entity for entity in chosen_redact_comprehend_entities
 
761
  if entity in custom_entities
762
  ]
763
 
764
  if custom_redact_entities:
765
  # Filter entities to only include those supported by the language
766
- language_supported_entities = filter_entities_for_language(custom_redact_entities, valid_language_entities, language)
 
 
767
 
768
  if language_supported_entities:
769
  text_analyzer_kwargs["entities"] = language_supported_entities
770
 
771
  page_analyser_result = nlp_analyser.analyze(
772
- text=page_text,
773
- language=language,
774
- **text_analyzer_kwargs
775
  )
776
  all_text_line_results = map_back_entity_results(
777
- page_analyser_result,
778
- page_text_mapping,
779
- all_text_line_results
780
  )
781
 
782
  # Process text in batches for AWS Comprehend
@@ -789,14 +937,14 @@ class CustomImageAnalyzerEngine:
789
  words = text_line.text.split()
790
  word_start_positions = list()
791
  current_pos = 0
792
-
793
  for word in words:
794
  word_start_positions.append(current_pos)
795
  current_pos += len(word) + 1
796
 
797
  for word_idx, word in enumerate(words):
798
  new_batch_char_count = len(current_batch) + len(word) + 1
799
-
800
  if batch_word_count >= 50 or new_batch_char_count >= 200:
801
  # Process current batch
802
  all_text_line_results = do_aws_comprehend_call(
@@ -804,17 +952,19 @@ class CustomImageAnalyzerEngine:
804
  current_batch_mapping,
805
  comprehend_client,
806
  aws_language,
807
- text_analyzer_kwargs.get('allow_list', []),
808
  chosen_redact_comprehend_entities,
809
- all_text_line_results
810
  )
811
  comprehend_query_number += 1
812
-
813
  # Reset batch
814
  current_batch = word
815
  batch_word_count = 1
816
  batch_char_count = len(word)
817
- current_batch_mapping = [(0, i, text_line, None, word_start_positions[word_idx])]
 
 
818
  else:
819
  if current_batch:
820
  current_batch += " "
@@ -822,15 +972,20 @@ class CustomImageAnalyzerEngine:
822
  current_batch += word
823
  batch_char_count += len(word)
824
  batch_word_count += 1
825
-
826
- if not current_batch_mapping or current_batch_mapping[-1][1] != i:
827
- current_batch_mapping.append((
828
- batch_char_count - len(word),
829
- i,
830
- text_line,
831
- None,
832
- word_start_positions[word_idx]
833
- ))
 
 
 
 
 
834
 
835
  # Process final batch if any
836
  if current_batch:
@@ -839,33 +994,39 @@ class CustomImageAnalyzerEngine:
839
  current_batch_mapping,
840
  comprehend_client,
841
  aws_language,
842
- text_analyzer_kwargs.get('allow_list', []),
843
  chosen_redact_comprehend_entities,
844
- all_text_line_results
845
  )
846
- comprehend_query_number += 1
847
 
848
  # Process results and create bounding boxes
849
  combined_results = list()
850
  for i, text_line in enumerate(line_level_ocr_results):
851
- line_results = next((results for idx, results in all_text_line_results if idx == i), [])
 
 
852
  if line_results and i < len(ocr_results_with_words):
853
  child_level_key = list(ocr_results_with_words.keys())[i]
854
- ocr_results_with_words_line_level = ocr_results_with_words[child_level_key]
855
-
 
 
856
  for result in line_results:
857
  bbox_results = self.map_analyzer_results_to_bounding_boxes(
858
  [result],
859
- [OCRResult(
860
- text=text_line.text[result.start:result.end],
861
- left=text_line.left,
862
- top=text_line.top,
863
- width=text_line.width,
864
- height=text_line.height
865
- )],
 
 
866
  text_line.text,
867
- text_analyzer_kwargs.get('allow_list', []),
868
- ocr_results_with_words_line_level
869
  )
870
  combined_results.extend(bbox_results)
871
 
@@ -873,61 +1034,65 @@ class CustomImageAnalyzerEngine:
873
 
874
  @staticmethod
875
  def map_analyzer_results_to_bounding_boxes(
876
- text_analyzer_results: List[RecognizerResult],
877
- redaction_relevant_ocr_results: List[OCRResult],
878
- full_text: str,
879
- allow_list: List[str],
880
- ocr_results_with_words_child_info: Dict[str, Dict]
881
- ) -> List[CustomImageRecognizerResult]:
882
  redaction_bboxes = list()
883
 
884
  for redaction_relevant_ocr_result in redaction_relevant_ocr_results:
885
- #print("ocr_results_with_words_child_info:", ocr_results_with_words_child_info)
886
 
887
- line_text = ocr_results_with_words_child_info['text']
888
  line_length = len(line_text)
889
  redaction_text = redaction_relevant_ocr_result.text
890
-
891
  for redaction_result in text_analyzer_results:
892
  # Check if the redaction text is not in the allow list
893
-
894
  if redaction_text not in allow_list:
895
-
896
  # Adjust start and end to be within line bounds
897
  start_in_line = max(0, redaction_result.start)
898
  end_in_line = min(line_length, redaction_result.end)
899
-
900
  # Get the matched text from this line
901
  matched_text = line_text[start_in_line:end_in_line]
902
- matched_words = matched_text.split()
903
-
904
  # Find the corresponding words in the OCR results
905
  matching_word_boxes = list()
906
 
907
  current_position = 0
908
 
909
- for word_info in ocr_results_with_words_child_info.get('words', []):
910
- word_text = word_info['text']
911
  word_length = len(word_text)
912
 
913
  word_start = current_position
914
  word_end = current_position + word_length
915
 
916
  # Update current position for the next word
917
- current_position += word_length + 1 # +1 for the space after the word
918
-
 
 
919
  # Check if the word's bounding box is within the start and end bounds
920
- if word_start >= start_in_line and word_end <= (end_in_line + 1):
921
- matching_word_boxes.append(word_info['bounding_box'])
922
- #print(f"Matched word: {word_info['text']}")
923
-
 
 
924
  if matching_word_boxes:
925
  # Calculate the combined bounding box for all matching words
926
  left = min(box[0] for box in matching_word_boxes)
927
  top = min(box[1] for box in matching_word_boxes)
928
  right = max(box[2] for box in matching_word_boxes)
929
  bottom = max(box[3] for box in matching_word_boxes)
930
-
931
  redaction_bboxes.append(
932
  CustomImageRecognizerResult(
933
  entity_type=redaction_result.entity_type,
@@ -938,12 +1103,12 @@ class CustomImageAnalyzerEngine:
938
  top=top,
939
  width=right - left,
940
  height=bottom - top,
941
- text=matched_text
942
  )
943
  )
944
 
945
  return redaction_bboxes
946
-
947
  @staticmethod
948
  def remove_space_boxes(ocr_result: dict) -> dict:
949
  """Remove OCR bboxes that are for spaces.
@@ -963,7 +1128,7 @@ class CustomImageAnalyzerEngine:
963
  filtered_ocr_result[key] = [ocr_result[key][i] for i in idx]
964
 
965
  return filtered_ocr_result
966
-
967
  @staticmethod
968
  def _scale_bbox_results(
969
  ocr_result: Dict[str, List[Union[int, str]]], scale_factor: float
@@ -994,67 +1159,87 @@ class CustomImageAnalyzerEngine:
994
  # Estimate the x-offset based on character position
995
  # This is a simple estimation and might need refinement for variable-width fonts
996
  return int(start / len(full_text) * len(full_text))
997
-
998
  def estimate_width(self, ocr_result: OCRResult, start: int, end: int) -> int:
999
  # Extract the relevant text portion
1000
  relevant_text = ocr_result.text[start:end]
1001
-
1002
  # If the relevant text is the same as the full text, return the full width
1003
  if relevant_text == ocr_result.text:
1004
  return ocr_result.width
1005
-
1006
  # Estimate width based on the proportion of the relevant text length to the total text length
1007
  total_text_length = len(ocr_result.text)
1008
  relevant_text_length = len(relevant_text)
1009
-
1010
  if total_text_length == 0:
1011
  return 0 # Avoid division by zero
1012
-
1013
  # Proportion of the relevant text to the total text
1014
  proportion = relevant_text_length / total_text_length
1015
-
1016
  # Estimate the width based on the proportion
1017
  estimated_width = int(proportion * ocr_result.width)
1018
-
1019
  return estimated_width
1020
 
1021
 
1022
- def bounding_boxes_overlap(box1:List, box2:List):
1023
  """Check if two bounding boxes overlap."""
1024
- return (box1[0] < box2[2] and box2[0] < box1[2] and
1025
- box1[1] < box2[3] and box2[1] < box1[3])
1026
-
1027
- def map_back_entity_results(page_analyser_result:dict, page_text_mapping:dict, all_text_line_results:List[Tuple]):
 
 
 
 
 
 
 
 
 
1028
  for entity in page_analyser_result:
1029
  entity_start = entity.start
1030
  entity_end = entity.end
1031
-
1032
  # Track if the entity has been added to any line
1033
  added_to_line = False
1034
-
1035
  for batch_start, line_idx, original_line, chars in page_text_mapping:
1036
  batch_end = batch_start + len(original_line.text)
1037
-
1038
  # Check if the entity overlaps with the current line
1039
- if batch_start < entity_end and batch_end > entity_start: # Overlap condition
1040
- relative_start = max(0, entity_start - batch_start) # Adjust start relative to the line
1041
- relative_end = min(entity_end - batch_start, len(original_line.text)) # Adjust end relative to the line
1042
-
 
 
 
 
 
 
1043
  # Create a new adjusted entity
1044
  adjusted_entity = copy.deepcopy(entity)
1045
  adjusted_entity.start = relative_start
1046
  adjusted_entity.end = relative_end
1047
-
1048
  # Check if this line already has an entry
1049
- existing_entry = next((entry for idx, entry in all_text_line_results if idx == line_idx), None)
1050
-
 
 
 
1051
  if existing_entry is None:
1052
  all_text_line_results.append((line_idx, [adjusted_entity]))
1053
  else:
1054
- existing_entry.append(adjusted_entity) # Append to the existing list of entities
1055
-
 
 
1056
  added_to_line = True
1057
-
1058
  # If the entity spans multiple lines, you may want to handle that here
1059
  if not added_to_line:
1060
  # Handle cases where the entity does not fit in any line (optional)
@@ -1062,7 +1247,14 @@ def map_back_entity_results(page_analyser_result:dict, page_text_mapping:dict, a
1062
 
1063
  return all_text_line_results
1064
 
1065
- def map_back_comprehend_entity_results(response:object, current_batch_mapping:List[Tuple], allow_list:List[str], chosen_redact_comprehend_entities:List[str], all_text_line_results:List[Tuple]):
 
 
 
 
 
 
 
1066
  if not response or "Entities" not in response:
1067
  return all_text_line_results
1068
 
@@ -1077,29 +1269,50 @@ def map_back_comprehend_entity_results(response:object, current_batch_mapping:Li
1077
  added_to_line = False
1078
 
1079
  # Find the correct line and offset within that line
1080
- for batch_start, line_idx, original_line, chars, line_offset in current_batch_mapping:
 
 
 
 
 
 
1081
  batch_end = batch_start + len(original_line.text[line_offset:])
1082
 
1083
  # Check if the entity overlaps with the current line
1084
- if batch_start < entity_end and batch_end > entity_start: # Overlap condition
 
 
1085
  # Calculate the absolute position within the line
1086
  relative_start = max(0, entity_start - batch_start + line_offset)
1087
- relative_end = min(entity_end - batch_start + line_offset, len(original_line.text))
 
 
1088
 
1089
  result_text = original_line.text[relative_start:relative_end]
1090
 
1091
  if result_text not in allow_list:
1092
  adjusted_entity = entity.copy()
1093
- adjusted_entity["BeginOffset"] = relative_start # Now relative to the full line
 
 
1094
  adjusted_entity["EndOffset"] = relative_end
1095
 
1096
  recogniser_entity = recognizer_result_from_dict(adjusted_entity)
1097
 
1098
- existing_entry = next((entry for idx, entry in all_text_line_results if idx == line_idx), None)
 
 
 
 
 
 
 
1099
  if existing_entry is None:
1100
  all_text_line_results.append((line_idx, [recogniser_entity]))
1101
  else:
1102
- existing_entry.append(recogniser_entity) # Append to the existing list of entities
 
 
1103
 
1104
  added_to_line = True
1105
 
@@ -1109,7 +1322,16 @@ def map_back_comprehend_entity_results(response:object, current_batch_mapping:Li
1109
 
1110
  return all_text_line_results
1111
 
1112
- def do_aws_comprehend_call(current_batch:str, current_batch_mapping:List[Tuple], comprehend_client:botocore.client.BaseClient, language:str, allow_list:List[str], chosen_redact_comprehend_entities:List[str], all_text_line_results:List[Tuple]):
 
 
 
 
 
 
 
 
 
1113
  if not current_batch:
1114
  return all_text_line_results
1115
 
@@ -1119,26 +1341,26 @@ def do_aws_comprehend_call(current_batch:str, current_batch_mapping:List[Tuple],
1119
  for attempt in range(max_retries):
1120
  try:
1121
  response = comprehend_client.detect_pii_entities(
1122
- Text=current_batch.strip(),
1123
- LanguageCode=language
1124
  )
1125
 
1126
  all_text_line_results = map_back_comprehend_entity_results(
1127
- response,
1128
- current_batch_mapping,
1129
- allow_list,
1130
- chosen_redact_comprehend_entities,
1131
- all_text_line_results
1132
  )
1133
 
1134
  return all_text_line_results
1135
-
1136
  except Exception as e:
1137
  if attempt == max_retries - 1:
1138
  print("AWS Comprehend calls failed due to", e)
1139
  raise
1140
  time.sleep(retry_delay)
1141
 
 
1142
  def run_page_text_redaction(
1143
  language: str,
1144
  chosen_redact_entities: List[str],
@@ -1147,13 +1369,13 @@ def run_page_text_redaction(
1147
  line_characters: List,
1148
  page_analyser_results: List = list(),
1149
  page_analysed_bounding_boxes: List = list(),
1150
- comprehend_client = None,
1151
  allow_list: List[str] = None,
1152
  pii_identification_method: str = LOCAL_PII_OPTION,
1153
  nlp_analyser: AnalyzerEngine = None,
1154
  score_threshold: float = 0.0,
1155
  custom_entities: List[str] = None,
1156
- comprehend_query_number:int = 0
1157
  ):
1158
  """
1159
  This function performs text redaction on a page based on the specified language and chosen entities.
@@ -1174,7 +1396,7 @@ def run_page_text_redaction(
1174
  custom_entities (List[str], optional): A list of custom entities for redaction. Defaults to None.
1175
  comprehend_query_number (int, optional): A counter for the number of Comprehend queries made. Defaults to 0.
1176
  """
1177
-
1178
  page_text = ""
1179
  page_text_mapping = list()
1180
  all_text_line_results = list()
@@ -1185,13 +1407,14 @@ def run_page_text_redaction(
1185
  if chosen_redact_entities:
1186
  if page_text:
1187
  page_text += " "
1188
-
1189
  start_pos = len(page_text)
1190
  page_text += text_line.text
1191
  page_text_mapping.append((start_pos, i, text_line, line_characters[i]))
1192
 
1193
-
1194
- valid_language_entities = nlp_analyser.registry.get_supported_entities(languages=[language])
 
1195
  if "CUSTOM" not in valid_language_entities:
1196
  valid_language_entities.append("CUSTOM")
1197
  if "CUSTOM_FUZZY" not in valid_language_entities:
@@ -1201,8 +1424,10 @@ def run_page_text_redaction(
1201
  if pii_identification_method == LOCAL_PII_OPTION:
1202
  if not nlp_analyser:
1203
  raise ValueError("nlp_analyser is required for Local identification method")
1204
-
1205
- language_supported_entities = filter_entities_for_language(chosen_redact_entities, valid_language_entities, language)
 
 
1206
 
1207
  page_analyser_result = nlp_analyser.analyze(
1208
  text=page_text,
@@ -1210,14 +1435,11 @@ def run_page_text_redaction(
1210
  entities=language_supported_entities,
1211
  score_threshold=score_threshold,
1212
  return_decision_process=True,
1213
- allow_list=allow_list
1214
  )
1215
 
1216
-
1217
  all_text_line_results = map_back_entity_results(
1218
- page_analyser_result,
1219
- page_text_mapping,
1220
- all_text_line_results
1221
  )
1222
 
1223
  elif pii_identification_method == AWS_PII_OPTION:
@@ -1225,11 +1447,14 @@ def run_page_text_redaction(
1225
  # Process custom entities if any
1226
  if custom_entities:
1227
  custom_redact_entities = [
1228
- entity for entity in chosen_redact_comprehend_entities
 
1229
  if entity in custom_entities
1230
  ]
1231
 
1232
- language_supported_entities = filter_entities_for_language(custom_redact_entities, valid_language_entities, language)
 
 
1233
 
1234
  if language_supported_entities:
1235
  page_analyser_result = nlp_analyser.analyze(
@@ -1238,13 +1463,11 @@ def run_page_text_redaction(
1238
  entities=language_supported_entities,
1239
  score_threshold=score_threshold,
1240
  return_decision_process=True,
1241
- allow_list=allow_list
1242
  )
1243
 
1244
  all_text_line_results = map_back_entity_results(
1245
- page_analyser_result,
1246
- page_text_mapping,
1247
- all_text_line_results
1248
  )
1249
 
1250
  current_batch = ""
@@ -1255,16 +1478,16 @@ def run_page_text_redaction(
1255
  for i, text_line in enumerate(line_level_text_results_list):
1256
  words = text_line.text.split()
1257
  word_start_positions = list()
1258
-
1259
  # Calculate word start positions within the line
1260
  current_pos = 0
1261
  for word in words:
1262
  word_start_positions.append(current_pos)
1263
  current_pos += len(word) + 1 # +1 for space
1264
-
1265
  for word_idx, word in enumerate(words):
1266
  new_batch_char_count = len(current_batch) + len(word) + 1
1267
-
1268
  if batch_word_count >= 50 or new_batch_char_count >= 200:
1269
  # Process current batch
1270
  all_text_line_results = do_aws_comprehend_call(
@@ -1274,15 +1497,23 @@ def run_page_text_redaction(
1274
  language,
1275
  allow_list,
1276
  chosen_redact_comprehend_entities,
1277
- all_text_line_results
1278
  )
1279
  comprehend_query_number += 1
1280
-
1281
  # Start new batch
1282
  current_batch = word
1283
  batch_word_count = 1
1284
  batch_char_count = len(word)
1285
- current_batch_mapping = [(0, i, text_line, line_characters[i], word_start_positions[word_idx])]
 
 
 
 
 
 
 
 
1286
  else:
1287
  if current_batch:
1288
  current_batch += " "
@@ -1290,15 +1521,19 @@ def run_page_text_redaction(
1290
  current_batch += word
1291
  batch_char_count += len(word)
1292
  batch_word_count += 1
1293
-
1294
  if not current_batch_mapping or current_batch_mapping[-1][1] != i:
1295
- current_batch_mapping.append((
1296
- batch_char_count - len(word),
1297
- i,
1298
- text_line,
1299
- line_characters[i],
1300
- word_start_positions[word_idx] # Add the word's start position within its line
1301
- ))
 
 
 
 
1302
 
1303
  # Process final batch
1304
  if current_batch:
@@ -1309,29 +1544,36 @@ def run_page_text_redaction(
1309
  language,
1310
  allow_list,
1311
  chosen_redact_comprehend_entities,
1312
- all_text_line_results
1313
  )
1314
  comprehend_query_number += 1
1315
 
1316
  # Process results for each line
1317
  for i, text_line in enumerate(line_level_text_results_list):
1318
- line_results = next((results for idx, results in all_text_line_results if idx == i), [])
1319
-
 
 
1320
  if line_results:
1321
  text_line_bounding_boxes = merge_text_bounding_boxes(
1322
- line_results,
1323
- line_characters[i]
1324
  )
1325
-
1326
  page_analyser_results.extend(line_results)
1327
  page_analysed_bounding_boxes.extend(text_line_bounding_boxes)
1328
 
1329
  return page_analysed_bounding_boxes
1330
 
1331
- def merge_text_bounding_boxes(analyser_results:dict, characters: List[LTChar], combine_pixel_dist: int = 20, vertical_padding: int = 0):
1332
- '''
 
 
 
 
 
 
1333
  Merge identified bounding boxes containing PII that are very close to one another
1334
- '''
1335
  analysed_bounding_boxes = list()
1336
  original_bounding_boxes = list() # List to hold original bounding boxes
1337
 
@@ -1339,9 +1581,17 @@ def merge_text_bounding_boxes(analyser_results:dict, characters: List[LTChar], c
1339
  # Extract bounding box coordinates for sorting
1340
  bounding_boxes = list()
1341
  for result in analyser_results:
1342
- #print("Result:", result)
1343
- char_boxes = [char.bbox for char in characters[result.start:result.end] if isinstance(char, LTChar)]
1344
- char_text = [char._text for char in characters[result.start:result.end] if isinstance(char, LTChar)]
 
 
 
 
 
 
 
 
1345
  if char_boxes:
1346
  # Calculate the bounding box that encompasses all characters
1347
  left = min(box[0] for box in char_boxes)
@@ -1349,11 +1599,19 @@ def merge_text_bounding_boxes(analyser_results:dict, characters: List[LTChar], c
1349
  right = max(box[2] for box in char_boxes)
1350
  top = max(box[3] for box in char_boxes) + vertical_padding
1351
  bbox = [left, bottom, right, top]
1352
- bounding_boxes.append((bottom, left, result, bbox, char_text)) # (y, x, result, bbox, text)
 
 
1353
 
1354
  # Store original bounding boxes
1355
- original_bounding_boxes.append({"text": "".join(char_text), "boundingBox": bbox, "result": copy.deepcopy(result)})
1356
- #print("Original bounding boxes:", original_bounding_boxes)
 
 
 
 
 
 
1357
 
1358
  # Sort the results by y-coordinate and then by x-coordinate
1359
  bounding_boxes.sort()
@@ -1375,19 +1633,26 @@ def merge_text_bounding_boxes(analyser_results:dict, characters: List[LTChar], c
1375
  vertical_diff_bboxes = abs(next_box[1] - current_y)
1376
  horizontal_diff_bboxes = abs(next_box[0] - current_box[2])
1377
 
1378
- if vertical_diff_bboxes <= 5 and horizontal_diff_bboxes <= combine_pixel_dist:
 
 
 
1379
  # Merge bounding boxes
1380
- #print("Merging boxes")
1381
  merged_box = current_box.copy()
1382
  merged_result = current_result
1383
  merged_text = current_text.copy()
1384
 
1385
  merged_box[2] = next_box[2] # Extend horizontally
1386
  merged_box[3] = max(current_box[3], next_box[3]) # Adjust the top
1387
- merged_result.end = max(current_result.end, result.end) # Extend text range
 
 
1388
  try:
1389
  if current_result.entity_type != result.entity_type:
1390
- merged_result.entity_type = current_result.entity_type + " - " + result.entity_type
 
 
1391
  else:
1392
  merged_result.entity_type = current_result.entity_type
1393
  except Exception as e:
@@ -1396,11 +1661,13 @@ def merge_text_bounding_boxes(analyser_results:dict, characters: List[LTChar], c
1396
  merged_text.append(" ") # Add space between texts
1397
  merged_text.extend(text)
1398
 
1399
- merged_bounding_boxes.append({
1400
- "text": "".join(merged_text),
1401
- "boundingBox": merged_box,
1402
- "result": merged_result
1403
- })
 
 
1404
 
1405
  else:
1406
  # Start a new bounding box
@@ -1413,18 +1680,21 @@ def merge_text_bounding_boxes(analyser_results:dict, characters: List[LTChar], c
1413
  analysed_bounding_boxes.extend(original_bounding_boxes)
1414
  analysed_bounding_boxes.extend(merged_bounding_boxes)
1415
 
1416
- #print("Analysed bounding boxes:", analysed_bounding_boxes)
1417
 
1418
  return analysed_bounding_boxes
1419
 
1420
- def recreate_page_line_level_ocr_results_with_page(page_line_level_ocr_results_with_words: dict):
 
 
 
1421
  reconstructed_results = list()
1422
-
1423
  # Assume all lines belong to the same page, so we can just read it from one item
1424
- #page = next(iter(page_line_level_ocr_results_with_words.values()))["page"]
1425
 
1426
  page = page_line_level_ocr_results_with_words["page"]
1427
-
1428
  for line_data in page_line_level_ocr_results_with_words["results"].values():
1429
  bbox = line_data["bounding_box"]
1430
  text = line_data["text"]
@@ -1438,15 +1708,21 @@ def recreate_page_line_level_ocr_results_with_page(page_line_level_ocr_results_w
1438
  top=bbox[1],
1439
  width=bbox[2] - bbox[0],
1440
  height=bbox[3] - bbox[1],
1441
- line=line_number
1442
  )
1443
  reconstructed_results.append(line_result)
1444
-
1445
- page_line_level_ocr_results_with_page = {"page": page, "results": reconstructed_results}
1446
-
 
 
 
1447
  return page_line_level_ocr_results_with_page
1448
 
1449
- def split_words_and_punctuation_from_line(line_of_words: List[OCRResult]) -> List[OCRResult]:
 
 
 
1450
  """
1451
  Takes a list of OCRResult objects and splits words with trailing/leading punctuation.
1452
 
@@ -1455,85 +1731,119 @@ def split_words_and_punctuation_from_line(line_of_words: List[OCRResult]) -> Lis
1455
  "high-tech" are preserved.
1456
  """
1457
  # Punctuation that will be split off. Hyphen is not included.
1458
- PUNCTUATION_TO_SPLIT = {'.', ',', '?', '!', ':', ';', '(', ')', '[', ']', '{', '}'}
1459
-
1460
  new_word_list = list()
1461
-
1462
  for word_result in line_of_words:
1463
  word_text = word_result.text
1464
-
1465
  # This regex finds a central "core" word, and captures leading and trailing punctuation
1466
  # Handles cases like "(word)." -> group1='(', group2='word', group3='.'
1467
  match = re.match(r"([(\[{]*)(.*?)_?([.,?!:;)\}\]]*)$", word_text)
1468
 
1469
  # Handle words with internal hyphens that might confuse the regex
1470
- if '-' in word_text and not match.group(2):
1471
- core_part_text = word_text
1472
- leading_punc = ""
1473
- trailing_punc = ""
1474
  elif match:
1475
  leading_punc, core_part_text, trailing_punc = match.groups()
1476
- else: # Failsafe
1477
  new_word_list.append(word_result)
1478
  continue
1479
-
1480
  # If no split is needed, just add the original and continue
1481
  if not leading_punc and not trailing_punc:
1482
  new_word_list.append(word_result)
1483
  continue
1484
-
1485
  # --- A split is required ---
1486
  # Estimate new bounding boxes by proportionally allocating width
1487
  original_width = word_result.width
1488
- if not word_text or original_width == 0: continue # Failsafe
1489
-
 
1490
  avg_char_width = original_width / len(word_text)
1491
  current_left = word_result.left
1492
 
1493
  # Add leading punctuation if it exists
1494
  if leading_punc:
1495
  punc_width = avg_char_width * len(leading_punc)
1496
- new_word_list.append(OCRResult(
1497
- text=leading_punc, left=current_left, top=word_result.top,
1498
- width=punc_width, height=word_result.height
1499
- ))
 
 
 
 
 
1500
  current_left += punc_width
1501
 
1502
  # Add the core part of the word
1503
  if core_part_text:
1504
  core_width = avg_char_width * len(core_part_text)
1505
- new_word_list.append(OCRResult(
1506
- text=core_part_text, left=current_left, top=word_result.top,
1507
- width=core_width, height=word_result.height
1508
- ))
 
 
 
 
 
1509
  current_left += core_width
1510
 
1511
  # Add trailing punctuation if it exists
1512
  if trailing_punc:
1513
  punc_width = avg_char_width * len(trailing_punc)
1514
- new_word_list.append(OCRResult(
1515
- text=trailing_punc, left=current_left, top=word_result.top,
1516
- width=punc_width, height=word_result.height
1517
- ))
1518
-
 
 
 
 
 
1519
  return new_word_list
1520
 
1521
- def create_ocr_result_with_children(combined_results:dict, i:int, current_bbox:dict, current_line:list):
1522
- combined_results["text_line_" + str(i)] = {
 
 
 
1523
  "line": i,
1524
- 'text': current_bbox.text,
1525
- 'bounding_box': (current_bbox.left, current_bbox.top,
1526
- current_bbox.left + current_bbox.width,
1527
- current_bbox.top + current_bbox.height),
1528
- 'words': [{'text': word.text,
1529
- 'bounding_box': (word.left, word.top,
1530
- word.left + word.width,
1531
- word.top + word.height)}
1532
- for word in current_line]
 
 
 
 
 
 
 
 
 
 
1533
  }
1534
- return combined_results["text_line_" + str(i)]
 
1535
 
1536
- def combine_ocr_results(ocr_results: List[OCRResult], x_threshold: float = 50.0, y_threshold: float = 12.0, page: int = 1):
 
 
 
 
 
1537
  """
1538
  Group OCR results into lines, splitting words from punctuation.
1539
  """
@@ -1568,28 +1878,36 @@ def combine_ocr_results(ocr_results: List[OCRResult], x_threshold: float = 50.0,
1568
  line_top = min(word.top for word in line)
1569
  line_right = max(word.left + word.width for word in line)
1570
  line_bottom = max(word.top + word.height for word in line)
1571
-
1572
  final_line_bbox = OCRResult(
1573
  text=line_text,
1574
  left=line_left,
1575
  top=line_top,
1576
  width=line_right - line_left,
1577
  height=line_bottom - line_top,
1578
- line=line_counter
1579
  )
1580
-
1581
  page_line_level_ocr_results.append(final_line_bbox)
1582
-
1583
  # Use the PROCESSED line to create the children. Creates a result within page_line_level_ocr_results_with_words
1584
- page_line_level_ocr_results_with_words["text_line_" + str(line_counter)] = create_ocr_result_with_children(
1585
- page_line_level_ocr_results_with_words,
1586
- line_counter,
1587
- final_line_bbox,
1588
- processed_line # <-- Use the new, split list of words
 
 
1589
  )
1590
  line_counter += 1
1591
 
1592
- page_level_results_with_page = {"page": page, "results": page_line_level_ocr_results}
1593
- page_level_results_with_words = {"page": page, "results": page_line_level_ocr_results_with_words}
 
 
 
 
 
 
1594
 
1595
  return page_level_results_with_page, page_level_results_with_words
 
 
 
 
 
 
 
 
 
 
 
1
  import copy
2
+ import re
3
+ import time
4
  from copy import deepcopy
5
+ from dataclasses import dataclass
6
+ from typing import Any, Dict, List, Optional, Tuple, Union
7
+
8
+ import botocore
9
+ import cv2
10
+ import gradio as gr
11
+ import numpy as np
12
+ import pytesseract
13
  from pdfminer.layout import LTChar
14
  from PIL import Image
15
+ from presidio_analyzer import AnalyzerEngine, RecognizerResult
16
+
17
+ from tools.config import (
18
+ AWS_PII_OPTION,
19
+ DEFAULT_LANGUAGE,
20
+ LOCAL_PII_OPTION,
21
+ PREPROCESS_LOCAL_OCR_IMAGES,
22
+ )
23
  from tools.helper_functions import clean_unicode_text
 
24
  from tools.load_spacy_model_custom_recognisers import custom_entities
25
+ from tools.presidio_analyzer_custom import recognizer_result_from_dict
26
 
27
+ if PREPROCESS_LOCAL_OCR_IMAGES == "True":
28
+ PREPROCESS_LOCAL_OCR_IMAGES = True
29
+ else:
30
+ PREPROCESS_LOCAL_OCR_IMAGES = False
31
 
32
  try:
33
  from paddleocr import PaddleOCR
34
  except ImportError:
35
  PaddleOCR = None
36
 
37
+
38
  # --- Language utilities ---
39
  def _normalize_lang(language: str) -> str:
40
  return language.strip().lower().replace("-", "_") if language else "en"
 
46
 
47
  mapping = {
48
  # Common
49
+ "en": "eng",
50
+ "eng": "eng",
51
+ "fr": "fra",
52
+ "fre": "fra",
53
+ "fra": "fra",
54
+ "de": "deu",
55
+ "ger": "deu",
56
+ "deu": "deu",
57
+ "es": "spa",
58
+ "spa": "spa",
59
+ "it": "ita",
60
+ "ita": "ita",
61
+ "nl": "nld",
62
+ "dut": "nld",
63
+ "nld": "nld",
64
+ "pt": "por",
65
+ "por": "por",
66
+ "ru": "rus",
67
+ "rus": "rus",
68
+ "ar": "ara",
69
+ "ara": "ara",
70
  # Nordics
71
+ "sv": "swe",
72
+ "swe": "swe",
73
+ "no": "nor",
74
+ "nb": "nor",
75
+ "nn": "nor",
76
+ "nor": "nor",
77
+ "fi": "fin",
78
+ "fin": "fin",
79
+ "da": "dan",
80
+ "dan": "dan",
81
  # Eastern/Central
82
+ "pl": "pol",
83
+ "pol": "pol",
84
+ "cs": "ces",
85
+ "cz": "ces",
86
+ "ces": "ces",
87
+ "hu": "hun",
88
+ "hun": "hun",
89
+ "ro": "ron",
90
+ "rum": "ron",
91
+ "ron": "ron",
92
+ "bg": "bul",
93
+ "bul": "bul",
94
+ "el": "ell",
95
+ "gre": "ell",
96
+ "ell": "ell",
97
  # Asian
98
+ "ja": "jpn",
99
+ "jp": "jpn",
100
+ "jpn": "jpn",
101
+ "zh": "chi_sim",
102
+ "zh_cn": "chi_sim",
103
+ "zh_hans": "chi_sim",
104
+ "chi_sim": "chi_sim",
105
+ "zh_tw": "chi_tra",
106
+ "zh_hk": "chi_tra",
107
+ "zh_tr": "chi_tra",
108
+ "chi_tra": "chi_tra",
109
+ "hi": "hin",
110
+ "hin": "hin",
111
+ "bn": "ben",
112
+ "ben": "ben",
113
+ "ur": "urd",
114
+ "urd": "urd",
115
+ "fa": "fas",
116
+ "per": "fas",
117
+ "fas": "fas",
118
  }
119
 
120
  return mapping.get(lang, "eng")
 
155
 
156
  return mapping.get(lang, "en")
157
 
158
+
159
  @dataclass
160
  class OCRResult:
161
  text: str
 
166
  conf: float = None
167
  line: int = None
168
 
169
+
170
  @dataclass
171
  class CustomImageRecognizerResult:
172
  entity_type: str
 
177
  top: int
178
  width: int
179
  height: int
180
+ text: str
181
+
182
+
183
  class ImagePreprocessor:
184
  """ImagePreprocessor class. Parent class for image preprocessing objects."""
185
+
186
  def __init__(self, use_greyscale: bool = True) -> None:
187
  self.use_greyscale = use_greyscale
188
 
 
199
  return img
200
 
201
  @staticmethod
202
+ def _get_bg_color(
203
+ image: np.ndarray, is_greyscale: bool, invert: bool = False
204
+ ) -> Union[int, Tuple[int, int, int]]:
205
  # Note: Modified to expect numpy array for bincount
206
  if invert:
207
+ image = 255 - image # Simple inversion for greyscale numpy array
208
+
209
  if is_greyscale:
210
  bg_color = int(np.bincount(image.flatten()).argmax())
211
  else:
 
213
  # For this pipeline, we only use greyscale, so it's fine.
214
  # A simple alternative:
215
  from scipy import stats
216
+
217
  bg_color = tuple(stats.mode(image.reshape(-1, 3), axis=0)[0][0])
218
  return bg_color
219
 
 
222
  contrast = np.std(image)
223
  mean_intensity = np.mean(image)
224
  return contrast, mean_intensity
225
+
226
+
227
  class BilateralFilter(ImagePreprocessor):
228
  """Applies bilateral filtering."""
229
+
230
+ def __init__(
231
+ self, diameter: int = 9, sigma_color: int = 75, sigma_space: int = 75
232
+ ) -> None:
233
  super().__init__(use_greyscale=True)
234
  self.diameter = diameter
235
  self.sigma_color = sigma_color
 
237
 
238
  def preprocess_image(self, image: np.ndarray) -> Tuple[np.ndarray, dict]:
239
  # Modified to accept and return numpy array for consistency in the pipeline
240
+ filtered_image = cv2.bilateralFilter(
241
+ image, self.diameter, self.sigma_color, self.sigma_space
242
+ )
243
+ metadata = {
244
+ "diameter": self.diameter,
245
+ "sigma_color": self.sigma_color,
246
+ "sigma_space": self.sigma_space,
247
+ }
248
  return filtered_image, metadata
249
+
250
+
251
  class SegmentedAdaptiveThreshold(ImagePreprocessor):
252
  """Applies adaptive thresholding."""
253
+
254
+ def __init__(
255
+ self,
256
+ block_size: int = 21,
257
+ contrast_threshold: int = 40,
258
+ c_low_contrast: int = 5,
259
+ c_high_contrast: int = 10,
260
+ bg_threshold: int = 127,
261
+ ) -> None:
262
  super().__init__(use_greyscale=True)
263
+ self.block_size = (
264
+ block_size if block_size % 2 == 1 else block_size + 1
265
+ ) # Ensure odd
266
  self.c_low_contrast = c_low_contrast
267
  self.c_high_contrast = c_high_contrast
268
  self.bg_threshold = bg_threshold
 
272
  # Modified to accept and return numpy array
273
  background_color = self._get_bg_color(image, True)
274
  contrast, _ = self._get_image_contrast(image)
275
+ c = (
276
+ self.c_low_contrast
277
+ if contrast <= self.contrast_threshold
278
+ else self.c_high_contrast
279
+ )
280
 
281
+ if background_color < self.bg_threshold: # Dark background, light text
282
  adaptive_threshold_image = cv2.adaptiveThreshold(
283
+ image,
284
+ 255,
285
+ cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
286
+ cv2.THRESH_BINARY_INV,
287
+ self.block_size,
288
+ -c,
289
  )
290
+ else: # Light background, dark text
291
  adaptive_threshold_image = cv2.adaptiveThreshold(
292
+ image,
293
+ 255,
294
+ cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
295
+ cv2.THRESH_BINARY,
296
+ self.block_size,
297
+ c,
298
  )
299
  metadata = {"C": c, "background_color": background_color, "contrast": contrast}
300
  return adaptive_threshold_image, metadata
301
+
302
+
303
  class ImageRescaling(ImagePreprocessor):
304
  """Rescales images based on their size."""
305
+
306
  def __init__(self, target_dpi: int = 300, assumed_input_dpi: int = 96) -> None:
307
  super().__init__(use_greyscale=True)
308
  self.target_dpi = target_dpi
 
317
  width = int(image.shape[1] * scale_factor)
318
  height = int(image.shape[0] * scale_factor)
319
  dimensions = (width, height)
320
+
321
  # Use better interpolation for upscaling vs downscaling
322
  interpolation = cv2.INTER_CUBIC if scale_factor > 1.0 else cv2.INTER_AREA
323
  rescaled_image = cv2.resize(image, dimensions, interpolation=interpolation)
324
  metadata["scale_factor"] = scale_factor
325
  return rescaled_image, metadata
326
+
327
  return image, metadata
328
 
329
+
330
  class ContrastSegmentedImageEnhancer(ImagePreprocessor):
331
  """Class containing all logic to perform contrastive segmentation."""
332
+
333
  def __init__(
334
  self,
335
  bilateral_filter: Optional[BilateralFilter] = None,
 
355
  adjusted_contrast = contrast
356
  return adjusted_image, contrast, adjusted_contrast
357
 
358
+ def preprocess_image(
359
+ self, image: Image.Image, perform_binarization: bool = False
360
+ ) -> Tuple[Image.Image, dict]:
361
  """
362
  A corrected, logical pipeline for OCR preprocessing.
363
  Order: Greyscale -> Rescale -> Denoise -> Enhance Contrast -> Binarize
 
368
  image_np = self.convert_image_to_array(image)
369
 
370
  # 2. Rescale image to optimal DPI (while still greyscale)
371
+ rescaled_image_np, scale_metadata = self.image_rescaling.preprocess_image(
372
+ image_np
373
+ )
374
 
375
  # 3. Apply bilateral filtering for noise reduction
376
  filtered_image_np, _ = self.bilateral_filter.preprocess_image(rescaled_image_np)
 
380
 
381
  # 5. Adaptive Thresholding (Binarization) - This is the final step
382
  if perform_binarization:
383
+ final_image_np, threshold_metadata = (
384
+ self.adaptive_threshold.preprocess_image(adjusted_image_np)
385
  )
386
  else:
387
  final_image_np = adjusted_image_np
388
  threshold_metadata = {}
389
+
390
  # Combine metadata
391
  final_metadata = {**scale_metadata, **threshold_metadata}
392
+
393
  # Convert final numpy array back to PIL Image for return
394
  return Image.fromarray(final_image_np), final_metadata
395
 
396
+
397
+ def rescale_ocr_data(ocr_data, scale_factor: float):
398
+
399
  # We loop from 0 to the number of detected words.
400
+ num_boxes = len(ocr_data["text"])
401
  for i in range(num_boxes):
402
  # We only want to process actual words, not empty boxes Tesseract might find
403
+ if int(ocr_data["conf"][i]) > -1: # -1 confidence is for structural elements
404
  # Get coordinates from the processed image using the index 'i'
405
+ x_proc = ocr_data["left"][i]
406
+ y_proc = ocr_data["top"][i]
407
+ w_proc = ocr_data["width"][i]
408
+ h_proc = ocr_data["height"][i]
409
 
410
  # Apply the inverse transformation (division)
411
  x_orig = int(x_proc / scale_factor)
 
415
 
416
  # --- THE MAPPING STEP ---
417
  # Update the dictionary values in-place using the same index 'i'
418
+ ocr_data["left"][i] = x_orig
419
+ ocr_data["top"][i] = y_orig
420
+ ocr_data["width"][i] = w_orig
421
+ ocr_data["height"][i] = h_orig
422
+
423
  return ocr_data
424
 
425
+
426
+ def filter_entities_for_language(
427
+ entities: List[str], valid_language_entities: List[str], language: str
428
+ ) -> List[str]:
429
 
430
  if not valid_language_entities:
431
  print(f"No valid entities supported for language: {language}")
432
+ # raise Warning(f"No valid entities supported for language: {language}")
433
  if not entities:
434
  print(f"No entities provided for language: {language}")
435
+ # raise Warning(f"No entities provided for language: {language}")
436
+
437
+ # print("entities:", entities)
438
+ # print("valid_language_entities:", valid_language_entities)
439
+ # print("language:", language)
440
 
441
  filtered_entities = [
442
+ entity for entity in entities if entity in valid_language_entities
 
443
  ]
444
 
445
  if not filtered_entities:
446
  print(f"No relevant entities supported for language: {language}")
447
+ # raise Warning(f"No relevant entities supported for language: {language}")
448
 
449
  if language != "en":
450
+ gr.Info(
451
+ f"Using {str(filtered_entities)} entities for local model analysis for language: {language}"
452
+ )
453
 
454
  return filtered_entities
455
 
456
+
457
  class CustomImageAnalyzerEngine:
458
  def __init__(
459
  self,
460
  analyzer_engine: Optional[AnalyzerEngine] = None,
461
+ ocr_engine: str = "tesseract",
462
  tesseract_config: Optional[str] = None,
463
  paddle_kwargs: Optional[Dict[str, Any]] = None,
464
  image_preprocessor: Optional[ImagePreprocessor] = None,
465
+ language: Optional[str] = DEFAULT_LANGUAGE,
466
  ):
467
  """
468
  Initializes the CustomImageAnalyzerEngine.
 
475
  :param language: Preferred OCR language (e.g., "en", "fr", "de"). Defaults to DEFAULT_LANGUAGE.
476
  """
477
  if ocr_engine not in ["tesseract", "paddle", "hybrid"]:
478
+ raise ValueError(
479
+ "ocr_engine must be either 'tesseract', 'hybrid', or 'paddle'"
480
+ )
481
 
482
  self.ocr_engine = ocr_engine
483
 
 
485
  self.language = language or DEFAULT_LANGUAGE or "en"
486
  self.tesseract_lang = _tesseract_lang_code(self.language)
487
  self.paddle_lang = _paddle_lang_code(self.language)
488
+
489
  if self.ocr_engine == "paddle" or self.ocr_engine == "hybrid":
490
  if PaddleOCR is None:
491
+ raise ImportError(
492
+ "paddleocr is not installed. Please run 'pip install paddleocr paddlepaddle'"
493
+ )
494
  # Default paddle configuration if none provided
495
  if paddle_kwargs is None:
496
+ paddle_kwargs = {
497
+ "use_textline_orientation": True,
498
+ "lang": self.paddle_lang,
499
+ }
500
  else:
501
  # Enforce language if not explicitly provided
502
+ paddle_kwargs.setdefault("lang", self.paddle_lang)
503
  self.paddle_ocr = PaddleOCR(**paddle_kwargs)
504
 
505
  if not analyzer_engine:
506
  analyzer_engine = AnalyzerEngine()
507
  self.analyzer_engine = analyzer_engine
508
 
509
+ self.tesseract_config = tesseract_config or "--oem 3 --psm 11"
510
 
511
  if not image_preprocessor:
512
  image_preprocessor = ContrastSegmentedImageEnhancer()
 
515
  def _sanitize_filename(self, text: str, max_length: int = 20) -> str:
516
  """
517
  Sanitizes text for use in filenames by removing invalid characters and limiting length.
518
+
519
  :param text: The text to sanitize
520
  :param max_length: Maximum length of the sanitized text
521
  :return: Sanitized text safe for filenames
522
  """
523
+
524
  # Remove or replace invalid filename characters
525
  # Windows: < > : " | ? * \ /
526
  # Unix: / (forward slash)
527
  # Also remove control characters and other problematic chars
528
  invalid_chars = r'[<>:"|?*\\/\x00-\x1f\x7f-\x9f]'
529
+ sanitized = re.sub(invalid_chars, "_", text)
530
+
531
  # Replace multiple consecutive underscores with a single one
532
+ sanitized = re.sub(r"_+", "_", sanitized)
533
+
534
  # Remove leading/trailing underscores and spaces
535
+ sanitized = sanitized.strip("_ ")
536
+
537
  # If empty after sanitization, use a default value
538
  if not sanitized:
539
+ sanitized = "text"
540
+
541
  # Limit to max_length characters
542
  if len(sanitized) > max_length:
543
  sanitized = sanitized[:max_length]
544
  # Ensure we don't end with an underscore if we cut in the middle
545
+ sanitized = sanitized.rstrip("_")
546
+
547
  return sanitized
548
 
549
+ def _convert_paddle_to_tesseract_format(
550
+ self, paddle_results: List[Any]
551
+ ) -> Dict[str, List]:
552
  """Converts PaddleOCR result format to Tesseract's dictionary format. NOTE: This attempts to create word-level bounding boxes by estimating the distance between characters in sentence-level text output. This is currently quite inaccurate, and word-level bounding boxes should not be relied upon."""
553
 
554
+ output = {
555
+ "text": [],
556
+ "left": [],
557
+ "top": [],
558
+ "width": [],
559
+ "height": [],
560
+ "conf": [],
561
+ }
562
 
563
  # paddle_results is now a list of dictionaries with detailed information
564
  if not paddle_results:
565
  return output
566
+
567
  for page_result in paddle_results:
568
  # Extract text recognition results from the new format
569
+ rec_texts = page_result.get("rec_texts", [])
570
+ rec_scores = page_result.get("rec_scores", [])
571
+ rec_polys = page_result.get("rec_polys", [])
572
+
573
+ for line_text, line_confidence, bounding_box in zip(
574
+ rec_texts, rec_scores, rec_polys
575
+ ):
576
  # bounding_box is now a numpy array with shape (4, 2)
577
  # Convert to list of coordinates if it's a numpy array
578
+ if hasattr(bounding_box, "tolist"):
579
  box = bounding_box.tolist()
580
  else:
581
  box = bounding_box
582
+
583
  # box is [[x1,y1], [x2,y2], [x3,y3], [x4,y4]]
584
  x_coords = [p[0] for p in box]
585
  y_coords = [p[1] for p in box]
586
+
587
  line_left = int(min(x_coords))
588
  line_top = int(min(y_coords))
589
  line_width = int(max(x_coords) - line_left)
590
  line_height = int(max(y_coords) - line_top)
591
+ # line_y_center = (max(y_coords) + min(y_coords)) / 2
592
 
593
+ # 2. Split the line into words
594
  words = line_text.split()
595
  if not words:
596
  continue
 
605
  for word in words:
606
  word_width = int(len(word) * avg_char_width)
607
  word_left = line_left + int(current_char_offset * avg_char_width)
608
+
609
+ output["text"].append(word)
610
+ output["left"].append(word_left)
611
+ output["top"].append(line_top)
612
+ output["width"].append(word_width)
613
+ output["height"].append(line_height)
614
  # Use the line's confidence for each word derived from it
615
+ output["conf"].append(int(line_confidence * 100))
616
 
617
  # Update offset for the next word (add word length + 1 for the space)
618
  current_char_offset += len(word) + 1
619
+
620
  return output
621
+
622
  def _perform_hybrid_ocr(
623
+ self,
624
+ image: Image.Image,
625
+ confidence_threshold: int = 65,
626
+ padding: int = 5,
627
+ ocr: Optional[Any] = None,
628
+ ) -> Dict[str, list]:
629
  """
630
  Performs OCR using Tesseract for bounding boxes and PaddleOCR for low-confidence text.
631
  Returns data in the same dictionary format as pytesseract.image_to_data.
632
  """
633
  if ocr is None:
634
+ if hasattr(self, "paddle_ocr") and self.paddle_ocr is not None:
635
  ocr = self.paddle_ocr
636
  else:
637
+ raise ValueError(
638
+ "No OCR object provided and 'paddle_ocr' is not initialized."
639
+ )
640
+
641
  print("Starting hybrid OCR process...")
642
+
643
  # 1. Get initial word-level results from Tesseract
644
  tesseract_data = pytesseract.image_to_data(
645
  image,
646
  output_type=pytesseract.Output.DICT,
647
  config=self.tesseract_config,
648
+ lang=self.tesseract_lang,
649
  )
650
 
651
+ # tesseract_data['abs_line_id'] = tesseract_data.groupby(['block_num', 'par_num', 'line_num']).ngroup()
652
+
653
+ final_data = {
654
+ "text": [],
655
+ "left": [],
656
+ "top": [],
657
+ "width": [],
658
+ "height": [],
659
+ "conf": [],
660
+ }
661
+
662
+ num_words = len(tesseract_data["text"])
663
 
664
  # This handles the "no text on page" case. If num_words is 0, the loop is skipped
665
  # and an empty dictionary with empty lists is returned, which is the correct behavior.
666
  for i in range(num_words):
667
+ text = tesseract_data["text"][i]
668
+ conf = int(tesseract_data["conf"][i])
669
+
670
  # Skip empty text boxes or non-word elements (like page/block markers)
671
  if not text.strip() or conf == -1:
672
  continue
673
 
674
+ left = tesseract_data["left"][i]
675
+ top = tesseract_data["top"][i]
676
+ width = tesseract_data["width"][i]
677
+ height = tesseract_data["height"][i]
678
+ # line_number = tesseract_data['abs_line_id'][i]
679
+
680
  # If confidence is low, use PaddleOCR for a second opinion
681
  if conf < confidence_threshold:
682
  img_width, img_height = image.size
 
684
  crop_top = max(0, top - padding)
685
  crop_right = min(img_width, left + width + padding + 15)
686
  crop_bottom = min(img_height, top + height + padding)
687
+
688
  # Ensure crop dimensions are valid
689
  if crop_right <= crop_left or crop_bottom <= crop_top:
690
+ continue # Skip invalid crops
691
 
692
+ cropped_image = image.crop(
693
+ (crop_left, crop_top, crop_right, crop_bottom)
694
+ )
695
  cropped_image_np = np.array(cropped_image)
696
+
697
  if len(cropped_image_np.shape) == 2:
698
  cropped_image_np = np.stack([cropped_image_np] * 3, axis=-1)
699
+
700
  paddle_results = ocr.predict(cropped_image_np)
701
+
702
  if paddle_results and paddle_results[0]:
703
+ rec_texts = paddle_results[0].get("rec_texts", [])
704
+ rec_scores = paddle_results[0].get("rec_scores", [])
705
+
706
  if rec_texts and rec_scores:
707
  new_text = " ".join(rec_texts)
708
+ new_conf = int(round(np.median(rec_scores) * 100, 0))
709
 
710
  # Only replace if Paddle's confidence is better
711
  if new_conf > conf:
712
+ print(
713
+ f" Re-OCR'd word: '{text}' (conf: {conf}) -> '{new_text}' (conf: {new_conf:.0f})"
714
+ )
715
 
716
  # For exporting example image comparisons, not used here
717
  safe_text = self._sanitize_filename(text, max_length=20)
718
+ self._sanitize_filename(new_text, max_length=20)
719
  output_image_path = f"examples/tess_vs_paddle_examples/{conf}_conf_{safe_text}_to_{new_text}_{new_conf}.png"
720
  cropped_image.save(output_image_path)
721
 
722
  text = new_text
723
  conf = new_conf
724
+
725
  else:
726
+ print(
727
+ f" '{text}' (conf: {conf}) -> Paddle result '{new_text}' (conf: {new_conf:.0f}) was not better. Keeping original."
728
+ )
729
  else:
730
  # Paddle ran but found nothing, so discard the original low-confidence word
731
+ print(
732
+ f" '{text}' (conf: {conf}) -> No text found by Paddle. Discarding."
733
+ )
734
+ text = ""
735
  else:
736
  # Paddle found nothing, discard original word
737
+ print(
738
+ f" '{text}' (conf: {conf}) -> No text found by Paddle. Discarding."
739
+ )
740
+ text = ""
741
 
742
  # Append the final result (either original, replaced, or skipped if empty)
743
  if text.strip():
744
+ final_data["text"].append(clean_unicode_text(text))
745
+ final_data["left"].append(left)
746
+ final_data["top"].append(top)
747
+ final_data["width"].append(width)
748
+ final_data["height"].append(height)
749
+ final_data["conf"].append(int(conf))
750
+ # final_data['line_number'].append(int(line_number))
751
+
752
  return final_data
753
+
754
+ def perform_ocr(
755
+ self, image: Union[str, Image.Image, np.ndarray], ocr: Optional[Any] = None
756
+ ) -> List[OCRResult]:
757
  """
758
  Performs OCR on the given image using the configured engine.
759
  """
 
761
  image = Image.open(image)
762
  elif isinstance(image, np.ndarray):
763
  image = Image.fromarray(image)
764
+
765
  # Pre-process image - currently seems to give worse results!
766
+ if str(PREPROCESS_LOCAL_OCR_IMAGES).lower() == "true":
767
+ image, preprocessing_metadata = self.image_preprocessor.preprocess_image(
768
+ image
769
+ )
770
  else:
771
  preprocessing_metadata = {}
772
 
 
781
  image,
782
  output_type=pytesseract.Output.DICT,
783
  config=self.tesseract_config,
784
+ lang=self.tesseract_lang, # Ensure the Tesseract language data (e.g., fra.traineddata) is installed on your system.
785
  )
786
 
787
+ # ocr_data['abs_line_id'] = ocr_data.groupby(['block_num', 'par_num', 'line_num']).ngroup()
788
 
789
  elif self.ocr_engine == "paddle":
790
 
791
+ image_np = np.array(image) # image_processed
792
+
793
  # PaddleOCR may need an RGB image. Ensure it has 3 channels.
794
  if len(image_np.shape) == 2:
795
  image_np = np.stack([image_np] * 3, axis=-1)
796
 
797
  if ocr is None:
798
+ if hasattr(self, "paddle_ocr") and self.paddle_ocr is not None:
799
  ocr = self.paddle_ocr
800
  else:
801
+ raise ValueError(
802
+ "No OCR object provided and 'paddle_ocr' is not initialised."
803
+ )
804
 
805
+ # ocr = PaddleOCR(use_textline_orientation=True, lang='en')
806
  paddle_results = ocr.predict(image_np)
807
  ocr_data = self._convert_paddle_to_tesseract_format(paddle_results)
808
 
809
  else:
810
  raise RuntimeError(f"Unsupported OCR engine: {self.ocr_engine}")
811
+
812
  if preprocessing_metadata:
813
+ scale_factor = preprocessing_metadata.get("scale_factor", 1.0)
814
  ocr_data = rescale_ocr_data(ocr_data, scale_factor)
815
 
816
  # The rest of your processing pipeline now works for both engines
817
  ocr_result = ocr_data
818
+
819
  # Filter out empty strings and low confidence results
820
  valid_indices = [
821
+ i
822
+ for i, text in enumerate(ocr_result["text"])
823
+ if text.strip() and int(ocr_result["conf"][i]) > 0
824
  ]
825
+
826
  return [
827
  OCRResult(
828
+ text=clean_unicode_text(ocr_result["text"][i]),
829
+ left=ocr_result["left"][i],
830
+ top=ocr_result["top"][i],
831
+ width=ocr_result["width"][i],
832
+ height=ocr_result["height"][i], # ,
833
+ # line_number=ocr_result['abs_line_id'][i]
834
  )
835
  for i in valid_indices
836
  ]
837
 
838
  def analyze_text(
839
+ self,
840
+ line_level_ocr_results: List[OCRResult],
841
  ocr_results_with_words: Dict[str, Dict],
842
  chosen_redact_comprehend_entities: List[str],
843
  pii_identification_method: str = LOCAL_PII_OPTION,
844
+ comprehend_client="",
845
+ custom_entities: List[str] = custom_entities,
846
  language: Optional[str] = DEFAULT_LANGUAGE,
847
  nlp_analyser: AnalyzerEngine = None,
848
+ **text_analyzer_kwargs,
849
  ) -> List[CustomImageRecognizerResult]:
850
 
851
  page_text = ""
 
866
  page_text_mapping.append((start_pos, i, line_level_ocr_result, None))
867
 
868
  # Determine language for downstream services
869
+ aws_language = language or getattr(self, "language", None) or "en"
870
 
871
+ valid_language_entities = nlp_analyser.registry.get_supported_entities(
872
+ languages=[language]
873
+ )
874
  if "CUSTOM" not in valid_language_entities:
875
  valid_language_entities.append("CUSTOM")
876
  if "CUSTOM_FUZZY" not in valid_language_entities:
 
879
  # Process using either Local or AWS Comprehend
880
  if pii_identification_method == LOCAL_PII_OPTION:
881
 
882
+ language_supported_entities = filter_entities_for_language(
883
+ custom_entities, valid_language_entities, language
884
+ )
885
+
886
  if language_supported_entities:
887
  text_analyzer_kwargs["entities"] = language_supported_entities
888
 
889
+ # if language != "en":
890
  # gr.Info(f"Using {str(language_supported_entities)} entities for local model analysis for language: {language}")
891
  else:
892
  print(f"No relevant entities supported for language: {language}")
893
+ raise Warning(
894
+ f"No relevant entities supported for language: {language}"
895
+ )
896
 
897
  analyzer_result = nlp_analyser.analyze(
898
+ text=page_text, language=language, **text_analyzer_kwargs
 
 
899
  )
900
  all_text_line_results = map_back_entity_results(
901
+ analyzer_result, page_text_mapping, all_text_line_results
 
 
902
  )
903
 
904
+ elif pii_identification_method == AWS_PII_OPTION:
905
 
906
  # Handle custom entities first
907
  if custom_entities:
908
  custom_redact_entities = [
909
+ entity
910
+ for entity in chosen_redact_comprehend_entities
911
  if entity in custom_entities
912
  ]
913
 
914
  if custom_redact_entities:
915
  # Filter entities to only include those supported by the language
916
+ language_supported_entities = filter_entities_for_language(
917
+ custom_redact_entities, valid_language_entities, language
918
+ )
919
 
920
  if language_supported_entities:
921
  text_analyzer_kwargs["entities"] = language_supported_entities
922
 
923
  page_analyser_result = nlp_analyser.analyze(
924
+ text=page_text, language=language, **text_analyzer_kwargs
 
 
925
  )
926
  all_text_line_results = map_back_entity_results(
927
+ page_analyser_result, page_text_mapping, all_text_line_results
 
 
928
  )
929
 
930
  # Process text in batches for AWS Comprehend
 
937
  words = text_line.text.split()
938
  word_start_positions = list()
939
  current_pos = 0
940
+
941
  for word in words:
942
  word_start_positions.append(current_pos)
943
  current_pos += len(word) + 1
944
 
945
  for word_idx, word in enumerate(words):
946
  new_batch_char_count = len(current_batch) + len(word) + 1
947
+
948
  if batch_word_count >= 50 or new_batch_char_count >= 200:
949
  # Process current batch
950
  all_text_line_results = do_aws_comprehend_call(
 
952
  current_batch_mapping,
953
  comprehend_client,
954
  aws_language,
955
+ text_analyzer_kwargs.get("allow_list", []),
956
  chosen_redact_comprehend_entities,
957
+ all_text_line_results,
958
  )
959
  comprehend_query_number += 1
960
+
961
  # Reset batch
962
  current_batch = word
963
  batch_word_count = 1
964
  batch_char_count = len(word)
965
+ current_batch_mapping = [
966
+ (0, i, text_line, None, word_start_positions[word_idx])
967
+ ]
968
  else:
969
  if current_batch:
970
  current_batch += " "
 
972
  current_batch += word
973
  batch_char_count += len(word)
974
  batch_word_count += 1
975
+
976
+ if (
977
+ not current_batch_mapping
978
+ or current_batch_mapping[-1][1] != i
979
+ ):
980
+ current_batch_mapping.append(
981
+ (
982
+ batch_char_count - len(word),
983
+ i,
984
+ text_line,
985
+ None,
986
+ word_start_positions[word_idx],
987
+ )
988
+ )
989
 
990
  # Process final batch if any
991
  if current_batch:
 
994
  current_batch_mapping,
995
  comprehend_client,
996
  aws_language,
997
+ text_analyzer_kwargs.get("allow_list", []),
998
  chosen_redact_comprehend_entities,
999
+ all_text_line_results,
1000
  )
1001
+ comprehend_query_number += 1
1002
 
1003
  # Process results and create bounding boxes
1004
  combined_results = list()
1005
  for i, text_line in enumerate(line_level_ocr_results):
1006
+ line_results = next(
1007
+ (results for idx, results in all_text_line_results if idx == i), []
1008
+ )
1009
  if line_results and i < len(ocr_results_with_words):
1010
  child_level_key = list(ocr_results_with_words.keys())[i]
1011
+ ocr_results_with_words_line_level = ocr_results_with_words[
1012
+ child_level_key
1013
+ ]
1014
+
1015
  for result in line_results:
1016
  bbox_results = self.map_analyzer_results_to_bounding_boxes(
1017
  [result],
1018
+ [
1019
+ OCRResult(
1020
+ text=text_line.text[result.start : result.end],
1021
+ left=text_line.left,
1022
+ top=text_line.top,
1023
+ width=text_line.width,
1024
+ height=text_line.height,
1025
+ )
1026
+ ],
1027
  text_line.text,
1028
+ text_analyzer_kwargs.get("allow_list", []),
1029
+ ocr_results_with_words_line_level,
1030
  )
1031
  combined_results.extend(bbox_results)
1032
 
 
1034
 
1035
  @staticmethod
1036
  def map_analyzer_results_to_bounding_boxes(
1037
+ text_analyzer_results: List[RecognizerResult],
1038
+ redaction_relevant_ocr_results: List[OCRResult],
1039
+ full_text: str,
1040
+ allow_list: List[str],
1041
+ ocr_results_with_words_child_info: Dict[str, Dict],
1042
+ ) -> List[CustomImageRecognizerResult]:
1043
  redaction_bboxes = list()
1044
 
1045
  for redaction_relevant_ocr_result in redaction_relevant_ocr_results:
1046
+ # print("ocr_results_with_words_child_info:", ocr_results_with_words_child_info)
1047
 
1048
+ line_text = ocr_results_with_words_child_info["text"]
1049
  line_length = len(line_text)
1050
  redaction_text = redaction_relevant_ocr_result.text
1051
+
1052
  for redaction_result in text_analyzer_results:
1053
  # Check if the redaction text is not in the allow list
1054
+
1055
  if redaction_text not in allow_list:
1056
+
1057
  # Adjust start and end to be within line bounds
1058
  start_in_line = max(0, redaction_result.start)
1059
  end_in_line = min(line_length, redaction_result.end)
1060
+
1061
  # Get the matched text from this line
1062
  matched_text = line_text[start_in_line:end_in_line]
1063
+ matched_text.split()
1064
+
1065
  # Find the corresponding words in the OCR results
1066
  matching_word_boxes = list()
1067
 
1068
  current_position = 0
1069
 
1070
+ for word_info in ocr_results_with_words_child_info.get("words", []):
1071
+ word_text = word_info["text"]
1072
  word_length = len(word_text)
1073
 
1074
  word_start = current_position
1075
  word_end = current_position + word_length
1076
 
1077
  # Update current position for the next word
1078
+ current_position += (
1079
+ word_length + 1
1080
+ ) # +1 for the space after the word
1081
+
1082
  # Check if the word's bounding box is within the start and end bounds
1083
+ if word_start >= start_in_line and word_end <= (
1084
+ end_in_line + 1
1085
+ ):
1086
+ matching_word_boxes.append(word_info["bounding_box"])
1087
+ # print(f"Matched word: {word_info['text']}")
1088
+
1089
  if matching_word_boxes:
1090
  # Calculate the combined bounding box for all matching words
1091
  left = min(box[0] for box in matching_word_boxes)
1092
  top = min(box[1] for box in matching_word_boxes)
1093
  right = max(box[2] for box in matching_word_boxes)
1094
  bottom = max(box[3] for box in matching_word_boxes)
1095
+
1096
  redaction_bboxes.append(
1097
  CustomImageRecognizerResult(
1098
  entity_type=redaction_result.entity_type,
 
1103
  top=top,
1104
  width=right - left,
1105
  height=bottom - top,
1106
+ text=matched_text,
1107
  )
1108
  )
1109
 
1110
  return redaction_bboxes
1111
+
1112
  @staticmethod
1113
  def remove_space_boxes(ocr_result: dict) -> dict:
1114
  """Remove OCR bboxes that are for spaces.
 
1128
  filtered_ocr_result[key] = [ocr_result[key][i] for i in idx]
1129
 
1130
  return filtered_ocr_result
1131
+
1132
  @staticmethod
1133
  def _scale_bbox_results(
1134
  ocr_result: Dict[str, List[Union[int, str]]], scale_factor: float
 
1159
  # Estimate the x-offset based on character position
1160
  # This is a simple estimation and might need refinement for variable-width fonts
1161
  return int(start / len(full_text) * len(full_text))
1162
+
1163
  def estimate_width(self, ocr_result: OCRResult, start: int, end: int) -> int:
1164
  # Extract the relevant text portion
1165
  relevant_text = ocr_result.text[start:end]
1166
+
1167
  # If the relevant text is the same as the full text, return the full width
1168
  if relevant_text == ocr_result.text:
1169
  return ocr_result.width
1170
+
1171
  # Estimate width based on the proportion of the relevant text length to the total text length
1172
  total_text_length = len(ocr_result.text)
1173
  relevant_text_length = len(relevant_text)
1174
+
1175
  if total_text_length == 0:
1176
  return 0 # Avoid division by zero
1177
+
1178
  # Proportion of the relevant text to the total text
1179
  proportion = relevant_text_length / total_text_length
1180
+
1181
  # Estimate the width based on the proportion
1182
  estimated_width = int(proportion * ocr_result.width)
1183
+
1184
  return estimated_width
1185
 
1186
 
1187
+ def bounding_boxes_overlap(box1: List, box2: List):
1188
  """Check if two bounding boxes overlap."""
1189
+ return (
1190
+ box1[0] < box2[2]
1191
+ and box2[0] < box1[2]
1192
+ and box1[1] < box2[3]
1193
+ and box2[1] < box1[3]
1194
+ )
1195
+
1196
+
1197
+ def map_back_entity_results(
1198
+ page_analyser_result: dict,
1199
+ page_text_mapping: dict,
1200
+ all_text_line_results: List[Tuple],
1201
+ ):
1202
  for entity in page_analyser_result:
1203
  entity_start = entity.start
1204
  entity_end = entity.end
1205
+
1206
  # Track if the entity has been added to any line
1207
  added_to_line = False
1208
+
1209
  for batch_start, line_idx, original_line, chars in page_text_mapping:
1210
  batch_end = batch_start + len(original_line.text)
1211
+
1212
  # Check if the entity overlaps with the current line
1213
+ if (
1214
+ batch_start < entity_end and batch_end > entity_start
1215
+ ): # Overlap condition
1216
+ relative_start = max(
1217
+ 0, entity_start - batch_start
1218
+ ) # Adjust start relative to the line
1219
+ relative_end = min(
1220
+ entity_end - batch_start, len(original_line.text)
1221
+ ) # Adjust end relative to the line
1222
+
1223
  # Create a new adjusted entity
1224
  adjusted_entity = copy.deepcopy(entity)
1225
  adjusted_entity.start = relative_start
1226
  adjusted_entity.end = relative_end
1227
+
1228
  # Check if this line already has an entry
1229
+ existing_entry = next(
1230
+ (entry for idx, entry in all_text_line_results if idx == line_idx),
1231
+ None,
1232
+ )
1233
+
1234
  if existing_entry is None:
1235
  all_text_line_results.append((line_idx, [adjusted_entity]))
1236
  else:
1237
+ existing_entry.append(
1238
+ adjusted_entity
1239
+ ) # Append to the existing list of entities
1240
+
1241
  added_to_line = True
1242
+
1243
  # If the entity spans multiple lines, you may want to handle that here
1244
  if not added_to_line:
1245
  # Handle cases where the entity does not fit in any line (optional)
 
1247
 
1248
  return all_text_line_results
1249
 
1250
+
1251
+ def map_back_comprehend_entity_results(
1252
+ response: object,
1253
+ current_batch_mapping: List[Tuple],
1254
+ allow_list: List[str],
1255
+ chosen_redact_comprehend_entities: List[str],
1256
+ all_text_line_results: List[Tuple],
1257
+ ):
1258
  if not response or "Entities" not in response:
1259
  return all_text_line_results
1260
 
 
1269
  added_to_line = False
1270
 
1271
  # Find the correct line and offset within that line
1272
+ for (
1273
+ batch_start,
1274
+ line_idx,
1275
+ original_line,
1276
+ chars,
1277
+ line_offset,
1278
+ ) in current_batch_mapping:
1279
  batch_end = batch_start + len(original_line.text[line_offset:])
1280
 
1281
  # Check if the entity overlaps with the current line
1282
+ if (
1283
+ batch_start < entity_end and batch_end > entity_start
1284
+ ): # Overlap condition
1285
  # Calculate the absolute position within the line
1286
  relative_start = max(0, entity_start - batch_start + line_offset)
1287
+ relative_end = min(
1288
+ entity_end - batch_start + line_offset, len(original_line.text)
1289
+ )
1290
 
1291
  result_text = original_line.text[relative_start:relative_end]
1292
 
1293
  if result_text not in allow_list:
1294
  adjusted_entity = entity.copy()
1295
+ adjusted_entity["BeginOffset"] = (
1296
+ relative_start # Now relative to the full line
1297
+ )
1298
  adjusted_entity["EndOffset"] = relative_end
1299
 
1300
  recogniser_entity = recognizer_result_from_dict(adjusted_entity)
1301
 
1302
+ existing_entry = next(
1303
+ (
1304
+ entry
1305
+ for idx, entry in all_text_line_results
1306
+ if idx == line_idx
1307
+ ),
1308
+ None,
1309
+ )
1310
  if existing_entry is None:
1311
  all_text_line_results.append((line_idx, [recogniser_entity]))
1312
  else:
1313
+ existing_entry.append(
1314
+ recogniser_entity
1315
+ ) # Append to the existing list of entities
1316
 
1317
  added_to_line = True
1318
 
 
1322
 
1323
  return all_text_line_results
1324
 
1325
+
1326
+ def do_aws_comprehend_call(
1327
+ current_batch: str,
1328
+ current_batch_mapping: List[Tuple],
1329
+ comprehend_client: botocore.client.BaseClient,
1330
+ language: str,
1331
+ allow_list: List[str],
1332
+ chosen_redact_comprehend_entities: List[str],
1333
+ all_text_line_results: List[Tuple],
1334
+ ):
1335
  if not current_batch:
1336
  return all_text_line_results
1337
 
 
1341
  for attempt in range(max_retries):
1342
  try:
1343
  response = comprehend_client.detect_pii_entities(
1344
+ Text=current_batch.strip(), LanguageCode=language
 
1345
  )
1346
 
1347
  all_text_line_results = map_back_comprehend_entity_results(
1348
+ response,
1349
+ current_batch_mapping,
1350
+ allow_list,
1351
+ chosen_redact_comprehend_entities,
1352
+ all_text_line_results,
1353
  )
1354
 
1355
  return all_text_line_results
1356
+
1357
  except Exception as e:
1358
  if attempt == max_retries - 1:
1359
  print("AWS Comprehend calls failed due to", e)
1360
  raise
1361
  time.sleep(retry_delay)
1362
 
1363
+
1364
  def run_page_text_redaction(
1365
  language: str,
1366
  chosen_redact_entities: List[str],
 
1369
  line_characters: List,
1370
  page_analyser_results: List = list(),
1371
  page_analysed_bounding_boxes: List = list(),
1372
+ comprehend_client=None,
1373
  allow_list: List[str] = None,
1374
  pii_identification_method: str = LOCAL_PII_OPTION,
1375
  nlp_analyser: AnalyzerEngine = None,
1376
  score_threshold: float = 0.0,
1377
  custom_entities: List[str] = None,
1378
+ comprehend_query_number: int = 0,
1379
  ):
1380
  """
1381
  This function performs text redaction on a page based on the specified language and chosen entities.
 
1396
  custom_entities (List[str], optional): A list of custom entities for redaction. Defaults to None.
1397
  comprehend_query_number (int, optional): A counter for the number of Comprehend queries made. Defaults to 0.
1398
  """
1399
+
1400
  page_text = ""
1401
  page_text_mapping = list()
1402
  all_text_line_results = list()
 
1407
  if chosen_redact_entities:
1408
  if page_text:
1409
  page_text += " "
1410
+
1411
  start_pos = len(page_text)
1412
  page_text += text_line.text
1413
  page_text_mapping.append((start_pos, i, text_line, line_characters[i]))
1414
 
1415
+ valid_language_entities = nlp_analyser.registry.get_supported_entities(
1416
+ languages=[language]
1417
+ )
1418
  if "CUSTOM" not in valid_language_entities:
1419
  valid_language_entities.append("CUSTOM")
1420
  if "CUSTOM_FUZZY" not in valid_language_entities:
 
1424
  if pii_identification_method == LOCAL_PII_OPTION:
1425
  if not nlp_analyser:
1426
  raise ValueError("nlp_analyser is required for Local identification method")
1427
+
1428
+ language_supported_entities = filter_entities_for_language(
1429
+ chosen_redact_entities, valid_language_entities, language
1430
+ )
1431
 
1432
  page_analyser_result = nlp_analyser.analyze(
1433
  text=page_text,
 
1435
  entities=language_supported_entities,
1436
  score_threshold=score_threshold,
1437
  return_decision_process=True,
1438
+ allow_list=allow_list,
1439
  )
1440
 
 
1441
  all_text_line_results = map_back_entity_results(
1442
+ page_analyser_result, page_text_mapping, all_text_line_results
 
 
1443
  )
1444
 
1445
  elif pii_identification_method == AWS_PII_OPTION:
 
1447
  # Process custom entities if any
1448
  if custom_entities:
1449
  custom_redact_entities = [
1450
+ entity
1451
+ for entity in chosen_redact_comprehend_entities
1452
  if entity in custom_entities
1453
  ]
1454
 
1455
+ language_supported_entities = filter_entities_for_language(
1456
+ custom_redact_entities, valid_language_entities, language
1457
+ )
1458
 
1459
  if language_supported_entities:
1460
  page_analyser_result = nlp_analyser.analyze(
 
1463
  entities=language_supported_entities,
1464
  score_threshold=score_threshold,
1465
  return_decision_process=True,
1466
+ allow_list=allow_list,
1467
  )
1468
 
1469
  all_text_line_results = map_back_entity_results(
1470
+ page_analyser_result, page_text_mapping, all_text_line_results
 
 
1471
  )
1472
 
1473
  current_batch = ""
 
1478
  for i, text_line in enumerate(line_level_text_results_list):
1479
  words = text_line.text.split()
1480
  word_start_positions = list()
1481
+
1482
  # Calculate word start positions within the line
1483
  current_pos = 0
1484
  for word in words:
1485
  word_start_positions.append(current_pos)
1486
  current_pos += len(word) + 1 # +1 for space
1487
+
1488
  for word_idx, word in enumerate(words):
1489
  new_batch_char_count = len(current_batch) + len(word) + 1
1490
+
1491
  if batch_word_count >= 50 or new_batch_char_count >= 200:
1492
  # Process current batch
1493
  all_text_line_results = do_aws_comprehend_call(
 
1497
  language,
1498
  allow_list,
1499
  chosen_redact_comprehend_entities,
1500
+ all_text_line_results,
1501
  )
1502
  comprehend_query_number += 1
1503
+
1504
  # Start new batch
1505
  current_batch = word
1506
  batch_word_count = 1
1507
  batch_char_count = len(word)
1508
+ current_batch_mapping = [
1509
+ (
1510
+ 0,
1511
+ i,
1512
+ text_line,
1513
+ line_characters[i],
1514
+ word_start_positions[word_idx],
1515
+ )
1516
+ ]
1517
  else:
1518
  if current_batch:
1519
  current_batch += " "
 
1521
  current_batch += word
1522
  batch_char_count += len(word)
1523
  batch_word_count += 1
1524
+
1525
  if not current_batch_mapping or current_batch_mapping[-1][1] != i:
1526
+ current_batch_mapping.append(
1527
+ (
1528
+ batch_char_count - len(word),
1529
+ i,
1530
+ text_line,
1531
+ line_characters[i],
1532
+ word_start_positions[
1533
+ word_idx
1534
+ ], # Add the word's start position within its line
1535
+ )
1536
+ )
1537
 
1538
  # Process final batch
1539
  if current_batch:
 
1544
  language,
1545
  allow_list,
1546
  chosen_redact_comprehend_entities,
1547
+ all_text_line_results,
1548
  )
1549
  comprehend_query_number += 1
1550
 
1551
  # Process results for each line
1552
  for i, text_line in enumerate(line_level_text_results_list):
1553
+ line_results = next(
1554
+ (results for idx, results in all_text_line_results if idx == i), []
1555
+ )
1556
+
1557
  if line_results:
1558
  text_line_bounding_boxes = merge_text_bounding_boxes(
1559
+ line_results, line_characters[i]
 
1560
  )
1561
+
1562
  page_analyser_results.extend(line_results)
1563
  page_analysed_bounding_boxes.extend(text_line_bounding_boxes)
1564
 
1565
  return page_analysed_bounding_boxes
1566
 
1567
+
1568
+ def merge_text_bounding_boxes(
1569
+ analyser_results: dict,
1570
+ characters: List[LTChar],
1571
+ combine_pixel_dist: int = 20,
1572
+ vertical_padding: int = 0,
1573
+ ):
1574
+ """
1575
  Merge identified bounding boxes containing PII that are very close to one another
1576
+ """
1577
  analysed_bounding_boxes = list()
1578
  original_bounding_boxes = list() # List to hold original bounding boxes
1579
 
 
1581
  # Extract bounding box coordinates for sorting
1582
  bounding_boxes = list()
1583
  for result in analyser_results:
1584
+ # print("Result:", result)
1585
+ char_boxes = [
1586
+ char.bbox
1587
+ for char in characters[result.start : result.end]
1588
+ if isinstance(char, LTChar)
1589
+ ]
1590
+ char_text = [
1591
+ char._text
1592
+ for char in characters[result.start : result.end]
1593
+ if isinstance(char, LTChar)
1594
+ ]
1595
  if char_boxes:
1596
  # Calculate the bounding box that encompasses all characters
1597
  left = min(box[0] for box in char_boxes)
 
1599
  right = max(box[2] for box in char_boxes)
1600
  top = max(box[3] for box in char_boxes) + vertical_padding
1601
  bbox = [left, bottom, right, top]
1602
+ bounding_boxes.append(
1603
+ (bottom, left, result, bbox, char_text)
1604
+ ) # (y, x, result, bbox, text)
1605
 
1606
  # Store original bounding boxes
1607
+ original_bounding_boxes.append(
1608
+ {
1609
+ "text": "".join(char_text),
1610
+ "boundingBox": bbox,
1611
+ "result": copy.deepcopy(result),
1612
+ }
1613
+ )
1614
+ # print("Original bounding boxes:", original_bounding_boxes)
1615
 
1616
  # Sort the results by y-coordinate and then by x-coordinate
1617
  bounding_boxes.sort()
 
1633
  vertical_diff_bboxes = abs(next_box[1] - current_y)
1634
  horizontal_diff_bboxes = abs(next_box[0] - current_box[2])
1635
 
1636
+ if (
1637
+ vertical_diff_bboxes <= 5
1638
+ and horizontal_diff_bboxes <= combine_pixel_dist
1639
+ ):
1640
  # Merge bounding boxes
1641
+ # print("Merging boxes")
1642
  merged_box = current_box.copy()
1643
  merged_result = current_result
1644
  merged_text = current_text.copy()
1645
 
1646
  merged_box[2] = next_box[2] # Extend horizontally
1647
  merged_box[3] = max(current_box[3], next_box[3]) # Adjust the top
1648
+ merged_result.end = max(
1649
+ current_result.end, result.end
1650
+ ) # Extend text range
1651
  try:
1652
  if current_result.entity_type != result.entity_type:
1653
+ merged_result.entity_type = (
1654
+ current_result.entity_type + " - " + result.entity_type
1655
+ )
1656
  else:
1657
  merged_result.entity_type = current_result.entity_type
1658
  except Exception as e:
 
1661
  merged_text.append(" ") # Add space between texts
1662
  merged_text.extend(text)
1663
 
1664
+ merged_bounding_boxes.append(
1665
+ {
1666
+ "text": "".join(merged_text),
1667
+ "boundingBox": merged_box,
1668
+ "result": merged_result,
1669
+ }
1670
+ )
1671
 
1672
  else:
1673
  # Start a new bounding box
 
1680
  analysed_bounding_boxes.extend(original_bounding_boxes)
1681
  analysed_bounding_boxes.extend(merged_bounding_boxes)
1682
 
1683
+ # print("Analysed bounding boxes:", analysed_bounding_boxes)
1684
 
1685
  return analysed_bounding_boxes
1686
 
1687
+
1688
+ def recreate_page_line_level_ocr_results_with_page(
1689
+ page_line_level_ocr_results_with_words: dict,
1690
+ ):
1691
  reconstructed_results = list()
1692
+
1693
  # Assume all lines belong to the same page, so we can just read it from one item
1694
+ # page = next(iter(page_line_level_ocr_results_with_words.values()))["page"]
1695
 
1696
  page = page_line_level_ocr_results_with_words["page"]
1697
+
1698
  for line_data in page_line_level_ocr_results_with_words["results"].values():
1699
  bbox = line_data["bounding_box"]
1700
  text = line_data["text"]
 
1708
  top=bbox[1],
1709
  width=bbox[2] - bbox[0],
1710
  height=bbox[3] - bbox[1],
1711
+ line=line_number,
1712
  )
1713
  reconstructed_results.append(line_result)
1714
+
1715
+ page_line_level_ocr_results_with_page = {
1716
+ "page": page,
1717
+ "results": reconstructed_results,
1718
+ }
1719
+
1720
  return page_line_level_ocr_results_with_page
1721
 
1722
+
1723
+ def split_words_and_punctuation_from_line(
1724
+ line_of_words: List[OCRResult],
1725
+ ) -> List[OCRResult]:
1726
  """
1727
  Takes a list of OCRResult objects and splits words with trailing/leading punctuation.
1728
 
 
1731
  "high-tech" are preserved.
1732
  """
1733
  # Punctuation that will be split off. Hyphen is not included.
1734
+
 
1735
  new_word_list = list()
1736
+
1737
  for word_result in line_of_words:
1738
  word_text = word_result.text
1739
+
1740
  # This regex finds a central "core" word, and captures leading and trailing punctuation
1741
  # Handles cases like "(word)." -> group1='(', group2='word', group3='.'
1742
  match = re.match(r"([(\[{]*)(.*?)_?([.,?!:;)\}\]]*)$", word_text)
1743
 
1744
  # Handle words with internal hyphens that might confuse the regex
1745
+ if "-" in word_text and not match.group(2):
1746
+ core_part_text = word_text
1747
+ leading_punc = ""
1748
+ trailing_punc = ""
1749
  elif match:
1750
  leading_punc, core_part_text, trailing_punc = match.groups()
1751
+ else: # Failsafe
1752
  new_word_list.append(word_result)
1753
  continue
1754
+
1755
  # If no split is needed, just add the original and continue
1756
  if not leading_punc and not trailing_punc:
1757
  new_word_list.append(word_result)
1758
  continue
1759
+
1760
  # --- A split is required ---
1761
  # Estimate new bounding boxes by proportionally allocating width
1762
  original_width = word_result.width
1763
+ if not word_text or original_width == 0:
1764
+ continue # Failsafe
1765
+
1766
  avg_char_width = original_width / len(word_text)
1767
  current_left = word_result.left
1768
 
1769
  # Add leading punctuation if it exists
1770
  if leading_punc:
1771
  punc_width = avg_char_width * len(leading_punc)
1772
+ new_word_list.append(
1773
+ OCRResult(
1774
+ text=leading_punc,
1775
+ left=current_left,
1776
+ top=word_result.top,
1777
+ width=punc_width,
1778
+ height=word_result.height,
1779
+ )
1780
+ )
1781
  current_left += punc_width
1782
 
1783
  # Add the core part of the word
1784
  if core_part_text:
1785
  core_width = avg_char_width * len(core_part_text)
1786
+ new_word_list.append(
1787
+ OCRResult(
1788
+ text=core_part_text,
1789
+ left=current_left,
1790
+ top=word_result.top,
1791
+ width=core_width,
1792
+ height=word_result.height,
1793
+ )
1794
+ )
1795
  current_left += core_width
1796
 
1797
  # Add trailing punctuation if it exists
1798
  if trailing_punc:
1799
  punc_width = avg_char_width * len(trailing_punc)
1800
+ new_word_list.append(
1801
+ OCRResult(
1802
+ text=trailing_punc,
1803
+ left=current_left,
1804
+ top=word_result.top,
1805
+ width=punc_width,
1806
+ height=word_result.height,
1807
+ )
1808
+ )
1809
+
1810
  return new_word_list
1811
 
1812
+
1813
+ def create_ocr_result_with_children(
1814
+ combined_results: dict, i: int, current_bbox: dict, current_line: list
1815
+ ):
1816
+ combined_results["text_line_" + str(i)] = {
1817
  "line": i,
1818
+ "text": current_bbox.text,
1819
+ "bounding_box": (
1820
+ current_bbox.left,
1821
+ current_bbox.top,
1822
+ current_bbox.left + current_bbox.width,
1823
+ current_bbox.top + current_bbox.height,
1824
+ ),
1825
+ "words": [
1826
+ {
1827
+ "text": word.text,
1828
+ "bounding_box": (
1829
+ word.left,
1830
+ word.top,
1831
+ word.left + word.width,
1832
+ word.top + word.height,
1833
+ ),
1834
+ }
1835
+ for word in current_line
1836
+ ],
1837
  }
1838
+ return combined_results["text_line_" + str(i)]
1839
+
1840
 
1841
+ def combine_ocr_results(
1842
+ ocr_results: List[OCRResult],
1843
+ x_threshold: float = 50.0,
1844
+ y_threshold: float = 12.0,
1845
+ page: int = 1,
1846
+ ):
1847
  """
1848
  Group OCR results into lines, splitting words from punctuation.
1849
  """
 
1878
  line_top = min(word.top for word in line)
1879
  line_right = max(word.left + word.width for word in line)
1880
  line_bottom = max(word.top + word.height for word in line)
1881
+
1882
  final_line_bbox = OCRResult(
1883
  text=line_text,
1884
  left=line_left,
1885
  top=line_top,
1886
  width=line_right - line_left,
1887
  height=line_bottom - line_top,
1888
+ line=line_counter,
1889
  )
1890
+
1891
  page_line_level_ocr_results.append(final_line_bbox)
1892
+
1893
  # Use the PROCESSED line to create the children. Creates a result within page_line_level_ocr_results_with_words
1894
+ page_line_level_ocr_results_with_words["text_line_" + str(line_counter)] = (
1895
+ create_ocr_result_with_children(
1896
+ page_line_level_ocr_results_with_words,
1897
+ line_counter,
1898
+ final_line_bbox,
1899
+ processed_line, # <-- Use the new, split list of words
1900
+ )
1901
  )
1902
  line_counter += 1
1903
 
1904
+ page_level_results_with_page = {
1905
+ "page": page,
1906
+ "results": page_line_level_ocr_results,
1907
+ }
1908
+ page_level_results_with_words = {
1909
+ "page": page,
1910
+ "results": page_line_level_ocr_results_with_words,
1911
+ }
1912
 
1913
  return page_level_results_with_page, page_level_results_with_words
tools/data_anonymise.py CHANGED
@@ -1,65 +1,112 @@
1
- import re
2
  import os
 
3
  import secrets
4
- import base64
5
  import time
 
 
 
6
  import boto3
7
  import botocore
8
- import pandas as pd
9
- import polars as pl
10
- import unicodedata
11
  import docx
12
  import gradio as gr
13
- from openpyxl import Workbook
 
 
14
  from faker import Faker
15
  from gradio import Progress
16
- from typing import List, Dict, Any, Optional
17
- from botocore.client import BaseClient
18
- from presidio_anonymizer.entities import OperatorConfig, ConflictResolutionStrategy
19
- from presidio_analyzer import AnalyzerEngine, BatchAnalyzerEngine, DictAnalyzerResult, RecognizerResult
 
 
 
20
  from presidio_anonymizer import AnonymizerEngine, BatchAnonymizerEngine
21
- from tools.config import RUN_AWS_FUNCTIONS, AWS_ACCESS_KEY, AWS_SECRET_KEY, OUTPUT_FOLDER, DEFAULT_LANGUAGE, aws_comprehend_language_choices, DO_INITIAL_TABULAR_DATA_CLEAN, CUSTOM_ENTITIES, PRIORITISE_SSO_OVER_AWS_ENV_ACCESS_KEYS, AWS_REGION, MAX_TABLE_ROWS, MAX_TABLE_COLUMNS, MAX_SIMULTANEOUS_FILES
22
- from tools.helper_functions import get_file_name_without_type, read_file, detect_file_type, _get_env_list
23
- from tools.load_spacy_model_custom_recognisers import nlp_analyser, score_threshold, custom_word_list_recogniser, CustomWordFuzzyRecognizer, create_nlp_analyser, load_spacy_model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  # Use custom version of analyze_dict to be able to track progress
25
  from tools.presidio_analyzer_custom import analyze_dict
26
 
27
- if DO_INITIAL_TABULAR_DATA_CLEAN == "True": DO_INITIAL_TABULAR_DATA_CLEAN = True
28
- else: DO_INITIAL_TABULAR_DATA_CLEAN = False
 
 
29
 
30
- if CUSTOM_ENTITIES: CUSTOM_ENTITIES = _get_env_list(CUSTOM_ENTITIES)
 
31
 
32
  custom_entities = CUSTOM_ENTITIES
33
 
34
  fake = Faker("en_UK")
 
 
35
  def fake_first_name(x):
36
  return fake.first_name()
37
 
 
38
  # #### Some of my cleaning functions
39
- url_pattern = r'http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\\(\\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+|(?:www\.)[a-zA-Z0-9._-]+\.[a-zA-Z]{2,}'
40
- html_pattern_regex = r'<.*?>|&([a-z0-9]+|#[0-9]{1,6}|#x[0-9a-f]{1,6});|\xa0|&nbsp;'
41
- html_start_pattern_end_dots_regex = r'<(.*?)\.\.'
42
- non_ascii_pattern = r'[^\x00-\x7F]+'
43
- and_sign_regex = r'&'
44
- multiple_spaces_regex = r'\s{2,}'
45
- multiple_new_lines_regex = r'(\r\n|\n)+'
46
  multiple_punctuation_regex = r"(\p{P})\p{P}+"
47
 
48
- def initial_clean(texts:pd.Series) -> pd.Series:
49
- '''
 
50
  This function cleans the text by removing URLs, HTML tags, and non-ASCII characters.
51
- '''
52
  for text in texts:
53
  if not text or pd.isnull(text):
54
  text = ""
55
 
56
  # Normalize unicode characters to decompose any special forms
57
- normalized_text = unicodedata.normalize('NFKC', text)
58
 
59
  # Replace smart quotes and special punctuation with standard ASCII equivalents
60
  replacements = {
61
- '': "'", '’': "'", '“': '"', '”': '"',
62
- '–': '-', '—': '-', '…': '...', '•': '*',
 
 
 
 
 
 
63
  }
64
 
65
  # Perform replacements
@@ -70,55 +117,70 @@ def initial_clean(texts:pd.Series) -> pd.Series:
70
 
71
  # Convert to polars Series
72
  texts = pl.Series(texts).str.strip_chars()
73
-
74
  # Define a list of patterns and their replacements
75
  patterns = [
76
- (multiple_new_lines_regex, ' '),
77
- (r'\r', ''),
78
- (url_pattern, ' '),
79
- (html_pattern_regex, ' '),
80
- (html_start_pattern_end_dots_regex, ' '),
81
- (non_ascii_pattern, ' '),
82
- (multiple_spaces_regex, ' '),
83
  (multiple_punctuation_regex, "${1}"),
84
- (and_sign_regex, 'and')
85
  ]
86
-
87
  # Apply each regex replacement
88
  for pattern, replacement in patterns:
89
  texts = texts.str.replace_all(pattern, replacement)
90
-
91
  # Convert the series back to a list
92
  texts = texts.to_list()
93
-
94
  return texts
95
 
96
- def process_recognizer_result(result:RecognizerResult, recognizer_result:RecognizerResult, data_row:int, dictionary_key:int, df_dict:Dict[str, List[Any]], keys_to_keep:List[str]) -> List[str]:
97
- output = list()
98
 
99
- if hasattr(result, 'value'):
100
- text = result.value[data_row]
101
- else:
102
- text = ""
103
-
104
- if isinstance(recognizer_result, list):
105
- for sub_result in recognizer_result:
106
- if isinstance(text, str):
107
- found_text = text[sub_result.start:sub_result.end]
108
- else:
109
- found_text = ''
110
- analysis_explanation = {key: sub_result.__dict__[key] for key in keys_to_keep}
111
- analysis_explanation.update({
112
- 'data_row': str(data_row),
113
- 'column': list(df_dict.keys())[dictionary_key],
114
- 'entity': found_text
115
- })
116
- output.append(str(analysis_explanation))
117
-
118
- return output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
 
120
  # Writing decision making process to file
121
- def generate_decision_process_output(analyzer_results: List[DictAnalyzerResult], df_dict: Dict[str, List[Any]]) -> str:
 
 
122
  """
123
  Generate a detailed output of the decision process for entity recognition.
124
 
@@ -135,35 +197,46 @@ def generate_decision_process_output(analyzer_results: List[DictAnalyzerResult],
135
  str: A string containing the detailed decision process output.
136
  """
137
  decision_process_output = list()
138
- keys_to_keep = ['entity_type', 'start', 'end']
139
 
140
  # Run through each column to analyse for PII
141
  for i, result in enumerate(analyzer_results):
142
 
143
  # If a single result
144
  if isinstance(result, RecognizerResult):
145
- decision_process_output.extend(process_recognizer_result(result, result, 0, i, df_dict, keys_to_keep))
 
 
146
 
147
  # If a list of results
148
  elif isinstance(result, list) or isinstance(result, DictAnalyzerResult):
149
  for x, recognizer_result in enumerate(result.recognizer_results):
150
- decision_process_output.extend(process_recognizer_result(result, recognizer_result, x, i, df_dict, keys_to_keep))
 
 
 
 
151
 
152
  else:
153
  try:
154
- decision_process_output.extend(process_recognizer_result(result, result, 0, i, df_dict, keys_to_keep))
 
 
 
 
155
  except Exception as e:
156
  print(e)
157
 
158
- decision_process_output_str = '\n'.join(decision_process_output)
159
 
160
  return decision_process_output_str
161
 
162
- def anon_consistent_names(df:pd.DataFrame) -> pd.DataFrame:
 
163
  # ## Pick out common names and replace them with the same person value
164
  df_dict = df.to_dict(orient="list")
165
 
166
- #analyzer = AnalyzerEngine()
167
  batch_analyzer = BatchAnalyzerEngine(analyzer_engine=nlp_analyser)
168
 
169
  analyzer_results = batch_analyzer.analyze_dict(df_dict, language=DEFAULT_LANGUAGE)
@@ -177,17 +250,17 @@ def anon_consistent_names(df:pd.DataFrame) -> pd.DataFrame:
177
 
178
  # Adjusting the parse_dict function to handle trailing ']'
179
  # Splitting the main data string into individual list strings
180
- list_strs = data_str[1:-1].split('], [')
181
 
182
  def parse_dict(s):
183
- s = s.strip('[]') # Removing any surrounding brackets
184
- items = s.split(', ')
185
  d = {}
186
  for item in items:
187
- key, value = item.split(': ')
188
- if key == 'score':
189
  d[key] = float(value)
190
- elif key in ['start', 'end']:
191
  d[key] = int(value)
192
  else:
193
  d[key] = value
@@ -199,9 +272,11 @@ def anon_consistent_names(df:pd.DataFrame) -> pd.DataFrame:
199
 
200
  for lst_str in list_strs:
201
  # Splitting each list string into individual dictionary strings
202
- dict_strs = lst_str.split(', type: ')
203
- dict_strs = [dict_strs[0]] + ['type: ' + s for s in dict_strs[1:]] # Prepending "type: " back to the split strings
204
-
 
 
205
  # Parsing each dictionary string
206
  dicts = [parse_dict(d) for d in dict_strs]
207
  result.append(dicts)
@@ -211,29 +286,36 @@ def anon_consistent_names(df:pd.DataFrame) -> pd.DataFrame:
211
  for idx, paragraph in enumerate(text):
212
  paragraph_texts = list()
213
  for dictionary in result[idx]:
214
- if dictionary['type'] == 'PERSON':
215
- paragraph_texts.append(paragraph[dictionary['start']:dictionary['end']])
 
 
216
  names.append(paragraph_texts)
217
 
218
  # Flatten the list of lists and extract unique names
219
  unique_names = list(set(name for sublist in names for name in sublist))
220
-
221
  fake_names = pd.Series(unique_names).apply(fake_first_name)
222
 
223
- mapping_df = pd.DataFrame(data={"Unique names":unique_names,
224
- "Fake names": fake_names})
 
225
 
226
  # Convert mapping dataframe to dictionary, adding word boundaries for full-word match
227
- name_map = {r'\b' + k + r'\b': v for k, v in zip(mapping_df['Unique names'], mapping_df['Fake names'])}
 
 
 
228
 
229
  name_map
230
 
231
- scrubbed_df_consistent_names = df.replace(name_map, regex = True)
232
 
233
  scrubbed_df_consistent_names
234
 
235
  return scrubbed_df_consistent_names
236
 
 
237
  def handle_docx_anonymisation(
238
  file_path: str,
239
  output_folder: str,
@@ -247,7 +329,7 @@ def handle_docx_anonymisation(
247
  comprehend_query_number: int,
248
  comprehend_client: BaseClient,
249
  language: Optional[str] = DEFAULT_LANGUAGE,
250
- nlp_analyser: AnalyzerEngine = nlp_analyser
251
  ):
252
  """
253
  Anonymises a .docx file by extracting text, processing it, and re-inserting it.
@@ -255,11 +337,13 @@ def handle_docx_anonymisation(
255
  Returns:
256
  A tuple containing the output file path and the log file path.
257
  """
258
-
259
  # 1. Load the document and extract text elements
260
  doc = docx.Document(file_path)
261
- text_elements = list() # This will store the actual docx objects (paragraphs, cells)
262
- original_texts = list() # This will store the text from those objects
 
 
263
 
264
  paragraph_count = len(doc.paragraphs)
265
 
@@ -278,18 +362,18 @@ def handle_docx_anonymisation(
278
  for table in doc.tables:
279
  for row in table.rows:
280
  for cell in row.cells:
281
- if cell.text.strip(): # Only process non-empty cells
282
  text_elements.append(cell)
283
  original_texts.append(cell.text)
284
-
285
  # If there's no text to process, return early
286
  if not original_texts:
287
  print(f"No text found in {file_path}. Skipping.")
288
  return None, None, 0
289
 
290
  # 2. Convert to a DataFrame for the existing anonymisation script
291
- df_to_anonymise = pd.DataFrame({'text_to_redact': original_texts})
292
-
293
  # 3. Call the core anonymisation script
294
  anonymised_df, _, decision_log, comprehend_query_number = anonymise_script(
295
  df=df_to_anonymise,
@@ -303,10 +387,10 @@ def handle_docx_anonymisation(
303
  chosen_redact_comprehend_entities=chosen_redact_comprehend_entities,
304
  comprehend_query_number=comprehend_query_number,
305
  comprehend_client=comprehend_client,
306
- nlp_analyser=nlp_analyser
307
  )
308
 
309
- anonymised_texts = anonymised_df['text_to_redact'].tolist()
310
 
311
  # 4. Re-insert the anonymised text back into the document objects
312
  for element, new_text in zip(text_elements, anonymised_texts):
@@ -321,44 +405,53 @@ def handle_docx_anonymisation(
321
  # 5. Save the redacted document and the log file
322
  base_name = os.path.basename(file_path)
323
  file_name_without_ext = os.path.splitext(base_name)[0]
324
-
325
- output_docx_path = os.path.join(output_folder, f"{file_name_without_ext}_redacted.docx")
326
- log_file_path = os.path.join(output_folder, f"{file_name_without_ext}_redacted_log.txt")
327
 
328
- output_xlsx_path = os.path.join(output_folder, f"{file_name_without_ext}_redacted.csv")
 
 
 
 
 
329
 
330
- anonymised_df.to_csv(output_xlsx_path, encoding="utf-8-sig", index=None)
 
 
 
 
331
  doc.save(output_docx_path)
332
-
333
  with open(log_file_path, "w", encoding="utf-8-sig") as f:
334
  f.write(decision_log)
335
 
336
  return output_docx_path, log_file_path, output_xlsx_path, comprehend_query_number
337
 
338
- def anonymise_files_with_open_text(file_paths: List[str],
339
- in_text: str,
340
- anon_strategy: str,
341
- chosen_cols: List[str],
342
- chosen_redact_entities: List[str],
343
- in_allow_list: List[str] = None,
344
- latest_file_completed: int = 0,
345
- out_message: list = list(),
346
- out_file_paths: list = list(),
347
- log_files_output_paths: list = list(),
348
- in_excel_sheets: list = list(),
349
- first_loop_state: bool = False,
350
- output_folder: str = OUTPUT_FOLDER,
351
- in_deny_list:list[str]=list(),
352
- max_fuzzy_spelling_mistakes_num:int=0,
353
- pii_identification_method:str="Local",
354
- chosen_redact_comprehend_entities:List[str]=list(),
355
- comprehend_query_number:int=0,
356
- aws_access_key_textbox:str='',
357
- aws_secret_key_textbox:str='',
358
- actual_time_taken_number:float=0,
359
- do_initial_clean:bool=DO_INITIAL_TABULAR_DATA_CLEAN,
360
- language: Optional[str] = None,
361
- progress: Progress = Progress(track_tqdm=True)):
 
 
 
362
  """
363
  This function anonymises data files based on the provided parameters.
364
 
@@ -379,7 +472,7 @@ def anonymise_files_with_open_text(file_paths: List[str],
379
  - output_folder (str, optional): The output folder path. Defaults to the global output_folder variable.
380
  - in_deny_list (list[str], optional): A list of specific terms to redact.
381
  - max_fuzzy_spelling_mistakes_num (int, optional): The maximum number of spelling mistakes allowed in a searched phrase for fuzzy matching. Can range from 0-9.
382
- - pii_identification_method (str, optional): The method to redact personal information. Either 'Local' (spacy model), or 'AWS Comprehend' (AWS Comprehend API).
383
  - chosen_redact_comprehend_entities (List[str]): A list of entity types to redact from files, chosen from the official list from AWS Comprehend service.
384
  - comprehend_query_number (int, optional): A counter tracking the number of queries to AWS Comprehend.
385
  - aws_access_key_textbox (str, optional): AWS access key for account with Textract and Comprehend permissions.
@@ -389,13 +482,14 @@ def anonymise_files_with_open_text(file_paths: List[str],
389
  - progress (Progress, optional): A Progress object to track progress. Defaults to a Progress object with track_tqdm=True.
390
  - do_initial_clean (bool, optional): Whether to perform an initial cleaning of the text. Defaults to True.
391
  """
392
-
393
  tic = time.perf_counter()
394
  comprehend_client = ""
395
 
396
  # If output folder doesn't end with a forward slash, add one
397
- if not output_folder.endswith('/'): output_folder = output_folder + '/'
398
-
 
399
  # Use provided language or default
400
  language = language or DEFAULT_LANGUAGE
401
 
@@ -405,7 +499,7 @@ def anonymise_files_with_open_text(file_paths: List[str],
405
  raise Warning(out_message)
406
 
407
  # If this is the first time around, set variables to 0/blank
408
- if first_loop_state==True:
409
  latest_file_completed = 0
410
  out_message = list()
411
  out_file_paths = list()
@@ -415,14 +509,14 @@ def anonymise_files_with_open_text(file_paths: List[str],
415
  if isinstance(out_message, str):
416
  out_message = [out_message]
417
 
418
- #print("log_files_output_paths:",log_files_output_paths)
419
 
420
  if isinstance(log_files_output_paths, str):
421
  log_files_output_paths = list()
422
 
423
  if not out_file_paths:
424
- out_file_paths = list()
425
-
426
  if isinstance(in_allow_list, list):
427
  if in_allow_list:
428
  in_allow_list_flat = in_allow_list
@@ -435,39 +529,45 @@ def anonymise_files_with_open_text(file_paths: List[str],
435
  in_allow_list_flat = list()
436
  else:
437
  in_allow_list_flat = list()
438
-
439
  anon_df = pd.DataFrame()
440
 
441
- # Try to connect to AWS services directly only if RUN_AWS_FUNCTIONS environmental variable is 1, otherwise an environment variable or direct textbox input is needed.
442
  if pii_identification_method == "AWS Comprehend":
443
  print("Trying to connect to AWS Comprehend service")
444
  if RUN_AWS_FUNCTIONS == "1" and PRIORITISE_SSO_OVER_AWS_ENV_ACCESS_KEYS == "1":
445
  print("Connecting to Comprehend via existing SSO connection")
446
- comprehend_client = boto3.client('comprehend', region_name=AWS_REGION)
447
  elif aws_access_key_textbox and aws_secret_key_textbox:
448
- print("Connecting to Comprehend using AWS access key and secret keys from textboxes.")
 
 
449
  print("aws_access_key_textbox:", aws_access_key_textbox)
450
  print("aws_secret_access_key:", aws_secret_key_textbox)
451
- comprehend_client = boto3.client('comprehend',
452
- aws_access_key_id=aws_access_key_textbox,
453
- aws_secret_access_key=aws_secret_key_textbox)
 
 
454
  elif RUN_AWS_FUNCTIONS == "1":
455
  print("Connecting to Comprehend via existing SSO connection")
456
- comprehend_client = boto3.client('comprehend')
457
  elif AWS_ACCESS_KEY and AWS_SECRET_KEY:
458
  print("Getting Comprehend credentials from environment variables")
459
- comprehend_client = boto3.client('comprehend',
460
- aws_access_key_id=AWS_ACCESS_KEY,
461
- aws_secret_access_key=AWS_SECRET_KEY)
 
 
462
  else:
463
  comprehend_client = ""
464
  out_message = "Cannot connect to AWS Comprehend service. Please provide access keys under Textract settings on the Redaction settings tab, or choose another PII identification method."
465
- raise(out_message)
466
-
467
  # Check if files and text exist
468
  if not file_paths:
469
  if in_text:
470
- file_paths=['open_text']
471
  else:
472
  out_message = "Please enter text or a file to redact."
473
  raise Exception(out_message)
@@ -479,31 +579,73 @@ def anonymise_files_with_open_text(file_paths: List[str],
479
  out_message = f"Number of files to anonymise is greater than {MAX_SIMULTANEOUS_FILES}. Please submit a smaller number of files."
480
  print(out_message)
481
  raise Exception(out_message)
482
-
483
  # If we have already redacted the last file, return the input out_message and file list to the relevant components
484
  if latest_file_completed >= len(file_paths):
485
- print("Last file reached") #, returning files:", str(latest_file_completed))
486
  # Set to a very high number so as not to mess with subsequent file processing by the user
487
- #latest_file_completed = 99
488
- final_out_message = '\n'.join(out_message)
489
- return final_out_message, out_file_paths, out_file_paths, latest_file_completed, log_files_output_paths, log_files_output_paths, actual_time_taken_number, comprehend_query_number
490
-
 
 
 
 
 
 
 
 
 
491
  file_path_loop = [file_paths[int(latest_file_completed)]]
492
 
493
- for anon_file in progress.tqdm(file_path_loop, desc="Anonymising files", unit = "files"):
 
 
494
 
495
  # Get a string file path
496
- if isinstance(anon_file, str): file_path = anon_file
497
- else: file_path = anon_file
 
 
498
 
499
- if anon_file=='open_text':
500
- anon_df = pd.DataFrame(data={'text':[in_text]})
501
- chosen_cols=['text']
502
  out_file_part = anon_file
503
  sheet_name = ""
504
  file_type = ""
505
 
506
- out_file_paths, out_message, key_string, log_files_output_paths, comprehend_query_number = tabular_anonymise_wrapper_func(file_path, anon_df, chosen_cols, out_file_paths, out_file_part, out_message, sheet_name, anon_strategy, language, chosen_redact_entities, in_allow_list, file_type, "", log_files_output_paths, in_deny_list, max_fuzzy_spelling_mistakes_num, pii_identification_method, chosen_redact_comprehend_entities, comprehend_query_number, comprehend_client, output_folder=OUTPUT_FOLDER, do_initial_clean=do_initial_clean)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
507
  else:
508
  # If file is an xlsx, we are going to run through all the Excel sheets to anonymise them separately.
509
  file_type = detect_file_type(file_path)
@@ -511,20 +653,22 @@ def anonymise_files_with_open_text(file_paths: List[str],
511
 
512
  out_file_part = get_file_name_without_type(file_path)
513
 
514
- if file_type == 'docx':
515
- output_path, log_path, output_xlsx_path, comprehend_query_number = handle_docx_anonymisation(
516
- file_path=file_path,
517
- output_folder=output_folder,
518
- anon_strategy=anon_strategy,
519
- chosen_redact_entities=chosen_redact_entities,
520
- in_allow_list=in_allow_list_flat,
521
- in_deny_list=in_deny_list,
522
- max_fuzzy_spelling_mistakes_num=max_fuzzy_spelling_mistakes_num,
523
- pii_identification_method=pii_identification_method,
524
- chosen_redact_comprehend_entities=chosen_redact_comprehend_entities,
525
- comprehend_query_number=comprehend_query_number,
526
- comprehend_client=comprehend_client,
527
- language=language
 
 
528
  )
529
  if output_path:
530
  out_file_paths.append(output_path)
@@ -532,34 +676,100 @@ def anonymise_files_with_open_text(file_paths: List[str],
532
  out_file_paths.append(output_xlsx_path)
533
  if log_path:
534
  log_files_output_paths.append(log_path)
535
-
536
- elif file_type == 'xlsx':
537
  print("Running through all xlsx sheets")
538
- #anon_xlsx = pd.ExcelFile(anon_file)
539
  if not in_excel_sheets:
540
- out_message.append("No Excel sheets selected. Please select at least one to anonymise.")
 
 
541
  continue
542
 
543
  # Create xlsx file:
544
- anon_xlsx = pd.ExcelFile(file_path)
545
- anon_xlsx_export_file_name = output_folder + out_file_part + "_redacted.xlsx"
 
 
546
 
547
  # Iterate through the sheet names
548
- for sheet_name in progress.tqdm(in_excel_sheets, desc="Anonymising sheets", unit = "sheets"):
 
 
549
  # Read each sheet into a DataFrame
550
  if sheet_name not in anon_xlsx.sheet_names:
551
  continue
552
 
553
  anon_df = pd.read_excel(file_path, sheet_name=sheet_name)
554
 
555
- out_file_paths, out_message, key_string, log_files_output_paths, comprehend_query_number = tabular_anonymise_wrapper_func(anon_file, anon_df, chosen_cols, out_file_paths, out_file_part, out_message, sheet_name, anon_strategy, language, chosen_redact_entities, in_allow_list, file_type, anon_xlsx_export_file_name, log_files_output_paths, in_deny_list, max_fuzzy_spelling_mistakes_num, pii_identification_method, language, chosen_redact_comprehend_entities, comprehend_query_number, comprehend_client, output_folder=output_folder, do_initial_clean=do_initial_clean)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
556
 
557
  else:
558
  sheet_name = ""
559
  anon_df = read_file(file_path)
560
  out_file_part = get_file_name_without_type(file_path)
561
 
562
- out_file_paths, out_message, key_string, log_files_output_paths, comprehend_query_number = tabular_anonymise_wrapper_func(anon_file, anon_df, chosen_cols, out_file_paths, out_file_part, out_message, sheet_name, anon_strategy, language, chosen_redact_entities, in_allow_list, file_type, "", log_files_output_paths, in_deny_list, max_fuzzy_spelling_mistakes_num, pii_identification_method, language, chosen_redact_comprehend_entities, comprehend_query_number, comprehend_client, output_folder=output_folder, do_initial_clean=do_initial_clean)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
563
 
564
  # Increase latest file completed count unless we are at the last file
565
  if latest_file_completed != len(file_paths):
@@ -575,46 +785,61 @@ def anonymise_files_with_open_text(file_paths: List[str],
575
 
576
  if isinstance(out_message, str):
577
  out_message = [out_message]
578
-
579
- out_message.append("Anonymisation of file '" + out_file_part + "' successfully completed in")
580
 
581
- out_message_out = '\n'.join(out_message)
 
 
 
 
582
  out_message_out = out_message_out + " " + out_time
583
 
584
  if anon_strategy == "encrypt":
585
- out_message_out.append(". Your decryption key is " + key_string)
586
-
587
- out_message_out = out_message_out + "\n\nGo to to the Redaction settings tab to see redaction logs. Please give feedback on the results below to help improve this app."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
588
 
589
- out_message_out = re.sub(r'^\n+|^\. ', '', out_message_out).strip()
590
-
591
- return out_message_out, out_file_paths, out_file_paths, latest_file_completed, log_files_output_paths, log_files_output_paths, actual_time_taken_number, comprehend_query_number
592
 
593
  def tabular_anonymise_wrapper_func(
594
- anon_file: str,
595
- anon_df: pd.DataFrame,
596
- chosen_cols: List[str],
597
- out_file_paths: List[str],
598
- out_file_part: str,
599
- out_message: str,
600
- excel_sheet_name: str,
601
- anon_strategy: str,
602
  language: str,
603
- chosen_redact_entities: List[str],
604
- in_allow_list: List[str],
605
- file_type: str,
606
- anon_xlsx_export_file_name: str,
607
  log_files_output_paths: List[str],
608
- in_deny_list: List[str]=list(),
609
- max_fuzzy_spelling_mistakes_num:int=0,
610
- pii_identification_method:str="Local",
611
  comprehend_language: Optional[str] = None,
612
- chosen_redact_comprehend_entities:List[str]=list(),
613
- comprehend_query_number:int=0,
614
- comprehend_client:botocore.client.BaseClient="",
615
  nlp_analyser: AnalyzerEngine = nlp_analyser,
616
  output_folder: str = OUTPUT_FOLDER,
617
- do_initial_clean:bool=DO_INITIAL_TABULAR_DATA_CLEAN
618
  ):
619
  """
620
  This function wraps the anonymisation process for a given dataframe. It filters the dataframe based on chosen columns, applies the specified anonymisation strategy using the anonymise_script function, and exports the anonymised data to a file.
@@ -636,16 +861,17 @@ def tabular_anonymise_wrapper_func(
636
  - log_files_output_paths: A list of paths where the log files will be saved.
637
  - in_deny_list: List of specific terms to remove from the data.
638
  - max_fuzzy_spelling_mistakes_num (int, optional): The maximum number of spelling mistakes allowed in a searched phrase for fuzzy matching. Can range from 0-9.
639
- - pii_identification_method (str, optional): The method to redact personal information. Either 'Local' (spacy model), or 'AWS Comprehend' (AWS Comprehend API).
640
  - chosen_redact_comprehend_entities (List[str]): A list of entity types to redact from files, chosen from the official list from AWS Comprehend service.
641
  - comprehend_query_number (int, optional): A counter tracking the number of queries to AWS Comprehend.
642
- - comprehend_client (optional): The client object from AWS containing a client connection to AWS Comprehend if that option is chosen on the first tab.
643
  - output_folder: The folder where the anonymized files will be saved. Defaults to the 'output_folder' variable.
644
  - do_initial_clean (bool, optional): Whether to perform an initial cleaning of the text. Defaults to True.
645
  """
 
646
  def check_lists(list1, list2):
647
- return any(string in list2 for string in list1)
648
-
649
  def get_common_strings(list1, list2):
650
  """
651
  Finds the common strings between two lists.
@@ -664,62 +890,95 @@ def tabular_anonymise_wrapper_func(
664
  return common_strings
665
 
666
  if pii_identification_method == "AWS Comprehend" and comprehend_client == "":
667
- raise("Connection to AWS Comprehend service not found, please check connection details.")
668
-
 
 
669
  # Check for chosen col, skip file if not found
670
  all_cols_original_order = list(anon_df.columns)
671
 
672
  any_cols_found = check_lists(chosen_cols, all_cols_original_order)
673
 
674
- if any_cols_found == False:
675
  out_message = "No chosen columns found in dataframe: " + out_file_part
676
  key_string = ""
677
  print(out_message)
678
- return out_file_paths, out_message, key_string, log_files_output_paths, comprehend_query_number
 
 
 
 
 
 
679
  else:
680
- chosen_cols_in_anon_df = get_common_strings(chosen_cols, all_cols_original_order)
 
 
681
 
682
  # Split dataframe to keep only selected columns
683
- #print("Remaining columns to redact:", chosen_cols_in_anon_df)
684
-
685
  if not anon_df.index.is_unique:
686
  anon_df = anon_df.reset_index(drop=True)
687
 
688
  anon_df_part = anon_df[chosen_cols_in_anon_df]
689
- anon_df_remain = anon_df.drop(chosen_cols_in_anon_df, axis = 1)
690
 
691
  row_count = anon_df_part.shape[0]
692
-
693
  if row_count > MAX_TABLE_ROWS:
694
  out_message = f"Number of rows in dataframe is greater than {MAX_TABLE_ROWS}. Please submit a smaller dataframe."
695
  print(out_message)
696
  raise Exception(out_message)
697
-
698
  column_count = anon_df_part.shape[1]
699
-
700
  if column_count > MAX_TABLE_COLUMNS:
701
  out_message = f"Number of columns in dataframe is greater than {MAX_TABLE_COLUMNS}. Please submit a smaller dataframe."
702
  print(out_message)
703
  raise Exception(out_message)
704
 
705
  # Anonymise the selected columns
706
- anon_df_part_out, key_string, decision_process_output_str, comprehend_query_number = anonymise_script(anon_df_part, anon_strategy, language, chosen_redact_entities, in_allow_list, in_deny_list, max_fuzzy_spelling_mistakes_num, pii_identification_method, chosen_redact_comprehend_entities, comprehend_query_number, comprehend_client, nlp_analyser=nlp_analyser, do_initial_clean=do_initial_clean)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
707
 
708
  anon_df_part_out.replace("^nan$", "", regex=True, inplace=True)
709
 
710
  # Rejoin the dataframe together
711
- anon_df_out = pd.concat([anon_df_part_out, anon_df_remain], axis = 1)
712
  anon_df_out = anon_df_out[all_cols_original_order]
713
-
714
  # Export file
715
  # Rename anonymisation strategy for file path naming
716
- if anon_strategy == "replace with 'REDACTED'": anon_strat_txt = "redact_replace"
717
- elif anon_strategy == "replace with <ENTITY_NAME>": anon_strat_txt = "redact_entity_type"
718
- elif anon_strategy == "redact completely": anon_strat_txt = "redact_remove"
719
- else: anon_strat_txt = anon_strategy
 
 
 
 
720
 
721
  # If the file is an xlsx, add a new sheet to the existing xlsx. Otherwise, write to csv
722
- if file_type == 'xlsx':
723
 
724
  anon_export_file_name = anon_xlsx_export_file_name
725
 
@@ -730,19 +989,33 @@ def tabular_anonymise_wrapper_func(
730
  wb.save(anon_xlsx_export_file_name)
731
 
732
  # Create a Pandas Excel writer using XlsxWriter as the engine.
733
- with pd.ExcelWriter(anon_xlsx_export_file_name, engine='openpyxl', mode='a', if_sheet_exists='replace') as writer:
 
 
 
 
 
734
  # Write each DataFrame to a different worksheet.
735
  anon_df_out.to_excel(writer, sheet_name=excel_sheet_name, index=None)
736
 
737
- decision_process_log_output_file = anon_xlsx_export_file_name + "_" + excel_sheet_name + "_decision_process_output.txt"
 
 
 
 
 
738
  with open(decision_process_log_output_file, "w") as f:
739
  f.write(decision_process_output_str)
740
 
741
  else:
742
- anon_export_file_name = output_folder + out_file_part + "_anon_" + anon_strat_txt + ".csv"
743
- anon_df_out.to_csv(anon_export_file_name, index = None, encoding="utf-8-sig")
744
-
745
- decision_process_log_output_file = anon_export_file_name + "_decision_process_output.txt"
 
 
 
 
746
  with open(decision_process_log_output_file, "w") as f:
747
  f.write(decision_process_output_str)
748
 
@@ -753,27 +1026,36 @@ def tabular_anonymise_wrapper_func(
753
  out_file_paths = list(set(out_file_paths))
754
 
755
  # Print result text to output text box if just anonymising open text
756
- if anon_file=='open_text':
757
- out_message = ["'" + anon_df_out['text'][0] + "'"]
758
-
759
- return out_file_paths, out_message, key_string, log_files_output_paths, comprehend_query_number
760
-
761
- def anonymise_script(df:pd.DataFrame,
762
- anon_strategy:str,
763
- language:str,
764
- chosen_redact_entities:List[str],
765
- in_allow_list:List[str]=list(),
766
- in_deny_list:List[str]=list(),
767
- max_fuzzy_spelling_mistakes_num:int=0,
768
- pii_identification_method:str="Local",
769
- chosen_redact_comprehend_entities:List[str]=list(),
770
- comprehend_query_number:int=0,
771
- comprehend_client:botocore.client.BaseClient="",
772
- custom_entities:List[str]=custom_entities,
773
- nlp_analyser: AnalyzerEngine = nlp_analyser,
774
- do_initial_clean:bool=DO_INITIAL_TABULAR_DATA_CLEAN,
775
- progress:Progress=Progress(track_tqdm=True)):
776
- '''
 
 
 
 
 
 
 
 
 
777
  Conduct anonymisation of a dataframe using Presidio and/or AWS Comprehend if chosen.
778
 
779
  Args:
@@ -792,14 +1074,14 @@ def anonymise_script(df:pd.DataFrame,
792
  nlp_analyser (AnalyzerEngine, optional): The Presidio AnalyzerEngine instance to use. Defaults to `nlp_analyser`.
793
  do_initial_clean (bool, optional): Whether to perform an initial cleaning of the text. Defaults to True.
794
  progress (Progress, optional): Gradio Progress object for tracking progress. Defaults to Progress(track_tqdm=False).
795
- '''
796
 
797
  print("Identifying personal information")
798
  analyse_tic = time.perf_counter()
799
 
800
  # Initialize analyzer_results as an empty dictionary to store results by column
801
  results_by_column = dict()
802
- key_string = ""
803
 
804
  if isinstance(in_allow_list, list):
805
  if in_allow_list:
@@ -818,20 +1100,22 @@ def anonymise_script(df:pd.DataFrame,
818
  try:
819
  if language != "en":
820
  progress(0.1, desc=f"Loading spaCy model for {language}")
821
-
822
  load_spacy_model(language)
823
-
824
  except Exception as e:
825
  out_message = f"Error downloading language packs for {language}: {e}"
826
  print(out_message)
827
  raise Exception(out_message)
828
-
829
  # Try updating the supported languages for the spacy analyser
830
  try:
831
  nlp_analyser = create_nlp_analyser(language, existing_nlp_analyser=nlp_analyser)
832
  # Check list of nlp_analyser recognisers and languages
833
  if language != "en":
834
- gr.Info(f"Language: {language} only supports the following entity detection: {str(nlp_analyser.registry.get_supported_entities(languages=[language]))}")
 
 
835
 
836
  except Exception as e:
837
  out_message = f"Error creating nlp_analyser for {language}: {e}"
@@ -848,40 +1132,49 @@ def anonymise_script(df:pd.DataFrame,
848
  # Sort the strings in order from the longest string to the shortest
849
  in_deny_list = sorted(in_deny_list, key=len, reverse=True)
850
 
851
- if in_deny_list:
852
  nlp_analyser.registry.remove_recognizer("CUSTOM")
853
  new_custom_recogniser = custom_word_list_recogniser(in_deny_list)
854
  nlp_analyser.registry.add_recognizer(new_custom_recogniser)
855
 
856
  nlp_analyser.registry.remove_recognizer("CustomWordFuzzyRecognizer")
857
- new_custom_fuzzy_recogniser = CustomWordFuzzyRecognizer(supported_entities=["CUSTOM_FUZZY"], custom_list=in_deny_list, spelling_mistakes_max=in_deny_list, search_whole_phrase=max_fuzzy_spelling_mistakes_num)
 
 
 
 
 
858
  nlp_analyser.registry.add_recognizer(new_custom_fuzzy_recogniser)
859
 
860
- #analyzer = nlp_analyser #AnalyzerEngine()
861
  batch_analyzer = BatchAnalyzerEngine(analyzer_engine=nlp_analyser)
862
- anonymizer = AnonymizerEngine()#conflict_resolution=ConflictResolutionStrategy.MERGE_SIMILAR_OR_CONTAINED)
863
- batch_anonymizer = BatchAnonymizerEngine(anonymizer_engine = anonymizer)
 
 
864
  analyzer_results = list()
865
 
866
  if do_initial_clean:
867
  progress(0.2, desc="Cleaning text")
868
- for col in progress.tqdm(df.columns, desc="Cleaning text", unit = "Columns"):
869
  df[col] = initial_clean(df[col])
870
 
871
  # DataFrame to dict
872
- df_dict = df.to_dict(orient="list")
873
 
874
  if pii_identification_method == "Local":
875
 
876
  # Use custom analyzer to be able to track progress with Gradio
877
- custom_results = analyze_dict(batch_analyzer,
878
- df_dict,
879
- language=language,
880
- entities=chosen_redact_entities,
881
- score_threshold=score_threshold,
882
- return_decision_process=True,
883
- allow_list=in_allow_list_flat)
884
-
 
 
885
  # Initialize results_by_column with custom entity results
886
  for result in custom_results:
887
  results_by_column[result.key] = result
@@ -891,23 +1184,26 @@ def anonymise_script(df:pd.DataFrame,
891
 
892
  # AWS Comprehend calls
893
  elif pii_identification_method == "AWS Comprehend" and comprehend_client:
894
-
895
  # Only run Local anonymisation for entities that are not covered by AWS Comprehend
896
  if custom_entities:
897
  custom_redact_entities = [
898
- entity for entity in chosen_redact_comprehend_entities
 
899
  if entity in custom_entities
900
  ]
901
  if custom_redact_entities:
902
  # Get results from analyze_dict
903
- custom_results = analyze_dict(batch_analyzer,
904
- df_dict,
905
- language=language,
906
- entities=custom_redact_entities,
907
- score_threshold=score_threshold,
908
- return_decision_process=True,
909
- allow_list=in_allow_list_flat)
910
-
 
 
911
  # Initialize results_by_column with custom entity results
912
  for result in custom_results:
913
  results_by_column[result.key] = result
@@ -916,47 +1212,56 @@ def anonymise_script(df:pd.DataFrame,
916
  retry_delay = 3
917
 
918
  # Process each text column in the dictionary
919
- for column_name, texts in progress.tqdm(df_dict.items(), desc="Querying AWS Comprehend service.", unit = "Columns"):
 
 
920
  # Get or create DictAnalyzerResult for this column
921
  if column_name in results_by_column:
922
  column_results = results_by_column[column_name]
923
  else:
924
  column_results = DictAnalyzerResult(
925
- recognizer_results=[[] for _ in texts],
926
- key=column_name,
927
- value=texts
928
  )
929
 
930
  # Process each text in the column
931
- for text_idx, text in progress.tqdm(enumerate(texts), desc="Querying AWS Comprehend service.", unit = "Row"):
 
 
932
 
933
  for attempt in range(max_retries):
934
  try:
935
  response = comprehend_client.detect_pii_entities(
936
- Text=str(text),
937
- LanguageCode=language
938
  )
939
 
940
  comprehend_query_number += 1
941
 
942
  # Add all entities from this text to the column's recognizer_results
943
  for entity in response["Entities"]:
944
- if entity.get("Type") not in chosen_redact_comprehend_entities:
 
 
 
945
  continue
946
 
947
  recognizer_result = RecognizerResult(
948
  entity_type=entity["Type"],
949
  start=entity["BeginOffset"],
950
  end=entity["EndOffset"],
951
- score=entity["Score"]
952
  )
953
- column_results.recognizer_results[text_idx].append(recognizer_result)
954
-
 
 
955
  break # Success, exit retry loop
956
-
957
  except Exception as e:
958
  if attempt == max_retries - 1:
959
- print(f"AWS Comprehend calls failed for text: {text[:100]}... due to", e)
 
 
 
960
  raise
961
  time.sleep(retry_delay)
962
 
@@ -967,56 +1272,79 @@ def anonymise_script(df:pd.DataFrame,
967
  analyzer_results = list(results_by_column.values())
968
 
969
  elif (pii_identification_method == "AWS Comprehend") & (not comprehend_client):
970
- raise("Unable to redact, Comprehend connection details not found.")
971
-
972
  else:
973
  print("Unable to redact.")
974
 
975
  # Usage in the main function:
976
- decision_process_output_str = generate_decision_process_output(analyzer_results, df_dict)
 
 
977
 
978
  analyse_toc = time.perf_counter()
979
- analyse_time_out = f"Analysing the text took {analyse_toc - analyse_tic:0.1f} seconds."
 
 
980
  print(analyse_time_out)
981
 
982
  # Set up the anonymization configuration WITHOUT DATE_TIME
983
- simple_replace_config = eval('{"DEFAULT": OperatorConfig("replace", {"new_value": "REDACTED"})}')
984
- replace_config = eval('{"DEFAULT": OperatorConfig("replace")}')
985
- redact_config = eval('{"DEFAULT": OperatorConfig("redact")}')
986
- hash_config = eval('{"DEFAULT": OperatorConfig("hash")}')
987
- mask_config = eval('{"DEFAULT": OperatorConfig("mask", {"masking_char":"*", "chars_to_mask":100, "from_end":True})}')
988
- people_encrypt_config = eval('{"PERSON": OperatorConfig("encrypt", {"key": key_string})}') # The encryption is using AES cypher in CBC mode and requires a cryptographic key as an input for both the encryption and the decryption.
989
- fake_first_name_config = eval('{"PERSON": OperatorConfig("custom", {"lambda": fake_first_name})}')
990
-
991
- if anon_strategy == "replace with 'REDACTED'": chosen_mask_config = simple_replace_config
992
- elif anon_strategy == "replace_redacted": chosen_mask_config = simple_replace_config
993
- elif anon_strategy == "replace with <ENTITY_NAME>": chosen_mask_config = replace_config
994
- elif anon_strategy == "entity_type": chosen_mask_config = replace_config
995
- elif anon_strategy == "redact completely": chosen_mask_config = redact_config
996
- elif anon_strategy == "redact": chosen_mask_config = redact_config
997
- elif anon_strategy == "hash": chosen_mask_config = hash_config
998
- elif anon_strategy == "mask": chosen_mask_config = mask_config
999
- elif anon_strategy == "encrypt":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1000
  chosen_mask_config = people_encrypt_config
1001
  key = secrets.token_bytes(16) # 128 bits = 16 bytes
1002
- key_string = base64.b64encode(key).decode('utf-8')
1003
-
1004
  # Now inject the key into the operator config
1005
  for entity, operator in chosen_mask_config.items():
1006
  if operator.operator_name == "encrypt":
1007
  operator.params = {"key": key_string}
1008
- elif anon_strategy == "fake_first_name": chosen_mask_config = fake_first_name_config
 
1009
  else:
1010
  print("Anonymisation strategy not found. Redacting completely by default.")
1011
- chosen_mask_config = redact_config # Redact completely by default
1012
-
1013
- # I think in general people will want to keep date / times - removed Mar 2025 as I don't want to assume for people.
1014
- #keep_date_config = eval('{"DATE_TIME": OperatorConfig("keep")}')
1015
 
1016
- combined_config = {**chosen_mask_config} #, **keep_date_config}
1017
 
1018
- anonymizer_results = batch_anonymizer.anonymize_dict(analyzer_results, operators=combined_config)
 
1019
 
1020
  scrubbed_df = pd.DataFrame(anonymizer_results)
1021
-
1022
- return scrubbed_df, key_string, decision_process_output_str, comprehend_query_number
 
1
+ import base64
2
  import os
3
+ import re
4
  import secrets
 
5
  import time
6
+ import unicodedata
7
+ from typing import Any, Dict, List, Optional
8
+
9
  import boto3
10
  import botocore
 
 
 
11
  import docx
12
  import gradio as gr
13
+ import pandas as pd
14
+ import polars as pl
15
+ from botocore.client import BaseClient
16
  from faker import Faker
17
  from gradio import Progress
18
+ from openpyxl import Workbook
19
+ from presidio_analyzer import (
20
+ AnalyzerEngine,
21
+ BatchAnalyzerEngine,
22
+ DictAnalyzerResult,
23
+ RecognizerResult
24
+ )
25
  from presidio_anonymizer import AnonymizerEngine, BatchAnonymizerEngine
26
+ from presidio_anonymizer.entities import OperatorConfig
27
+
28
+ from tools.config import (
29
+ AWS_ACCESS_KEY,
30
+ AWS_REGION,
31
+ AWS_SECRET_KEY,
32
+ CUSTOM_ENTITIES,
33
+ DEFAULT_LANGUAGE,
34
+ DO_INITIAL_TABULAR_DATA_CLEAN,
35
+ MAX_SIMULTANEOUS_FILES,
36
+ MAX_TABLE_COLUMNS,
37
+ MAX_TABLE_ROWS,
38
+ OUTPUT_FOLDER,
39
+ PRIORITISE_SSO_OVER_AWS_ENV_ACCESS_KEYS,
40
+ RUN_AWS_FUNCTIONS,
41
+ aws_comprehend_language_choices,
42
+ )
43
+ from tools.helper_functions import (
44
+ _get_env_list,
45
+ detect_file_type,
46
+ get_file_name_without_type,
47
+ read_file,
48
+ )
49
+ from tools.load_spacy_model_custom_recognisers import (
50
+ CustomWordFuzzyRecognizer,
51
+ create_nlp_analyser,
52
+ custom_word_list_recogniser,
53
+ load_spacy_model,
54
+ nlp_analyser,
55
+ score_threshold,
56
+ )
57
+
58
  # Use custom version of analyze_dict to be able to track progress
59
  from tools.presidio_analyzer_custom import analyze_dict
60
 
61
+ if DO_INITIAL_TABULAR_DATA_CLEAN == "True":
62
+ DO_INITIAL_TABULAR_DATA_CLEAN = True
63
+ else:
64
+ DO_INITIAL_TABULAR_DATA_CLEAN = False
65
 
66
+ if CUSTOM_ENTITIES:
67
+ CUSTOM_ENTITIES = _get_env_list(CUSTOM_ENTITIES)
68
 
69
  custom_entities = CUSTOM_ENTITIES
70
 
71
  fake = Faker("en_UK")
72
+
73
+
74
  def fake_first_name(x):
75
  return fake.first_name()
76
 
77
+
78
  # #### Some of my cleaning functions
79
+ url_pattern = r"http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\\(\\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+|(?:www\.)[a-zA-Z0-9._-]+\.[a-zA-Z]{2,}"
80
+ html_pattern_regex = r"<.*?>|&([a-z0-9]+|#[0-9]{1,6}|#x[0-9a-f]{1,6});|\xa0|&nbsp;"
81
+ html_start_pattern_end_dots_regex = r"<(.*?)\.\."
82
+ non_ascii_pattern = r"[^\x00-\x7F]+"
83
+ and_sign_regex = r"&"
84
+ multiple_spaces_regex = r"\s{2,}"
85
+ multiple_new_lines_regex = r"(\r\n|\n)+"
86
  multiple_punctuation_regex = r"(\p{P})\p{P}+"
87
 
88
+
89
+ def initial_clean(texts: pd.Series) -> pd.Series:
90
+ """
91
  This function cleans the text by removing URLs, HTML tags, and non-ASCII characters.
92
+ """
93
  for text in texts:
94
  if not text or pd.isnull(text):
95
  text = ""
96
 
97
  # Normalize unicode characters to decompose any special forms
98
+ normalized_text = unicodedata.normalize("NFKC", text)
99
 
100
  # Replace smart quotes and special punctuation with standard ASCII equivalents
101
  replacements = {
102
+ "‘": "'",
103
+ "’": "'",
104
+ "“": '"',
105
+ "”": '"',
106
+ "–": "-",
107
+ "—": "-",
108
+ "…": "...",
109
+ "•": "*",
110
  }
111
 
112
  # Perform replacements
 
117
 
118
  # Convert to polars Series
119
  texts = pl.Series(texts).str.strip_chars()
120
+
121
  # Define a list of patterns and their replacements
122
  patterns = [
123
+ (multiple_new_lines_regex, " "),
124
+ (r"\r", ""),
125
+ (url_pattern, " "),
126
+ (html_pattern_regex, " "),
127
+ (html_start_pattern_end_dots_regex, " "),
128
+ (non_ascii_pattern, " "),
129
+ (multiple_spaces_regex, " "),
130
  (multiple_punctuation_regex, "${1}"),
131
+ (and_sign_regex, "and"),
132
  ]
133
+
134
  # Apply each regex replacement
135
  for pattern, replacement in patterns:
136
  texts = texts.str.replace_all(pattern, replacement)
137
+
138
  # Convert the series back to a list
139
  texts = texts.to_list()
140
+
141
  return texts
142
 
 
 
143
 
144
+ def process_recognizer_result(
145
+ result: RecognizerResult,
146
+ recognizer_result: RecognizerResult,
147
+ data_row: int,
148
+ dictionary_key: int,
149
+ df_dict: Dict[str, List[Any]],
150
+ keys_to_keep: List[str],
151
+ ) -> List[str]:
152
+ output = list()
153
+
154
+ if hasattr(result, "value"):
155
+ text = result.value[data_row]
156
+ else:
157
+ text = ""
158
+
159
+ if isinstance(recognizer_result, list):
160
+ for sub_result in recognizer_result:
161
+ if isinstance(text, str):
162
+ found_text = text[sub_result.start : sub_result.end]
163
+ else:
164
+ found_text = ""
165
+ analysis_explanation = {
166
+ key: sub_result.__dict__[key] for key in keys_to_keep
167
+ }
168
+ analysis_explanation.update(
169
+ {
170
+ "data_row": str(data_row),
171
+ "column": list(df_dict.keys())[dictionary_key],
172
+ "entity": found_text,
173
+ }
174
+ )
175
+ output.append(str(analysis_explanation))
176
+
177
+ return output
178
+
179
 
180
  # Writing decision making process to file
181
+ def generate_decision_process_output(
182
+ analyzer_results: List[DictAnalyzerResult], df_dict: Dict[str, List[Any]]
183
+ ) -> str:
184
  """
185
  Generate a detailed output of the decision process for entity recognition.
186
 
 
197
  str: A string containing the detailed decision process output.
198
  """
199
  decision_process_output = list()
200
+ keys_to_keep = ["entity_type", "start", "end"]
201
 
202
  # Run through each column to analyse for PII
203
  for i, result in enumerate(analyzer_results):
204
 
205
  # If a single result
206
  if isinstance(result, RecognizerResult):
207
+ decision_process_output.extend(
208
+ process_recognizer_result(result, result, 0, i, df_dict, keys_to_keep)
209
+ )
210
 
211
  # If a list of results
212
  elif isinstance(result, list) or isinstance(result, DictAnalyzerResult):
213
  for x, recognizer_result in enumerate(result.recognizer_results):
214
+ decision_process_output.extend(
215
+ process_recognizer_result(
216
+ result, recognizer_result, x, i, df_dict, keys_to_keep
217
+ )
218
+ )
219
 
220
  else:
221
  try:
222
+ decision_process_output.extend(
223
+ process_recognizer_result(
224
+ result, result, 0, i, df_dict, keys_to_keep
225
+ )
226
+ )
227
  except Exception as e:
228
  print(e)
229
 
230
+ decision_process_output_str = "\n".join(decision_process_output)
231
 
232
  return decision_process_output_str
233
 
234
+
235
+ def anon_consistent_names(df: pd.DataFrame) -> pd.DataFrame:
236
  # ## Pick out common names and replace them with the same person value
237
  df_dict = df.to_dict(orient="list")
238
 
239
+ # analyzer = AnalyzerEngine()
240
  batch_analyzer = BatchAnalyzerEngine(analyzer_engine=nlp_analyser)
241
 
242
  analyzer_results = batch_analyzer.analyze_dict(df_dict, language=DEFAULT_LANGUAGE)
 
250
 
251
  # Adjusting the parse_dict function to handle trailing ']'
252
  # Splitting the main data string into individual list strings
253
+ list_strs = data_str[1:-1].split("], [")
254
 
255
  def parse_dict(s):
256
+ s = s.strip("[]") # Removing any surrounding brackets
257
+ items = s.split(", ")
258
  d = {}
259
  for item in items:
260
+ key, value = item.split(": ")
261
+ if key == "score":
262
  d[key] = float(value)
263
+ elif key in ["start", "end"]:
264
  d[key] = int(value)
265
  else:
266
  d[key] = value
 
272
 
273
  for lst_str in list_strs:
274
  # Splitting each list string into individual dictionary strings
275
+ dict_strs = lst_str.split(", type: ")
276
+ dict_strs = [dict_strs[0]] + [
277
+ "type: " + s for s in dict_strs[1:]
278
+ ] # Prepending "type: " back to the split strings
279
+
280
  # Parsing each dictionary string
281
  dicts = [parse_dict(d) for d in dict_strs]
282
  result.append(dicts)
 
286
  for idx, paragraph in enumerate(text):
287
  paragraph_texts = list()
288
  for dictionary in result[idx]:
289
+ if dictionary["type"] == "PERSON":
290
+ paragraph_texts.append(
291
+ paragraph[dictionary["start"] : dictionary["end"]]
292
+ )
293
  names.append(paragraph_texts)
294
 
295
  # Flatten the list of lists and extract unique names
296
  unique_names = list(set(name for sublist in names for name in sublist))
297
+
298
  fake_names = pd.Series(unique_names).apply(fake_first_name)
299
 
300
+ mapping_df = pd.DataFrame(
301
+ data={"Unique names": unique_names, "Fake names": fake_names}
302
+ )
303
 
304
  # Convert mapping dataframe to dictionary, adding word boundaries for full-word match
305
+ name_map = {
306
+ r"\b" + k + r"\b": v
307
+ for k, v in zip(mapping_df["Unique names"], mapping_df["Fake names"])
308
+ }
309
 
310
  name_map
311
 
312
+ scrubbed_df_consistent_names = df.replace(name_map, regex=True)
313
 
314
  scrubbed_df_consistent_names
315
 
316
  return scrubbed_df_consistent_names
317
 
318
+
319
  def handle_docx_anonymisation(
320
  file_path: str,
321
  output_folder: str,
 
329
  comprehend_query_number: int,
330
  comprehend_client: BaseClient,
331
  language: Optional[str] = DEFAULT_LANGUAGE,
332
+ nlp_analyser: AnalyzerEngine = nlp_analyser,
333
  ):
334
  """
335
  Anonymises a .docx file by extracting text, processing it, and re-inserting it.
 
337
  Returns:
338
  A tuple containing the output file path and the log file path.
339
  """
340
+
341
  # 1. Load the document and extract text elements
342
  doc = docx.Document(file_path)
343
+ text_elements = (
344
+ list()
345
+ ) # This will store the actual docx objects (paragraphs, cells)
346
+ original_texts = list() # This will store the text from those objects
347
 
348
  paragraph_count = len(doc.paragraphs)
349
 
 
362
  for table in doc.tables:
363
  for row in table.rows:
364
  for cell in row.cells:
365
+ if cell.text.strip(): # Only process non-empty cells
366
  text_elements.append(cell)
367
  original_texts.append(cell.text)
368
+
369
  # If there's no text to process, return early
370
  if not original_texts:
371
  print(f"No text found in {file_path}. Skipping.")
372
  return None, None, 0
373
 
374
  # 2. Convert to a DataFrame for the existing anonymisation script
375
+ df_to_anonymise = pd.DataFrame({"text_to_redact": original_texts})
376
+
377
  # 3. Call the core anonymisation script
378
  anonymised_df, _, decision_log, comprehend_query_number = anonymise_script(
379
  df=df_to_anonymise,
 
387
  chosen_redact_comprehend_entities=chosen_redact_comprehend_entities,
388
  comprehend_query_number=comprehend_query_number,
389
  comprehend_client=comprehend_client,
390
+ nlp_analyser=nlp_analyser,
391
  )
392
 
393
+ anonymised_texts = anonymised_df["text_to_redact"].tolist()
394
 
395
  # 4. Re-insert the anonymised text back into the document objects
396
  for element, new_text in zip(text_elements, anonymised_texts):
 
405
  # 5. Save the redacted document and the log file
406
  base_name = os.path.basename(file_path)
407
  file_name_without_ext = os.path.splitext(base_name)[0]
 
 
 
408
 
409
+ output_docx_path = os.path.join(
410
+ output_folder, f"{file_name_without_ext}_redacted.docx"
411
+ )
412
+ log_file_path = os.path.join(
413
+ output_folder, f"{file_name_without_ext}_redacted_log.txt"
414
+ )
415
 
416
+ output_xlsx_path = os.path.join(
417
+ output_folder, f"{file_name_without_ext}_redacted.csv"
418
+ )
419
+
420
+ anonymised_df.to_csv(output_xlsx_path, encoding="utf-8-sig", index=None)
421
  doc.save(output_docx_path)
422
+
423
  with open(log_file_path, "w", encoding="utf-8-sig") as f:
424
  f.write(decision_log)
425
 
426
  return output_docx_path, log_file_path, output_xlsx_path, comprehend_query_number
427
 
428
+
429
+ def anonymise_files_with_open_text(
430
+ file_paths: List[str],
431
+ in_text: str,
432
+ anon_strategy: str,
433
+ chosen_cols: List[str],
434
+ chosen_redact_entities: List[str],
435
+ in_allow_list: List[str] = None,
436
+ latest_file_completed: int = 0,
437
+ out_message: list = list(),
438
+ out_file_paths: list = list(),
439
+ log_files_output_paths: list = list(),
440
+ in_excel_sheets: list = list(),
441
+ first_loop_state: bool = False,
442
+ output_folder: str = OUTPUT_FOLDER,
443
+ in_deny_list: list[str] = list(),
444
+ max_fuzzy_spelling_mistakes_num: int = 0,
445
+ pii_identification_method: str = "Local",
446
+ chosen_redact_comprehend_entities: List[str] = list(),
447
+ comprehend_query_number: int = 0,
448
+ aws_access_key_textbox: str = "",
449
+ aws_secret_key_textbox: str = "",
450
+ actual_time_taken_number: float = 0,
451
+ do_initial_clean: bool = DO_INITIAL_TABULAR_DATA_CLEAN,
452
+ language: Optional[str] = None,
453
+ progress: Progress = Progress(track_tqdm=True),
454
+ ):
455
  """
456
  This function anonymises data files based on the provided parameters.
457
 
 
472
  - output_folder (str, optional): The output folder path. Defaults to the global output_folder variable.
473
  - in_deny_list (list[str], optional): A list of specific terms to redact.
474
  - max_fuzzy_spelling_mistakes_num (int, optional): The maximum number of spelling mistakes allowed in a searched phrase for fuzzy matching. Can range from 0-9.
475
+ - pii_identification_method (str, optional): The method to redact personal information. Either 'Local' (spacy model), or 'AWS Comprehend' (AWS Comprehend API).
476
  - chosen_redact_comprehend_entities (List[str]): A list of entity types to redact from files, chosen from the official list from AWS Comprehend service.
477
  - comprehend_query_number (int, optional): A counter tracking the number of queries to AWS Comprehend.
478
  - aws_access_key_textbox (str, optional): AWS access key for account with Textract and Comprehend permissions.
 
482
  - progress (Progress, optional): A Progress object to track progress. Defaults to a Progress object with track_tqdm=True.
483
  - do_initial_clean (bool, optional): Whether to perform an initial cleaning of the text. Defaults to True.
484
  """
485
+
486
  tic = time.perf_counter()
487
  comprehend_client = ""
488
 
489
  # If output folder doesn't end with a forward slash, add one
490
+ if not output_folder.endswith("/"):
491
+ output_folder = output_folder + "/"
492
+
493
  # Use provided language or default
494
  language = language or DEFAULT_LANGUAGE
495
 
 
499
  raise Warning(out_message)
500
 
501
  # If this is the first time around, set variables to 0/blank
502
+ if first_loop_state is True:
503
  latest_file_completed = 0
504
  out_message = list()
505
  out_file_paths = list()
 
509
  if isinstance(out_message, str):
510
  out_message = [out_message]
511
 
512
+ # print("log_files_output_paths:",log_files_output_paths)
513
 
514
  if isinstance(log_files_output_paths, str):
515
  log_files_output_paths = list()
516
 
517
  if not out_file_paths:
518
+ out_file_paths = list()
519
+
520
  if isinstance(in_allow_list, list):
521
  if in_allow_list:
522
  in_allow_list_flat = in_allow_list
 
529
  in_allow_list_flat = list()
530
  else:
531
  in_allow_list_flat = list()
532
+
533
  anon_df = pd.DataFrame()
534
 
535
+ # Try to connect to AWS services directly only if RUN_AWS_FUNCTIONS environmental variable is 1, otherwise an environment variable or direct textbox input is needed.
536
  if pii_identification_method == "AWS Comprehend":
537
  print("Trying to connect to AWS Comprehend service")
538
  if RUN_AWS_FUNCTIONS == "1" and PRIORITISE_SSO_OVER_AWS_ENV_ACCESS_KEYS == "1":
539
  print("Connecting to Comprehend via existing SSO connection")
540
+ comprehend_client = boto3.client("comprehend", region_name=AWS_REGION)
541
  elif aws_access_key_textbox and aws_secret_key_textbox:
542
+ print(
543
+ "Connecting to Comprehend using AWS access key and secret keys from textboxes."
544
+ )
545
  print("aws_access_key_textbox:", aws_access_key_textbox)
546
  print("aws_secret_access_key:", aws_secret_key_textbox)
547
+ comprehend_client = boto3.client(
548
+ "comprehend",
549
+ aws_access_key_id=aws_access_key_textbox,
550
+ aws_secret_access_key=aws_secret_key_textbox,
551
+ )
552
  elif RUN_AWS_FUNCTIONS == "1":
553
  print("Connecting to Comprehend via existing SSO connection")
554
+ comprehend_client = boto3.client("comprehend")
555
  elif AWS_ACCESS_KEY and AWS_SECRET_KEY:
556
  print("Getting Comprehend credentials from environment variables")
557
+ comprehend_client = boto3.client(
558
+ "comprehend",
559
+ aws_access_key_id=AWS_ACCESS_KEY,
560
+ aws_secret_access_key=AWS_SECRET_KEY,
561
+ )
562
  else:
563
  comprehend_client = ""
564
  out_message = "Cannot connect to AWS Comprehend service. Please provide access keys under Textract settings on the Redaction settings tab, or choose another PII identification method."
565
+ raise (out_message)
566
+
567
  # Check if files and text exist
568
  if not file_paths:
569
  if in_text:
570
+ file_paths = ["open_text"]
571
  else:
572
  out_message = "Please enter text or a file to redact."
573
  raise Exception(out_message)
 
579
  out_message = f"Number of files to anonymise is greater than {MAX_SIMULTANEOUS_FILES}. Please submit a smaller number of files."
580
  print(out_message)
581
  raise Exception(out_message)
582
+
583
  # If we have already redacted the last file, return the input out_message and file list to the relevant components
584
  if latest_file_completed >= len(file_paths):
585
+ print("Last file reached") # , returning files:", str(latest_file_completed))
586
  # Set to a very high number so as not to mess with subsequent file processing by the user
587
+ # latest_file_completed = 99
588
+ final_out_message = "\n".join(out_message)
589
+ return (
590
+ final_out_message,
591
+ out_file_paths,
592
+ out_file_paths,
593
+ latest_file_completed,
594
+ log_files_output_paths,
595
+ log_files_output_paths,
596
+ actual_time_taken_number,
597
+ comprehend_query_number,
598
+ )
599
+
600
  file_path_loop = [file_paths[int(latest_file_completed)]]
601
 
602
+ for anon_file in progress.tqdm(
603
+ file_path_loop, desc="Anonymising files", unit="files"
604
+ ):
605
 
606
  # Get a string file path
607
+ if isinstance(anon_file, str):
608
+ file_path = anon_file
609
+ else:
610
+ file_path = anon_file
611
 
612
+ if anon_file == "open_text":
613
+ anon_df = pd.DataFrame(data={"text": [in_text]})
614
+ chosen_cols = ["text"]
615
  out_file_part = anon_file
616
  sheet_name = ""
617
  file_type = ""
618
 
619
+ (
620
+ out_file_paths,
621
+ out_message,
622
+ key_string,
623
+ log_files_output_paths,
624
+ comprehend_query_number,
625
+ ) = tabular_anonymise_wrapper_func(
626
+ file_path,
627
+ anon_df,
628
+ chosen_cols,
629
+ out_file_paths,
630
+ out_file_part,
631
+ out_message,
632
+ sheet_name,
633
+ anon_strategy,
634
+ language,
635
+ chosen_redact_entities,
636
+ in_allow_list,
637
+ file_type,
638
+ "",
639
+ log_files_output_paths,
640
+ in_deny_list,
641
+ max_fuzzy_spelling_mistakes_num,
642
+ pii_identification_method,
643
+ chosen_redact_comprehend_entities,
644
+ comprehend_query_number,
645
+ comprehend_client,
646
+ output_folder=OUTPUT_FOLDER,
647
+ do_initial_clean=do_initial_clean,
648
+ )
649
  else:
650
  # If file is an xlsx, we are going to run through all the Excel sheets to anonymise them separately.
651
  file_type = detect_file_type(file_path)
 
653
 
654
  out_file_part = get_file_name_without_type(file_path)
655
 
656
+ if file_type == "docx":
657
+ output_path, log_path, output_xlsx_path, comprehend_query_number = (
658
+ handle_docx_anonymisation(
659
+ file_path=file_path,
660
+ output_folder=output_folder,
661
+ anon_strategy=anon_strategy,
662
+ chosen_redact_entities=chosen_redact_entities,
663
+ in_allow_list=in_allow_list_flat,
664
+ in_deny_list=in_deny_list,
665
+ max_fuzzy_spelling_mistakes_num=max_fuzzy_spelling_mistakes_num,
666
+ pii_identification_method=pii_identification_method,
667
+ chosen_redact_comprehend_entities=chosen_redact_comprehend_entities,
668
+ comprehend_query_number=comprehend_query_number,
669
+ comprehend_client=comprehend_client,
670
+ language=language,
671
+ )
672
  )
673
  if output_path:
674
  out_file_paths.append(output_path)
 
676
  out_file_paths.append(output_xlsx_path)
677
  if log_path:
678
  log_files_output_paths.append(log_path)
679
+
680
+ elif file_type == "xlsx":
681
  print("Running through all xlsx sheets")
682
+ # anon_xlsx = pd.ExcelFile(anon_file)
683
  if not in_excel_sheets:
684
+ out_message.append(
685
+ "No Excel sheets selected. Please select at least one to anonymise."
686
+ )
687
  continue
688
 
689
  # Create xlsx file:
690
+ anon_xlsx = pd.ExcelFile(file_path)
691
+ anon_xlsx_export_file_name = (
692
+ output_folder + out_file_part + "_redacted.xlsx"
693
+ )
694
 
695
  # Iterate through the sheet names
696
+ for sheet_name in progress.tqdm(
697
+ in_excel_sheets, desc="Anonymising sheets", unit="sheets"
698
+ ):
699
  # Read each sheet into a DataFrame
700
  if sheet_name not in anon_xlsx.sheet_names:
701
  continue
702
 
703
  anon_df = pd.read_excel(file_path, sheet_name=sheet_name)
704
 
705
+ (
706
+ out_file_paths,
707
+ out_message,
708
+ key_string,
709
+ log_files_output_paths,
710
+ comprehend_query_number,
711
+ ) = tabular_anonymise_wrapper_func(
712
+ anon_file,
713
+ anon_df,
714
+ chosen_cols,
715
+ out_file_paths,
716
+ out_file_part,
717
+ out_message,
718
+ sheet_name,
719
+ anon_strategy,
720
+ language,
721
+ chosen_redact_entities,
722
+ in_allow_list,
723
+ file_type,
724
+ anon_xlsx_export_file_name,
725
+ log_files_output_paths,
726
+ in_deny_list,
727
+ max_fuzzy_spelling_mistakes_num,
728
+ pii_identification_method,
729
+ language,
730
+ chosen_redact_comprehend_entities,
731
+ comprehend_query_number,
732
+ comprehend_client,
733
+ output_folder=output_folder,
734
+ do_initial_clean=do_initial_clean,
735
+ )
736
 
737
  else:
738
  sheet_name = ""
739
  anon_df = read_file(file_path)
740
  out_file_part = get_file_name_without_type(file_path)
741
 
742
+ (
743
+ out_file_paths,
744
+ out_message,
745
+ key_string,
746
+ log_files_output_paths,
747
+ comprehend_query_number,
748
+ ) = tabular_anonymise_wrapper_func(
749
+ anon_file,
750
+ anon_df,
751
+ chosen_cols,
752
+ out_file_paths,
753
+ out_file_part,
754
+ out_message,
755
+ sheet_name,
756
+ anon_strategy,
757
+ language,
758
+ chosen_redact_entities,
759
+ in_allow_list,
760
+ file_type,
761
+ "",
762
+ log_files_output_paths,
763
+ in_deny_list,
764
+ max_fuzzy_spelling_mistakes_num,
765
+ pii_identification_method,
766
+ language,
767
+ chosen_redact_comprehend_entities,
768
+ comprehend_query_number,
769
+ comprehend_client,
770
+ output_folder=output_folder,
771
+ do_initial_clean=do_initial_clean,
772
+ )
773
 
774
  # Increase latest file completed count unless we are at the last file
775
  if latest_file_completed != len(file_paths):
 
785
 
786
  if isinstance(out_message, str):
787
  out_message = [out_message]
 
 
788
 
789
+ out_message.append(
790
+ "Anonymisation of file '" + out_file_part + "' successfully completed in"
791
+ )
792
+
793
+ out_message_out = "\n".join(out_message)
794
  out_message_out = out_message_out + " " + out_time
795
 
796
  if anon_strategy == "encrypt":
797
+ out_message_out.append(". Your decryption key is " + key_string)
798
+
799
+ out_message_out = (
800
+ out_message_out
801
+ + "\n\nGo to to the Redaction settings tab to see redaction logs. Please give feedback on the results below to help improve this app."
802
+ )
803
+
804
+ out_message_out = re.sub(r"^\n+|^\. ", "", out_message_out).strip()
805
+
806
+ return (
807
+ out_message_out,
808
+ out_file_paths,
809
+ out_file_paths,
810
+ latest_file_completed,
811
+ log_files_output_paths,
812
+ log_files_output_paths,
813
+ actual_time_taken_number,
814
+ comprehend_query_number,
815
+ )
816
 
 
 
 
817
 
818
  def tabular_anonymise_wrapper_func(
819
+ anon_file: str,
820
+ anon_df: pd.DataFrame,
821
+ chosen_cols: List[str],
822
+ out_file_paths: List[str],
823
+ out_file_part: str,
824
+ out_message: str,
825
+ excel_sheet_name: str,
826
+ anon_strategy: str,
827
  language: str,
828
+ chosen_redact_entities: List[str],
829
+ in_allow_list: List[str],
830
+ file_type: str,
831
+ anon_xlsx_export_file_name: str,
832
  log_files_output_paths: List[str],
833
+ in_deny_list: List[str] = list(),
834
+ max_fuzzy_spelling_mistakes_num: int = 0,
835
+ pii_identification_method: str = "Local",
836
  comprehend_language: Optional[str] = None,
837
+ chosen_redact_comprehend_entities: List[str] = list(),
838
+ comprehend_query_number: int = 0,
839
+ comprehend_client: botocore.client.BaseClient = "",
840
  nlp_analyser: AnalyzerEngine = nlp_analyser,
841
  output_folder: str = OUTPUT_FOLDER,
842
+ do_initial_clean: bool = DO_INITIAL_TABULAR_DATA_CLEAN,
843
  ):
844
  """
845
  This function wraps the anonymisation process for a given dataframe. It filters the dataframe based on chosen columns, applies the specified anonymisation strategy using the anonymise_script function, and exports the anonymised data to a file.
 
861
  - log_files_output_paths: A list of paths where the log files will be saved.
862
  - in_deny_list: List of specific terms to remove from the data.
863
  - max_fuzzy_spelling_mistakes_num (int, optional): The maximum number of spelling mistakes allowed in a searched phrase for fuzzy matching. Can range from 0-9.
864
+ - pii_identification_method (str, optional): The method to redact personal information. Either 'Local' (spacy model), or 'AWS Comprehend' (AWS Comprehend API).
865
  - chosen_redact_comprehend_entities (List[str]): A list of entity types to redact from files, chosen from the official list from AWS Comprehend service.
866
  - comprehend_query_number (int, optional): A counter tracking the number of queries to AWS Comprehend.
867
+ - comprehend_client (optional): The client object from AWS containing a client connection to AWS Comprehend if that option is chosen on the first tab.
868
  - output_folder: The folder where the anonymized files will be saved. Defaults to the 'output_folder' variable.
869
  - do_initial_clean (bool, optional): Whether to perform an initial cleaning of the text. Defaults to True.
870
  """
871
+
872
  def check_lists(list1, list2):
873
+ return any(string in list2 for string in list1)
874
+
875
  def get_common_strings(list1, list2):
876
  """
877
  Finds the common strings between two lists.
 
890
  return common_strings
891
 
892
  if pii_identification_method == "AWS Comprehend" and comprehend_client == "":
893
+ raise (
894
+ "Connection to AWS Comprehend service not found, please check connection details."
895
+ )
896
+
897
  # Check for chosen col, skip file if not found
898
  all_cols_original_order = list(anon_df.columns)
899
 
900
  any_cols_found = check_lists(chosen_cols, all_cols_original_order)
901
 
902
+ if any_cols_found is False:
903
  out_message = "No chosen columns found in dataframe: " + out_file_part
904
  key_string = ""
905
  print(out_message)
906
+ return (
907
+ out_file_paths,
908
+ out_message,
909
+ key_string,
910
+ log_files_output_paths,
911
+ comprehend_query_number,
912
+ )
913
  else:
914
+ chosen_cols_in_anon_df = get_common_strings(
915
+ chosen_cols, all_cols_original_order
916
+ )
917
 
918
  # Split dataframe to keep only selected columns
919
+ # print("Remaining columns to redact:", chosen_cols_in_anon_df)
920
+
921
  if not anon_df.index.is_unique:
922
  anon_df = anon_df.reset_index(drop=True)
923
 
924
  anon_df_part = anon_df[chosen_cols_in_anon_df]
925
+ anon_df_remain = anon_df.drop(chosen_cols_in_anon_df, axis=1)
926
 
927
  row_count = anon_df_part.shape[0]
928
+
929
  if row_count > MAX_TABLE_ROWS:
930
  out_message = f"Number of rows in dataframe is greater than {MAX_TABLE_ROWS}. Please submit a smaller dataframe."
931
  print(out_message)
932
  raise Exception(out_message)
933
+
934
  column_count = anon_df_part.shape[1]
935
+
936
  if column_count > MAX_TABLE_COLUMNS:
937
  out_message = f"Number of columns in dataframe is greater than {MAX_TABLE_COLUMNS}. Please submit a smaller dataframe."
938
  print(out_message)
939
  raise Exception(out_message)
940
 
941
  # Anonymise the selected columns
942
+ (
943
+ anon_df_part_out,
944
+ key_string,
945
+ decision_process_output_str,
946
+ comprehend_query_number,
947
+ ) = anonymise_script(
948
+ anon_df_part,
949
+ anon_strategy,
950
+ language,
951
+ chosen_redact_entities,
952
+ in_allow_list,
953
+ in_deny_list,
954
+ max_fuzzy_spelling_mistakes_num,
955
+ pii_identification_method,
956
+ chosen_redact_comprehend_entities,
957
+ comprehend_query_number,
958
+ comprehend_client,
959
+ nlp_analyser=nlp_analyser,
960
+ do_initial_clean=do_initial_clean,
961
+ )
962
 
963
  anon_df_part_out.replace("^nan$", "", regex=True, inplace=True)
964
 
965
  # Rejoin the dataframe together
966
+ anon_df_out = pd.concat([anon_df_part_out, anon_df_remain], axis=1)
967
  anon_df_out = anon_df_out[all_cols_original_order]
968
+
969
  # Export file
970
  # Rename anonymisation strategy for file path naming
971
+ if anon_strategy == "replace with 'REDACTED'":
972
+ anon_strat_txt = "redact_replace"
973
+ elif anon_strategy == "replace with <ENTITY_NAME>":
974
+ anon_strat_txt = "redact_entity_type"
975
+ elif anon_strategy == "redact completely":
976
+ anon_strat_txt = "redact_remove"
977
+ else:
978
+ anon_strat_txt = anon_strategy
979
 
980
  # If the file is an xlsx, add a new sheet to the existing xlsx. Otherwise, write to csv
981
+ if file_type == "xlsx":
982
 
983
  anon_export_file_name = anon_xlsx_export_file_name
984
 
 
989
  wb.save(anon_xlsx_export_file_name)
990
 
991
  # Create a Pandas Excel writer using XlsxWriter as the engine.
992
+ with pd.ExcelWriter(
993
+ anon_xlsx_export_file_name,
994
+ engine="openpyxl",
995
+ mode="a",
996
+ if_sheet_exists="replace",
997
+ ) as writer:
998
  # Write each DataFrame to a different worksheet.
999
  anon_df_out.to_excel(writer, sheet_name=excel_sheet_name, index=None)
1000
 
1001
+ decision_process_log_output_file = (
1002
+ anon_xlsx_export_file_name
1003
+ + "_"
1004
+ + excel_sheet_name
1005
+ + "_decision_process_output.txt"
1006
+ )
1007
  with open(decision_process_log_output_file, "w") as f:
1008
  f.write(decision_process_output_str)
1009
 
1010
  else:
1011
+ anon_export_file_name = (
1012
+ output_folder + out_file_part + "_anon_" + anon_strat_txt + ".csv"
1013
+ )
1014
+ anon_df_out.to_csv(anon_export_file_name, index=None, encoding="utf-8-sig")
1015
+
1016
+ decision_process_log_output_file = (
1017
+ anon_export_file_name + "_decision_process_output.txt"
1018
+ )
1019
  with open(decision_process_log_output_file, "w") as f:
1020
  f.write(decision_process_output_str)
1021
 
 
1026
  out_file_paths = list(set(out_file_paths))
1027
 
1028
  # Print result text to output text box if just anonymising open text
1029
+ if anon_file == "open_text":
1030
+ out_message = ["'" + anon_df_out["text"][0] + "'"]
1031
+
1032
+ return (
1033
+ out_file_paths,
1034
+ out_message,
1035
+ key_string,
1036
+ log_files_output_paths,
1037
+ comprehend_query_number,
1038
+ )
1039
+
1040
+
1041
+ def anonymise_script(
1042
+ df: pd.DataFrame,
1043
+ anon_strategy: str,
1044
+ language: str,
1045
+ chosen_redact_entities: List[str],
1046
+ in_allow_list: List[str] = list(),
1047
+ in_deny_list: List[str] = list(),
1048
+ max_fuzzy_spelling_mistakes_num: int = 0,
1049
+ pii_identification_method: str = "Local",
1050
+ chosen_redact_comprehend_entities: List[str] = list(),
1051
+ comprehend_query_number: int = 0,
1052
+ comprehend_client: botocore.client.BaseClient = "",
1053
+ custom_entities: List[str] = custom_entities,
1054
+ nlp_analyser: AnalyzerEngine = nlp_analyser,
1055
+ do_initial_clean: bool = DO_INITIAL_TABULAR_DATA_CLEAN,
1056
+ progress: Progress = Progress(track_tqdm=True),
1057
+ ):
1058
+ """
1059
  Conduct anonymisation of a dataframe using Presidio and/or AWS Comprehend if chosen.
1060
 
1061
  Args:
 
1074
  nlp_analyser (AnalyzerEngine, optional): The Presidio AnalyzerEngine instance to use. Defaults to `nlp_analyser`.
1075
  do_initial_clean (bool, optional): Whether to perform an initial cleaning of the text. Defaults to True.
1076
  progress (Progress, optional): Gradio Progress object for tracking progress. Defaults to Progress(track_tqdm=False).
1077
+ """
1078
 
1079
  print("Identifying personal information")
1080
  analyse_tic = time.perf_counter()
1081
 
1082
  # Initialize analyzer_results as an empty dictionary to store results by column
1083
  results_by_column = dict()
1084
+ key_string = ""
1085
 
1086
  if isinstance(in_allow_list, list):
1087
  if in_allow_list:
 
1100
  try:
1101
  if language != "en":
1102
  progress(0.1, desc=f"Loading spaCy model for {language}")
1103
+
1104
  load_spacy_model(language)
1105
+
1106
  except Exception as e:
1107
  out_message = f"Error downloading language packs for {language}: {e}"
1108
  print(out_message)
1109
  raise Exception(out_message)
1110
+
1111
  # Try updating the supported languages for the spacy analyser
1112
  try:
1113
  nlp_analyser = create_nlp_analyser(language, existing_nlp_analyser=nlp_analyser)
1114
  # Check list of nlp_analyser recognisers and languages
1115
  if language != "en":
1116
+ gr.Info(
1117
+ f"Language: {language} only supports the following entity detection: {str(nlp_analyser.registry.get_supported_entities(languages=[language]))}"
1118
+ )
1119
 
1120
  except Exception as e:
1121
  out_message = f"Error creating nlp_analyser for {language}: {e}"
 
1132
  # Sort the strings in order from the longest string to the shortest
1133
  in_deny_list = sorted(in_deny_list, key=len, reverse=True)
1134
 
1135
+ if in_deny_list:
1136
  nlp_analyser.registry.remove_recognizer("CUSTOM")
1137
  new_custom_recogniser = custom_word_list_recogniser(in_deny_list)
1138
  nlp_analyser.registry.add_recognizer(new_custom_recogniser)
1139
 
1140
  nlp_analyser.registry.remove_recognizer("CustomWordFuzzyRecognizer")
1141
+ new_custom_fuzzy_recogniser = CustomWordFuzzyRecognizer(
1142
+ supported_entities=["CUSTOM_FUZZY"],
1143
+ custom_list=in_deny_list,
1144
+ spelling_mistakes_max=in_deny_list,
1145
+ search_whole_phrase=max_fuzzy_spelling_mistakes_num,
1146
+ )
1147
  nlp_analyser.registry.add_recognizer(new_custom_fuzzy_recogniser)
1148
 
1149
+ # analyzer = nlp_analyser #AnalyzerEngine()
1150
  batch_analyzer = BatchAnalyzerEngine(analyzer_engine=nlp_analyser)
1151
+ anonymizer = (
1152
+ AnonymizerEngine()
1153
+ ) # conflict_resolution=ConflictResolutionStrategy.MERGE_SIMILAR_OR_CONTAINED)
1154
+ batch_anonymizer = BatchAnonymizerEngine(anonymizer_engine=anonymizer)
1155
  analyzer_results = list()
1156
 
1157
  if do_initial_clean:
1158
  progress(0.2, desc="Cleaning text")
1159
+ for col in progress.tqdm(df.columns, desc="Cleaning text", unit="Columns"):
1160
  df[col] = initial_clean(df[col])
1161
 
1162
  # DataFrame to dict
1163
+ df_dict = df.to_dict(orient="list")
1164
 
1165
  if pii_identification_method == "Local":
1166
 
1167
  # Use custom analyzer to be able to track progress with Gradio
1168
+ custom_results = analyze_dict(
1169
+ batch_analyzer,
1170
+ df_dict,
1171
+ language=language,
1172
+ entities=chosen_redact_entities,
1173
+ score_threshold=score_threshold,
1174
+ return_decision_process=True,
1175
+ allow_list=in_allow_list_flat,
1176
+ )
1177
+
1178
  # Initialize results_by_column with custom entity results
1179
  for result in custom_results:
1180
  results_by_column[result.key] = result
 
1184
 
1185
  # AWS Comprehend calls
1186
  elif pii_identification_method == "AWS Comprehend" and comprehend_client:
1187
+
1188
  # Only run Local anonymisation for entities that are not covered by AWS Comprehend
1189
  if custom_entities:
1190
  custom_redact_entities = [
1191
+ entity
1192
+ for entity in chosen_redact_comprehend_entities
1193
  if entity in custom_entities
1194
  ]
1195
  if custom_redact_entities:
1196
  # Get results from analyze_dict
1197
+ custom_results = analyze_dict(
1198
+ batch_analyzer,
1199
+ df_dict,
1200
+ language=language,
1201
+ entities=custom_redact_entities,
1202
+ score_threshold=score_threshold,
1203
+ return_decision_process=True,
1204
+ allow_list=in_allow_list_flat,
1205
+ )
1206
+
1207
  # Initialize results_by_column with custom entity results
1208
  for result in custom_results:
1209
  results_by_column[result.key] = result
 
1212
  retry_delay = 3
1213
 
1214
  # Process each text column in the dictionary
1215
+ for column_name, texts in progress.tqdm(
1216
+ df_dict.items(), desc="Querying AWS Comprehend service.", unit="Columns"
1217
+ ):
1218
  # Get or create DictAnalyzerResult for this column
1219
  if column_name in results_by_column:
1220
  column_results = results_by_column[column_name]
1221
  else:
1222
  column_results = DictAnalyzerResult(
1223
+ recognizer_results=[[] for _ in texts], key=column_name, value=texts
 
 
1224
  )
1225
 
1226
  # Process each text in the column
1227
+ for text_idx, text in progress.tqdm(
1228
+ enumerate(texts), desc="Querying AWS Comprehend service.", unit="Row"
1229
+ ):
1230
 
1231
  for attempt in range(max_retries):
1232
  try:
1233
  response = comprehend_client.detect_pii_entities(
1234
+ Text=str(text), LanguageCode=language
 
1235
  )
1236
 
1237
  comprehend_query_number += 1
1238
 
1239
  # Add all entities from this text to the column's recognizer_results
1240
  for entity in response["Entities"]:
1241
+ if (
1242
+ entity.get("Type")
1243
+ not in chosen_redact_comprehend_entities
1244
+ ):
1245
  continue
1246
 
1247
  recognizer_result = RecognizerResult(
1248
  entity_type=entity["Type"],
1249
  start=entity["BeginOffset"],
1250
  end=entity["EndOffset"],
1251
+ score=entity["Score"],
1252
  )
1253
+ column_results.recognizer_results[text_idx].append(
1254
+ recognizer_result
1255
+ )
1256
+
1257
  break # Success, exit retry loop
1258
+
1259
  except Exception as e:
1260
  if attempt == max_retries - 1:
1261
+ print(
1262
+ f"AWS Comprehend calls failed for text: {text[:100]}... due to",
1263
+ e,
1264
+ )
1265
  raise
1266
  time.sleep(retry_delay)
1267
 
 
1272
  analyzer_results = list(results_by_column.values())
1273
 
1274
  elif (pii_identification_method == "AWS Comprehend") & (not comprehend_client):
1275
+ raise ("Unable to redact, Comprehend connection details not found.")
1276
+
1277
  else:
1278
  print("Unable to redact.")
1279
 
1280
  # Usage in the main function:
1281
+ decision_process_output_str = generate_decision_process_output(
1282
+ analyzer_results, df_dict
1283
+ )
1284
 
1285
  analyse_toc = time.perf_counter()
1286
+ analyse_time_out = (
1287
+ f"Analysing the text took {analyse_toc - analyse_tic:0.1f} seconds."
1288
+ )
1289
  print(analyse_time_out)
1290
 
1291
  # Set up the anonymization configuration WITHOUT DATE_TIME
1292
+ simple_replace_config = {
1293
+ "DEFAULT": OperatorConfig("replace", {"new_value": "REDACTED"})
1294
+ }
1295
+ replace_config = {"DEFAULT": OperatorConfig("replace")}
1296
+ redact_config = {"DEFAULT": OperatorConfig("redact")}
1297
+ hash_config = {"DEFAULT": OperatorConfig("hash")}
1298
+ mask_config = {
1299
+ "DEFAULT": OperatorConfig("mask", {
1300
+ "masking_char": "*",
1301
+ "chars_to_mask": 100,
1302
+ "from_end": True
1303
+ })
1304
+ }
1305
+ people_encrypt_config = {
1306
+ "PERSON": OperatorConfig("encrypt", {"key": key_string})
1307
+ } # The encryption is using AES cypher in CBC mode and requires a cryptographic key as an input for both the encryption and the decryption.
1308
+ fake_first_name_config = {
1309
+ "PERSON": OperatorConfig("custom", {"lambda": fake_first_name})
1310
+ }
1311
+
1312
+ if anon_strategy == "replace with 'REDACTED'":
1313
+ chosen_mask_config = simple_replace_config
1314
+ elif anon_strategy == "replace_redacted":
1315
+ chosen_mask_config = simple_replace_config
1316
+ elif anon_strategy == "replace with <ENTITY_NAME>":
1317
+ chosen_mask_config = replace_config
1318
+ elif anon_strategy == "entity_type":
1319
+ chosen_mask_config = replace_config
1320
+ elif anon_strategy == "redact completely":
1321
+ chosen_mask_config = redact_config
1322
+ elif anon_strategy == "redact":
1323
+ chosen_mask_config = redact_config
1324
+ elif anon_strategy == "hash":
1325
+ chosen_mask_config = hash_config
1326
+ elif anon_strategy == "mask":
1327
+ chosen_mask_config = mask_config
1328
+ elif anon_strategy == "encrypt":
1329
  chosen_mask_config = people_encrypt_config
1330
  key = secrets.token_bytes(16) # 128 bits = 16 bytes
1331
+ key_string = base64.b64encode(key).decode("utf-8")
1332
+
1333
  # Now inject the key into the operator config
1334
  for entity, operator in chosen_mask_config.items():
1335
  if operator.operator_name == "encrypt":
1336
  operator.params = {"key": key_string}
1337
+ elif anon_strategy == "fake_first_name":
1338
+ chosen_mask_config = fake_first_name_config
1339
  else:
1340
  print("Anonymisation strategy not found. Redacting completely by default.")
1341
+ chosen_mask_config = redact_config # Redact completely by default
 
 
 
1342
 
1343
+ combined_config = {**chosen_mask_config}
1344
 
1345
+ anonymizer_results = batch_anonymizer.anonymize_dict(
1346
+ analyzer_results, operators=combined_config)
1347
 
1348
  scrubbed_df = pd.DataFrame(anonymizer_results)
1349
+
1350
+ return scrubbed_df, key_string, decision_process_output_str, comprehend_query_number
tools/file_conversion.py CHANGED
The diff for this file is too large to render. See raw diff
 
tools/file_redaction.py CHANGED
The diff for this file is too large to render. See raw diff
 
tools/find_duplicate_pages.py CHANGED
@@ -1,24 +1,30 @@
1
- import pandas as pd
2
  import os
3
  import re
4
  import time
5
- from sklearn.feature_extraction.text import TfidfVectorizer
6
- from sklearn.metrics.pairwise import cosine_similarity
7
- from typing import List, Tuple, Optional, Dict, Union
8
  from collections import defaultdict
 
 
 
9
  import gradio as gr
 
10
  from gradio import Progress
11
- from pathlib import Path
12
- from typing import List
13
- from tools.helper_functions import OUTPUT_FOLDER
14
  from tools.config import MAX_SIMULTANEOUS_FILES
15
- from tools.file_conversion import redact_whole_pymupdf_page, convert_annotation_data_to_dataframe, fill_missing_box_ids_each_box
 
 
 
 
 
16
  from tools.load_spacy_model_custom_recognisers import nlp
17
 
18
- number_of_zeros_to_add_to_index = 7 # Number of zeroes to add between page number and line numbers to get a unique page/line index value
19
  ID_MULTIPLIER = 100000
20
  # Define the set of punctuation characters for efficient lookup
21
- PUNCTUATION_TO_STRIP = {'.', ',', '?', '!', ':', ';'}
 
22
 
23
  def split_text_with_punctuation(text: str) -> List[str]:
24
  """
@@ -29,84 +35,89 @@ def split_text_with_punctuation(text: str) -> List[str]:
29
  # 1. A sequence of one or more punctuation marks `[.,?!:;]+`
30
  # 2. OR a sequence of one or more characters that are NOT punctuation or whitespace `[^.,?!:;\s]+`
31
  pattern = re.compile(r"([.,?!:;]+|[^.,?!:;\s]+)")
32
-
33
  final_list = []
34
  # We first split by whitespace to handle sentences correctly
35
  for word in text.split():
36
  # Then, for each whitespace-separated word, we tokenize it further
37
  final_list.extend(pattern.findall(word))
38
-
39
  return final_list
40
 
 
41
  def extract_indices_from_page_ranges(
42
  results_df: pd.DataFrame,
43
- start_col: str = 'Page2_Start_Page',
44
- end_col: str = 'Page2_End_Page',
45
- modulo_divisor_number_of_zeros: int = number_of_zeros_to_add_to_index, # Search for number of added
46
- converted_index: bool = False # Has the index been converted to the page_no + 0000 + line number format that needs the modulo divisor to convert back?
47
  ) -> List[int]:
48
  all_indices = set()
49
- modulo_divisor = int("1" + modulo_divisor_number_of_zeros*"0")
50
 
51
  for _, row in results_df.iterrows():
52
  start_page = row[start_col]
53
  end_page = row[end_col]
54
  for encoded_page_id in range(start_page, end_page + 1):
55
- if converted_index == True:
56
- original_page, original_index = _parse_page_line_id(encoded_page_id)#(encoded_page_id % modulo_divisor) - 1
 
 
57
  else:
58
  original_index = encoded_page_id
59
 
60
  all_indices.add(original_index)
61
  return sorted(list(all_indices))
62
 
 
63
  def punctuation_at_word_text_end(word_level_df_orig: pd.DataFrame) -> bool:
64
  """
65
- Check the first 1000 rows of word_level_df_orig to see if any of the strings
66
  in 'word_text' end with a full stop '.', exclamation mark '!', or question mark '?',
67
  for strings that do not contain these characters alone.
68
-
69
  Args:
70
  word_level_df_orig (pd.DataFrame): DataFrame containing word-level OCR data with 'word_text' column
71
-
72
  Returns:
73
  bool: True if any strings end with punctuation marks, False otherwise
74
  """
75
  # Get the first 1000 rows or all rows if less than 1000
76
  sample_df = word_level_df_orig.head(1000)
77
-
78
  # Check if 'word_text' column exists
79
- if 'word_text' not in sample_df.columns:
80
  return False
81
-
82
  # Define punctuation marks to check for
83
- punctuation_marks = ['.', '!', '?']
84
-
85
  # Check each word_text string
86
- for word_text in sample_df['word_text']:
87
  if pd.isna(word_text) or not isinstance(word_text, str):
88
  continue
89
-
90
  # Skip strings that contain only punctuation marks
91
  if word_text.strip() in punctuation_marks:
92
  continue
93
-
94
  # Check if the string ends with any of the punctuation marks
95
  if any(word_text.rstrip().endswith(punct) for punct in punctuation_marks):
96
  return True
97
-
98
  return False
99
 
 
100
  def run_full_search_and_analysis(
101
  search_query_text: str,
102
  word_level_df_orig: pd.DataFrame,
103
  similarity_threshold: float = 1,
104
- combine_pages: bool = False,
105
  min_word_count: int = 1,
106
  min_consecutive_pages: int = 1,
107
  greedy_match: bool = True,
108
  remake_index: bool = False,
109
- progress=gr.Progress(track_tqdm=True)
110
  ):
111
  """
112
  This function orchestrates the entire pipeline for finding duplicate pages based on a user's search query. It takes in the search query text, the original word-level OCR data, and various parameters to control the analysis. The function then:
@@ -120,7 +131,7 @@ def run_full_search_and_analysis(
120
  - search_query_text (str): The text entered by the user to search for in the OCR data.
121
  - word_level_df_orig (pd.DataFrame): The original DataFrame containing word-level OCR data.
122
  - similarity_threshold (float, optional): The minimum similarity score required for two pages to be considered duplicates. Defaults to 1.
123
- - combine_pages (bool, optional): A flag indicating whether to combine text from the same page number within a file. Defaults to False.
124
  - min_word_count (int, optional): The minimum number of words required for a page to be considered in the analysis. Defaults to 1.
125
  - min_consecutive_pages (int, optional): The minimum number of consecutive pages required to be considered a match. Defaults to 1.
126
  - greedy_match (bool, optional): A flag indicating whether to use a greedy strategy for matching consecutive pages. Defaults to True.
@@ -133,25 +144,32 @@ def run_full_search_and_analysis(
133
  if len(search_query_text) > 100:
134
  raise Warning("Please use a search query with at less than 100 characters.")
135
 
136
- if punctuation_at_word_text_end(word_level_df_orig) == True: do_punctuation_split = False
137
- else: do_punctuation_split = True
 
 
138
 
139
  # Step 1: Process the user's search query string
140
- search_query_data, query_word_length = create_dataframe_from_string(search_query_text, file_name="user_search_query", split_words=True, split_punctuation=do_punctuation_split)
 
 
 
 
 
141
  if not search_query_data:
142
  # Handle case where user submits an empty search string
143
- raise Warning("Could not convert search string to required format")
144
 
145
  if query_word_length > 25:
146
  # Handle case where user submits an empty search string
147
- raise Warning("Please use a query with less than 25 words")
148
 
149
  # Overwrite min_consecutive_pages with the search string length
150
  min_consecutive_pages = query_word_length
151
-
152
  # Create word index from reference table
153
  word_level_df_orig["index"] = word_level_df_orig.index
154
- word_level_df = word_level_df_orig.copy()
155
 
156
  # Step 2: Process the main word-level OCR DataFrame
157
  word_level_data = convert_word_level_df(word_level_df, file_name="source_document")
@@ -160,13 +178,13 @@ def run_full_search_and_analysis(
160
  all_data_to_process = search_query_data + word_level_data
161
  if not all_data_to_process:
162
  raise gr.Error("No data to process. Please check your inputs.")
163
-
164
  # Step 4: Run the combination logic
165
  combined_df, _, full_out_ocr_df = combine_ocr_dataframes(
166
  input_data=all_data_to_process,
167
  combine_pages=combine_pages,
168
- output_folder=None, # No need to save this intermediate file
169
- remake_index=remake_index
170
  )
171
 
172
  # Step 5: Run the final similarity analysis on the combined data
@@ -181,28 +199,37 @@ def run_full_search_and_analysis(
181
  do_text_clean=False,
182
  file1_name="user_search_query",
183
  file2_name="source_document",
184
- progress=progress
185
  )
186
 
187
  print("Finished text search")
188
 
189
  # Map the results back to the reference data file
190
- if remake_index == True:
191
- results_df_index_list = extract_indices_from_page_ranges(results_df, converted_index=True)
 
 
192
  else:
193
- results_df_index_list = extract_indices_from_page_ranges(results_df, converted_index=False)
 
 
194
 
195
- word_level_df_out = word_level_df_orig.loc[word_level_df_orig["index"].isin(results_df_index_list)]
 
 
196
 
197
  return word_level_df_out, duplicate_files, full_data
198
 
199
- def create_all_data_to_process(converted_data:pd.DataFrame, other_data_list:List[Tuple]):
 
 
 
200
  all_data_to_process = converted_data + other_data_list
201
  return all_data_to_process
202
 
 
203
  def convert_word_level_df(
204
- word_level_df: pd.DataFrame,
205
- file_name: str = "converted_dataframe"
206
  ) -> List[Tuple[str, pd.DataFrame]]:
207
  """
208
  Converts a word-level OCR DataFrame to the format for
@@ -225,34 +252,37 @@ def convert_word_level_df(
225
  DataFrame will have 'page' and 'text' columns.
226
  """
227
  # --- 1. Validate Input ---
228
- required_columns = ['page', 'line', 'word_text']
229
  if not all(col in word_level_df.columns for col in required_columns):
230
- raise ValueError(f"Input DataFrame must contain all of the following columns: {required_columns}")
 
 
231
 
232
  df = word_level_df.copy()
233
 
234
  # --- 2. Process the DataFrame ---
235
  # Ensure word_text is a string to allow for joining
236
- df['word_text'] = df['word_text'].astype(str)
237
 
238
  # Group by page and line number, then join the words with a space (not needed for word level search)
239
  # The result is a Series with a MultiIndex (page, line)
240
- #line_text_series = df.groupby(['page', 'line'])['word_text'].apply(' '.join)
241
 
242
  # Convert the Series back to a DataFrame and reset the index
243
- #line_level_df = line_text_series.reset_index()
244
 
245
  # Rename the aggregated column from 'word_text' to the required 'text'
246
- df = df.rename(columns={'word_text': 'text'})
247
 
248
  # --- 3. Finalise the structure ---
249
  # We now have a DataFrame with columns [page, line, text].
250
- final_df = df[['page', 'text']]
251
 
252
  # --- 4. Package for output ---
253
  # Return in the required List[Tuple[str, DataFrame]] format
254
  return [(file_name, final_df)]
255
 
 
256
  def create_dataframe_from_string(
257
  text_string: str,
258
  file_name: str = "user_search_query",
@@ -292,24 +322,22 @@ def create_dataframe_from_string(
292
 
293
  if split_words:
294
  # --- Split string into words, one per row, based on similar punctuation split technique used to create ocr_results_with_words objects ---
295
- if split_punctuation == True:
296
  words = split_text_with_punctuation(text_string)
297
  else:
298
  words = text_string.split()
299
-
300
- #words = text_string.split()
301
  len_words = len(words)
302
- data = {
303
- 'page': [page_number] * len_words, # Assign the same page number to every word
304
- 'text': words # The list of words becomes the text column
 
305
  }
306
  else:
307
  # --- Entire string in one row ---
308
  len_words = 1
309
- data = {
310
- 'page': [page_number],
311
- 'text': [text_string]
312
- }
313
 
314
  # Create the DataFrame from the prepared data
315
  df = pd.DataFrame(data)
@@ -319,13 +347,14 @@ def create_dataframe_from_string(
319
  # Return it in the required format: a list containing one (name, df) tuple
320
  return [(file_name, df)], len_words
321
 
 
322
  def combine_ocr_dataframes(
323
  input_data: List[Tuple[str, pd.DataFrame]],
324
  combine_pages: bool = True,
325
  output_folder: str = OUTPUT_FOLDER,
326
  output_filename: str = "combined_ocr_output.csv",
327
  number_of_added_zeros: int = number_of_zeros_to_add_to_index,
328
- remake_index:bool = True
329
  ) -> Tuple[pd.DataFrame, List[str]]:
330
  """
331
  Combines text from multiple pandas DataFrames containing page and text columns.
@@ -358,18 +387,20 @@ def combine_ocr_dataframes(
358
  df = df_initial.copy() # Work on a copy to avoid side effects
359
 
360
  # --- Validation ---
361
- if 'page' not in df.columns or 'text' not in df.columns:
362
- print(f"Warning: Skipping data for '{file_identifier}' - missing required columns 'page' and 'text'.")
 
 
363
  continue
364
 
365
  # --- Processing ---
366
- df['text'] = df['text'].fillna('').astype(str)
367
 
368
  if combine_pages:
369
  # Group by page and concatenate text into a single string
370
- processed_df = df.groupby('page')['text'].apply(' '.join).reset_index()
371
  else:
372
- if remake_index == True:
373
  # # Create a unique, sortable page ID for each line without combining
374
  # df['line_number_by_page'] = df.groupby('page').cumcount() + 1
375
  # df['original_page'] = df['page']
@@ -377,41 +408,47 @@ def combine_ocr_dataframes(
377
  # df['page'] = (
378
  # df['page'].astype(str).str.zfill(number_of_added_zeros) +
379
  # df['line_number_by_page'].astype(str).str.zfill(number_of_added_zeros)
380
- # ).astype(int)
381
 
382
  # Define the multiplier based on the max expected lines per page.
383
  # If you expect up to 99,999 lines, use 100,000.
384
 
385
- df['line_number_by_page'] = df.groupby('page').cumcount() + 1
386
- df['original_page'] = df['page']
387
 
388
  # Create the new combined ID using arithmetic
389
- df['page'] = (df['original_page'] * ID_MULTIPLIER) + df['line_number_by_page']
390
-
 
 
391
  else:
392
- if not 'index' in df.columns:
393
- df['index'] = df.index
394
- df['page'] = df['index']
395
-
396
  processed_df = df
397
 
398
  # Add the file identifier column
399
- processed_df['file'] = file_identifier
400
  all_data.append(processed_df)
401
 
402
  if not all_data:
403
- raise ValueError("No valid DataFrames were processed. Ensure input data is not empty and DataFrames have 'page' and 'text' columns.")
 
 
404
 
405
  # --- Final Combination ---
406
  combined_df = pd.concat(all_data, ignore_index=True)
407
 
408
  # Reorder columns to a standard format, dropping intermediate columns
409
- final_columns = ['file', 'page', 'text']
410
- if 'original_page' in combined_df.columns:
411
- final_columns.append('original_page') # Keep for context if created
412
-
413
  # Ensure all final columns exist before trying to select them
414
- existing_final_columns = [col for col in final_columns if col in combined_df.columns]
 
 
415
 
416
  full_out_ocr_df = combined_df
417
  combined_df = combined_df.copy()[existing_final_columns]
@@ -427,11 +464,12 @@ def combine_ocr_dataframes(
427
 
428
  return combined_df, output_files, full_out_ocr_df
429
 
 
430
  def combine_ocr_output_text(
431
  input_files: Union[str, List[str]],
432
  combine_pages: bool = True,
433
  remake_index: bool = True,
434
- output_folder: str = OUTPUT_FOLDER
435
  ) -> Tuple[pd.DataFrame, List[str]]:
436
  """
437
  Reads multiple OCR CSV files, combines them, and saves the result.
@@ -472,19 +510,20 @@ def combine_ocr_output_text(
472
  input_data=data_to_process,
473
  combine_pages=combine_pages,
474
  output_folder=output_folder,
475
- output_filename="combined_ocr_from_files.csv", # Specific name for this path
476
- remake_index=remake_index
477
  )
478
 
479
- def clean_and_stem_text_series(df:pd.DataFrame, column:str):
480
- '''
 
481
  Clean and stem text columns in a data frame
482
- '''
483
-
484
  def _clean_text(raw_text):
485
  # Remove HTML tags
486
- clean = re.sub(r'<.*?>', '', raw_text)
487
- clean = ' '.join(clean.split())
488
  # Join the cleaned words back into a string
489
  return clean
490
 
@@ -492,59 +531,106 @@ def clean_and_stem_text_series(df:pd.DataFrame, column:str):
492
  def _apply_lemmatization(text):
493
  doc = nlp(text)
494
  # Keep only alphabetic tokens and remove stopwords
495
- lemmatized_words = [token.lemma_ for token in doc if token.is_alpha and not token.is_stop]
496
- return ' '.join(lemmatized_words)
497
-
498
- df['text_clean'] = df[column].apply(_clean_text)
 
 
 
 
499
 
500
- df['text_clean'] = df['text_clean'].apply(_apply_lemmatization)
501
-
502
  return df
503
 
504
- def map_metadata_single_page(similarity_df:pd.DataFrame, metadata_source_df:pd.DataFrame, preview_length:int=200):
 
 
 
 
 
505
  """Helper to map metadata for single page results."""
506
- metadata_df = metadata_source_df[['file', 'page', 'text']]
507
- results_df = similarity_df.merge(metadata_df, left_on='Page1_Index', right_index=True)\
508
- .rename(columns={'file': 'Page1_File', 'page': 'Page1_Page', 'text': 'Page1_Text'})
509
- results_df = results_df.merge(metadata_df, left_on='Page2_Index', right_index=True, suffixes=('_1', '_2'))\
510
- .rename(columns={'file': 'Page2_File', 'page': 'Page2_Page', 'text': 'Page2_Text'})
 
 
511
  results_df["Similarity_Score"] = results_df["Similarity_Score"].round(3)
512
- final_df = results_df[['Page1_File', 'Page1_Page', 'Page2_File', 'Page2_Page', 'Similarity_Score', 'Page1_Text', 'Page2_Text']]
513
- final_df = final_df.sort_values(["Page1_File", "Page1_Page", "Page2_File", "Page2_Page"])
514
- final_df['Page1_Text'] = final_df['Page1_Text'].str[:preview_length]
515
- final_df['Page2_Text'] = final_df['Page2_Text'].str[:preview_length]
 
 
 
 
 
 
 
 
 
 
 
 
516
  return final_df
517
 
518
- def map_metadata_subdocument(subdocument_df:pd.DataFrame, metadata_source_df:pd.DataFrame, preview_length:int=200):
 
 
 
 
 
519
  """Helper to map metadata for subdocument results."""
520
- metadata_df = metadata_source_df[['file', 'page', 'text']]
521
-
522
- subdocument_df = subdocument_df.merge(metadata_df, left_on='Page1_Start_Index', right_index=True)\
523
- .rename(columns={'file': 'Page1_File', 'page': 'Page1_Start_Page', 'text': 'Page1_Text'})
524
- subdocument_df = subdocument_df.merge(metadata_df[['page']], left_on='Page1_End_Index', right_index=True)\
525
- .rename(columns={'page': 'Page1_End_Page'})
526
- subdocument_df = subdocument_df.merge(metadata_df, left_on='Page2_Start_Index', right_index=True)\
527
- .rename(columns={'file': 'Page2_File', 'page': 'Page2_Start_Page', 'text': 'Page2_Text'})
528
- subdocument_df = subdocument_df.merge(metadata_df[['page']], left_on='Page2_End_Index', right_index=True)\
529
- .rename(columns={'page': 'Page2_End_Page'})
530
-
531
- cols = ['Page1_File', 'Page1_Start_Page', 'Page1_End_Page',
532
- 'Page2_File', 'Page2_Start_Page', 'Page2_End_Page',
533
- 'Match_Length', 'Page1_Text', 'Page2_Text']
534
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
535
  # Add Avg_Similarity if it exists (it won't for greedy match unless we add it)
536
- if 'Avg_Similarity' in subdocument_df.columns:
537
- subdocument_df['Avg_Similarity'] = subdocument_df['Avg_Similarity'].round(3)
538
- cols.insert(7, 'Avg_Similarity')
539
 
540
  final_df = subdocument_df[cols]
541
- final_df = final_df.sort_values(['Page1_File', 'Page1_Start_Page', 'Page2_File', 'Page2_Start_Page'])
542
- final_df['Page1_Text'] = final_df['Page1_Text'].str[:preview_length]
543
- final_df['Page2_Text'] = final_df['Page2_Text'].str[:preview_length]
 
 
544
 
545
  return final_df
546
 
547
- def save_results_and_redaction_lists(final_df: pd.DataFrame, output_folder: str, combine_pages:bool = True) -> list:
 
 
 
548
  """
549
  Saves the main results DataFrame and generates per-file redaction lists.
550
  This function is extracted to be reusable.
@@ -566,44 +652,53 @@ def save_results_and_redaction_lists(final_df: pd.DataFrame, output_folder: str,
566
  return []
567
 
568
  # 1. Save the main results DataFrame
569
- similarity_file_output_path = output_folder_path / 'page_similarity_results.csv'
570
  final_df.to_csv(similarity_file_output_path, index=False, encoding="utf-8-sig")
571
 
572
  output_paths.append(str(similarity_file_output_path))
573
- #print(f"Main results saved to {similarity_file_output_path}")
574
 
575
  # 2. Save per-file redaction lists
576
  # Use 'Page2_File' as the source of duplicate content
577
- if combine_pages == True:
578
- grouping_col = 'Page2_File'
579
  if grouping_col not in final_df.columns:
580
- print("Warning: 'Page2_File' column not found. Cannot generate redaction lists.")
 
 
581
  return output_paths
582
 
583
  for redact_file, group in final_df.groupby(grouping_col):
584
  output_file_name_stem = Path(redact_file).stem
585
- output_file_path = output_folder_path / f"{output_file_name_stem}_pages_to_redact.csv"
586
-
 
 
587
  all_pages_to_redact = set()
588
- is_subdocument_match = 'Page2_Start_Page' in group.columns
589
 
590
  if is_subdocument_match:
591
  for _, row in group.iterrows():
592
- pages_in_range = range(int(row['Page2_Start_Page']), int(row['Page2_End_Page']) + 1)
 
 
593
  all_pages_to_redact.update(pages_in_range)
594
  else:
595
- pages = group['Page2_Page'].unique()
596
  all_pages_to_redact.update(pages)
597
-
598
  if all_pages_to_redact:
599
- redaction_df = pd.DataFrame(sorted(list(all_pages_to_redact)), columns=['Page_to_Redact'])
 
 
600
  redaction_df.to_csv(output_file_path, header=False, index=False)
601
 
602
  output_paths.append(str(output_file_path))
603
  print(f"Redaction list for {redact_file} saved to {output_file_path}")
604
-
605
  return output_paths
606
 
 
607
  def _sequences_match(query_seq: List[str], ref_seq: List[str]) -> bool:
608
  """
609
  Helper function to compare two sequences of tokens with punctuation flexibility.
@@ -627,9 +722,9 @@ def _sequences_match(query_seq: List[str], ref_seq: List[str]) -> bool:
627
  # - Its last character must be in our punctuation set
628
  # - The token without its last character must match the query token
629
  if (
630
- len(ref_token) > 1 and
631
- ref_token[-1] in PUNCTUATION_TO_STRIP and
632
- ref_token[:-1] == query_token
633
  ):
634
  continue
635
 
@@ -639,10 +734,9 @@ def _sequences_match(query_seq: List[str], ref_seq: List[str]) -> bool:
639
  # If the loop completes, every token has matched.
640
  return True
641
 
 
642
  def find_consecutive_sequence_matches(
643
- df_filtered: pd.DataFrame,
644
- search_file_name: str,
645
- reference_file_name: str
646
  ) -> pd.DataFrame:
647
  """
648
  Finds all occurrences of a consecutive sequence of tokens from a search file
@@ -659,22 +753,22 @@ def find_consecutive_sequence_matches(
659
  A DataFrame with two columns ('Page1_Index', 'Page2_Index') mapping the
660
  consecutive match, or an empty DataFrame if no match is found.
661
  """
662
- #print(f"Starting sequence search for '{search_file_name}' in '{reference_file_name}'...")
663
 
664
  # Step 1: Isolate the data for each file
665
- search_df = df_filtered[df_filtered['file'] == search_file_name]
666
- reference_df = df_filtered[df_filtered['file'] == reference_file_name]
667
 
668
  if search_df.empty or reference_df.empty:
669
  print("Error: One or both files not found or are empty.")
670
- return pd.DataFrame(columns=['Page1_Index', 'Page2_Index'])
671
 
672
  # Step 2: Convert the token data into lists for easy comparison.
673
  # We need both the text tokens and their original global indices.
674
- query_tokens = search_df['text_clean'].tolist()
675
  query_indices = search_df.index.tolist()
676
-
677
- reference_tokens = reference_df['text_clean'].tolist()
678
  reference_indices = reference_df.index.tolist()
679
 
680
  query_len = len(query_tokens)
@@ -689,29 +783,32 @@ def find_consecutive_sequence_matches(
689
 
690
  # Step 4: If the window matches the query with or without punctuation on end
691
  if _sequences_match(query_tokens, window):
692
- #print(f"Found a consecutive match starting at reference index: {reference_indices[i]}")
693
-
694
  # Get the global indices for this entire matching block
695
  matching_reference_indices = reference_indices[i : i + query_len]
696
-
697
  # Create the mapping between query indices and the found reference indices
698
  for j in range(query_len):
699
  all_found_matches.append(
700
  (query_indices[j], matching_reference_indices[j], 1)
701
  )
702
-
703
  # If you only want the *first* match, you can uncomment the next line:
704
- # break
705
 
706
  if not all_found_matches:
707
  print("No matches found")
708
  gr.Info("No matches found")
709
- return pd.DataFrame(columns=['Page1_Index', 'Page2_Index', 'Similarity_Score'])
710
 
711
  # Step 5: Create the final DataFrame in the desired format
712
- result_df = pd.DataFrame(all_found_matches, columns=['Page1_Index', 'Page2_Index', 'Similarity_Score'])
 
 
713
  return result_df
714
 
 
715
  def identify_similar_text_sequences(
716
  df_combined: pd.DataFrame,
717
  similarity_threshold: float = 1,
@@ -720,11 +817,11 @@ def identify_similar_text_sequences(
720
  greedy_match: bool = True,
721
  combine_pages: bool = False,
722
  inter_file_only: bool = False,
723
- do_text_clean:bool = True,
724
- file1_name: str = '',
725
- file2_name: str = '',
726
  output_folder: str = OUTPUT_FOLDER,
727
- progress=Progress(track_tqdm=True)
728
  ) -> Tuple[pd.DataFrame, List[str], pd.DataFrame]:
729
  """
730
  Identifies similar pages. Uses a highly optimized path for inter_file_only=True.
@@ -732,14 +829,17 @@ def identify_similar_text_sequences(
732
  progress(0.1, desc="Processing and filtering text")
733
 
734
  if do_text_clean:
735
- df = clean_and_stem_text_series(df_combined, 'text') # Will produce the column 'text_clean'
 
 
736
  else:
737
  df = df_combined.copy()
738
- df['text_clean'] = df['text'].str.lower()#.str.replace(r'[^\w\s]', '', regex=True)
739
-
 
740
 
741
- df['word_count'] = df['text_clean'].str.split().str.len().fillna(0)
742
- #df['word_count'] = pd.to_numeric(df['word_count'], errors='coerce').fillna(0).astype('int64')
743
 
744
  # ensure min_word_count is an int (e.g., from Gradio/text input)
745
  try:
@@ -748,57 +848,66 @@ def identify_similar_text_sequences(
748
  min_word_count = 0 # or raise/log, depending on your preference
749
 
750
  original_row_count = len(df)
751
- df_filtered = df[df['word_count'] >= min_word_count].copy()
752
  df_filtered.reset_index(drop=True, inplace=True)
753
-
754
- print(f"Filtered out {original_row_count - len(df_filtered)} pages with fewer than {min_word_count} words.")
 
 
755
  if len(df_filtered) < 2:
756
  return pd.DataFrame(), [], df_combined
757
 
758
-
759
-
760
  # Similarity calculated differently if comparing between files only (inter_file_only==True), or within the same file
761
  if inter_file_only:
762
 
763
  progress(0.2, desc="Finding direct text matches...")
764
-
765
- #base_similarity_df = _debug_similarity_between_two_files(df_filtered, vectorizer, similarity_threshold, file1_name, file2_name)
766
- base_similarity_df = find_consecutive_sequence_matches(df_filtered, file1_name, file2_name)
 
 
767
  if base_similarity_df.empty:
768
- return pd.DataFrame(), [], df_combined
769
-
770
  else:
771
  # Use the original, simpler path for all-to-all comparisons (including intra-file).
772
  vectorizer = TfidfVectorizer()
773
  print("Standard Path: Calculating all-to-all similarity.")
774
  progress(0.2, desc="Vectorizing text...")
775
- tfidf_matrix = vectorizer.fit_transform(df_filtered['text_clean'])
776
 
777
  progress(0.3, desc="Calculating similarity matrix...")
778
  similarity_matrix = cosine_similarity(tfidf_matrix, dense_output=False)
779
  coo_matrix = similarity_matrix.tocoo()
780
 
781
  similar_pages = [
782
- (r, c, v) for r, c, v in zip(coo_matrix.row, coo_matrix.col, coo_matrix.data)
 
783
  if r < c and v >= similarity_threshold
784
  ]
785
 
786
  if not similar_pages:
787
  return pd.DataFrame(), [], df_combined
788
-
789
- base_similarity_df = pd.DataFrame(similar_pages, columns=['Page1_Index', 'Page2_Index', 'Similarity_Score'])
 
 
790
 
791
  progress(0.7, desc="Aggregating results based on matching strategy")
792
 
793
  if greedy_match or min_consecutive_pages > 1:
794
- #print("Finding all consecutive page matches of minimum length:", min_consecutive_pages)
795
-
796
  # Sort the dataframe to ensure consecutive pages are adjacent
797
- similarity_df = base_similarity_df #.sort_values(['Page1_Index', 'Page2_Index']).copy()
 
 
798
 
799
  # A new sequence starts if the difference from the previous row is not (1, 1)
800
  # is_consecutive will be True if a row continues the sequence, False if it's a new one.
801
- is_consecutive = (similarity_df['Page1_Index'].diff() == 1) & (similarity_df['Page2_Index'].diff() == 1)
 
 
802
 
803
  # Use cumsum() on the inverted boolean series to create a unique ID for each block.
804
  # Every time a 'False' appears (a new block starts), the sum increases.
@@ -809,46 +918,53 @@ def identify_similar_text_sequences(
809
 
810
  # Aggregate each group to get the start, end, and length of the match
811
  agg_results = grouped.agg(
812
- Page1_Start_Index=('Page1_Index', 'first'),
813
- Page2_Start_Index=('Page2_Index', 'first'),
814
- Page1_End_Index=('Page1_Index', 'last'),
815
- Page2_End_Index=('Page2_Index', 'last'),
816
- Match_Length=('Page1_Index', 'size'),
817
- Avg_Similarity=('Similarity_Score', 'mean')
818
  ).reset_index(drop=True)
819
 
820
  # If greedy_match=True, we keep all matches. If min_consecutive_pages > 1, we filter.
821
  if greedy_match and min_consecutive_pages <= 1:
822
  subdocument_df = agg_results
823
  else:
824
- # This handles the case for min_consecutive_pages > 1
825
- subdocument_df = agg_results[agg_results['Match_Length'] >= min_consecutive_pages].copy()
 
 
826
 
827
  if subdocument_df.empty:
828
  gr.Info("No matches found")
829
  return pd.DataFrame(), [], df_combined
830
-
831
  final_df = map_metadata_subdocument(subdocument_df, df_filtered)
832
  else:
833
- print(f"Finding single page matches, not greedy (min_consecutive_pages=1)")
834
  # This part of your code would handle the non-sequential case
835
  final_df = map_metadata_single_page(base_similarity_df, df_filtered)
836
- #subdocument_df = final_df # To align variable names for saving
837
 
838
  if final_df.empty:
839
  gr.Info("No matches found")
840
  return pd.DataFrame(), [], df_combined
841
 
842
  progress(0.9, desc="Saving output files")
843
-
844
- output_paths = save_results_and_redaction_lists(final_df, output_folder, combine_pages)
 
 
845
 
846
  gr.Info(f"Found {final_df.shape[0]} match(es)")
847
  print(f"Found {final_df.shape[0]} match(es)")
848
 
849
  return final_df, output_paths, df_combined
850
-
851
- def handle_selection_and_preview(evt: gr.SelectData, results_df:pd.DataFrame, full_duplicate_data_by_file: dict):
 
 
 
852
  """
853
  This single function handles a user selecting a row. It:
854
  1. Determines the selected row index.
@@ -857,18 +973,23 @@ def handle_selection_and_preview(evt: gr.SelectData, results_df:pd.DataFrame, fu
857
  """
858
  # If the user deselects, the event might be None.
859
  if not evt:
860
- return None, None, None # Clear state and both preview panes
861
 
862
  # 1. Get the selected index
863
  selected_index = evt.index[0]
864
 
865
  # 2. Get the preview data
866
- page1_data, page2_data = show_page_previews(full_duplicate_data_by_file, results_df, evt)
 
 
867
 
868
  # 3. Return all three outputs in the correct order
869
  return selected_index, page1_data, page2_data
870
 
871
- def exclude_match(results_df:pd.DataFrame, selected_index:int, output_folder=OUTPUT_FOLDER):
 
 
 
872
  """
873
  Removes a selected row from the results DataFrame, regenerates output files,
874
  and clears the text preview panes.
@@ -877,23 +998,34 @@ def exclude_match(results_df:pd.DataFrame, selected_index:int, output_folder=OUT
877
  gr.Warning("No match selected. Please click on a row in the table first.")
878
  # Return the original dataframe and update=False for the files
879
  return results_df, gr.update(), None, None
880
-
881
  if results_df.empty:
882
  gr.Warning("No duplicate page results found, nothing to exclude.")
883
  return results_df, gr.update(), None, None
884
 
885
  # Drop the selected row
886
  updated_df = results_df.drop(selected_index).reset_index(drop=True)
887
-
888
  # Recalculate all output files using the helper function
889
  new_output_paths = save_results_and_redaction_lists(updated_df, output_folder)
890
-
891
  gr.Info(f"Match at row {selected_index} excluded. Output files have been updated.")
892
-
893
  # Return the updated dataframe, the new file list, and clear the preview panes
894
  return updated_df, new_output_paths, None, None
895
 
896
- def run_duplicate_analysis(files:list[str], threshold:float, min_words:int, min_consecutive:int, greedy_match:bool, combine_pages:bool=True, preview_length:int=500, output_folder:str=OUTPUT_FOLDER, progress=gr.Progress(track_tqdm=True)):
 
 
 
 
 
 
 
 
 
 
 
897
  """
898
  Main wrapper function to orchestrate the duplicate page analysis process.
899
  It handles file loading, text combination, similarity identification,
@@ -911,9 +1043,11 @@ def run_duplicate_analysis(files:list[str], threshold:float, min_words:int, min_
911
  progress (gr.Progress, optional): A Gradio progress tracker object to display progress in the UI.
912
  """
913
 
914
- if not files: raise Warning("Please upload files to analyse.")
 
915
 
916
- if isinstance(files, str): files = [files]
 
917
 
918
  if len(files) > MAX_SIMULTANEOUS_FILES:
919
  out_message = f"Number of files to deduplicate is greater than {MAX_SIMULTANEOUS_FILES}. Please submit a smaller number of files."
@@ -923,9 +1057,11 @@ def run_duplicate_analysis(files:list[str], threshold:float, min_words:int, min_
923
  start_time = time.time()
924
 
925
  task_textbox = "deduplicate"
926
-
927
  progress(0, desc="Combining input files...")
928
- df_combined, _, full_out_ocr_df = combine_ocr_output_text(files, combine_pages=combine_pages)
 
 
929
 
930
  if df_combined.empty:
931
  raise Warning("No data found in the uploaded files.")
@@ -939,26 +1075,32 @@ def run_duplicate_analysis(files:list[str], threshold:float, min_words:int, min_
939
  greedy_match=greedy_match,
940
  combine_pages=combine_pages,
941
  output_folder=output_folder,
942
- progress=progress
943
  )
944
 
945
  # Clip text to first 200 characters
946
- full_df['text'] = full_df['text'].str[:preview_length]
947
  # Preprocess full_data (without preview text) for fast access (run once)
948
  full_data_by_file = {
949
- file: df.sort_values('page').set_index('page')
950
- for file, df in full_df.drop(["text_clean"],axis=1).groupby('file')
951
  }
952
 
953
  if results_df.empty:
954
- gr.Info(f"No duplicate pages found, no results returned.")
955
 
956
  end_time = time.time()
957
  processing_time = round(end_time - start_time, 2)
958
-
959
  return results_df, output_paths, full_data_by_file, processing_time, task_textbox
960
 
961
- def show_page_previews(full_data_by_file: dict, results_df: pd.DataFrame, evt: gr.SelectData, preview_length:int=500):
 
 
 
 
 
 
962
  """
963
  Optimized version using pre-partitioned and indexed full_data.
964
  Triggered when a user selects a row in the results DataFrame.
@@ -968,26 +1110,35 @@ def show_page_previews(full_data_by_file: dict, results_df: pd.DataFrame, evt: g
968
 
969
  selected_row = results_df.iloc[evt.index[0], :]
970
 
971
- is_subdocument_match = 'Page1_Start_Page' in selected_row
972
 
973
  if is_subdocument_match:
974
- file1, start1, end1 = selected_row['Page1_File'], selected_row['Page1_Start_Page'], selected_row['Page1_End_Page']
975
- file2, start2, end2 = selected_row['Page2_File'], selected_row['Page2_Start_Page'], selected_row['Page2_End_Page']
976
-
977
- page1_data = full_data_by_file[file1].loc[start1:end1, ['text']].reset_index()
978
- page2_data = full_data_by_file[file2].loc[start2:end2, ['text']].reset_index()
 
 
 
 
 
 
 
 
979
 
980
  else:
981
- file1, page1 = selected_row['Page1_File'], selected_row['Page1_Page']
982
- file2, page2 = selected_row['Page2_File'], selected_row['Page2_Page']
 
 
 
983
 
984
- page1_data = full_data_by_file[file1].loc[[page1], ['text']].reset_index()
985
- page2_data = full_data_by_file[file2].loc[[page2], ['text']].reset_index()
986
 
987
- page1_data['text'] = page1_data['text'].str[:preview_length]
988
- page2_data['text'] = page2_data['text'].str[:preview_length]
989
 
990
- return page1_data[['page', 'text']], page2_data[['page', 'text']]
991
 
992
  def get_page_image_info(page_num: int, page_sizes: List[Dict]) -> Optional[Dict]:
993
  """
@@ -995,10 +1146,9 @@ def get_page_image_info(page_num: int, page_sizes: List[Dict]) -> Optional[Dict]
995
  """
996
  return next((size for size in page_sizes if size["page"] == page_num), None)
997
 
 
998
  def add_new_annotations_to_existing_page_annotations(
999
- all_annotations: List[Dict],
1000
- image_path: str,
1001
- new_annotation_boxes: List[Dict]
1002
  ) -> Tuple[List[Dict], Dict]:
1003
  """
1004
  Adds a list of new annotation boxes to the annotations for a specific page.
@@ -1018,8 +1168,12 @@ def add_new_annotations_to_existing_page_annotations(
1018
  """
1019
  # Find the annotation group for the current page/image
1020
  current_page_group = next(
1021
- (annot_group for annot_group in all_annotations if annot_group["image"] == image_path),
1022
- None
 
 
 
 
1023
  )
1024
 
1025
  if current_page_group:
@@ -1027,22 +1181,27 @@ def add_new_annotations_to_existing_page_annotations(
1027
  current_page_group["boxes"].extend(new_annotation_boxes)
1028
  else:
1029
  # This is the first set of annotations for this page, create a new group
1030
- new_group = {
1031
- "image": image_path,
1032
- "boxes": new_annotation_boxes
1033
- }
1034
  all_annotations.append(new_group)
1035
 
1036
  # This object represents all annotations that were just added for this page
1037
- newly_added_annotation_group = {
1038
- "image": image_path,
1039
- "boxes": new_annotation_boxes
1040
- }
1041
 
1042
  return all_annotations, newly_added_annotation_group
1043
 
1044
- def apply_whole_page_redactions_from_list(duplicate_page_numbers_df: pd.DataFrame, doc_file_name_with_extension_textbox: str, review_file_state: pd.DataFrame, duplicate_output_paths: list[str], pymupdf_doc: object, page_sizes: list[dict], all_existing_annotations: list[dict], combine_pages:bool=True, new_annotations_with_bounding_boxes:List[dict]=list()):
1045
- '''
 
 
 
 
 
 
 
 
 
 
 
1046
  This function applies redactions to whole pages based on a provided list of duplicate page numbers. It supports two modes of operation: combining pages and not combining pages. When combining pages is enabled, it attempts to identify duplicate pages across different files and applies redactions accordingly. If combining pages is disabled, it relies on new annotations with bounding boxes to determine which pages to redact. The function utilises a PyMuPDF document object to manipulate the PDF file, and it also considers the sizes of pages to ensure accurate redaction application.
1047
 
1048
  Args:
@@ -1055,7 +1214,7 @@ def apply_whole_page_redactions_from_list(duplicate_page_numbers_df: pd.DataFram
1055
  all_existing_annotations (list[dict]): A list of all existing annotations in the document.
1056
  combine_pages (bool, optional): A flag indicating whether to combine pages for redaction. Defaults to True.
1057
  new_annotations_with_bounding_boxes (List[dict], optional): A list of new annotations with bounding boxes. Defaults to an empty list.
1058
- '''
1059
  if all_existing_annotations is None:
1060
  all_existing_annotations = []
1061
 
@@ -1069,31 +1228,39 @@ def apply_whole_page_redactions_from_list(duplicate_page_numbers_df: pd.DataFram
1069
  print(f"Warning: {message}")
1070
  raise Warning(message)
1071
 
1072
- list_whole_pages_to_redact = []
1073
 
1074
- if combine_pages == True:
1075
  # Get list of pages to redact from either dataframe or file
1076
  if not duplicate_page_numbers_df.empty:
1077
  list_whole_pages_to_redact = duplicate_page_numbers_df.iloc[:, 0].tolist()
1078
  elif duplicate_output_paths:
1079
- expected_duplicate_pages_to_redact_name = f"{doc_file_name_with_extension_textbox}"
 
 
1080
  whole_pages_list = pd.DataFrame() # Initialize empty DataFrame
1081
-
1082
  for output_file in duplicate_output_paths:
1083
  # Note: output_file.name might not be available if output_file is just a string path
1084
  # If it's a Path object or similar, .name is fine. Otherwise, parse from string.
1085
- file_name_from_path = output_file.split('/')[-1] if isinstance(output_file, str) else output_file.name
 
 
 
 
1086
  if expected_duplicate_pages_to_redact_name in file_name_from_path:
1087
- whole_pages_list = pd.read_csv(output_file, header=None) # Use output_file directly if it's a path
1088
- break
 
 
1089
  else:
1090
  message = "No relevant list of whole pages to redact found."
1091
  print(message)
1092
  raise Warning(message)
1093
-
1094
  if not whole_pages_list.empty:
1095
  list_whole_pages_to_redact = whole_pages_list.iloc[:, 0].tolist()
1096
-
1097
  list_whole_pages_to_redact = list(set(list_whole_pages_to_redact))
1098
 
1099
  else:
@@ -1101,19 +1268,20 @@ def apply_whole_page_redactions_from_list(duplicate_page_numbers_df: pd.DataFram
1101
  message = "Can't find any new annotations to add"
1102
  print(message)
1103
  raise Warning(message)
1104
-
1105
  list_whole_pages_to_redact = []
1106
  for annotation in new_annotations_with_bounding_boxes:
1107
- match = re.search(r'_(\d+)\.png$', annotation["image"])
1108
  if match:
1109
  page = int(match.group(1)) + 1
1110
  list_whole_pages_to_redact.append(page)
1111
  else:
1112
- print(f"Warning: Could not extract page number from {annotation['image']}")
 
 
1113
 
1114
  list_whole_pages_to_redact = list(set(list_whole_pages_to_redact))
1115
 
1116
-
1117
  new_annotations = []
1118
  # Process each page for redaction
1119
  for page in list_whole_pages_to_redact:
@@ -1130,47 +1298,65 @@ def apply_whole_page_redactions_from_list(duplicate_page_numbers_df: pd.DataFram
1130
  continue
1131
 
1132
  image_path = page_info["image_path"]
1133
- page_annotation_group = next((g for g in all_annotations if g["image"] == image_path), None)
1134
- if page_annotation_group and any(box["label"] == "Whole page" for box in page_annotation_group["boxes"]):
1135
- print(f"Whole page redaction for page {page_num} already exists, skipping.")
 
 
 
 
 
 
1136
  continue
1137
-
1138
  # --- Create a LIST of boxes to add.---
1139
  boxes_to_add = []
1140
-
1141
  pymupdf_page = pymupdf_doc[page_index]
1142
 
1143
- if combine_pages==True:
1144
  whole_page_box = redact_whole_pymupdf_page(
1145
  rect_height=page_info["cropbox_height"],
1146
  rect_width=page_info["cropbox_width"],
1147
- page=pymupdf_page, border=0.005, redact_pdf=False
 
 
1148
  )
1149
  boxes_to_add.append(whole_page_box)
1150
  else:
1151
  # Find the specific annotation group that matches the current page's image path
1152
  relevant_box_group = next(
1153
- (group for group in new_annotations_with_bounding_boxes if group.get('image') == image_path),
1154
- None # Default to None if no match is found
 
 
 
 
1155
  )
1156
-
1157
  # Check if we found a matching group of boxes for this page
1158
  if relevant_box_group:
1159
- boxes_to_add.extend(relevant_box_group['boxes'])
1160
  else:
1161
  # This case would be unexpected, but it's good to handle.
1162
  # It means a page was in list_whole_pages_to_redact but had no
1163
  # corresponding boxes generated in new_annotations_with_bounding_boxes.
1164
- print(f"Warning: No new annotation boxes found for page {page_num} ({image_path}).")
1165
-
 
 
1166
  # === Use the modified helper function to add a LIST of boxes ===
1167
- all_annotations, new_annotations_for_page = add_new_annotations_to_existing_page_annotations(
1168
- all_annotations=all_annotations,
1169
- image_path=image_path,
1170
- new_annotation_boxes=boxes_to_add # Pass the list here
 
 
1171
  )
1172
 
1173
- new_annotations_for_page = fill_missing_box_ids_each_box(new_annotations_for_page)
 
 
1174
  new_annotations.append(new_annotations_for_page)
1175
 
1176
  except Exception as e:
@@ -1185,36 +1371,58 @@ def apply_whole_page_redactions_from_list(duplicate_page_numbers_df: pd.DataFram
1185
  gr.Info(message)
1186
  return review_file_state, all_annotations
1187
 
1188
- expected_cols = ['image', 'page', 'label', 'color', 'xmin', 'ymin', 'xmax', 'ymax', 'text', 'id']
 
 
 
 
 
 
 
 
 
 
 
1189
  for col in expected_cols:
1190
- if col not in review_file_state.columns: review_file_state[col] = pd.NA
1191
- if col not in whole_page_review_file.columns: whole_page_review_file[col] = pd.NA
 
 
 
 
 
 
 
 
 
 
 
 
1192
 
1193
- review_file_out = pd.concat([review_file_state, whole_page_review_file], ignore_index=True)
1194
- review_file_out = review_file_out.sort_values(by=["page", "ymin", "xmin"]).reset_index(drop=True)
1195
- review_file_out = review_file_out.drop_duplicates(subset=['page', 'label', 'text', 'id'], keep='first')
1196
-
1197
  out_message = "Successfully created duplicate text redactions."
1198
  print(out_message)
1199
  gr.Info(out_message)
1200
 
1201
  return review_file_out, all_annotations
1202
 
 
1203
  def _parse_page_line_id(combined_id: int) -> Tuple[int, int]:
1204
  """Parses a combined ID using modular arithmetic."""
1205
  if int(combined_id) < ID_MULTIPLIER:
1206
  # Handle cases where page is 0 (or just an edge case)
1207
  return 0, combined_id
1208
-
1209
  page = combined_id // ID_MULTIPLIER
1210
  line = combined_id % ID_MULTIPLIER
1211
  return page, line
1212
 
 
1213
  def create_annotation_objects_from_duplicates(
1214
- duplicates_df: pd.DataFrame,
1215
  ocr_results_df: pd.DataFrame,
1216
  page_sizes: List[Dict],
1217
- combine_pages:bool=False) -> List[Dict]:
 
1218
  """
1219
  Creates structured annotation objects from duplicate line ranges, mapping
1220
  page numbers to image paths.
@@ -1233,33 +1441,44 @@ def create_annotation_objects_from_duplicates(
1233
  if duplicates_df.empty:
1234
  raise Warning("No duplicates found")
1235
  if ocr_results_df.empty:
1236
- raise Warning("No OCR results found for file under review. Please upload relevant OCR_output file and original PDF document on the review tab.")
 
 
1237
 
1238
- if combine_pages == False:
1239
- page_to_image_map = {item['page']: item['image_path'] for item in page_sizes}
1240
 
1241
  # Prepare OCR Data: Add a line number column if it doesn't exist
1242
- if 'line_number_by_page' not in ocr_results_df.columns:
1243
- ocr_results_df = ocr_results_df.sort_values(by=['page', 'top', 'left']).reset_index(drop=True)
1244
- ocr_results_df['line_number_by_page'] = ocr_results_df.groupby('page').cumcount() + 1
1245
-
 
 
 
 
1246
  annotations_by_page = defaultdict(list)
1247
 
1248
  # Iterate through each duplicate range (this logic is unchanged)
1249
  for _, row in duplicates_df.iterrows():
1250
- start_page, start_line = _parse_page_line_id(row['Page2_Start_Page'])
1251
- end_page, end_line = _parse_page_line_id(row['Page2_End_Page'])
1252
-
1253
  # Select OCR Lines based on the range (this logic is unchanged)
1254
  if start_page == end_page:
1255
- condition = (
1256
- (ocr_results_df['page'] == start_page) &
1257
- (ocr_results_df['line_number_by_page'].between(start_line, end_line))
1258
  )
1259
  else:
1260
- cond_start = (ocr_results_df['page'] == start_page) & (ocr_results_df['line_number_by_page'] >= start_line)
1261
- cond_middle = ocr_results_df['page'].between(start_page + 1, end_page - 1)
1262
- cond_end = (ocr_results_df['page'] == end_page) & (ocr_results_df['line_number_by_page'] <= end_line)
 
 
 
 
 
 
1263
  condition = cond_start | cond_middle | cond_end
1264
 
1265
  lines_to_annotate = ocr_results_df[condition]
@@ -1268,30 +1487,27 @@ def create_annotation_objects_from_duplicates(
1268
  for _, line_row in lines_to_annotate.iterrows():
1269
  box = {
1270
  "label": "Duplicate text",
1271
- "color": (0,0,0),
1272
- "xmin": line_row['left'],
1273
- "ymin": line_row['top'],
1274
- "xmax": line_row['left'] + line_row['width'],
1275
- "ymax": line_row['top'] + line_row['height'],
1276
- "text": line_row['text'],
1277
- "id": "" # to be filled in after
1278
  }
1279
- page_number = line_row['page']
1280
-
1281
  annotations_by_page[page_number].append(box)
1282
-
1283
  # --- Format the final output list using the page-to-image map ---
1284
  final_output = []
1285
  # Sort by page number for a predictable order
1286
  for page_num, boxes in sorted(annotations_by_page.items()):
1287
  # Look up the image path using the page number
1288
  image_path = page_to_image_map.get(page_num)
1289
-
1290
  if image_path:
1291
- page_boxes = {
1292
- "image": image_path,
1293
- "boxes": boxes
1294
- }
1295
 
1296
  # Fill in missing IDs for the new data entries
1297
  page_boxes = fill_missing_box_ids_each_box(page_boxes)
@@ -1300,7 +1516,9 @@ def create_annotation_objects_from_duplicates(
1300
  final_output.append(page_boxes)
1301
  else:
1302
  # Handle cases where a page might not have a corresponding image path
1303
- print(f"Warning: Page {page_num} found in OCR data but has no corresponding "
1304
- f"entry in the 'page_sizes' object. This page's annotations will be skipped.")
1305
-
1306
- return final_output
 
 
 
 
1
  import os
2
  import re
3
  import time
 
 
 
4
  from collections import defaultdict
5
+ from pathlib import Path
6
+ from typing import Dict, List, Optional, Tuple, Union
7
+
8
  import gradio as gr
9
+ import pandas as pd
10
  from gradio import Progress
11
+ from sklearn.feature_extraction.text import TfidfVectorizer
12
+ from sklearn.metrics.pairwise import cosine_similarity
13
+
14
  from tools.config import MAX_SIMULTANEOUS_FILES
15
+ from tools.file_conversion import (
16
+ convert_annotation_data_to_dataframe,
17
+ fill_missing_box_ids_each_box,
18
+ redact_whole_pymupdf_page,
19
+ )
20
+ from tools.helper_functions import OUTPUT_FOLDER
21
  from tools.load_spacy_model_custom_recognisers import nlp
22
 
23
+ number_of_zeros_to_add_to_index = 7 # Number of zeroes to add between page number and line numbers to get a unique page/line index value
24
  ID_MULTIPLIER = 100000
25
  # Define the set of punctuation characters for efficient lookup
26
+ PUNCTUATION_TO_STRIP = {".", ",", "?", "!", ":", ";"}
27
+
28
 
29
  def split_text_with_punctuation(text: str) -> List[str]:
30
  """
 
35
  # 1. A sequence of one or more punctuation marks `[.,?!:;]+`
36
  # 2. OR a sequence of one or more characters that are NOT punctuation or whitespace `[^.,?!:;\s]+`
37
  pattern = re.compile(r"([.,?!:;]+|[^.,?!:;\s]+)")
38
+
39
  final_list = []
40
  # We first split by whitespace to handle sentences correctly
41
  for word in text.split():
42
  # Then, for each whitespace-separated word, we tokenize it further
43
  final_list.extend(pattern.findall(word))
44
+
45
  return final_list
46
 
47
+
48
  def extract_indices_from_page_ranges(
49
  results_df: pd.DataFrame,
50
+ start_col: str = "Page2_Start_Page",
51
+ end_col: str = "Page2_End_Page",
52
+ modulo_divisor_number_of_zeros: int = number_of_zeros_to_add_to_index, # Search for number of added
53
+ converted_index: bool = False, # Has the index been converted to the page_no + 0000 + line number format that needs the modulo divisor to convert back?
54
  ) -> List[int]:
55
  all_indices = set()
56
+ int("1" + modulo_divisor_number_of_zeros * "0")
57
 
58
  for _, row in results_df.iterrows():
59
  start_page = row[start_col]
60
  end_page = row[end_col]
61
  for encoded_page_id in range(start_page, end_page + 1):
62
+ if converted_index is True:
63
+ original_page, original_index = _parse_page_line_id(
64
+ encoded_page_id
65
+ ) # (encoded_page_id % modulo_divisor) - 1
66
  else:
67
  original_index = encoded_page_id
68
 
69
  all_indices.add(original_index)
70
  return sorted(list(all_indices))
71
 
72
+
73
  def punctuation_at_word_text_end(word_level_df_orig: pd.DataFrame) -> bool:
74
  """
75
+ Check the first 1000 rows of word_level_df_orig to see if any of the strings
76
  in 'word_text' end with a full stop '.', exclamation mark '!', or question mark '?',
77
  for strings that do not contain these characters alone.
78
+
79
  Args:
80
  word_level_df_orig (pd.DataFrame): DataFrame containing word-level OCR data with 'word_text' column
81
+
82
  Returns:
83
  bool: True if any strings end with punctuation marks, False otherwise
84
  """
85
  # Get the first 1000 rows or all rows if less than 1000
86
  sample_df = word_level_df_orig.head(1000)
87
+
88
  # Check if 'word_text' column exists
89
+ if "word_text" not in sample_df.columns:
90
  return False
91
+
92
  # Define punctuation marks to check for
93
+ punctuation_marks = [".", "!", "?"]
94
+
95
  # Check each word_text string
96
+ for word_text in sample_df["word_text"]:
97
  if pd.isna(word_text) or not isinstance(word_text, str):
98
  continue
99
+
100
  # Skip strings that contain only punctuation marks
101
  if word_text.strip() in punctuation_marks:
102
  continue
103
+
104
  # Check if the string ends with any of the punctuation marks
105
  if any(word_text.rstrip().endswith(punct) for punct in punctuation_marks):
106
  return True
107
+
108
  return False
109
 
110
+
111
  def run_full_search_and_analysis(
112
  search_query_text: str,
113
  word_level_df_orig: pd.DataFrame,
114
  similarity_threshold: float = 1,
115
+ combine_pages: bool = False,
116
  min_word_count: int = 1,
117
  min_consecutive_pages: int = 1,
118
  greedy_match: bool = True,
119
  remake_index: bool = False,
120
+ progress=gr.Progress(track_tqdm=True),
121
  ):
122
  """
123
  This function orchestrates the entire pipeline for finding duplicate pages based on a user's search query. It takes in the search query text, the original word-level OCR data, and various parameters to control the analysis. The function then:
 
131
  - search_query_text (str): The text entered by the user to search for in the OCR data.
132
  - word_level_df_orig (pd.DataFrame): The original DataFrame containing word-level OCR data.
133
  - similarity_threshold (float, optional): The minimum similarity score required for two pages to be considered duplicates. Defaults to 1.
134
+ - combine_pages (bool, optional): A flag indicating whether to combine text from the same page number within a file. Defaults to False.
135
  - min_word_count (int, optional): The minimum number of words required for a page to be considered in the analysis. Defaults to 1.
136
  - min_consecutive_pages (int, optional): The minimum number of consecutive pages required to be considered a match. Defaults to 1.
137
  - greedy_match (bool, optional): A flag indicating whether to use a greedy strategy for matching consecutive pages. Defaults to True.
 
144
  if len(search_query_text) > 100:
145
  raise Warning("Please use a search query with at less than 100 characters.")
146
 
147
+ if punctuation_at_word_text_end(word_level_df_orig) is True:
148
+ do_punctuation_split = False
149
+ else:
150
+ do_punctuation_split = True
151
 
152
  # Step 1: Process the user's search query string
153
+ search_query_data, query_word_length = create_dataframe_from_string(
154
+ search_query_text,
155
+ file_name="user_search_query",
156
+ split_words=True,
157
+ split_punctuation=do_punctuation_split,
158
+ )
159
  if not search_query_data:
160
  # Handle case where user submits an empty search string
161
+ raise Warning("Could not convert search string to required format")
162
 
163
  if query_word_length > 25:
164
  # Handle case where user submits an empty search string
165
+ raise Warning("Please use a query with less than 25 words")
166
 
167
  # Overwrite min_consecutive_pages with the search string length
168
  min_consecutive_pages = query_word_length
169
+
170
  # Create word index from reference table
171
  word_level_df_orig["index"] = word_level_df_orig.index
172
+ word_level_df = word_level_df_orig.copy()
173
 
174
  # Step 2: Process the main word-level OCR DataFrame
175
  word_level_data = convert_word_level_df(word_level_df, file_name="source_document")
 
178
  all_data_to_process = search_query_data + word_level_data
179
  if not all_data_to_process:
180
  raise gr.Error("No data to process. Please check your inputs.")
181
+
182
  # Step 4: Run the combination logic
183
  combined_df, _, full_out_ocr_df = combine_ocr_dataframes(
184
  input_data=all_data_to_process,
185
  combine_pages=combine_pages,
186
+ output_folder=None, # No need to save this intermediate file
187
+ remake_index=remake_index,
188
  )
189
 
190
  # Step 5: Run the final similarity analysis on the combined data
 
199
  do_text_clean=False,
200
  file1_name="user_search_query",
201
  file2_name="source_document",
202
+ progress=progress,
203
  )
204
 
205
  print("Finished text search")
206
 
207
  # Map the results back to the reference data file
208
+ if remake_index is True:
209
+ results_df_index_list = extract_indices_from_page_ranges(
210
+ results_df, converted_index=True
211
+ )
212
  else:
213
+ results_df_index_list = extract_indices_from_page_ranges(
214
+ results_df, converted_index=False
215
+ )
216
 
217
+ word_level_df_out = word_level_df_orig.loc[
218
+ word_level_df_orig["index"].isin(results_df_index_list)
219
+ ]
220
 
221
  return word_level_df_out, duplicate_files, full_data
222
 
223
+
224
+ def create_all_data_to_process(
225
+ converted_data: pd.DataFrame, other_data_list: List[Tuple]
226
+ ):
227
  all_data_to_process = converted_data + other_data_list
228
  return all_data_to_process
229
 
230
+
231
  def convert_word_level_df(
232
+ word_level_df: pd.DataFrame, file_name: str = "converted_dataframe"
 
233
  ) -> List[Tuple[str, pd.DataFrame]]:
234
  """
235
  Converts a word-level OCR DataFrame to the format for
 
252
  DataFrame will have 'page' and 'text' columns.
253
  """
254
  # --- 1. Validate Input ---
255
+ required_columns = ["page", "line", "word_text"]
256
  if not all(col in word_level_df.columns for col in required_columns):
257
+ raise ValueError(
258
+ f"Input DataFrame must contain all of the following columns: {required_columns}"
259
+ )
260
 
261
  df = word_level_df.copy()
262
 
263
  # --- 2. Process the DataFrame ---
264
  # Ensure word_text is a string to allow for joining
265
+ df["word_text"] = df["word_text"].astype(str)
266
 
267
  # Group by page and line number, then join the words with a space (not needed for word level search)
268
  # The result is a Series with a MultiIndex (page, line)
269
+ # line_text_series = df.groupby(['page', 'line'])['word_text'].apply(' '.join)
270
 
271
  # Convert the Series back to a DataFrame and reset the index
272
+ # line_level_df = line_text_series.reset_index()
273
 
274
  # Rename the aggregated column from 'word_text' to the required 'text'
275
+ df = df.rename(columns={"word_text": "text"})
276
 
277
  # --- 3. Finalise the structure ---
278
  # We now have a DataFrame with columns [page, line, text].
279
+ final_df = df[["page", "text"]]
280
 
281
  # --- 4. Package for output ---
282
  # Return in the required List[Tuple[str, DataFrame]] format
283
  return [(file_name, final_df)]
284
 
285
+
286
  def create_dataframe_from_string(
287
  text_string: str,
288
  file_name: str = "user_search_query",
 
322
 
323
  if split_words:
324
  # --- Split string into words, one per row, based on similar punctuation split technique used to create ocr_results_with_words objects ---
325
+ if split_punctuation is True:
326
  words = split_text_with_punctuation(text_string)
327
  else:
328
  words = text_string.split()
329
+
330
+ # words = text_string.split()
331
  len_words = len(words)
332
+ data = {
333
+ "page": [page_number]
334
+ * len_words, # Assign the same page number to every word
335
+ "text": words, # The list of words becomes the text column
336
  }
337
  else:
338
  # --- Entire string in one row ---
339
  len_words = 1
340
+ data = {"page": [page_number], "text": [text_string]}
 
 
 
341
 
342
  # Create the DataFrame from the prepared data
343
  df = pd.DataFrame(data)
 
347
  # Return it in the required format: a list containing one (name, df) tuple
348
  return [(file_name, df)], len_words
349
 
350
+
351
  def combine_ocr_dataframes(
352
  input_data: List[Tuple[str, pd.DataFrame]],
353
  combine_pages: bool = True,
354
  output_folder: str = OUTPUT_FOLDER,
355
  output_filename: str = "combined_ocr_output.csv",
356
  number_of_added_zeros: int = number_of_zeros_to_add_to_index,
357
+ remake_index: bool = True,
358
  ) -> Tuple[pd.DataFrame, List[str]]:
359
  """
360
  Combines text from multiple pandas DataFrames containing page and text columns.
 
387
  df = df_initial.copy() # Work on a copy to avoid side effects
388
 
389
  # --- Validation ---
390
+ if "page" not in df.columns or "text" not in df.columns:
391
+ print(
392
+ f"Warning: Skipping data for '{file_identifier}' - missing required columns 'page' and 'text'."
393
+ )
394
  continue
395
 
396
  # --- Processing ---
397
+ df["text"] = df["text"].fillna("").astype(str)
398
 
399
  if combine_pages:
400
  # Group by page and concatenate text into a single string
401
+ processed_df = df.groupby("page")["text"].apply(" ".join).reset_index()
402
  else:
403
+ if remake_index is True:
404
  # # Create a unique, sortable page ID for each line without combining
405
  # df['line_number_by_page'] = df.groupby('page').cumcount() + 1
406
  # df['original_page'] = df['page']
 
408
  # df['page'] = (
409
  # df['page'].astype(str).str.zfill(number_of_added_zeros) +
410
  # df['line_number_by_page'].astype(str).str.zfill(number_of_added_zeros)
411
+ # ).astype(int)
412
 
413
  # Define the multiplier based on the max expected lines per page.
414
  # If you expect up to 99,999 lines, use 100,000.
415
 
416
+ df["line_number_by_page"] = df.groupby("page").cumcount() + 1
417
+ df["original_page"] = df["page"]
418
 
419
  # Create the new combined ID using arithmetic
420
+ df["page"] = (df["original_page"] * ID_MULTIPLIER) + df[
421
+ "line_number_by_page"
422
+ ]
423
+
424
  else:
425
+ if "index" not in df.columns:
426
+ df["index"] = df.index
427
+ df["page"] = df["index"]
428
+
429
  processed_df = df
430
 
431
  # Add the file identifier column
432
+ processed_df["file"] = file_identifier
433
  all_data.append(processed_df)
434
 
435
  if not all_data:
436
+ raise ValueError(
437
+ "No valid DataFrames were processed. Ensure input data is not empty and DataFrames have 'page' and 'text' columns."
438
+ )
439
 
440
  # --- Final Combination ---
441
  combined_df = pd.concat(all_data, ignore_index=True)
442
 
443
  # Reorder columns to a standard format, dropping intermediate columns
444
+ final_columns = ["file", "page", "text"]
445
+ if "original_page" in combined_df.columns:
446
+ final_columns.append("original_page") # Keep for context if created
447
+
448
  # Ensure all final columns exist before trying to select them
449
+ existing_final_columns = [
450
+ col for col in final_columns if col in combined_df.columns
451
+ ]
452
 
453
  full_out_ocr_df = combined_df
454
  combined_df = combined_df.copy()[existing_final_columns]
 
464
 
465
  return combined_df, output_files, full_out_ocr_df
466
 
467
+
468
  def combine_ocr_output_text(
469
  input_files: Union[str, List[str]],
470
  combine_pages: bool = True,
471
  remake_index: bool = True,
472
+ output_folder: str = OUTPUT_FOLDER,
473
  ) -> Tuple[pd.DataFrame, List[str]]:
474
  """
475
  Reads multiple OCR CSV files, combines them, and saves the result.
 
510
  input_data=data_to_process,
511
  combine_pages=combine_pages,
512
  output_folder=output_folder,
513
+ output_filename="combined_ocr_from_files.csv", # Specific name for this path
514
+ remake_index=remake_index,
515
  )
516
 
517
+
518
+ def clean_and_stem_text_series(df: pd.DataFrame, column: str):
519
+ """
520
  Clean and stem text columns in a data frame
521
+ """
522
+
523
  def _clean_text(raw_text):
524
  # Remove HTML tags
525
+ clean = re.sub(r"<.*?>", "", raw_text)
526
+ clean = " ".join(clean.split())
527
  # Join the cleaned words back into a string
528
  return clean
529
 
 
531
  def _apply_lemmatization(text):
532
  doc = nlp(text)
533
  # Keep only alphabetic tokens and remove stopwords
534
+ lemmatized_words = [
535
+ token.lemma_ for token in doc if token.is_alpha and not token.is_stop
536
+ ]
537
+ return " ".join(lemmatized_words)
538
+
539
+ df["text_clean"] = df[column].apply(_clean_text)
540
+
541
+ df["text_clean"] = df["text_clean"].apply(_apply_lemmatization)
542
 
 
 
543
  return df
544
 
545
+
546
+ def map_metadata_single_page(
547
+ similarity_df: pd.DataFrame,
548
+ metadata_source_df: pd.DataFrame,
549
+ preview_length: int = 200,
550
+ ):
551
  """Helper to map metadata for single page results."""
552
+ metadata_df = metadata_source_df[["file", "page", "text"]]
553
+ results_df = similarity_df.merge(
554
+ metadata_df, left_on="Page1_Index", right_index=True
555
+ ).rename(columns={"file": "Page1_File", "page": "Page1_Page", "text": "Page1_Text"})
556
+ results_df = results_df.merge(
557
+ metadata_df, left_on="Page2_Index", right_index=True, suffixes=("_1", "_2")
558
+ ).rename(columns={"file": "Page2_File", "page": "Page2_Page", "text": "Page2_Text"})
559
  results_df["Similarity_Score"] = results_df["Similarity_Score"].round(3)
560
+ final_df = results_df[
561
+ [
562
+ "Page1_File",
563
+ "Page1_Page",
564
+ "Page2_File",
565
+ "Page2_Page",
566
+ "Similarity_Score",
567
+ "Page1_Text",
568
+ "Page2_Text",
569
+ ]
570
+ ]
571
+ final_df = final_df.sort_values(
572
+ ["Page1_File", "Page1_Page", "Page2_File", "Page2_Page"]
573
+ )
574
+ final_df["Page1_Text"] = final_df["Page1_Text"].str[:preview_length]
575
+ final_df["Page2_Text"] = final_df["Page2_Text"].str[:preview_length]
576
  return final_df
577
 
578
+
579
+ def map_metadata_subdocument(
580
+ subdocument_df: pd.DataFrame,
581
+ metadata_source_df: pd.DataFrame,
582
+ preview_length: int = 200,
583
+ ):
584
  """Helper to map metadata for subdocument results."""
585
+ metadata_df = metadata_source_df[["file", "page", "text"]]
586
+
587
+ subdocument_df = subdocument_df.merge(
588
+ metadata_df, left_on="Page1_Start_Index", right_index=True
589
+ ).rename(
590
+ columns={"file": "Page1_File", "page": "Page1_Start_Page", "text": "Page1_Text"}
591
+ )
592
+ subdocument_df = subdocument_df.merge(
593
+ metadata_df[["page"]], left_on="Page1_End_Index", right_index=True
594
+ ).rename(columns={"page": "Page1_End_Page"})
595
+ subdocument_df = subdocument_df.merge(
596
+ metadata_df, left_on="Page2_Start_Index", right_index=True
597
+ ).rename(
598
+ columns={"file": "Page2_File", "page": "Page2_Start_Page", "text": "Page2_Text"}
599
+ )
600
+ subdocument_df = subdocument_df.merge(
601
+ metadata_df[["page"]], left_on="Page2_End_Index", right_index=True
602
+ ).rename(columns={"page": "Page2_End_Page"})
603
+
604
+ cols = [
605
+ "Page1_File",
606
+ "Page1_Start_Page",
607
+ "Page1_End_Page",
608
+ "Page2_File",
609
+ "Page2_Start_Page",
610
+ "Page2_End_Page",
611
+ "Match_Length",
612
+ "Page1_Text",
613
+ "Page2_Text",
614
+ ]
615
+
616
  # Add Avg_Similarity if it exists (it won't for greedy match unless we add it)
617
+ if "Avg_Similarity" in subdocument_df.columns:
618
+ subdocument_df["Avg_Similarity"] = subdocument_df["Avg_Similarity"].round(3)
619
+ cols.insert(7, "Avg_Similarity")
620
 
621
  final_df = subdocument_df[cols]
622
+ final_df = final_df.sort_values(
623
+ ["Page1_File", "Page1_Start_Page", "Page2_File", "Page2_Start_Page"]
624
+ )
625
+ final_df["Page1_Text"] = final_df["Page1_Text"].str[:preview_length]
626
+ final_df["Page2_Text"] = final_df["Page2_Text"].str[:preview_length]
627
 
628
  return final_df
629
 
630
+
631
+ def save_results_and_redaction_lists(
632
+ final_df: pd.DataFrame, output_folder: str, combine_pages: bool = True
633
+ ) -> list:
634
  """
635
  Saves the main results DataFrame and generates per-file redaction lists.
636
  This function is extracted to be reusable.
 
652
  return []
653
 
654
  # 1. Save the main results DataFrame
655
+ similarity_file_output_path = output_folder_path / "page_similarity_results.csv"
656
  final_df.to_csv(similarity_file_output_path, index=False, encoding="utf-8-sig")
657
 
658
  output_paths.append(str(similarity_file_output_path))
659
+ # print(f"Main results saved to {similarity_file_output_path}")
660
 
661
  # 2. Save per-file redaction lists
662
  # Use 'Page2_File' as the source of duplicate content
663
+ if combine_pages is True:
664
+ grouping_col = "Page2_File"
665
  if grouping_col not in final_df.columns:
666
+ print(
667
+ "Warning: 'Page2_File' column not found. Cannot generate redaction lists."
668
+ )
669
  return output_paths
670
 
671
  for redact_file, group in final_df.groupby(grouping_col):
672
  output_file_name_stem = Path(redact_file).stem
673
+ output_file_path = (
674
+ output_folder_path / f"{output_file_name_stem}_pages_to_redact.csv"
675
+ )
676
+
677
  all_pages_to_redact = set()
678
+ is_subdocument_match = "Page2_Start_Page" in group.columns
679
 
680
  if is_subdocument_match:
681
  for _, row in group.iterrows():
682
+ pages_in_range = range(
683
+ int(row["Page2_Start_Page"]), int(row["Page2_End_Page"]) + 1
684
+ )
685
  all_pages_to_redact.update(pages_in_range)
686
  else:
687
+ pages = group["Page2_Page"].unique()
688
  all_pages_to_redact.update(pages)
689
+
690
  if all_pages_to_redact:
691
+ redaction_df = pd.DataFrame(
692
+ sorted(list(all_pages_to_redact)), columns=["Page_to_Redact"]
693
+ )
694
  redaction_df.to_csv(output_file_path, header=False, index=False)
695
 
696
  output_paths.append(str(output_file_path))
697
  print(f"Redaction list for {redact_file} saved to {output_file_path}")
698
+
699
  return output_paths
700
 
701
+
702
  def _sequences_match(query_seq: List[str], ref_seq: List[str]) -> bool:
703
  """
704
  Helper function to compare two sequences of tokens with punctuation flexibility.
 
722
  # - Its last character must be in our punctuation set
723
  # - The token without its last character must match the query token
724
  if (
725
+ len(ref_token) > 1
726
+ and ref_token[-1] in PUNCTUATION_TO_STRIP
727
+ and ref_token[:-1] == query_token
728
  ):
729
  continue
730
 
 
734
  # If the loop completes, every token has matched.
735
  return True
736
 
737
+
738
  def find_consecutive_sequence_matches(
739
+ df_filtered: pd.DataFrame, search_file_name: str, reference_file_name: str
 
 
740
  ) -> pd.DataFrame:
741
  """
742
  Finds all occurrences of a consecutive sequence of tokens from a search file
 
753
  A DataFrame with two columns ('Page1_Index', 'Page2_Index') mapping the
754
  consecutive match, or an empty DataFrame if no match is found.
755
  """
756
+ # print(f"Starting sequence search for '{search_file_name}' in '{reference_file_name}'...")
757
 
758
  # Step 1: Isolate the data for each file
759
+ search_df = df_filtered[df_filtered["file"] == search_file_name]
760
+ reference_df = df_filtered[df_filtered["file"] == reference_file_name]
761
 
762
  if search_df.empty or reference_df.empty:
763
  print("Error: One or both files not found or are empty.")
764
+ return pd.DataFrame(columns=["Page1_Index", "Page2_Index"])
765
 
766
  # Step 2: Convert the token data into lists for easy comparison.
767
  # We need both the text tokens and their original global indices.
768
+ query_tokens = search_df["text_clean"].tolist()
769
  query_indices = search_df.index.tolist()
770
+
771
+ reference_tokens = reference_df["text_clean"].tolist()
772
  reference_indices = reference_df.index.tolist()
773
 
774
  query_len = len(query_tokens)
 
783
 
784
  # Step 4: If the window matches the query with or without punctuation on end
785
  if _sequences_match(query_tokens, window):
786
+ # print(f"Found a consecutive match starting at reference index: {reference_indices[i]}")
787
+
788
  # Get the global indices for this entire matching block
789
  matching_reference_indices = reference_indices[i : i + query_len]
790
+
791
  # Create the mapping between query indices and the found reference indices
792
  for j in range(query_len):
793
  all_found_matches.append(
794
  (query_indices[j], matching_reference_indices[j], 1)
795
  )
796
+
797
  # If you only want the *first* match, you can uncomment the next line:
798
+ # break
799
 
800
  if not all_found_matches:
801
  print("No matches found")
802
  gr.Info("No matches found")
803
+ return pd.DataFrame(columns=["Page1_Index", "Page2_Index", "Similarity_Score"])
804
 
805
  # Step 5: Create the final DataFrame in the desired format
806
+ result_df = pd.DataFrame(
807
+ all_found_matches, columns=["Page1_Index", "Page2_Index", "Similarity_Score"]
808
+ )
809
  return result_df
810
 
811
+
812
  def identify_similar_text_sequences(
813
  df_combined: pd.DataFrame,
814
  similarity_threshold: float = 1,
 
817
  greedy_match: bool = True,
818
  combine_pages: bool = False,
819
  inter_file_only: bool = False,
820
+ do_text_clean: bool = True,
821
+ file1_name: str = "",
822
+ file2_name: str = "",
823
  output_folder: str = OUTPUT_FOLDER,
824
+ progress=Progress(track_tqdm=True),
825
  ) -> Tuple[pd.DataFrame, List[str], pd.DataFrame]:
826
  """
827
  Identifies similar pages. Uses a highly optimized path for inter_file_only=True.
 
829
  progress(0.1, desc="Processing and filtering text")
830
 
831
  if do_text_clean:
832
+ df = clean_and_stem_text_series(
833
+ df_combined, "text"
834
+ ) # Will produce the column 'text_clean'
835
  else:
836
  df = df_combined.copy()
837
+ df["text_clean"] = df[
838
+ "text"
839
+ ].str.lower() # .str.replace(r'[^\w\s]', '', regex=True)
840
 
841
+ df["word_count"] = df["text_clean"].str.split().str.len().fillna(0)
842
+ # df['word_count'] = pd.to_numeric(df['word_count'], errors='coerce').fillna(0).astype('int64')
843
 
844
  # ensure min_word_count is an int (e.g., from Gradio/text input)
845
  try:
 
848
  min_word_count = 0 # or raise/log, depending on your preference
849
 
850
  original_row_count = len(df)
851
+ df_filtered = df[df["word_count"] >= min_word_count].copy()
852
  df_filtered.reset_index(drop=True, inplace=True)
853
+
854
+ print(
855
+ f"Filtered out {original_row_count - len(df_filtered)} pages with fewer than {min_word_count} words."
856
+ )
857
  if len(df_filtered) < 2:
858
  return pd.DataFrame(), [], df_combined
859
 
 
 
860
  # Similarity calculated differently if comparing between files only (inter_file_only==True), or within the same file
861
  if inter_file_only:
862
 
863
  progress(0.2, desc="Finding direct text matches...")
864
+
865
+ # base_similarity_df = _debug_similarity_between_two_files(df_filtered, vectorizer, similarity_threshold, file1_name, file2_name)
866
+ base_similarity_df = find_consecutive_sequence_matches(
867
+ df_filtered, file1_name, file2_name
868
+ )
869
  if base_similarity_df.empty:
870
+ return pd.DataFrame(), [], df_combined
871
+
872
  else:
873
  # Use the original, simpler path for all-to-all comparisons (including intra-file).
874
  vectorizer = TfidfVectorizer()
875
  print("Standard Path: Calculating all-to-all similarity.")
876
  progress(0.2, desc="Vectorizing text...")
877
+ tfidf_matrix = vectorizer.fit_transform(df_filtered["text_clean"])
878
 
879
  progress(0.3, desc="Calculating similarity matrix...")
880
  similarity_matrix = cosine_similarity(tfidf_matrix, dense_output=False)
881
  coo_matrix = similarity_matrix.tocoo()
882
 
883
  similar_pages = [
884
+ (r, c, v)
885
+ for r, c, v in zip(coo_matrix.row, coo_matrix.col, coo_matrix.data)
886
  if r < c and v >= similarity_threshold
887
  ]
888
 
889
  if not similar_pages:
890
  return pd.DataFrame(), [], df_combined
891
+
892
+ base_similarity_df = pd.DataFrame(
893
+ similar_pages, columns=["Page1_Index", "Page2_Index", "Similarity_Score"]
894
+ )
895
 
896
  progress(0.7, desc="Aggregating results based on matching strategy")
897
 
898
  if greedy_match or min_consecutive_pages > 1:
899
+ # print("Finding all consecutive page matches of minimum length:", min_consecutive_pages)
900
+
901
  # Sort the dataframe to ensure consecutive pages are adjacent
902
+ similarity_df = (
903
+ base_similarity_df # .sort_values(['Page1_Index', 'Page2_Index']).copy()
904
+ )
905
 
906
  # A new sequence starts if the difference from the previous row is not (1, 1)
907
  # is_consecutive will be True if a row continues the sequence, False if it's a new one.
908
+ is_consecutive = (similarity_df["Page1_Index"].diff() == 1) & (
909
+ similarity_df["Page2_Index"].diff() == 1
910
+ )
911
 
912
  # Use cumsum() on the inverted boolean series to create a unique ID for each block.
913
  # Every time a 'False' appears (a new block starts), the sum increases.
 
918
 
919
  # Aggregate each group to get the start, end, and length of the match
920
  agg_results = grouped.agg(
921
+ Page1_Start_Index=("Page1_Index", "first"),
922
+ Page2_Start_Index=("Page2_Index", "first"),
923
+ Page1_End_Index=("Page1_Index", "last"),
924
+ Page2_End_Index=("Page2_Index", "last"),
925
+ Match_Length=("Page1_Index", "size"),
926
+ Avg_Similarity=("Similarity_Score", "mean"),
927
  ).reset_index(drop=True)
928
 
929
  # If greedy_match=True, we keep all matches. If min_consecutive_pages > 1, we filter.
930
  if greedy_match and min_consecutive_pages <= 1:
931
  subdocument_df = agg_results
932
  else:
933
+ # This handles the case for min_consecutive_pages > 1
934
+ subdocument_df = agg_results[
935
+ agg_results["Match_Length"] >= min_consecutive_pages
936
+ ].copy()
937
 
938
  if subdocument_df.empty:
939
  gr.Info("No matches found")
940
  return pd.DataFrame(), [], df_combined
941
+
942
  final_df = map_metadata_subdocument(subdocument_df, df_filtered)
943
  else:
944
+ print("Finding single page matches, not greedy (min_consecutive_pages=1)")
945
  # This part of your code would handle the non-sequential case
946
  final_df = map_metadata_single_page(base_similarity_df, df_filtered)
947
+ # subdocument_df = final_df # To align variable names for saving
948
 
949
  if final_df.empty:
950
  gr.Info("No matches found")
951
  return pd.DataFrame(), [], df_combined
952
 
953
  progress(0.9, desc="Saving output files")
954
+
955
+ output_paths = save_results_and_redaction_lists(
956
+ final_df, output_folder, combine_pages
957
+ )
958
 
959
  gr.Info(f"Found {final_df.shape[0]} match(es)")
960
  print(f"Found {final_df.shape[0]} match(es)")
961
 
962
  return final_df, output_paths, df_combined
963
+
964
+
965
+ def handle_selection_and_preview(
966
+ evt: gr.SelectData, results_df: pd.DataFrame, full_duplicate_data_by_file: dict
967
+ ):
968
  """
969
  This single function handles a user selecting a row. It:
970
  1. Determines the selected row index.
 
973
  """
974
  # If the user deselects, the event might be None.
975
  if not evt:
976
+ return None, None, None # Clear state and both preview panes
977
 
978
  # 1. Get the selected index
979
  selected_index = evt.index[0]
980
 
981
  # 2. Get the preview data
982
+ page1_data, page2_data = show_page_previews(
983
+ full_duplicate_data_by_file, results_df, evt
984
+ )
985
 
986
  # 3. Return all three outputs in the correct order
987
  return selected_index, page1_data, page2_data
988
 
989
+
990
+ def exclude_match(
991
+ results_df: pd.DataFrame, selected_index: int, output_folder=OUTPUT_FOLDER
992
+ ):
993
  """
994
  Removes a selected row from the results DataFrame, regenerates output files,
995
  and clears the text preview panes.
 
998
  gr.Warning("No match selected. Please click on a row in the table first.")
999
  # Return the original dataframe and update=False for the files
1000
  return results_df, gr.update(), None, None
1001
+
1002
  if results_df.empty:
1003
  gr.Warning("No duplicate page results found, nothing to exclude.")
1004
  return results_df, gr.update(), None, None
1005
 
1006
  # Drop the selected row
1007
  updated_df = results_df.drop(selected_index).reset_index(drop=True)
1008
+
1009
  # Recalculate all output files using the helper function
1010
  new_output_paths = save_results_and_redaction_lists(updated_df, output_folder)
1011
+
1012
  gr.Info(f"Match at row {selected_index} excluded. Output files have been updated.")
1013
+
1014
  # Return the updated dataframe, the new file list, and clear the preview panes
1015
  return updated_df, new_output_paths, None, None
1016
 
1017
+
1018
+ def run_duplicate_analysis(
1019
+ files: list[str],
1020
+ threshold: float,
1021
+ min_words: int,
1022
+ min_consecutive: int,
1023
+ greedy_match: bool,
1024
+ combine_pages: bool = True,
1025
+ preview_length: int = 500,
1026
+ output_folder: str = OUTPUT_FOLDER,
1027
+ progress=gr.Progress(track_tqdm=True),
1028
+ ):
1029
  """
1030
  Main wrapper function to orchestrate the duplicate page analysis process.
1031
  It handles file loading, text combination, similarity identification,
 
1043
  progress (gr.Progress, optional): A Gradio progress tracker object to display progress in the UI.
1044
  """
1045
 
1046
+ if not files:
1047
+ raise Warning("Please upload files to analyse.")
1048
 
1049
+ if isinstance(files, str):
1050
+ files = [files]
1051
 
1052
  if len(files) > MAX_SIMULTANEOUS_FILES:
1053
  out_message = f"Number of files to deduplicate is greater than {MAX_SIMULTANEOUS_FILES}. Please submit a smaller number of files."
 
1057
  start_time = time.time()
1058
 
1059
  task_textbox = "deduplicate"
1060
+
1061
  progress(0, desc="Combining input files...")
1062
+ df_combined, _, full_out_ocr_df = combine_ocr_output_text(
1063
+ files, combine_pages=combine_pages
1064
+ )
1065
 
1066
  if df_combined.empty:
1067
  raise Warning("No data found in the uploaded files.")
 
1075
  greedy_match=greedy_match,
1076
  combine_pages=combine_pages,
1077
  output_folder=output_folder,
1078
+ progress=progress,
1079
  )
1080
 
1081
  # Clip text to first 200 characters
1082
+ full_df["text"] = full_df["text"].str[:preview_length]
1083
  # Preprocess full_data (without preview text) for fast access (run once)
1084
  full_data_by_file = {
1085
+ file: df.sort_values("page").set_index("page")
1086
+ for file, df in full_df.drop(["text_clean"], axis=1).groupby("file")
1087
  }
1088
 
1089
  if results_df.empty:
1090
+ gr.Info("No duplicate pages found, no results returned.")
1091
 
1092
  end_time = time.time()
1093
  processing_time = round(end_time - start_time, 2)
1094
+
1095
  return results_df, output_paths, full_data_by_file, processing_time, task_textbox
1096
 
1097
+
1098
+ def show_page_previews(
1099
+ full_data_by_file: dict,
1100
+ results_df: pd.DataFrame,
1101
+ evt: gr.SelectData,
1102
+ preview_length: int = 500,
1103
+ ):
1104
  """
1105
  Optimized version using pre-partitioned and indexed full_data.
1106
  Triggered when a user selects a row in the results DataFrame.
 
1110
 
1111
  selected_row = results_df.iloc[evt.index[0], :]
1112
 
1113
+ is_subdocument_match = "Page1_Start_Page" in selected_row
1114
 
1115
  if is_subdocument_match:
1116
+ file1, start1, end1 = (
1117
+ selected_row["Page1_File"],
1118
+ selected_row["Page1_Start_Page"],
1119
+ selected_row["Page1_End_Page"],
1120
+ )
1121
+ file2, start2, end2 = (
1122
+ selected_row["Page2_File"],
1123
+ selected_row["Page2_Start_Page"],
1124
+ selected_row["Page2_End_Page"],
1125
+ )
1126
+
1127
+ page1_data = full_data_by_file[file1].loc[start1:end1, ["text"]].reset_index()
1128
+ page2_data = full_data_by_file[file2].loc[start2:end2, ["text"]].reset_index()
1129
 
1130
  else:
1131
+ file1, page1 = selected_row["Page1_File"], selected_row["Page1_Page"]
1132
+ file2, page2 = selected_row["Page2_File"], selected_row["Page2_Page"]
1133
+
1134
+ page1_data = full_data_by_file[file1].loc[[page1], ["text"]].reset_index()
1135
+ page2_data = full_data_by_file[file2].loc[[page2], ["text"]].reset_index()
1136
 
1137
+ page1_data["text"] = page1_data["text"].str[:preview_length]
1138
+ page2_data["text"] = page2_data["text"].str[:preview_length]
1139
 
1140
+ return page1_data[["page", "text"]], page2_data[["page", "text"]]
 
1141
 
 
1142
 
1143
  def get_page_image_info(page_num: int, page_sizes: List[Dict]) -> Optional[Dict]:
1144
  """
 
1146
  """
1147
  return next((size for size in page_sizes if size["page"] == page_num), None)
1148
 
1149
+
1150
  def add_new_annotations_to_existing_page_annotations(
1151
+ all_annotations: List[Dict], image_path: str, new_annotation_boxes: List[Dict]
 
 
1152
  ) -> Tuple[List[Dict], Dict]:
1153
  """
1154
  Adds a list of new annotation boxes to the annotations for a specific page.
 
1168
  """
1169
  # Find the annotation group for the current page/image
1170
  current_page_group = next(
1171
+ (
1172
+ annot_group
1173
+ for annot_group in all_annotations
1174
+ if annot_group["image"] == image_path
1175
+ ),
1176
+ None,
1177
  )
1178
 
1179
  if current_page_group:
 
1181
  current_page_group["boxes"].extend(new_annotation_boxes)
1182
  else:
1183
  # This is the first set of annotations for this page, create a new group
1184
+ new_group = {"image": image_path, "boxes": new_annotation_boxes}
 
 
 
1185
  all_annotations.append(new_group)
1186
 
1187
  # This object represents all annotations that were just added for this page
1188
+ newly_added_annotation_group = {"image": image_path, "boxes": new_annotation_boxes}
 
 
 
1189
 
1190
  return all_annotations, newly_added_annotation_group
1191
 
1192
+
1193
+ def apply_whole_page_redactions_from_list(
1194
+ duplicate_page_numbers_df: pd.DataFrame,
1195
+ doc_file_name_with_extension_textbox: str,
1196
+ review_file_state: pd.DataFrame,
1197
+ duplicate_output_paths: list[str],
1198
+ pymupdf_doc: object,
1199
+ page_sizes: list[dict],
1200
+ all_existing_annotations: list[dict],
1201
+ combine_pages: bool = True,
1202
+ new_annotations_with_bounding_boxes: List[dict] = list(),
1203
+ ):
1204
+ """
1205
  This function applies redactions to whole pages based on a provided list of duplicate page numbers. It supports two modes of operation: combining pages and not combining pages. When combining pages is enabled, it attempts to identify duplicate pages across different files and applies redactions accordingly. If combining pages is disabled, it relies on new annotations with bounding boxes to determine which pages to redact. The function utilises a PyMuPDF document object to manipulate the PDF file, and it also considers the sizes of pages to ensure accurate redaction application.
1206
 
1207
  Args:
 
1214
  all_existing_annotations (list[dict]): A list of all existing annotations in the document.
1215
  combine_pages (bool, optional): A flag indicating whether to combine pages for redaction. Defaults to True.
1216
  new_annotations_with_bounding_boxes (List[dict], optional): A list of new annotations with bounding boxes. Defaults to an empty list.
1217
+ """
1218
  if all_existing_annotations is None:
1219
  all_existing_annotations = []
1220
 
 
1228
  print(f"Warning: {message}")
1229
  raise Warning(message)
1230
 
1231
+ list_whole_pages_to_redact = []
1232
 
1233
+ if combine_pages is True:
1234
  # Get list of pages to redact from either dataframe or file
1235
  if not duplicate_page_numbers_df.empty:
1236
  list_whole_pages_to_redact = duplicate_page_numbers_df.iloc[:, 0].tolist()
1237
  elif duplicate_output_paths:
1238
+ expected_duplicate_pages_to_redact_name = (
1239
+ f"{doc_file_name_with_extension_textbox}"
1240
+ )
1241
  whole_pages_list = pd.DataFrame() # Initialize empty DataFrame
1242
+
1243
  for output_file in duplicate_output_paths:
1244
  # Note: output_file.name might not be available if output_file is just a string path
1245
  # If it's a Path object or similar, .name is fine. Otherwise, parse from string.
1246
+ file_name_from_path = (
1247
+ output_file.split("/")[-1]
1248
+ if isinstance(output_file, str)
1249
+ else output_file.name
1250
+ )
1251
  if expected_duplicate_pages_to_redact_name in file_name_from_path:
1252
+ whole_pages_list = pd.read_csv(
1253
+ output_file, header=None
1254
+ ) # Use output_file directly if it's a path
1255
+ break
1256
  else:
1257
  message = "No relevant list of whole pages to redact found."
1258
  print(message)
1259
  raise Warning(message)
1260
+
1261
  if not whole_pages_list.empty:
1262
  list_whole_pages_to_redact = whole_pages_list.iloc[:, 0].tolist()
1263
+
1264
  list_whole_pages_to_redact = list(set(list_whole_pages_to_redact))
1265
 
1266
  else:
 
1268
  message = "Can't find any new annotations to add"
1269
  print(message)
1270
  raise Warning(message)
1271
+
1272
  list_whole_pages_to_redact = []
1273
  for annotation in new_annotations_with_bounding_boxes:
1274
+ match = re.search(r"_(\d+)\.png$", annotation["image"])
1275
  if match:
1276
  page = int(match.group(1)) + 1
1277
  list_whole_pages_to_redact.append(page)
1278
  else:
1279
+ print(
1280
+ f"Warning: Could not extract page number from {annotation['image']}"
1281
+ )
1282
 
1283
  list_whole_pages_to_redact = list(set(list_whole_pages_to_redact))
1284
 
 
1285
  new_annotations = []
1286
  # Process each page for redaction
1287
  for page in list_whole_pages_to_redact:
 
1298
  continue
1299
 
1300
  image_path = page_info["image_path"]
1301
+ page_annotation_group = next(
1302
+ (g for g in all_annotations if g["image"] == image_path), None
1303
+ )
1304
+ if page_annotation_group and any(
1305
+ box["label"] == "Whole page" for box in page_annotation_group["boxes"]
1306
+ ):
1307
+ print(
1308
+ f"Whole page redaction for page {page_num} already exists, skipping."
1309
+ )
1310
  continue
1311
+
1312
  # --- Create a LIST of boxes to add.---
1313
  boxes_to_add = []
1314
+
1315
  pymupdf_page = pymupdf_doc[page_index]
1316
 
1317
+ if combine_pages is True:
1318
  whole_page_box = redact_whole_pymupdf_page(
1319
  rect_height=page_info["cropbox_height"],
1320
  rect_width=page_info["cropbox_width"],
1321
+ page=pymupdf_page,
1322
+ border=0.005,
1323
+ redact_pdf=False,
1324
  )
1325
  boxes_to_add.append(whole_page_box)
1326
  else:
1327
  # Find the specific annotation group that matches the current page's image path
1328
  relevant_box_group = next(
1329
+ (
1330
+ group
1331
+ for group in new_annotations_with_bounding_boxes
1332
+ if group.get("image") == image_path
1333
+ ),
1334
+ None, # Default to None if no match is found
1335
  )
1336
+
1337
  # Check if we found a matching group of boxes for this page
1338
  if relevant_box_group:
1339
+ boxes_to_add.extend(relevant_box_group["boxes"])
1340
  else:
1341
  # This case would be unexpected, but it's good to handle.
1342
  # It means a page was in list_whole_pages_to_redact but had no
1343
  # corresponding boxes generated in new_annotations_with_bounding_boxes.
1344
+ print(
1345
+ f"Warning: No new annotation boxes found for page {page_num} ({image_path})."
1346
+ )
1347
+
1348
  # === Use the modified helper function to add a LIST of boxes ===
1349
+ all_annotations, new_annotations_for_page = (
1350
+ add_new_annotations_to_existing_page_annotations(
1351
+ all_annotations=all_annotations,
1352
+ image_path=image_path,
1353
+ new_annotation_boxes=boxes_to_add, # Pass the list here
1354
+ )
1355
  )
1356
 
1357
+ new_annotations_for_page = fill_missing_box_ids_each_box(
1358
+ new_annotations_for_page
1359
+ )
1360
  new_annotations.append(new_annotations_for_page)
1361
 
1362
  except Exception as e:
 
1371
  gr.Info(message)
1372
  return review_file_state, all_annotations
1373
 
1374
+ expected_cols = [
1375
+ "image",
1376
+ "page",
1377
+ "label",
1378
+ "color",
1379
+ "xmin",
1380
+ "ymin",
1381
+ "xmax",
1382
+ "ymax",
1383
+ "text",
1384
+ "id",
1385
+ ]
1386
  for col in expected_cols:
1387
+ if col not in review_file_state.columns:
1388
+ review_file_state[col] = pd.NA
1389
+ if col not in whole_page_review_file.columns:
1390
+ whole_page_review_file[col] = pd.NA
1391
+
1392
+ review_file_out = pd.concat(
1393
+ [review_file_state, whole_page_review_file], ignore_index=True
1394
+ )
1395
+ review_file_out = review_file_out.sort_values(
1396
+ by=["page", "ymin", "xmin"]
1397
+ ).reset_index(drop=True)
1398
+ review_file_out = review_file_out.drop_duplicates(
1399
+ subset=["page", "label", "text", "id"], keep="first"
1400
+ )
1401
 
 
 
 
 
1402
  out_message = "Successfully created duplicate text redactions."
1403
  print(out_message)
1404
  gr.Info(out_message)
1405
 
1406
  return review_file_out, all_annotations
1407
 
1408
+
1409
  def _parse_page_line_id(combined_id: int) -> Tuple[int, int]:
1410
  """Parses a combined ID using modular arithmetic."""
1411
  if int(combined_id) < ID_MULTIPLIER:
1412
  # Handle cases where page is 0 (or just an edge case)
1413
  return 0, combined_id
1414
+
1415
  page = combined_id // ID_MULTIPLIER
1416
  line = combined_id % ID_MULTIPLIER
1417
  return page, line
1418
 
1419
+
1420
  def create_annotation_objects_from_duplicates(
1421
+ duplicates_df: pd.DataFrame,
1422
  ocr_results_df: pd.DataFrame,
1423
  page_sizes: List[Dict],
1424
+ combine_pages: bool = False,
1425
+ ) -> List[Dict]:
1426
  """
1427
  Creates structured annotation objects from duplicate line ranges, mapping
1428
  page numbers to image paths.
 
1441
  if duplicates_df.empty:
1442
  raise Warning("No duplicates found")
1443
  if ocr_results_df.empty:
1444
+ raise Warning(
1445
+ "No OCR results found for file under review. Please upload relevant OCR_output file and original PDF document on the review tab."
1446
+ )
1447
 
1448
+ if combine_pages is False:
1449
+ page_to_image_map = {item["page"]: item["image_path"] for item in page_sizes}
1450
 
1451
  # Prepare OCR Data: Add a line number column if it doesn't exist
1452
+ if "line_number_by_page" not in ocr_results_df.columns:
1453
+ ocr_results_df = ocr_results_df.sort_values(
1454
+ by=["page", "top", "left"]
1455
+ ).reset_index(drop=True)
1456
+ ocr_results_df["line_number_by_page"] = (
1457
+ ocr_results_df.groupby("page").cumcount() + 1
1458
+ )
1459
+
1460
  annotations_by_page = defaultdict(list)
1461
 
1462
  # Iterate through each duplicate range (this logic is unchanged)
1463
  for _, row in duplicates_df.iterrows():
1464
+ start_page, start_line = _parse_page_line_id(row["Page2_Start_Page"])
1465
+ end_page, end_line = _parse_page_line_id(row["Page2_End_Page"])
1466
+
1467
  # Select OCR Lines based on the range (this logic is unchanged)
1468
  if start_page == end_page:
1469
+ condition = (ocr_results_df["page"] == start_page) & (
1470
+ ocr_results_df["line_number_by_page"].between(start_line, end_line)
 
1471
  )
1472
  else:
1473
+ cond_start = (ocr_results_df["page"] == start_page) & (
1474
+ ocr_results_df["line_number_by_page"] >= start_line
1475
+ )
1476
+ cond_middle = ocr_results_df["page"].between(
1477
+ start_page + 1, end_page - 1
1478
+ )
1479
+ cond_end = (ocr_results_df["page"] == end_page) & (
1480
+ ocr_results_df["line_number_by_page"] <= end_line
1481
+ )
1482
  condition = cond_start | cond_middle | cond_end
1483
 
1484
  lines_to_annotate = ocr_results_df[condition]
 
1487
  for _, line_row in lines_to_annotate.iterrows():
1488
  box = {
1489
  "label": "Duplicate text",
1490
+ "color": (0, 0, 0),
1491
+ "xmin": line_row["left"],
1492
+ "ymin": line_row["top"],
1493
+ "xmax": line_row["left"] + line_row["width"],
1494
+ "ymax": line_row["top"] + line_row["height"],
1495
+ "text": line_row["text"],
1496
+ "id": "", # to be filled in after
1497
  }
1498
+ page_number = line_row["page"]
1499
+
1500
  annotations_by_page[page_number].append(box)
1501
+
1502
  # --- Format the final output list using the page-to-image map ---
1503
  final_output = []
1504
  # Sort by page number for a predictable order
1505
  for page_num, boxes in sorted(annotations_by_page.items()):
1506
  # Look up the image path using the page number
1507
  image_path = page_to_image_map.get(page_num)
1508
+
1509
  if image_path:
1510
+ page_boxes = {"image": image_path, "boxes": boxes}
 
 
 
1511
 
1512
  # Fill in missing IDs for the new data entries
1513
  page_boxes = fill_missing_box_ids_each_box(page_boxes)
 
1516
  final_output.append(page_boxes)
1517
  else:
1518
  # Handle cases where a page might not have a corresponding image path
1519
+ print(
1520
+ f"Warning: Page {page_num} found in OCR data but has no corresponding "
1521
+ f"entry in the 'page_sizes' object. This page's annotations will be skipped."
1522
+ )
1523
+
1524
+ return final_output
tools/find_duplicate_tabular.py CHANGED
@@ -1,22 +1,36 @@
1
- import pandas as pd
2
  import os
3
  import re
4
  import time
5
- from sklearn.feature_extraction.text import TfidfVectorizer
6
- from sklearn.metrics.pairwise import cosine_similarity
7
- from typing import List, Tuple, Dict
8
  import gradio as gr
 
9
  from gradio import Progress
10
- from pathlib import Path
11
- from tools.helper_functions import OUTPUT_FOLDER, read_file
 
 
 
 
 
 
 
12
  from tools.data_anonymise import initial_clean
 
13
  from tools.load_spacy_model_custom_recognisers import nlp
14
- from tools.config import DO_INITIAL_TABULAR_DATA_CLEAN, REMOVE_DUPLICATE_ROWS, MAX_SIMULTANEOUS_FILES, MAX_TABLE_ROWS
15
 
16
- if REMOVE_DUPLICATE_ROWS == "True": REMOVE_DUPLICATE_ROWS = True
17
- else: REMOVE_DUPLICATE_ROWS = False
 
 
 
18
 
19
- def clean_and_stem_text_series(df: pd.DataFrame, column: str, do_initial_clean_dup: bool = DO_INITIAL_TABULAR_DATA_CLEAN):
 
 
 
 
20
  """
21
  Clean and stem text columns in a data frame for tabular data
22
  """
@@ -25,31 +39,34 @@ def clean_and_stem_text_series(df: pd.DataFrame, column: str, do_initial_clean_d
25
  def _apply_lemmatization(text):
26
  doc = nlp(text)
27
  # Keep only alphabetic tokens and remove stopwords
28
- lemmatized_words = [token.lemma_ for token in doc if token.is_alpha and not token.is_stop]
29
- return ' '.join(lemmatized_words)
30
-
 
 
31
  if do_initial_clean_dup:
32
- df['text_clean'] = initial_clean(df[column])
 
 
 
 
 
33
 
34
- df['text_clean'] = df['text_clean'].apply(_apply_lemmatization)
35
- df['text_clean'] = df[column].str.lower()#.str.replace(r'[^\w\s]', '', regex=True)
36
-
37
  return df
38
 
 
39
  def convert_tabular_data_to_analysis_format(
40
- df: pd.DataFrame,
41
- file_name: str,
42
- text_columns: List[str] = None
43
  ) -> List[Tuple[str, pd.DataFrame]]:
44
  """
45
  Convert tabular data (CSV/XLSX) to the format needed for duplicate analysis.
46
-
47
  Args:
48
  df (pd.DataFrame): The input DataFrame
49
  file_name (str): Name of the file
50
- text_columns (List[str], optional): Columns to analyze for duplicates.
51
  If None, uses all string columns.
52
-
53
  Returns:
54
  List[Tuple[str, pd.DataFrame]]: List containing (file_name, processed_df) tuple
55
  """
@@ -60,34 +77,39 @@ def convert_tabular_data_to_analysis_format(
60
  # text_columns = df.select_dtypes(include=['object', 'string']).columns.tolist()
61
 
62
  text_columns = [col for col in text_columns if col in df.columns]
63
-
64
  if not text_columns:
65
  print(f"No text columns found in {file_name}")
66
  return list()
67
-
68
  # Create a copy to avoid modifying original
69
  df_copy = df.copy()
70
-
71
  # Create a combined text column from all text columns
72
- df_copy['combined_text'] = df_copy[text_columns].fillna('').astype(str).agg(' '.join, axis=1)
73
-
 
 
74
  # Add row identifier
75
- df_copy['row_id'] = df_copy.index
76
-
77
  # Create the format expected by the duplicate detection system
78
  # Using 'row_number' as row number and 'text' as the combined text
79
- processed_df = pd.DataFrame({
80
- 'row_number': df_copy['row_id'],
81
- 'text': df_copy['combined_text'],
82
- 'file': file_name
83
- })
84
-
 
 
85
  # Add original row data for reference
86
  for col in text_columns:
87
- processed_df[f'original_{col}'] = df_copy[col]
88
-
89
  return [(file_name, processed_df)]
90
 
 
91
  def find_duplicate_cells_in_tabular_data(
92
  input_files: List[str],
93
  similarity_threshold: float = 0.95,
@@ -97,11 +119,11 @@ def find_duplicate_cells_in_tabular_data(
97
  do_initial_clean_dup: bool = DO_INITIAL_TABULAR_DATA_CLEAN,
98
  remove_duplicate_rows: bool = REMOVE_DUPLICATE_ROWS,
99
  in_excel_tabular_sheets: str = "",
100
- progress: Progress = Progress(track_tqdm=True)
101
  ) -> Tuple[pd.DataFrame, List[str], Dict[str, pd.DataFrame]]:
102
  """
103
  Find duplicate cells/text in tabular data files (CSV, XLSX, Parquet).
104
-
105
  Args:
106
  input_files (List[str]): List of file paths to analyze
107
  similarity_threshold (float): Minimum similarity score to consider duplicates
@@ -110,27 +132,27 @@ def find_duplicate_cells_in_tabular_data(
110
  output_folder (str, optional): Output folder for results
111
  do_initial_clean_dup (bool, optional): Whether to do initial clean of text
112
  progress (Progress): Progress tracking object
113
-
114
  Returns:
115
  Tuple containing:
116
  - results_df: DataFrame with duplicate matches
117
  - output_paths: List of output file paths
118
  - full_data_by_file: Dictionary of processed data by file
119
  """
120
-
121
  if not input_files:
122
  raise gr.Error("Please upload files to analyze.")
123
-
124
  progress(0.1, desc="Loading and processing files...")
125
-
126
  all_data_to_process = list()
127
  full_data_by_file = dict()
128
- file_paths = list()
129
-
130
  # Process each file
131
  for file_path in input_files:
132
  try:
133
- if file_path.endswith('.xlsx') or file_path.endswith('.xls'):
134
  temp_df = pd.DataFrame()
135
 
136
  # Try finding each sheet in the given list until a match is found
@@ -147,12 +169,12 @@ def find_duplicate_cells_in_tabular_data(
147
 
148
  file_name = os.path.basename(file_path) + "_" + sheet_name
149
  file_paths.append(file_path)
150
-
151
  # Convert to analysis format
152
  processed_data = convert_tabular_data_to_analysis_format(
153
  temp_df, file_name, text_columns
154
  )
155
-
156
  if processed_data:
157
  all_data_to_process.extend(processed_data)
158
  full_data_by_file[file_name] = processed_data[0][1]
@@ -165,99 +187,129 @@ def find_duplicate_cells_in_tabular_data(
165
  out_message = f"Number of rows in {file_path} is greater than {MAX_TABLE_ROWS}. Please submit a smaller file."
166
  print(out_message)
167
  raise Exception(out_message)
168
-
169
  file_name = os.path.basename(file_path)
170
  file_paths.append(file_path)
171
-
172
  # Convert to analysis format
173
  processed_data = convert_tabular_data_to_analysis_format(
174
  temp_df, file_name, text_columns
175
  )
176
-
177
  if processed_data:
178
  all_data_to_process.extend(processed_data)
179
  full_data_by_file[file_name] = processed_data[0][1]
180
-
181
  except Exception as e:
182
  print(f"Error processing {file_path}: {e}")
183
  continue
184
-
185
  if not all_data_to_process:
186
  raise gr.Error("No valid data found in uploaded files.")
187
-
188
  progress(0.2, desc="Combining data...")
189
-
190
  # Combine all data
191
- combined_df = pd.concat([data[1] for data in all_data_to_process], ignore_index=True)
 
 
 
 
192
 
193
- combined_df = combined_df.drop_duplicates(subset=['row_number', 'file'])
194
-
195
  progress(0.3, desc="Cleaning and preparing text...")
196
-
197
  # Clean and prepare text
198
- combined_df = clean_and_stem_text_series(combined_df, 'text', do_initial_clean_dup=do_initial_clean_dup)
199
-
 
 
200
  # Filter by minimum word count
201
- combined_df['word_count'] = combined_df['text_clean'].str.split().str.len().fillna(0)
202
- combined_df = combined_df[combined_df['word_count'] >= min_word_count].copy()
203
-
 
 
204
  if len(combined_df) < 2:
205
  return pd.DataFrame(), [], full_data_by_file
206
-
207
  progress(0.4, desc="Calculating similarities...")
208
-
209
  # Calculate similarities
210
  vectorizer = TfidfVectorizer()
211
- tfidf_matrix = vectorizer.fit_transform(combined_df['text_clean'])
212
  similarity_matrix = cosine_similarity(tfidf_matrix, dense_output=False)
213
-
214
  # Find similar pairs
215
  coo_matrix = similarity_matrix.tocoo()
216
  similar_pairs = [
217
- (r, c, v) for r, c, v in zip(coo_matrix.row, coo_matrix.col, coo_matrix.data)
 
218
  if r < c and v >= similarity_threshold
219
  ]
220
-
221
  if not similar_pairs:
222
  gr.Info("No duplicate cells found.")
223
  return pd.DataFrame(), [], full_data_by_file
224
-
225
  progress(0.7, desc="Processing results...")
226
-
227
  # Create results DataFrame
228
  results_data = []
229
  for row1, row2, similarity in similar_pairs:
230
  row1_data = combined_df.iloc[row1]
231
  row2_data = combined_df.iloc[row2]
232
-
233
- results_data.append({
234
- 'File1': row1_data['file'],
235
- 'Row1': int(row1_data['row_number']),
236
- 'File2': row2_data['file'],
237
- 'Row2': int(row2_data['row_number']),
238
- 'Similarity_Score': round(similarity, 3),
239
- 'Text1': row1_data['text'][:200] + '...' if len(row1_data['text']) > 200 else row1_data['text'],
240
- 'Text2': row2_data['text'][:200] + '...' if len(row2_data['text']) > 200 else row2_data['text'],
241
- 'Original_Index1': row1,
242
- 'Original_Index2': row2
243
- })
244
-
 
 
 
 
 
 
 
 
 
 
245
  results_df = pd.DataFrame(results_data)
246
- results_df = results_df.sort_values(['File1', 'Row1', 'File2', 'Row2'])
247
-
248
  progress(0.9, desc="Saving results...")
249
-
250
  # Save results
251
- output_paths = save_tabular_duplicate_results(results_df, output_folder, file_paths, remove_duplicate_rows=remove_duplicate_rows, in_excel_tabular_sheets=in_excel_tabular_sheets)
252
-
 
 
 
 
 
 
253
  gr.Info(f"Found {len(results_df)} duplicate cell matches")
254
-
255
  return results_df, output_paths, full_data_by_file
256
 
257
- def save_tabular_duplicate_results(results_df: pd.DataFrame, output_folder: str, file_paths: List[str], remove_duplicate_rows: bool = REMOVE_DUPLICATE_ROWS, in_excel_tabular_sheets: List[str] = []) -> List[str]:
 
 
 
 
 
 
 
258
  """
259
  Save tabular duplicate detection results to files.
260
-
261
  Args:
262
  results_df (pd.DataFrame): Results DataFrame
263
  output_folder (str): Output folder path
@@ -270,78 +322,87 @@ def save_tabular_duplicate_results(results_df: pd.DataFrame, output_folder: str,
270
  output_paths = list()
271
  output_folder_path = Path(output_folder)
272
  output_folder_path.mkdir(exist_ok=True)
273
-
274
  if results_df.empty:
275
  print("No duplicate matches to save.")
276
  return list()
277
-
278
  # Save main results
279
- results_file = output_folder_path / 'tabular_duplicate_results.csv'
280
  results_df.to_csv(results_file, index=False, encoding="utf-8-sig")
281
  output_paths.append(str(results_file))
282
-
283
  # Group results by original file to handle Excel files properly
284
- excel_files_processed = dict() # Track which Excel files have been processed
285
-
286
  # Save per-file duplicate lists
287
- for file_name, group in results_df.groupby('File2'):
288
  # Check for matches with original file names
289
  for original_file in file_paths:
290
  original_file_name = os.path.basename(original_file)
291
 
292
  if original_file_name in file_name:
293
  original_file_extension = os.path.splitext(original_file)[-1]
294
- if original_file_extension in ['.xlsx', '.xls']:
295
-
296
  # Split the string using a regex to handle both .xlsx_ and .xls_ delimiters
297
  # The regex r'\.xlsx_|\.xls_' correctly matches either ".xlsx_" or ".xls_" as a delimiter.
298
- parts = re.split(r'\.xlsx_|\.xls_', os.path.basename(file_name))
299
  # The sheet name is the last part after splitting
300
  file_sheet_name = parts[-1]
301
 
302
  file_path = original_file
303
-
304
  # Initialize Excel file tracking if not already done
305
  if file_path not in excel_files_processed:
306
  excel_files_processed[file_path] = {
307
- 'sheets_data': dict(),
308
- 'all_sheets': list(),
309
- 'processed_sheets': set()
310
  }
311
-
312
  # Read the original Excel file to get all sheet names
313
- if not excel_files_processed[file_path]['all_sheets']:
314
  try:
315
  excel_file = pd.ExcelFile(file_path)
316
- excel_files_processed[file_path]['all_sheets'] = excel_file.sheet_names
 
 
317
  except Exception as e:
318
  print(f"Error reading Excel file {file_path}: {e}")
319
  continue
320
-
321
  # Read the current sheet
322
  df = read_file(file_path, excel_sheet_name=file_sheet_name)
323
-
324
  # Create duplicate rows file for this sheet
325
  file_stem = Path(file_name).stem
326
- duplicate_rows_file = output_folder_path / f"{file_stem}_{file_sheet_name}_duplicate_rows.csv"
327
-
 
 
 
328
  # Get unique row numbers to remove
329
- rows_to_remove = sorted(group['Row2'].unique())
330
- duplicate_df = pd.DataFrame({'Row_to_Remove': rows_to_remove})
331
  duplicate_df.to_csv(duplicate_rows_file, index=False)
332
  output_paths.append(str(duplicate_rows_file))
333
-
334
  # Process the sheet data
335
  df_cleaned = df.copy()
336
  df_cleaned["duplicated"] = False
337
  df_cleaned.loc[rows_to_remove, "duplicated"] = True
338
  if remove_duplicate_rows:
339
  df_cleaned = df_cleaned.drop(index=rows_to_remove)
340
-
341
  # Store the processed sheet data
342
- excel_files_processed[file_path]['sheets_data'][file_sheet_name] = df_cleaned
343
- excel_files_processed[file_path]['processed_sheets'].add(file_sheet_name)
344
-
 
 
 
 
345
  else:
346
  file_sheet_name = ""
347
  file_path = original_file
@@ -350,11 +411,13 @@ def save_tabular_duplicate_results(results_df: pd.DataFrame, output_folder: str,
350
  df = read_file(file_path)
351
 
352
  file_stem = Path(file_name).stem
353
- duplicate_rows_file = output_folder_path / f"{file_stem}_duplicate_rows.csv"
354
-
 
 
355
  # Get unique row numbers to remove
356
- rows_to_remove = sorted(group['Row2'].unique())
357
- duplicate_df = pd.DataFrame({'Row_to_Remove': rows_to_remove})
358
  duplicate_df.to_csv(duplicate_rows_file, index=False)
359
  output_paths.append(str(duplicate_rows_file))
360
 
@@ -366,63 +429,66 @@ def save_tabular_duplicate_results(results_df: pd.DataFrame, output_folder: str,
366
 
367
  file_ext = os.path.splitext(file_name)[-1]
368
 
369
- if file_ext in ['.parquet']:
370
- output_path = os.path.join(output_folder, f"{file_base_name}_deduplicated.parquet")
 
 
371
  df_cleaned.to_parquet(output_path, index=False)
372
  else:
373
- output_path = os.path.join(output_folder, f"{file_base_name}_deduplicated.csv")
374
- df_cleaned.to_csv(output_path, index=False, encoding="utf-8-sig")
375
-
 
 
 
 
376
  output_paths.append(str(output_path))
377
  break
378
-
379
  # Process Excel files to create complete deduplicated files
380
  for file_path, file_data in excel_files_processed.items():
381
  try:
382
  # Create output filename
383
  file_base_name = os.path.splitext(os.path.basename(file_path))[0]
384
  file_ext = os.path.splitext(file_path)[-1]
385
- output_path = os.path.join(output_folder, f"{file_base_name}_deduplicated{file_ext}")
386
-
 
 
387
  # Create Excel writer
388
- with pd.ExcelWriter(output_path, engine='openpyxl') as writer:
389
  # Write all sheets
390
- for sheet_name in file_data['all_sheets']:
391
- if sheet_name in file_data['processed_sheets']:
392
  # Use the processed (deduplicated) version
393
- file_data['sheets_data'][sheet_name].to_excel(
394
- writer,
395
- sheet_name=sheet_name,
396
- index=False
397
  )
398
  else:
399
  # Use the original sheet (no duplicates found)
400
  original_df = read_file(file_path, excel_sheet_name=sheet_name)
401
- original_df.to_excel(
402
- writer,
403
- sheet_name=sheet_name,
404
- index=False
405
- )
406
-
407
  output_paths.append(str(output_path))
408
  print(f"Created deduplicated Excel file: {output_path}")
409
-
410
  except Exception as e:
411
  print(f"Error creating deduplicated Excel file for {file_path}: {e}")
412
  continue
413
-
414
  return output_paths
415
 
 
416
  def remove_duplicate_rows_from_tabular_data(
417
  file_path: str,
418
  duplicate_rows: List[int],
419
  output_folder: str = OUTPUT_FOLDER,
420
  in_excel_tabular_sheets: List[str] = [],
421
- remove_duplicate_rows: bool = REMOVE_DUPLICATE_ROWS
422
  ) -> str:
423
  """
424
  Remove duplicate rows from a tabular data file.
425
-
426
  Args:
427
  file_path (str): Path to the input file
428
  duplicate_rows (List[int]): List of row indices to remove
@@ -434,31 +500,39 @@ def remove_duplicate_rows_from_tabular_data(
434
  """
435
  try:
436
  # Load the file
437
- df = read_file(file_path, excel_sheet_name=in_excel_tabular_sheets if in_excel_tabular_sheets else "")
438
-
 
 
 
439
  # Remove duplicate rows (0-indexed)
440
  df_cleaned = df.drop(index=duplicate_rows).reset_index(drop=True)
441
-
442
  # Save cleaned file
443
  file_name = os.path.basename(file_path)
444
  file_stem = os.path.splitext(file_name)[0]
445
  file_ext = os.path.splitext(file_name)[-1]
446
-
447
  output_path = os.path.join(output_folder, f"{file_stem}_deduplicated{file_ext}")
448
-
449
- if file_ext in ['.xlsx', '.xls']:
450
- df_cleaned.to_excel(output_path, index=False, sheet_name=in_excel_tabular_sheets if in_excel_tabular_sheets else [])
451
- elif file_ext in ['.parquet']:
 
 
 
 
452
  df_cleaned.to_parquet(output_path, index=False)
453
  else:
454
  df_cleaned.to_csv(output_path, index=False, encoding="utf-8-sig")
455
-
456
  return output_path
457
-
458
  except Exception as e:
459
  print(f"Error removing duplicates from {file_path}: {e}")
460
  raise
461
 
 
462
  def run_tabular_duplicate_analysis(
463
  files: List[str],
464
  threshold: float,
@@ -468,11 +542,11 @@ def run_tabular_duplicate_analysis(
468
  do_initial_clean_dup: bool = DO_INITIAL_TABULAR_DATA_CLEAN,
469
  remove_duplicate_rows: bool = REMOVE_DUPLICATE_ROWS,
470
  in_excel_tabular_sheets: List[str] = [],
471
- progress: Progress = Progress(track_tqdm=True)
472
  ) -> Tuple[pd.DataFrame, List[str], Dict[str, pd.DataFrame]]:
473
  """
474
  Main function to run tabular duplicate analysis.
475
-
476
  Args:
477
  files (List[str]): List of file paths
478
  threshold (float): Similarity threshold
@@ -480,7 +554,7 @@ def run_tabular_duplicate_analysis(
480
  text_columns (List[str], optional): Specific columns to analyze
481
  output_folder (str, optional): Output folder for results
482
  progress (Progress): Progress tracking
483
-
484
  Returns:
485
  Tuple containing results DataFrame, output paths, and full data by file
486
  """
@@ -491,43 +565,58 @@ def run_tabular_duplicate_analysis(
491
  text_columns=text_columns if text_columns else [],
492
  output_folder=output_folder,
493
  do_initial_clean_dup=do_initial_clean_dup,
494
- in_excel_tabular_sheets=in_excel_tabular_sheets if in_excel_tabular_sheets else [],
495
- remove_duplicate_rows=remove_duplicate_rows
 
 
496
  )
497
 
498
 
499
-
500
  # Function to update column choices when files are uploaded
501
  def update_tabular_column_choices(files, in_excel_tabular_sheets: List[str] = []):
502
  if not files:
503
  return gr.update(choices=[])
504
-
505
  all_columns = set()
506
  for file in files:
507
  try:
508
  file_extension = os.path.splitext(file.name)[-1]
509
- if file_extension in ['.xlsx', '.xls']:
510
  for sheet_name in in_excel_tabular_sheets:
511
  df = read_file(file.name, excel_sheet_name=sheet_name)
512
- text_cols = df.select_dtypes(include=['object', 'string']).columns.tolist()
 
 
513
  all_columns.update(text_cols)
514
  else:
515
  df = read_file(file.name)
516
- text_cols = df.select_dtypes(include=['object', 'string']).columns.tolist()
 
 
517
  all_columns.update(text_cols)
518
 
519
  # Get text columns
520
- text_cols = df.select_dtypes(include=['object', 'string']).columns.tolist()
521
 
522
  all_columns.update(text_cols)
523
  except Exception as e:
524
  print(f"Error reading {file.name}: {e}")
525
  continue
526
-
527
  return gr.Dropdown(choices=sorted(list(all_columns)))
528
 
 
529
  # Function to handle tabular duplicate detection
530
- def run_tabular_duplicate_detection(files, threshold, min_words, text_columns, output_folder: str = OUTPUT_FOLDER, do_initial_clean_dup: bool = DO_INITIAL_TABULAR_DATA_CLEAN, in_excel_tabular_sheets: List[str] = [], remove_duplicate_rows: bool = REMOVE_DUPLICATE_ROWS):
 
 
 
 
 
 
 
 
 
531
  if not files:
532
  print("No files uploaded")
533
  return pd.DataFrame(), [], gr.Dropdown(choices=[]), 0, "deduplicate"
@@ -537,8 +626,9 @@ def run_tabular_duplicate_detection(files, threshold, min_words, text_columns, o
537
  task_textbox = "deduplicate"
538
 
539
  # If output folder doesn't end with a forward slash, add one
540
- if not output_folder.endswith('/'): output_folder = output_folder + '/'
541
-
 
542
  file_paths = list()
543
  if isinstance(files, str):
544
  # If 'files' is a single string, treat it as a list with one element
@@ -549,18 +639,22 @@ def run_tabular_duplicate_detection(files, threshold, min_words, text_columns, o
549
  if isinstance(f_item, str):
550
  # If an element is a string, it's a direct file path
551
  file_paths.append(f_item)
552
- elif hasattr(f_item, 'name'):
553
  # If an element has a '.name' attribute (e.g., a Gradio File object), use its name
554
  file_paths.append(f_item.name)
555
  else:
556
  # Log a warning for unexpected element types within the list
557
- print(f"Warning: Skipping an element in 'files' list that is neither a string nor has a '.name' attribute: {type(f_item)}")
558
- elif hasattr(files, 'name'):
 
 
559
  # Handle the case where a single file object (e.g., gr.File) is passed directly, not in a list
560
  file_paths.append(files.name)
561
  else:
562
  # Raise an error for any other unexpected type of the 'files' argument itself
563
- raise TypeError(f"Unexpected type for 'files' argument: {type(files)}. Expected str, list of str/file objects, or a single file object.")
 
 
564
 
565
  if len(file_paths) > MAX_SIMULTANEOUS_FILES:
566
  out_message = f"Number of files to deduplicate is greater than {MAX_SIMULTANEOUS_FILES}. Please submit a smaller number of files."
@@ -574,21 +668,30 @@ def run_tabular_duplicate_detection(files, threshold, min_words, text_columns, o
574
  text_columns=text_columns if text_columns else [],
575
  output_folder=output_folder,
576
  do_initial_clean_dup=do_initial_clean_dup,
577
- in_excel_tabular_sheets=in_excel_tabular_sheets if in_excel_tabular_sheets else None,
578
- remove_duplicate_rows=remove_duplicate_rows
 
 
579
  )
580
-
581
  # Update file choices for cleaning
582
  file_choices = list(set([f for f in file_paths]))
583
 
584
  end_time = time.time()
585
  processing_time = round(end_time - start_time, 2)
586
-
587
- return results_df, output_paths, gr.Dropdown(choices=file_choices), processing_time, task_textbox
 
 
 
 
 
 
 
588
 
589
  # Function to handle row selection for preview
590
- def handle_tabular_row_selection(results_df, evt:gr.SelectData):
591
-
592
  if not evt:
593
  return None, "", ""
594
 
@@ -596,25 +699,32 @@ def handle_tabular_row_selection(results_df, evt:gr.SelectData):
596
  return None, "", ""
597
  elif results_df.empty:
598
  return None, "", ""
599
-
600
  selected_index = evt.index[0]
601
  if selected_index >= len(results_df):
602
  return None, "", ""
603
-
604
  row = results_df.iloc[selected_index]
605
- return selected_index, row['Text1'], row['Text2']
 
606
 
607
  # Function to clean duplicates from selected file
608
- def clean_tabular_duplicates(file_name, results_df, output_folder, in_excel_tabular_sheets: str = "", remove_duplicate_rows: bool = REMOVE_DUPLICATE_ROWS):
 
 
 
 
 
 
609
  if not file_name or results_df.empty:
610
  return None
611
-
612
  # Get duplicate rows for this file
613
- file_duplicates = results_df[results_df['File2'] == file_name]['Row2'].tolist()
614
-
615
  if not file_duplicates:
616
  return None
617
-
618
  try:
619
  # Find the original file path
620
  # This is a simplified approach - in practice you might want to store file paths
@@ -623,9 +733,9 @@ def clean_tabular_duplicates(file_name, results_df, output_folder, in_excel_tabu
623
  duplicate_rows=file_duplicates,
624
  output_folder=output_folder,
625
  in_excel_tabular_sheets=in_excel_tabular_sheets,
626
- remove_duplicate_rows=remove_duplicate_rows
627
  )
628
  return cleaned_file
629
  except Exception as e:
630
  print(f"Error cleaning duplicates: {e}")
631
- return None
 
 
1
  import os
2
  import re
3
  import time
4
+ from pathlib import Path
5
+ from typing import Dict, List, Tuple
6
+
7
  import gradio as gr
8
+ import pandas as pd
9
  from gradio import Progress
10
+ from sklearn.feature_extraction.text import TfidfVectorizer
11
+ from sklearn.metrics.pairwise import cosine_similarity
12
+
13
+ from tools.config import (
14
+ DO_INITIAL_TABULAR_DATA_CLEAN,
15
+ MAX_SIMULTANEOUS_FILES,
16
+ MAX_TABLE_ROWS,
17
+ REMOVE_DUPLICATE_ROWS,
18
+ )
19
  from tools.data_anonymise import initial_clean
20
+ from tools.helper_functions import OUTPUT_FOLDER, read_file
21
  from tools.load_spacy_model_custom_recognisers import nlp
 
22
 
23
+ if REMOVE_DUPLICATE_ROWS == "True":
24
+ REMOVE_DUPLICATE_ROWS = True
25
+ else:
26
+ REMOVE_DUPLICATE_ROWS = False
27
+
28
 
29
+ def clean_and_stem_text_series(
30
+ df: pd.DataFrame,
31
+ column: str,
32
+ do_initial_clean_dup: bool = DO_INITIAL_TABULAR_DATA_CLEAN,
33
+ ):
34
  """
35
  Clean and stem text columns in a data frame for tabular data
36
  """
 
39
  def _apply_lemmatization(text):
40
  doc = nlp(text)
41
  # Keep only alphabetic tokens and remove stopwords
42
+ lemmatized_words = [
43
+ token.lemma_ for token in doc if token.is_alpha and not token.is_stop
44
+ ]
45
+ return " ".join(lemmatized_words)
46
+
47
  if do_initial_clean_dup:
48
+ df["text_clean"] = initial_clean(df[column])
49
+
50
+ df["text_clean"] = df["text_clean"].apply(_apply_lemmatization)
51
+ df["text_clean"] = df[
52
+ column
53
+ ].str.lower() # .str.replace(r'[^\w\s]', '', regex=True)
54
 
 
 
 
55
  return df
56
 
57
+
58
  def convert_tabular_data_to_analysis_format(
59
+ df: pd.DataFrame, file_name: str, text_columns: List[str] = None
 
 
60
  ) -> List[Tuple[str, pd.DataFrame]]:
61
  """
62
  Convert tabular data (CSV/XLSX) to the format needed for duplicate analysis.
63
+
64
  Args:
65
  df (pd.DataFrame): The input DataFrame
66
  file_name (str): Name of the file
67
+ text_columns (List[str], optional): Columns to analyze for duplicates.
68
  If None, uses all string columns.
69
+
70
  Returns:
71
  List[Tuple[str, pd.DataFrame]]: List containing (file_name, processed_df) tuple
72
  """
 
77
  # text_columns = df.select_dtypes(include=['object', 'string']).columns.tolist()
78
 
79
  text_columns = [col for col in text_columns if col in df.columns]
80
+
81
  if not text_columns:
82
  print(f"No text columns found in {file_name}")
83
  return list()
84
+
85
  # Create a copy to avoid modifying original
86
  df_copy = df.copy()
87
+
88
  # Create a combined text column from all text columns
89
+ df_copy["combined_text"] = (
90
+ df_copy[text_columns].fillna("").astype(str).agg(" ".join, axis=1)
91
+ )
92
+
93
  # Add row identifier
94
+ df_copy["row_id"] = df_copy.index
95
+
96
  # Create the format expected by the duplicate detection system
97
  # Using 'row_number' as row number and 'text' as the combined text
98
+ processed_df = pd.DataFrame(
99
+ {
100
+ "row_number": df_copy["row_id"],
101
+ "text": df_copy["combined_text"],
102
+ "file": file_name,
103
+ }
104
+ )
105
+
106
  # Add original row data for reference
107
  for col in text_columns:
108
+ processed_df[f"original_{col}"] = df_copy[col]
109
+
110
  return [(file_name, processed_df)]
111
 
112
+
113
  def find_duplicate_cells_in_tabular_data(
114
  input_files: List[str],
115
  similarity_threshold: float = 0.95,
 
119
  do_initial_clean_dup: bool = DO_INITIAL_TABULAR_DATA_CLEAN,
120
  remove_duplicate_rows: bool = REMOVE_DUPLICATE_ROWS,
121
  in_excel_tabular_sheets: str = "",
122
+ progress: Progress = Progress(track_tqdm=True),
123
  ) -> Tuple[pd.DataFrame, List[str], Dict[str, pd.DataFrame]]:
124
  """
125
  Find duplicate cells/text in tabular data files (CSV, XLSX, Parquet).
126
+
127
  Args:
128
  input_files (List[str]): List of file paths to analyze
129
  similarity_threshold (float): Minimum similarity score to consider duplicates
 
132
  output_folder (str, optional): Output folder for results
133
  do_initial_clean_dup (bool, optional): Whether to do initial clean of text
134
  progress (Progress): Progress tracking object
135
+
136
  Returns:
137
  Tuple containing:
138
  - results_df: DataFrame with duplicate matches
139
  - output_paths: List of output file paths
140
  - full_data_by_file: Dictionary of processed data by file
141
  """
142
+
143
  if not input_files:
144
  raise gr.Error("Please upload files to analyze.")
145
+
146
  progress(0.1, desc="Loading and processing files...")
147
+
148
  all_data_to_process = list()
149
  full_data_by_file = dict()
150
+ file_paths = list()
151
+
152
  # Process each file
153
  for file_path in input_files:
154
  try:
155
+ if file_path.endswith(".xlsx") or file_path.endswith(".xls"):
156
  temp_df = pd.DataFrame()
157
 
158
  # Try finding each sheet in the given list until a match is found
 
169
 
170
  file_name = os.path.basename(file_path) + "_" + sheet_name
171
  file_paths.append(file_path)
172
+
173
  # Convert to analysis format
174
  processed_data = convert_tabular_data_to_analysis_format(
175
  temp_df, file_name, text_columns
176
  )
177
+
178
  if processed_data:
179
  all_data_to_process.extend(processed_data)
180
  full_data_by_file[file_name] = processed_data[0][1]
 
187
  out_message = f"Number of rows in {file_path} is greater than {MAX_TABLE_ROWS}. Please submit a smaller file."
188
  print(out_message)
189
  raise Exception(out_message)
190
+
191
  file_name = os.path.basename(file_path)
192
  file_paths.append(file_path)
193
+
194
  # Convert to analysis format
195
  processed_data = convert_tabular_data_to_analysis_format(
196
  temp_df, file_name, text_columns
197
  )
198
+
199
  if processed_data:
200
  all_data_to_process.extend(processed_data)
201
  full_data_by_file[file_name] = processed_data[0][1]
202
+
203
  except Exception as e:
204
  print(f"Error processing {file_path}: {e}")
205
  continue
206
+
207
  if not all_data_to_process:
208
  raise gr.Error("No valid data found in uploaded files.")
209
+
210
  progress(0.2, desc="Combining data...")
211
+
212
  # Combine all data
213
+ combined_df = pd.concat(
214
+ [data[1] for data in all_data_to_process], ignore_index=True
215
+ )
216
+
217
+ combined_df = combined_df.drop_duplicates(subset=["row_number", "file"])
218
 
 
 
219
  progress(0.3, desc="Cleaning and preparing text...")
220
+
221
  # Clean and prepare text
222
+ combined_df = clean_and_stem_text_series(
223
+ combined_df, "text", do_initial_clean_dup=do_initial_clean_dup
224
+ )
225
+
226
  # Filter by minimum word count
227
+ combined_df["word_count"] = (
228
+ combined_df["text_clean"].str.split().str.len().fillna(0)
229
+ )
230
+ combined_df = combined_df[combined_df["word_count"] >= min_word_count].copy()
231
+
232
  if len(combined_df) < 2:
233
  return pd.DataFrame(), [], full_data_by_file
234
+
235
  progress(0.4, desc="Calculating similarities...")
236
+
237
  # Calculate similarities
238
  vectorizer = TfidfVectorizer()
239
+ tfidf_matrix = vectorizer.fit_transform(combined_df["text_clean"])
240
  similarity_matrix = cosine_similarity(tfidf_matrix, dense_output=False)
241
+
242
  # Find similar pairs
243
  coo_matrix = similarity_matrix.tocoo()
244
  similar_pairs = [
245
+ (r, c, v)
246
+ for r, c, v in zip(coo_matrix.row, coo_matrix.col, coo_matrix.data)
247
  if r < c and v >= similarity_threshold
248
  ]
249
+
250
  if not similar_pairs:
251
  gr.Info("No duplicate cells found.")
252
  return pd.DataFrame(), [], full_data_by_file
253
+
254
  progress(0.7, desc="Processing results...")
255
+
256
  # Create results DataFrame
257
  results_data = []
258
  for row1, row2, similarity in similar_pairs:
259
  row1_data = combined_df.iloc[row1]
260
  row2_data = combined_df.iloc[row2]
261
+
262
+ results_data.append(
263
+ {
264
+ "File1": row1_data["file"],
265
+ "Row1": int(row1_data["row_number"]),
266
+ "File2": row2_data["file"],
267
+ "Row2": int(row2_data["row_number"]),
268
+ "Similarity_Score": round(similarity, 3),
269
+ "Text1": (
270
+ row1_data["text"][:200] + "..."
271
+ if len(row1_data["text"]) > 200
272
+ else row1_data["text"]
273
+ ),
274
+ "Text2": (
275
+ row2_data["text"][:200] + "..."
276
+ if len(row2_data["text"]) > 200
277
+ else row2_data["text"]
278
+ ),
279
+ "Original_Index1": row1,
280
+ "Original_Index2": row2,
281
+ }
282
+ )
283
+
284
  results_df = pd.DataFrame(results_data)
285
+ results_df = results_df.sort_values(["File1", "Row1", "File2", "Row2"])
286
+
287
  progress(0.9, desc="Saving results...")
288
+
289
  # Save results
290
+ output_paths = save_tabular_duplicate_results(
291
+ results_df,
292
+ output_folder,
293
+ file_paths,
294
+ remove_duplicate_rows=remove_duplicate_rows,
295
+ in_excel_tabular_sheets=in_excel_tabular_sheets,
296
+ )
297
+
298
  gr.Info(f"Found {len(results_df)} duplicate cell matches")
299
+
300
  return results_df, output_paths, full_data_by_file
301
 
302
+
303
+ def save_tabular_duplicate_results(
304
+ results_df: pd.DataFrame,
305
+ output_folder: str,
306
+ file_paths: List[str],
307
+ remove_duplicate_rows: bool = REMOVE_DUPLICATE_ROWS,
308
+ in_excel_tabular_sheets: List[str] = [],
309
+ ) -> List[str]:
310
  """
311
  Save tabular duplicate detection results to files.
312
+
313
  Args:
314
  results_df (pd.DataFrame): Results DataFrame
315
  output_folder (str): Output folder path
 
322
  output_paths = list()
323
  output_folder_path = Path(output_folder)
324
  output_folder_path.mkdir(exist_ok=True)
325
+
326
  if results_df.empty:
327
  print("No duplicate matches to save.")
328
  return list()
329
+
330
  # Save main results
331
+ results_file = output_folder_path / "tabular_duplicate_results.csv"
332
  results_df.to_csv(results_file, index=False, encoding="utf-8-sig")
333
  output_paths.append(str(results_file))
334
+
335
  # Group results by original file to handle Excel files properly
336
+ excel_files_processed = dict() # Track which Excel files have been processed
337
+
338
  # Save per-file duplicate lists
339
+ for file_name, group in results_df.groupby("File2"):
340
  # Check for matches with original file names
341
  for original_file in file_paths:
342
  original_file_name = os.path.basename(original_file)
343
 
344
  if original_file_name in file_name:
345
  original_file_extension = os.path.splitext(original_file)[-1]
346
+ if original_file_extension in [".xlsx", ".xls"]:
347
+
348
  # Split the string using a regex to handle both .xlsx_ and .xls_ delimiters
349
  # The regex r'\.xlsx_|\.xls_' correctly matches either ".xlsx_" or ".xls_" as a delimiter.
350
+ parts = re.split(r"\.xlsx_|\.xls_", os.path.basename(file_name))
351
  # The sheet name is the last part after splitting
352
  file_sheet_name = parts[-1]
353
 
354
  file_path = original_file
355
+
356
  # Initialize Excel file tracking if not already done
357
  if file_path not in excel_files_processed:
358
  excel_files_processed[file_path] = {
359
+ "sheets_data": dict(),
360
+ "all_sheets": list(),
361
+ "processed_sheets": set(),
362
  }
363
+
364
  # Read the original Excel file to get all sheet names
365
+ if not excel_files_processed[file_path]["all_sheets"]:
366
  try:
367
  excel_file = pd.ExcelFile(file_path)
368
+ excel_files_processed[file_path][
369
+ "all_sheets"
370
+ ] = excel_file.sheet_names
371
  except Exception as e:
372
  print(f"Error reading Excel file {file_path}: {e}")
373
  continue
374
+
375
  # Read the current sheet
376
  df = read_file(file_path, excel_sheet_name=file_sheet_name)
377
+
378
  # Create duplicate rows file for this sheet
379
  file_stem = Path(file_name).stem
380
+ duplicate_rows_file = (
381
+ output_folder_path
382
+ / f"{file_stem}_{file_sheet_name}_duplicate_rows.csv"
383
+ )
384
+
385
  # Get unique row numbers to remove
386
+ rows_to_remove = sorted(group["Row2"].unique())
387
+ duplicate_df = pd.DataFrame({"Row_to_Remove": rows_to_remove})
388
  duplicate_df.to_csv(duplicate_rows_file, index=False)
389
  output_paths.append(str(duplicate_rows_file))
390
+
391
  # Process the sheet data
392
  df_cleaned = df.copy()
393
  df_cleaned["duplicated"] = False
394
  df_cleaned.loc[rows_to_remove, "duplicated"] = True
395
  if remove_duplicate_rows:
396
  df_cleaned = df_cleaned.drop(index=rows_to_remove)
397
+
398
  # Store the processed sheet data
399
+ excel_files_processed[file_path]["sheets_data"][
400
+ file_sheet_name
401
+ ] = df_cleaned
402
+ excel_files_processed[file_path]["processed_sheets"].add(
403
+ file_sheet_name
404
+ )
405
+
406
  else:
407
  file_sheet_name = ""
408
  file_path = original_file
 
411
  df = read_file(file_path)
412
 
413
  file_stem = Path(file_name).stem
414
+ duplicate_rows_file = (
415
+ output_folder_path / f"{file_stem}_duplicate_rows.csv"
416
+ )
417
+
418
  # Get unique row numbers to remove
419
+ rows_to_remove = sorted(group["Row2"].unique())
420
+ duplicate_df = pd.DataFrame({"Row_to_Remove": rows_to_remove})
421
  duplicate_df.to_csv(duplicate_rows_file, index=False)
422
  output_paths.append(str(duplicate_rows_file))
423
 
 
429
 
430
  file_ext = os.path.splitext(file_name)[-1]
431
 
432
+ if file_ext in [".parquet"]:
433
+ output_path = os.path.join(
434
+ output_folder, f"{file_base_name}_deduplicated.parquet"
435
+ )
436
  df_cleaned.to_parquet(output_path, index=False)
437
  else:
438
+ output_path = os.path.join(
439
+ output_folder, f"{file_base_name}_deduplicated.csv"
440
+ )
441
+ df_cleaned.to_csv(
442
+ output_path, index=False, encoding="utf-8-sig"
443
+ )
444
+
445
  output_paths.append(str(output_path))
446
  break
447
+
448
  # Process Excel files to create complete deduplicated files
449
  for file_path, file_data in excel_files_processed.items():
450
  try:
451
  # Create output filename
452
  file_base_name = os.path.splitext(os.path.basename(file_path))[0]
453
  file_ext = os.path.splitext(file_path)[-1]
454
+ output_path = os.path.join(
455
+ output_folder, f"{file_base_name}_deduplicated{file_ext}"
456
+ )
457
+
458
  # Create Excel writer
459
+ with pd.ExcelWriter(output_path, engine="openpyxl") as writer:
460
  # Write all sheets
461
+ for sheet_name in file_data["all_sheets"]:
462
+ if sheet_name in file_data["processed_sheets"]:
463
  # Use the processed (deduplicated) version
464
+ file_data["sheets_data"][sheet_name].to_excel(
465
+ writer, sheet_name=sheet_name, index=False
 
 
466
  )
467
  else:
468
  # Use the original sheet (no duplicates found)
469
  original_df = read_file(file_path, excel_sheet_name=sheet_name)
470
+ original_df.to_excel(writer, sheet_name=sheet_name, index=False)
471
+
 
 
 
 
472
  output_paths.append(str(output_path))
473
  print(f"Created deduplicated Excel file: {output_path}")
474
+
475
  except Exception as e:
476
  print(f"Error creating deduplicated Excel file for {file_path}: {e}")
477
  continue
478
+
479
  return output_paths
480
 
481
+
482
  def remove_duplicate_rows_from_tabular_data(
483
  file_path: str,
484
  duplicate_rows: List[int],
485
  output_folder: str = OUTPUT_FOLDER,
486
  in_excel_tabular_sheets: List[str] = [],
487
+ remove_duplicate_rows: bool = REMOVE_DUPLICATE_ROWS,
488
  ) -> str:
489
  """
490
  Remove duplicate rows from a tabular data file.
491
+
492
  Args:
493
  file_path (str): Path to the input file
494
  duplicate_rows (List[int]): List of row indices to remove
 
500
  """
501
  try:
502
  # Load the file
503
+ df = read_file(
504
+ file_path,
505
+ excel_sheet_name=in_excel_tabular_sheets if in_excel_tabular_sheets else "",
506
+ )
507
+
508
  # Remove duplicate rows (0-indexed)
509
  df_cleaned = df.drop(index=duplicate_rows).reset_index(drop=True)
510
+
511
  # Save cleaned file
512
  file_name = os.path.basename(file_path)
513
  file_stem = os.path.splitext(file_name)[0]
514
  file_ext = os.path.splitext(file_name)[-1]
515
+
516
  output_path = os.path.join(output_folder, f"{file_stem}_deduplicated{file_ext}")
517
+
518
+ if file_ext in [".xlsx", ".xls"]:
519
+ df_cleaned.to_excel(
520
+ output_path,
521
+ index=False,
522
+ sheet_name=in_excel_tabular_sheets if in_excel_tabular_sheets else [],
523
+ )
524
+ elif file_ext in [".parquet"]:
525
  df_cleaned.to_parquet(output_path, index=False)
526
  else:
527
  df_cleaned.to_csv(output_path, index=False, encoding="utf-8-sig")
528
+
529
  return output_path
530
+
531
  except Exception as e:
532
  print(f"Error removing duplicates from {file_path}: {e}")
533
  raise
534
 
535
+
536
  def run_tabular_duplicate_analysis(
537
  files: List[str],
538
  threshold: float,
 
542
  do_initial_clean_dup: bool = DO_INITIAL_TABULAR_DATA_CLEAN,
543
  remove_duplicate_rows: bool = REMOVE_DUPLICATE_ROWS,
544
  in_excel_tabular_sheets: List[str] = [],
545
+ progress: Progress = Progress(track_tqdm=True),
546
  ) -> Tuple[pd.DataFrame, List[str], Dict[str, pd.DataFrame]]:
547
  """
548
  Main function to run tabular duplicate analysis.
549
+
550
  Args:
551
  files (List[str]): List of file paths
552
  threshold (float): Similarity threshold
 
554
  text_columns (List[str], optional): Specific columns to analyze
555
  output_folder (str, optional): Output folder for results
556
  progress (Progress): Progress tracking
557
+
558
  Returns:
559
  Tuple containing results DataFrame, output paths, and full data by file
560
  """
 
565
  text_columns=text_columns if text_columns else [],
566
  output_folder=output_folder,
567
  do_initial_clean_dup=do_initial_clean_dup,
568
+ in_excel_tabular_sheets=(
569
+ in_excel_tabular_sheets if in_excel_tabular_sheets else []
570
+ ),
571
+ remove_duplicate_rows=remove_duplicate_rows,
572
  )
573
 
574
 
 
575
  # Function to update column choices when files are uploaded
576
  def update_tabular_column_choices(files, in_excel_tabular_sheets: List[str] = []):
577
  if not files:
578
  return gr.update(choices=[])
579
+
580
  all_columns = set()
581
  for file in files:
582
  try:
583
  file_extension = os.path.splitext(file.name)[-1]
584
+ if file_extension in [".xlsx", ".xls"]:
585
  for sheet_name in in_excel_tabular_sheets:
586
  df = read_file(file.name, excel_sheet_name=sheet_name)
587
+ text_cols = df.select_dtypes(
588
+ include=["object", "string"]
589
+ ).columns.tolist()
590
  all_columns.update(text_cols)
591
  else:
592
  df = read_file(file.name)
593
+ text_cols = df.select_dtypes(
594
+ include=["object", "string"]
595
+ ).columns.tolist()
596
  all_columns.update(text_cols)
597
 
598
  # Get text columns
599
+ text_cols = df.select_dtypes(include=["object", "string"]).columns.tolist()
600
 
601
  all_columns.update(text_cols)
602
  except Exception as e:
603
  print(f"Error reading {file.name}: {e}")
604
  continue
605
+
606
  return gr.Dropdown(choices=sorted(list(all_columns)))
607
 
608
+
609
  # Function to handle tabular duplicate detection
610
+ def run_tabular_duplicate_detection(
611
+ files,
612
+ threshold,
613
+ min_words,
614
+ text_columns,
615
+ output_folder: str = OUTPUT_FOLDER,
616
+ do_initial_clean_dup: bool = DO_INITIAL_TABULAR_DATA_CLEAN,
617
+ in_excel_tabular_sheets: List[str] = [],
618
+ remove_duplicate_rows: bool = REMOVE_DUPLICATE_ROWS,
619
+ ):
620
  if not files:
621
  print("No files uploaded")
622
  return pd.DataFrame(), [], gr.Dropdown(choices=[]), 0, "deduplicate"
 
626
  task_textbox = "deduplicate"
627
 
628
  # If output folder doesn't end with a forward slash, add one
629
+ if not output_folder.endswith("/"):
630
+ output_folder = output_folder + "/"
631
+
632
  file_paths = list()
633
  if isinstance(files, str):
634
  # If 'files' is a single string, treat it as a list with one element
 
639
  if isinstance(f_item, str):
640
  # If an element is a string, it's a direct file path
641
  file_paths.append(f_item)
642
+ elif hasattr(f_item, "name"):
643
  # If an element has a '.name' attribute (e.g., a Gradio File object), use its name
644
  file_paths.append(f_item.name)
645
  else:
646
  # Log a warning for unexpected element types within the list
647
+ print(
648
+ f"Warning: Skipping an element in 'files' list that is neither a string nor has a '.name' attribute: {type(f_item)}"
649
+ )
650
+ elif hasattr(files, "name"):
651
  # Handle the case where a single file object (e.g., gr.File) is passed directly, not in a list
652
  file_paths.append(files.name)
653
  else:
654
  # Raise an error for any other unexpected type of the 'files' argument itself
655
+ raise TypeError(
656
+ f"Unexpected type for 'files' argument: {type(files)}. Expected str, list of str/file objects, or a single file object."
657
+ )
658
 
659
  if len(file_paths) > MAX_SIMULTANEOUS_FILES:
660
  out_message = f"Number of files to deduplicate is greater than {MAX_SIMULTANEOUS_FILES}. Please submit a smaller number of files."
 
668
  text_columns=text_columns if text_columns else [],
669
  output_folder=output_folder,
670
  do_initial_clean_dup=do_initial_clean_dup,
671
+ in_excel_tabular_sheets=(
672
+ in_excel_tabular_sheets if in_excel_tabular_sheets else None
673
+ ),
674
+ remove_duplicate_rows=remove_duplicate_rows,
675
  )
676
+
677
  # Update file choices for cleaning
678
  file_choices = list(set([f for f in file_paths]))
679
 
680
  end_time = time.time()
681
  processing_time = round(end_time - start_time, 2)
682
+
683
+ return (
684
+ results_df,
685
+ output_paths,
686
+ gr.Dropdown(choices=file_choices),
687
+ processing_time,
688
+ task_textbox,
689
+ )
690
+
691
 
692
  # Function to handle row selection for preview
693
+ def handle_tabular_row_selection(results_df, evt: gr.SelectData):
694
+
695
  if not evt:
696
  return None, "", ""
697
 
 
699
  return None, "", ""
700
  elif results_df.empty:
701
  return None, "", ""
702
+
703
  selected_index = evt.index[0]
704
  if selected_index >= len(results_df):
705
  return None, "", ""
706
+
707
  row = results_df.iloc[selected_index]
708
+ return selected_index, row["Text1"], row["Text2"]
709
+
710
 
711
  # Function to clean duplicates from selected file
712
+ def clean_tabular_duplicates(
713
+ file_name,
714
+ results_df,
715
+ output_folder,
716
+ in_excel_tabular_sheets: str = "",
717
+ remove_duplicate_rows: bool = REMOVE_DUPLICATE_ROWS,
718
+ ):
719
  if not file_name or results_df.empty:
720
  return None
721
+
722
  # Get duplicate rows for this file
723
+ file_duplicates = results_df[results_df["File2"] == file_name]["Row2"].tolist()
724
+
725
  if not file_duplicates:
726
  return None
727
+
728
  try:
729
  # Find the original file path
730
  # This is a simplified approach - in practice you might want to store file paths
 
733
  duplicate_rows=file_duplicates,
734
  output_folder=output_folder,
735
  in_excel_tabular_sheets=in_excel_tabular_sheets,
736
+ remove_duplicate_rows=remove_duplicate_rows,
737
  )
738
  return cleaned_file
739
  except Exception as e:
740
  print(f"Error cleaning duplicates: {e}")
741
+ return None
tools/helper_functions.py CHANGED
@@ -1,66 +1,119 @@
1
  import os
2
  import re
 
 
 
 
3
  import boto3
4
- from botocore.exceptions import ClientError
5
  import gradio as gr
6
- import pandas as pd
7
  import numpy as np
8
- import unicodedata
9
- from typing import List
10
- from math import ceil
11
  from gradio_image_annotation import image_annotator
12
- from tools.config import CUSTOM_HEADER_VALUE, CUSTOM_HEADER, OUTPUT_FOLDER, INPUT_FOLDER, SESSION_OUTPUT_FOLDER, AWS_USER_POOL_ID, TEXTRACT_WHOLE_DOCUMENT_ANALYSIS_INPUT_SUBFOLDER, TEXTRACT_WHOLE_DOCUMENT_ANALYSIS_OUTPUT_SUBFOLDER, TEXTRACT_JOBS_S3_LOC, TEXTRACT_JOBS_LOCAL_LOC, SELECTABLE_TEXT_EXTRACT_OPTION, TESSERACT_TEXT_EXTRACT_OPTION, TEXTRACT_TEXT_EXTRACT_OPTION, NO_REDACTION_PII_OPTION, AWS_PII_OPTION, MAPPED_LANGUAGE_CHOICES, LANGUAGE_CHOICES, textract_language_choices, aws_comprehend_language_choices, DEFAULT_LANGUAGE
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  def _get_env_list(env_var_name: str) -> List[str]:
15
  """Parses a comma-separated environment variable into a list of strings."""
16
- value = env_var_name[1:-1].strip().replace('\"', '').replace("\'","")
17
  if not value:
18
  return []
19
  # Split by comma and filter out any empty strings that might result from extra commas
20
- return [s.strip() for s in value.split(',') if s.strip()]
 
21
 
22
- if textract_language_choices: textract_language_choices = _get_env_list(textract_language_choices)
23
- if aws_comprehend_language_choices: aws_comprehend_language_choices = _get_env_list(aws_comprehend_language_choices)
 
 
24
 
25
- if MAPPED_LANGUAGE_CHOICES: MAPPED_LANGUAGE_CHOICES = _get_env_list(MAPPED_LANGUAGE_CHOICES)
26
- if LANGUAGE_CHOICES: LANGUAGE_CHOICES = _get_env_list(LANGUAGE_CHOICES)
 
 
27
 
28
  LANGUAGE_MAP = dict(zip(MAPPED_LANGUAGE_CHOICES, LANGUAGE_CHOICES))
29
 
 
30
  def reset_state_vars():
31
- return [], pd.DataFrame(), pd.DataFrame(), 0, "", image_annotator(
 
 
 
 
 
 
32
  label="Modify redaction boxes",
33
  label_list=["Redaction"],
34
  label_colors=[(0, 0, 0)],
35
  show_label=False,
36
- sources=None,#["upload"],
37
  show_clear_button=False,
38
  show_share_button=False,
39
  show_remove_button=False,
40
- interactive=False
41
- ), [], [], pd.DataFrame(), pd.DataFrame(), [], [], "", False, 0, []
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
  def reset_ocr_results_state():
44
  return pd.DataFrame(), pd.DataFrame(), []
45
 
 
46
  def reset_review_vars():
47
  return pd.DataFrame(), pd.DataFrame()
48
 
 
49
  def reset_data_vars():
50
  return 0, [], 0
51
 
 
52
  def reset_aws_call_vars():
53
  return 0, 0
54
 
 
55
  def load_in_default_allow_list(allow_list_file_path):
56
  if isinstance(allow_list_file_path, str):
57
  allow_list_file_path = [allow_list_file_path]
58
  return allow_list_file_path
59
 
60
- def load_in_default_cost_codes(cost_codes_path:str, default_cost_code:str=""):
61
- '''
 
62
  Load in the cost codes list from file.
63
- '''
64
  cost_codes_df = pd.read_csv(cost_codes_path)
65
  dropdown_choices = cost_codes_df.iloc[:, 0].astype(str).tolist()
66
 
@@ -76,36 +129,50 @@ def load_in_default_cost_codes(cost_codes_path:str, default_cost_code:str=""):
76
  value=default_cost_code if default_cost_code in dropdown_choices else "",
77
  label="Choose cost code for analysis",
78
  choices=dropdown_choices,
79
- allow_custom_value=False
80
  )
81
-
82
  return cost_codes_df, cost_codes_df, out_dropdown
83
 
84
- def enforce_cost_codes(enforce_cost_code_textbox:str, cost_code_choice:str, cost_code_df:pd.DataFrame, verify_cost_codes:bool=True):
85
- '''
 
 
 
 
 
 
86
  Check if the enforce cost codes variable is set to true, and then check that a cost cost has been chosen. If not, raise an error. Then, check against the values in the cost code dataframe to ensure that the cost code exists.
87
- '''
88
 
89
  if enforce_cost_code_textbox == "True":
90
  if not cost_code_choice:
91
  raise Exception("Please choose a cost code before continuing")
92
-
93
- if verify_cost_codes == True:
94
  if cost_code_df.empty:
95
  raise Exception("No cost codes present in dataframe for verification")
96
  else:
97
- valid_cost_codes_list = list(cost_code_df.iloc[:,0].unique())
98
 
99
- if not cost_code_choice in valid_cost_codes_list:
100
- raise Exception("Selected cost code not found in list. Please contact Finance if you cannot find the correct cost code from the given list of suggestions.")
 
 
101
  return
102
 
103
- def update_cost_code_dataframe_from_dropdown_select(cost_dropdown_selection:str, cost_code_df:pd.DataFrame):
104
- cost_code_df = cost_code_df.loc[cost_code_df.iloc[:,0] == cost_dropdown_selection, :]
 
 
 
 
 
105
  return cost_code_df
106
 
107
- def ensure_folder_exists(output_folder:str):
108
- """Checks if the specified folder exists, creates it if not."""
 
109
 
110
  if not os.path.exists(output_folder):
111
  # Create the folder if it doesn't exist
@@ -114,58 +181,80 @@ def ensure_folder_exists(output_folder:str):
114
  else:
115
  print(f"The {output_folder} folder already exists.")
116
 
117
- def update_dataframe(df:pd.DataFrame):
 
118
  df_copy = df.copy()
119
  return df_copy
120
 
 
121
  def get_file_name_without_type(file_path):
122
  # First, get the basename of the file (e.g., "example.txt" from "/path/to/example.txt")
123
  basename = os.path.basename(file_path)
124
-
125
  # Then, split the basename and its extension and return only the basename without the extension
126
  filename_without_extension, _ = os.path.splitext(basename)
127
 
128
- #print(filename_without_extension)
129
-
130
  return filename_without_extension
131
 
132
- def detect_file_type(filename:str):
 
133
  """Detect the file type based on its extension."""
134
  if not isinstance(filename, str):
135
  filename = str(filename)
136
 
137
- if (filename.endswith('.csv')) | (filename.endswith('.csv.gz')) | (filename.endswith('.zip')): return 'csv'
138
- elif filename.endswith('.xlsx'): return 'xlsx'
139
- elif filename.endswith('.xls'): return 'xls'
140
- elif filename.endswith('.parquet'): return 'parquet'
141
- elif filename.endswith('.pdf'): return 'pdf'
142
- elif filename.endswith('.jpg'): return 'jpg'
143
- elif filename.endswith('.jpeg'): return 'jpeg'
144
- elif filename.endswith('.png'): return 'png'
145
- elif filename.endswith('.xfdf'): return 'xfdf'
146
- elif filename.endswith('.docx'): return 'docx'
147
- else: raise ValueError("Unsupported file type.")
148
-
149
- def read_file(filename:str, excel_sheet_name: str = ""):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
  """Read the file based on its detected type."""
151
  file_type = detect_file_type(filename)
152
-
153
- if file_type == 'csv':
154
  return pd.read_csv(filename, low_memory=False)
155
- elif file_type == 'xlsx':
156
  if excel_sheet_name:
157
  try:
158
  return pd.read_excel(filename, sheet_name=excel_sheet_name)
159
  except Exception as e:
160
- print(f"Error reading {filename} with sheet name {excel_sheet_name}: {e}")
 
 
161
  return pd.DataFrame()
162
  else:
163
  return pd.read_excel(filename)
164
- elif file_type == 'parquet':
165
  return pd.read_parquet(filename)
166
 
167
- def ensure_output_folder_exists(output_folder:str):
168
- """Checks if the specified folder exists, creates it if not."""
 
169
 
170
  if not os.path.exists(output_folder):
171
  # Create the folder if it doesn't exist
@@ -175,11 +264,10 @@ def ensure_output_folder_exists(output_folder:str):
175
  print(f"The {output_folder} folder already exists.")
176
 
177
 
178
-
179
- def custom_regex_load(in_file:List[str], file_type:str = "allow_list"):
180
- '''
181
  When file is loaded, update the column dropdown choices and write to relevant data states.
182
- '''
183
  custom_regex_df = pd.DataFrame()
184
 
185
  if in_file:
@@ -188,11 +276,13 @@ def custom_regex_load(in_file:List[str], file_type:str = "allow_list"):
188
  regex_file_names = [string for string in file_list if "csv" in string.lower()]
189
  if regex_file_names:
190
  regex_file_name = regex_file_names[0]
191
- custom_regex_df = pd.read_csv(regex_file_name, low_memory=False, header=None)
192
-
 
 
193
  # Select just first columns
194
- custom_regex_df = pd.DataFrame(custom_regex_df.iloc[:,[0]])
195
- custom_regex_df.rename(columns={0:file_type}, inplace=True)
196
 
197
  custom_regex_df.columns = custom_regex_df.columns.astype(str)
198
 
@@ -200,23 +290,24 @@ def custom_regex_load(in_file:List[str], file_type:str = "allow_list"):
200
  print(output_text)
201
  else:
202
  output_text = "No file provided."
203
- #print(output_text)
204
  return output_text, custom_regex_df
205
-
206
  return output_text, custom_regex_df
207
 
208
- def put_columns_in_df(in_file:List[str]):
 
209
  new_choices = []
210
  concat_choices = []
211
  all_sheet_names = []
212
  number_of_excel_files = 0
213
-
214
  for file in in_file:
215
  file_name = file.name
216
  file_type = detect_file_type(file_name)
217
  print("File type is:", file_type)
218
 
219
- if (file_type == 'xlsx') | (file_type == 'xls'):
220
  number_of_excel_files += 1
221
  new_choices = []
222
  print("Running through all xlsx sheets")
@@ -240,47 +331,65 @@ def put_columns_in_df(in_file:List[str]):
240
  new_choices = []
241
 
242
  concat_choices.extend(new_choices)
243
-
244
  # Drop duplicate columns
245
  concat_choices = list(set(concat_choices))
246
 
247
- if number_of_excel_files > 0:
248
- return gr.Dropdown(choices=concat_choices, value=concat_choices), gr.Dropdown(choices=all_sheet_names, value=all_sheet_names, visible=True)
 
 
249
  else:
250
- return gr.Dropdown(choices=concat_choices, value=concat_choices), gr.Dropdown(visible=False)
 
 
251
 
252
- def check_for_existing_textract_file(doc_file_name_no_extension_textbox:str, output_folder:str=OUTPUT_FOLDER):
253
- textract_output_path = os.path.join(output_folder, doc_file_name_no_extension_textbox + "_textract.json")
 
 
 
 
 
254
 
255
  if os.path.exists(textract_output_path):
256
- print("Existing Textract analysis output file found.")
257
  return True
258
-
259
  else:
260
  return False
261
-
262
- def check_for_relevant_ocr_output_with_words(doc_file_name_no_extension_textbox:str, text_extraction_method:str, output_folder:str=OUTPUT_FOLDER):
263
- if text_extraction_method == SELECTABLE_TEXT_EXTRACT_OPTION: file_ending = "_ocr_results_with_words_local_text.json"
264
- elif text_extraction_method == TESSERACT_TEXT_EXTRACT_OPTION: file_ending = "_ocr_results_with_words_local_ocr.json"
265
- elif text_extraction_method == TEXTRACT_TEXT_EXTRACT_OPTION: file_ending = "_ocr_results_with_words_textract.json"
 
 
 
 
 
 
 
 
266
  else:
267
  print("No valid text extraction method found. Returning False")
268
  return False
269
-
270
  doc_file_with_ending = doc_file_name_no_extension_textbox + file_ending
271
 
272
  local_ocr_output_path = os.path.join(output_folder, doc_file_with_ending)
273
 
274
  if os.path.exists(local_ocr_output_path):
275
- print("Existing OCR with words analysis output file found.")
276
- return True
277
  else:
278
  return False
279
 
 
280
  def add_folder_to_path(folder_path: str):
281
- '''
282
  Check if a folder exists on your system. If so, get the absolute path and then add it to the system Path variable if it doesn't already exist. Function is only relevant for locally-created executable files based on this app (when using pyinstaller it creates a _internal folder that contains tesseract and poppler. These need to be added to the system path to enable the app to run)
283
- '''
284
 
285
  if os.path.exists(folder_path) and os.path.isdir(folder_path):
286
  print(folder_path, "folder exists.")
@@ -288,21 +397,31 @@ def add_folder_to_path(folder_path: str):
288
  # Resolve relative path to absolute path
289
  absolute_path = os.path.abspath(folder_path)
290
 
291
- current_path = os.environ['PATH']
292
  if absolute_path not in current_path.split(os.pathsep):
293
  full_path_extension = absolute_path + os.pathsep + current_path
294
- os.environ['PATH'] = full_path_extension
295
- #print(f"Updated PATH with: ", full_path_extension)
296
  else:
297
  print(f"Directory {folder_path} already exists in PATH.")
298
  else:
299
  print(f"Folder not found at {folder_path} - not added to PATH")
300
 
 
301
  # Upon running a process, the feedback buttons are revealed
302
  def reveal_feedback_buttons():
303
- return gr.Radio(visible=True, label="Please give some feedback about the results of the redaction. A reminder that the app is only expected to identify about 60% of personally identifiable information in a given (typed) document."), gr.Textbox(visible=True), gr.Button(visible=True), gr.Markdown(visible=True)
 
 
 
 
 
 
 
 
304
 
305
- def wipe_logs(feedback_logs_loc:str, usage_logs_loc:str):
 
306
  try:
307
  os.remove(feedback_logs_loc)
308
  except Exception as e:
@@ -312,7 +431,8 @@ def wipe_logs(feedback_logs_loc:str, usage_logs_loc:str):
312
  except Exception as e:
313
  print("Could not remove usage logs file", e)
314
 
315
- def merge_csv_files(file_list:List[str], output_folder:str=OUTPUT_FOLDER):
 
316
 
317
  # Initialise an empty list to hold DataFrames
318
  dataframes = []
@@ -325,14 +445,16 @@ def merge_csv_files(file_list:List[str], output_folder:str=OUTPUT_FOLDER):
325
  dataframes.append(df)
326
 
327
  # Concatenate all DataFrames into a single DataFrame
328
- merged_df = pd.concat(dataframes, ignore_index=True)
329
 
330
- for col in ['xmin', 'xmax', 'ymin', 'ymax']:
331
  merged_df[col] = np.floor(merged_df[col])
332
 
333
- merged_df = merged_df.drop_duplicates(subset=['page', 'label', 'color', 'xmin', 'ymin', 'xmax', 'ymax'])
 
 
334
 
335
- merged_df = merged_df.sort_values(['page', 'ymin', 'xmin', 'label'])
336
 
337
  file_out_name = os.path.basename(file_list[0])
338
 
@@ -344,51 +466,58 @@ def merge_csv_files(file_list:List[str], output_folder:str=OUTPUT_FOLDER):
344
 
345
  return output_files
346
 
347
- async def get_connection_params(request: gr.Request,
348
- output_folder_textbox:str=OUTPUT_FOLDER,
349
- input_folder_textbox:str=INPUT_FOLDER,
350
- session_output_folder:str=SESSION_OUTPUT_FOLDER,
351
- textract_document_upload_input_folder:str=TEXTRACT_WHOLE_DOCUMENT_ANALYSIS_INPUT_SUBFOLDER,
352
- textract_document_upload_output_folder:str=TEXTRACT_WHOLE_DOCUMENT_ANALYSIS_OUTPUT_SUBFOLDER,
353
- s3_textract_document_logs_subfolder:str=TEXTRACT_JOBS_S3_LOC,
354
- local_textract_document_logs_subfolder:str=TEXTRACT_JOBS_LOCAL_LOC):
355
 
356
- #print("Session hash:", request.session_hash)
 
 
 
 
 
 
 
 
 
 
 
357
 
358
  if CUSTOM_HEADER and CUSTOM_HEADER_VALUE:
359
- if CUSTOM_HEADER in request.headers:
360
- supplied_custom_header_value = request.headers[CUSTOM_HEADER]
361
- if supplied_custom_header_value == CUSTOM_HEADER_VALUE:
362
- print("Custom header supplied and matches CUSTOM_HEADER_VALUE")
363
- else:
364
- print("Custom header value does not match expected value.")
365
- raise ValueError("Custom header value does not match expected value.")
366
  else:
367
- print("Custom header value not found.")
368
- raise ValueError("Custom header value not found.")
 
 
 
369
 
370
  # Get output save folder from 1 - username passed in from direct Cognito login, 2 - Cognito ID header passed through a Lambda authenticator, 3 - the session hash.
371
 
372
  if request.username:
373
  out_session_hash = request.username
374
- #print("Request username found:", out_session_hash)
375
 
376
- elif 'x-cognito-id' in request.headers:
377
- out_session_hash = request.headers['x-cognito-id']
378
- #print("Cognito ID found:", out_session_hash)
379
 
380
- elif 'x-amzn-oidc-identity' in request.headers:
381
- out_session_hash = request.headers['x-amzn-oidc-identity']
382
 
383
  # Fetch email address using Cognito client
384
- cognito_client = boto3.client('cognito-idp')
385
  try:
386
  response = cognito_client.admin_get_user(
387
  UserPoolId=AWS_USER_POOL_ID, # Replace with your User Pool ID
388
- Username=out_session_hash
389
  )
390
- email = next(attr['Value'] for attr in response['UserAttributes'] if attr['Name'] == 'email')
391
- #print("Email address found:", email)
 
 
 
 
392
 
393
  out_session_hash = email
394
  except ClientError as e:
@@ -400,33 +529,59 @@ async def get_connection_params(request: gr.Request,
400
  else:
401
  out_session_hash = request.session_hash
402
 
403
- if session_output_folder == 'True':
404
  output_folder = output_folder_textbox + out_session_hash + "/"
405
  input_folder = input_folder_textbox + out_session_hash + "/"
406
 
407
- textract_document_upload_input_folder = textract_document_upload_input_folder + "/" + out_session_hash
408
- textract_document_upload_output_folder = textract_document_upload_output_folder + "/" + out_session_hash
 
 
 
 
409
 
410
- s3_textract_document_logs_subfolder = s3_textract_document_logs_subfolder + "/" + out_session_hash
411
- local_textract_document_logs_subfolder = local_textract_document_logs_subfolder + "/" + out_session_hash + "/"
 
 
 
 
412
 
413
  else:
414
  output_folder = output_folder_textbox
415
  input_folder = input_folder_textbox
416
 
417
- if not os.path.exists(output_folder): os.mkdir(output_folder)
418
- if not os.path.exists(input_folder): os.mkdir(input_folder)
 
 
 
 
 
 
 
 
 
 
 
 
 
419
 
420
- return out_session_hash, output_folder, out_session_hash, input_folder, textract_document_upload_input_folder, textract_document_upload_output_folder, s3_textract_document_logs_subfolder, local_textract_document_logs_subfolder
421
 
422
- def clean_unicode_text(text:str):
423
  # Step 1: Normalise unicode characters to decompose any special forms
424
- normalized_text = unicodedata.normalize('NFKC', text)
425
 
426
  # Step 2: Replace smart quotes and special punctuation with standard ASCII equivalents
427
  replacements = {
428
- '': "'", '’': "'", '“': '"', '”': '"',
429
- '–': '-', '—': '-', '…': '...', '•': '*',
 
 
 
 
 
 
430
  }
431
 
432
  # Perform replacements
@@ -436,14 +591,15 @@ def clean_unicode_text(text:str):
436
  # Step 3: Optionally remove non-ASCII characters if needed
437
  # This regex removes any remaining non-ASCII characters, if desired.
438
  # Comment this line if you want to keep all Unicode characters.
439
- cleaned_text = re.sub(r'[^\x00-\x7F]+', '', normalized_text)
440
 
441
  return cleaned_text
442
-
443
- def load_all_output_files(folder_path:str=OUTPUT_FOLDER) -> List[str]:
 
444
  """Get the file paths of all files in the given folder."""
445
  file_paths = []
446
-
447
  # List all files in the specified folder
448
  for filename in os.listdir(folder_path):
449
  # Construct full file path
@@ -451,26 +607,29 @@ def load_all_output_files(folder_path:str=OUTPUT_FOLDER) -> List[str]:
451
  # Check if it's a file (not a directory)
452
  if os.path.isfile(full_path):
453
  file_paths.append(full_path)
454
-
455
  return file_paths
456
 
457
- def calculate_aws_costs(number_of_pages:str,
458
- text_extract_method_radio:str,
459
- handwrite_signature_checkbox:List[str],
460
- pii_identification_method:str,
461
- textract_output_found_checkbox:bool,
462
- only_extract_text_radio:bool,
463
- convert_to_gbp:bool=True,
464
- usd_gbp_conversion_rate:float=0.76,
465
- textract_page_cost:float=1.5/1000,
466
- textract_signature_cost:float=2.0/1000,
467
- comprehend_unit_cost:float=0.0001,
468
- comprehend_size_unit_average:float=250,
469
- average_characters_per_page:float=2000,
470
- TEXTRACT_TEXT_EXTRACT_OPTION:str=TEXTRACT_TEXT_EXTRACT_OPTION,
471
- NO_REDACTION_PII_OPTION:str=NO_REDACTION_PII_OPTION,
472
- AWS_PII_OPTION:str=AWS_PII_OPTION):
473
- '''
 
 
 
474
  Calculate the approximate cost of submitting a document to AWS Textract and/or AWS Comprehend, assuming that Textract outputs do not already exist in the output folder.
475
 
476
  - number_of_pages: The number of pages in the uploaded document(s).
@@ -489,49 +648,57 @@ def calculate_aws_costs(number_of_pages:str,
489
  - TEXTRACT_TEXT_EXTRACT_OPTION (str, optional): String label for the text_extract_method_radio button for AWS Textract.
490
  - NO_REDACTION_PII_OPTION (str, optional): String label for pii_identification_method_drop for no redaction.
491
  - AWS_PII_OPTION (str, optional): String label for pii_identification_method_drop for AWS Comprehend.
492
- '''
493
  text_extraction_cost = 0
494
  pii_identification_cost = 0
495
  calculated_aws_cost = 0
496
  number_of_pages = int(number_of_pages)
497
-
498
- if textract_output_found_checkbox != True:
499
  if text_extract_method_radio == TEXTRACT_TEXT_EXTRACT_OPTION:
500
  text_extraction_cost = number_of_pages * textract_page_cost
501
 
502
  if "Extract signatures" in handwrite_signature_checkbox:
503
- text_extraction_cost += (textract_signature_cost * number_of_pages)
504
 
505
  if pii_identification_method != NO_REDACTION_PII_OPTION:
506
  if pii_identification_method == AWS_PII_OPTION:
507
- comprehend_page_cost = ceil(average_characters_per_page / comprehend_size_unit_average) * comprehend_unit_cost
 
 
 
508
  pii_identification_cost = comprehend_page_cost * number_of_pages
509
 
510
- calculated_aws_cost = calculated_aws_cost + text_extraction_cost + pii_identification_cost
 
 
511
 
512
- if convert_to_gbp == True:
513
  calculated_aws_cost *= usd_gbp_conversion_rate
514
 
515
  return calculated_aws_cost
516
 
517
- def calculate_time_taken(number_of_pages:str,
518
- text_extract_method_radio:str,
519
- pii_identification_method:str,
520
- textract_output_found_checkbox:bool,
521
- only_extract_text_radio:bool,
522
- local_ocr_output_found_checkbox:bool,
523
- convert_page_time:float=0.5,
524
- textract_page_time:float=1.2,
525
- comprehend_page_time:float=1.2,
526
- local_text_extraction_page_time:float=0.3,
527
- local_pii_redaction_page_time:float=0.5,
528
- local_ocr_extraction_page_time:float=1.5,
529
- TEXTRACT_TEXT_EXTRACT_OPTION:str=TEXTRACT_TEXT_EXTRACT_OPTION,
530
- SELECTABLE_TEXT_EXTRACT_OPTION:str=SELECTABLE_TEXT_EXTRACT_OPTION,
531
- local_ocr_option:str=TESSERACT_TEXT_EXTRACT_OPTION,
532
- NO_REDACTION_PII_OPTION:str=NO_REDACTION_PII_OPTION,
533
- AWS_PII_OPTION:str=AWS_PII_OPTION):
534
- '''
 
 
 
535
  Calculate the approximate time to redact a document.
536
 
537
  - number_of_pages: The number of pages in the uploaded document(s).
@@ -548,9 +715,9 @@ def calculate_time_taken(number_of_pages:str,
548
  - TEXTRACT_TEXT_EXTRACT_OPTION (str, optional): String label for the text_extract_method_radio button for AWS Textract.
549
  - SELECTABLE_TEXT_EXTRACT_OPTION (str, optional): String label for text_extract_method_radio for text extraction.
550
  - local_ocr_option (str, optional): String label for text_extract_method_radio for local OCR.
551
- - NO_REDACTION_PII_OPTION (str, optional): String label for pii_identification_method_drop for no redaction.
552
  - AWS_PII_OPTION (str, optional): String label for pii_identification_method_drop for AWS Comprehend.
553
- '''
554
  calculated_time_taken = 0
555
  page_conversion_time_taken = 0
556
  page_extraction_time_taken = 0
@@ -559,16 +726,20 @@ def calculate_time_taken(number_of_pages:str,
559
  number_of_pages = int(number_of_pages)
560
 
561
  # Page preparation/conversion to image time
562
- if (text_extract_method_radio != SELECTABLE_TEXT_EXTRACT_OPTION) and (textract_output_found_checkbox != True):
 
 
563
  page_conversion_time_taken = number_of_pages * convert_page_time
564
 
565
  # Page text extraction time
566
  if text_extract_method_radio == TEXTRACT_TEXT_EXTRACT_OPTION:
567
- if textract_output_found_checkbox != True:
568
  page_extraction_time_taken = number_of_pages * textract_page_time
569
  elif text_extract_method_radio == local_ocr_option:
570
- if local_ocr_output_found_checkbox != True:
571
- page_extraction_time_taken = number_of_pages * local_ocr_extraction_page_time
 
 
572
  elif text_extract_method_radio == SELECTABLE_TEXT_EXTRACT_OPTION:
573
  page_conversion_time_taken = number_of_pages * local_text_extraction_page_time
574
 
@@ -579,47 +750,91 @@ def calculate_time_taken(number_of_pages:str,
579
  else:
580
  page_redaction_time_taken = number_of_pages * local_pii_redaction_page_time
581
 
582
- calculated_time_taken = (page_conversion_time_taken + page_extraction_time_taken + page_redaction_time_taken)/60
 
 
 
 
583
 
584
  return calculated_time_taken
585
 
586
- def reset_base_dataframe(df:pd.DataFrame):
 
587
  return df
588
 
589
- def reset_ocr_base_dataframe(df:pd.DataFrame):
 
590
  if df.empty:
591
  return pd.DataFrame(columns=["page", "line", "text"])
592
  else:
593
  return df.loc[:, ["page", "line", "text"]]
594
 
595
- def reset_ocr_with_words_base_dataframe(df:pd.DataFrame, page_entity_dropdown_redaction_value:str):
596
-
 
 
 
597
  df["index"] = df.index
598
  output_df = df.copy()
599
 
600
- df["page"]=df["page"].astype(str)
601
-
602
- output_df_filtered = df.loc[df["page"]==str(page_entity_dropdown_redaction_value), ["page", "line", "word_text", "word_x0", "word_y0", "word_x1", "word_y1", "index"]]
 
 
 
 
 
 
 
 
 
 
 
 
603
  return output_df_filtered, output_df
604
 
605
- def update_language_dropdown(chosen_language_full_name_drop, textract_language_choices=textract_language_choices, aws_comprehend_language_choices=aws_comprehend_language_choices, LANGUAGE_MAP=LANGUAGE_MAP):
 
 
 
 
 
 
606
 
607
  try:
608
  full_language_name = chosen_language_full_name_drop.lower()
609
  matched_language = LANGUAGE_MAP[full_language_name]
610
 
611
- chosen_language_drop = gr.Dropdown(value = matched_language, choices = LANGUAGE_CHOICES, label="Chosen language short code", multiselect=False, visible=True)
612
-
613
- if matched_language not in aws_comprehend_language_choices and matched_language not in textract_language_choices:
614
- gr.Info(f"Note that {full_language_name} is not supported by AWS Comprehend or AWS Textract")
 
 
 
 
 
 
 
 
 
 
 
615
  elif matched_language not in aws_comprehend_language_choices:
616
- gr.Info(f"Note that {full_language_name} is not supported by AWS Comprehend")
 
 
617
  elif matched_language not in textract_language_choices:
618
  gr.Info(f"Note that {full_language_name} is not supported by AWS Textract")
619
  except Exception as e:
620
  print(e)
621
  gr.Info("Could not find language in list")
622
- chosen_language_drop = gr.Dropdown(value = DEFAULT_LANGUAGE, choices = LANGUAGE_CHOICES, label="Chosen language short code", multiselect=False)
 
 
 
 
 
623
 
624
  return chosen_language_drop
625
-
 
1
  import os
2
  import re
3
+ import unicodedata
4
+ from math import ceil
5
+ from typing import List
6
+
7
  import boto3
 
8
  import gradio as gr
 
9
  import numpy as np
10
+ import pandas as pd
11
+ from botocore.exceptions import ClientError
 
12
  from gradio_image_annotation import image_annotator
13
+
14
+ from tools.config import (
15
+ AWS_PII_OPTION,
16
+ AWS_USER_POOL_ID,
17
+ CUSTOM_HEADER,
18
+ CUSTOM_HEADER_VALUE,
19
+ DEFAULT_LANGUAGE,
20
+ INPUT_FOLDER,
21
+ LANGUAGE_CHOICES,
22
+ MAPPED_LANGUAGE_CHOICES,
23
+ NO_REDACTION_PII_OPTION,
24
+ OUTPUT_FOLDER,
25
+ SELECTABLE_TEXT_EXTRACT_OPTION,
26
+ SESSION_OUTPUT_FOLDER,
27
+ TESSERACT_TEXT_EXTRACT_OPTION,
28
+ TEXTRACT_JOBS_LOCAL_LOC,
29
+ TEXTRACT_JOBS_S3_LOC,
30
+ TEXTRACT_TEXT_EXTRACT_OPTION,
31
+ TEXTRACT_WHOLE_DOCUMENT_ANALYSIS_INPUT_SUBFOLDER,
32
+ TEXTRACT_WHOLE_DOCUMENT_ANALYSIS_OUTPUT_SUBFOLDER,
33
+ aws_comprehend_language_choices,
34
+ textract_language_choices,
35
+ )
36
+
37
 
38
  def _get_env_list(env_var_name: str) -> List[str]:
39
  """Parses a comma-separated environment variable into a list of strings."""
40
+ value = env_var_name[1:-1].strip().replace('"', "").replace("'", "")
41
  if not value:
42
  return []
43
  # Split by comma and filter out any empty strings that might result from extra commas
44
+ return [s.strip() for s in value.split(",") if s.strip()]
45
+
46
 
47
+ if textract_language_choices:
48
+ textract_language_choices = _get_env_list(textract_language_choices)
49
+ if aws_comprehend_language_choices:
50
+ aws_comprehend_language_choices = _get_env_list(aws_comprehend_language_choices)
51
 
52
+ if MAPPED_LANGUAGE_CHOICES:
53
+ MAPPED_LANGUAGE_CHOICES = _get_env_list(MAPPED_LANGUAGE_CHOICES)
54
+ if LANGUAGE_CHOICES:
55
+ LANGUAGE_CHOICES = _get_env_list(LANGUAGE_CHOICES)
56
 
57
  LANGUAGE_MAP = dict(zip(MAPPED_LANGUAGE_CHOICES, LANGUAGE_CHOICES))
58
 
59
+
60
  def reset_state_vars():
61
+ return (
62
+ [],
63
+ pd.DataFrame(),
64
+ pd.DataFrame(),
65
+ 0,
66
+ "",
67
+ image_annotator(
68
  label="Modify redaction boxes",
69
  label_list=["Redaction"],
70
  label_colors=[(0, 0, 0)],
71
  show_label=False,
72
+ sources=None, # ["upload"],
73
  show_clear_button=False,
74
  show_share_button=False,
75
  show_remove_button=False,
76
+ interactive=False,
77
+ ),
78
+ [],
79
+ [],
80
+ pd.DataFrame(),
81
+ pd.DataFrame(),
82
+ [],
83
+ [],
84
+ "",
85
+ False,
86
+ 0,
87
+ [],
88
+ )
89
+
90
 
91
  def reset_ocr_results_state():
92
  return pd.DataFrame(), pd.DataFrame(), []
93
 
94
+
95
  def reset_review_vars():
96
  return pd.DataFrame(), pd.DataFrame()
97
 
98
+
99
  def reset_data_vars():
100
  return 0, [], 0
101
 
102
+
103
  def reset_aws_call_vars():
104
  return 0, 0
105
 
106
+
107
  def load_in_default_allow_list(allow_list_file_path):
108
  if isinstance(allow_list_file_path, str):
109
  allow_list_file_path = [allow_list_file_path]
110
  return allow_list_file_path
111
 
112
+
113
+ def load_in_default_cost_codes(cost_codes_path: str, default_cost_code: str = ""):
114
+ """
115
  Load in the cost codes list from file.
116
+ """
117
  cost_codes_df = pd.read_csv(cost_codes_path)
118
  dropdown_choices = cost_codes_df.iloc[:, 0].astype(str).tolist()
119
 
 
129
  value=default_cost_code if default_cost_code in dropdown_choices else "",
130
  label="Choose cost code for analysis",
131
  choices=dropdown_choices,
132
+ allow_custom_value=False,
133
  )
134
+
135
  return cost_codes_df, cost_codes_df, out_dropdown
136
 
137
+
138
+ def enforce_cost_codes(
139
+ enforce_cost_code_textbox: str,
140
+ cost_code_choice: str,
141
+ cost_code_df: pd.DataFrame,
142
+ verify_cost_codes: bool = True,
143
+ ):
144
+ """
145
  Check if the enforce cost codes variable is set to true, and then check that a cost cost has been chosen. If not, raise an error. Then, check against the values in the cost code dataframe to ensure that the cost code exists.
146
+ """
147
 
148
  if enforce_cost_code_textbox == "True":
149
  if not cost_code_choice:
150
  raise Exception("Please choose a cost code before continuing")
151
+
152
+ if verify_cost_codes is True:
153
  if cost_code_df.empty:
154
  raise Exception("No cost codes present in dataframe for verification")
155
  else:
156
+ valid_cost_codes_list = list(cost_code_df.iloc[:, 0].unique())
157
 
158
+ if cost_code_choice not in valid_cost_codes_list:
159
+ raise Exception(
160
+ "Selected cost code not found in list. Please contact Finance if you cannot find the correct cost code from the given list of suggestions."
161
+ )
162
  return
163
 
164
+
165
+ def update_cost_code_dataframe_from_dropdown_select(
166
+ cost_dropdown_selection: str, cost_code_df: pd.DataFrame
167
+ ):
168
+ cost_code_df = cost_code_df.loc[
169
+ cost_code_df.iloc[:, 0] == cost_dropdown_selection, :
170
+ ]
171
  return cost_code_df
172
 
173
+
174
+ def ensure_folder_exists(output_folder: str):
175
+ """Checks if the specified folder exists, creates it if not."""
176
 
177
  if not os.path.exists(output_folder):
178
  # Create the folder if it doesn't exist
 
181
  else:
182
  print(f"The {output_folder} folder already exists.")
183
 
184
+
185
+ def update_dataframe(df: pd.DataFrame):
186
  df_copy = df.copy()
187
  return df_copy
188
 
189
+
190
  def get_file_name_without_type(file_path):
191
  # First, get the basename of the file (e.g., "example.txt" from "/path/to/example.txt")
192
  basename = os.path.basename(file_path)
193
+
194
  # Then, split the basename and its extension and return only the basename without the extension
195
  filename_without_extension, _ = os.path.splitext(basename)
196
 
197
+ # print(filename_without_extension)
198
+
199
  return filename_without_extension
200
 
201
+
202
+ def detect_file_type(filename: str):
203
  """Detect the file type based on its extension."""
204
  if not isinstance(filename, str):
205
  filename = str(filename)
206
 
207
+ if (
208
+ (filename.endswith(".csv"))
209
+ | (filename.endswith(".csv.gz"))
210
+ | (filename.endswith(".zip"))
211
+ ):
212
+ return "csv"
213
+ elif filename.endswith(".xlsx"):
214
+ return "xlsx"
215
+ elif filename.endswith(".xls"):
216
+ return "xls"
217
+ elif filename.endswith(".parquet"):
218
+ return "parquet"
219
+ elif filename.endswith(".pdf"):
220
+ return "pdf"
221
+ elif filename.endswith(".jpg"):
222
+ return "jpg"
223
+ elif filename.endswith(".jpeg"):
224
+ return "jpeg"
225
+ elif filename.endswith(".png"):
226
+ return "png"
227
+ elif filename.endswith(".xfdf"):
228
+ return "xfdf"
229
+ elif filename.endswith(".docx"):
230
+ return "docx"
231
+ else:
232
+ raise ValueError("Unsupported file type.")
233
+
234
+
235
+ def read_file(filename: str, excel_sheet_name: str = ""):
236
  """Read the file based on its detected type."""
237
  file_type = detect_file_type(filename)
238
+
239
+ if file_type == "csv":
240
  return pd.read_csv(filename, low_memory=False)
241
+ elif file_type == "xlsx":
242
  if excel_sheet_name:
243
  try:
244
  return pd.read_excel(filename, sheet_name=excel_sheet_name)
245
  except Exception as e:
246
+ print(
247
+ f"Error reading {filename} with sheet name {excel_sheet_name}: {e}"
248
+ )
249
  return pd.DataFrame()
250
  else:
251
  return pd.read_excel(filename)
252
+ elif file_type == "parquet":
253
  return pd.read_parquet(filename)
254
 
255
+
256
+ def ensure_output_folder_exists(output_folder: str):
257
+ """Checks if the specified folder exists, creates it if not."""
258
 
259
  if not os.path.exists(output_folder):
260
  # Create the folder if it doesn't exist
 
264
  print(f"The {output_folder} folder already exists.")
265
 
266
 
267
+ def custom_regex_load(in_file: List[str], file_type: str = "allow_list"):
268
+ """
 
269
  When file is loaded, update the column dropdown choices and write to relevant data states.
270
+ """
271
  custom_regex_df = pd.DataFrame()
272
 
273
  if in_file:
 
276
  regex_file_names = [string for string in file_list if "csv" in string.lower()]
277
  if regex_file_names:
278
  regex_file_name = regex_file_names[0]
279
+ custom_regex_df = pd.read_csv(
280
+ regex_file_name, low_memory=False, header=None
281
+ )
282
+
283
  # Select just first columns
284
+ custom_regex_df = pd.DataFrame(custom_regex_df.iloc[:, [0]])
285
+ custom_regex_df.rename(columns={0: file_type}, inplace=True)
286
 
287
  custom_regex_df.columns = custom_regex_df.columns.astype(str)
288
 
 
290
  print(output_text)
291
  else:
292
  output_text = "No file provided."
293
+ # print(output_text)
294
  return output_text, custom_regex_df
295
+
296
  return output_text, custom_regex_df
297
 
298
+
299
+ def put_columns_in_df(in_file: List[str]):
300
  new_choices = []
301
  concat_choices = []
302
  all_sheet_names = []
303
  number_of_excel_files = 0
304
+
305
  for file in in_file:
306
  file_name = file.name
307
  file_type = detect_file_type(file_name)
308
  print("File type is:", file_type)
309
 
310
+ if (file_type == "xlsx") | (file_type == "xls"):
311
  number_of_excel_files += 1
312
  new_choices = []
313
  print("Running through all xlsx sheets")
 
331
  new_choices = []
332
 
333
  concat_choices.extend(new_choices)
334
+
335
  # Drop duplicate columns
336
  concat_choices = list(set(concat_choices))
337
 
338
+ if number_of_excel_files > 0:
339
+ return gr.Dropdown(choices=concat_choices, value=concat_choices), gr.Dropdown(
340
+ choices=all_sheet_names, value=all_sheet_names, visible=True
341
+ )
342
  else:
343
+ return gr.Dropdown(choices=concat_choices, value=concat_choices), gr.Dropdown(
344
+ visible=False
345
+ )
346
 
347
+
348
+ def check_for_existing_textract_file(
349
+ doc_file_name_no_extension_textbox: str, output_folder: str = OUTPUT_FOLDER
350
+ ):
351
+ textract_output_path = os.path.join(
352
+ output_folder, doc_file_name_no_extension_textbox + "_textract.json"
353
+ )
354
 
355
  if os.path.exists(textract_output_path):
356
+ print("Existing Textract analysis output file found.")
357
  return True
358
+
359
  else:
360
  return False
361
+
362
+
363
+ def check_for_relevant_ocr_output_with_words(
364
+ doc_file_name_no_extension_textbox: str,
365
+ text_extraction_method: str,
366
+ output_folder: str = OUTPUT_FOLDER,
367
+ ):
368
+ if text_extraction_method == SELECTABLE_TEXT_EXTRACT_OPTION:
369
+ file_ending = "_ocr_results_with_words_local_text.json"
370
+ elif text_extraction_method == TESSERACT_TEXT_EXTRACT_OPTION:
371
+ file_ending = "_ocr_results_with_words_local_ocr.json"
372
+ elif text_extraction_method == TEXTRACT_TEXT_EXTRACT_OPTION:
373
+ file_ending = "_ocr_results_with_words_textract.json"
374
  else:
375
  print("No valid text extraction method found. Returning False")
376
  return False
377
+
378
  doc_file_with_ending = doc_file_name_no_extension_textbox + file_ending
379
 
380
  local_ocr_output_path = os.path.join(output_folder, doc_file_with_ending)
381
 
382
  if os.path.exists(local_ocr_output_path):
383
+ print("Existing OCR with words analysis output file found.")
384
+ return True
385
  else:
386
  return False
387
 
388
+
389
  def add_folder_to_path(folder_path: str):
390
+ """
391
  Check if a folder exists on your system. If so, get the absolute path and then add it to the system Path variable if it doesn't already exist. Function is only relevant for locally-created executable files based on this app (when using pyinstaller it creates a _internal folder that contains tesseract and poppler. These need to be added to the system path to enable the app to run)
392
+ """
393
 
394
  if os.path.exists(folder_path) and os.path.isdir(folder_path):
395
  print(folder_path, "folder exists.")
 
397
  # Resolve relative path to absolute path
398
  absolute_path = os.path.abspath(folder_path)
399
 
400
+ current_path = os.environ["PATH"]
401
  if absolute_path not in current_path.split(os.pathsep):
402
  full_path_extension = absolute_path + os.pathsep + current_path
403
+ os.environ["PATH"] = full_path_extension
404
+ # print(f"Updated PATH with: ", full_path_extension)
405
  else:
406
  print(f"Directory {folder_path} already exists in PATH.")
407
  else:
408
  print(f"Folder not found at {folder_path} - not added to PATH")
409
 
410
+
411
  # Upon running a process, the feedback buttons are revealed
412
  def reveal_feedback_buttons():
413
+ return (
414
+ gr.Radio(
415
+ visible=True,
416
+ label="Please give some feedback about the results of the redaction. A reminder that the app is only expected to identify about 60% of personally identifiable information in a given (typed) document.",
417
+ ),
418
+ gr.Textbox(visible=True),
419
+ gr.Button(visible=True),
420
+ gr.Markdown(visible=True),
421
+ )
422
 
423
+
424
+ def wipe_logs(feedback_logs_loc: str, usage_logs_loc: str):
425
  try:
426
  os.remove(feedback_logs_loc)
427
  except Exception as e:
 
431
  except Exception as e:
432
  print("Could not remove usage logs file", e)
433
 
434
+
435
+ def merge_csv_files(file_list: List[str], output_folder: str = OUTPUT_FOLDER):
436
 
437
  # Initialise an empty list to hold DataFrames
438
  dataframes = []
 
445
  dataframes.append(df)
446
 
447
  # Concatenate all DataFrames into a single DataFrame
448
+ merged_df = pd.concat(dataframes, ignore_index=True)
449
 
450
+ for col in ["xmin", "xmax", "ymin", "ymax"]:
451
  merged_df[col] = np.floor(merged_df[col])
452
 
453
+ merged_df = merged_df.drop_duplicates(
454
+ subset=["page", "label", "color", "xmin", "ymin", "xmax", "ymax"]
455
+ )
456
 
457
+ merged_df = merged_df.sort_values(["page", "ymin", "xmin", "label"])
458
 
459
  file_out_name = os.path.basename(file_list[0])
460
 
 
466
 
467
  return output_files
468
 
 
 
 
 
 
 
 
 
469
 
470
+ async def get_connection_params(
471
+ request: gr.Request,
472
+ output_folder_textbox: str = OUTPUT_FOLDER,
473
+ input_folder_textbox: str = INPUT_FOLDER,
474
+ session_output_folder: str = SESSION_OUTPUT_FOLDER,
475
+ textract_document_upload_input_folder: str = TEXTRACT_WHOLE_DOCUMENT_ANALYSIS_INPUT_SUBFOLDER,
476
+ textract_document_upload_output_folder: str = TEXTRACT_WHOLE_DOCUMENT_ANALYSIS_OUTPUT_SUBFOLDER,
477
+ s3_textract_document_logs_subfolder: str = TEXTRACT_JOBS_S3_LOC,
478
+ local_textract_document_logs_subfolder: str = TEXTRACT_JOBS_LOCAL_LOC,
479
+ ):
480
+
481
+ # print("Session hash:", request.session_hash)
482
 
483
  if CUSTOM_HEADER and CUSTOM_HEADER_VALUE:
484
+ if CUSTOM_HEADER in request.headers:
485
+ supplied_custom_header_value = request.headers[CUSTOM_HEADER]
486
+ if supplied_custom_header_value == CUSTOM_HEADER_VALUE:
487
+ print("Custom header supplied and matches CUSTOM_HEADER_VALUE")
 
 
 
488
  else:
489
+ print("Custom header value does not match expected value.")
490
+ raise ValueError("Custom header value does not match expected value.")
491
+ else:
492
+ print("Custom header value not found.")
493
+ raise ValueError("Custom header value not found.")
494
 
495
  # Get output save folder from 1 - username passed in from direct Cognito login, 2 - Cognito ID header passed through a Lambda authenticator, 3 - the session hash.
496
 
497
  if request.username:
498
  out_session_hash = request.username
499
+ # print("Request username found:", out_session_hash)
500
 
501
+ elif "x-cognito-id" in request.headers:
502
+ out_session_hash = request.headers["x-cognito-id"]
503
+ # print("Cognito ID found:", out_session_hash)
504
 
505
+ elif "x-amzn-oidc-identity" in request.headers:
506
+ out_session_hash = request.headers["x-amzn-oidc-identity"]
507
 
508
  # Fetch email address using Cognito client
509
+ cognito_client = boto3.client("cognito-idp")
510
  try:
511
  response = cognito_client.admin_get_user(
512
  UserPoolId=AWS_USER_POOL_ID, # Replace with your User Pool ID
513
+ Username=out_session_hash,
514
  )
515
+ email = next(
516
+ attr["Value"]
517
+ for attr in response["UserAttributes"]
518
+ if attr["Name"] == "email"
519
+ )
520
+ # print("Email address found:", email)
521
 
522
  out_session_hash = email
523
  except ClientError as e:
 
529
  else:
530
  out_session_hash = request.session_hash
531
 
532
+ if session_output_folder == "True":
533
  output_folder = output_folder_textbox + out_session_hash + "/"
534
  input_folder = input_folder_textbox + out_session_hash + "/"
535
 
536
+ textract_document_upload_input_folder = (
537
+ textract_document_upload_input_folder + "/" + out_session_hash
538
+ )
539
+ textract_document_upload_output_folder = (
540
+ textract_document_upload_output_folder + "/" + out_session_hash
541
+ )
542
 
543
+ s3_textract_document_logs_subfolder = (
544
+ s3_textract_document_logs_subfolder + "/" + out_session_hash
545
+ )
546
+ local_textract_document_logs_subfolder = (
547
+ local_textract_document_logs_subfolder + "/" + out_session_hash + "/"
548
+ )
549
 
550
  else:
551
  output_folder = output_folder_textbox
552
  input_folder = input_folder_textbox
553
 
554
+ if not os.path.exists(output_folder):
555
+ os.mkdir(output_folder)
556
+ if not os.path.exists(input_folder):
557
+ os.mkdir(input_folder)
558
+
559
+ return (
560
+ out_session_hash,
561
+ output_folder,
562
+ out_session_hash,
563
+ input_folder,
564
+ textract_document_upload_input_folder,
565
+ textract_document_upload_output_folder,
566
+ s3_textract_document_logs_subfolder,
567
+ local_textract_document_logs_subfolder,
568
+ )
569
 
 
570
 
571
+ def clean_unicode_text(text: str):
572
  # Step 1: Normalise unicode characters to decompose any special forms
573
+ normalized_text = unicodedata.normalize("NFKC", text)
574
 
575
  # Step 2: Replace smart quotes and special punctuation with standard ASCII equivalents
576
  replacements = {
577
+ "‘": "'",
578
+ "’": "'",
579
+ "“": '"',
580
+ "”": '"',
581
+ "–": "-",
582
+ "—": "-",
583
+ "…": "...",
584
+ "•": "*",
585
  }
586
 
587
  # Perform replacements
 
591
  # Step 3: Optionally remove non-ASCII characters if needed
592
  # This regex removes any remaining non-ASCII characters, if desired.
593
  # Comment this line if you want to keep all Unicode characters.
594
+ cleaned_text = re.sub(r"[^\x00-\x7F]+", "", normalized_text)
595
 
596
  return cleaned_text
597
+
598
+
599
+ def load_all_output_files(folder_path: str = OUTPUT_FOLDER) -> List[str]:
600
  """Get the file paths of all files in the given folder."""
601
  file_paths = []
602
+
603
  # List all files in the specified folder
604
  for filename in os.listdir(folder_path):
605
  # Construct full file path
 
607
  # Check if it's a file (not a directory)
608
  if os.path.isfile(full_path):
609
  file_paths.append(full_path)
610
+
611
  return file_paths
612
 
613
+
614
+ def calculate_aws_costs(
615
+ number_of_pages: str,
616
+ text_extract_method_radio: str,
617
+ handwrite_signature_checkbox: List[str],
618
+ pii_identification_method: str,
619
+ textract_output_found_checkbox: bool,
620
+ only_extract_text_radio: bool,
621
+ convert_to_gbp: bool = True,
622
+ usd_gbp_conversion_rate: float = 0.76,
623
+ textract_page_cost: float = 1.5 / 1000,
624
+ textract_signature_cost: float = 2.0 / 1000,
625
+ comprehend_unit_cost: float = 0.0001,
626
+ comprehend_size_unit_average: float = 250,
627
+ average_characters_per_page: float = 2000,
628
+ TEXTRACT_TEXT_EXTRACT_OPTION: str = TEXTRACT_TEXT_EXTRACT_OPTION,
629
+ NO_REDACTION_PII_OPTION: str = NO_REDACTION_PII_OPTION,
630
+ AWS_PII_OPTION: str = AWS_PII_OPTION,
631
+ ):
632
+ """
633
  Calculate the approximate cost of submitting a document to AWS Textract and/or AWS Comprehend, assuming that Textract outputs do not already exist in the output folder.
634
 
635
  - number_of_pages: The number of pages in the uploaded document(s).
 
648
  - TEXTRACT_TEXT_EXTRACT_OPTION (str, optional): String label for the text_extract_method_radio button for AWS Textract.
649
  - NO_REDACTION_PII_OPTION (str, optional): String label for pii_identification_method_drop for no redaction.
650
  - AWS_PII_OPTION (str, optional): String label for pii_identification_method_drop for AWS Comprehend.
651
+ """
652
  text_extraction_cost = 0
653
  pii_identification_cost = 0
654
  calculated_aws_cost = 0
655
  number_of_pages = int(number_of_pages)
656
+
657
+ if textract_output_found_checkbox is not True:
658
  if text_extract_method_radio == TEXTRACT_TEXT_EXTRACT_OPTION:
659
  text_extraction_cost = number_of_pages * textract_page_cost
660
 
661
  if "Extract signatures" in handwrite_signature_checkbox:
662
+ text_extraction_cost += textract_signature_cost * number_of_pages
663
 
664
  if pii_identification_method != NO_REDACTION_PII_OPTION:
665
  if pii_identification_method == AWS_PII_OPTION:
666
+ comprehend_page_cost = (
667
+ ceil(average_characters_per_page / comprehend_size_unit_average)
668
+ * comprehend_unit_cost
669
+ )
670
  pii_identification_cost = comprehend_page_cost * number_of_pages
671
 
672
+ calculated_aws_cost = (
673
+ calculated_aws_cost + text_extraction_cost + pii_identification_cost
674
+ )
675
 
676
+ if convert_to_gbp is True:
677
  calculated_aws_cost *= usd_gbp_conversion_rate
678
 
679
  return calculated_aws_cost
680
 
681
+
682
+ def calculate_time_taken(
683
+ number_of_pages: str,
684
+ text_extract_method_radio: str,
685
+ pii_identification_method: str,
686
+ textract_output_found_checkbox: bool,
687
+ only_extract_text_radio: bool,
688
+ local_ocr_output_found_checkbox: bool,
689
+ convert_page_time: float = 0.5,
690
+ textract_page_time: float = 1.2,
691
+ comprehend_page_time: float = 1.2,
692
+ local_text_extraction_page_time: float = 0.3,
693
+ local_pii_redaction_page_time: float = 0.5,
694
+ local_ocr_extraction_page_time: float = 1.5,
695
+ TEXTRACT_TEXT_EXTRACT_OPTION: str = TEXTRACT_TEXT_EXTRACT_OPTION,
696
+ SELECTABLE_TEXT_EXTRACT_OPTION: str = SELECTABLE_TEXT_EXTRACT_OPTION,
697
+ local_ocr_option: str = TESSERACT_TEXT_EXTRACT_OPTION,
698
+ NO_REDACTION_PII_OPTION: str = NO_REDACTION_PII_OPTION,
699
+ AWS_PII_OPTION: str = AWS_PII_OPTION,
700
+ ):
701
+ """
702
  Calculate the approximate time to redact a document.
703
 
704
  - number_of_pages: The number of pages in the uploaded document(s).
 
715
  - TEXTRACT_TEXT_EXTRACT_OPTION (str, optional): String label for the text_extract_method_radio button for AWS Textract.
716
  - SELECTABLE_TEXT_EXTRACT_OPTION (str, optional): String label for text_extract_method_radio for text extraction.
717
  - local_ocr_option (str, optional): String label for text_extract_method_radio for local OCR.
718
+ - NO_REDACTION_PII_OPTION (str, optional): String label for pii_identification_method_drop for no redaction.
719
  - AWS_PII_OPTION (str, optional): String label for pii_identification_method_drop for AWS Comprehend.
720
+ """
721
  calculated_time_taken = 0
722
  page_conversion_time_taken = 0
723
  page_extraction_time_taken = 0
 
726
  number_of_pages = int(number_of_pages)
727
 
728
  # Page preparation/conversion to image time
729
+ if (text_extract_method_radio != SELECTABLE_TEXT_EXTRACT_OPTION) and (
730
+ textract_output_found_checkbox is not True
731
+ ):
732
  page_conversion_time_taken = number_of_pages * convert_page_time
733
 
734
  # Page text extraction time
735
  if text_extract_method_radio == TEXTRACT_TEXT_EXTRACT_OPTION:
736
+ if textract_output_found_checkbox is not True:
737
  page_extraction_time_taken = number_of_pages * textract_page_time
738
  elif text_extract_method_radio == local_ocr_option:
739
+ if local_ocr_output_found_checkbox is not True:
740
+ page_extraction_time_taken = (
741
+ number_of_pages * local_ocr_extraction_page_time
742
+ )
743
  elif text_extract_method_radio == SELECTABLE_TEXT_EXTRACT_OPTION:
744
  page_conversion_time_taken = number_of_pages * local_text_extraction_page_time
745
 
 
750
  else:
751
  page_redaction_time_taken = number_of_pages * local_pii_redaction_page_time
752
 
753
+ calculated_time_taken = (
754
+ page_conversion_time_taken
755
+ + page_extraction_time_taken
756
+ + page_redaction_time_taken
757
+ ) / 60
758
 
759
  return calculated_time_taken
760
 
761
+
762
+ def reset_base_dataframe(df: pd.DataFrame):
763
  return df
764
 
765
+
766
+ def reset_ocr_base_dataframe(df: pd.DataFrame):
767
  if df.empty:
768
  return pd.DataFrame(columns=["page", "line", "text"])
769
  else:
770
  return df.loc[:, ["page", "line", "text"]]
771
 
772
+
773
+ def reset_ocr_with_words_base_dataframe(
774
+ df: pd.DataFrame, page_entity_dropdown_redaction_value: str
775
+ ):
776
+
777
  df["index"] = df.index
778
  output_df = df.copy()
779
 
780
+ df["page"] = df["page"].astype(str)
781
+
782
+ output_df_filtered = df.loc[
783
+ df["page"] == str(page_entity_dropdown_redaction_value),
784
+ [
785
+ "page",
786
+ "line",
787
+ "word_text",
788
+ "word_x0",
789
+ "word_y0",
790
+ "word_x1",
791
+ "word_y1",
792
+ "index",
793
+ ],
794
+ ]
795
  return output_df_filtered, output_df
796
 
797
+
798
+ def update_language_dropdown(
799
+ chosen_language_full_name_drop,
800
+ textract_language_choices=textract_language_choices,
801
+ aws_comprehend_language_choices=aws_comprehend_language_choices,
802
+ LANGUAGE_MAP=LANGUAGE_MAP,
803
+ ):
804
 
805
  try:
806
  full_language_name = chosen_language_full_name_drop.lower()
807
  matched_language = LANGUAGE_MAP[full_language_name]
808
 
809
+ chosen_language_drop = gr.Dropdown(
810
+ value=matched_language,
811
+ choices=LANGUAGE_CHOICES,
812
+ label="Chosen language short code",
813
+ multiselect=False,
814
+ visible=True,
815
+ )
816
+
817
+ if (
818
+ matched_language not in aws_comprehend_language_choices
819
+ and matched_language not in textract_language_choices
820
+ ):
821
+ gr.Info(
822
+ f"Note that {full_language_name} is not supported by AWS Comprehend or AWS Textract"
823
+ )
824
  elif matched_language not in aws_comprehend_language_choices:
825
+ gr.Info(
826
+ f"Note that {full_language_name} is not supported by AWS Comprehend"
827
+ )
828
  elif matched_language not in textract_language_choices:
829
  gr.Info(f"Note that {full_language_name} is not supported by AWS Textract")
830
  except Exception as e:
831
  print(e)
832
  gr.Info("Could not find language in list")
833
+ chosen_language_drop = gr.Dropdown(
834
+ value=DEFAULT_LANGUAGE,
835
+ choices=LANGUAGE_CHOICES,
836
+ label="Chosen language short code",
837
+ multiselect=False,
838
+ )
839
 
840
  return chosen_language_drop
 
tools/load_spacy_model_custom_recognisers.py CHANGED
@@ -1,36 +1,58 @@
1
  from typing import List
2
- from presidio_analyzer import AnalyzerEngine, PatternRecognizer, EntityRecognizer, Pattern, RecognizerResult
3
- from presidio_analyzer.nlp_engine import SpacyNlpEngine, NlpArtifacts, NerModelConfiguration
4
  import spacy
 
 
 
 
 
 
 
 
 
 
 
 
5
  from spacy.matcher import Matcher
6
  from spaczz.matcher import FuzzyMatcher
 
7
  spacy.prefer_gpu()
8
- from spacy.cli.download import download
9
- import Levenshtein
10
- import re
11
  import os
12
- import requests
 
13
  import gradio as gr
14
- from tools.config import DEFAULT_LANGUAGE, TESSERACT_DATA_FOLDER, CUSTOM_ENTITIES
 
 
 
 
15
  from tools.helper_functions import _get_env_list
16
 
17
  score_threshold = 0.001
18
 
19
- if CUSTOM_ENTITIES: CUSTOM_ENTITIES = _get_env_list(CUSTOM_ENTITIES)
 
20
  custom_entities = CUSTOM_ENTITIES
21
 
 
22
  # Create a class inheriting from SpacyNlpEngine
23
  class LoadedSpacyNlpEngine(SpacyNlpEngine):
24
  def __init__(self, loaded_spacy_model, language_code: str):
25
- super().__init__(ner_model_configuration=NerModelConfiguration(labels_to_ignore=["CARDINAL", "ORDINAL"])) # Ignore non-relevant labels
 
 
 
 
26
  self.nlp = {language_code: loaded_spacy_model}
27
 
 
28
  def _base_language_code(language: str) -> str:
29
  lang = _normalize_language_input(language)
30
  if "_" in lang:
31
  return lang.split("_")[0]
32
  return lang
33
 
 
34
  def load_spacy_model(language: str = DEFAULT_LANGUAGE):
35
  """
36
  Load a spaCy model for the requested language and return it as `nlp`.
@@ -78,32 +100,39 @@ def load_spacy_model(language: str = DEFAULT_LANGUAGE):
78
  "en_trf": ["en_core_web_trf"],
79
  "en_md": ["en_core_web_md"],
80
  "en_sm": ["en_core_web_sm"],
81
-
82
  # Major languages (news pipelines)
83
- "ca": ["ca_core_news_lg", "ca_core_news_md", "ca_core_news_sm"], # Catalan
84
- "da": ["da_core_news_lg", "da_core_news_md", "da_core_news_sm"], # Danish
85
- "de": ["de_core_news_lg", "de_core_news_md", "de_core_news_sm"], # German
86
- "el": ["el_core_news_lg", "el_core_news_md", "el_core_news_sm"], # Greek
87
- "es": ["es_core_news_lg", "es_core_news_md", "es_core_news_sm"], # Spanish
88
- "fi": ["fi_core_news_lg", "fi_core_news_md", "fi_core_news_sm"], # Finnish
89
- "fr": ["fr_core_news_lg", "fr_core_news_md", "fr_core_news_sm"], # French
90
- "hr": ["hr_core_news_lg", "hr_core_news_md", "hr_core_news_sm"], # Croatian
91
- "it": ["it_core_news_lg", "it_core_news_md", "it_core_news_sm"], # Italian
92
- "ja": ["ja_core_news_lg", "ja_core_news_md", "ja_core_news_sm"], # Japanese
93
- "ko": ["ko_core_news_lg", "ko_core_news_md", "ko_core_news_sm"], # Korean
94
- "lt": ["lt_core_news_lg", "lt_core_news_md", "lt_core_news_sm"], # Lithuanian
95
- "mk": ["mk_core_news_lg", "mk_core_news_md", "mk_core_news_sm"], # Macedonian
96
- "nb": ["nb_core_news_lg", "nb_core_news_md", "nb_core_news_sm"], # Norwegian Bokmål
97
- "nl": ["nl_core_news_lg", "nl_core_news_md", "nl_core_news_sm"], # Dutch
98
- "pl": ["pl_core_news_lg", "pl_core_news_md", "pl_core_news_sm"], # Polish
99
- "pt": ["pt_core_news_lg", "pt_core_news_md", "pt_core_news_sm"], # Portuguese
100
- "ro": ["ro_core_news_lg", "ro_core_news_md", "ro_core_news_sm"], # Romanian
101
- "ru": ["ru_core_news_lg", "ru_core_news_md", "ru_core_news_sm"], # Russian
102
- "sl": ["sl_core_news_lg", "sl_core_news_md", "sl_core_news_sm"], # Slovenian
103
- "sv": ["sv_core_news_lg", "sv_core_news_md", "sv_core_news_sm"], # Swedish
104
- "uk": ["uk_core_news_lg", "uk_core_news_md", "uk_core_news_sm"], # Ukrainian
105
- "zh": ["zh_core_web_lg", "zh_core_web_mod", "zh_core_web_sm", "zh_core_web_trf"], # Chinese
106
-
 
 
 
 
 
 
 
 
107
  # Multilingual NER
108
  "xx": ["xx_ent_wiki_sm"],
109
  }
@@ -158,17 +187,22 @@ def load_spacy_model(language: str = DEFAULT_LANGUAGE):
158
  last_error = e
159
  continue
160
 
161
- raise RuntimeError(f"Failed to load spaCy model for language '{language}'. Last error: {last_error}")
 
 
 
162
 
163
  # Language-aware spaCy model loader
164
  def _normalize_language_input(language: str) -> str:
165
  return language.strip().lower().replace("-", "_")
166
 
 
167
  # Update the global variables to use the new function
168
  ACTIVE_LANGUAGE_CODE = _base_language_code(DEFAULT_LANGUAGE)
169
- nlp = None # Placeholder, will be loaded in the create_nlp_analyser function below #load_spacy_model(DEFAULT_LANGUAGE)
 
170
 
171
- def get_tesseract_lang_code(short_code:str):
172
  """
173
  Maps a two-letter language code to the corresponding Tesseract OCR code.
174
 
@@ -200,12 +234,15 @@ def get_tesseract_lang_code(short_code:str):
200
  "ru": "rus",
201
  "sl": "slv",
202
  "sv": "swe",
203
- "uk": "ukr"
204
  }
205
 
206
  return lang_map.get(short_code)
207
 
208
- def download_tesseract_lang_pack(short_lang_code:str, tessdata_dir=TESSERACT_DATA_FOLDER):
 
 
 
209
  """
210
  Downloads a Tesseract language pack to a local directory.
211
 
@@ -214,7 +251,7 @@ def download_tesseract_lang_pack(short_lang_code:str, tessdata_dir=TESSERACT_DAT
214
  tessdata_dir (str, optional): The directory to save the language pack.
215
  Defaults to "tessdata".
216
  """
217
-
218
  # Create the directory if it doesn't exist
219
  if not os.path.exists(tessdata_dir):
220
  os.makedirs(tessdata_dir)
@@ -223,16 +260,18 @@ def download_tesseract_lang_pack(short_lang_code:str, tessdata_dir=TESSERACT_DAT
223
  lang_code = get_tesseract_lang_code(short_lang_code)
224
 
225
  if lang_code is None:
226
- raise ValueError(f"Language code {short_lang_code} not found in Tesseract language map")
227
-
 
 
228
  # Set the local file path
229
  file_path = os.path.join(tessdata_dir, f"{lang_code}.traineddata")
230
-
231
  # Check if the file already exists
232
  if os.path.exists(file_path):
233
  print(f"Language pack {lang_code}.traineddata already exists at {file_path}")
234
  return file_path
235
-
236
  # Construct the URL for the language pack
237
  url = f"https://raw.githubusercontent.com/tesseract-ocr/tessdata/main/{lang_code}.traineddata"
238
 
@@ -252,35 +291,61 @@ def download_tesseract_lang_pack(short_lang_code:str, tessdata_dir=TESSERACT_DAT
252
  print(f"Error downloading {lang_code}.traineddata: {e}")
253
  return None
254
 
 
255
  #### Custom recognisers
256
- def custom_word_list_recogniser(custom_list:List[str]=[]):
257
  # Create regex pattern, handling quotes carefully
258
 
259
  quote_str = '"'
260
  replace_str = '(?:"|"|")'
261
 
262
- custom_regex = '|'.join(
263
- rf'(?<!\w){re.escape(term.strip()).replace(quote_str, replace_str)}(?!\w)'
264
  for term in custom_list
265
  )
266
- #print(custom_regex)
267
 
268
- custom_pattern = Pattern(name="custom_pattern", regex=custom_regex, score = 1)
269
-
270
- custom_recogniser = PatternRecognizer(supported_entity="CUSTOM", name="CUSTOM", patterns = [custom_pattern],
271
- global_regex_flags=re.DOTALL | re.MULTILINE | re.IGNORECASE)
 
 
 
 
272
 
273
  return custom_recogniser
274
 
 
275
  # Initialise custom recogniser that will be overwritten later
276
  custom_recogniser = custom_word_list_recogniser()
277
 
278
  # Custom title recogniser
279
- titles_list = ["Sir", "Ma'am", "Madam", "Mr", "Mr.", "Mrs", "Mrs.", "Ms", "Ms.", "Miss", "Dr", "Dr.", "Professor"]
280
- titles_regex = '\\b' + '\\b|\\b'.join(rf"{re.escape(title)}" for title in titles_list) + '\\b'
281
- titles_pattern = Pattern(name="titles_pattern",regex=titles_regex, score = 1)
282
- titles_recogniser = PatternRecognizer(supported_entity="TITLES", name="TITLES", patterns = [titles_pattern],
283
- global_regex_flags=re.DOTALL | re.MULTILINE)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
284
 
285
  # %%
286
  # Custom postcode recogniser
@@ -289,38 +354,117 @@ titles_recogniser = PatternRecognizer(supported_entity="TITLES", name="TITLES",
289
  ukpostcode_pattern = Pattern(
290
  name="ukpostcode_pattern",
291
  regex=r"\b([A-Z]{1,2}\d[A-Z\d]? ?\d[A-Z]{2}|GIR ?0AA)\b",
292
- score=1
293
  )
294
 
295
  # Define the recognizer with one or more patterns
296
- ukpostcode_recogniser = PatternRecognizer(supported_entity="UKPOSTCODE", name = "UKPOSTCODE", patterns = [ukpostcode_pattern])
 
 
297
 
298
  ### Street name
299
 
300
- def extract_street_name(text:str) -> str:
 
301
  """
302
  Extracts the street name and preceding word (that should contain at least one number) from the given text.
303
 
304
- """
305
-
306
  street_types = [
307
- 'Street', 'St', 'Boulevard', 'Blvd', 'Highway', 'Hwy', 'Broadway', 'Freeway',
308
- 'Causeway', 'Cswy', 'Expressway', 'Way', 'Walk', 'Lane', 'Ln', 'Road', 'Rd',
309
- 'Avenue', 'Ave', 'Circle', 'Cir', 'Cove', 'Cv', 'Drive', 'Dr', 'Parkway', 'Pkwy',
310
- 'Park', 'Court', 'Ct', 'Square', 'Sq', 'Loop', 'Place', 'Pl', 'Parade', 'Estate',
311
- 'Alley', 'Arcade', 'Avenue', 'Ave', 'Bay', 'Bend', 'Brae', 'Byway', 'Close', 'Corner', 'Cove',
312
- 'Crescent', 'Cres', 'Cul-de-sac', 'Dell', 'Drive', 'Dr', 'Esplanade', 'Glen', 'Green', 'Grove', 'Heights', 'Hts',
313
- 'Mews', 'Parade', 'Path', 'Piazza', 'Promenade', 'Quay', 'Ridge', 'Row', 'Terrace', 'Ter', 'Track', 'Trail', 'View', 'Villas',
314
- 'Marsh', 'Embankment', 'Cut', 'Hill', 'Passage', 'Rise', 'Vale', 'Side'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
315
  ]
316
 
317
  # Dynamically construct the regex pattern with all possible street types
318
- street_types_pattern = '|'.join(rf"{re.escape(street_type)}" for street_type in street_types)
 
 
319
 
320
  # The overall regex pattern to capture the street name and preceding word(s)
321
 
322
- pattern = rf'(?P<preceding_word>\w*\d\w*)\s*'
323
- pattern += rf'(?P<street_name>\w+\s*\b(?:{street_types_pattern})\b)'
324
 
325
  # Find all matches in text
326
  matches = re.finditer(pattern, text, re.DOTALL | re.MULTILINE | re.IGNORECASE)
@@ -329,26 +473,29 @@ def extract_street_name(text:str) -> str:
329
  end_positions = []
330
 
331
  for match in matches:
332
- preceding_word = match.group('preceding_word').strip()
333
- street_name = match.group('street_name').strip()
334
  start_pos = match.start()
335
  end_pos = match.end()
336
- #print(f"Start: {start_pos}, End: {end_pos}")
337
- #print(f"Preceding words: {preceding_word}")
338
- #print(f"Street name: {street_name}")
339
 
340
  start_positions.append(start_pos)
341
  end_positions.append(end_pos)
342
 
343
  return start_positions, end_positions
344
 
 
345
  class StreetNameRecognizer(EntityRecognizer):
346
 
347
  def load(self) -> None:
348
  """No loading is required."""
349
  pass
350
 
351
- def analyze(self, text: str, entities: List[str], nlp_artifacts: NlpArtifacts) -> List[RecognizerResult]:
 
 
352
  """
353
  Logic for detecting a specific PII
354
  """
@@ -360,32 +507,33 @@ class StreetNameRecognizer(EntityRecognizer):
360
  for i in range(0, len(start_pos)):
361
 
362
  result = RecognizerResult(
363
- entity_type="STREETNAME",
364
- start = start_pos[i],
365
- end = end_pos[i],
366
- score= 1
367
- )
368
-
369
  results.append(result)
370
-
371
  return results
372
-
 
373
  street_recogniser = StreetNameRecognizer(supported_entities=["STREETNAME"])
374
 
 
375
  ## Custom fuzzy match recogniser for list of strings
376
- def custom_fuzzy_word_list_regex(text:str, custom_list:List[str]=[]):
377
  # Create regex pattern, handling quotes carefully
378
 
379
  quote_str = '"'
380
  replace_str = '(?:"|"|")'
381
 
382
- custom_regex_pattern = '|'.join(
383
- rf'(?<!\w){re.escape(term.strip()).replace(quote_str, replace_str)}(?!\w)'
384
  for term in custom_list
385
  )
386
 
387
  # Find all matches in text
388
- matches = re.finditer(custom_regex_pattern, text, re.DOTALL | re.MULTILINE | re.IGNORECASE)
 
 
389
 
390
  start_positions = []
391
  end_positions = []
@@ -401,46 +549,69 @@ def custom_fuzzy_word_list_regex(text:str, custom_list:List[str]=[]):
401
 
402
 
403
  class CustomWordFuzzyRecognizer(EntityRecognizer):
404
- def __init__(self, supported_entities: List[str], custom_list: List[str] = [], spelling_mistakes_max: int = 1, search_whole_phrase: bool = True):
 
 
 
 
 
 
405
  super().__init__(supported_entities=supported_entities)
406
  self.custom_list = custom_list # Store the custom_list as an instance attribute
407
- self.spelling_mistakes_max = spelling_mistakes_max # Store the max spelling mistakes
408
- self.search_whole_phrase = search_whole_phrase # Store the search whole phrase flag
 
 
 
 
409
 
410
  def load(self) -> None:
411
  """No loading is required."""
412
  pass
413
 
414
- def analyze(self, text: str, entities: List[str], nlp_artifacts: NlpArtifacts) -> List[RecognizerResult]:
 
 
415
  """
416
  Logic for detecting a specific PII
417
  """
418
- start_pos, end_pos = spacy_fuzzy_search(text, self.custom_list, self.spelling_mistakes_max, self.search_whole_phrase) # Pass new parameters
 
 
419
 
420
  results = []
421
 
422
  for i in range(0, len(start_pos)):
423
  result = RecognizerResult(
424
- entity_type="CUSTOM_FUZZY",
425
- start=start_pos[i],
426
- end=end_pos[i],
427
- score=1
428
  )
429
  results.append(result)
430
 
431
  return results
432
-
 
433
  custom_list_default = []
434
- custom_word_fuzzy_recognizer = CustomWordFuzzyRecognizer(supported_entities=["CUSTOM_FUZZY"], custom_list=custom_list_default)
 
 
435
 
436
  # Pass the loaded model to the new LoadedSpacyNlpEngine
437
- loaded_nlp_engine = LoadedSpacyNlpEngine(loaded_spacy_model = nlp, language_code = ACTIVE_LANGUAGE_CODE)
 
 
 
438
 
439
- def create_nlp_analyser(language: str = DEFAULT_LANGUAGE, custom_list: List[str] = None,
440
- spelling_mistakes_max: int = 1, search_whole_phrase: bool = True, existing_nlp_analyser: AnalyzerEngine = None, return_also_model: bool = False):
 
 
 
 
 
 
441
  """
442
  Create an nlp_analyser object based on the specified language input.
443
-
444
  Args:
445
  language (str): Language code (e.g., "en", "de", "fr", "es", etc.)
446
  custom_list (List[str], optional): List of custom words to recognize. Defaults to None.
@@ -448,12 +619,12 @@ def create_nlp_analyser(language: str = DEFAULT_LANGUAGE, custom_list: List[str]
448
  search_whole_phrase (bool, optional): Whether to search for whole phrases or individual words. Defaults to True.
449
  existing_nlp_analyser (AnalyzerEngine, optional): Existing nlp_analyser object to use. Defaults to None.
450
  return_also_model (bool, optional): Whether to return the nlp_model object as well. Defaults to False.
451
-
452
  Returns:
453
  AnalyzerEngine: Configured nlp_analyser object with custom recognizers
454
  """
455
 
456
- if existing_nlp_analyser is None:
457
  pass
458
  else:
459
  if existing_nlp_analyser.supported_languages[0] == language:
@@ -463,28 +634,27 @@ def create_nlp_analyser(language: str = DEFAULT_LANGUAGE, custom_list: List[str]
463
 
464
  # Load spaCy model for the specified language
465
  nlp_model = load_spacy_model(language)
466
-
467
  # Get base language code
468
  base_lang_code = _base_language_code(language)
469
-
470
  # Create custom recognizers
471
  if custom_list is None:
472
  custom_list = []
473
-
474
  custom_recogniser = custom_word_list_recogniser(custom_list)
475
  custom_word_fuzzy_recognizer = CustomWordFuzzyRecognizer(
476
- supported_entities=["CUSTOM_FUZZY"],
477
  custom_list=custom_list,
478
  spelling_mistakes_max=spelling_mistakes_max,
479
- search_whole_phrase=search_whole_phrase
480
  )
481
-
482
  # Create NLP engine with loaded model
483
  loaded_nlp_engine = LoadedSpacyNlpEngine(
484
- loaded_spacy_model=nlp_model,
485
- language_code=base_lang_code
486
  )
487
-
488
  # Create analyzer engine
489
  nlp_analyser = AnalyzerEngine(
490
  nlp_engine=loaded_nlp_engine,
@@ -492,11 +662,11 @@ def create_nlp_analyser(language: str = DEFAULT_LANGUAGE, custom_list: List[str]
492
  supported_languages=[base_lang_code],
493
  log_decision_process=False,
494
  )
495
-
496
  # Add custom recognizers to nlp_analyser
497
  nlp_analyser.registry.add_recognizer(custom_recogniser)
498
  nlp_analyser.registry.add_recognizer(custom_word_fuzzy_recognizer)
499
-
500
  # Add language-specific recognizers for English
501
  if base_lang_code == "en":
502
  nlp_analyser.registry.add_recognizer(street_recogniser)
@@ -505,21 +675,30 @@ def create_nlp_analyser(language: str = DEFAULT_LANGUAGE, custom_list: List[str]
505
 
506
  if return_also_model:
507
  return nlp_analyser, nlp_model
508
-
509
  return nlp_analyser
510
 
 
511
  # Create the default nlp_analyser using the new function
512
  nlp_analyser, nlp = create_nlp_analyser(DEFAULT_LANGUAGE, return_also_model=True)
513
 
514
- def spacy_fuzzy_search(text: str, custom_query_list:List[str]=[], spelling_mistakes_max:int = 1, search_whole_phrase:bool=True, nlp=nlp, progress=gr.Progress(track_tqdm=True)):
515
- ''' Conduct fuzzy match on a list of text data.'''
 
 
 
 
 
 
 
 
516
 
517
  all_matches = []
518
  all_start_positions = []
519
  all_end_positions = []
520
  all_ratios = []
521
 
522
- #print("custom_query_list:", custom_query_list)
523
 
524
  if not text:
525
  out_message = "No text data found. Skipping page."
@@ -530,23 +709,31 @@ def spacy_fuzzy_search(text: str, custom_query_list:List[str]=[], spelling_mista
530
 
531
  query = nlp(string_query)
532
 
533
- if search_whole_phrase == False:
534
  # Keep only words that are not stop words
535
- token_query = [token.text for token in query if not token.is_space and not token.is_stop and not token.is_punct]
 
 
 
 
536
 
537
  spelling_mistakes_fuzzy_pattern = "FUZZY" + str(spelling_mistakes_max)
538
 
539
  if len(token_query) > 1:
540
- #pattern_lemma = [{"LEMMA": {"IN": query}}]
541
- pattern_fuzz = [{"TEXT": {spelling_mistakes_fuzzy_pattern: {"IN": token_query}}}]
 
 
542
  else:
543
- #pattern_lemma = [{"LEMMA": query[0]}]
544
- pattern_fuzz = [{"TEXT": {spelling_mistakes_fuzzy_pattern: token_query[0]}}]
 
 
545
 
546
- matcher = Matcher(nlp.vocab)
547
  matcher.add(string_query, [pattern_fuzz])
548
- #matcher.add(string_query, [pattern_lemma])
549
-
550
  else:
551
  # If matching a whole phrase, use Spacy PhraseMatcher, then consider similarity after using Levenshtein distance.
552
  # If you want to match the whole phrase, use phrase matcher
@@ -558,52 +745,51 @@ def spacy_fuzzy_search(text: str, custom_query_list:List[str]=[], spelling_mista
558
  docs = nlp.pipe([text], batch_size=batch_size)
559
 
560
  # Get number of matches per doc
561
- for doc in docs: #progress.tqdm(docs, desc = "Searching text", unit = "rows"):
562
  matches = matcher(doc)
563
  match_count = len(matches)
564
 
565
  # If considering each sub term individually, append match. If considering together, consider weight of the relevance to that of the whole phrase.
566
- if search_whole_phrase==False:
567
  all_matches.append(match_count)
568
 
569
  for match_id, start, end in matches:
570
  span = str(doc[start:end]).strip()
571
  query_search = str(query).strip()
572
 
573
-
574
  # Convert word positions to character positions
575
  start_char = doc[start].idx # Start character position
576
- end_char = doc[end - 1].idx + len(doc[end - 1]) # End character position
 
 
577
 
578
  # The positions here are word position, not character position
579
  all_matches.append(match_count)
580
  all_start_positions.append(start_char)
581
  all_end_positions.append(end_char)
582
-
583
  else:
584
  for match_id, start, end, ratio, pattern in matches:
585
  span = str(doc[start:end]).strip()
586
  query_search = str(query).strip()
587
-
588
  # Calculate Levenshtein distance. Only keep matches with less than specified number of spelling mistakes
589
  distance = Levenshtein.distance(query_search.lower(), span.lower())
590
 
591
- #print("Levenshtein distance:", distance)
592
-
593
- if distance > spelling_mistakes_max:
594
  match_count = match_count - 1
595
  else:
596
  # Convert word positions to character positions
597
  start_char = doc[start].idx # Start character position
598
- end_char = doc[end - 1].idx + len(doc[end - 1]) # End character position
 
 
599
 
600
  all_matches.append(match_count)
601
  all_start_positions.append(start_char)
602
  all_end_positions.append(end_char)
603
- all_ratios.append(ratio)
604
-
605
 
606
  return all_start_positions, all_end_positions
607
-
608
-
609
-
 
1
  from typing import List
2
+
 
3
  import spacy
4
+ from presidio_analyzer import (
5
+ AnalyzerEngine,
6
+ EntityRecognizer,
7
+ Pattern,
8
+ PatternRecognizer,
9
+ RecognizerResult,
10
+ )
11
+ from presidio_analyzer.nlp_engine import (
12
+ NerModelConfiguration,
13
+ NlpArtifacts,
14
+ SpacyNlpEngine,
15
+ )
16
  from spacy.matcher import Matcher
17
  from spaczz.matcher import FuzzyMatcher
18
+
19
  spacy.prefer_gpu()
 
 
 
20
  import os
21
+ import re
22
+
23
  import gradio as gr
24
+ import Levenshtein
25
+ import requests
26
+ from spacy.cli.download import download
27
+
28
+ from tools.config import CUSTOM_ENTITIES, DEFAULT_LANGUAGE, TESSERACT_DATA_FOLDER
29
  from tools.helper_functions import _get_env_list
30
 
31
  score_threshold = 0.001
32
 
33
+ if CUSTOM_ENTITIES:
34
+ CUSTOM_ENTITIES = _get_env_list(CUSTOM_ENTITIES)
35
  custom_entities = CUSTOM_ENTITIES
36
 
37
+
38
  # Create a class inheriting from SpacyNlpEngine
39
  class LoadedSpacyNlpEngine(SpacyNlpEngine):
40
  def __init__(self, loaded_spacy_model, language_code: str):
41
+ super().__init__(
42
+ ner_model_configuration=NerModelConfiguration(
43
+ labels_to_ignore=["CARDINAL", "ORDINAL"]
44
+ )
45
+ ) # Ignore non-relevant labels
46
  self.nlp = {language_code: loaded_spacy_model}
47
 
48
+
49
  def _base_language_code(language: str) -> str:
50
  lang = _normalize_language_input(language)
51
  if "_" in lang:
52
  return lang.split("_")[0]
53
  return lang
54
 
55
+
56
  def load_spacy_model(language: str = DEFAULT_LANGUAGE):
57
  """
58
  Load a spaCy model for the requested language and return it as `nlp`.
 
100
  "en_trf": ["en_core_web_trf"],
101
  "en_md": ["en_core_web_md"],
102
  "en_sm": ["en_core_web_sm"],
 
103
  # Major languages (news pipelines)
104
+ "ca": ["ca_core_news_lg", "ca_core_news_md", "ca_core_news_sm"], # Catalan
105
+ "da": ["da_core_news_lg", "da_core_news_md", "da_core_news_sm"], # Danish
106
+ "de": ["de_core_news_lg", "de_core_news_md", "de_core_news_sm"], # German
107
+ "el": ["el_core_news_lg", "el_core_news_md", "el_core_news_sm"], # Greek
108
+ "es": ["es_core_news_lg", "es_core_news_md", "es_core_news_sm"], # Spanish
109
+ "fi": ["fi_core_news_lg", "fi_core_news_md", "fi_core_news_sm"], # Finnish
110
+ "fr": ["fr_core_news_lg", "fr_core_news_md", "fr_core_news_sm"], # French
111
+ "hr": ["hr_core_news_lg", "hr_core_news_md", "hr_core_news_sm"], # Croatian
112
+ "it": ["it_core_news_lg", "it_core_news_md", "it_core_news_sm"], # Italian
113
+ "ja": ["ja_core_news_lg", "ja_core_news_md", "ja_core_news_sm"], # Japanese
114
+ "ko": ["ko_core_news_lg", "ko_core_news_md", "ko_core_news_sm"], # Korean
115
+ "lt": ["lt_core_news_lg", "lt_core_news_md", "lt_core_news_sm"], # Lithuanian
116
+ "mk": ["mk_core_news_lg", "mk_core_news_md", "mk_core_news_sm"], # Macedonian
117
+ "nb": [
118
+ "nb_core_news_lg",
119
+ "nb_core_news_md",
120
+ "nb_core_news_sm",
121
+ ], # Norwegian Bokmål
122
+ "nl": ["nl_core_news_lg", "nl_core_news_md", "nl_core_news_sm"], # Dutch
123
+ "pl": ["pl_core_news_lg", "pl_core_news_md", "pl_core_news_sm"], # Polish
124
+ "pt": ["pt_core_news_lg", "pt_core_news_md", "pt_core_news_sm"], # Portuguese
125
+ "ro": ["ro_core_news_lg", "ro_core_news_md", "ro_core_news_sm"], # Romanian
126
+ "ru": ["ru_core_news_lg", "ru_core_news_md", "ru_core_news_sm"], # Russian
127
+ "sl": ["sl_core_news_lg", "sl_core_news_md", "sl_core_news_sm"], # Slovenian
128
+ "sv": ["sv_core_news_lg", "sv_core_news_md", "sv_core_news_sm"], # Swedish
129
+ "uk": ["uk_core_news_lg", "uk_core_news_md", "uk_core_news_sm"], # Ukrainian
130
+ "zh": [
131
+ "zh_core_web_lg",
132
+ "zh_core_web_mod",
133
+ "zh_core_web_sm",
134
+ "zh_core_web_trf",
135
+ ], # Chinese
136
  # Multilingual NER
137
  "xx": ["xx_ent_wiki_sm"],
138
  }
 
187
  last_error = e
188
  continue
189
 
190
+ raise RuntimeError(
191
+ f"Failed to load spaCy model for language '{language}'. Last error: {last_error}"
192
+ )
193
+
194
 
195
  # Language-aware spaCy model loader
196
  def _normalize_language_input(language: str) -> str:
197
  return language.strip().lower().replace("-", "_")
198
 
199
+
200
  # Update the global variables to use the new function
201
  ACTIVE_LANGUAGE_CODE = _base_language_code(DEFAULT_LANGUAGE)
202
+ nlp = None # Placeholder, will be loaded in the create_nlp_analyser function below #load_spacy_model(DEFAULT_LANGUAGE)
203
+
204
 
205
+ def get_tesseract_lang_code(short_code: str):
206
  """
207
  Maps a two-letter language code to the corresponding Tesseract OCR code.
208
 
 
234
  "ru": "rus",
235
  "sl": "slv",
236
  "sv": "swe",
237
+ "uk": "ukr",
238
  }
239
 
240
  return lang_map.get(short_code)
241
 
242
+
243
+ def download_tesseract_lang_pack(
244
+ short_lang_code: str, tessdata_dir=TESSERACT_DATA_FOLDER
245
+ ):
246
  """
247
  Downloads a Tesseract language pack to a local directory.
248
 
 
251
  tessdata_dir (str, optional): The directory to save the language pack.
252
  Defaults to "tessdata".
253
  """
254
+
255
  # Create the directory if it doesn't exist
256
  if not os.path.exists(tessdata_dir):
257
  os.makedirs(tessdata_dir)
 
260
  lang_code = get_tesseract_lang_code(short_lang_code)
261
 
262
  if lang_code is None:
263
+ raise ValueError(
264
+ f"Language code {short_lang_code} not found in Tesseract language map"
265
+ )
266
+
267
  # Set the local file path
268
  file_path = os.path.join(tessdata_dir, f"{lang_code}.traineddata")
269
+
270
  # Check if the file already exists
271
  if os.path.exists(file_path):
272
  print(f"Language pack {lang_code}.traineddata already exists at {file_path}")
273
  return file_path
274
+
275
  # Construct the URL for the language pack
276
  url = f"https://raw.githubusercontent.com/tesseract-ocr/tessdata/main/{lang_code}.traineddata"
277
 
 
291
  print(f"Error downloading {lang_code}.traineddata: {e}")
292
  return None
293
 
294
+
295
  #### Custom recognisers
296
+ def custom_word_list_recogniser(custom_list: List[str] = []):
297
  # Create regex pattern, handling quotes carefully
298
 
299
  quote_str = '"'
300
  replace_str = '(?:"|"|")'
301
 
302
+ custom_regex = "|".join(
303
+ rf"(?<!\w){re.escape(term.strip()).replace(quote_str, replace_str)}(?!\w)"
304
  for term in custom_list
305
  )
306
+ # print(custom_regex)
307
 
308
+ custom_pattern = Pattern(name="custom_pattern", regex=custom_regex, score=1)
309
+
310
+ custom_recogniser = PatternRecognizer(
311
+ supported_entity="CUSTOM",
312
+ name="CUSTOM",
313
+ patterns=[custom_pattern],
314
+ global_regex_flags=re.DOTALL | re.MULTILINE | re.IGNORECASE,
315
+ )
316
 
317
  return custom_recogniser
318
 
319
+
320
  # Initialise custom recogniser that will be overwritten later
321
  custom_recogniser = custom_word_list_recogniser()
322
 
323
  # Custom title recogniser
324
+ titles_list = [
325
+ "Sir",
326
+ "Ma'am",
327
+ "Madam",
328
+ "Mr",
329
+ "Mr.",
330
+ "Mrs",
331
+ "Mrs.",
332
+ "Ms",
333
+ "Ms.",
334
+ "Miss",
335
+ "Dr",
336
+ "Dr.",
337
+ "Professor",
338
+ ]
339
+ titles_regex = (
340
+ "\\b" + "\\b|\\b".join(rf"{re.escape(title)}" for title in titles_list) + "\\b"
341
+ )
342
+ titles_pattern = Pattern(name="titles_pattern", regex=titles_regex, score=1)
343
+ titles_recogniser = PatternRecognizer(
344
+ supported_entity="TITLES",
345
+ name="TITLES",
346
+ patterns=[titles_pattern],
347
+ global_regex_flags=re.DOTALL | re.MULTILINE,
348
+ )
349
 
350
  # %%
351
  # Custom postcode recogniser
 
354
  ukpostcode_pattern = Pattern(
355
  name="ukpostcode_pattern",
356
  regex=r"\b([A-Z]{1,2}\d[A-Z\d]? ?\d[A-Z]{2}|GIR ?0AA)\b",
357
+ score=1,
358
  )
359
 
360
  # Define the recognizer with one or more patterns
361
+ ukpostcode_recogniser = PatternRecognizer(
362
+ supported_entity="UKPOSTCODE", name="UKPOSTCODE", patterns=[ukpostcode_pattern]
363
+ )
364
 
365
  ### Street name
366
 
367
+
368
+ def extract_street_name(text: str) -> str:
369
  """
370
  Extracts the street name and preceding word (that should contain at least one number) from the given text.
371
 
372
+ """
373
+
374
  street_types = [
375
+ "Street",
376
+ "St",
377
+ "Boulevard",
378
+ "Blvd",
379
+ "Highway",
380
+ "Hwy",
381
+ "Broadway",
382
+ "Freeway",
383
+ "Causeway",
384
+ "Cswy",
385
+ "Expressway",
386
+ "Way",
387
+ "Walk",
388
+ "Lane",
389
+ "Ln",
390
+ "Road",
391
+ "Rd",
392
+ "Avenue",
393
+ "Ave",
394
+ "Circle",
395
+ "Cir",
396
+ "Cove",
397
+ "Cv",
398
+ "Drive",
399
+ "Dr",
400
+ "Parkway",
401
+ "Pkwy",
402
+ "Park",
403
+ "Court",
404
+ "Ct",
405
+ "Square",
406
+ "Sq",
407
+ "Loop",
408
+ "Place",
409
+ "Pl",
410
+ "Parade",
411
+ "Estate",
412
+ "Alley",
413
+ "Arcade",
414
+ "Avenue",
415
+ "Ave",
416
+ "Bay",
417
+ "Bend",
418
+ "Brae",
419
+ "Byway",
420
+ "Close",
421
+ "Corner",
422
+ "Cove",
423
+ "Crescent",
424
+ "Cres",
425
+ "Cul-de-sac",
426
+ "Dell",
427
+ "Drive",
428
+ "Dr",
429
+ "Esplanade",
430
+ "Glen",
431
+ "Green",
432
+ "Grove",
433
+ "Heights",
434
+ "Hts",
435
+ "Mews",
436
+ "Parade",
437
+ "Path",
438
+ "Piazza",
439
+ "Promenade",
440
+ "Quay",
441
+ "Ridge",
442
+ "Row",
443
+ "Terrace",
444
+ "Ter",
445
+ "Track",
446
+ "Trail",
447
+ "View",
448
+ "Villas",
449
+ "Marsh",
450
+ "Embankment",
451
+ "Cut",
452
+ "Hill",
453
+ "Passage",
454
+ "Rise",
455
+ "Vale",
456
+ "Side",
457
  ]
458
 
459
  # Dynamically construct the regex pattern with all possible street types
460
+ street_types_pattern = "|".join(
461
+ rf"{re.escape(street_type)}" for street_type in street_types
462
+ )
463
 
464
  # The overall regex pattern to capture the street name and preceding word(s)
465
 
466
+ pattern = r"(?P<preceding_word>\w*\d\w*)\s*"
467
+ pattern += rf"(?P<street_name>\w+\s*\b(?:{street_types_pattern})\b)"
468
 
469
  # Find all matches in text
470
  matches = re.finditer(pattern, text, re.DOTALL | re.MULTILINE | re.IGNORECASE)
 
473
  end_positions = []
474
 
475
  for match in matches:
476
+ match.group("preceding_word").strip()
477
+ match.group("street_name").strip()
478
  start_pos = match.start()
479
  end_pos = match.end()
480
+ # print(f"Start: {start_pos}, End: {end_pos}")
481
+ # print(f"Preceding words: {preceding_word}")
482
+ # print(f"Street name: {street_name}")
483
 
484
  start_positions.append(start_pos)
485
  end_positions.append(end_pos)
486
 
487
  return start_positions, end_positions
488
 
489
+
490
  class StreetNameRecognizer(EntityRecognizer):
491
 
492
  def load(self) -> None:
493
  """No loading is required."""
494
  pass
495
 
496
+ def analyze(
497
+ self, text: str, entities: List[str], nlp_artifacts: NlpArtifacts
498
+ ) -> List[RecognizerResult]:
499
  """
500
  Logic for detecting a specific PII
501
  """
 
507
  for i in range(0, len(start_pos)):
508
 
509
  result = RecognizerResult(
510
+ entity_type="STREETNAME", start=start_pos[i], end=end_pos[i], score=1
511
+ )
512
+
 
 
 
513
  results.append(result)
514
+
515
  return results
516
+
517
+
518
  street_recogniser = StreetNameRecognizer(supported_entities=["STREETNAME"])
519
 
520
+
521
  ## Custom fuzzy match recogniser for list of strings
522
+ def custom_fuzzy_word_list_regex(text: str, custom_list: List[str] = []):
523
  # Create regex pattern, handling quotes carefully
524
 
525
  quote_str = '"'
526
  replace_str = '(?:"|"|")'
527
 
528
+ custom_regex_pattern = "|".join(
529
+ rf"(?<!\w){re.escape(term.strip()).replace(quote_str, replace_str)}(?!\w)"
530
  for term in custom_list
531
  )
532
 
533
  # Find all matches in text
534
+ matches = re.finditer(
535
+ custom_regex_pattern, text, re.DOTALL | re.MULTILINE | re.IGNORECASE
536
+ )
537
 
538
  start_positions = []
539
  end_positions = []
 
549
 
550
 
551
  class CustomWordFuzzyRecognizer(EntityRecognizer):
552
+ def __init__(
553
+ self,
554
+ supported_entities: List[str],
555
+ custom_list: List[str] = [],
556
+ spelling_mistakes_max: int = 1,
557
+ search_whole_phrase: bool = True,
558
+ ):
559
  super().__init__(supported_entities=supported_entities)
560
  self.custom_list = custom_list # Store the custom_list as an instance attribute
561
+ self.spelling_mistakes_max = (
562
+ spelling_mistakes_max # Store the max spelling mistakes
563
+ )
564
+ self.search_whole_phrase = (
565
+ search_whole_phrase # Store the search whole phrase flag
566
+ )
567
 
568
  def load(self) -> None:
569
  """No loading is required."""
570
  pass
571
 
572
+ def analyze(
573
+ self, text: str, entities: List[str], nlp_artifacts: NlpArtifacts
574
+ ) -> List[RecognizerResult]:
575
  """
576
  Logic for detecting a specific PII
577
  """
578
+ start_pos, end_pos = spacy_fuzzy_search(
579
+ text, self.custom_list, self.spelling_mistakes_max, self.search_whole_phrase
580
+ ) # Pass new parameters
581
 
582
  results = []
583
 
584
  for i in range(0, len(start_pos)):
585
  result = RecognizerResult(
586
+ entity_type="CUSTOM_FUZZY", start=start_pos[i], end=end_pos[i], score=1
 
 
 
587
  )
588
  results.append(result)
589
 
590
  return results
591
+
592
+
593
  custom_list_default = []
594
+ custom_word_fuzzy_recognizer = CustomWordFuzzyRecognizer(
595
+ supported_entities=["CUSTOM_FUZZY"], custom_list=custom_list_default
596
+ )
597
 
598
  # Pass the loaded model to the new LoadedSpacyNlpEngine
599
+ loaded_nlp_engine = LoadedSpacyNlpEngine(
600
+ loaded_spacy_model=nlp, language_code=ACTIVE_LANGUAGE_CODE
601
+ )
602
+
603
 
604
+ def create_nlp_analyser(
605
+ language: str = DEFAULT_LANGUAGE,
606
+ custom_list: List[str] = None,
607
+ spelling_mistakes_max: int = 1,
608
+ search_whole_phrase: bool = True,
609
+ existing_nlp_analyser: AnalyzerEngine = None,
610
+ return_also_model: bool = False,
611
+ ):
612
  """
613
  Create an nlp_analyser object based on the specified language input.
614
+
615
  Args:
616
  language (str): Language code (e.g., "en", "de", "fr", "es", etc.)
617
  custom_list (List[str], optional): List of custom words to recognize. Defaults to None.
 
619
  search_whole_phrase (bool, optional): Whether to search for whole phrases or individual words. Defaults to True.
620
  existing_nlp_analyser (AnalyzerEngine, optional): Existing nlp_analyser object to use. Defaults to None.
621
  return_also_model (bool, optional): Whether to return the nlp_model object as well. Defaults to False.
622
+
623
  Returns:
624
  AnalyzerEngine: Configured nlp_analyser object with custom recognizers
625
  """
626
 
627
+ if existing_nlp_analyser is None:
628
  pass
629
  else:
630
  if existing_nlp_analyser.supported_languages[0] == language:
 
634
 
635
  # Load spaCy model for the specified language
636
  nlp_model = load_spacy_model(language)
637
+
638
  # Get base language code
639
  base_lang_code = _base_language_code(language)
640
+
641
  # Create custom recognizers
642
  if custom_list is None:
643
  custom_list = []
644
+
645
  custom_recogniser = custom_word_list_recogniser(custom_list)
646
  custom_word_fuzzy_recognizer = CustomWordFuzzyRecognizer(
647
+ supported_entities=["CUSTOM_FUZZY"],
648
  custom_list=custom_list,
649
  spelling_mistakes_max=spelling_mistakes_max,
650
+ search_whole_phrase=search_whole_phrase,
651
  )
652
+
653
  # Create NLP engine with loaded model
654
  loaded_nlp_engine = LoadedSpacyNlpEngine(
655
+ loaded_spacy_model=nlp_model, language_code=base_lang_code
 
656
  )
657
+
658
  # Create analyzer engine
659
  nlp_analyser = AnalyzerEngine(
660
  nlp_engine=loaded_nlp_engine,
 
662
  supported_languages=[base_lang_code],
663
  log_decision_process=False,
664
  )
665
+
666
  # Add custom recognizers to nlp_analyser
667
  nlp_analyser.registry.add_recognizer(custom_recogniser)
668
  nlp_analyser.registry.add_recognizer(custom_word_fuzzy_recognizer)
669
+
670
  # Add language-specific recognizers for English
671
  if base_lang_code == "en":
672
  nlp_analyser.registry.add_recognizer(street_recogniser)
 
675
 
676
  if return_also_model:
677
  return nlp_analyser, nlp_model
678
+
679
  return nlp_analyser
680
 
681
+
682
  # Create the default nlp_analyser using the new function
683
  nlp_analyser, nlp = create_nlp_analyser(DEFAULT_LANGUAGE, return_also_model=True)
684
 
685
+
686
+ def spacy_fuzzy_search(
687
+ text: str,
688
+ custom_query_list: List[str] = [],
689
+ spelling_mistakes_max: int = 1,
690
+ search_whole_phrase: bool = True,
691
+ nlp=nlp,
692
+ progress=gr.Progress(track_tqdm=True),
693
+ ):
694
+ """Conduct fuzzy match on a list of text data."""
695
 
696
  all_matches = []
697
  all_start_positions = []
698
  all_end_positions = []
699
  all_ratios = []
700
 
701
+ # print("custom_query_list:", custom_query_list)
702
 
703
  if not text:
704
  out_message = "No text data found. Skipping page."
 
709
 
710
  query = nlp(string_query)
711
 
712
+ if search_whole_phrase is False:
713
  # Keep only words that are not stop words
714
+ token_query = [
715
+ token.text
716
+ for token in query
717
+ if not token.is_space and not token.is_stop and not token.is_punct
718
+ ]
719
 
720
  spelling_mistakes_fuzzy_pattern = "FUZZY" + str(spelling_mistakes_max)
721
 
722
  if len(token_query) > 1:
723
+ # pattern_lemma = [{"LEMMA": {"IN": query}}]
724
+ pattern_fuzz = [
725
+ {"TEXT": {spelling_mistakes_fuzzy_pattern: {"IN": token_query}}}
726
+ ]
727
  else:
728
+ # pattern_lemma = [{"LEMMA": query[0]}]
729
+ pattern_fuzz = [
730
+ {"TEXT": {spelling_mistakes_fuzzy_pattern: token_query[0]}}
731
+ ]
732
 
733
+ matcher = Matcher(nlp.vocab)
734
  matcher.add(string_query, [pattern_fuzz])
735
+ # matcher.add(string_query, [pattern_lemma])
736
+
737
  else:
738
  # If matching a whole phrase, use Spacy PhraseMatcher, then consider similarity after using Levenshtein distance.
739
  # If you want to match the whole phrase, use phrase matcher
 
745
  docs = nlp.pipe([text], batch_size=batch_size)
746
 
747
  # Get number of matches per doc
748
+ for doc in docs: # progress.tqdm(docs, desc = "Searching text", unit = "rows"):
749
  matches = matcher(doc)
750
  match_count = len(matches)
751
 
752
  # If considering each sub term individually, append match. If considering together, consider weight of the relevance to that of the whole phrase.
753
+ if search_whole_phrase is False:
754
  all_matches.append(match_count)
755
 
756
  for match_id, start, end in matches:
757
  span = str(doc[start:end]).strip()
758
  query_search = str(query).strip()
759
 
 
760
  # Convert word positions to character positions
761
  start_char = doc[start].idx # Start character position
762
+ end_char = doc[end - 1].idx + len(
763
+ doc[end - 1]
764
+ ) # End character position
765
 
766
  # The positions here are word position, not character position
767
  all_matches.append(match_count)
768
  all_start_positions.append(start_char)
769
  all_end_positions.append(end_char)
770
+
771
  else:
772
  for match_id, start, end, ratio, pattern in matches:
773
  span = str(doc[start:end]).strip()
774
  query_search = str(query).strip()
775
+
776
  # Calculate Levenshtein distance. Only keep matches with less than specified number of spelling mistakes
777
  distance = Levenshtein.distance(query_search.lower(), span.lower())
778
 
779
+ # print("Levenshtein distance:", distance)
780
+
781
+ if distance > spelling_mistakes_max:
782
  match_count = match_count - 1
783
  else:
784
  # Convert word positions to character positions
785
  start_char = doc[start].idx # Start character position
786
+ end_char = doc[end - 1].idx + len(
787
+ doc[end - 1]
788
+ ) # End character position
789
 
790
  all_matches.append(match_count)
791
  all_start_positions.append(start_char)
792
  all_end_positions.append(end_char)
793
+ all_ratios.append(ratio)
 
794
 
795
  return all_start_positions, all_end_positions
 
 
 
tools/presidio_analyzer_custom.py CHANGED
@@ -1,10 +1,12 @@
 
 
1
  import gradio as gr
2
- from typing import List, Iterable, Dict, Union, Any, Optional, Iterator, Tuple
3
- #from tqdm import tqdm
4
 
 
5
  from presidio_analyzer import DictAnalyzerResult, RecognizerResult
6
  from presidio_analyzer.nlp_engine import NlpArtifacts
7
 
 
8
  def recognizer_result_from_dict(data: Dict) -> RecognizerResult:
9
  """
10
  Create RecognizerResult from a dictionary.
@@ -25,116 +27,116 @@ def recognizer_result_from_dict(data: Dict) -> RecognizerResult:
25
  score = data.get("Score")
26
  analysis_explanation = None
27
  recognition_metadata = None
28
-
29
- return RecognizerResult(entity_type, start, end, score, analysis_explanation, recognition_metadata)
 
 
 
30
 
31
  def analyze_iterator_custom(
32
- self,
33
- texts: Iterable[Union[str, bool, float, int]],
34
- language: str,
35
- list_length:int,
36
- progress=gr.Progress(),
37
- **kwargs,
38
- ) -> List[List[RecognizerResult]]:
39
- """
40
- Analyze an iterable of strings.
41
-
42
- :param texts: An list containing strings to be analyzed.
43
- :param language: Input language
44
- :param list_length: Length of the input list.
45
- :param kwargs: Additional parameters for the `AnalyzerEngine.analyze` method.
46
- """
47
-
48
- # validate types
49
- texts = self._validate_types(texts)
50
-
51
- # Process the texts as batch for improved performance
52
- nlp_artifacts_batch: Iterator[
53
- Tuple[str, NlpArtifacts]
54
- ] = self.analyzer_engine.nlp_engine.process_batch(
55
- texts=texts, language=language
56
- )
57
 
58
-
 
 
 
 
59
 
60
- list_results = []
 
61
 
62
- # Uncomment this if you want to show progress within a file
63
- #for text, nlp_artifacts in progress.tqdm(nlp_artifacts_batch, total = list_length, desc = "Analysing text for personal information", unit = "rows"):
64
- for text, nlp_artifacts in nlp_artifacts_batch:
65
- results = self.analyzer_engine.analyze(
66
- text=str(text), nlp_artifacts=nlp_artifacts, language=language, **kwargs
67
- )
 
 
 
 
 
 
 
 
 
68
 
69
- list_results.append(results)
70
 
71
- return list_results
72
 
73
  def analyze_dict(
74
- self,
75
- input_dict: Dict[str, Union[Any, Iterable[Any]]],
76
- language: str,
77
- keys_to_skip: Optional[List[str]] = None,
78
- **kwargs,
79
- ) -> Iterator[DictAnalyzerResult]:
80
- """
81
- Analyze a dictionary of keys (strings) and values/iterable of values.
82
-
83
- Non-string values are returned as is.
84
-
85
- :param input_dict: The input dictionary for analysis
86
- :param language: Input language
87
- :param keys_to_skip: Keys to ignore during analysis
88
- :param kwargs: Additional keyword arguments
89
- for the `AnalyzerEngine.analyze` method.
90
- Use this to pass arguments to the analyze method,
91
- such as `ad_hoc_recognizers`, `context`, `return_decision_process`.
92
- See `AnalyzerEngine.analyze` for the full list.
93
- """
94
-
95
- context = []
96
- if "context" in kwargs:
97
- context = kwargs["context"]
98
- del kwargs["context"]
99
-
100
- if not keys_to_skip:
101
- keys_to_skip = []
102
-
103
-
104
- for key, value in input_dict.items():
105
- if not value or key in keys_to_skip:
106
- yield DictAnalyzerResult(key=key, value=value, recognizer_results=[])
107
- continue # skip this key as requested
108
-
109
- # Add the key as an additional context
110
- specific_context = context[:]
111
- specific_context.append(key)
112
-
113
- if type(value) in (str, int, bool, float):
114
- results: List[RecognizerResult] = self.analyzer_engine.analyze(
115
- text=str(value), language=language, context=[key], **kwargs
116
- )
117
- elif isinstance(value, dict):
118
- new_keys_to_skip = self._get_nested_keys_to_skip(key, keys_to_skip)
119
- results = self.analyze_dict(
120
- input_dict=value,
121
- language=language,
122
- context=specific_context,
123
- keys_to_skip=new_keys_to_skip,
124
- **kwargs,
125
- )
126
- elif isinstance(value, Iterable):
127
- # Recursively iterate nested dicts
128
- list_length = len(value)
129
-
130
- results: List[List[RecognizerResult]] = analyze_iterator_custom(self,
131
- texts=value,
132
- language=language,
133
- context=specific_context,
134
- list_length=list_length,
135
- **kwargs,
136
- )
137
- else:
138
- raise ValueError(f"type {type(value)} is unsupported.")
139
-
140
- yield DictAnalyzerResult(key=key, value=value, recognizer_results=results)
 
1
+ from typing import Any, Dict, Iterable, Iterator, List, Optional, Tuple, Union
2
+
3
  import gradio as gr
 
 
4
 
5
+ # from tqdm import tqdm
6
  from presidio_analyzer import DictAnalyzerResult, RecognizerResult
7
  from presidio_analyzer.nlp_engine import NlpArtifacts
8
 
9
+
10
  def recognizer_result_from_dict(data: Dict) -> RecognizerResult:
11
  """
12
  Create RecognizerResult from a dictionary.
 
27
  score = data.get("Score")
28
  analysis_explanation = None
29
  recognition_metadata = None
30
+
31
+ return RecognizerResult(
32
+ entity_type, start, end, score, analysis_explanation, recognition_metadata
33
+ )
34
+
35
 
36
  def analyze_iterator_custom(
37
+ self,
38
+ texts: Iterable[Union[str, bool, float, int]],
39
+ language: str,
40
+ list_length: int,
41
+ progress=gr.Progress(),
42
+ **kwargs,
43
+ ) -> List[List[RecognizerResult]]:
44
+ """
45
+ Analyze an iterable of strings.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
+ :param texts: An list containing strings to be analyzed.
48
+ :param language: Input language
49
+ :param list_length: Length of the input list.
50
+ :param kwargs: Additional parameters for the `AnalyzerEngine.analyze` method.
51
+ """
52
 
53
+ # validate types
54
+ texts = self._validate_types(texts)
55
 
56
+ # Process the texts as batch for improved performance
57
+ nlp_artifacts_batch: Iterator[Tuple[str, NlpArtifacts]] = (
58
+ self.analyzer_engine.nlp_engine.process_batch(texts=texts, language=language)
59
+ )
60
+
61
+ list_results = []
62
+
63
+ # Uncomment this if you want to show progress within a file
64
+ # for text, nlp_artifacts in progress.tqdm(nlp_artifacts_batch, total = list_length, desc = "Analysing text for personal information", unit = "rows"):
65
+ for text, nlp_artifacts in nlp_artifacts_batch:
66
+ results = self.analyzer_engine.analyze(
67
+ text=str(text), nlp_artifacts=nlp_artifacts, language=language, **kwargs
68
+ )
69
+
70
+ list_results.append(results)
71
 
72
+ return list_results
73
 
 
74
 
75
  def analyze_dict(
76
+ self,
77
+ input_dict: Dict[str, Union[Any, Iterable[Any]]],
78
+ language: str,
79
+ keys_to_skip: Optional[List[str]] = None,
80
+ **kwargs,
81
+ ) -> Iterator[DictAnalyzerResult]:
82
+ """
83
+ Analyze a dictionary of keys (strings) and values/iterable of values.
84
+
85
+ Non-string values are returned as is.
86
+
87
+ :param input_dict: The input dictionary for analysis
88
+ :param language: Input language
89
+ :param keys_to_skip: Keys to ignore during analysis
90
+ :param kwargs: Additional keyword arguments
91
+ for the `AnalyzerEngine.analyze` method.
92
+ Use this to pass arguments to the analyze method,
93
+ such as `ad_hoc_recognizers`, `context`, `return_decision_process`.
94
+ See `AnalyzerEngine.analyze` for the full list.
95
+ """
96
+
97
+ context = []
98
+ if "context" in kwargs:
99
+ context = kwargs["context"]
100
+ del kwargs["context"]
101
+
102
+ if not keys_to_skip:
103
+ keys_to_skip = []
104
+
105
+ for key, value in input_dict.items():
106
+ if not value or key in keys_to_skip:
107
+ yield DictAnalyzerResult(key=key, value=value, recognizer_results=[])
108
+ continue # skip this key as requested
109
+
110
+ # Add the key as an additional context
111
+ specific_context = context[:]
112
+ specific_context.append(key)
113
+
114
+ if type(value) in (str, int, bool, float):
115
+ results: List[RecognizerResult] = self.analyzer_engine.analyze(
116
+ text=str(value), language=language, context=[key], **kwargs
117
+ )
118
+ elif isinstance(value, dict):
119
+ new_keys_to_skip = self._get_nested_keys_to_skip(key, keys_to_skip)
120
+ results = self.analyze_dict(
121
+ input_dict=value,
122
+ language=language,
123
+ context=specific_context,
124
+ keys_to_skip=new_keys_to_skip,
125
+ **kwargs,
126
+ )
127
+ elif isinstance(value, Iterable):
128
+ # Recursively iterate nested dicts
129
+ list_length = len(value)
130
+
131
+ results: List[List[RecognizerResult]] = analyze_iterator_custom(
132
+ self,
133
+ texts=value,
134
+ language=language,
135
+ context=specific_context,
136
+ list_length=list_length,
137
+ **kwargs,
138
+ )
139
+ else:
140
+ raise ValueError(f"type {type(value)} is unsupported.")
141
+
142
+ yield DictAnalyzerResult(key=key, value=value, recognizer_results=results)
tools/redaction_review.py CHANGED
The diff for this file is too large to render. See raw diff
 
tools/textract_batch_call.py CHANGED
@@ -1,36 +1,54 @@
1
- import boto3
2
- import os
3
- import pandas as pd
4
  import json
5
  import logging
6
- import datetime
7
- import pymupdf
 
 
 
8
  import gradio as gr
 
 
 
 
 
 
 
 
9
  from gradio import FileData
10
- from typing import List
11
- from io import StringIO
12
- from urllib.parse import urlparse
13
- from botocore.exceptions import ClientError, NoCredentialsError, PartialCredentialsError, TokenRetrievalError
14
- from tools.config import TEXTRACT_WHOLE_DOCUMENT_ANALYSIS_BUCKET, OUTPUT_FOLDER, AWS_REGION, DOCUMENT_REDACTION_BUCKET, LOAD_PREVIOUS_TEXTRACT_JOBS_S3, TEXTRACT_JOBS_S3_LOC, TEXTRACT_JOBS_LOCAL_LOC, RUN_AWS_FUNCTIONS, INPUT_FOLDER, DAYS_TO_DISPLAY_WHOLE_DOCUMENT_JOBS
15
  from tools.aws_functions import download_file_from_s3
 
 
 
 
 
 
 
 
 
 
 
 
16
  from tools.file_conversion import get_input_file_names
17
  from tools.helper_functions import get_file_name_without_type
18
 
19
  DAYS_TO_DISPLAY_WHOLE_DOCUMENT_JOBS = int(DAYS_TO_DISPLAY_WHOLE_DOCUMENT_JOBS)
20
 
 
21
  def analyse_document_with_textract_api(
22
  local_pdf_path: str,
23
  s3_input_prefix: str,
24
  s3_output_prefix: str,
25
- job_df:pd.DataFrame,
26
  s3_bucket_name: str = TEXTRACT_WHOLE_DOCUMENT_ANALYSIS_BUCKET,
27
- local_output_dir: str = OUTPUT_FOLDER,
28
- handwrite_signature_checkbox:List[str] = list(),
29
- successful_job_number:int=0,
30
- total_document_page_count:int=1,
31
  general_s3_bucket_name: str = DOCUMENT_REDACTION_BUCKET,
32
- aws_region: str = AWS_REGION # Optional: specify region if not default
33
- ):
34
  """
35
  Uploads a local PDF to S3, starts a Textract analysis job (detecting text & signatures),
36
  waits for completion, and downloads the output JSON from S3 to a local directory.
@@ -42,7 +60,7 @@ def analyse_document_with_textract_api(
42
  s3_output_prefix (str): S3 prefix (folder) where Textract should write output.
43
  job_df (pd.DataFrame): Dataframe containing information from previous Textract API calls.
44
  s3_bucket_name (str, optional): S3 bucket in which to save API call outputs.
45
- local_output_dir (str, optional): Local directory to save the downloaded JSON results.
46
  handwrite_signature_checkbox (List[str], optional): List of feature types to extract from the document.
47
  successful_job_number (int): The number of successful jobs that have been submitted in this session.
48
  total_document_page_count (int): The number of pages in the document
@@ -71,7 +89,7 @@ def analyse_document_with_textract_api(
71
  file_extension = os.path.splitext(local_pdf_path)[1].lower()
72
 
73
  # Load pdf to get page count if not provided
74
- if not total_document_page_count and file_extension in ['.pdf']:
75
  print("Page count not provided. Loading PDF to get page count")
76
  try:
77
  pymupdf_doc = pymupdf.open(local_pdf_path)
@@ -81,7 +99,7 @@ def analyse_document_with_textract_api(
81
  except Exception as e:
82
  print("Failed to load PDF to get page count:", e, "setting page count to 1")
83
  total_document_page_count = 1
84
- #raise Exception(f"Failed to load PDF to get page count: {e}")
85
  else:
86
  total_document_page_count = 1
87
 
@@ -89,50 +107,92 @@ def analyse_document_with_textract_api(
89
  os.makedirs(local_output_dir)
90
  log_message = f"Created local output directory: {local_output_dir}"
91
  print(log_message)
92
- #logging.info(log_message)
93
 
94
  # Initialize boto3 clients
95
  session = boto3.Session(region_name=aws_region)
96
- s3_client = session.client('s3')
97
- textract_client = session.client('textract')
98
 
99
  # --- 1. Upload PDF to S3 ---
100
  pdf_filename = os.path.basename(local_pdf_path)
101
- s3_input_key = os.path.join(s3_input_prefix, pdf_filename).replace("\\", "/") # Ensure forward slashes for S3
 
 
102
 
103
- log_message = f"Uploading '{local_pdf_path}' to 's3://{s3_bucket_name}/{s3_input_key}'..."
 
 
104
  print(log_message)
105
- #logging.info(log_message)
106
  try:
107
  s3_client.upload_file(local_pdf_path, s3_bucket_name, s3_input_key)
108
  log_message = "Upload successful."
109
  print(log_message)
110
- #logging.info(log_message)
111
  except Exception as e:
112
  log_message = f"Failed to upload PDF to S3: {e}"
113
  print(log_message)
114
- #logging.error(log_message)
115
  raise
116
 
117
  # Filter job_df to include rows only where the analysis date is after the current date - DAYS_TO_DISPLAY_WHOLE_DOCUMENT_JOBS
118
  if not job_df.empty:
119
- job_df = job_df.loc[job_df["job_date_time"] > (datetime.datetime.now() - datetime.timedelta(days=DAYS_TO_DISPLAY_WHOLE_DOCUMENT_JOBS)),:]
 
 
 
 
 
 
 
120
 
121
  # If job_df is not empty
122
- if not job_df.empty:
123
 
124
  if "file_name" in job_df.columns:
125
- matching_job_id_file_names = job_df.loc[(job_df["file_name"] == pdf_filename) & (job_df["signature_extraction"].astype(str) == str(handwrite_signature_checkbox)), "file_name"]
126
- matching_job_id_file_names_dates = job_df.loc[(job_df["file_name"] == pdf_filename) & (job_df["signature_extraction"].astype(str) == str(handwrite_signature_checkbox)), "job_date_time"]
127
- matching_job_id = job_df.loc[(job_df["file_name"] == pdf_filename) & (job_df["signature_extraction"].astype(str) == str(handwrite_signature_checkbox)), "job_id"]
128
- matching_handwrite_signature = job_df.loc[(job_df["file_name"] == pdf_filename) & (job_df["signature_extraction"].astype(str) == str(handwrite_signature_checkbox)), "signature_extraction"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
 
130
  if len(matching_job_id) > 0:
131
  pass
132
  else:
133
  matching_job_id = "unknown_job_id"
134
 
135
- if len(matching_job_id_file_names) > 0 and len(matching_handwrite_signature) > 0:
 
 
 
136
  out_message = f"Existing Textract outputs found for file {pdf_filename} from date {matching_job_id_file_names_dates.iloc[0]}. No need to re-analyse. Please download existing results from the list with job ID {matching_job_id.iloc[0]}"
137
  gr.Warning(out_message)
138
  raise Exception(out_message)
@@ -142,104 +202,121 @@ def analyse_document_with_textract_api(
142
  print(message)
143
 
144
  try:
145
- if "Extract signatures" in handwrite_signature_checkbox or "Extract forms" in handwrite_signature_checkbox or "Extract layout" in handwrite_signature_checkbox or "Extract tables" in handwrite_signature_checkbox:
 
 
 
 
 
146
  feature_types = list()
147
- if 'Extract signatures' in handwrite_signature_checkbox:
148
- feature_types.append('SIGNATURES')
149
  if "Extract forms" in handwrite_signature_checkbox:
150
- feature_types.append('FORMS')
151
  if "Extract layout" in handwrite_signature_checkbox:
152
- feature_types.append('LAYOUT')
153
  if "Extract tables" in handwrite_signature_checkbox:
154
- feature_types.append('TABLES')
155
  response = textract_client.start_document_analysis(
156
  DocumentLocation={
157
- 'S3Object': {
158
- 'Bucket': s3_bucket_name,
159
- 'Name': s3_input_key
160
- }
161
  },
162
- FeatureTypes=feature_types, # Analyze for signatures, forms, and tables
163
- OutputConfig={
164
- 'S3Bucket': s3_bucket_name,
165
- 'S3Prefix': s3_output_prefix
166
- }
167
  )
168
- job_type="document_analysis"
169
-
170
- if not "Extract signatures" in handwrite_signature_checkbox and not "Extract forms" in handwrite_signature_checkbox and not "Extract layout" in handwrite_signature_checkbox and not "Extract tables" in handwrite_signature_checkbox:
 
 
 
 
 
171
  response = textract_client.start_document_text_detection(
172
  DocumentLocation={
173
- 'S3Object': {
174
- 'Bucket': s3_bucket_name,
175
- 'Name': s3_input_key
176
- }
177
  },
178
- OutputConfig={
179
- 'S3Bucket': s3_bucket_name,
180
- 'S3Prefix': s3_output_prefix
181
- }
182
  )
183
- job_type="document_text_detection"
184
 
185
- job_id = response['JobId']
186
  print(f"Textract job started with JobId: {job_id}")
187
 
188
  # Prepare CSV in memory
189
  log_csv_key_location = f"{s3_output_prefix}/textract_document_jobs.csv"
190
- job_location_full = f"s3://{s3_bucket_name}/{s3_output_prefix}/{job_id}/"
191
 
192
- csv_buffer = StringIO()
193
- log_df = pd.DataFrame([{
194
- 'job_id': job_id,
195
- 'file_name': pdf_filename,
196
- 'job_type': job_type,
197
- 'signature_extraction':handwrite_signature_checkbox,
198
- 'job_date_time': datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
199
- }])
 
 
 
 
 
 
200
 
201
  # File path
202
  log_file_path = os.path.join(local_output_dir, "textract_document_jobs.csv")
203
- log_file_path_job_id = os.path.join(local_output_dir, pdf_filename + "_textract_document_jobs_job_id.txt")
 
 
204
 
205
  # Write latest job ID to local text file
206
- with open(log_file_path_job_id, 'w') as f:
207
  f.write(job_id)
208
 
209
  # Check if file exists
210
  file_exists = os.path.exists(log_file_path)
211
 
212
  # Append to CSV if it exists, otherwise write with header
213
- log_df.to_csv(log_file_path, mode='a', index=False, header=not file_exists)
214
-
215
- #log_df.to_csv(csv_buffer)
216
 
217
  # Upload the file
218
- s3_client.upload_file(log_file_path, general_s3_bucket_name, log_csv_key_location)
 
 
219
 
220
  # Upload to S3 (overwrite existing file)
221
- #s3_client.put_object(Bucket=general_s3_bucket_name, Key=log_csv_key_location, Body=csv_buffer.getvalue())
222
  print(f"Job ID written to {log_csv_key_location}")
223
- #logging.info(f"Job ID written to s3://{s3_bucket_name}/{s3_output_prefix}/textract_document_jobs.csv")
224
 
225
  except Exception as e:
226
  error = f"Failed to start Textract job: {e}"
227
  print(error)
228
- #logging.error(error)
229
  raise
230
 
231
  successful_job_number += 1
232
  total_number_of_textract_page_calls = total_document_page_count
233
 
234
- return f"Textract analysis job submitted, job ID:{job_id}", job_id, job_type, successful_job_number, is_a_textract_api_call, total_number_of_textract_page_calls, task_textbox
 
 
 
 
 
 
 
 
 
235
 
236
- def return_job_status(job_id:str,
237
- response:dict,
238
- attempts:int,
239
- poll_interval_seconds: int = 0,
240
- max_polling_attempts: int = 1 # ~10 minutes total wait time
241
- ):
242
- '''
 
243
  Polls the AWS Textract service to retrieve the current status of an asynchronous document analysis job.
244
  This function checks the job status from the provided response and logs relevant information or errors.
245
 
@@ -255,87 +332,103 @@ def return_job_status(job_id:str,
255
 
256
  Raises:
257
  Exception: If the Textract job status is 'FAILED' or 'PARTIAL_SUCCESS', or if an unexpected status is encountered.
258
- '''
259
 
260
- job_status = response['JobStatus']
261
- logging.info(f"Polling attempt {attempts}/{max_polling_attempts}. Job status: {job_status}")
 
 
262
 
263
- if job_status == 'IN_PROGRESS':
264
  pass
265
- #time.sleep(poll_interval_seconds)
266
- elif job_status == 'SUCCEEDED':
267
  logging.info("Textract job succeeded.")
268
- elif job_status in ['FAILED', 'PARTIAL_SUCCESS']:
269
- status_message = response.get('StatusMessage', 'No status message provided.')
270
- warnings = response.get('Warnings', [])
271
- logging.error(f"Textract job ended with status: {job_status}. Message: {status_message}")
 
 
272
  if warnings:
273
  logging.warning(f"Warnings: {warnings}")
274
  # Decide if PARTIAL_SUCCESS should proceed or raise error
275
  # For simplicity here, we raise for both FAILED and PARTIAL_SUCCESS
276
- raise Exception(f"Textract job {job_id} failed or partially failed. Status: {job_status}. Message: {status_message}")
 
 
277
  else:
278
  # Should not happen based on documentation, but handle defensively
279
  raise Exception(f"Unexpected Textract job status: {job_status}")
280
-
281
  return job_status
282
 
283
- def download_textract_job_files(s3_client:str,
284
- s3_bucket_name:str,
285
- s3_output_key_prefix:str,
286
- pdf_filename:str,
287
- job_id:str,
288
- local_output_dir:str):
289
- '''
 
 
 
290
  Download and combine selected job files from the AWS Textract service.
291
- '''
292
 
293
- #print("s3_output_key_prefix at download:", s3_output_key_prefix)
294
 
295
  list_response = s3_client.list_objects_v2(
296
- Bucket=s3_bucket_name,
297
- Prefix=s3_output_key_prefix
298
  )
299
 
300
- output_files = list_response.get('Contents', [])
301
  if not output_files:
302
  # Sometimes Textract might take a moment longer to write the output after SUCCEEDED status
303
- #logging.warning("No output files found immediately after job success. Waiting briefly and retrying list...")
304
- #time.sleep(5)
305
  list_response = s3_client.list_objects_v2(
306
- Bucket=s3_bucket_name,
307
- Prefix=s3_output_key_prefix
308
  )
309
- output_files = list_response.get('Contents', [])
310
 
311
  if not output_files:
312
- logging.error(f"No output files found in s3://{s3_bucket_name}/{s3_output_key_prefix}")
 
 
313
  # You could alternatively try getting results via get_document_analysis pagination here
314
  # but sticking to the request to download from S3 output path.
315
- raise FileNotFoundError(f"Textract output files not found in S3 path: s3://{s3_bucket_name}/{s3_output_key_prefix}")
 
 
316
 
317
  # Usually, we only need the first/main JSON output file(s)
318
  # For simplicity, download the first one found. A more complex scenario might merge multiple files.
319
  # Filter out potential directory markers if any key ends with '/'
320
  json_files_to_download = [
321
- f for f in output_files
322
- if f['Key'] != s3_output_key_prefix and not f['Key'].endswith('/') and 'access_check' not in f['Key']
323
- ]
 
 
 
324
 
325
- #print("json_files_to_download:", json_files_to_download)
326
 
327
  if not json_files_to_download:
328
  error = f"No JSON files found (only prefix marker?) in s3://{s3_bucket_name}/{s3_output_key_prefix}"
329
  print(error)
330
- #logging.error(error)
331
  raise FileNotFoundError(error)
332
 
333
  combined_blocks = []
334
 
335
- for f in sorted(json_files_to_download, key=lambda x: x['Key']): # Optional: sort to ensure consistent order
336
- obj = s3_client.get_object(Bucket=s3_bucket_name, Key=f['Key'])
337
- data = json.loads(obj['Body'].read())
338
-
 
 
339
  # Assuming Textract-style output with a "Blocks" key
340
  if "Blocks" in data:
341
  combined_blocks.extend(data["Blocks"])
@@ -345,10 +438,10 @@ def download_textract_job_files(s3_client:str,
345
  # Build final combined JSON structure
346
  combined_output = {
347
  "DocumentMetadata": {
348
- "Pages": len(set(block.get('Page', 1) for block in combined_blocks))
349
  },
350
  "Blocks": combined_blocks,
351
- "JobStatus": "SUCCEEDED"
352
  }
353
 
354
  output_filename_base = os.path.basename(pdf_filename)
@@ -356,7 +449,7 @@ def download_textract_job_files(s3_client:str,
356
  local_output_filename = f"{output_filename_base_no_ext}_textract.json"
357
  local_output_path = os.path.join(local_output_dir, local_output_filename)
358
 
359
- with open(local_output_path, 'w') as f:
360
  json.dump(combined_output, f)
361
 
362
  print(f"Combined Textract output written to {local_output_path}")
@@ -367,133 +460,201 @@ def download_textract_job_files(s3_client:str,
367
  downloaded_file_path = local_output_path
368
 
369
  # Log if multiple files were found, as user might need to handle them
370
- #if len(json_files_to_download) > 1:
371
  # logging.warning(f"Multiple output files found in S3 output location. Downloaded the first: '{s3_output_key}'. Other files exist.")
372
 
373
  return downloaded_file_path
374
 
375
- def check_for_provided_job_id(job_id:str):
 
376
  if not job_id:
377
- raise Exception("Please provide a job ID.")
378
  return
379
 
 
380
  def load_pdf_job_file_from_s3(
381
  load_s3_jobs_input_loc,
382
  pdf_filename,
383
  local_output_dir,
384
  s3_bucket_name,
385
- RUN_AWS_FUNCTIONS=RUN_AWS_FUNCTIONS):
 
386
 
387
  try:
388
- pdf_file_location = ''
389
- doc_file_name_no_extension_textbox = ''
390
 
391
- s3_input_key_prefix = os.path.join(load_s3_jobs_input_loc, pdf_filename).replace("\\", "/")
 
 
392
  s3_input_key_prefix = s3_input_key_prefix + ".pdf"
393
-
394
  local_input_file_path = os.path.join(local_output_dir, pdf_filename)
395
  local_input_file_path = local_input_file_path + ".pdf"
396
 
397
- download_file_from_s3(s3_bucket_name, s3_input_key_prefix, local_input_file_path, RUN_AWS_FUNCTIONS= RUN_AWS_FUNCTIONS)
398
-
 
 
 
 
 
399
  pdf_file_location = [local_input_file_path]
400
  doc_file_name_no_extension_textbox = get_file_name_without_type(pdf_filename)
401
  except Exception as e:
402
- print("Could not download PDF job file from S3 due to:", e)
403
 
404
  return pdf_file_location, doc_file_name_no_extension_textbox
405
 
406
- def replace_existing_pdf_input_for_whole_document_outputs(
407
- load_s3_jobs_input_loc:str,
408
- pdf_filename:str,
409
- local_output_dir:str,
410
- s3_bucket_name:str,
411
- in_doc_files:FileData=[],
412
- input_folder:str=INPUT_FOLDER,
 
413
  RUN_AWS_FUNCTIONS=RUN_AWS_FUNCTIONS,
414
- progress = gr.Progress(track_tqdm=True)):
 
415
 
416
  progress(0.1, "Loading PDF from s3")
417
 
418
  if in_doc_files:
419
- doc_file_name_no_extension_textbox, doc_file_name_with_extension_textbox, doc_full_file_name_textbox, doc_file_name_textbox_list, total_pdf_page_count = get_input_file_names(in_doc_files)
 
 
 
 
 
 
420
 
421
  if pdf_filename == doc_file_name_no_extension_textbox:
422
  print("Existing loaded PDF file has same name as file from S3")
423
  doc_file_name_no_extension_textbox = pdf_filename
424
  downloaded_pdf_file_location = in_doc_files
425
  else:
426
- downloaded_pdf_file_location, doc_file_name_no_extension_textbox = load_pdf_job_file_from_s3(load_s3_jobs_input_loc, pdf_filename, local_output_dir, s3_bucket_name, RUN_AWS_FUNCTIONS=RUN_AWS_FUNCTIONS)
 
 
 
 
 
 
 
 
427
 
428
- doc_file_name_no_extension_textbox, doc_file_name_with_extension_textbox, doc_full_file_name_textbox, doc_file_name_textbox_list, total_pdf_page_count = get_input_file_names(downloaded_pdf_file_location)
429
- else:
430
- downloaded_pdf_file_location, doc_file_name_no_extension_textbox = load_pdf_job_file_from_s3(load_s3_jobs_input_loc, pdf_filename, local_output_dir, s3_bucket_name, RUN_AWS_FUNCTIONS=RUN_AWS_FUNCTIONS)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
431
 
432
- doc_file_name_no_extension_textbox, doc_file_name_with_extension_textbox, doc_full_file_name_textbox, doc_file_name_textbox_list, total_pdf_page_count = get_input_file_names(downloaded_pdf_file_location)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
433
 
434
- return downloaded_pdf_file_location, doc_file_name_no_extension_textbox, doc_file_name_with_extension_textbox, doc_full_file_name_textbox, doc_file_name_textbox_list, total_pdf_page_count
435
 
436
  def poll_whole_document_textract_analysis_progress_and_download(
437
- job_id:str,
438
- job_type_dropdown:str,
439
  s3_output_prefix: str,
440
- pdf_filename:str,
441
- job_df:pd.DataFrame,
442
  s3_bucket_name: str = TEXTRACT_WHOLE_DOCUMENT_ANALYSIS_BUCKET,
443
  local_output_dir: str = OUTPUT_FOLDER,
444
- load_s3_jobs_loc:str=TEXTRACT_JOBS_S3_LOC,
445
- load_local_jobs_loc:str=TEXTRACT_JOBS_LOCAL_LOC,
446
- aws_region: str = AWS_REGION, # Optional: specify region if not default
447
- load_jobs_from_s3:str = LOAD_PREVIOUS_TEXTRACT_JOBS_S3,
448
  poll_interval_seconds: int = 1,
449
- max_polling_attempts: int = 1, # ~10 minutes total wait time):
450
  DAYS_TO_DISPLAY_WHOLE_DOCUMENT_JOBS: int = DAYS_TO_DISPLAY_WHOLE_DOCUMENT_JOBS,
451
- progress = gr.Progress(track_tqdm=True)
452
- ):
453
- '''
454
  Poll AWS for the status of a Textract API job. Return status, and if finished, combine and download results into a locally-stored json file for further processing by the app.
455
- '''
456
 
457
  progress(0.1, "Querying AWS Textract for status of document analysis job")
458
 
459
  if job_id:
460
  # Initialize boto3 clients
461
  session = boto3.Session(region_name=aws_region)
462
- s3_client = session.client('s3')
463
- textract_client = session.client('textract')
464
 
465
  # --- 3. Poll for Job Completion ---
466
- job_status = 'IN_PROGRESS'
467
  attempts = 0
468
 
469
  message = "Polling Textract for job completion status..."
470
  print(message)
471
- #logging.info("Polling Textract for job completion status...")
472
 
473
  # Update Textract document history df
474
  try:
475
- job_df = load_in_textract_job_details(load_s3_jobs=load_jobs_from_s3,
476
- load_s3_jobs_loc=load_s3_jobs_loc,
477
- load_local_jobs_loc=load_local_jobs_loc)
 
 
478
  except Exception as e:
479
- #logging.error(f"Failed to update job details dataframe: {e}")
480
  print(f"Failed to update job details dataframe: {e}")
481
- #raise
482
 
483
- while job_status == 'IN_PROGRESS' and attempts <= max_polling_attempts:
484
  attempts += 1
485
  try:
486
- if job_type_dropdown=="document_analysis":
487
  response = textract_client.get_document_analysis(JobId=job_id)
488
- job_status = return_job_status(job_id, response, attempts, poll_interval_seconds, max_polling_attempts)
489
- elif job_type_dropdown=="document_text_detection":
 
 
 
 
 
 
490
  response = textract_client.get_document_text_detection(JobId=job_id)
491
- job_status = return_job_status(job_id, response, attempts, poll_interval_seconds, max_polling_attempts)
 
 
 
 
 
 
492
  else:
493
- error = f"Unknown job type, cannot poll job"
494
  print(error)
495
- logging.error(f"Invalid JobId: {job_id}. This might happen if the job expired (older than {DAYS_TO_DISPLAY_WHOLE_DOCUMENT_JOBS} days) or never existed.")
496
- raise Exception(error_message)
497
 
498
  except textract_client.exceptions.InvalidJobIdException:
499
  error_message = f"Invalid JobId: {job_id}. This might happen if the job expired (older than {DAYS_TO_DISPLAY_WHOLE_DOCUMENT_JOBS} days) or never existed."
@@ -501,17 +662,19 @@ def poll_whole_document_textract_analysis_progress_and_download(
501
  logging.error(error_message)
502
  raise Exception(error_message)
503
  except Exception as e:
504
- error_message = f"Error while polling Textract status for job {job_id}: {e}"
 
 
505
  print(error_message)
506
  logging.error(error_message)
507
  raise Exception(error_message)
508
 
509
  downloaded_file_path = None
510
- if job_status == 'SUCCEEDED':
511
- #raise TimeoutError(f"Textract job {job_id} did not complete successfully within the polling limit.")
512
- # 3b - Replace PDF file name if it exists in the job dataframe
513
 
514
- progress(0.5, "Document analysis task outputs found. Downloading from S3")
515
 
516
  # If job_df is not empty
517
 
@@ -520,7 +683,9 @@ def poll_whole_document_textract_analysis_progress_and_download(
520
 
521
  if not job_df.empty:
522
  if "file_name" in job_df.columns:
523
- matching_job_id_file_names = job_df.loc[job_df["job_id"] == job_id, "file_name"]
 
 
524
 
525
  if pdf_filename and not matching_job_id_file_names.empty:
526
  if pdf_filename == matching_job_id_file_names.iloc[0]:
@@ -537,58 +702,77 @@ def poll_whole_document_textract_analysis_progress_and_download(
537
  # Textract typically creates output under s3_output_prefix/job_id/
538
  # There might be multiple JSON files if pagination occurred during writing.
539
  # Usually, for smaller docs, there's one file, often named '1'.
540
- # For robust handling, list objects and find the JSON(s).
541
 
542
- s3_output_key_prefix = os.path.join(s3_output_prefix, job_id).replace("\\", "/") + "/"
543
- logging.info(f"Searching for output files in s3://{s3_bucket_name}/{s3_output_key_prefix}")
 
 
 
 
544
 
545
  try:
546
- downloaded_file_path = download_textract_job_files(s3_client,
547
- s3_bucket_name,
548
- s3_output_key_prefix,
549
- pdf_filename,
550
- job_id,
551
- local_output_dir)
 
 
552
 
553
  except Exception as e:
554
- #logging.error(f"Failed to download or process Textract output from S3: {e}")
555
  print(f"Failed to download or process Textract output from S3: {e}")
556
  raise
557
 
558
  else:
559
- raise Exception("No Job ID provided.")
560
-
561
  output_pdf_filename = get_file_name_without_type(pdf_filename)
562
 
563
  return downloaded_file_path, job_status, job_df, output_pdf_filename
564
 
565
- def load_in_textract_job_details(load_s3_jobs:str=LOAD_PREVIOUS_TEXTRACT_JOBS_S3,
566
- load_s3_jobs_loc:str=TEXTRACT_JOBS_S3_LOC,
567
- load_local_jobs_loc:str=TEXTRACT_JOBS_LOCAL_LOC,
568
- document_redaction_bucket:str=DOCUMENT_REDACTION_BUCKET,
569
- aws_region:str=AWS_REGION,
570
- DAYS_TO_DISPLAY_WHOLE_DOCUMENT_JOBS:int=DAYS_TO_DISPLAY_WHOLE_DOCUMENT_JOBS):
571
- '''
 
 
 
572
  Load in a dataframe of jobs previous submitted to the Textract API service.
573
- '''
574
- job_df = pd.DataFrame(columns=['job_id','file_name','job_type','signature_extraction','job_date_time'])
 
 
 
 
 
 
 
 
575
 
576
  # Initialize boto3 clients
577
  session = boto3.Session(region_name=aws_region)
578
- s3_client = session.client('s3')
 
 
579
 
580
- local_output_path = f'{load_local_jobs_loc}/textract_document_jobs.csv'
 
581
 
582
- if load_s3_jobs == 'True':
583
- s3_output_key = f'{load_s3_jobs_loc}/textract_document_jobs.csv'
584
-
585
  try:
586
  s3_client.head_object(Bucket=document_redaction_bucket, Key=s3_output_key)
587
- #print(f"File exists. Downloading from '{s3_output_key}' to '{local_output_path}'...")
588
- s3_client.download_file(document_redaction_bucket, s3_output_key, local_output_path)
589
- #print("Download successful.")
 
 
590
  except ClientError as e:
591
- if e.response['Error']['Code'] == '404':
592
  print("Log file does not exist in S3.")
593
  else:
594
  print(f"Unexpected error occurred: {e}")
@@ -602,22 +786,37 @@ def load_in_textract_job_details(load_s3_jobs:str=LOAD_PREVIOUS_TEXTRACT_JOBS_S3
602
  job_df = pd.read_csv(local_output_path)
603
 
604
  if "job_date_time" in job_df.columns:
605
- job_df["job_date_time"] = pd.to_datetime(job_df["job_date_time"], errors='coerce')
 
 
606
  # Keep only jobs that have been completed in the last 'DAYS_TO_DISPLAY_WHOLE_DOCUMENT_JOBS' days
607
- cutoff_time = pd.Timestamp.now() - pd.Timedelta(days=DAYS_TO_DISPLAY_WHOLE_DOCUMENT_JOBS)
608
- job_df = job_df.loc[job_df["job_date_time"] > cutoff_time,:]
 
 
609
 
610
  try:
611
- job_df = job_df[['job_id','file_name','job_type','signature_extraction','job_date_time']]
 
 
 
 
 
 
 
 
612
  except Exception as e:
613
- print("Could not find one or more columns in Textract whole document list dataframe:", e)
 
 
 
614
 
615
  return job_df
616
 
617
- def download_textract_output(job_id:str,
618
- output_bucket:str,
619
- output_prefix:str,
620
- local_folder:str):
621
  """
622
  Checks the status of a Textract job and downloads the output ZIP file if the job is complete.
623
 
@@ -626,23 +825,26 @@ def download_textract_output(job_id:str,
626
  :param output_prefix: The prefix (folder path) in S3 where the output file is stored.
627
  :param local_folder: The local directory where the ZIP file should be saved.
628
  """
629
- textract_client = boto3.client('textract')
630
- s3_client = boto3.client('s3')
631
 
632
  # Check job status
633
  while True:
634
  response = textract_client.get_document_analysis(JobId=job_id)
635
- status = response['JobStatus']
636
-
637
- if status == 'SUCCEEDED':
638
  print("Job completed successfully.")
639
  break
640
- elif status == 'FAILED':
641
- print("Job failed:", response.get("StatusMessage", "No error message provided."))
 
 
 
642
  return
643
  else:
644
  print(f"Job is still {status}.")
645
- #time.sleep(10) # Wait before checking again
646
 
647
  # Find output ZIP file in S3
648
  output_file_key = f"{output_prefix}/{job_id}.zip"
@@ -655,8 +857,12 @@ def download_textract_output(job_id:str,
655
  except Exception as e:
656
  print(f"Error downloading file: {e}")
657
 
 
658
  def check_textract_outputs_exist(textract_output_found_checkbox):
659
- if textract_output_found_checkbox == True:
660
- print("Textract outputs found")
661
- return
662
- else: raise Exception("Relevant Textract outputs not found. Please ensure you have selected to correct results output and you have uploaded the relevant document file in 'Choose document or image file...' above")
 
 
 
 
1
+ import datetime
 
 
2
  import json
3
  import logging
4
+ import os
5
+ from io import StringIO
6
+ from typing import List
7
+
8
+ import boto3
9
  import gradio as gr
10
+ import pandas as pd
11
+ import pymupdf
12
+ from botocore.exceptions import (
13
+ ClientError,
14
+ NoCredentialsError,
15
+ PartialCredentialsError,
16
+ TokenRetrievalError,
17
+ )
18
  from gradio import FileData
19
+
 
 
 
 
20
  from tools.aws_functions import download_file_from_s3
21
+ from tools.config import (
22
+ AWS_REGION,
23
+ DAYS_TO_DISPLAY_WHOLE_DOCUMENT_JOBS,
24
+ DOCUMENT_REDACTION_BUCKET,
25
+ INPUT_FOLDER,
26
+ LOAD_PREVIOUS_TEXTRACT_JOBS_S3,
27
+ OUTPUT_FOLDER,
28
+ RUN_AWS_FUNCTIONS,
29
+ TEXTRACT_JOBS_LOCAL_LOC,
30
+ TEXTRACT_JOBS_S3_LOC,
31
+ TEXTRACT_WHOLE_DOCUMENT_ANALYSIS_BUCKET,
32
+ )
33
  from tools.file_conversion import get_input_file_names
34
  from tools.helper_functions import get_file_name_without_type
35
 
36
  DAYS_TO_DISPLAY_WHOLE_DOCUMENT_JOBS = int(DAYS_TO_DISPLAY_WHOLE_DOCUMENT_JOBS)
37
 
38
+
39
  def analyse_document_with_textract_api(
40
  local_pdf_path: str,
41
  s3_input_prefix: str,
42
  s3_output_prefix: str,
43
+ job_df: pd.DataFrame,
44
  s3_bucket_name: str = TEXTRACT_WHOLE_DOCUMENT_ANALYSIS_BUCKET,
45
+ local_output_dir: str = OUTPUT_FOLDER,
46
+ handwrite_signature_checkbox: List[str] = list(),
47
+ successful_job_number: int = 0,
48
+ total_document_page_count: int = 1,
49
  general_s3_bucket_name: str = DOCUMENT_REDACTION_BUCKET,
50
+ aws_region: str = AWS_REGION, # Optional: specify region if not default
51
+ ):
52
  """
53
  Uploads a local PDF to S3, starts a Textract analysis job (detecting text & signatures),
54
  waits for completion, and downloads the output JSON from S3 to a local directory.
 
60
  s3_output_prefix (str): S3 prefix (folder) where Textract should write output.
61
  job_df (pd.DataFrame): Dataframe containing information from previous Textract API calls.
62
  s3_bucket_name (str, optional): S3 bucket in which to save API call outputs.
63
+ local_output_dir (str, optional): Local directory to save the downloaded JSON results.
64
  handwrite_signature_checkbox (List[str], optional): List of feature types to extract from the document.
65
  successful_job_number (int): The number of successful jobs that have been submitted in this session.
66
  total_document_page_count (int): The number of pages in the document
 
89
  file_extension = os.path.splitext(local_pdf_path)[1].lower()
90
 
91
  # Load pdf to get page count if not provided
92
+ if not total_document_page_count and file_extension in [".pdf"]:
93
  print("Page count not provided. Loading PDF to get page count")
94
  try:
95
  pymupdf_doc = pymupdf.open(local_pdf_path)
 
99
  except Exception as e:
100
  print("Failed to load PDF to get page count:", e, "setting page count to 1")
101
  total_document_page_count = 1
102
+ # raise Exception(f"Failed to load PDF to get page count: {e}")
103
  else:
104
  total_document_page_count = 1
105
 
 
107
  os.makedirs(local_output_dir)
108
  log_message = f"Created local output directory: {local_output_dir}"
109
  print(log_message)
110
+ # logging.info(log_message)
111
 
112
  # Initialize boto3 clients
113
  session = boto3.Session(region_name=aws_region)
114
+ s3_client = session.client("s3")
115
+ textract_client = session.client("textract")
116
 
117
  # --- 1. Upload PDF to S3 ---
118
  pdf_filename = os.path.basename(local_pdf_path)
119
+ s3_input_key = os.path.join(s3_input_prefix, pdf_filename).replace(
120
+ "\\", "/"
121
+ ) # Ensure forward slashes for S3
122
 
123
+ log_message = (
124
+ f"Uploading '{local_pdf_path}' to 's3://{s3_bucket_name}/{s3_input_key}'..."
125
+ )
126
  print(log_message)
127
+ # logging.info(log_message)
128
  try:
129
  s3_client.upload_file(local_pdf_path, s3_bucket_name, s3_input_key)
130
  log_message = "Upload successful."
131
  print(log_message)
132
+ # logging.info(log_message)
133
  except Exception as e:
134
  log_message = f"Failed to upload PDF to S3: {e}"
135
  print(log_message)
136
+ # logging.error(log_message)
137
  raise
138
 
139
  # Filter job_df to include rows only where the analysis date is after the current date - DAYS_TO_DISPLAY_WHOLE_DOCUMENT_JOBS
140
  if not job_df.empty:
141
+ job_df = job_df.loc[
142
+ job_df["job_date_time"]
143
+ > (
144
+ datetime.datetime.now()
145
+ - datetime.timedelta(days=DAYS_TO_DISPLAY_WHOLE_DOCUMENT_JOBS)
146
+ ),
147
+ :,
148
+ ]
149
 
150
  # If job_df is not empty
151
+ if not job_df.empty:
152
 
153
  if "file_name" in job_df.columns:
154
+ matching_job_id_file_names = job_df.loc[
155
+ (job_df["file_name"] == pdf_filename)
156
+ & (
157
+ job_df["signature_extraction"].astype(str)
158
+ == str(handwrite_signature_checkbox)
159
+ ),
160
+ "file_name",
161
+ ]
162
+ matching_job_id_file_names_dates = job_df.loc[
163
+ (job_df["file_name"] == pdf_filename)
164
+ & (
165
+ job_df["signature_extraction"].astype(str)
166
+ == str(handwrite_signature_checkbox)
167
+ ),
168
+ "job_date_time",
169
+ ]
170
+ matching_job_id = job_df.loc[
171
+ (job_df["file_name"] == pdf_filename)
172
+ & (
173
+ job_df["signature_extraction"].astype(str)
174
+ == str(handwrite_signature_checkbox)
175
+ ),
176
+ "job_id",
177
+ ]
178
+ matching_handwrite_signature = job_df.loc[
179
+ (job_df["file_name"] == pdf_filename)
180
+ & (
181
+ job_df["signature_extraction"].astype(str)
182
+ == str(handwrite_signature_checkbox)
183
+ ),
184
+ "signature_extraction",
185
+ ]
186
 
187
  if len(matching_job_id) > 0:
188
  pass
189
  else:
190
  matching_job_id = "unknown_job_id"
191
 
192
+ if (
193
+ len(matching_job_id_file_names) > 0
194
+ and len(matching_handwrite_signature) > 0
195
+ ):
196
  out_message = f"Existing Textract outputs found for file {pdf_filename} from date {matching_job_id_file_names_dates.iloc[0]}. No need to re-analyse. Please download existing results from the list with job ID {matching_job_id.iloc[0]}"
197
  gr.Warning(out_message)
198
  raise Exception(out_message)
 
202
  print(message)
203
 
204
  try:
205
+ if (
206
+ "Extract signatures" in handwrite_signature_checkbox
207
+ or "Extract forms" in handwrite_signature_checkbox
208
+ or "Extract layout" in handwrite_signature_checkbox
209
+ or "Extract tables" in handwrite_signature_checkbox
210
+ ):
211
  feature_types = list()
212
+ if "Extract signatures" in handwrite_signature_checkbox:
213
+ feature_types.append("SIGNATURES")
214
  if "Extract forms" in handwrite_signature_checkbox:
215
+ feature_types.append("FORMS")
216
  if "Extract layout" in handwrite_signature_checkbox:
217
+ feature_types.append("LAYOUT")
218
  if "Extract tables" in handwrite_signature_checkbox:
219
+ feature_types.append("TABLES")
220
  response = textract_client.start_document_analysis(
221
  DocumentLocation={
222
+ "S3Object": {"Bucket": s3_bucket_name, "Name": s3_input_key}
 
 
 
223
  },
224
+ FeatureTypes=feature_types, # Analyze for signatures, forms, and tables
225
+ OutputConfig={"S3Bucket": s3_bucket_name, "S3Prefix": s3_output_prefix},
 
 
 
226
  )
227
+ job_type = "document_analysis"
228
+
229
+ if (
230
+ "Extract signatures" not in handwrite_signature_checkbox
231
+ and "Extract forms" not in handwrite_signature_checkbox
232
+ and "Extract layout" not in handwrite_signature_checkbox
233
+ and "Extract tables" not in handwrite_signature_checkbox
234
+ ):
235
  response = textract_client.start_document_text_detection(
236
  DocumentLocation={
237
+ "S3Object": {"Bucket": s3_bucket_name, "Name": s3_input_key}
 
 
 
238
  },
239
+ OutputConfig={"S3Bucket": s3_bucket_name, "S3Prefix": s3_output_prefix},
 
 
 
240
  )
241
+ job_type = "document_text_detection"
242
 
243
+ job_id = response["JobId"]
244
  print(f"Textract job started with JobId: {job_id}")
245
 
246
  # Prepare CSV in memory
247
  log_csv_key_location = f"{s3_output_prefix}/textract_document_jobs.csv"
 
248
 
249
+ StringIO()
250
+ log_df = pd.DataFrame(
251
+ [
252
+ {
253
+ "job_id": job_id,
254
+ "file_name": pdf_filename,
255
+ "job_type": job_type,
256
+ "signature_extraction": handwrite_signature_checkbox,
257
+ "job_date_time": datetime.datetime.now().strftime(
258
+ "%Y-%m-%d %H:%M:%S"
259
+ ),
260
+ }
261
+ ]
262
+ )
263
 
264
  # File path
265
  log_file_path = os.path.join(local_output_dir, "textract_document_jobs.csv")
266
+ log_file_path_job_id = os.path.join(
267
+ local_output_dir, pdf_filename + "_textract_document_jobs_job_id.txt"
268
+ )
269
 
270
  # Write latest job ID to local text file
271
+ with open(log_file_path_job_id, "w") as f:
272
  f.write(job_id)
273
 
274
  # Check if file exists
275
  file_exists = os.path.exists(log_file_path)
276
 
277
  # Append to CSV if it exists, otherwise write with header
278
+ log_df.to_csv(log_file_path, mode="a", index=False, header=not file_exists)
279
+
280
+ # log_df.to_csv(csv_buffer)
281
 
282
  # Upload the file
283
+ s3_client.upload_file(
284
+ log_file_path, general_s3_bucket_name, log_csv_key_location
285
+ )
286
 
287
  # Upload to S3 (overwrite existing file)
288
+ # s3_client.put_object(Bucket=general_s3_bucket_name, Key=log_csv_key_location, Body=csv_buffer.getvalue())
289
  print(f"Job ID written to {log_csv_key_location}")
290
+ # logging.info(f"Job ID written to s3://{s3_bucket_name}/{s3_output_prefix}/textract_document_jobs.csv")
291
 
292
  except Exception as e:
293
  error = f"Failed to start Textract job: {e}"
294
  print(error)
295
+ # logging.error(error)
296
  raise
297
 
298
  successful_job_number += 1
299
  total_number_of_textract_page_calls = total_document_page_count
300
 
301
+ return (
302
+ f"Textract analysis job submitted, job ID:{job_id}",
303
+ job_id,
304
+ job_type,
305
+ successful_job_number,
306
+ is_a_textract_api_call,
307
+ total_number_of_textract_page_calls,
308
+ task_textbox,
309
+ )
310
+
311
 
312
+ def return_job_status(
313
+ job_id: str,
314
+ response: dict,
315
+ attempts: int,
316
+ poll_interval_seconds: int = 0,
317
+ max_polling_attempts: int = 1, # ~10 minutes total wait time
318
+ ):
319
+ """
320
  Polls the AWS Textract service to retrieve the current status of an asynchronous document analysis job.
321
  This function checks the job status from the provided response and logs relevant information or errors.
322
 
 
332
 
333
  Raises:
334
  Exception: If the Textract job status is 'FAILED' or 'PARTIAL_SUCCESS', or if an unexpected status is encountered.
335
+ """
336
 
337
+ job_status = response["JobStatus"]
338
+ logging.info(
339
+ f"Polling attempt {attempts}/{max_polling_attempts}. Job status: {job_status}"
340
+ )
341
 
342
+ if job_status == "IN_PROGRESS":
343
  pass
344
+ # time.sleep(poll_interval_seconds)
345
+ elif job_status == "SUCCEEDED":
346
  logging.info("Textract job succeeded.")
347
+ elif job_status in ["FAILED", "PARTIAL_SUCCESS"]:
348
+ status_message = response.get("StatusMessage", "No status message provided.")
349
+ warnings = response.get("Warnings", [])
350
+ logging.error(
351
+ f"Textract job ended with status: {job_status}. Message: {status_message}"
352
+ )
353
  if warnings:
354
  logging.warning(f"Warnings: {warnings}")
355
  # Decide if PARTIAL_SUCCESS should proceed or raise error
356
  # For simplicity here, we raise for both FAILED and PARTIAL_SUCCESS
357
+ raise Exception(
358
+ f"Textract job {job_id} failed or partially failed. Status: {job_status}. Message: {status_message}"
359
+ )
360
  else:
361
  # Should not happen based on documentation, but handle defensively
362
  raise Exception(f"Unexpected Textract job status: {job_status}")
363
+
364
  return job_status
365
 
366
+
367
+ def download_textract_job_files(
368
+ s3_client: str,
369
+ s3_bucket_name: str,
370
+ s3_output_key_prefix: str,
371
+ pdf_filename: str,
372
+ job_id: str,
373
+ local_output_dir: str,
374
+ ):
375
+ """
376
  Download and combine selected job files from the AWS Textract service.
377
+ """
378
 
379
+ # print("s3_output_key_prefix at download:", s3_output_key_prefix)
380
 
381
  list_response = s3_client.list_objects_v2(
382
+ Bucket=s3_bucket_name, Prefix=s3_output_key_prefix
 
383
  )
384
 
385
+ output_files = list_response.get("Contents", [])
386
  if not output_files:
387
  # Sometimes Textract might take a moment longer to write the output after SUCCEEDED status
388
+ # logging.warning("No output files found immediately after job success. Waiting briefly and retrying list...")
389
+ # time.sleep(5)
390
  list_response = s3_client.list_objects_v2(
391
+ Bucket=s3_bucket_name, Prefix=s3_output_key_prefix
 
392
  )
393
+ output_files = list_response.get("Contents", [])
394
 
395
  if not output_files:
396
+ logging.error(
397
+ f"No output files found in s3://{s3_bucket_name}/{s3_output_key_prefix}"
398
+ )
399
  # You could alternatively try getting results via get_document_analysis pagination here
400
  # but sticking to the request to download from S3 output path.
401
+ raise FileNotFoundError(
402
+ f"Textract output files not found in S3 path: s3://{s3_bucket_name}/{s3_output_key_prefix}"
403
+ )
404
 
405
  # Usually, we only need the first/main JSON output file(s)
406
  # For simplicity, download the first one found. A more complex scenario might merge multiple files.
407
  # Filter out potential directory markers if any key ends with '/'
408
  json_files_to_download = [
409
+ f
410
+ for f in output_files
411
+ if f["Key"] != s3_output_key_prefix
412
+ and not f["Key"].endswith("/")
413
+ and "access_check" not in f["Key"]
414
+ ]
415
 
416
+ # print("json_files_to_download:", json_files_to_download)
417
 
418
  if not json_files_to_download:
419
  error = f"No JSON files found (only prefix marker?) in s3://{s3_bucket_name}/{s3_output_key_prefix}"
420
  print(error)
421
+ # logging.error(error)
422
  raise FileNotFoundError(error)
423
 
424
  combined_blocks = []
425
 
426
+ for f in sorted(
427
+ json_files_to_download, key=lambda x: x["Key"]
428
+ ): # Optional: sort to ensure consistent order
429
+ obj = s3_client.get_object(Bucket=s3_bucket_name, Key=f["Key"])
430
+ data = json.loads(obj["Body"].read())
431
+
432
  # Assuming Textract-style output with a "Blocks" key
433
  if "Blocks" in data:
434
  combined_blocks.extend(data["Blocks"])
 
438
  # Build final combined JSON structure
439
  combined_output = {
440
  "DocumentMetadata": {
441
+ "Pages": len(set(block.get("Page", 1) for block in combined_blocks))
442
  },
443
  "Blocks": combined_blocks,
444
+ "JobStatus": "SUCCEEDED",
445
  }
446
 
447
  output_filename_base = os.path.basename(pdf_filename)
 
449
  local_output_filename = f"{output_filename_base_no_ext}_textract.json"
450
  local_output_path = os.path.join(local_output_dir, local_output_filename)
451
 
452
+ with open(local_output_path, "w") as f:
453
  json.dump(combined_output, f)
454
 
455
  print(f"Combined Textract output written to {local_output_path}")
 
460
  downloaded_file_path = local_output_path
461
 
462
  # Log if multiple files were found, as user might need to handle them
463
+ # if len(json_files_to_download) > 1:
464
  # logging.warning(f"Multiple output files found in S3 output location. Downloaded the first: '{s3_output_key}'. Other files exist.")
465
 
466
  return downloaded_file_path
467
 
468
+
469
+ def check_for_provided_job_id(job_id: str):
470
  if not job_id:
471
+ raise Exception("Please provide a job ID.")
472
  return
473
 
474
+
475
  def load_pdf_job_file_from_s3(
476
  load_s3_jobs_input_loc,
477
  pdf_filename,
478
  local_output_dir,
479
  s3_bucket_name,
480
+ RUN_AWS_FUNCTIONS=RUN_AWS_FUNCTIONS,
481
+ ):
482
 
483
  try:
484
+ pdf_file_location = ""
485
+ doc_file_name_no_extension_textbox = ""
486
 
487
+ s3_input_key_prefix = os.path.join(
488
+ load_s3_jobs_input_loc, pdf_filename
489
+ ).replace("\\", "/")
490
  s3_input_key_prefix = s3_input_key_prefix + ".pdf"
491
+
492
  local_input_file_path = os.path.join(local_output_dir, pdf_filename)
493
  local_input_file_path = local_input_file_path + ".pdf"
494
 
495
+ download_file_from_s3(
496
+ s3_bucket_name,
497
+ s3_input_key_prefix,
498
+ local_input_file_path,
499
+ RUN_AWS_FUNCTIONS=RUN_AWS_FUNCTIONS,
500
+ )
501
+
502
  pdf_file_location = [local_input_file_path]
503
  doc_file_name_no_extension_textbox = get_file_name_without_type(pdf_filename)
504
  except Exception as e:
505
+ print("Could not download PDF job file from S3 due to:", e)
506
 
507
  return pdf_file_location, doc_file_name_no_extension_textbox
508
 
509
+
510
+ def replace_existing_pdf_input_for_whole_document_outputs(
511
+ load_s3_jobs_input_loc: str,
512
+ pdf_filename: str,
513
+ local_output_dir: str,
514
+ s3_bucket_name: str,
515
+ in_doc_files: FileData = [],
516
+ input_folder: str = INPUT_FOLDER,
517
  RUN_AWS_FUNCTIONS=RUN_AWS_FUNCTIONS,
518
+ progress=gr.Progress(track_tqdm=True),
519
+ ):
520
 
521
  progress(0.1, "Loading PDF from s3")
522
 
523
  if in_doc_files:
524
+ (
525
+ doc_file_name_no_extension_textbox,
526
+ doc_file_name_with_extension_textbox,
527
+ doc_full_file_name_textbox,
528
+ doc_file_name_textbox_list,
529
+ total_pdf_page_count,
530
+ ) = get_input_file_names(in_doc_files)
531
 
532
  if pdf_filename == doc_file_name_no_extension_textbox:
533
  print("Existing loaded PDF file has same name as file from S3")
534
  doc_file_name_no_extension_textbox = pdf_filename
535
  downloaded_pdf_file_location = in_doc_files
536
  else:
537
+ downloaded_pdf_file_location, doc_file_name_no_extension_textbox = (
538
+ load_pdf_job_file_from_s3(
539
+ load_s3_jobs_input_loc,
540
+ pdf_filename,
541
+ local_output_dir,
542
+ s3_bucket_name,
543
+ RUN_AWS_FUNCTIONS=RUN_AWS_FUNCTIONS,
544
+ )
545
+ )
546
 
547
+ (
548
+ doc_file_name_no_extension_textbox,
549
+ doc_file_name_with_extension_textbox,
550
+ doc_full_file_name_textbox,
551
+ doc_file_name_textbox_list,
552
+ total_pdf_page_count,
553
+ ) = get_input_file_names(downloaded_pdf_file_location)
554
+ else:
555
+ downloaded_pdf_file_location, doc_file_name_no_extension_textbox = (
556
+ load_pdf_job_file_from_s3(
557
+ load_s3_jobs_input_loc,
558
+ pdf_filename,
559
+ local_output_dir,
560
+ s3_bucket_name,
561
+ RUN_AWS_FUNCTIONS=RUN_AWS_FUNCTIONS,
562
+ )
563
+ )
564
 
565
+ (
566
+ doc_file_name_no_extension_textbox,
567
+ doc_file_name_with_extension_textbox,
568
+ doc_full_file_name_textbox,
569
+ doc_file_name_textbox_list,
570
+ total_pdf_page_count,
571
+ ) = get_input_file_names(downloaded_pdf_file_location)
572
+
573
+ return (
574
+ downloaded_pdf_file_location,
575
+ doc_file_name_no_extension_textbox,
576
+ doc_file_name_with_extension_textbox,
577
+ doc_full_file_name_textbox,
578
+ doc_file_name_textbox_list,
579
+ total_pdf_page_count,
580
+ )
581
 
 
582
 
583
  def poll_whole_document_textract_analysis_progress_and_download(
584
+ job_id: str,
585
+ job_type_dropdown: str,
586
  s3_output_prefix: str,
587
+ pdf_filename: str,
588
+ job_df: pd.DataFrame,
589
  s3_bucket_name: str = TEXTRACT_WHOLE_DOCUMENT_ANALYSIS_BUCKET,
590
  local_output_dir: str = OUTPUT_FOLDER,
591
+ load_s3_jobs_loc: str = TEXTRACT_JOBS_S3_LOC,
592
+ load_local_jobs_loc: str = TEXTRACT_JOBS_LOCAL_LOC,
593
+ aws_region: str = AWS_REGION, # Optional: specify region if not default
594
+ load_jobs_from_s3: str = LOAD_PREVIOUS_TEXTRACT_JOBS_S3,
595
  poll_interval_seconds: int = 1,
596
+ max_polling_attempts: int = 1, # ~10 minutes total wait time):
597
  DAYS_TO_DISPLAY_WHOLE_DOCUMENT_JOBS: int = DAYS_TO_DISPLAY_WHOLE_DOCUMENT_JOBS,
598
+ progress=gr.Progress(track_tqdm=True),
599
+ ):
600
+ """
601
  Poll AWS for the status of a Textract API job. Return status, and if finished, combine and download results into a locally-stored json file for further processing by the app.
602
+ """
603
 
604
  progress(0.1, "Querying AWS Textract for status of document analysis job")
605
 
606
  if job_id:
607
  # Initialize boto3 clients
608
  session = boto3.Session(region_name=aws_region)
609
+ s3_client = session.client("s3")
610
+ textract_client = session.client("textract")
611
 
612
  # --- 3. Poll for Job Completion ---
613
+ job_status = "IN_PROGRESS"
614
  attempts = 0
615
 
616
  message = "Polling Textract for job completion status..."
617
  print(message)
618
+ # logging.info("Polling Textract for job completion status...")
619
 
620
  # Update Textract document history df
621
  try:
622
+ job_df = load_in_textract_job_details(
623
+ load_s3_jobs=load_jobs_from_s3,
624
+ load_s3_jobs_loc=load_s3_jobs_loc,
625
+ load_local_jobs_loc=load_local_jobs_loc,
626
+ )
627
  except Exception as e:
628
+ # logging.error(f"Failed to update job details dataframe: {e}")
629
  print(f"Failed to update job details dataframe: {e}")
630
+ # raise
631
 
632
+ while job_status == "IN_PROGRESS" and attempts <= max_polling_attempts:
633
  attempts += 1
634
  try:
635
+ if job_type_dropdown == "document_analysis":
636
  response = textract_client.get_document_analysis(JobId=job_id)
637
+ job_status = return_job_status(
638
+ job_id,
639
+ response,
640
+ attempts,
641
+ poll_interval_seconds,
642
+ max_polling_attempts,
643
+ )
644
+ elif job_type_dropdown == "document_text_detection":
645
  response = textract_client.get_document_text_detection(JobId=job_id)
646
+ job_status = return_job_status(
647
+ job_id,
648
+ response,
649
+ attempts,
650
+ poll_interval_seconds,
651
+ max_polling_attempts,
652
+ )
653
  else:
654
+ error = "Unknown job type, cannot poll job"
655
  print(error)
656
+ logging.error(error)
657
+ raise Exception(error)
658
 
659
  except textract_client.exceptions.InvalidJobIdException:
660
  error_message = f"Invalid JobId: {job_id}. This might happen if the job expired (older than {DAYS_TO_DISPLAY_WHOLE_DOCUMENT_JOBS} days) or never existed."
 
662
  logging.error(error_message)
663
  raise Exception(error_message)
664
  except Exception as e:
665
+ error_message = (
666
+ f"Error while polling Textract status for job {job_id}: {e}"
667
+ )
668
  print(error_message)
669
  logging.error(error_message)
670
  raise Exception(error_message)
671
 
672
  downloaded_file_path = None
673
+ if job_status == "SUCCEEDED":
674
+ # raise TimeoutError(f"Textract job {job_id} did not complete successfully within the polling limit.")
675
+ # 3b - Replace PDF file name if it exists in the job dataframe
676
 
677
+ progress(0.5, "Document analysis task outputs found. Downloading from S3")
678
 
679
  # If job_df is not empty
680
 
 
683
 
684
  if not job_df.empty:
685
  if "file_name" in job_df.columns:
686
+ matching_job_id_file_names = job_df.loc[
687
+ job_df["job_id"] == job_id, "file_name"
688
+ ]
689
 
690
  if pdf_filename and not matching_job_id_file_names.empty:
691
  if pdf_filename == matching_job_id_file_names.iloc[0]:
 
702
  # Textract typically creates output under s3_output_prefix/job_id/
703
  # There might be multiple JSON files if pagination occurred during writing.
704
  # Usually, for smaller docs, there's one file, often named '1'.
705
+ # For robust handling, list objects and find the JSON(s).
706
 
707
+ s3_output_key_prefix = (
708
+ os.path.join(s3_output_prefix, job_id).replace("\\", "/") + "/"
709
+ )
710
+ logging.info(
711
+ f"Searching for output files in s3://{s3_bucket_name}/{s3_output_key_prefix}"
712
+ )
713
 
714
  try:
715
+ downloaded_file_path = download_textract_job_files(
716
+ s3_client,
717
+ s3_bucket_name,
718
+ s3_output_key_prefix,
719
+ pdf_filename,
720
+ job_id,
721
+ local_output_dir,
722
+ )
723
 
724
  except Exception as e:
725
+ # logging.error(f"Failed to download or process Textract output from S3: {e}")
726
  print(f"Failed to download or process Textract output from S3: {e}")
727
  raise
728
 
729
  else:
730
+ raise Exception("No Job ID provided.")
731
+
732
  output_pdf_filename = get_file_name_without_type(pdf_filename)
733
 
734
  return downloaded_file_path, job_status, job_df, output_pdf_filename
735
 
736
+
737
+ def load_in_textract_job_details(
738
+ load_s3_jobs: str = LOAD_PREVIOUS_TEXTRACT_JOBS_S3,
739
+ load_s3_jobs_loc: str = TEXTRACT_JOBS_S3_LOC,
740
+ load_local_jobs_loc: str = TEXTRACT_JOBS_LOCAL_LOC,
741
+ document_redaction_bucket: str = DOCUMENT_REDACTION_BUCKET,
742
+ aws_region: str = AWS_REGION,
743
+ DAYS_TO_DISPLAY_WHOLE_DOCUMENT_JOBS: int = DAYS_TO_DISPLAY_WHOLE_DOCUMENT_JOBS,
744
+ ):
745
+ """
746
  Load in a dataframe of jobs previous submitted to the Textract API service.
747
+ """
748
+ job_df = pd.DataFrame(
749
+ columns=[
750
+ "job_id",
751
+ "file_name",
752
+ "job_type",
753
+ "signature_extraction",
754
+ "job_date_time",
755
+ ]
756
+ )
757
 
758
  # Initialize boto3 clients
759
  session = boto3.Session(region_name=aws_region)
760
+ s3_client = session.client("s3")
761
+
762
+ local_output_path = f"{load_local_jobs_loc}/textract_document_jobs.csv"
763
 
764
+ if load_s3_jobs == "True":
765
+ s3_output_key = f"{load_s3_jobs_loc}/textract_document_jobs.csv"
766
 
 
 
 
767
  try:
768
  s3_client.head_object(Bucket=document_redaction_bucket, Key=s3_output_key)
769
+ # print(f"File exists. Downloading from '{s3_output_key}' to '{local_output_path}'...")
770
+ s3_client.download_file(
771
+ document_redaction_bucket, s3_output_key, local_output_path
772
+ )
773
+ # print("Download successful.")
774
  except ClientError as e:
775
+ if e.response["Error"]["Code"] == "404":
776
  print("Log file does not exist in S3.")
777
  else:
778
  print(f"Unexpected error occurred: {e}")
 
786
  job_df = pd.read_csv(local_output_path)
787
 
788
  if "job_date_time" in job_df.columns:
789
+ job_df["job_date_time"] = pd.to_datetime(
790
+ job_df["job_date_time"], errors="coerce"
791
+ )
792
  # Keep only jobs that have been completed in the last 'DAYS_TO_DISPLAY_WHOLE_DOCUMENT_JOBS' days
793
+ cutoff_time = pd.Timestamp.now() - pd.Timedelta(
794
+ days=DAYS_TO_DISPLAY_WHOLE_DOCUMENT_JOBS
795
+ )
796
+ job_df = job_df.loc[job_df["job_date_time"] > cutoff_time, :]
797
 
798
  try:
799
+ job_df = job_df[
800
+ [
801
+ "job_id",
802
+ "file_name",
803
+ "job_type",
804
+ "signature_extraction",
805
+ "job_date_time",
806
+ ]
807
+ ]
808
  except Exception as e:
809
+ print(
810
+ "Could not find one or more columns in Textract whole document list dataframe:",
811
+ e,
812
+ )
813
 
814
  return job_df
815
 
816
+
817
+ def download_textract_output(
818
+ job_id: str, output_bucket: str, output_prefix: str, local_folder: str
819
+ ):
820
  """
821
  Checks the status of a Textract job and downloads the output ZIP file if the job is complete.
822
 
 
825
  :param output_prefix: The prefix (folder path) in S3 where the output file is stored.
826
  :param local_folder: The local directory where the ZIP file should be saved.
827
  """
828
+ textract_client = boto3.client("textract")
829
+ s3_client = boto3.client("s3")
830
 
831
  # Check job status
832
  while True:
833
  response = textract_client.get_document_analysis(JobId=job_id)
834
+ status = response["JobStatus"]
835
+
836
+ if status == "SUCCEEDED":
837
  print("Job completed successfully.")
838
  break
839
+ elif status == "FAILED":
840
+ print(
841
+ "Job failed:",
842
+ response.get("StatusMessage", "No error message provided."),
843
+ )
844
  return
845
  else:
846
  print(f"Job is still {status}.")
847
+ # time.sleep(10) # Wait before checking again
848
 
849
  # Find output ZIP file in S3
850
  output_file_key = f"{output_prefix}/{job_id}.zip"
 
857
  except Exception as e:
858
  print(f"Error downloading file: {e}")
859
 
860
+
861
  def check_textract_outputs_exist(textract_output_found_checkbox):
862
+ if textract_output_found_checkbox is True:
863
+ print("Textract outputs found")
864
+ return
865
+ else:
866
+ raise Exception(
867
+ "Relevant Textract outputs not found. Please ensure you have selected to correct results output and you have uploaded the relevant document file in 'Choose document or image file...' above"
868
+ )