##!/usr/bin/python3 # -*- coding: utf-8 -*- import os print("Installing correct gradio version...") os.system("pip uninstall -y gradio") os.system("pip install gradio==3.50.0") print("Installing Finished!") ##!/usr/bin/python3 # -*- coding: utf-8 -*- import gradio as gr import os import cv2 from PIL import Image import numpy as np from segment_anything import SamPredictor, sam_model_registry import torch from diffusers import StableDiffusionBrushNetPipeline, BrushNetModel, UniPCMultistepScheduler import random mobile_sam = sam_model_registry['vit_h'](checkpoint='data/ckpt/sam_vit_h_4b8939.pth').to("cuda") mobile_sam.eval() mobile_predictor = SamPredictor(mobile_sam) colors = [(255, 0, 0), (0, 255, 0)] markers = [1, 5] # - - - - - examples - - - - - # image_examples = [ ["examples/brushnet/src/test_image.jpg", "A beautiful cake on the table", "examples/brushnet/src/test_mask.jpg", 0, [], [Image.open("examples/brushnet/src/test_result.png")]], ["examples/brushnet/src/example_1.jpg", "A man in Chinese traditional clothes", "examples/brushnet/src/example_1_mask.jpg", 1, [], [Image.open("examples/brushnet/src/example_1_result.png")]], ["examples/brushnet/src/example_3.jpg", "a cut toy on the table", "examples/brushnet/src/example_3_mask.jpg", 2, [], [Image.open("examples/brushnet/src/example_3_result.png")]], ["examples/brushnet/src/example_4.jpeg", "a car driving in the wild", "examples/brushnet/src/example_4_mask.jpg", 3, [], [Image.open("examples/brushnet/src/example_4_result.png")]], ["examples/brushnet/src/example_5.jpg", "a charming woman wearing dress standing in the dark forest", "examples/brushnet/src/example_5_mask.jpg", 4, [], [Image.open("examples/brushnet/src/example_5_result.png")]], ] # choose the base model here base_model_path = "data/ckpt/realisticVisionV60B1_v51VAE" # base_model_path = "runwayml/stable-diffusion-v1-5" # input brushnet ckpt path brushnet_path = "data/ckpt/segmentation_mask_brushnet_ckpt" brushnet = BrushNetModel.from_pretrained(brushnet_path, torch_dtype=torch.float16) pipe = StableDiffusionBrushNetPipeline.from_pretrained( base_model_path, brushnet=brushnet, torch_dtype=torch.float16, low_cpu_mem_usage=False ) # speed up diffusion process with faster scheduler and memory optimization pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) # remove following line if xformers is not installed or when using Torch 2.0. # pipe.enable_xformers_memory_efficient_attention() # memory optimization. pipe.enable_model_cpu_offload() def resize_image(input_image, resolution): H, W, C = input_image.shape H = float(H) W = float(W) k = float(resolution) / min(H, W) H *= k W *= k H = int(np.round(H / 64.0)) * 64 W = int(np.round(W / 64.0)) * 64 img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA) return img # once user upload an image, the original image is stored in `original_image` def store_img(img): # image upload is too slow if min(img.shape[0], img.shape[1]) > 512: img = resize_image(img, 512) if max(img.shape[0], img.shape[1])*1.0/min(img.shape[0], img.shape[1])>2.0: raise gr.Error('image aspect ratio cannot be larger than 2.0') return img def process(original_image, input_mask, prompt, negative_prompt, blended, invert_mask, control_strength, seed, randomize_seed, guidance_scale, num_inference_steps): if original_image is None: raise gr.Error('Please upload the input image') if input_mask is None: raise gr.Error("Please upload a white-black Mask image") #resizing input image and mask of the object original_image = store_img(original_image) input_mask = store_img(input_mask) H, W = original_image.shape[:2] original_mask = cv2.resize(input_mask, (W, H)) if invert_mask: original_mask = 255 - original_mask mask = 1.*(original_mask.sum(-1) > 255)[:,:,np.newaxis] masked_image = original_image * (1 - mask) init_image = Image.fromarray(masked_image.astype(np.uint8)).convert("RGB") mask_image = Image.fromarray(original_mask.astype(np.uint8)).convert("RGB") generator = torch.Generator("cuda").manual_seed(random.randint(0, 2147483647) if randomize_seed else seed) image = pipe( [prompt]*2, init_image, mask_image, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, generator=generator, brushnet_conditioning_scale=float(control_strength), negative_prompt=[negative_prompt]*2, ).images if blended: if control_strength < 1.0: raise gr.Error('Using blurred blending with control strength less than 1.0 is not allowed') blended_image = [] mask_blurred = cv2.GaussianBlur(mask*255, (21, 21), 0)/255 mask_blurred = mask_blurred[:,:,np.newaxis] mask = 1 - (1 - mask) * (1 - mask_blurred) for image_i in image: image_np = np.array(image_i) image_pasted = original_image * (1 - mask) + image_np * mask image_pasted = image_pasted.astype(image_np.dtype) blended_image.append(Image.fromarray(image_pasted)) image = blended_image return image # Create Gradio interface with gr.Blocks() as demo: with gr.Row(): with gr.Column(): original_image = gr.Image(type="numpy", label="Original Image") input_mask = gr.Image(type="numpy", label="Mask Image") prompt = gr.Textbox(label="Prompt") negative_prompt = gr.Textbox(label="Negative Prompt", value='ugly, low quality') blended = gr.Checkbox(label="Blurred Blending", value=False) invert_mask = gr.Checkbox(label="Invert Mask", value=False) control_strength = gr.Slider(label="Control Strength", minimum=0, maximum=1.1, value=1, step=0.01) seed = gr.Slider(label="Seed", minimum=0, maximum=2147483647, step=1, value=551793204) randomize_seed = gr.Checkbox(label="Randomize Seed", value=False) guidance_scale = gr.Slider(label="Guidance Scale", minimum=1, maximum=12, step=0.1, value=7.5) num_inference_steps = gr.Slider(label="Number of Inference Steps", minimum=1, maximum=50, step=1, value=50) #selected_points = gr.State([],label="select points") run_button = gr.Button("Run") with gr.Column(): result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery", preview=True) inputs = [original_image, input_mask, prompt, negative_prompt, blended, invert_mask, control_strength, seed, randomize_seed, guidance_scale, num_inference_steps] run_button.click(fn=process, inputs=inputs, outputs=[result_gallery]) demo.queue(concurrency_count=1, api_open=True) demo.launch(show_api=True, enable_queue=True, show_error=True)