##!/usr/bin/python3
# -*- coding: utf-8 -*-
import os, random, sys
import numpy as np
import requests
import torch
import spaces


import gradio as gr

from PIL import Image


from huggingface_hub import hf_hub_download, snapshot_download
from scipy.ndimage import binary_dilation, binary_erosion
from transformers import (LlavaNextProcessor, LlavaNextForConditionalGeneration, 
                        Qwen2_5_VLForConditionalGeneration, AutoProcessor)

from segment_anything import SamPredictor, build_sam, SamAutomaticMaskGenerator
from diffusers import StableDiffusionBrushNetPipeline, BrushNetModel, UniPCMultistepScheduler
from diffusers.image_processor  import VaeImageProcessor


from app.src.vlm_pipeline import (
    vlm_response_editing_type, 
    vlm_response_object_wait_for_edit, 
    vlm_response_mask, 
    vlm_response_prompt_after_apply_instruction
)
from app.src.brushedit_all_in_one_pipeline import BrushEdit_Pipeline
from app.utils.utils import load_grounding_dino_model

from app.src.vlm_template import vlms_template
from app.src.base_model_template import base_models_template
from app.src.aspect_ratio_template import aspect_ratios

from openai import OpenAI
# base_openai_url = ""

#### Description ####
logo = r"""
<center><img src='./assets/logo_brushedit.png' alt='BrushEdit logo' style="width:80px; margin-bottom:10px"></center>
"""
head = r"""
<div style="text-align: center;">
    <h1> BrushEdit: All-In-One Image Inpainting and Editing</h1>
    <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
        <a href='https://liyaowei-stu.github.io/project/BrushEdit/'><img src='https://img.shields.io/badge/Project_Page-BrushEdit-green' alt='Project Page'></a>
        <a href='https://arxiv.org/abs/2412.10316'><img src='https://img.shields.io/badge/Paper-Arxiv-blue'></a>
        <a href='https://github.com/TencentARC/BrushEdit'><img src='https://img.shields.io/badge/Code-Github-orange'></a>
        
    </div>
    </br>
</div>
"""
descriptions = r"""
Official Gradio Demo for <a href='https://tencentarc.github.io/BrushNet/'><b>BrushEdit: All-In-One Image Inpainting and Editing</b></a><br>
🧙 BrushEdit enables precise, user-friendly instruction-based image editing via a inpainting model.<br>
"""

instructions = r"""
Currently, we support two modes: <b>fully automated command editing</b> and <b>interactive command editing</b>.

🛠️ <b>Fully automated instruction-based editing</b>:
<ul>
    <li> ⭐️ <b>1.Choose Image: </b> Upload <img src="https://github.com/user-attachments/assets/f2dca1e6-31f9-4716-ae84-907f24415bac" alt="upload" style="display:inline; height:1em; vertical-align:middle;"> or select <img src="https://github.com/user-attachments/assets/de808f7d-c74a-44c7-9cbf-f0dbfc2c1abf" alt="example" style="display:inline; height:1em; vertical-align:middle;">  one image from Example. </li>
    <li> ⭐️ <b>2.Input ⌨️ Instructions: </b> Input the instructions (supports addition, deletion, and modification), e.g. remove xxx .</li>
    <li> ⭐️ <b>3.Run: </b> Click <b>💫 Run</b> button to automatic edit image.</li>
</ul>

🛠️ <b>Interactive instruction-based editing</b>:
<ul>
    <li> ⭐️ <b>1.Choose Image: </b> Upload <img src="https://github.com/user-attachments/assets/f2dca1e6-31f9-4716-ae84-907f24415bac" alt="upload" style="display:inline; height:1em; vertical-align:middle;"> or select <img src="https://github.com/user-attachments/assets/de808f7d-c74a-44c7-9cbf-f0dbfc2c1abf" alt="example" style="display:inline; height:1em; vertical-align:middle;">  one image from Example. </li>
    <li> ⭐️ <b>2.Finely Brushing: </b> Use a brush <img src="https://github.com/user-attachments/assets/c466c5cc-ac8f-4b4a-9bc5-04c4737fe1ef" alt="brush" style="display:inline; height:1em; vertical-align:middle;"> to outline the area you want to edit. And You can also use the eraser <img src="https://github.com/user-attachments/assets/b6370369-b080-4550-b0d0-830ff22d9068" alt="eraser" style="display:inline; height:1em; vertical-align:middle;">  to restore. </li>
    <li> ⭐️ <b>3.Input ⌨️ Instructions: </b> Input the instructions. </li>
    <li> ⭐️ <b>4.Run: </b> Click <b>💫 Run</b> button to automatic edit image. </li>
</ul>

<b> We strongly recommend using GPT-4o for reasoning. </b> After selecting the VLM model as gpt4-o, enter the API KEY and click the Submit and Verify button. If the output is success, you can use gpt4-o normally. Secondarily, we recommend using the Qwen2VL model.

<b> We recommend zooming out in your browser for a better viewing range and experience. </b>

<b> For more detailed feature descriptions, see the bottom. </b>

⚠️ We recommend using gpt4-o. If you are using Qwen-7b, when it does not meet the editing intention, you can copy the popped-up target prompt and modify it according to your own wishes, and then type it into the <b>Input Target Prompt</b> and re-run.

☕️ Have fun! 🎄 Wishing you a merry Christmas!
            """

tips =  r"""
💡 <b>Some Tips</b>:
<ul>    
    <li> 🤠 After input the instructions, you can click the <b>Generate Mask</b> button. The mask generated by VLM will be displayed in the preview panel on the right side. </li>
    <li> 🤠 After generating the mask or when you use the brush to draw the mask, you can perform operations such as  <b>randomization</b>,  <b>dilation</b>,  <b>erosion</b>, and  <b>movement</b>. </li>
    <li> 🤠 After input the instructions, you can click the <b>Generate Target Prompt</b> button. The target prompt will be displayed in the text box, and you can modify it according to your ideas. </li>
</ul>

💡 <b>Detailed Features</b>:
<ul>    
    <li> 🎨 <b>Aspect Ratio</b>: Select the aspect ratio of the image. To prevent OOM, 1024px is the maximum resolution.</li>
    <li> 🎨 <b>VLM Model</b>: Select the VLM model. We use preloaded models to save time. To use other VLM models, download them and uncomment the relevant lines in vlm_template.py from our GitHub repo. </li>
    <li> 🎨 <b>Generate Mask</b>: According to the input instructions, generate a mask for the area that may need to be edited. </li>
    <li> 🎨 <b>Square/Circle Mask</b>: Based on the existing mask, generate masks for squares and circles. (The coarse-grained mask provides more editing imagination.) </li>
    <li> 🎨 <b>Invert Mask</b>: Invert the mask to generate a new mask. </li>
    <li> 🎨 <b>Dilation/Erosion Mask</b>: Expand or shrink the mask to include or exclude more areas. </li>
    <li> 🎨 <b>Move Mask</b>: Move the mask to a new position. </li>
    <li> 🎨 <b>Generate Target Prompt</b>: Generate a target prompt based on the input instructions. </li>
    <li> 🎨 <b>Target Prompt</b>: Description for masking area, manual input or modification can be made when the content generated by VLM does not meet expectations. </li>
    <li> 🎨 <b>Blending</b>: Blending brushnet's output and the original input, ensuring the original image details in the unedited areas. (turn off is beeter when removing.) </li>
    <li> 🎨 <b>Control length</b>: The intensity of editing and inpainting. </li>
</ul> 

💡 <b>Advanced Features</b>:
<ul>    
    <li> 🎨 <b>Base Model</b>: We use preloaded models to save time. To use other VLM models, download them and uncomment the relevant lines in vlm_template.py from our GitHub repo. </li>
    <li> 🎨 <b>Blending</b>: Blending brushnet's output and the original input, ensuring the original image details in the unedited areas. (turn off is beeter when removing.) </li>
    <li> 🎨 <b>Control length</b>: The intensity of editing and inpainting. </li>
    <li> 🎨 <b>Num samples</b>: The number of samples to generate. </li>
    <li> 🎨 <b>Negative prompt</b>: The negative prompt for the classifier-free guidance. </li>
    <li> 🎨 <b>Guidance scale</b>: The guidance scale for the classifier-free guidance. </li>
</ul> 


"""



