Update app.py
Browse files
    	
        app.py
    CHANGED
    
    | @@ -1,52 +1,178 @@ | |
| 1 | 
             
            import gradio as gr
         | 
| 2 | 
            -
            from PIL import Image
         | 
| 3 | 
            -
            import requests
         | 
| 4 | 
            -
            from io import BytesIO
         | 
| 5 | 
             
            import torch
         | 
| 6 | 
            -
             | 
| 7 | 
            -
             | 
| 8 | 
            -
             | 
| 9 | 
            -
            from  | 
| 10 | 
            -
             | 
| 11 | 
            -
             | 
| 12 | 
            -
             | 
| 13 | 
            -
             | 
| 14 | 
            -
                 | 
| 15 | 
            -
             | 
| 16 | 
            -
             | 
| 17 | 
            -
                 | 
| 18 | 
            -
                 | 
| 19 | 
            -
                 | 
| 20 | 
            -
                 | 
| 21 | 
            -
                 | 
| 22 | 
            -
                 | 
| 23 | 
            -
             | 
| 24 | 
            -
             | 
| 25 | 
            -
             | 
| 26 | 
            -
             | 
| 27 | 
            -
             | 
| 28 | 
            -
                 | 
| 29 | 
            -
                 | 
| 30 | 
            -
                 | 
| 31 | 
            -
             | 
| 32 | 
            -
             | 
| 33 | 
            -
             | 
| 34 | 
            -
             | 
| 35 | 
            -
                 | 
| 36 | 
            -
                 | 
| 37 | 
            -
             | 
| 38 | 
            -
                 | 
| 39 | 
            -
                 | 
| 40 | 
            -
             | 
| 41 | 
            -
             | 
| 42 | 
            -
                return  | 
| 43 | 
            -
             | 
| 44 | 
            -
             | 
| 45 | 
            -
             | 
| 46 | 
            -
             | 
| 47 | 
            -
             | 
| 48 | 
            -
             | 
| 49 | 
            -
             | 
| 50 | 
            -
             | 
| 51 | 
            -
             | 
| 52 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
             
            import gradio as gr
         | 
|  | |
|  | |
|  | |
| 2 | 
             
            import torch
         | 
| 3 | 
            +
            import numpy as np
         | 
| 4 | 
            +
            import diffusers
         | 
| 5 | 
            +
            import os
         | 
| 6 | 
            +
            from PIL import Image
         | 
| 7 | 
            +
            hf_token = os.environ.get("HF_TOKEN")
         | 
| 8 | 
            +
            from diffusers import StableDiffusionXLInpaintPipeline, DDIMScheduler, UNet2DConditionModel
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            ratios_map =  {
         | 
| 11 | 
            +
                0.5:{"width":704,"height":1408},
         | 
| 12 | 
            +
                0.57:{"width":768,"height":1344},
         | 
| 13 | 
            +
                0.68:{"width":832,"height":1216},
         | 
| 14 | 
            +
                0.72:{"width":832,"height":1152},
         | 
| 15 | 
            +
                0.78:{"width":896,"height":1152},
         | 
| 16 | 
            +
                0.82:{"width":896,"height":1088},
         | 
| 17 | 
            +
                0.88:{"width":960,"height":1088},
         | 
| 18 | 
            +
                0.94:{"width":960,"height":1024},
         | 
| 19 | 
            +
                1.00:{"width":1024,"height":1024},
         | 
| 20 | 
            +
                1.13:{"width":1088,"height":960},
         | 
| 21 | 
            +
                1.21:{"width":1088,"height":896},
         | 
| 22 | 
            +
                1.29:{"width":1152,"height":896},
         | 
| 23 | 
            +
                1.38:{"width":1152,"height":832},
         | 
| 24 | 
            +
                1.46:{"width":1216,"height":832},
         | 
| 25 | 
            +
                1.67:{"width":1280,"height":768},
         | 
| 26 | 
            +
                1.75:{"width":1344,"height":768},
         | 
| 27 | 
            +
                2.00:{"width":1408,"height":704}
         | 
| 28 | 
            +
            }
         | 
| 29 | 
            +
            ratios = np.array(list(ratios_map.keys()))
         | 
| 30 | 
            +
             | 
