import os
import tempfile

import fitz
import gradio as gr
import PIL
import skimage
from fastai.learner import load_learner
from fastai.vision.all import *
from fpdf import FPDF
from huggingface_hub import hf_hub_download
from icevision.all import *
from icevision.models.checkpoint import *
from PIL import Image as PILImage

checkpoint_path = "./allsynthetic-imgsize768.pth"
checkpoint_and_model = model_from_checkpoint(checkpoint_path)
model = checkpoint_and_model["model"]
model_type = checkpoint_and_model["model_type"]
class_map = checkpoint_and_model["class_map"]

img_size = checkpoint_and_model["img_size"]
valid_tfms = tfms.A.Adapter(
    [*tfms.A.resize_and_pad(img_size), tfms.A.Normalize()]
)


learn = load_learner(
    hf_hub_download("strickvl/redaction-classifier-fastai", "model.pkl")
)

labels = learn.dls.vocab


def get_content_area(pred_dict) -> int:
    if "content" not in pred_dict["detection"]["labels"]:
        return 0
    content_bboxes = [
        pred_dict["detection"]["bboxes"][idx]
        for idx, label in enumerate(pred_dict["detection"]["labels"])
        if label == "content"
    ]
    cb = content_bboxes[0]
    return (cb.xmax - cb.xmin) * (cb.ymax - cb.ymin)


def get_redaction_area(pred_dict) -> int:
    if "redaction" not in pred_dict["detection"]["labels"]:
        return 0
    redaction_bboxes = [
        pred_dict["detection"]["bboxes"][idx]
        for idx, label in enumerate(pred_dict["detection"]["labels"])
        if label == "redaction"
    ]
    return sum(
        (bbox.xmax - bbox.xmin) * (bbox.ymax - bbox.ymin)
        for bbox in redaction_bboxes
    )


def predict(pdf, confidence, generate_file):
    filename_without_extension = pdf.name[:-4]
    document = fitz.open(pdf.name)
    results = []
    images = []
    total_image_areas = 0
    total_content_areas = 0
    total_redaction_area = 0
    tmp_dir = tempfile.gettempdir()
    for page_num, page in enumerate(document, start=1):
        image_pixmap = page.get_pixmap()
        image = image_pixmap.tobytes()
        _, _, probs = learn.predict(image)
        results.append(
            {labels[i]: float(probs[i]) for i in range(len(labels))}
        )
        if probs[0] > (confidence / 100):
            redaction_count = len(images)
            if not os.path.exists(
                os.path.join(tmp_dir, filename_without_extension)
            ):
                os.makedirs(os.path.join(tmp_dir, filename_without_extension))
            image_pixmap.save(
                os.path.join(
                    tmp_dir, filename_without_extension, f"page-{page_num}.png"
                )
            )
            images.append(
                [
                    f"Redacted page #{redaction_count + 1} on page {page_num}",
                    os.path.join(
                        tmp_dir,
                        filename_without_extension,
                        f"page-{page_num}.png",
                    ),
                ]
            )

    redacted_pages = [
        str(page + 1)
        for page in range(len(results))
        if results[page]["redacted"] > (confidence / 100)
    ]
    report = os.path.join(
        tmp_dir, filename_without_extension, "redacted_pages.pdf"
    )
    if generate_file:
        pdf = FPDF(unit="cm", format="A4")
        pdf.set_auto_page_break(0)
        imagelist = sorted(
            [
                i
                for i in os.listdir(
                    os.path.join(tmp_dir, filename_without_extension)
                )
                if i.endswith("png")
            ]
        )
        for image in imagelist:
            with PILImage.open(
                os.path.join(tmp_dir, filename_without_extension, image)
            ) as img:
                size = img.size
                width, height = size
                if width > height:
                    pdf.add_page(orientation="L")
                else:
                    pdf.add_page(orientation="P")
                pred_dict = model_type.end2end_detect(
                    img,
                    valid_tfms,
                    model,
                    class_map=class_map,
                    detection_threshold=confidence / 100,
                    display_label=True,
                    display_bbox=True,
                    return_img=True,
                    font_size=16,
                    label_color="#FF59D6",
                )

                total_image_areas += pred_dict["width"] * pred_dict["height"]
                total_content_areas += get_content_area(pred_dict)
                total_redaction_area += get_redaction_area(pred_dict)

                pred_dict["img"].save(
                    os.path.join(
                        tmp_dir, filename_without_extension, f"pred-{image}"
                    ),
                )
            # TODO: resize image such that it fits the pdf
            pdf.image(
                os.path.join(
                    tmp_dir, filename_without_extension, f"pred-{image}"
                ),
                w=pdf.w,
                h=pdf.h,
            )
        pdf.output(report, "F")

    text_output = f"A total of {len(redacted_pages)} pages were redacted. \n\nThe redacted page numbers were: {', '.join(redacted_pages)}. \n\n"

    if not generate_file:
        return text_output, images, None

    total_redaction_proportion = round(
        (total_redaction_area / total_image_areas) * 100, 1
    )
    content_redaction_proportion = round(
        (total_redaction_area / total_content_areas) * 100, 1
    )

    redaction_analysis = f"- {total_redaction_proportion}% of the total area of the redacted pages was redacted. \n- {content_redaction_proportion}% of the actual content of those redacted pages was redacted."

    return text_output + redaction_analysis, images, report


title = "Redaction Detector for PDFs"

description = "An MVP app for detection, extraction and analysis of PDF documents that contain redactions. Two models are used for this demo, both trained on publicly released redacted (and unredacted) FOIA documents: \n\n - Classification model trained using [fastai](https://github.com/fastai/fastai) \n- Object detection model trained using [IceVision](https://airctic.com/0.12.0/)"

with open("article.md") as f:
    article = f.read()

examples = [
    ["test1.pdf", 80, True],
    ["test2.pdf", 80, False],
    ["test3.pdf", 80, True],
    ["test4.pdf", 80, False],
    ["test5.pdf", 80, False],
]
interpretation = "default"
enable_queue = True
theme = "grass"
allow_flagging = "never"

demo = gr.Interface(
    fn=predict,
    inputs=[
        gr.inputs.File(label="PDF file", file_count="single"),
        gr.inputs.Slider(
            minimum=0,
            maximum=100,
            step=None,
            default=80,
            label="Confidence",
            optional=False,
        ),
        gr.inputs.Checkbox(
            label="Analyse and extract redacted images", default=True
        ),
    ],
    outputs=[
        gr.outputs.Textbox(label="Document Analysis"),
        gr.outputs.Carousel(["text", "image"], label="Redacted pages"),
        gr.outputs.File(label="Download redacted pages"),
    ],
    title=title,
    description=description,
    article=article,
    theme=theme,
    allow_flagging=allow_flagging,
    examples=examples,
    interpretation=interpretation,
)

demo.launch(
    cache_examples=True,
    enable_queue=enable_queue,
)