BrushNet1.0 / app.py
Kidzure's picture
Update app.py
da17f18 verified
raw
history blame
6.94 kB
##!/usr/bin/python3
# -*- coding: utf-8 -*-
import os
print("Installing correct gradio version...")
os.system("pip uninstall -y gradio")
os.system("pip install gradio==3.50.0")
print("Installing Finished!")
##!/usr/bin/python3
# -*- coding: utf-8 -*-
import gradio as gr
import os
import cv2
from PIL import Image
import numpy as np
from segment_anything import SamPredictor, sam_model_registry
import torch
from diffusers import StableDiffusionBrushNetPipeline, BrushNetModel, UniPCMultistepScheduler
import random
mobile_sam = sam_model_registry['vit_h'](checkpoint='data/ckpt/sam_vit_h_4b8939.pth').to("cuda")
mobile_sam.eval()
mobile_predictor = SamPredictor(mobile_sam)
colors = [(255, 0, 0), (0, 255, 0)]
markers = [1, 5]
# - - - - - examples - - - - - #
image_examples = [
["examples/brushnet/src/test_image.jpg", "A beautiful cake on the table", "examples/brushnet/src/test_mask.jpg", 0, [], [Image.open("examples/brushnet/src/test_result.png")]],
["examples/brushnet/src/example_1.jpg", "A man in Chinese traditional clothes", "examples/brushnet/src/example_1_mask.jpg", 1, [], [Image.open("examples/brushnet/src/example_1_result.png")]],
["examples/brushnet/src/example_3.jpg", "a cut toy on the table", "examples/brushnet/src/example_3_mask.jpg", 2, [], [Image.open("examples/brushnet/src/example_3_result.png")]],
["examples/brushnet/src/example_4.jpeg", "a car driving in the wild", "examples/brushnet/src/example_4_mask.jpg", 3, [], [Image.open("examples/brushnet/src/example_4_result.png")]],
["examples/brushnet/src/example_5.jpg", "a charming woman wearing dress standing in the dark forest", "examples/brushnet/src/example_5_mask.jpg", 4, [], [Image.open("examples/brushnet/src/example_5_result.png")]],
]
# choose the base model here
base_model_path = "data/ckpt/realisticVisionV60B1_v51VAE"
# base_model_path = "runwayml/stable-diffusion-v1-5"
# input brushnet ckpt path
brushnet_path = "data/ckpt/segmentation_mask_brushnet_ckpt"
brushnet = BrushNetModel.from_pretrained(brushnet_path, torch_dtype=torch.float16)
pipe = StableDiffusionBrushNetPipeline.from_pretrained(
base_model_path, brushnet=brushnet, torch_dtype=torch.float16, low_cpu_mem_usage=False
)
# speed up diffusion process with faster scheduler and memory optimization
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
# remove following line if xformers is not installed or when using Torch 2.0.
# pipe.enable_xformers_memory_efficient_attention()
# memory optimization.
pipe.enable_model_cpu_offload()
def resize_image(input_image, resolution):
H, W, C = input_image.shape
H = float(H)
W = float(W)
k = float(resolution) / min(H, W)
H *= k
W *= k
H = int(np.round(H / 64.0)) * 64
W = int(np.round(W / 64.0)) * 64
img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA)
return img
# once user upload an image, the original image is stored in `original_image`
def store_img(img):
# image upload is too slow
if min(img.shape[0], img.shape[1]) > 512:
img = resize_image(img, 512)
if max(img.shape[0], img.shape[1])*1.0/min(img.shape[0], img.shape[1])>2.0:
raise gr.Error('image aspect ratio cannot be larger than 2.0')
return img
def process(original_image, input_mask, prompt, negative_prompt, blended, invert_mask, control_strength, seed, randomize_seed, guidance_scale, num_inference_steps):
if original_image is None:
raise gr.Error('Please upload the input image')
if input_mask is None:
raise gr.Error("Please upload a white-black Mask image")
#resizing input image and mask of the object
original_image = store_img(original_image)
input_mask = store_img(input_mask)
H, W = original_image.shape[:2]
original_mask = cv2.resize(input_mask, (W, H))
if invert_mask:
original_mask = 255 - original_mask
mask = 1.*(original_mask.sum(-1) > 255)[:,:,np.newaxis]
masked_image = original_image * (1 - mask)
init_image = Image.fromarray(masked_image.astype(np.uint8)).convert("RGB")
mask_image = Image.fromarray(original_mask.astype(np.uint8)).convert("RGB")
generator = torch.Generator("cuda").manual_seed(random.randint(0, 2147483647) if randomize_seed else seed)
image = pipe(
[prompt]*2,
init_image,
mask_image,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
generator=generator,
brushnet_conditioning_scale=float(control_strength),
negative_prompt=[negative_prompt]*2,
).images
if blended:
if control_strength < 1.0:
raise gr.Error('Using blurred blending with control strength less than 1.0 is not allowed')
blended_image = []
mask_blurred = cv2.GaussianBlur(mask*255, (21, 21), 0)/255
mask_blurred = mask_blurred[:,:,np.newaxis]
mask = 1 - (1 - mask) * (1 - mask_blurred)
for image_i in image:
image_np = np.array(image_i)
image_pasted = original_image * (1 - mask) + image_np * mask
image_pasted = image_pasted.astype(image_np.dtype)
blended_image.append(Image.fromarray(image_pasted))
image = blended_image
return image
# Create Gradio interface
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
original_image = gr.Image(type="numpy", label="Original Image")
input_mask = gr.Image(type="numpy", label="Mask Image")
prompt = gr.Textbox(label="Prompt")
negative_prompt = gr.Textbox(label="Negative Prompt", value='ugly, low quality')
blended = gr.Checkbox(label="Blurred Blending", value=False)
invert_mask = gr.Checkbox(label="Invert Mask", value=False)
control_strength = gr.Slider(label="Control Strength", minimum=0, maximum=1.1, value=1, step=0.01)
seed = gr.Slider(label="Seed", minimum=0, maximum=2147483647, step=1, value=551793204)
randomize_seed = gr.Checkbox(label="Randomize Seed", value=False)
guidance_scale = gr.Slider(label="Guidance Scale", minimum=1, maximum=12, step=0.1, value=7.5)
num_inference_steps = gr.Slider(label="Number of Inference Steps", minimum=1, maximum=50, step=1, value=50)
#selected_points = gr.State([],label="select points")
run_button = gr.Button("Run")
with gr.Column():
result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery", preview=True)
inputs = [original_image, input_mask, prompt, negative_prompt, blended, invert_mask, control_strength, seed, randomize_seed, guidance_scale, num_inference_steps]
run_button.click(fn=process, inputs=inputs, outputs=[result_gallery])
demo.queue(concurrency_count=1, api_open=True)
demo.launch(show_api=True, enable_queue=True, show_error=True)