| 31 | 
            +
            def get_size(init_image):
         | 
| 32 | 
            +
                w,h=init_image.size
         | 
| 33 | 
            +
                curr_ratio = w/h
         | 
| 34 | 
            +
                ind = np.argmin(np.abs(curr_ratio-ratios))
         | 
| 35 | 
            +
                ratio = ratios[ind]
         | 
| 36 | 
            +
                chosen_ratio  = ratios_map[ratio]
         | 
| 37 | 
            +
                w,h = chosen_ratio['width'], chosen_ratio['height']
         | 
| 38 | 
            +
             | 
| 39 | 
            +
                return w,h
         | 
| 40 | 
            +
             | 
| 41 | 
            +
            device = "cuda" if torch.cuda.is_available() else "cpu"
         | 
| 42 | 
            +
             | 
| 43 | 
            +
            unet = UNet2DConditionModel.from_pretrained(
         | 
| 44 | 
            +
                "briaai/BRIA-2.2-Inpainting",
         | 
| 45 | 
            +
                subfolder="unet",
         | 
| 46 | 
            +
                torch_dtype=torch.float16,
         | 
| 47 | 
            +
            )
         | 
| 48 | 
            +
             | 
| 49 | 
            +
            scheduler = DDIMScheduler.from_pretrained("briaai/BRIA-2.3", subfolder="scheduler",clip_sample=False)
         | 
| 50 | 
            +
             | 
| 51 | 
            +
            pipe = StableDiffusionXLInpaintPipeline.from_pretrained(
         | 
| 52 | 
            +
                "briaai/BRIA-2.3",
         | 
| 53 | 
            +
                unet=unet,
         | 
| 54 | 
            +
                scheduler=scheduler,
         | 
| 55 | 
            +
                torch_dtype=torch.float16,
         | 
| 56 | 
            +
                force_zeros_for_empty_prompt=False
         | 
| 57 | 
            +
            )
         | 
| 58 | 
            +
             | 
| 59 | 
            +
            pipe = pipe.to(device)
         | 
| 60 | 
            +
            pipe.force_zeros_for_empty_prompt = False
         | 
| 61 | 
            +
             | 
| 62 | 
            +
            default_negative_prompt= "" #"Logo,Watermark,Text,Ugly,Morbid,Extra fingers,Poorly drawn hands,Mutation,Blurry,Extra limbs,Gross proportions,Missing arms,Mutated hands,Long neck,Duplicate,Mutilated,Mutilated hands,Poorly drawn face,Deformed,Bad anatomy,Cloned face,Malformed limbs,Missing legs,Too many fingers"
         | 
| 63 | 
            +
             | 
| 64 | 
            +
             | 
| 65 | 
            +
            def read_content(file_path: str) -> str:
         | 
| 66 | 
            +
                """read the content of target file
         | 
| 67 | 
            +
                """
         | 
| 68 | 
            +
                with open(file_path, 'r', encoding='utf-8') as f:
         | 
| 69 | 
            +
                    content = f.read()
         | 
| 70 | 
            +
             | 
| 71 | 
            +
                return content
         | 
| 72 | 
            +
             | 
| 73 | 
            +
            def predict(dict, prompt="", negative_prompt="", guidance_scale=5, steps=30, strength=1.0):
         | 
| 74 | 
            +
                if negative_prompt == "":
         | 
| 75 | 
            +
                    negative_prompt = None
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                
         | 
| 78 | 
            +
                init_image = dict["image"].convert("RGB")#.resize((1024, 1024))
         | 
| 79 | 
            +
                mask = dict["mask"].convert("RGB")#.resize((1024, 1024))
         | 
| 80 | 
            +
             | 
| 81 | 
            +
                w,h = get_size(init_image)
         | 
| 82 | 
            +
             | 
| 83 | 
            +
                init_image = init_image.resize((w, h))
         | 
| 84 | 
            +
                mask = mask.resize((w, h))
         | 
| 85 | 
            +
                
         | 
| 86 | 
            +
                # Resize to nearest ratio ?
         | 
| 87 | 
            +
                
         | 
| 88 | 
            +
                mask = np.array(mask)
         | 
| 89 | 
            +
                mask[mask>0]=255
         | 