citation = r"""
If BrushEdit is helpful, please help to ⭐ the <a href='https://github.com/TencentARC/BrushEdit' target='_blank'>Github Repo</a>. Thanks!
[![GitHub Stars](https://img.shields.io/github/stars/TencentARC/BrushEdit?style=social)](https://github.com/TencentARC/BrushEdit)
---
📝 **Citation**
<br>
If our work is useful for your research, please consider citing:
```bibtex
@misc{li2024brushedit,
  title={BrushEdit: All-In-One Image Inpainting and Editing}, 
  author={Yaowei Li and Yuxuan Bian and Xuan Ju and Zhaoyang Zhang and and Junhao Zhuang and Ying Shan and Yuexian Zou and Qiang Xu},
  year={2024},
  eprint={2412.10316},
  archivePrefix={arXiv},
  primaryClass={cs.CV}
}
```
📧 **Contact**
<br>
If you have any questions, please feel free to reach me out at <b>liyaowei@gmail.com</b>.
"""

# - - - - - examples  - - - - -  #
EXAMPLES = [

    [
    Image.open("./assets/frog/frog.jpeg").convert("RGBA"),
     "add a magic hat on frog head.", 
     642087011,
     "frog",
     "frog",
     True,
     False,
     "GPT4-o (Highly Recommended)"
    ],
    [
    Image.open("./assets/chinese_girl/chinese_girl.png").convert("RGBA"),
     "replace the background to ancient China.", 
     648464818,
     "chinese_girl",
     "chinese_girl",
     True,
     False,
     "GPT4-o (Highly Recommended)"
    ],
    [
    Image.open("./assets/angel_christmas/angel_christmas.png").convert("RGBA"),
     "remove the deer.", 
     648464818,
     "angel_christmas",
     "angel_christmas",
     False,
     False,
     "GPT4-o (Highly Recommended)"
    ],
    [
    Image.open("./assets/sunflower_girl/sunflower_girl.png").convert("RGBA"),
     "add a wreath on head.", 
     648464818,
     "sunflower_girl",
     "sunflower_girl",
     True,
     False,
     "GPT4-o (Highly Recommended)"
    ],
    [
    Image.open("./assets/girl_on_sun/girl_on_sun.png").convert("RGBA"),
     "add a butterfly fairy.", 
     648464818,
     "girl_on_sun",
     "girl_on_sun",
     True,
     False,
     "GPT4-o (Highly Recommended)"
    ],
    [
    Image.open("./assets/spider_man_rm/spider_man.png").convert("RGBA"),
     "remove the christmas hat.", 
     642087011,
     "spider_man_rm",
     "spider_man_rm",
     False,
     False,
     "GPT4-o (Highly Recommended)"
    ],
    [
    Image.open("./assets/anime_flower/anime_flower.png").convert("RGBA"),
     "remove the flower.", 
     642087011,
     "anime_flower",
     "anime_flower",
     False,
     False,
     "GPT4-o (Highly Recommended)"
    ],
    [
    Image.open("./assets/chenduling/chengduling.jpg").convert("RGBA"),
     "replace the clothes to a delicated floral skirt.", 
     648464818,
     "chenduling",
     "chenduling",
     True,
     False,
     "GPT4-o (Highly Recommended)"
    ],
    [
    Image.open("./assets/hedgehog_rp_bg/hedgehog.png").convert("RGBA"),
     "make the hedgehog in Italy.", 
     648464818,
     "hedgehog_rp_bg",
     "hedgehog_rp_bg",
     True,
     False,
     "GPT4-o (Highly Recommended)"
    ],

]

INPUT_IMAGE_PATH = {
    "frog": "./assets/frog/frog.jpeg",
    "chinese_girl": "./assets/chinese_girl/chinese_girl.png",
    "angel_christmas": "./assets/angel_christmas/angel_christmas.png",
    "sunflower_girl": "./assets/sunflower_girl/sunflower_girl.png",
    "girl_on_sun": "./assets/girl_on_sun/girl_on_sun.png",
    "spider_man_rm": "./assets/spider_man_rm/spider_man.png",
    "anime_flower": "./assets/anime_flower/anime_flower.png",
    "chenduling": "./assets/chenduling/chengduling.jpg",
    "hedgehog_rp_bg": "./assets/hedgehog_rp_bg/hedgehog.png",
}
MASK_IMAGE_PATH = {
    "frog": "./assets/frog/mask_f7b350de-6f2c-49e3-b535-995c486d78e7.png",
    "chinese_girl": "./assets/chinese_girl/mask_54759648-0989-48e0-bc82-f20e28b5ec29.png",
    "angel_christmas": "./assets/angel_christmas/mask_f15d9b45-c978-4e3d-9f5f-251e308560c3.png",
    "sunflower_girl": "./assets/sunflower_girl/mask_99cc50b4-7dc4-4de5-8748-ec10772f0317.png",
    "girl_on_sun": "./assets/girl_on_sun/mask_264eac8b-8b65-479c-9755-020a60880c37.png",
    "spider_man_rm": "./assets/spider_man_rm/mask_a5d410e6-8e8d-432f-8144-defbc3e1eae9.png",
    "anime_flower": "./assets/anime_flower/mask_37553172-9b38-4727-bf2e-37d7e2b93461.png",
    "chenduling": "./assets/chenduling/mask_68e3ff6f-da07-4b37-91df-13d6eed7b997.png",
    "hedgehog_rp_bg": "./assets/hedgehog_rp_bg/mask_db7f8bf8-8349-46d3-b14e-43d67fbe25d3.png",
}
MASKED_IMAGE_PATH = {
    "frog": "./assets/frog/masked_image_f7b350de-6f2c-49e3-b535-995c486d78e7.png",
    "chinese_girl": "./assets/chinese_girl/masked_image_54759648-0989-48e0-bc82-f20e28b5ec29.png",
    "angel_christmas": "./assets/angel_christmas/masked_image_f15d9b45-c978-4e3d-9f5f-251e308560c3.png",
    "sunflower_girl": "./assets/sunflower_girl/masked_image_99cc50b4-7dc4-4de5-8748-ec10772f0317.png",
    "girl_on_sun": "./assets/girl_on_sun/masked_image_264eac8b-8b65-479c-9755-020a60880c37.png",
    "spider_man_rm": "./assets/spider_man_rm/masked_image_a5d410e6-8e8d-432f-8144-defbc3e1eae9.png",
    "anime_flower": "./assets/anime_flower/masked_image_37553172-9b38-4727-bf2e-37d7e2b93461.png",
    "chenduling": "./assets/chenduling/masked_image_68e3ff6f-da07-4b37-91df-13d6eed7b997.png",
    "hedgehog_rp_bg": "./assets/hedgehog_rp_bg/masked_image_db7f8bf8-8349-46d3-b14e-43d67fbe25d3.png",
}
OUTPUT_IMAGE_PATH = {
    "frog": "./assets/frog/image_edit_f7b350de-6f2c-49e3-b535-995c486d78e7_1.png",
    "chinese_girl": "./assets/chinese_girl/image_edit_54759648-0989-48e0-bc82-f20e28b5ec29_1.png",
    "angel_christmas": "./assets/angel_christmas/image_edit_f15d9b45-c978-4e3d-9f5f-251e308560c3_0.png",
    "sunflower_girl": "./assets/sunflower_girl/image_edit_99cc50b4-7dc4-4de5-8748-ec10772f0317_3.png",
    "girl_on_sun": "./assets/girl_on_sun/image_edit_264eac8b-8b65-479c-9755-020a60880c37_0.png",
    "spider_man_rm": "./assets/spider_man_rm/image_edit_a5d410e6-8e8d-432f-8144-defbc3e1eae9_0.png",
    "anime_flower": "./assets/anime_flower/image_edit_37553172-9b38-4727-bf2e-37d7e2b93461_2.png",
    "chenduling": "./assets/chenduling/image_edit_68e3ff6f-da07-4b37-91df-13d6eed7b997_0.png",
    "hedgehog_rp_bg": "./assets/hedgehog_rp_bg/image_edit_db7f8bf8-8349-46d3-b14e-43d67fbe25d3_3.png",
}

