import logging
import os
import time

import cv2
from diffusers import StableDiffusionPipeline
import gradio as gr
# import mediapipe as mp
import numpy as np
import PIL
import torch.cuda
from transformers import pipeline

os.environ['HF_HUB_ENABLE_HF_TRANSFER'] = '1'


logging.basicConfig(level=logging.INFO,
                    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
                    force=True)

LOG = logging.getLogger(__name__)

LOG.info("Loading image segmentation model")

seg_kwargs = {
    "task": "image-segmentation",
    "model": "nvidia/segformer-b0-finetuned-ade-512-512"
}

img_segmentation_model = pipeline(**seg_kwargs)


# mp_selfie_segmentation = mp.solutions.selfie_segmentation
# img_segmentation_model = mp_selfie_segmentation.SelfieSegmentation(model_selection=0)


LOG.info("Loading diffusion model")

diffusion = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")

if torch.cuda.is_available():
    LOG.info("Moving diffusion model to GPU")
    diffusion.to('cuda')


def image_preprocess(image: PIL.Image):
    LOG.info("Preprocessing image %s", image)
    start = time.time()
    # image = PIL.ImageOps.exif_transpose(image)
    image = image.convert("RGB")
    image = resize_image(image)
    # image = np.array(image)
    # # Convert RGB to BGR
    # image = image[:, :, ::-1].copy()
    elapsed = time.time() - start
    LOG.info("Image preprocessed, %.2f seconds elapsed", elapsed)
    return image


def resize_image(image: PIL.Image):
    width, height = image.size
    ratio = max(width / 512, height / 512)
    width = int(width / ratio) // 8 * 8
    height = int(height / ratio) // 8 * 8
    image = image.resize((width, height))
    return image


def extract_selfie_mask(threshold, image):
    LOG.info("Extracting selfie mask")
    start = time.time()
    segments = img_segmentation_model(image)
    kept = None
    for s in segments:
        if s['score'] is None:
            s['score'] = 1
        if s['label'] == 'person' and s['score'] > 0.99:
            if not kept:
                kept = s
            elif kept['score'] < s['score']:
                kept = s
    if not kept:
        LOG.info("No person found in the photo, skipping")
        mask = np.zeros((image.size[1], image.size[0], 3), dtype='float32')
    else:
        mask = kept['mask']
        mask = np.array(mask, dtype='float32')

    cv2.threshold(mask, threshold, 1, cv2.THRESH_BINARY, dst=mask)
    cv2.dilate(mask, np.ones((5, 5), np.uint8), iterations=1, dst=mask)
    cv2.blur(mask, (10, 10), dst=mask)

    elapsed = time.time() - start
    LOG.info("Selfie extracted, %.2f seconds elapsed", elapsed)
    return mask


def generate_background(prompt, num_inference_steps, height, width):
    LOG.info("Generating background")
    start = time.time()
    background = diffusion(
        prompt=prompt,
        num_inference_steps=int(num_inference_steps),
        height=height,
        width=width
    )
    nsfw = background.nsfw_content_detected[0]
    background = background.images[0]

    if nsfw:
        LOG.info('NSFW detected, skipping')
        background = np.zeros((height, width, 3), dtype='uint8')
    else:
        background = np.array(background)
        # Convert RGB to BGR
        background = background[:, :, ::-1].copy()

    elapsed = time.time() - start
    LOG.info("Background generated, elapsed %.2f seconds", elapsed)
    return background


def merge_selfie_and_background(selfie, background, mask):
    LOG.info("Merging extracted selfie and generated background")
    selfie = np.array(selfie)
    # Convert RGB to BGR
    selfie = selfie[:, :, ::-1].copy()
    cv2.blendLinear(selfie, background, mask, 1 - mask, dst=selfie)
    selfie = cv2.cvtColor(selfie, cv2.COLOR_BGR2RGB)
    selfie = PIL.Image.fromarray(selfie)
    return selfie


def demo(threshold, image, prompt, num_inference_steps):
    LOG.info("Processing image")
    try:
        image = image_preprocess(image)
        mask = extract_selfie_mask(threshold, image)
        background = generate_background(prompt, num_inference_steps,
                                         image.size[1], image.size[0])
        output = merge_selfie_and_background(image, background, mask)
    except Exception as e:
        LOG.error("Some unexpected error occured")
        LOG.exception(e)
        raise
    return output


iface = gr.Interface(
    fn=demo,
    inputs=[
        gr.Slider(minimum=0.1, maximum=1, step=0.05, label="Selfie segmentation threshold",
                  value=0.8),
        gr.Image(type='pil', label="Upload your selfie"),
        gr.Text(value="a photo of the Eiffel tower on the right side",
                label="Background description"),
        gr.Slider(minimum=5, maximum=100, step=5, label="Diffusion inference steps",
                  value=50)
    ],
    outputs=[
        gr.Image(label="Invent yourself a life :)")
    ])

# iface.launch(server_name="0.0.0.0", server_port=6443)
iface.launch()