Spaces:
Runtime error
Runtime error
import os | |
import tempfile | |
import fitz | |
import gradio as gr | |
import PIL | |
import skimage | |
from fastai.learner import load_learner | |
from fastai.vision.all import * | |
from fpdf import FPDF | |
from huggingface_hub import hf_hub_download | |
from icevision.all import * | |
from icevision.models.checkpoint import * | |
from PIL import Image as PILImage | |
checkpoint_path = "./allsynthetic-imgsize768.pth" | |
checkpoint_and_model = model_from_checkpoint(checkpoint_path) | |
model = checkpoint_and_model["model"] | |
model_type = checkpoint_and_model["model_type"] | |
class_map = checkpoint_and_model["class_map"] | |
img_size = checkpoint_and_model["img_size"] | |
valid_tfms = tfms.A.Adapter( | |
[*tfms.A.resize_and_pad(img_size), tfms.A.Normalize()] | |
) | |
learn = load_learner( | |
hf_hub_download("strickvl/redaction-classifier-fastai", "model.pkl") | |
) | |
labels = learn.dls.vocab | |
def get_content_area(pred_dict) -> int: | |
if "content" not in pred_dict["detection"]["labels"]: | |
return 0 | |
content_bboxes = [ | |
pred_dict["detection"]["bboxes"][idx] | |
for idx, label in enumerate(pred_dict["detection"]["labels"]) | |
if label == "content" | |
] | |
cb = content_bboxes[0] | |
return (cb.xmax - cb.xmin) * (cb.ymax - cb.ymin) | |
def get_redaction_area(pred_dict) -> int: | |
if "redaction" not in pred_dict["detection"]["labels"]: | |
return 0 | |
redaction_bboxes = [ | |
pred_dict["detection"]["bboxes"][idx] | |
for idx, label in enumerate(pred_dict["detection"]["labels"]) | |
if label == "redaction" | |
] | |
return sum( | |
(bbox.xmax - bbox.xmin) * (bbox.ymax - bbox.ymin) | |
for bbox in redaction_bboxes | |
) | |
def predict(pdf, confidence, generate_file): | |
filename_without_extension = pdf.name[:-4] | |
document = fitz.open(pdf.name) | |
results = [] | |
images = [] | |
total_image_areas = 0 | |
total_content_areas = 0 | |
total_redaction_area = 0 | |
tmp_dir = tempfile.gettempdir() | |
for page_num, page in enumerate(document, start=1): | |
image_pixmap = page.get_pixmap() | |
image = image_pixmap.tobytes() | |
_, _, probs = learn.predict(image) | |
results.append( | |
{labels[i]: float(probs[i]) for i in range(len(labels))} | |
) | |
if probs[0] > (confidence / 100): | |
redaction_count = len(images) | |
if not os.path.exists( | |
os.path.join(tmp_dir, filename_without_extension) | |
): | |
os.makedirs(os.path.join(tmp_dir, filename_without_extension)) | |
image_pixmap.save( | |
os.path.join( | |
tmp_dir, filename_without_extension, f"page-{page_num}.png" | |
) | |
) | |
images.append( | |
[ | |
f"Redacted page #{redaction_count + 1} on page {page_num}", | |
os.path.join( | |
tmp_dir, | |
filename_without_extension, | |
f"page-{page_num}.png", | |
), | |
] | |
) | |
redacted_pages = [ | |
str(page + 1) | |
for page in range(len(results)) | |
if results[page]["redacted"] > (confidence / 100) | |
] | |
report = os.path.join( | |
tmp_dir, filename_without_extension, "redacted_pages.pdf" | |
) | |
if generate_file: | |
pdf = FPDF() | |
pdf.set_auto_page_break(0) | |
imagelist = sorted( | |
[ | |
i | |
for i in os.listdir( | |
os.path.join(tmp_dir, filename_without_extension) | |
) | |
if i.endswith("png") | |
] | |
) | |
for image in imagelist: | |
with PILImage.open( | |
os.path.join(tmp_dir, filename_without_extension, image) | |
) as img: | |
size = img.size | |
width, height = size | |
if width > height: | |
pdf.add_page(orientation="L") | |
else: | |
pdf.add_page(orientation="P") | |
pred_dict = model_type.end2end_detect( | |
img, | |
valid_tfms, | |
model, | |
class_map=class_map, | |
detection_threshold=confidence / 100, | |
display_label=True, | |
display_bbox=True, | |
return_img=True, | |
font_size=16, | |
label_color="#FF59D6", | |
) | |
# print(pred_dict) | |
total_image_areas += pred_dict["width"] * pred_dict["height"] | |
total_content_areas += get_content_area(pred_dict) | |
total_redaction_area += get_redaction_area(pred_dict) | |
pred_dict["img"].save( | |
os.path.join( | |
tmp_dir, filename_without_extension, f"pred-{image}" | |
), | |
) | |
# TODO: resize image such that it fits the pdf | |
pdf.image( | |
os.path.join( | |
tmp_dir, filename_without_extension, f"pred-{image}" | |
) | |
) | |
pdf.output(report, "F") | |
text_output = f"A total of {len(redacted_pages)} pages were redacted. \n\nThe redacted page numbers were: {', '.join(redacted_pages)}. \n\n" | |
if not generate_file: | |
return text_output, images, None | |
total_redaction_proportion = round( | |
(total_redaction_area / total_image_areas) * 100, 1 | |
) | |
content_redaction_proportion = round( | |
(total_redaction_area / total_content_areas) * 100, 1 | |
) | |
redaction_analysis = f"- {total_redaction_proportion}% of the total area of the redacted pages was redacted. \n- {content_redaction_proportion}% of the actual content of those redacted pages was redacted." | |
return text_output + redaction_analysis, images, report | |
title = "Redaction Detector for PDFs" | |
description = "An MVP app for detection, extraction and analysis of PDF documents that contain redactions. Two models are used for this demo, both trained on publicly released redacted (and unredacted) FOIA documents. Classifcation model trained using [fastai](https://github.com/fastai/fastai) and the object detection model trained using [IceVision](https://airctic.com/0.12.0/)." | |
with open("article.md") as f: | |
article = f.read() | |
examples = [["test1.pdf", 80, False], ["test2.pdf", 80, False]] | |
interpretation = "default" | |
enable_queue = True | |
theme = "grass" | |
allow_flagging = "never" | |
demo = gr.Interface( | |
fn=predict, | |
inputs=[ | |
gr.inputs.File(label="PDF file", file_count="single"), | |
gr.inputs.Slider( | |
minimum=0, | |
maximum=100, | |
step=None, | |
default=80, | |
label="Confidence", | |
optional=False, | |
), | |
gr.inputs.Checkbox(label="Extract redacted images"), | |
], | |
outputs=[ | |
gr.outputs.Textbox(label="Document Analysis"), | |
gr.outputs.Carousel(["text", "image"], label="Redacted pages"), | |
gr.outputs.File(label="Download redacted pages"), | |
], | |
title=title, | |
description=description, | |
article=article, | |
theme=theme, | |
allow_flagging=allow_flagging, | |
examples=examples, | |
interpretation=interpretation, | |
) | |
demo.launch( | |
cache_examples=True, | |
enable_queue=enable_queue, | |
) | |