import cv2
import numpy as np
import matplotlib.pyplot as plt
import os
import cairosvg
from potrace import POTRACE_CORNER, Path, Bitmap
import io
from PIL import Image, ImageStat

import streamlit
from PIL import Image

@streamlit.cache_data
def pipeline_svg(image_input, size_value, level=3, streamlit=False, threshold=0, kernel_type=cv2.MORPH_ELLIPSE, dilate_lines_value=0):
    """
    uint8 ==> uint8

    Args:
        streamlit:
        size_value:
        image_input:

    Returns:

    """

    # Process image
    image_processed = process_svg(image_input, size_value=size_value, streamlit=streamlit, kernel_type=kernel_type, dilate_lines_value=dilate_lines_value)

    return image_processed

def process_svg(img, size_value=12, level=1, streamlit=False, kernel_type=cv2.MORPH_ELLIPSE, dilate_lines_value=0):

    image_path = "input_image.png"
    img = img.astype('uint8')

    # Lines very small
    if dilate_lines_value > 0:
        size = dilate_lines_value + 1 # No sens to dilate by one pixel (doesn't do anything).
        kernel = get_kernel_ellipse(size, kernel_type=kernel_type)
        img = cv2.erode(img, kernel, iterations=1)

    #cv2.imwrite(image_path, img)
    img_array = convert_to_svg_and_back(img)

    img_array = binarise(img_array)
    img_bin = (255 - img_array)
    img_bin = img_bin.astype('uint8')
    image_already_added = np.zeros_like(img_bin)

    target_min_size = max(1, size_value)

    image_final = img_bin
    for i in range(target_min_size+1):
        size = 2 * i + 1
        kernel = get_kernel_ellipse(size, kernel_type=kernel_type)

        erosion = cv2.erode((img_bin - image_already_added), kernel, iterations=1)
        dilation = cv2.dilate(erosion, kernel, iterations=1)

        image_petits_objets = (img_bin - dilation)
        image_petits_objets = remove_solo_pixels(image_petits_objets, kernel_size=3)

        size = 2 * (target_min_size - i) + 1
        kernel = get_kernel_ellipse(size, kernel_type=kernel_type)
        dilate_image_petits_objets = cv2.dilate(image_petits_objets, kernel, iterations=1)

        image_already_added = (image_already_added + image_petits_objets)

        if i > level:
            image_final = (image_final + dilate_image_petits_objets)

    #cv2.imwrite("image_finale.png", (255 - image_final))
    image = convert_to_svg_and_back((255 - image_final))
    return image
def get_kernel_ellipse(size, kernel_type=cv2.MORPH_ELLIPSE):
    list_coords = [size, size]
    return cv2.getStructuringElement(kernel_type, (list_coords[0], list_coords[1]),
                                     (int((list_coords[0] - 1) / 2), int((list_coords[1] - 1) / 2)))


def binarise(img):
    img = img > 200
    img = img * 255
    img = img.astype('uint8')
    return img


def imshow(title, image, vmin=0, vmax=1):
    plt.figure()
    plt.title(title)
    plt.imshow(image * 255, vmin=vmin * 255, vmax=vmax * 255, cmap='gray')


def remove_solo_pixels(image, kernel_size=3):
    kernel = get_kernel_ellipse(kernel_size)

    erosion = cv2.erode(image, kernel, iterations=1)
    dilation = cv2.dilate(erosion, kernel, iterations=1)

    dilation = dilation.astype('uint8')
    return dilation

def convert_to_svg_and_back(image_array) -> np.array:
    image_pil = Image.fromarray(image_array)

    bm = Bitmap(image_pil, blacklevel=0.5)

    plist = bm.trace(
        turdsize=2,
        turnpolicy=4,
        alphamax=1,
        opticurve= False,
        opttolerance=0.2)

    image = backend_svg_no_file(image_pil, plist)

    return np.array(image)

def backend_svg_no_file(image, path: Path):
    output = f'<svg version="1.1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" width="{image.width}" height="{image.height}" viewBox="0 0 {image.width} {image.height}">'

    parts = []
    for curve in path:
        fs = curve.start_point
        parts.append("M%f,%f" % (fs.x, fs.y))
        for segment in curve.segments:
            if segment.is_corner:
                a = segment.c
                parts.append("L%f,%f" % (a.x, a.y))
                b = segment.end_point
                parts.append("L%f,%f" % (b.x, b.y))
            else:
                a = segment.c1
                b = segment.c2
                c = segment.end_point
                parts.append("C%f,%f %f,%f %f,%f" % (a.x, a.y, b.x, b.y, c.x, c.y))
        parts.append("z")
    output += f'<path stroke="none" fill="#000000" fill-rule="evenodd" d="{"".join(parts)}"/>'

    output += "</svg>"
    # From svg to png (bytes)
    image_data = cairosvg.surface.PNGSurface.convert(output)
    image = Image.open(io.BytesIO(image_data)).split()[-1]
    return image