# ------------------------------------------
# TextDiffuser: Diffusion Models as Text Painters
# Paper Link: https://arxiv.org/abs/2305.10855
# Code Link: https://github.com/microsoft/unilm/tree/master/textdiffuser
# Copyright (c) Microsoft Corporation.
# This file defines a set of commonly used utility functions.
# ------------------------------------------

import os
import re
import cv2
import math
import shutil
import string
import textwrap
import numpy as np
from PIL import Image, ImageFont, ImageDraw, ImageOps

from typing import *

# define alphabet and alphabet_dic
alphabet = string.digits + string.ascii_lowercase + string.ascii_uppercase + string.punctuation + ' ' # len(aphabet) = 95
alphabet_dic = {}
for index, c in enumerate(alphabet):
    alphabet_dic[c] = index + 1 # the index 0 stands for non-character
    


def transform_mask_pil(mask_root, size):
    """
    This function extracts the mask area and text area from the images.
    
    Args:
        mask_root (str): The path of mask image.
            * The white area is the unmasked area
            * The gray area is the masked area
            * The white area is the text area
    """
    img = np.array(mask_root)
    img = cv2.resize(img, (size, size), interpolation=cv2.INTER_NEAREST)
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    ret, binary = cv2.threshold(gray, 250, 255, cv2.THRESH_BINARY) # pixel value is set to 0 or 255 according to the threshold
    return 1 - (binary.astype(np.float32) / 255) 
    

def transform_mask(mask_root, size):
    """
    This function extracts the mask area and text area from the images.
    
    Args:
        mask_root (str): The path of mask image.
            * The white area is the unmasked area
            * The gray area is the masked area
            * The white area is the text area
    """
    img = cv2.imread(mask_root)
    img = cv2.resize(img, (size, size), interpolation=cv2.INTER_NEAREST)
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    ret, binary = cv2.threshold(gray, 250, 255, cv2.THRESH_BINARY) # pixel value is set to 0 or 255 according to the threshold
    return 1 - (binary.astype(np.float32) / 255) 


def segmentation_mask_visualization(font_path: str, segmentation_mask: np.array):
    """
    This function visualizes the segmentaiton masks with characters.
    
    Args:
        font_path (str): The path of font. We recommand to use Arial.ttf
        segmentation_mask (np.array): The character-level segmentation mask.
    """
    segmentation_mask = cv2.resize(segmentation_mask, (64, 64), interpolation=cv2.INTER_NEAREST)
    font = ImageFont.truetype(font_path, 8)
    blank = Image.new('RGB', (512,512), (0,0,0))
    d = ImageDraw.Draw(blank)
    for i in range(64):
        for j in range(64):
            if int(segmentation_mask[i][j]) == 0 or int(segmentation_mask[i][j])-1 >= len(alphabet): 
                continue
            else:
                d.text((j*8, i*8), alphabet[int(segmentation_mask[i][j])-1], font=font, fill=(0, 255, 0))
    return blank


def make_caption_pil(font_path: str, captions: List[str]):
    """
    This function converts captions into pil images.
    
    Args:
        font_path (str): The path of font. We recommand to use Arial.ttf
        captions (List[str]): List of captions.
    """
    caption_pil_list = []
    font = ImageFont.truetype(font_path, 18)

    for caption in captions:
        border_size = 2
        img = Image.new('RGB', (512-4,48-4), (255,255,255)) 
        img = ImageOps.expand(img, border=(border_size, border_size, border_size, border_size), fill=(127, 127, 127))
        draw = ImageDraw.Draw(img)
        border_size = 2
        text = caption
        lines = textwrap.wrap(text, width=40)
        x, y = 4, 4
        line_height = font.getsize('A')[1] + 4 

        start = 0
        for line in lines:
            draw.text((x, y+start), line, font=font, fill=(200, 127, 0))
            y += line_height

        caption_pil_list.append(img)
    return caption_pil_list