# os.environ['GRADIO_TEMP_DIR'] = 'gradio_temp_dir'
# os.makedirs('gradio_temp_dir', exist_ok=True)

VLM_MODEL_NAMES = list(vlms_template.keys())
DEFAULT_VLM_MODEL_NAME = "Qwen2.5-VL-7B-Instruct (Default)"
BASE_MODELS = list(base_models_template.keys())
DEFAULT_BASE_MODEL = "realisticVision (Default)"

ASPECT_RATIO_LABELS = list(aspect_ratios)
DEFAULT_ASPECT_RATIO = ASPECT_RATIO_LABELS[0]


## init device
try:
    if torch.cuda.is_available():
        device = "cuda"
    elif sys.platform == "darwin" and torch.backends.mps.is_available():
        device = "mps"
    else:
        device = "cpu"
except:
    device = "cpu"

# ## init torch dtype
# if torch.cuda.is_available() and torch.cuda.is_bf16_supported():
#     torch_dtype = torch.bfloat16
# else:
#     torch_dtype = torch.float16

# if device == "mps":
#     torch_dtype = torch.float16

torch_dtype = torch.float16



# download hf models
BrushEdit_path = "models/"
if not os.path.exists(BrushEdit_path):
    BrushEdit_path = snapshot_download(
        repo_id="TencentARC/BrushEdit",
        local_dir=BrushEdit_path,
        token=os.getenv("HF_TOKEN"),
    )

## init default VLM
vlm_type, vlm_local_path, vlm_processor, vlm_model = vlms_template[DEFAULT_VLM_MODEL_NAME]
if vlm_processor != "" and vlm_model != "":
    vlm_model.to(device)
else:
    raise gr.Error("Please Download default VLM model "+ DEFAULT_VLM_MODEL_NAME +" first.")


## init base model
base_model_path = os.path.join(BrushEdit_path, "base_model/realisticVisionV60B1_v51VAE")
brushnet_path = os.path.join(BrushEdit_path, "brushnetX")
sam_path = os.path.join(BrushEdit_path, "sam/sam_vit_h_4b8939.pth")
groundingdino_path = os.path.join(BrushEdit_path, "grounding_dino/groundingdino_swint_ogc.pth")


# input brushnetX ckpt path
brushnet = BrushNetModel.from_pretrained(brushnet_path, torch_dtype=torch_dtype)
pipe = StableDiffusionBrushNetPipeline.from_pretrained(
        base_model_path, brushnet=brushnet, torch_dtype=torch_dtype, low_cpu_mem_usage=False
    )
# speed up diffusion process with faster scheduler and memory optimization
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
# remove following line if xformers is not installed or when using Torch 2.0.
# pipe.enable_xformers_memory_efficient_attention()
pipe.enable_model_cpu_offload()


## init SAM
sam = build_sam(checkpoint=sam_path)
sam.to(device=device)
sam_predictor = SamPredictor(sam)
sam_automask_generator = SamAutomaticMaskGenerator(sam)

## init groundingdino_model
config_file = 'app/utils/GroundingDINO_SwinT_OGC.py'
groundingdino_model = load_grounding_dino_model(config_file, groundingdino_path, device=device)

## Ordinary function
def crop_and_resize(image: Image.Image, 
                    target_width: int, 
                    target_height: int) -> Image.Image:
    """
    Crops and resizes an image while preserving the aspect ratio.

    Args:
        image (Image.Image): Input PIL image to be cropped and resized.
        target_width (int): Target width of the output image.
        target_height (int): Target height of the output image.

    Returns:
        Image.Image: Cropped and resized image.
    """
    # Original dimensions
    original_width, original_height = image.size
    original_aspect = original_width / original_height
    target_aspect = target_width / target_height

    # Calculate crop box to maintain aspect ratio
    if original_aspect > target_aspect:
        # Crop horizontally
        new_width = int(original_height * target_aspect)
        new_height = original_height
        left = (original_width - new_width) / 2
        top = 0
        right = left + new_width
        bottom = original_height
    else:
        # Crop vertically
        new_width = original_width
        new_height = int(original_width / target_aspect)
        left = 0
        top = (original_height - new_height) / 2
        right = original_width
        bottom = top + new_height

    # Crop and resize
    cropped_image = image.crop((left, top, right, bottom))
    resized_image = cropped_image.resize((target_width, target_height), Image.NEAREST)
    return resized_image


## Ordinary function
def resize(image: Image.Image, 
                    target_width: int, 
                    target_height: int) -> Image.Image:
    """
    Crops and resizes an image while preserving the aspect ratio.

    Args:
        image (Image.Image): Input PIL image to be cropped and resized.
        target_width (int): Target width of the output image.
        target_height (int): Target height of the output image.

    Returns:
        Image.Image: Cropped and resized image.
    """
    # Original dimensions
    resized_image = image.resize((target_width, target_height), Image.NEAREST)
    return resized_image


def move_mask_func(mask, direction, units):
    binary_mask = mask.squeeze()>0
    rows, cols = binary_mask.shape
    moved_mask = np.zeros_like(binary_mask, dtype=bool)

    if direction == 'down':
        # move down
        moved_mask[max(0, units):, :] = binary_mask[:rows - units, :]

    elif direction == 'up':
        # move up
        moved_mask[:rows - units, :] = binary_mask[units:, :]

    elif direction == 'right':
        # move left
        moved_mask[:, max(0, units):] = binary_mask[:, :cols - units]

    elif direction == 'left':
        # move right
        moved_mask[:, :cols - units] = binary_mask[:, units:]

    return moved_mask