| 90 | 
            +
                mask = Image.fromarray(mask)
         | 
| 91 | 
            +
                
         | 
| 92 | 
            +
                output = pipe(prompt = prompt,width=w,height=h, negative_prompt=negative_prompt, image=init_image, mask_image=mask, guidance_scale=guidance_scale, num_inference_steps=int(steps), strength=strength)
         | 
| 93 | 
            +
                
         | 
| 94 | 
            +
                return output.images[0] #, gr.update(visible=True)
         | 
| 95 | 
            +
             | 
| 96 | 
            +
             | 
| 97 | 
            +
            css = '''
         | 
| 98 | 
            +
            .gradio-container{max-width: 1100px !important}
         | 
| 99 | 
            +
            #image_upload{min-height:400px}
         | 
| 100 | 
            +
            #image_upload [data-testid="image"], #image_upload [data-testid="image"] > div{min-height: 400px}
         | 
| 101 | 
            +
            #mask_radio .gr-form{background:transparent; border: none}
         | 
| 102 | 
            +
            #word_mask{margin-top: .75em !important}
         | 
| 103 | 
            +
            #word_mask textarea:disabled{opacity: 0.3}
         | 
| 104 | 
            +
            .footer {margin-bottom: 45px;margin-top: 35px;text-align: center;border-bottom: 1px solid #e5e5e5}
         | 
| 105 | 
            +
            .footer>p {font-size: .8rem; display: inline-block; padding: 0 10px;transform: translateY(10px);background: white}
         | 
| 106 | 
            +
            .dark .footer {border-color: #303030}
         | 
| 107 | 
            +
            .dark .footer>p {background: #0b0f19}
         | 
| 108 | 
            +
            .acknowledgments h4{margin: 1.25em 0 .25em 0;font-weight: bold;font-size: 115%}
         | 
| 109 | 
            +
            #image_upload .touch-none{display: flex}
         | 
| 110 | 
            +
            @keyframes spin {
         | 
| 111 | 
            +
                from {
         | 
| 112 | 
            +
                    transform: rotate(0deg);
         | 
| 113 | 
            +
                }
         | 
| 114 | 
            +
                to {
         | 
| 115 | 
            +
                    transform: rotate(360deg);
         | 
| 116 | 
            +
                }
         | 
| 117 | 
            +
            }
         | 
| 118 | 
            +
            #share-btn-container {padding-left: 0.5rem !important; padding-right: 0.5rem !important; background-color: #000000; justify-content: center; align-items: center; border-radius: 9999px !important; max-width: 13rem; margin-left: auto;}
         | 
| 119 | 
            +
            div#share-btn-container > div {flex-direction: row;background: black;align-items: center}
         | 
| 120 | 
            +
            #share-btn-container:hover {background-color: #060606}
         | 
| 121 | 
            +
            #share-btn {all: initial; color: #ffffff;font-weight: 600; cursor:pointer; font-family: 'IBM Plex Sans', sans-serif; margin-left: 0.5rem !important; padding-top: 0.5rem !important; padding-bottom: 0.5rem !important;right:0;}
         | 
| 122 | 
            +
            #share-btn * {all: unset}
         | 
| 123 | 
            +
            #share-btn-container div:nth-child(-n+2){width: auto !important;min-height: 0px !important;}
         | 
| 124 | 
            +
            #share-btn-container .wrap {display: none !important}
         | 
| 125 | 
            +
            #share-btn-container.hidden {display: none!important}
         | 
| 126 | 
            +
            #prompt input{width: calc(100% - 160px);border-top-right-radius: 0px;border-bottom-right-radius: 0px;}
         | 
| 127 | 
            +
            #run_button{position:absolute;margin-top: 11px;right: 0;margin-right: 0.8em;border-bottom-left-radius: 0px;
         | 
| 128 | 
            +
                border-top-left-radius: 0px;}
         | 
| 129 | 
            +
            #prompt-container{margin-top:-18px;}
         | 
| 130 | 
            +
            #prompt-container .form{border-top-left-radius: 0;border-top-right-radius: 0}
         | 
| 131 | 
            +
            #image_upload{border-bottom-left-radius: 0px;border-bottom-right-radius: 0px}
         | 
| 132 | 
            +
            '''
         | 
