Update app.py
Browse files
    	
        app.py
    CHANGED
    
    | @@ -6,6 +6,20 @@ import os | |
| 6 | 
             
            from PIL import Image
         | 
| 7 | 
             
            hf_token = os.environ.get("HF_TOKEN")
         | 
| 8 | 
             
            from diffusers import StableDiffusionXLInpaintPipeline, DDIMScheduler, UNet2DConditionModel
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 9 |  | 
| 10 | 
             
            ratios_map =  {
         | 
| 11 | 
             
                0.5:{"width":704,"height":1408},
         | 
| @@ -28,6 +42,30 @@ ratios_map =  { | |
| 28 | 
             
            }
         | 
| 29 | 
             
            ratios = np.array(list(ratios_map.keys()))
         | 
| 30 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 31 | 
             
            def get_size(init_image):
         | 
| 32 | 
             
                w,h=init_image.size
         | 
| 33 | 
             
                curr_ratio = w/h
         | 
| @@ -40,26 +78,33 @@ def get_size(init_image): | |
| 40 |  | 
| 41 | 
             
            device = "cuda" if torch.cuda.is_available() else "cpu"
         | 
| 42 |  | 
| 43 | 
            -
            unet = UNet2DConditionModel.from_pretrained(
         | 
| 44 | 
            -
                "briaai/BRIA-2.2-Inpainting",
         | 
| 45 | 
            -
                subfolder="unet",
         | 
| 46 | 
            -
                torch_dtype=torch.float16,
         | 
| 47 | 
            -
            )
         | 
| 48 |  | 
| 49 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 50 |  | 
| 51 | 
            -
             | 
| 52 | 
            -
             | 
| 53 | 
            -
             | 
| 54 | 
            -
             | 
| 55 | 
            -
             | 
| 56 | 
            -
             | 
| 57 | 
            -
            )
         | 
|  | |
|  | |
|  | |
| 58 |  | 
| 59 | 
            -
             | 
| 60 | 
            -
            pipe.force_zeros_for_empty_prompt = False
         | 
| 61 |  | 
| 62 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 63 |  | 
| 64 |  | 
| 65 | 
             
            def read_content(file_path: str) -> str:
         | 
| @@ -70,26 +115,74 @@ def read_content(file_path: str) -> str: | |
| 70 |  | 
| 71 | 
             
                return content
         | 
