import os
import tempfile

import fitz
import gradio as gr
import PIL
import skimage
from fastai.learner import load_learner
from fastai.vision.all import *
from fpdf import FPDF
from huggingface_hub import hf_hub_download
from icevision.all import *
from icevision.models.checkpoint import *
from PIL import Image as PILImage

# checkpoint_path = "./2022-01-15-vfnet-post-self-train.pth"
checkpoint_path = "./allsynthetic-imgsize768.pth"
checkpoint_and_model = model_from_checkpoint(checkpoint_path)
model = checkpoint_and_model["model"]
model_type = checkpoint_and_model["model_type"]
class_map = checkpoint_and_model["class_map"]

img_size = checkpoint_and_model["img_size"]
valid_tfms = tfms.A.Adapter(
    [*tfms.A.resize_and_pad(img_size), tfms.A.Normalize()]
)


learn = load_learner(
    hf_hub_download("strickvl/redaction-classifier-fastai", "model.pkl")
)

labels = learn.dls.vocab


def predict(pdf, confidence, generate_file):
    filename_without_extension = pdf.name[:-4]
    document = fitz.open(pdf.name)
    results = []
    images = []
    tmp_dir = tempfile.gettempdir()
    for page_num, page in enumerate(document, start=1):
        image_pixmap = page.get_pixmap()
        image = image_pixmap.tobytes()
        _, _, probs = learn.predict(image)
        results.append(
            {labels[i]: float(probs[i]) for i in range(len(labels))}
        )
        if probs[0] > (confidence / 100):
            redaction_count = len(images)
            if not os.path.exists(
                os.path.join(tmp_dir, filename_without_extension)
            ):
                os.makedirs(os.path.join(tmp_dir, filename_without_extension))
            image_pixmap.save(
                os.path.join(
                    tmp_dir, filename_without_extension, f"page-{page_num}.png"
                )
            )
            images.append(
                [
                    f"Redacted page #{redaction_count + 1} on page {page_num}",
                    os.path.join(
                        tmp_dir,
                        filename_without_extension,
                        f"page-{page_num}.png",
                    ),
                ]
            )

    redacted_pages = [
        str(page + 1)
        for page in range(len(results))
        if results[page]["redacted"] > (confidence / 100)
    ]
    report = os.path.join(
        tmp_dir, filename_without_extension, "redacted_pages.pdf"
    )
    if generate_file:
        pdf = FPDF()
        pdf.set_auto_page_break(0)
        imagelist = sorted(
            [
                i
                for i in os.listdir(
                    os.path.join(tmp_dir, filename_without_extension)
                )
                if i.endswith("png")
            ]
        )
        for image in imagelist:
            with PILImage.open(
                os.path.join(tmp_dir, filename_without_extension, image)
            ) as img:
                size = img.size
                if size[0] > size[1]:
                    pdf.add_page("L")
                else:
                    pdf.add_page("P")
                pred_dict = model_type.end2end_detect(
                    img,
                    valid_tfms,
                    model,
                    class_map=class_map,
                    detection_threshold=confidence / 100,
                    display_label=True,
                    display_bbox=True,
                    return_img=True,
                    font_size=16,
                    label_color="#FF59D6",
                )
                print(pred_dict)
                pred_dict["img"].save(
                    os.path.join(
                        tmp_dir, filename_without_extension, f"pred-{image}"
                    )
                )
            # TODO: resize image such that it fits the pdf
            pdf.image(
                os.path.join(
                    tmp_dir, filename_without_extension, f"pred-{image}"
                )
            )
        pdf.output(report, "F")

    text_output = f"A total of {len(redacted_pages)} pages were redacted. \n\n The redacted page numbers were: {', '.join(redacted_pages)}."

    if generate_file:
        return text_output, images, report
    else:
        return text_output, images, None


title = "Redaction Detector"

description = "A classifier trained on publicly released redacted (and unredacted) FOIA documents, using [fastai](https://github.com/fastai/fastai)."

with open("article.md") as f:
    article = f.read()

examples = [["test1.pdf", 80, False], ["test2.pdf", 80, False]]
interpretation = "default"
enable_queue = True
theme = "grass"
allow_flagging = "never"

demo = gr.Interface(
    fn=predict,
    inputs=[
        gr.inputs.File(label="PDF file", file_count="single"),
        gr.inputs.Slider(
            minimum=0,
            maximum=100,
            step=None,
            default=80,
            label="Confidence",
            optional=False,
        ),
        gr.inputs.Checkbox(label="Extract redacted images"),
    ],
    outputs=[
        gr.outputs.Textbox(label="Document Analysis"),
        gr.outputs.Carousel(["text", "image"], label="Redacted pages"),
        gr.outputs.File(label="Download redacted pages"),
    ],
    title=title,
    description=description,
    article=article,
    theme=theme,
    allow_flagging=allow_flagging,
    examples=examples,
    interpretation=interpretation,
)

demo.launch(
    cache_examples=True,
    enable_queue=enable_queue,
)