| 133 | 
            +
             | 
| 134 | 
            +
            image_blocks = gr.Blocks(css=css, elem_id="total-container")
         | 
| 135 | 
            +
            with image_blocks as demo:
         | 
| 136 | 
            +
                with gr.Column(elem_id="col-container"):
         | 
| 137 | 
            +
                    gr.Markdown("## BRIA 2.2")
         | 
| 138 | 
            +
                    gr.HTML('''
         | 
| 139 | 
            +
                      <p style="margin-bottom: 10px; font-size: 94%">
         | 
| 140 | 
            +
                        This is a demo for 
         | 
| 141 | 
            +
                        <a href="https://huggingface.co/briaai/BRIA-2.2" target="_blank">BRIA 2.2 text-to-image </a>. 
         | 
| 142 | 
            +
                        BRIA 2.2 improve the generation of humans and illustrations compared to BRIA 2.2 while still trained on licensed data, and so provide full legal liability coverage for copyright and privacy infringement.
         | 
| 143 | 
            +
                      </p>
         | 
| 144 | 
            +
                    ''')
         | 
| 145 | 
            +
                with gr.Row():
         | 
| 146 | 
            +
                            with gr.Column():
         | 
| 147 | 
            +
                                image = gr.Image(sources=['upload'], tool='sketch', elem_id="image_upload", type="pil", label="Upload", height=400)
         | 
| 148 | 
            +
                                with gr.Row(elem_id="prompt-container", equal_height=True):
         | 
| 149 | 
            +
                                    with gr.Row():
         | 
| 150 | 
            +
                                        prompt = gr.Textbox(placeholder="Your prompt (what you want in place of what is erased)", show_label=False, elem_id="prompt")
         | 
| 151 | 
            +
                                        btn = gr.Button("Inpaint!", elem_id="run_button")
         | 
| 152 | 
            +
                                
         | 
| 153 | 
            +
                                with gr.Accordion(label="Advanced Settings", open=False):
         | 
| 154 | 
            +
                                    with gr.Row(equal_height=True):
         | 
| 155 | 
            +
                                        guidance_scale = gr.Number(value=5, minimum=1.0, maximum=10.0, step=0.5, label="guidance_scale")
         | 
| 156 | 
            +
                                        steps = gr.Number(value=30, minimum=20, maximum=50, step=1, label="steps")
         | 
| 157 | 
            +
                                        strength = gr.Number(value=1, minimum=0.01, maximum=1.0, step=0.01, label="strength")
         | 
| 158 | 
            +
                                        negative_prompt = gr.Textbox(label="negative_prompt", value=default_negative_prompt, placeholder=default_negative_prompt, info="what you don't want to see in the image")
         | 
| 159 | 
            +
             | 
| 160 | 
            +
                                    
         | 
| 161 | 
            +
                            with gr.Column():
         | 
| 162 | 
            +
                                image_out = gr.Image(label="Output", elem_id="output-img", height=400)
         | 
| 163 | 
            +
             | 
| 164 | 
            +
                        
         | 
| 165 | 
            +
             | 
| 166 | 
            +
                btn.click(fn=predict, inputs=[image, prompt, negative_prompt, guidance_scale, steps, strength], outputs=[image_out], api_name='run')
         | 
| 167 | 
            +
                prompt.submit(fn=predict, inputs=[image, prompt, negative_prompt, guidance_scale, steps, strength], outputs=[image_out])
         | 
| 168 | 
            +
             | 
| 169 | 
            +
                gr.HTML(
         | 
| 170 | 
            +
                    """
         | 
| 171 | 
            +
                        <div class="footer">
         | 
| 172 | 
            +
                            <p>Model by <a href="https://huggingface.co/diffusers" style="text-decoration: underline;" target="_blank">Diffusers</a> - Gradio Demo by 🤗 Hugging Face
         | 
| 173 | 
            +
                            </p>
         | 
| 174 | 
            +
                        </div>
         | 
| 175 | 
            +
                    """
         | 
| 176 | 
            +
                )
         | 
| 177 | 
            +
             | 
| 178 | 
            +
            image_blocks.queue(max_size=25,api_open=False).launch(show_api=False)
         | 
 
			