| 72 |  | 
| 73 | 
            -
            def predict(dict, prompt="", negative_prompt= | 
| 74 | 
             
                if negative_prompt == "":
         | 
| 75 | 
             
                    negative_prompt = None
         | 
| 76 |  | 
| 77 |  | 
| 78 | 
             
                init_image = dict["image"].convert("RGB")#.resize((1024, 1024))
         | 
| 79 | 
            -
                mask = dict["mask"].convert(" | 
| 80 |  | 
| 81 | 
            -
                 | 
| 82 |  | 
| 83 | 
            -
                init_image = init_image.resize(( | 
| 84 | 
            -
                mask = mask.resize(( | 
| 85 |  | 
| 86 | 
             
                # Resize to nearest ratio ?
         | 
| 87 |  | 
| 88 | 
            -
                mask = np.array(mask)
         | 
| 89 | 
            -
                mask[mask>0]=255
         | 
| 90 | 
            -
                mask = Image.fromarray(mask)
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 91 |  | 
| 92 | 
            -
                 | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 93 |  | 
| 94 | 
             
                return output.images[0] #, gr.update(visible=True)
         | 
| 95 |  | 
|  | |
| 6 | 
             
            from PIL import Image
         | 
| 7 | 
             
            hf_token = os.environ.get("HF_TOKEN")
         | 
| 8 | 
             
            from diffusers import StableDiffusionXLInpaintPipeline, DDIMScheduler, UNet2DConditionModel
         | 
| 9 | 
            +
            from diffusers import (
         | 
| 10 | 
            +
                AutoencoderKL,
         | 
| 11 | 
            +
                LCMScheduler,
         | 
| 12 | 
            +
            )
         | 
| 13 | 
            +
            from pipeline_controlnet_sd_xl import StableDiffusionXLControlNetPipeline
         | 
| 14 | 
            +
            from controlnet import ControlNetModel, ControlNetConditioningEmbedding
         | 
| 15 | 
            +
            import torch
         | 
| 16 | 
            +
            import numpy as np
         | 
| 17 | 
            +
            from PIL import Image
         | 
| 18 | 
            +
            import requests
         | 
| 19 | 
            +
            import PIL
         | 
| 20 | 
            +
            from io import BytesIO
         | 
| 21 | 
            +
            from torchvision import transforms
         | 
| 22 | 
            +
             | 
| 23 |  | 
| 24 | 
             
            ratios_map =  {
         | 
| 25 | 
             
                0.5:{"width":704,"height":1408},
         | 
|  | |
| 42 | 
             
            }
         | 
| 43 | 
             
            ratios = np.array(list(ratios_map.keys()))
         | 
| 44 |  | 
| 45 | 
            +
            image_transforms = transforms.Compose(
         | 
| 46 | 
            +
                [
         | 
| 47 | 
            +
                    transforms.ToTensor(),
         | 
| 48 | 
            +
                ]
         | 
| 49 | 
            +
            )
         | 
| 50 | 
            +
             | 
| 51 | 
            +
            default_negative_prompt = "Logo,Watermark,Text,Ugly,Morbid,Extra fingers,Poorly drawn hands,Mutation,Blurry,Extra limbs,Gross proportions,Missing arms,Mutated hands,Long neck,Duplicate,Mutilated,Mutilated hands,Poorly drawn face,Deformed,Bad anatomy,Cloned face,Malformed limbs,Missing legs,Too many fingers"
         | 
| 52 | 
            +
             | 
| 53 | 
            +
             | 
| 54 | 
            +
            def get_masked_image(image, image_mask, width, height):
         | 
| 55 | 
            +
                image_mask = image_mask # inpaint area is white
         | 
| 56 | 
            +
                image_mask = image_mask.resize((width, height)) # object to remove is white (1)
         | 
| 57 | 
            +
                image_mask_pil = image_mask
         | 
| 58 | 
            +
                image = np.array(image.convert("RGB")).astype(np.float32) / 255.0
         | 
| 59 | 
            +
                image_mask = np.array(image_mask_pil.convert("L")).astype(np.float32) / 255.0
         | 
| 60 | 
            +
                assert image.shape[0:1] == image_mask.shape[0:1], "image and image_mask must have the same image size"
         | 
| 61 | 
            +
                masked_image_to_present = image.copy()
         | 
| 62 | 
            +
                masked_image_to_present[image_mask > 0.5] = (0.5,0.5,0.5)  # set as masked pixel
         | 
| 63 | 
            +
                image[image_mask > 0.5] = 0.5  # set as masked pixel - s.t. will be grey 
         | 
| 64 | 
            +
                image = Image.fromarray((image * 255.0).astype(np.uint8))
         | 
| 65 | 
            +
                masked_image_to_present = Image.fromarray((masked_image_to_present * 255.0).astype(np.uint8))
         | 
| 66 | 
            +
                return image, image_mask_pil, masked_image_to_present
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                
         | 
| 69 | 
             
            def get_size(init_image):
         | 
| 70 | 
             
                w,h=init_image.size
         | 
| 71 | 
             
                curr_ratio = w/h
         | 
|  | |
| 78 |  | 
| 79 | 
             
            device = "cuda" if torch.cuda.is_available() else "cpu"
         | 
| 80 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 81 |  | 
| 82 | 
            +
            # Load, init model
         | 
| 83 | 
            +
            controlnet = ControlNetModel().from_config('briaai/DEV-ControlNetInpaintingFast', torch_dtype=torch.float16)          
         | 
| 84 | 
            +
            controlnet.controlnet_cond_embedding = ControlNetConditioningEmbedding(
         | 
| 85 | 
            +
                conditioning_embedding_channels=320,
         | 
| 86 | 
            +
                conditioning_channels = 5
         | 
| 87 | 
            +
            )   
         | 
| 88 |  | 
| 89 | 
            +
            controlnet = ControlNetModel().from_pretrained("briaai/DEV-ControlNetInpaintingFast", torch_dtype=torch.float16)
         | 
| 90 | 
            +
            vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
         | 
| 91 | 
            +
            pipe = StableDiffusionXLControlNetPipeline.from_pretrained("briaai/BRIA-2.3", controlnet=controlnet.to(dtype=torch.float16), torch_dtype=torch.float16, vae=vae) #force_zeros_for_empty_prompt=False, # vae=vae)
         | 
| 92 | 
            +
             | 
| 93 | 
            +
            pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
         | 
| 94 | 
            +
            pipe.load_lora_weights("briaai/BRIA-2.3-FAST-LORA")
         | 
| 95 | 
            +
            pipe.fuse_lora()
         | 
| 96 | 
            +
             | 
| 97 | 
            +
            pipe = pipe.to('cuda:0')
         | 
| 98 | 
            +
            pipe.enable_xformers_memory_efficient_attention()
         | 
| 99 |  | 
| 100 | 
            +
            generator = torch.Generator(device='cuda:0').manual_seed(123456)
         | 
|  | |
| 101 |  | 
| 102 | 
            +
            vae = pipe.vae
         | 
| 103 | 
            +
             | 
| 104 | 
            +
             | 
| 105 | 
            +
            # pipe.force_zeros_for_empty_prompt = False
         | 
| 106 | 
            +
             | 
| 107 | 
            +
            # default_negative_prompt= "" #"Logo,Watermark,Text,Ugly,Morbid,Extra fingers,Poorly drawn hands,Mutation,Blurry,Extra limbs,Gross proportions,Missing arms,Mutated hands,Long neck,Duplicate,Mutilated,Mutilated hands,Poorly drawn face,Deformed,Bad anatomy,Cloned face,Malformed limbs,Missing legs,Too many fingers"
         | 
| 108 |  | 
| 109 |  | 
| 110 | 
             
            def read_content(file_path: str) -> str:
         | 
|  | |
| 115 |  | 
| 116 | 
             
                return content
         | 
| 117 |  | 
| 118 | 
            +
            def predict(dict, prompt="", negative_prompt = default_negative_prompt, guidance_scale=1.2, steps=12, strength=1.0):
         | 
| 119 | 
             
                if negative_prompt == "":
         | 
| 120 | 
             
                    negative_prompt = None
         | 
| 121 |  | 
| 122 |  | 
| 123 | 
             
                init_image = dict["image"].convert("RGB")#.resize((1024, 1024))
         | 
| 124 | 
            +
                mask = dict["mask"].convert("L")#.resize((1024, 1024))
         | 
| 125 |  | 
| 126 | 
            +
                width, height = get_size(init_image)
         | 
| 127 |  | 
| 128 | 
            +
                init_image = init_image.resize((width, height))
         | 
| 129 | 
            +
                mask = mask.resize((width, height))
         | 
| 130 |  | 
| 131 | 
             
                # Resize to nearest ratio ?
         | 
| 132 |  | 
| 133 | 
            +
                # mask = np.array(mask)
         | 
| 134 | 
            +
                # mask[mask>0]=255
         | 
| 135 | 
            +
                # mask = Image.fromarray(mask)
         | 
| 136 | 
            +
                
         | 
| 137 | 
            +
             | 
| 138 | 
            +
                masked_image, image_mask, masked_image_to_present = get_masked_image(init_image, mask, width, height)
         | 
| 139 | 
            +
                masked_image_tensor = image_transforms(masked_image)
         | 
| 140 | 
            +
                masked_image_tensor = (masked_image_tensor - 0.5) / 0.5
         | 
| 141 | 
            +
                
         | 
| 142 | 
            +
                masked_image_tensor = masked_image_tensor.unsqueeze(0).to(device="cuda")
         | 
| 143 | 
            +
                
         | 
| 144 | 
            +
                control_latents = vae.encode(  
         | 
| 145 | 
            +
                        masked_image_tensor[:, :3, :, :].to(vae.dtype)
         | 
| 146 | 
            +
                    ).latent_dist.sample()
         | 
| 147 |  | 
| 148 | 
            +
                control_latents = control_latents * vae.config.scaling_factor 
         | 
| 149 | 
            +
                
         | 
| 150 | 
            +
                image_mask = np.array(image_mask)[:,:]
         | 
| 151 | 
            +
                mask_tensor = torch.tensor(image_mask, dtype=torch.float32)[None, ...]
         | 
| 152 | 
            +
                # binarize the mask
         | 
| 153 | 
            +
                mask_tensor = torch.where(mask_tensor > 128.0, 255.0, 0)       
         | 
| 154 | 
            +
                
         | 
| 155 | 
            +
                mask_tensor = mask_tensor / 255.0
         | 
| 156 | 
            +
                
         | 
| 157 | 
            +
                mask_tensor = mask_tensor.to(device="cuda")
         | 
| 158 | 
            +
                mask_resized = torch.nn.functional.interpolate(mask_tensor[None, ...], size=(control_latents.shape[2], control_latents.shape[3]), mode='nearest')
         | 
| 159 | 
            +
                # mask_resized = mask_resized.to(torch.float16)
         | 
| 160 | 
            +
                masked_image = torch.cat([control_latents, mask_resized], dim=1)
         | 
| 161 | 
            +
             | 
| 162 | 
            +
             | 
| 163 | 
            +
                output = pipe(prompt = prompt,
         | 
| 164 | 
            +
                              width=width,
         | 
| 165 | 
            +
                              height=height,
         | 
| 166 | 
            +
                              negative_prompt=negative_prompt,
         | 
| 167 | 
            +
                              image = masked_image, # control image V
         | 
| 168 | 
            +
                              init_image = init_image,
         | 
| 169 | 
            +
                              mask_image=mask_tensor,
         | 
| 170 | 
            +
                              guidance_scale=guidance_scale,
         | 
| 171 | 
            +
                              num_inference_steps=int(steps),
         | 
| 172 | 
            +
                              strength=strength,
         | 
| 173 | 
            +
                              generator=generator,
         | 
| 174 | 
            +
                              controlnet_conditioning_sale=1.0, )
         | 
| 175 | 
            +
             | 
| 176 | 
            +
                # gen_img = pipe(negative_prompt=default_negative_prompt, prompt=prompt, 
         | 
| 177 | 
            +
                #         controlnet_conditioning_sale=1.0, 
         | 
| 178 | 
            +
                #         num_inference_steps=12, 
         | 
| 179 | 
            +
                #         height=height, width=width, 
         | 
| 180 | 
            +
                #         image = masked_image, # control image
         | 
| 181 | 
            +
                #         init_image = init_image,     
         | 
| 182 | 
            +
                #         mask_image = mask_tensor,
         | 
| 183 | 
            +
                #         guidance_scale = 1.2,
         | 
| 184 | 
            +
                #         generator=generator).images[0]
         | 
| 185 | 
            +
             | 
| 186 |  | 
| 187 | 
             
                return output.images[0] #, gr.update(visible=True)
         | 
| 188 |  | 
 
			