def random_mask_func(mask, dilation_type='square', dilation_size=20):
    # Randomly select the size of dilation
    binary_mask = mask.squeeze()>0

    if dilation_type == 'square_dilation':
        structure = np.ones((dilation_size, dilation_size), dtype=bool)
        dilated_mask = binary_dilation(binary_mask, structure=structure)
    elif dilation_type == 'square_erosion':
        structure = np.ones((dilation_size, dilation_size), dtype=bool)
        dilated_mask = binary_erosion(binary_mask, structure=structure)
    elif dilation_type == 'bounding_box':
        # find the most left top and left bottom point
        rows, cols = np.where(binary_mask)
        if len(rows) == 0 or len(cols) == 0:
            return mask  # return original mask if no valid points

        min_row = np.min(rows)
        max_row = np.max(rows)
        min_col = np.min(cols)
        max_col = np.max(cols)

        # create a bounding box
        dilated_mask = np.zeros_like(binary_mask, dtype=bool)
        dilated_mask[min_row:max_row + 1, min_col:max_col + 1] = True

    elif dilation_type == 'bounding_ellipse':
        # find the most left top and left bottom point
        rows, cols = np.where(binary_mask)
        if len(rows) == 0 or len(cols) == 0:
            return mask  # return original mask if no valid points

        min_row = np.min(rows)
        max_row = np.max(rows)
        min_col = np.min(cols)
        max_col = np.max(cols)

        # calculate the center and axis length of the ellipse
        center = ((min_col + max_col) // 2, (min_row + max_row) // 2)
        a = (max_col - min_col) // 2  # half long axis
        b = (max_row - min_row) // 2  # half short axis

        # create a bounding ellipse
        y, x = np.ogrid[:mask.shape[0], :mask.shape[1]]
        ellipse_mask = ((x - center[0])**2 / a**2 + (y - center[1])**2 / b**2) <= 1
        dilated_mask = np.zeros_like(binary_mask, dtype=bool)
        dilated_mask[ellipse_mask] = True
    else:
        ValueError("dilation_type must be 'square' or 'ellipse'")

    # use binary dilation
    dilated_mask =  np.uint8(dilated_mask[:,:,np.newaxis]) * 255
    return dilated_mask


## Gradio component function
def update_vlm_model(vlm_name):
    global vlm_model, vlm_processor
    if vlm_model is not None:
        del vlm_model
        torch.cuda.empty_cache()

    vlm_type, vlm_local_path, vlm_processor, vlm_model = vlms_template[vlm_name]
    
    ## we recommend using preload models, otherwise it will take a long time to download the model. you can edit the code via vlm_template.py
    if vlm_type == "llava-next":
        if vlm_processor != "" and vlm_model != "":
            vlm_model.to(device)
            return vlm_model_dropdown
        else:
            if os.path.exists(vlm_local_path):
                vlm_processor = LlavaNextProcessor.from_pretrained(vlm_local_path)
                vlm_model = LlavaNextForConditionalGeneration.from_pretrained(vlm_local_path, torch_dtype=torch_dtype, device_map=device)
            else:
                if vlm_name == "llava-v1.6-mistral-7b-hf (Preload)":
                    vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
                    vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf", torch_dtype=torch_dtype, device_map=device)
                elif vlm_name == "llama3-llava-next-8b-hf (Preload)":
                    vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llama3-llava-next-8b-hf")
                    vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llama3-llava-next-8b-hf", torch_dtype=torch_dtype, device_map=device)
                elif vlm_name == "llava-v1.6-vicuna-13b-hf (Preload)":
                    vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-vicuna-13b-hf")
                    vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-vicuna-13b-hf", torch_dtype=torch_dtype, device_map=device)
                elif vlm_name == "llava-v1.6-34b-hf (Preload)":
                    vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-34b-hf")
                    vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-34b-hf", torch_dtype=torch_dtype, device_map=device)
                elif vlm_name == "llava-next-72b-hf (Preload)":
                    vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-next-72b-hf")
                    vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-next-72b-hf", torch_dtype=torch_dtype, device_map=device)
    elif vlm_type == "qwen2-vl":
        if vlm_processor != "" and vlm_model != "":
            vlm_model.to(device)
            return vlm_model_dropdown
        else:
            if os.path.exists(vlm_local_path):
                vlm_processor = AutoProcessor.from_pretrained(vlm_local_path)
                vlm_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(vlm_local_path, torch_dtype=torch_dtype, device_map=device)
            else:
                if vlm_name == "Qwen2.5-VL-7B-Instruct (Default)":
                    vlm_processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
                    vlm_model = Qwen2_5_VLForConditionalGeneration.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct", torch_dtype=torch_dtype, device_map=device)
    elif vlm_type == "openai":
        pass
    return "success"


def update_base_model(base_model_name):
    global pipe
    ## we recommend using preload models, otherwise it will take a long time to download the model. you can edit the code via base_model_template.py
    if pipe is not None:
        del pipe
        torch.cuda.empty_cache()
    base_model_path, pipe = base_models_template[base_model_name]
    if pipe != "":
        pipe.to(device)
    else:
        if os.path.exists(base_model_path):
            pipe = StableDiffusionBrushNetPipeline.from_pretrained(
                base_model_path, brushnet=brushnet, torch_dtype=torch_dtype, low_cpu_mem_usage=False
            )
            # pipe.enable_xformers_memory_efficient_attention()
            pipe.enable_model_cpu_offload()
        else:
            raise gr.Error(f"The base model {base_model_name} does not exist")
    return "success"


def submit_GPT4o_KEY(GPT4o_KEY):
    global vlm_model, vlm_processor
    if vlm_model is not None:
        del vlm_model
        torch.cuda.empty_cache()
    try:
        vlm_model = OpenAI(api_key=GPT4o_KEY)
        vlm_processor = ""
        response = vlm_model.chat.completions.create(
                model="gpt-4o-2024-08-06",
                messages=[
                    {"role": "system", "content": "You are a helpful assistant."},
                    {"role": "user", "content": "Say this is a test"}
                ]
            )
        response_str = response.choices[0].message.content
     
        return "Success, " + response_str, "GPT4-o (Highly Recommended)"
    except Exception as e:
        return "Invalid GPT4o API Key", "GPT4-o (Highly Recommended)"
    

    
@spaces.GPU(duration=180)
def process(input_image, 
    original_image, 
    original_mask, 
    prompt, 
    negative_prompt, 
    control_strength, 
    seed, 
    randomize_seed, 
    guidance_scale, 
    num_inference_steps,
    num_samples,
    blending,
    category,
    target_prompt,
    resize_default,
    aspect_ratio_name,
    invert_mask_state):
    if original_image is None:
        if input_image is None:
            raise gr.Error('Please upload the input image')
        else:
            image_pil = input_image["background"].convert("RGB")
            original_image = np.array(image_pil)
    if prompt is None or prompt == "":
        if target_prompt is None or target_prompt == "":
            raise gr.Error("Please input your instructions, e.g., remove the xxx")
    
    alpha_mask = input_image["layers"][0].split()[3]
    input_mask = np.asarray(alpha_mask)
    output_w, output_h = aspect_ratios[aspect_ratio_name]
    if output_w == "" or output_h == "":    
        output_h, output_w = original_image.shape[:2]

        if resize_default:
            short_side = min(output_w, output_h)
            scale_ratio = 640 / short_side
            output_w = int(output_w * scale_ratio)
            output_h = int(output_h * scale_ratio)
            original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
            original_image = np.array(original_image)
            if input_mask is not None:
                input_mask = resize(Image.fromarray(np.squeeze(input_mask).astype(np.uint8)), target_width=int(output_w), target_height=int(output_h))
                input_mask = np.array(input_mask)
            if original_mask is not None:
                original_mask = resize(Image.fromarray(np.squeeze(original_mask).astype(np.uint8)), target_width=int(output_w), target_height=int(output_h))
                original_mask = np.array(original_mask)
            gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
        else:
            gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
            pass 
    else:
        if resize_default:
            short_side = min(output_w, output_h)
            scale_ratio = 640 / short_side
            output_w = int(output_w * scale_ratio)
            output_h = int(output_h * scale_ratio)
        gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
        original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
        original_image = np.array(original_image)
        if input_mask is not None:
            input_mask = resize(Image.fromarray(np.squeeze(input_mask).astype(np.uint8)), target_width=int(output_w), target_height=int(output_h))
            input_mask = np.array(input_mask)
        if original_mask is not None:
            original_mask = resize(Image.fromarray(np.squeeze(original_mask).astype(np.uint8)), target_width=int(output_w), target_height=int(output_h))
            original_mask = np.array(original_mask)

    if invert_mask_state:
        original_mask = original_mask
    else:
        if input_mask.max() == 0:
            original_mask = original_mask
        else:
            original_mask = input_mask

    
    ## inpainting directly if target_prompt is not None
    if category is not None:
        pass
    elif target_prompt is not None and len(target_prompt) >= 1 and original_mask is not None:
        pass
    else:
        try:
            category = vlm_response_editing_type(vlm_processor, vlm_model, original_image, prompt, device)
        except Exception as e:
            raise gr.Error("Please select the correct VLM model and input the correct API Key first!")
    

    if original_mask is not None:
        original_mask = np.clip(original_mask, 0, 255).astype(np.uint8)
    else:
        try:
            object_wait_for_edit = vlm_response_object_wait_for_edit(
                                                vlm_processor, 
                                                vlm_model, 
                                                original_image,
                                                category, 
                                                prompt,
                                                device)

            original_mask = vlm_response_mask(vlm_processor,
                                            vlm_model,
                                            category, 
                                            original_image, 
                                            prompt, 
                                            object_wait_for_edit, 
                                            sam,
                                            sam_predictor,
                                            sam_automask_generator,
                                            groundingdino_model,
                                            device).astype(np.uint8)
        except Exception as e:
            raise gr.Error("Please select the correct VLM model and input the correct API Key first!")

    if original_mask.ndim == 2:
        original_mask = original_mask[:,:,None]
    

    if target_prompt is not None and len(target_prompt) >= 1:
        prompt_after_apply_instruction = target_prompt
        
    else:
        try:
            prompt_after_apply_instruction = vlm_response_prompt_after_apply_instruction(
                                                                    vlm_processor, 
                                                                    vlm_model, 
                                                                    original_image,
                                                                    prompt,
                                                                    device)
        except Exception as e:
            raise gr.Error("Please select the correct VLM model and input the correct API Key first!")

    generator = torch.Generator(device).manual_seed(random.randint(0, 2147483647) if randomize_seed else seed)


    with torch.autocast(device):
        image, mask_image, mask_np, init_image_np = BrushEdit_Pipeline(pipe, 
                                    prompt_after_apply_instruction,
                                    original_mask,
                                    original_image,
                                    generator,
                                    num_inference_steps,
                                    guidance_scale,
                                    control_strength,
                                    negative_prompt,
                                    num_samples,
                                    blending)
    original_image = np.array(init_image_np)
    masked_image = original_image * (1 - (mask_np>0))
    masked_image = masked_image.astype(np.uint8)
    masked_image = Image.fromarray(masked_image)
    # Save the images (optional)
    # import uuid
    # uuid = str(uuid.uuid4())
    # image[0].save(f"outputs/image_edit_{uuid}_0.png")
    # image[1].save(f"outputs/image_edit_{uuid}_1.png")
    # image[2].save(f"outputs/image_edit_{uuid}_2.png")
    # image[3].save(f"outputs/image_edit_{uuid}_3.png")
    # mask_image.save(f"outputs/mask_{uuid}.png")
    # masked_image.save(f"outputs/masked_image_{uuid}.png")
    gr.Info(f"Target Prompt: {prompt_after_apply_instruction}", duration=16)
    return image, [mask_image], [masked_image], prompt, '', False


def generate_target_prompt(input_image, 
                           original_image, 
                           prompt):
    # load example image
    if isinstance(original_image, str):
        original_image = input_image

    prompt_after_apply_instruction = vlm_response_prompt_after_apply_instruction(
                                                            vlm_processor, 
                                                            vlm_model, 
                                                            original_image,
                                                            prompt,
                                                            device)
    return prompt_after_apply_instruction


def process_mask(input_image, 
    original_image, 
    prompt,
    resize_default,
    aspect_ratio_name):
    if original_image is None:
        raise gr.Error('Please upload the input image')
    if prompt is None:
        raise gr.Error("Please input your instructions, e.g., remove the xxx")

    ## load mask
    alpha_mask = input_image["layers"][0].split()[3]
    input_mask = np.array(alpha_mask)

    # load example image
    if isinstance(original_image, str):
        original_image = input_image["background"]

    if input_mask.max() == 0:
        category = vlm_response_editing_type(vlm_processor, vlm_model, original_image, prompt, device)

        object_wait_for_edit = vlm_response_object_wait_for_edit(vlm_processor, 
                                                                vlm_model, 
                                                                original_image,
                                                                category, 
                                                                prompt,
                                                                device)
        # original mask: h,w,1 [0, 255]
        original_mask = vlm_response_mask(
                                vlm_processor,
                                vlm_model,
                                category, 
                                original_image, 
                                prompt, 
                                object_wait_for_edit, 
                                sam,
                                sam_predictor,
                                sam_automask_generator,
                                groundingdino_model,
                                device).astype(np.uint8)
    else:
        original_mask = input_mask.astype(np.uint8)
        category = None

    ## resize mask if needed
    output_w, output_h = aspect_ratios[aspect_ratio_name]
    if output_w == "" or output_h == "":    
        output_h, output_w = original_image.shape[:2]
        if resize_default:
            short_side = min(output_w, output_h)
            scale_ratio = 640 / short_side
            output_w = int(output_w * scale_ratio)
            output_h = int(output_h * scale_ratio)
            original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
            original_image = np.array(original_image)
            if input_mask is not None:
                input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
                input_mask = np.array(input_mask)
            if original_mask is not None:
                original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
                original_mask = np.array(original_mask)
            gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
        else:
            gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
            pass 
    else:
        if resize_default:
            short_side = min(output_w, output_h)
            scale_ratio = 640 / short_side
            output_w = int(output_w * scale_ratio)
            output_h = int(output_h * scale_ratio)
        gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
        original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
        original_image = np.array(original_image)
        if input_mask is not None:
            input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
            input_mask = np.array(input_mask)
        if original_mask is not None:
            original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
            original_mask = np.array(original_mask)

    

    if original_mask.ndim == 2:
        original_mask = original_mask[:,:,None]

    mask_image = Image.fromarray(original_mask.squeeze().astype(np.uint8)).convert("RGB")

    masked_image = original_image * (1 - (original_mask>0))
    masked_image = masked_image.astype(np.uint8)
    masked_image = Image.fromarray(masked_image)

    return [masked_image], [mask_image], original_mask.astype(np.uint8), category


def process_random_mask(input_image, 
                         original_image, 
                         original_mask, 
                         resize_default, 
                         aspect_ratio_name, 
                         ):

    alpha_mask = input_image["layers"][0].split()[3]
    input_mask = np.asarray(alpha_mask)
    
    output_w, output_h = aspect_ratios[aspect_ratio_name]
    if output_w == "" or output_h == "":    
        output_h, output_w = original_image.shape[:2]
        if resize_default:
            short_side = min(output_w, output_h)
            scale_ratio = 640 / short_side
            output_w = int(output_w * scale_ratio)
            output_h = int(output_h * scale_ratio)
            original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
            original_image = np.array(original_image)
            if input_mask is not None:
                input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
                input_mask = np.array(input_mask)
            if original_mask is not None:
                original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
                original_mask = np.array(original_mask)
            gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
        else:
            gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
            pass 
    else:
        if resize_default:
            short_side = min(output_w, output_h)
            scale_ratio = 640 / short_side
            output_w = int(output_w * scale_ratio)
            output_h = int(output_h * scale_ratio)
        gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
        original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
        original_image = np.array(original_image)
        if input_mask is not None:
            input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
            input_mask = np.array(input_mask)
        if original_mask is not None:
            original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
            original_mask = np.array(original_mask)


    if input_mask.max() == 0:
        original_mask = original_mask
    else:
        original_mask = input_mask
    
    if original_mask is None:
        raise gr.Error('Please generate mask first')

    if original_mask.ndim == 2:
        original_mask = original_mask[:,:,None]

    dilation_type = np.random.choice(['bounding_box', 'bounding_ellipse'])
    random_mask = random_mask_func(original_mask, dilation_type).squeeze()

    mask_image = Image.fromarray(random_mask.astype(np.uint8)).convert("RGB")

    masked_image = original_image * (1 - (random_mask[:,:,None]>0))
    masked_image = masked_image.astype(original_image.dtype)
    masked_image = Image.fromarray(masked_image)


    return [masked_image], [mask_image], random_mask[:,:,None].astype(np.uint8)


def process_dilation_mask(input_image, 
                          original_image, 
                          original_mask, 
                          resize_default, 
                          aspect_ratio_name, 
                          dilation_size=20):

    alpha_mask = input_image["layers"][0].split()[3]
    input_mask = np.asarray(alpha_mask)

    output_w, output_h = aspect_ratios[aspect_ratio_name]
    if output_w == "" or output_h == "":    
        output_h, output_w = original_image.shape[:2]
        if resize_default:
            short_side = min(output_w, output_h)
            scale_ratio = 640 / short_side
            output_w = int(output_w * scale_ratio)
            output_h = int(output_h * scale_ratio)
            original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
            original_image = np.array(original_image)
            if input_mask is not None:
                input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
                input_mask = np.array(input_mask)
            if original_mask is not None:
                original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
                original_mask = np.array(original_mask)
            gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
        else:
            gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
            pass 
    else:
        if resize_default:
            short_side = min(output_w, output_h)
            scale_ratio = 640 / short_side
            output_w = int(output_w * scale_ratio)
            output_h = int(output_h * scale_ratio)
        gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
        original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
        original_image = np.array(original_image)
        if input_mask is not None:
            input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
            input_mask = np.array(input_mask)
        if original_mask is not None:
            original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
            original_mask = np.array(original_mask)

    if input_mask.max() == 0:
        original_mask = original_mask
    else:
        original_mask = input_mask

    if original_mask is None:
        raise gr.Error('Please generate mask first')

    if original_mask.ndim == 2:
        original_mask = original_mask[:,:,None]

    dilation_type = np.random.choice(['square_dilation'])
    random_mask = random_mask_func(original_mask, dilation_type, dilation_size).squeeze()

    mask_image = Image.fromarray(random_mask.astype(np.uint8)).convert("RGB")

    masked_image = original_image * (1 - (random_mask[:,:,None]>0))
    masked_image = masked_image.astype(original_image.dtype)
    masked_image = Image.fromarray(masked_image)

    return [masked_image], [mask_image], random_mask[:,:,None].astype(np.uint8)


def process_erosion_mask(input_image, 
                         original_image, 
                         original_mask, 
                         resize_default, 
                         aspect_ratio_name, 
                         dilation_size=20):
    alpha_mask = input_image["layers"][0].split()[3]
    input_mask = np.asarray(alpha_mask)
    output_w, output_h = aspect_ratios[aspect_ratio_name]
    if output_w == "" or output_h == "":    
        output_h, output_w = original_image.shape[:2]
        if resize_default:
            short_side = min(output_w, output_h)
            scale_ratio = 640 / short_side
            output_w = int(output_w * scale_ratio)
            output_h = int(output_h * scale_ratio)
            original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
            original_image = np.array(original_image)
            if input_mask is not None:
                input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
                input_mask = np.array(input_mask)
            if original_mask is not None:
                original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
                original_mask = np.array(original_mask)
            gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
        else:
            gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
            pass 
    else:
        if resize_default:
            short_side = min(output_w, output_h)
            scale_ratio = 640 / short_side
            output_w = int(output_w * scale_ratio)
            output_h = int(output_h * scale_ratio)
        gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
        original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
        original_image = np.array(original_image)
        if input_mask is not None:
            input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
            input_mask = np.array(input_mask)
        if original_mask is not None:
            original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
            original_mask = np.array(original_mask)

    if input_mask.max() == 0:
        original_mask = original_mask
    else:
        original_mask = input_mask

    if original_mask is None:
        raise gr.Error('Please generate mask first')

    if original_mask.ndim == 2:
        original_mask = original_mask[:,:,None]

    dilation_type = np.random.choice(['square_erosion'])
    random_mask = random_mask_func(original_mask, dilation_type, dilation_size).squeeze()

    mask_image = Image.fromarray(random_mask.astype(np.uint8)).convert("RGB")

    masked_image = original_image * (1 - (random_mask[:,:,None]>0))
    masked_image = masked_image.astype(original_image.dtype)
    masked_image = Image.fromarray(masked_image)


    return [masked_image], [mask_image], random_mask[:,:,None].astype(np.uint8)


def move_mask_left(input_image, 
                   original_image, 
                   original_mask, 
                   moving_pixels, 
                   resize_default, 
                   aspect_ratio_name):

    alpha_mask = input_image["layers"][0].split()[3]
    input_mask = np.asarray(alpha_mask)

    output_w, output_h = aspect_ratios[aspect_ratio_name]
    if output_w == "" or output_h == "":    
        output_h, output_w = original_image.shape[:2]
        if resize_default:
            short_side = min(output_w, output_h)
            scale_ratio = 640 / short_side
            output_w = int(output_w * scale_ratio)
            output_h = int(output_h * scale_ratio)
            original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
            original_image = np.array(original_image)
            if input_mask is not None:
                input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
                input_mask = np.array(input_mask)
            if original_mask is not None:
                original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
                original_mask = np.array(original_mask)
            gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
        else:
            gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
            pass 
    else:
        if resize_default:
            short_side = min(output_w, output_h)
            scale_ratio = 640 / short_side
            output_w = int(output_w * scale_ratio)
            output_h = int(output_h * scale_ratio)
        gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
        original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
        original_image = np.array(original_image)
        if input_mask is not None:
            input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
            input_mask = np.array(input_mask)
        if original_mask is not None:
            original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
            original_mask = np.array(original_mask)

    if input_mask.max() == 0:
        original_mask = original_mask
    else:
        original_mask = input_mask

    if original_mask is None:
        raise gr.Error('Please generate mask first')

    if original_mask.ndim == 2:
        original_mask = original_mask[:,:,None]

    moved_mask = move_mask_func(original_mask, 'left', int(moving_pixels)).squeeze()
    mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")

    masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
    masked_image = masked_image.astype(original_image.dtype)
    masked_image = Image.fromarray(masked_image)

    if moved_mask.max() <= 1:
        moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
        original_mask = moved_mask
    return [masked_image], [mask_image], original_mask.astype(np.uint8)


def move_mask_right(input_image, 
                    original_image, 
                    original_mask, 
                    moving_pixels, 
                    resize_default, 
                    aspect_ratio_name):
    alpha_mask = input_image["layers"][0].split()[3]
    input_mask = np.asarray(alpha_mask)

    output_w, output_h = aspect_ratios[aspect_ratio_name]
    if output_w == "" or output_h == "":    
        output_h, output_w = original_image.shape[:2]
        if resize_default:
            short_side = min(output_w, output_h)
            scale_ratio = 640 / short_side
            output_w = int(output_w * scale_ratio)
            output_h = int(output_h * scale_ratio)
            original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
            original_image = np.array(original_image)
            if input_mask is not None:
                input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
                input_mask = np.array(input_mask)
            if original_mask is not None:
                original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
                original_mask = np.array(original_mask)
            gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
        else:
            gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
            pass 
    else:
        if resize_default:
            short_side = min(output_w, output_h)
            scale_ratio = 640 / short_side
            output_w = int(output_w * scale_ratio)
            output_h = int(output_h * scale_ratio)
        gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
        original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
        original_image = np.array(original_image)
        if input_mask is not None:
            input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
            input_mask = np.array(input_mask)
        if original_mask is not None:
            original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
            original_mask = np.array(original_mask)

    if input_mask.max() == 0:
        original_mask = original_mask
    else:
        original_mask = input_mask

    if original_mask is None:
        raise gr.Error('Please generate mask first')

    if original_mask.ndim == 2:
        original_mask = original_mask[:,:,None]

    moved_mask = move_mask_func(original_mask, 'right', int(moving_pixels)).squeeze()

    mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")

    masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
    masked_image = masked_image.astype(original_image.dtype)
    masked_image = Image.fromarray(masked_image)


    if moved_mask.max() <= 1:
        moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
        original_mask = moved_mask

    return [masked_image], [mask_image], original_mask.astype(np.uint8)


def move_mask_up(input_image, 
                 original_image, 
                 original_mask, 
                 moving_pixels, 
                 resize_default, 
                 aspect_ratio_name):
    alpha_mask = input_image["layers"][0].split()[3]
    input_mask = np.asarray(alpha_mask) 

    output_w, output_h = aspect_ratios[aspect_ratio_name]
    if output_w == "" or output_h == "":    
        output_h, output_w = original_image.shape[:2]
        if resize_default:
            short_side = min(output_w, output_h)
            scale_ratio = 640 / short_side
            output_w = int(output_w * scale_ratio)
            output_h = int(output_h * scale_ratio)
            original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
            original_image = np.array(original_image)
            if input_mask is not None:
                input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
                input_mask = np.array(input_mask)
            if original_mask is not None:
                original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
                original_mask = np.array(original_mask)
            gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
        else:
            gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
            pass 
    else:
        if resize_default:
            short_side = min(output_w, output_h)
            scale_ratio = 640 / short_side
            output_w = int(output_w * scale_ratio)
            output_h = int(output_h * scale_ratio)
        gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
        original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
        original_image = np.array(original_image)
        if input_mask is not None:
            input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
            input_mask = np.array(input_mask)
        if original_mask is not None:
            original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
            original_mask = np.array(original_mask)

    if input_mask.max() == 0:
        original_mask = original_mask
    else:
        original_mask = input_mask

    if original_mask is None:
        raise gr.Error('Please generate mask first')

    if original_mask.ndim == 2:
        original_mask = original_mask[:,:,None]

    moved_mask = move_mask_func(original_mask, 'up', int(moving_pixels)).squeeze()    
    mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")

    masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
    masked_image = masked_image.astype(original_image.dtype)
    masked_image = Image.fromarray(masked_image)

    if moved_mask.max() <= 1:
        moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
        original_mask = moved_mask

    return [masked_image], [mask_image], original_mask.astype(np.uint8)          


def move_mask_down(input_image, 
                   original_image, 
                   original_mask, 
                   moving_pixels, 
                   resize_default, 
                   aspect_ratio_name):
    alpha_mask = input_image["layers"][0].split()[3]
    input_mask = np.asarray(alpha_mask)
    output_w, output_h = aspect_ratios[aspect_ratio_name]
    if output_w == "" or output_h == "":    
        output_h, output_w = original_image.shape[:2]
        if resize_default:
            short_side = min(output_w, output_h)
            scale_ratio = 640 / short_side
            output_w = int(output_w * scale_ratio)
            output_h = int(output_h * scale_ratio)
            original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
            original_image = np.array(original_image)
            if input_mask is not None:
                input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
                input_mask = np.array(input_mask)
            if original_mask is not None:
                original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
                original_mask = np.array(original_mask)
            gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
        else:
            gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
            pass 
    else:
        if resize_default:
            short_side = min(output_w, output_h)
            scale_ratio = 640 / short_side
            output_w = int(output_w * scale_ratio)
            output_h = int(output_h * scale_ratio)
        gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
        original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
        original_image = np.array(original_image)
        if input_mask is not None:
            input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
            input_mask = np.array(input_mask)
        if original_mask is not None:
            original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
            original_mask = np.array(original_mask)

    if input_mask.max() == 0:
        original_mask = original_mask
    else:
        original_mask = input_mask

    if original_mask is None:
        raise gr.Error('Please generate mask first')

    if original_mask.ndim == 2:
        original_mask = original_mask[:,:,None]

    moved_mask = move_mask_func(original_mask, 'down', int(moving_pixels)).squeeze()
    mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")
         
    masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
    masked_image = masked_image.astype(original_image.dtype)
    masked_image = Image.fromarray(masked_image)

    if moved_mask.max() <= 1:
        moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
        original_mask = moved_mask  

    return [masked_image], [mask_image], original_mask.astype(np.uint8)


def invert_mask(input_image, 
                original_image, 
                original_mask,
                ):
    alpha_mask = input_image["layers"][0].split()[3]
    input_mask = np.asarray(alpha_mask) 
    if input_mask.max() == 0:
        original_mask = 1 - (original_mask>0).astype(np.uint8)
    else:
        original_mask = 1 - (input_mask>0).astype(np.uint8)

    if original_mask is None:
        raise gr.Error('Please generate mask first')

    original_mask = original_mask.squeeze()
    mask_image = Image.fromarray(original_mask*255).convert("RGB")

    if original_mask.ndim == 2:
        original_mask = original_mask[:,:,None]

    if original_mask.max() <= 1:
        original_mask = (original_mask * 255).astype(np.uint8)

    masked_image = original_image * (1 - (original_mask>0))
    masked_image = masked_image.astype(original_image.dtype)
    masked_image = Image.fromarray(masked_image)
    
    return [masked_image], [mask_image], original_mask, True


def init_img(base, 
             init_type, 
             prompt,
             aspect_ratio,
             example_change_times
             ):
    image_pil = base["background"].convert("RGB")
    original_image = np.array(image_pil)
    if max(original_image.shape[0], original_image.shape[1]) * 1.0 / min(original_image.shape[0], original_image.shape[1])>2.0:
        raise gr.Error('image aspect ratio cannot be larger than 2.0')
    if init_type in MASK_IMAGE_PATH.keys() and example_change_times < 2:
        mask_gallery = [Image.open(MASK_IMAGE_PATH[init_type]).convert("L")]
        masked_gallery = [Image.open(MASKED_IMAGE_PATH[init_type]).convert("RGB")]
        result_gallery = [Image.open(OUTPUT_IMAGE_PATH[init_type]).convert("RGB")]
        width, height = image_pil.size
        image_processor = VaeImageProcessor(vae_scale_factor=pipe.vae_scale_factor, do_convert_rgb=True)
        height_new, width_new = image_processor.get_default_height_width(image_pil, height, width)
        image_pil = image_pil.resize((width_new, height_new))
        mask_gallery[0] = mask_gallery[0].resize((width_new, height_new))
        masked_gallery[0] = masked_gallery[0].resize((width_new, height_new))
        result_gallery[0] = result_gallery[0].resize((width_new, height_new))
        original_mask = np.array(mask_gallery[0]).astype(np.uint8)[:,:,None] # h,w,1
        return base, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, "", "", "Custom resolution", False, False, example_change_times
    else:
        if aspect_ratio not in ASPECT_RATIO_LABELS:
            aspect_ratio = "Custom resolution"
        return base, original_image, None, "", None, None, None, "", "", aspect_ratio, True, False, 0


def reset_func(input_image, 
               original_image, 
               original_mask, 
               prompt, 
               target_prompt, 
               ):
    input_image = None
    original_image = None
    original_mask = None
    prompt = ''
    mask_gallery = []
    masked_gallery = []
    result_gallery = []
    target_prompt = ''
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    return input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, target_prompt, True, False


def update_example(example_type, 
                   prompt, 
                   example_change_times):
    input_image = INPUT_IMAGE_PATH[example_type]
    image_pil = Image.open(input_image).convert("RGB")
    mask_gallery = [Image.open(MASK_IMAGE_PATH[example_type]).convert("L")]
    masked_gallery = [Image.open(MASKED_IMAGE_PATH[example_type]).convert("RGB")]
    result_gallery = [Image.open(OUTPUT_IMAGE_PATH[example_type]).convert("RGB")]
    width, height = image_pil.size
    image_processor = VaeImageProcessor(vae_scale_factor=pipe.vae_scale_factor, do_convert_rgb=True)
    height_new, width_new = image_processor.get_default_height_width(image_pil, height, width)
    image_pil = image_pil.resize((width_new, height_new))
    mask_gallery[0] = mask_gallery[0].resize((width_new, height_new))
    masked_gallery[0] = masked_gallery[0].resize((width_new, height_new))
    result_gallery[0] = result_gallery[0].resize((width_new, height_new))

    original_image = np.array(image_pil)
    original_mask = np.array(mask_gallery[0]).astype(np.uint8)[:,:,None] # h,w,1
    aspect_ratio = "Custom resolution"
    example_change_times += 1
    return input_image, prompt, original_image, original_mask, mask_gallery, masked_gallery, result_gallery, aspect_ratio, "", False, example_change_times


block = gr.Blocks(
        theme=gr.themes.Soft(
             radius_size=gr.themes.sizes.radius_none,
             text_size=gr.themes.sizes.text_md
         )
        )
with block as demo:
    with gr.Row():
        with gr.Column(): 
            gr.HTML(head)

    gr.Markdown(descriptions)

    with gr.Accordion(label="🧭 Instructions:", open=True, elem_id="accordion"):
        with gr.Row(equal_height=True):
            gr.Markdown(instructions)

    original_image = gr.State(value=None)
    original_mask = gr.State(value=None)
    category = gr.State(value=None)
    status = gr.State(value=None)
    invert_mask_state = gr.State(value=False)
    example_change_times = gr.State(value=0)


    with gr.Row():
        with gr.Column():
            with gr.Row():
                input_image = gr.ImageEditor( 
                    label="Input Image",
                    type="pil",
                    brush=gr.Brush(colors=["#FFFFFF"], default_size = 30, color_mode="fixed"),
                    layers = False,
                    interactive=True,
                    height=1024,
                    sources=["upload"],
                    placeholder="Please click here or the icon below to upload the image.",
                    )

            prompt = gr.Textbox(label="⌨️ Instruction", placeholder="Please input your instruction.", value="",lines=1)
            run_button = gr.Button("💫 Run")
            
            vlm_model_dropdown = gr.Dropdown(label="VLM model", choices=VLM_MODEL_NAMES, value=DEFAULT_VLM_MODEL_NAME, interactive=True)
            with gr.Group():    
                with gr.Row():
                    GPT4o_KEY = gr.Textbox(label="GPT4o API Key", placeholder="Please input your GPT4o API Key when use GPT4o VLM (highly recommended).", value="", lines=1)
                    
                    GPT4o_KEY_submit = gr.Button("Submit and Verify")


            aspect_ratio = gr.Dropdown(label="Output aspect ratio", choices=ASPECT_RATIO_LABELS, value=DEFAULT_ASPECT_RATIO)
            resize_default = gr.Checkbox(label="Short edge resize to 640px", value=True)

            with gr.Row():
                mask_button = gr.Button("Generate Mask")
                random_mask_button = gr.Button("Square/Circle Mask ")
            

            with gr.Row():
                generate_target_prompt_button = gr.Button("Generate Target Prompt")
                
            target_prompt = gr.Text(
                        label="Input Target Prompt",
                        max_lines=5,
                        placeholder="VLM-generated target prompt, you can first generate if and then modify it (optional)",
                        value='',
                        lines=2
                    )

            with gr.Accordion("Advanced Options", open=False, elem_id="accordion1"):
                base_model_dropdown = gr.Dropdown(label="Base model", choices=BASE_MODELS, value=DEFAULT_BASE_MODEL, interactive=True)
                negative_prompt = gr.Text(
                        label="Negative Prompt",
                        max_lines=5,
                        placeholder="Please input your negative prompt",
                        value='ugly, low quality',lines=1
                    )
                                    
                control_strength = gr.Slider(
                    label="Control Strength: ", show_label=True, minimum=0, maximum=1.1, value=1, step=0.01
                    )
                with gr.Group():
                    seed = gr.Slider(
                        label="Seed: ", minimum=0, maximum=2147483647, step=1, value=648464818
                    )
                    randomize_seed = gr.Checkbox(label="Randomize seed", value=False)
                
                blending = gr.Checkbox(label="Blending mode", value=True)

                
                num_samples = gr.Slider(
                    label="Num samples", minimum=0, maximum=4, step=1, value=4
                )
                
                with gr.Group():
                    with gr.Row():
                        guidance_scale = gr.Slider(
                            label="Guidance scale",
                            minimum=1,
                            maximum=12,
                            step=0.1,
                            value=7.5,
                        )
                        num_inference_steps = gr.Slider(
                            label="Number of inference steps",
                            minimum=1,
                            maximum=50,
                            step=1,
                            value=50,
                        )

            
        with gr.Column():
            with gr.Row():
                with gr.Tab(elem_classes="feedback", label="Masked Image"):
                    masked_gallery = gr.Gallery(label='Masked Image', show_label=True, elem_id="gallery", preview=True, height=360)
                with gr.Tab(elem_classes="feedback", label="Mask"):
                    mask_gallery = gr.Gallery(label='Mask', show_label=True, elem_id="gallery", preview=True, height=360)
                
            invert_mask_button = gr.Button("Invert Mask")
            dilation_size = gr.Slider(
                        label="Dilation size: ", minimum=0, maximum=50, step=1, value=20
                    )
            with gr.Row():
                dilation_mask_button = gr.Button("Dilation Generated Mask")
                erosion_mask_button = gr.Button("Erosion Generated Mask")

            moving_pixels = gr.Slider(
                    label="Moving pixels:", show_label=True, minimum=0, maximum=50, value=4, step=1
                    )
            with gr.Row():
                move_left_button = gr.Button("Move Left")
                move_right_button = gr.Button("Move Right")
            with gr.Row():
                move_up_button = gr.Button("Move Up")
                move_down_button = gr.Button("Move Down")
            
            with gr.Tab(elem_classes="feedback", label="Output"):
                result_gallery = gr.Gallery(label='Output', show_label=True, elem_id="gallery", preview=True, height=400)

            # target_prompt_output = gr.Text(label="Output Target Prompt", value="", lines=1, interactive=False)

            reset_button = gr.Button("Reset")

            init_type = gr.Textbox(label="Init Name", value="", visible=False)
            example_type = gr.Textbox(label="Example Name", value="", visible=False)



    with gr.Row():
        example = gr.Examples(
            label="Quick Example",
            examples=EXAMPLES,
            inputs=[input_image, prompt, seed, init_type, example_type, blending, resize_default, vlm_model_dropdown],
            examples_per_page=10,
            cache_examples=False,
        )
    

    with gr.Accordion(label="🎬 Feature Details:", open=True, elem_id="accordion"):
        with gr.Row(equal_height=True):
            gr.Markdown(tips)

    with gr.Row():
        gr.Markdown(citation)

    ## gr.examples can not be used to update the gr.Gallery, so we need to use the following two functions to update the gr.Gallery. 
    ## And we need to solve the conflict between the upload and change example functions.
    input_image.upload(
        init_img,
        [input_image, init_type, prompt, aspect_ratio, example_change_times],
        [input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, target_prompt, init_type, aspect_ratio, resize_default, invert_mask_state, example_change_times]
    ) 
    example_type.change(fn=update_example, inputs=[example_type, prompt, example_change_times], outputs=[input_image, prompt, original_image, original_mask, mask_gallery, masked_gallery, result_gallery, aspect_ratio, target_prompt, invert_mask_state, example_change_times])
    
    ## vlm and base model dropdown
    vlm_model_dropdown.change(fn=update_vlm_model, inputs=[vlm_model_dropdown], outputs=[status])
    base_model_dropdown.change(fn=update_base_model, inputs=[base_model_dropdown], outputs=[status])


    GPT4o_KEY_submit.click(fn=submit_GPT4o_KEY, inputs=[GPT4o_KEY], outputs=[GPT4o_KEY, vlm_model_dropdown])
    invert_mask_button.click(fn=invert_mask, inputs=[input_image, original_image, original_mask], outputs=[masked_gallery, mask_gallery, original_mask, invert_mask_state])


    ips=[input_image, 
         original_image, 
         original_mask, 
         prompt, 
         negative_prompt, 
         control_strength, 
         seed, 
         randomize_seed, 
         guidance_scale, 
         num_inference_steps,
         num_samples,
         blending,
         category,
         target_prompt,
         resize_default,
         aspect_ratio,
         invert_mask_state]

    ## run brushedit
    run_button.click(fn=process, inputs=ips, outputs=[result_gallery, mask_gallery, masked_gallery, prompt, target_prompt, invert_mask_state])
    
    ## mask func
    mask_button.click(fn=process_mask, inputs=[input_image, original_image, prompt, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask, category])
    random_mask_button.click(fn=process_random_mask, inputs=[input_image, original_image, original_mask, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask])
    dilation_mask_button.click(fn=process_dilation_mask, inputs=[input_image, original_image, original_mask, resize_default, aspect_ratio, dilation_size], outputs=[ masked_gallery, mask_gallery, original_mask])
    erosion_mask_button.click(fn=process_erosion_mask, inputs=[input_image, original_image, original_mask, resize_default, aspect_ratio, dilation_size], outputs=[ masked_gallery, mask_gallery, original_mask])

    ## move mask func
    move_left_button.click(fn=move_mask_left, inputs=[input_image, original_image, original_mask, moving_pixels, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask])
    move_right_button.click(fn=move_mask_right, inputs=[input_image, original_image, original_mask, moving_pixels, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask])
    move_up_button.click(fn=move_mask_up, inputs=[input_image, original_image, original_mask, moving_pixels, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask])
    move_down_button.click(fn=move_mask_down, inputs=[input_image, original_image, original_mask, moving_pixels, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask])    

    ## prompt func
    generate_target_prompt_button.click(fn=generate_target_prompt, inputs=[input_image, original_image, prompt], outputs=[target_prompt])
    
    ## reset func
    reset_button.click(fn=reset_func, inputs=[input_image, original_image, original_mask, prompt, target_prompt], outputs=[input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, target_prompt, resize_default, invert_mask_state])
    
    
demo.launch()