# ************************************************************************* # Copyright (2023) Bytedance Inc. # # Copyright (2023) DragDiffusion Authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ************************************************************************* import cv2 import numpy as np import PIL from PIL import Image from PIL.ImageOps import exif_transpose import os import gradio as gr import datetime import pickle from copy import deepcopy LENGTH=480 # length of the square area displaying/editing images def clear_all(length=480): return gr.Image.update(value=None, height=length, width=length), \ gr.Image.update(value=None, height=length, width=length), \ [], None, None def mask_image(image, mask, color=[255,0,0], alpha=0.5): """ Overlay mask on image for visualization purpose. Args: image (H, W, 3) or (H, W): input image mask (H, W): mask to be overlaid color: the color of overlaid mask alpha: the transparency of the mask """ out = deepcopy(image) img = deepcopy(image) img[mask == 1] = color out = cv2.addWeighted(img, alpha, out, 1-alpha, 0, out) return out def store_img(img, length=512): image, mask = img["image"], np.float32(img["mask"][:, :, 0]) / 255. height,width,_ = image.shape image = Image.fromarray(image) image = exif_transpose(image) image = image.resize((length,int(length*height/width)), PIL.Image.BILINEAR) mask = cv2.resize(mask, (length,int(length*height/width)), interpolation=cv2.INTER_NEAREST) image = np.array(image) if mask.sum() > 0: mask = np.uint8(mask > 0) masked_img = mask_image(image, 1 - mask, color=[0, 0, 0], alpha=0.3) else: masked_img = image.copy() # when new image is uploaded, `selected_points` should be empty return image, [], masked_img, mask # user click the image to get points, and show the points on the image def get_points(img, sel_pix, evt: gr.SelectData): # collect the selected point sel_pix.append(evt.index) # draw points points = [] for idx, point in enumerate(sel_pix): if idx % 2 == 0: # draw a red circle at the handle point cv2.circle(img, tuple(point), 10, (255, 0, 0), -1) else: # draw a blue circle at the handle point cv2.circle(img, tuple(point), 10, (0, 0, 255), -1) points.append(tuple(point)) # draw an arrow from handle point to target point if len(points) == 2: cv2.arrowedLine(img, points[0], points[1], (255, 255, 255), 4, tipLength=0.5) points = [] return img if isinstance(img, np.ndarray) else np.array(img) # clear all handle/target points def undo_points(original_image, mask): if mask.sum() > 0: mask = np.uint8(mask > 0) masked_img = mask_image(original_image, 1 - mask, color=[0, 0, 0], alpha=0.3) else: masked_img = original_image.copy() return masked_img, [] def save_all(category, source_image, image_with_clicks, mask, labeler, prompt, points, root_dir='./drag_bench_data'): if not os.path.isdir(root_dir): os.mkdir(root_dir) if not os.path.isdir(os.path.join(root_dir, category)): os.mkdir(os.path.join(root_dir, category)) save_prefix = labeler + '_' + datetime.datetime.now().strftime("%Y-%m-%d-%H%M-%S") save_dir = os.path.join(root_dir, category, save_prefix) if not os.path.isdir(save_dir): os.mkdir(save_dir) # save images Image.fromarray(source_image).save(os.path.join(save_dir, 'original_image.png')) Image.fromarray(image_with_clicks).save(os.path.join(save_dir, 'user_drag.png')) # save meta data meta_data = { 'prompt' : prompt, 'points' : points, 'mask' : mask, } with open(os.path.join(save_dir, 'meta_data.pkl'), 'wb') as f: pickle.dump(meta_data, f) return save_prefix + " saved!" with gr.Blocks() as demo: # UI components for editing real images with gr.Tab(label="Editing Real Image"): mask = gr.State(value=None) # store mask selected_points = gr.State([]) # store points original_image = gr.State(value=None) # store original input image with gr.Row(): with gr.Column(): gr.Markdown("""
Draw Mask
""") canvas = gr.Image(type="numpy", tool="sketch", label="Draw Mask", show_label=True, height=LENGTH, width=LENGTH) # for mask painting with gr.Column(): gr.Markdown("""Click Points
""") input_image = gr.Image(type="numpy", label="Click Points", show_label=True, height=LENGTH, width=LENGTH) # for points clicking with gr.Row(): labeler = gr.Textbox(label="Labeler") category = gr.Dropdown(value="art_work", label="Image Category", choices=[ 'art_work', 'land_scape', 'building_city_view', 'building_countryside_view', 'animals', 'human_head', 'human_upper_body', 'human_full_body', 'interior_design', 'other_objects', ] ) prompt = gr.Textbox(label="Prompt") save_status = gr.Textbox(label="display saving status") with gr.Row(): undo_button = gr.Button("undo points") clear_all_button = gr.Button("clear all") save_button = gr.Button("save") # event definition # event for dragging user-input real image canvas.edit( store_img, [canvas], [original_image, selected_points, input_image, mask] ) input_image.select( get_points, [input_image, selected_points], [input_image], ) undo_button.click( undo_points, [original_image, mask], [input_image, selected_points] ) clear_all_button.click( clear_all, [gr.Number(value=LENGTH, visible=False, precision=0)], [canvas, input_image, selected_points, original_image, mask] ) save_button.click( save_all, [category, original_image, input_image, mask, labeler, prompt, selected_points,], [save_status] ) demo.queue().launch(share=True, debug=True)