def filter_segmentation_mask(segmentation_mask: np.array):
    """
    This function removes some noisy predictions of segmentation masks.
    
    Args:
        segmentation_mask (np.array): The character-level segmentation mask.
    """
    segmentation_mask[segmentation_mask==alphabet_dic['-']] = 0
    segmentation_mask[segmentation_mask==alphabet_dic[' ']] = 0
    return segmentation_mask
    
    

def combine_image(args, resolution, sub_output_dir: str, pred_image_list: List, image_pil: Image, character_mask_pil: Image, character_mask_highlight_pil: Image, caption_pil_list: List):
    """
    This function combines all the outputs and useful inputs together.
    
    Args:
        args (argparse.ArgumentParser): The arguments.
        pred_image_list (List): List of predicted images.
        image_pil (Image): The original image.
        character_mask_pil (Image): The character-level segmentation mask.
        character_mask_highlight_pil (Image): The character-level segmentation mask highlighting character regions with green color.
        caption_pil_list (List): List of captions.
    """
    
    
    size = len(pred_image_list)
    
    if size == 1:
        return pred_image_list[0]
    elif size == 2:
        blank = Image.new('RGB', (resolution*2, resolution), (0,0,0))
        blank.paste(pred_image_list[0],(0,0))
        blank.paste(pred_image_list[1],(resolution,0))
    elif size == 3:
        blank = Image.new('RGB', (resolution*3, resolution), (0,0,0))
        blank.paste(pred_image_list[0],(0,0))
        blank.paste(pred_image_list[1],(resolution,0))
        blank.paste(pred_image_list[2],(resolution*2,0))
    elif size == 4:
        blank = Image.new('RGB', (resolution*2, resolution*2), (0,0,0))
        blank.paste(pred_image_list[0],(0,0))
        blank.paste(pred_image_list[1],(resolution,0))
        blank.paste(pred_image_list[2],(0,resolution))
        blank.paste(pred_image_list[3],(resolution,resolution))

    
    return blank
    
    
def combine_image_gradio(args, size, sub_output_dir: str, pred_image_list: List, image_pil: Image, character_mask_pil: Image, character_mask_highlight_pil: Image, caption_pil_list: List):
    """
    This function combines all the outputs and useful inputs together.
    
    Args:
        args (argparse.ArgumentParser): The arguments.
        pred_image_list (List): List of predicted images.
        image_pil (Image): The original image.
        character_mask_pil (Image): The character-level segmentation mask.
        character_mask_highlight_pil (Image): The character-level segmentation mask highlighting character regions with green color.
        caption_pil_list (List): List of captions.
    """
    
    size = len(pred_image_list)
    
    if size == 1:
        return pred_image_list[0]
    elif size == 2:
        blank = Image.new('RGB', (size*2, size), (0,0,0))
        blank.paste(pred_image_list[0],(0,0))
        blank.paste(pred_image_list[1],(size,0))
    elif size == 3:
        blank = Image.new('RGB', (size*3, size), (0,0,0))
        blank.paste(pred_image_list[0],(0,0))
        blank.paste(pred_image_list[1],(size,0))
        blank.paste(pred_image_list[2],(size*2,0))
    elif size == 4:
        blank = Image.new('RGB', (size*2, size*2), (0,0,0))
        blank.paste(pred_image_list[0],(0,0))
        blank.paste(pred_image_list[1],(size,0))
        blank.paste(pred_image_list[2],(0,size))
        blank.paste(pred_image_list[3],(size,size))

    
    return blank
    
def get_width(font_path, text):
    """
    This function calculates the width of the text.
    
    Args:
        font_path (str): user prompt.
        text (str): user prompt.
    """
    font = ImageFont.truetype(font_path, 24)
    width, _ = font.getsize(text)
    return width



