##!/usr/bin/python3
# -*- coding: utf-8 -*-
import os

print("Installing correct gradio version...")
os.system("pip uninstall -y gradio")
os.system("pip install gradio==3.50.0")
print("Installing Finished!")

##!/usr/bin/python3
# -*- coding: utf-8 -*-
import gradio as gr
import os
import cv2
from PIL import Image
import numpy as np
from segment_anything import SamPredictor, sam_model_registry
import torch
from diffusers import StableDiffusionBrushNetPipeline, BrushNetModel, UniPCMultistepScheduler
import random

mobile_sam = sam_model_registry['vit_h'](checkpoint='data/ckpt/sam_vit_h_4b8939.pth').to("cuda")
mobile_sam.eval()
mobile_predictor = SamPredictor(mobile_sam)
colors = [(255, 0, 0), (0, 255, 0)]
markers = [1, 5]

# - - - - - examples  - - - - -  #
image_examples = [
    ["examples/brushnet/src/test_image.jpg", "A beautiful cake on the table", "examples/brushnet/src/test_mask.jpg", 0, [], [Image.open("examples/brushnet/src/test_result.png")]],
    ["examples/brushnet/src/example_1.jpg", "A man in Chinese traditional clothes", "examples/brushnet/src/example_1_mask.jpg", 1, [], [Image.open("examples/brushnet/src/example_1_result.png")]],
    ["examples/brushnet/src/example_3.jpg", "a cut toy on the table", "examples/brushnet/src/example_3_mask.jpg", 2, [], [Image.open("examples/brushnet/src/example_3_result.png")]],
    ["examples/brushnet/src/example_4.jpeg", "a car driving in the wild", "examples/brushnet/src/example_4_mask.jpg", 3, [], [Image.open("examples/brushnet/src/example_4_result.png")]],
    ["examples/brushnet/src/example_5.jpg", "a charming woman wearing dress standing in the dark forest", "examples/brushnet/src/example_5_mask.jpg", 4, [], [Image.open("examples/brushnet/src/example_5_result.png")]],
]


# choose the base model here
base_model_path = "data/ckpt/realisticVisionV60B1_v51VAE"
# base_model_path = "runwayml/stable-diffusion-v1-5"

# input brushnet ckpt path
brushnet_path = "data/ckpt/segmentation_mask_brushnet_ckpt"

brushnet = BrushNetModel.from_pretrained(brushnet_path, torch_dtype=torch.float16)
pipe = StableDiffusionBrushNetPipeline.from_pretrained(
    base_model_path, brushnet=brushnet, torch_dtype=torch.float16, low_cpu_mem_usage=False
)

# speed up diffusion process with faster scheduler and memory optimization
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
# remove following line if xformers is not installed or when using Torch 2.0.
# pipe.enable_xformers_memory_efficient_attention()
# memory optimization.
pipe.enable_model_cpu_offload()

def resize_image(input_image, resolution):
    H, W, C = input_image.shape
    H = float(H)
    W = float(W)
    k = float(resolution) / min(H, W)
    H *= k
    W *= k
    H = int(np.round(H / 64.0)) * 64
    W = int(np.round(W / 64.0)) * 64
    img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA)
    return img

# once user upload an image, the original image is stored in `original_image`
def store_img(img):
    # image upload is too slow
    if min(img.shape[0], img.shape[1]) > 1024:
        img = resize_image(img, 1024)
    if max(img.shape[0], img.shape[1])*1.0/min(img.shape[0], img.shape[1])>2.0:
        raise gr.Error('image aspect ratio cannot be larger than 2.0')
    return img

def process(original_image, input_mask, prompt, negative_prompt, blended, invert_mask, control_strength, seed, randomize_seed, guidance_scale, num_inference_steps):
    if original_image is None:
        raise gr.Error('Please upload the input image')
    if input_mask is None:
        raise gr.Error("Please upload a white-black Mask image")
    #resizing input image and mask of the object
    original_image = store_img(original_image)
    input_mask = store_img(input_mask)
    
    H, W = original_image.shape[:2]
    original_mask = cv2.resize(input_mask, (W, H))

    if invert_mask:
        original_mask = 255 - original_mask
    mask = 1.*(original_mask.sum(-1) > 255)[:,:,np.newaxis]
    masked_image = original_image * (1 - mask)
    init_image = Image.fromarray(masked_image.astype(np.uint8)).convert("RGB")
    mask_image = Image.fromarray(original_mask.astype(np.uint8)).convert("RGB")
    generator = torch.Generator("cuda").manual_seed(random.randint(0, 2147483647) if randomize_seed else seed)
    image = pipe(
        [prompt]*2,
        init_image,
        mask_image,
        num_inference_steps=num_inference_steps,
        guidance_scale=guidance_scale,
        generator=generator,
        brushnet_conditioning_scale=float(control_strength),
        negative_prompt=[negative_prompt]*2,
    ).images

    if blended:
        if control_strength < 1.0:
            raise gr.Error('Using blurred blending with control strength less than 1.0 is not allowed')
        blended_image = []
        mask_blurred = cv2.GaussianBlur(mask*255, (21, 21), 0)/255
        mask_blurred = mask_blurred[:,:,np.newaxis]
        mask = 1 - (1 - mask) * (1 - mask_blurred)
        for image_i in image:
            image_np = np.array(image_i)
            image_pasted = original_image * (1 - mask) + image_np * mask
            image_pasted = image_pasted.astype(image_np.dtype)
            blended_image.append(Image.fromarray(image_pasted))
        image = blended_image

    return image

# Create Gradio interface
with gr.Blocks() as demo:
    with gr.Row():
        with gr.Column():
            original_image = gr.Image(type="numpy", label="Original Image")
            input_mask = gr.Image(type="numpy", label="Mask Image")
            prompt = gr.Textbox(label="Prompt")
            negative_prompt = gr.Textbox(label="Negative Prompt", value='ugly, low quality')
            blended = gr.Checkbox(label="Blurred Blending", value=False)
            invert_mask = gr.Checkbox(label="Invert Mask", value=False)
            control_strength = gr.Slider(label="Control Strength", minimum=0, maximum=1.1, value=1, step=0.01)
            seed = gr.Slider(label="Seed", minimum=0, maximum=2147483647, step=1, value=551793204)
            randomize_seed = gr.Checkbox(label="Randomize Seed", value=False)
            guidance_scale = gr.Slider(label="Guidance Scale", minimum=1, maximum=12, step=0.1, value=7.5)
            num_inference_steps = gr.Slider(label="Number of Inference Steps", minimum=1, maximum=50, step=1, value=50)
            #selected_points = gr.State([],label="select points")
            run_button = gr.Button("Run")
        
        with gr.Column():
            result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery", preview=True)
    
    inputs = [original_image, input_mask, prompt, negative_prompt, blended, invert_mask, control_strength, seed, randomize_seed, guidance_scale, num_inference_steps]
    run_button.click(fn=process, inputs=inputs, outputs=[result_gallery])

demo.queue(concurrency_count=1, max_size=1, api_open=True)
demo.launch(show_api=True)