File size: 6,941 Bytes
5bf7c30
 
 
884584a
 
 
 
 
 
81bbc60
 
 
 
5bf7c30
 
 
 
 
 
 
cd256bf
efdf9b6
5bf7c30
 
 
 
 
 
 
f79954b
 
704eb5b
 
 
5bf7c30
 
 
 
 
 
 
 
 
 
81bbc60
5bf7c30
81bbc60
5bf7c30
 
 
 
 
 
 
81bbc60
5bf7c30
 
 
 
 
 
 
 
 
 
 
 
 
17eb8da
 
 
da17f18
 
17eb8da
 
 
 
5182029
5bf7c30
5182029
 
 
17eb8da
 
 
5bf7c30
5182029
 
5bf7c30
 
5182029
 
 
5bf7c30
 
5182029
5bf7c30
5182029
 
 
 
5bf7c30
 
 
 
 
 
 
5182029
5bf7c30
5182029
401ed77
 
5182029
5bf7c30
5182029
 
 
5bf7c30
5182029
5bf7c30
 
 
5182029
 
5bf7c30
 
5182029
 
 
 
 
 
 
 
 
 
 
da9726f
5182029
 
5bf7c30
5182029
97adb33
5182029
 
5bf7c30
0cd4dd1
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
##!/usr/bin/python3
# -*- coding: utf-8 -*-
import os

print("Installing correct gradio version...")
os.system("pip uninstall -y gradio")
os.system("pip install gradio==3.50.0")
print("Installing Finished!")

##!/usr/bin/python3
# -*- coding: utf-8 -*-
import gradio as gr
import os
import cv2
from PIL import Image
import numpy as np
from segment_anything import SamPredictor, sam_model_registry
import torch
from diffusers import StableDiffusionBrushNetPipeline, BrushNetModel, UniPCMultistepScheduler
import random

mobile_sam = sam_model_registry['vit_h'](checkpoint='data/ckpt/sam_vit_h_4b8939.pth').to("cuda")
mobile_sam.eval()
mobile_predictor = SamPredictor(mobile_sam)
colors = [(255, 0, 0), (0, 255, 0)]
markers = [1, 5]

# - - - - - examples  - - - - -  #
image_examples = [
    ["examples/brushnet/src/test_image.jpg", "A beautiful cake on the table", "examples/brushnet/src/test_mask.jpg", 0, [], [Image.open("examples/brushnet/src/test_result.png")]],
    ["examples/brushnet/src/example_1.jpg", "A man in Chinese traditional clothes", "examples/brushnet/src/example_1_mask.jpg", 1, [], [Image.open("examples/brushnet/src/example_1_result.png")]],
    ["examples/brushnet/src/example_3.jpg", "a cut toy on the table", "examples/brushnet/src/example_3_mask.jpg", 2, [], [Image.open("examples/brushnet/src/example_3_result.png")]],
    ["examples/brushnet/src/example_4.jpeg", "a car driving in the wild", "examples/brushnet/src/example_4_mask.jpg", 3, [], [Image.open("examples/brushnet/src/example_4_result.png")]],
    ["examples/brushnet/src/example_5.jpg", "a charming woman wearing dress standing in the dark forest", "examples/brushnet/src/example_5_mask.jpg", 4, [], [Image.open("examples/brushnet/src/example_5_result.png")]],
]


# choose the base model here
base_model_path = "data/ckpt/realisticVisionV60B1_v51VAE"
# base_model_path = "runwayml/stable-diffusion-v1-5"

# input brushnet ckpt path
brushnet_path = "data/ckpt/segmentation_mask_brushnet_ckpt"

brushnet = BrushNetModel.from_pretrained(brushnet_path, torch_dtype=torch.float16)
pipe = StableDiffusionBrushNetPipeline.from_pretrained(
    base_model_path, brushnet=brushnet, torch_dtype=torch.float16, low_cpu_mem_usage=False
)

# speed up diffusion process with faster scheduler and memory optimization
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
# remove following line if xformers is not installed or when using Torch 2.0.
# pipe.enable_xformers_memory_efficient_attention()
# memory optimization.
pipe.enable_model_cpu_offload()