def get_key_words(text: str):
    """
    This function detect keywords (enclosed by quotes) from user prompts. The keywords are used to guide the layout generation.
    
    Args:
        text (str): user prompt.
    """

    words = []
    text = text
    matches = re.findall(r"'(.*?)'", text) # find the keywords enclosed by ''
    if matches:
        for match in matches:
            words.extend(match.split())
            
    if len(words) >= 8:
        return []
   
    return words


def adjust_overlap_box(box_output, current_index):
    """
    This function adjust the overlapping boxes.
    
    Args:
        box_output (List): List of predicted boxes.
        current_index (int): the index of current box.
    """
    
    if current_index == 0:
        return box_output
    else:
        # judge whether it contains overlap with the last output
        last_box = box_output[0, current_index-1, :]
        xmin_last, ymin_last, xmax_last, ymax_last = last_box
        
        current_box = box_output[0, current_index, :]
        xmin, ymin, xmax, ymax = current_box
        
        if xmin_last <= xmin <= xmax_last and ymin_last <= ymin <= ymax_last:
            print('adjust overlapping')
            distance_x = xmax_last - xmin
            distance_y = ymax_last - ymin
            if distance_x <= distance_y:
                # avoid overlap
                new_x_min = xmax_last + 0.025
                new_x_max = xmax - xmin + xmax_last + 0.025
                box_output[0,current_index,0] = new_x_min
                box_output[0,current_index,2] = new_x_max
            else:
                new_y_min = ymax_last + 0.025
                new_y_max = ymax - ymin + ymax_last + 0.025
                box_output[0,current_index,1] = new_y_min
                box_output[0,current_index,3] = new_y_max  
                
        elif xmin_last <= xmin <= xmax_last and ymin_last <= ymax <= ymax_last:
            print('adjust overlapping')
            new_x_min = xmax_last + 0.05
            new_x_max = xmax - xmin + xmax_last + 0.05
            box_output[0,current_index,0] = new_x_min
            box_output[0,current_index,2] = new_x_max
                    
        return box_output
    
    
def shrink_box(box, scale_factor = 0.9):
    """
    This function shrinks the box.
    
    Args:
        box (List): List of predicted boxes.
        scale_factor (float): The scale factor of shrinking.
    """
    
    x1, y1, x2, y2 = box
    x1_new = x1 + (x2 - x1) * (1 - scale_factor) / 2
    y1_new = y1 + (y2 - y1) * (1 - scale_factor) / 2
    x2_new = x2 - (x2 - x1) * (1 - scale_factor) / 2
    y2_new = y2 - (y2 - y1) * (1 - scale_factor) / 2
    return (x1_new, y1_new, x2_new, y2_new)


def adjust_font_size(args, width, height, draw, text):
    """
    This function adjusts the font size.
    
    Args:
        args (argparse.ArgumentParser): The arguments.
        width (int): The width of the text.
        height (int): The height of the text.
        draw (ImageDraw): The ImageDraw object.
        text (str): The text.
    """
    
    size_start = height
    while True:
        font = ImageFont.truetype(args.font_path, size_start)
        text_width, _ = draw.textsize(text, font=font)
        if text_width >= width:
            size_start = size_start - 1
        else:
            return size_start
    
    
def inpainting_merge_image(original_image, mask_image, inpainting_image):
    """
    This function merges the original image, mask image and inpainting image.
        
    Args:
        original_image (PIL.Image): The original image.
        mask_image (PIL.Image): The mask images.
        inpainting_image (PIL.Image): The inpainting images.
    """
    
    original_image = original_image.resize((512, 512))
    mask_image = mask_image.resize((512, 512))
    inpainting_image = inpainting_image.resize((512, 512))
    mask_image.convert('L')
    threshold = 250 
    table = []
    for i in range(256):
        if i < threshold:
            table.append(1)
        else:
            table.append(0)
    mask_image = mask_image.point(table, "1")
    merged_image = Image.composite(inpainting_image, original_image, mask_image)
    return merged_image