def resize_image(input_image, resolution):
    H, W, C = input_image.shape
    H = float(H)
    W = float(W)
    k = float(resolution) / min(H, W)
    H *= k
    W *= k
    H = int(np.round(H / 64.0)) * 64
    W = int(np.round(W / 64.0)) * 64
    img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA)
    return img

# once user upload an image, the original image is stored in `original_image`
def store_img(img):
    # image upload is too slow
    if min(img.shape[0], img.shape[1]) > 512:
        img = resize_image(img, 512)
    if max(img.shape[0], img.shape[1])*1.0/min(img.shape[0], img.shape[1])>2.0:
        raise gr.Error('image aspect ratio cannot be larger than 2.0')
    return img

def process(original_image, input_mask, prompt, negative_prompt, blended, invert_mask, control_strength, seed, randomize_seed, guidance_scale, num_inference_steps):
    if original_image is None:
        raise gr.Error('Please upload the input image')
    if input_mask is None:
        raise gr.Error("Please upload a white-black Mask image")
    #resizing input image and mask of the object
    original_image = store_img(original_image)
    input_mask = store_img(input_mask)
    
    H, W = original_image.shape[:2]
    original_mask = cv2.resize(input_mask, (W, H))

    if invert_mask:
        original_mask = 255 - original_mask
    mask = 1.*(original_mask.sum(-1) > 255)[:,:,np.newaxis]
    masked_image = original_image * (1 - mask)
    init_image = Image.fromarray(masked_image.astype(np.uint8)).convert("RGB")
    mask_image = Image.fromarray(original_mask.astype(np.uint8)).convert("RGB")
    generator = torch.Generator("cuda").manual_seed(random.randint(0, 2147483647) if randomize_seed else seed)
    image = pipe(
        [prompt]*2,
        init_image,
        mask_image,
        num_inference_steps=num_inference_steps,
        guidance_scale=guidance_scale,
        generator=generator,
        brushnet_conditioning_scale=float(control_strength),
        negative_prompt=[negative_prompt]*2,
    ).images

    if blended:
        if control_strength < 1.0:
            raise gr.Error('Using blurred blending with control strength less than 1.0 is not allowed')
        blended_image = []
        mask_blurred = cv2.GaussianBlur(mask*255, (21, 21), 0)/255
        mask_blurred = mask_blurred[:,:,np.newaxis]
        mask = 1 - (1 - mask) * (1 - mask_blurred)
        for image_i in image:
            image_np = np.array(image_i)
            image_pasted = original_image * (1 - mask) + image_np * mask
            image_pasted = image_pasted.astype(image_np.dtype)
            blended_image.append(Image.fromarray(image_pasted))
        image = blended_image

    return image

# Create Gradio interface
with gr.Blocks() as demo:
    with gr.Row():
        with gr.Column():
            original_image = gr.Image(type="numpy", label="Original Image")
            input_mask = gr.Image(type="numpy", label="Mask Image")
            prompt = gr.Textbox(label="Prompt")
            negative_prompt = gr.Textbox(label="Negative Prompt", value='ugly, low quality')
            blended = gr.Checkbox(label="Blurred Blending", value=False)
            invert_mask = gr.Checkbox(label="Invert Mask", value=False)
            control_strength = gr.Slider(label="Control Strength", minimum=0, maximum=1.1, value=1, step=0.01)
            seed = gr.Slider(label="Seed", minimum=0, maximum=2147483647, step=1, value=551793204)
            randomize_seed = gr.Checkbox(label="Randomize Seed", value=False)
            guidance_scale = gr.Slider(label="Guidance Scale", minimum=1, maximum=12, step=0.1, value=7.5)
            num_inference_steps = gr.Slider(label="Number of Inference Steps", minimum=1, maximum=50, step=1, value=50)
            #selected_points = gr.State([],label="select points")
            run_button = gr.Button("Run")
        
        with gr.Column():
            result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery", preview=True)
    
    inputs = [original_image, input_mask, prompt, negative_prompt, blended, invert_mask, control_strength, seed, randomize_seed, guidance_scale, num_inference_steps]
    run_button.click(fn=process, inputs=inputs, outputs=[result_gallery])

demo.queue(concurrency_count=1, api_open=True)
demo.launch(show_api=True, enable_queue=True, show_error=True)