Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	Commit 
							
							·
						
						a0efccd
	
1
								Parent(s):
							
							fb6e008
								
modify inference
Browse files- examples/examples.py +154 -2
 - infer_fft.py +178 -0
 - infer_lora.py +228 -0
 - inference/__init__.py +2 -0
 - inference/ace_plus_diffusers.py +7 -3
 - inference/ace_plus_inference.py +83 -0
 - inference/registry.py +228 -0
 - inference/utils.py +38 -11
 
    	
        examples/examples.py
    CHANGED
    
    | 
         @@ -2,9 +2,9 @@ all_examples = [ 
     | 
|
| 2 | 
         
             
                        {
         
     | 
| 3 | 
         
             
                            "input_image": None,
         
     | 
| 4 | 
         
             
                            "input_mask": None,
         
     | 
| 5 | 
         
            -
                            "input_reference_image": "assets/samples/portrait/ 
     | 
| 6 | 
         
             
                            "save_path": "examples/outputs/portrait_human_1.jpg",
         
     | 
| 7 | 
         
            -
                            "instruction": " 
     | 
| 8 | 
         
             
                            "output_h": 1024,
         
     | 
| 9 | 
         
             
                            "output_w": 1024,
         
     | 
| 10 | 
         
             
                            "seed": 4194866942,
         
     | 
| 
         @@ -78,4 +78,156 @@ all_examples = [ 
     | 
|
| 78 | 
         
             
                            "edit_type": "repainting"
         
     | 
| 79 | 
         
             
                        }
         
     | 
| 80 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 81 | 
         
             
                    ]
         
     | 
| 
         | 
|
| 2 | 
         
             
                        {
         
     | 
| 3 | 
         
             
                            "input_image": None,
         
     | 
| 4 | 
         
             
                            "input_mask": None,
         
     | 
| 5 | 
         
            +
                            "input_reference_image": "assets/samples/portrait/human_1.jpg",
         
     | 
| 6 | 
         
             
                            "save_path": "examples/outputs/portrait_human_1.jpg",
         
     | 
| 7 | 
         
            +
                            "instruction": "Maintain the facial features, A girl is wearing a neat police uniform and sporting a badge. She is smiling with a friendly and confident demeanor. The background is blurred, featuring a cartoon logo.",
         
     | 
| 8 | 
         
             
                            "output_h": 1024,
         
     | 
| 9 | 
         
             
                            "output_w": 1024,
         
     | 
| 10 | 
         
             
                            "seed": 4194866942,
         
     | 
| 
         | 
|
| 78 | 
         
             
                            "edit_type": "repainting"
         
     | 
| 79 | 
         
             
                        }
         
     | 
| 80 | 
         | 
| 81 | 
         
            +
                    ]
         
     | 
| 82 | 
         
            +
             
     | 
| 83 | 
         
            +
            fft_examples =  [
         
     | 
| 84 | 
         
            +
                        {
         
     | 
| 85 | 
         
            +
                            "input_image": None,
         
     | 
| 86 | 
         
            +
                            "input_mask": None,
         
     | 
| 87 | 
         
            +
                            "input_reference_image": "./assets/samples/portrait/human_1.jpg",
         
     | 
| 88 | 
         
            +
                            "save_path": "examples/outputs/portrait_human_1.jpg",
         
     | 
| 89 | 
         
            +
                            "instruction": "Maintain the facial features, A girl is wearing a neat police uniform and sporting a badge. She is smiling with a friendly and confident demeanor. The background is blurred, featuring a cartoon logo.",
         
     | 
| 90 | 
         
            +
                            "output_h": 1024,
         
     | 
| 91 | 
         
            +
                            "output_w": 1024,
         
     | 
| 92 | 
         
            +
                            "seed": 10000000,
         
     | 
| 93 | 
         
            +
                            "repainting_scale": 1.0,
         
     | 
| 94 | 
         
            +
                            "edit_type": "repainting"
         
     | 
| 95 | 
         
            +
                        },
         
     | 
| 96 | 
         
            +
                        {
         
     | 
| 97 | 
         
            +
                            "input_image": None,
         
     | 
| 98 | 
         
            +
                            "input_mask": None,
         
     | 
| 99 | 
         
            +
                            "input_reference_image": "./assets/samples/subject/subject_1.jpg",
         
     | 
| 100 | 
         
            +
                            "save_path": "examples/outputs/subject_subject_1.jpg",
         
     | 
| 101 | 
         
            +
                            "instruction": "Display the logo in a minimalist style printed in white on a matte black ceramic coffee mug, alongside a steaming cup of coffee on a cozy cafe table.",
         
     | 
| 102 | 
         
            +
                            "output_h": 1024,
         
     | 
| 103 | 
         
            +
                            "output_w": 1024,
         
     | 
| 104 | 
         
            +
                            "seed": 10000000,
         
     | 
| 105 | 
         
            +
                            "repainting_scale": 1.0,
         
     | 
| 106 | 
         
            +
                            "edit_type": "repainting"
         
     | 
| 107 | 
         
            +
                        },
         
     | 
| 108 | 
         
            +
                        {
         
     | 
| 109 | 
         
            +
                            "input_image": "./assets/samples/application/photo_editing/1_2_edit.jpg",
         
     | 
| 110 | 
         
            +
                            "input_mask": "./assets/samples/application/photo_editing/1_2_m.webp",
         
     | 
| 111 | 
         
            +
                            "input_reference_image": "./assets/samples/application/photo_editing/1_ref.png",
         
     | 
| 112 | 
         
            +
                            "save_path": "examples/outputs/photo_editing_1.jpg",
         
     | 
| 113 | 
         
            +
                            "instruction": "The item is put on the table.",
         
     | 
| 114 | 
         
            +
                            "output_h": 1024,
         
     | 
| 115 | 
         
            +
                            "output_w": 1024,
         
     | 
| 116 | 
         
            +
                            "seed": 8006019,
         
     | 
| 117 | 
         
            +
                            "repainting_scale": 1.0,
         
     | 
| 118 | 
         
            +
                            "edit_type": "repainting"
         
     | 
| 119 | 
         
            +
                        },
         
     | 
| 120 | 
         
            +
                        {
         
     | 
| 121 | 
         
            +
                            "input_image": "./assets/samples/application/logo_paste/1_1_edit.png",
         
     | 
| 122 | 
         
            +
                            "input_mask": "./assets/samples/application/logo_paste/1_1_m.png",
         
     | 
| 123 | 
         
            +
                            "input_reference_image": "assets/samples/application/logo_paste/1_ref.png",
         
     | 
| 124 | 
         
            +
                            "save_path": "examples/outputs/logo_paste_1.jpg",
         
     | 
| 125 | 
         
            +
                            "instruction": "The logo is printed on the headphones.",
         
     | 
| 126 | 
         
            +
                            "output_h": 1024,
         
     | 
| 127 | 
         
            +
                            "output_w": 1024,
         
     | 
| 128 | 
         
            +
                            "seed": 934582264,
         
     | 
| 129 | 
         
            +
                            "repainting_scale": 1.0,
         
     | 
| 130 | 
         
            +
                            "edit_type": "repainting"
         
     | 
| 131 | 
         
            +
                        },
         
     | 
| 132 | 
         
            +
                        {
         
     | 
| 133 | 
         
            +
                            "input_image": "./assets/samples/application/try_on/1_1_edit.png",
         
     | 
| 134 | 
         
            +
                            "input_mask": "./assets/samples/application/try_on/1_1_m.png",
         
     | 
| 135 | 
         
            +
                            "input_reference_image": "assets/samples/application/try_on/1_ref.png",
         
     | 
| 136 | 
         
            +
                            "save_path": "examples/outputs/try_on_1.jpg",
         
     | 
| 137 | 
         
            +
                            "instruction": "The woman dresses this skirt.",
         
     | 
| 138 | 
         
            +
                            "output_h": 1024,
         
     | 
| 139 | 
         
            +
                            "output_w": 1024,
         
     | 
| 140 | 
         
            +
                            "seed": 934582264,
         
     | 
| 141 | 
         
            +
                            "repainting_scale": 1.0,
         
     | 
| 142 | 
         
            +
                            "edit_type": "repainting"
         
     | 
| 143 | 
         
            +
                        },
         
     | 
| 144 | 
         
            +
                        {
         
     | 
| 145 | 
         
            +
                            "input_image": "./assets/samples/portrait/human_1.jpg",
         
     | 
| 146 | 
         
            +
                            "input_mask": "assets/samples/application/movie_poster/1_2_m.webp",
         
     | 
| 147 | 
         
            +
                            "input_reference_image": "assets/samples/application/movie_poster/1_ref.png",
         
     | 
| 148 | 
         
            +
                            "save_path": "examples/outputs/movie_poster_1.jpg",
         
     | 
| 149 | 
         
            +
                            "instruction": "{image}, the man faces the camera.",
         
     | 
| 150 | 
         
            +
                            "output_h": 1024,
         
     | 
| 151 | 
         
            +
                            "output_w": 1024,
         
     | 
| 152 | 
         
            +
                            "seed": 3999647,
         
     | 
| 153 | 
         
            +
                            "repainting_scale": 1.0,
         
     | 
| 154 | 
         
            +
                            "edit_type": "repainting"
         
     | 
| 155 | 
         
            +
                        },
         
     | 
| 156 | 
         
            +
                        {
         
     | 
| 157 | 
         
            +
                            "input_image": "./assets/samples/application/sr/sr_tiger.png",
         
     | 
| 158 | 
         
            +
                            "input_mask": "./assets/samples/application/sr/sr_tiger_m.webp",
         
     | 
| 159 | 
         
            +
                            "input_reference_image": None,
         
     | 
| 160 | 
         
            +
                            "save_path": "examples/outputs/mario_recolorizing_1.jpg",
         
     | 
| 161 | 
         
            +
                            "instruction": "{image} features a close-up of a young, furry tiger cub on a rock. The tiger, which appears to be quite young, has distinctive orange, "
         
     | 
| 162 | 
         
            +
                                           "black, and white striped fur, typical of tigers. The cub's eyes have a bright and curious expression, and its ears are perked up, "
         
     | 
| 163 | 
         
            +
                                           "indicating alertness. The cub seems to be in the act of climbing or resting on the rock. The background is a blurred grassland with trees, "
         
     | 
| 164 | 
         
            +
                                           "but the focus is on the cub, which is vividly colored while the rest of the image is in grayscale, drawing attention to the tiger's details."
         
     | 
| 165 | 
         
            +
                                           " The photo captures a moment in the wild, depicting the charming and tenacious nature of this young tiger,"
         
     | 
| 166 | 
         
            +
                                           " as well as its typical interaction with the environment.",
         
     | 
| 167 | 
         
            +
                            "output_h": 1024,
         
     | 
| 168 | 
         
            +
                            "output_w": 1024,
         
     | 
| 169 | 
         
            +
                            "seed": 199999,
         
     | 
| 170 | 
         
            +
                            "repainting_scale": 0.0,
         
     | 
| 171 | 
         
            +
                            "edit_type": "no_preprocess"
         
     | 
| 172 | 
         
            +
                        },
         
     | 
| 173 | 
         
            +
                        {
         
     | 
| 174 | 
         
            +
                            "input_image": "./assets/samples/application/photo_editing/1_ref.png",
         
     | 
| 175 | 
         
            +
                            "input_mask": "./assets/samples/application/photo_editing/1_1_orm.webp",
         
     | 
| 176 | 
         
            +
                            "input_reference_image": None,
         
     | 
| 177 | 
         
            +
                            "save_path": "examples/outputs/mario_repainting_1.jpg",
         
     | 
| 178 | 
         
            +
                            "instruction": "a blue hand",
         
     | 
| 179 | 
         
            +
                            "output_h": 1024,
         
     | 
| 180 | 
         
            +
                            "output_w": 1024,
         
     | 
| 181 | 
         
            +
                            "seed": 63401,
         
     | 
| 182 | 
         
            +
                            "repainting_scale": 1.0,
         
     | 
| 183 | 
         
            +
                            "edit_type": "repainting"
         
     | 
| 184 | 
         
            +
                        },
         
     | 
| 185 | 
         
            +
                        {
         
     | 
| 186 | 
         
            +
                            "input_image": "./assets/samples/application/photo_editing/1_ref.png",
         
     | 
| 187 | 
         
            +
                            "input_mask": "./assets/samples/application/photo_editing/1_1_rm.webp",
         
     | 
| 188 | 
         
            +
                            "input_reference_image": None,
         
     | 
| 189 | 
         
            +
                            "save_path": "examples/outputs/mario_repainting_2.jpg",
         
     | 
| 190 | 
         
            +
                            "instruction": "Mechanical  hands like a robot",
         
     | 
| 191 | 
         
            +
                            "output_h": 1024,
         
     | 
| 192 | 
         
            +
                            "output_w": 1024,
         
     | 
| 193 | 
         
            +
                            "seed": 59107,
         
     | 
| 194 | 
         
            +
                            "repainting_scale": 1.0,
         
     | 
| 195 | 
         
            +
                            "edit_type": "repainting"
         
     | 
| 196 | 
         
            +
                        },
         
     | 
| 197 | 
         
            +
                        {
         
     | 
| 198 | 
         
            +
                            "input_image": "./assets/samples/control/1_1.webp",
         
     | 
| 199 | 
         
            +
                            "input_mask": "./assets/samples/control/1_1_m.webp",
         
     | 
| 200 | 
         
            +
                            "input_reference_image": None,
         
     | 
| 201 | 
         
            +
                            "save_path": "examples/outputs/control_recolorizing.jpg",
         
     | 
| 202 | 
         
            +
                            "instruction": "{image} Beautiful female portrait, Robot with smooth White transparent carbon shell, rococo detailing, Natural lighting, Highly detailed, Cinematic, 4K.",
         
     | 
| 203 | 
         
            +
                            "output_h": 1024,
         
     | 
| 204 | 
         
            +
                            "output_w": 1024,
         
     | 
| 205 | 
         
            +
                            "seed": 9652101,
         
     | 
| 206 | 
         
            +
                            "repainting_scale": 0.0,
         
     | 
| 207 | 
         
            +
                            "edit_type": "recolorizing"
         
     | 
| 208 | 
         
            +
                        },
         
     | 
| 209 | 
         
            +
                        {
         
     | 
| 210 | 
         
            +
                            "input_image": "./assets/samples/control/1_1.webp",
         
     | 
| 211 | 
         
            +
                            "input_mask": "./assets/samples/control/1_1_m.webp",
         
     | 
| 212 | 
         
            +
                            "input_reference_image": None,
         
     | 
| 213 | 
         
            +
                            "save_path": "examples/outputs/control_depth.jpg",
         
     | 
| 214 | 
         
            +
                            "instruction": "{image} Beautiful female portrait, Robot with smooth White transparent carbon shell, rococo detailing, Natural lighting, Highly detailed, Cinematic, 4K.",
         
     | 
| 215 | 
         
            +
                            "output_h": 1024,
         
     | 
| 216 | 
         
            +
                            "output_w": 1024,
         
     | 
| 217 | 
         
            +
                            "seed": 14979476,
         
     | 
| 218 | 
         
            +
                            "repainting_scale": 0.0,
         
     | 
| 219 | 
         
            +
                            "edit_type": "depth_repainting"
         
     | 
| 220 | 
         
            +
                        },
         
     | 
| 221 | 
         
            +
                        {
         
     | 
| 222 | 
         
            +
                            "input_image": "./assets/samples/control/1_1.webp",
         
     | 
| 223 | 
         
            +
                            "input_mask": "./assets/samples/control/1_1_m.webp",
         
     | 
| 224 | 
         
            +
                            "input_reference_image": None,
         
     | 
| 225 | 
         
            +
                            "save_path": "examples/outputs/control_contour.jpg",
         
     | 
| 226 | 
         
            +
                            "instruction": "{image} Beautiful female portrait, Robot with smooth White transparent carbon shell, rococo detailing, Natural lighting, Highly detailed, Cinematic, 4K.",
         
     | 
| 227 | 
         
            +
                            "output_h": 1024,
         
     | 
| 228 | 
         
            +
                            "output_w": 1024,
         
     | 
| 229 | 
         
            +
                            "seed": 4227292472,
         
     | 
| 230 | 
         
            +
                            "repainting_scale": 0.0,
         
     | 
| 231 | 
         
            +
                            "edit_type": "contour_repainting"
         
     | 
| 232 | 
         
            +
                        }
         
     | 
| 233 | 
         
             
                    ]
         
     | 
    	
        infer_fft.py
    ADDED
    
    | 
         @@ -0,0 +1,178 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # -*- coding: utf-8 -*-
         
     | 
| 2 | 
         
            +
            # Copyright (c) Alibaba, Inc. and its affiliates.
         
     | 
| 3 | 
         
            +
            import argparse
         
     | 
| 4 | 
         
            +
            import glob
         
     | 
| 5 | 
         
            +
            import importlib
         
     | 
| 6 | 
         
            +
            import io
         
     | 
| 7 | 
         
            +
            import os
         
     | 
| 8 | 
         
            +
            import sys
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            from PIL import Image
         
     | 
| 11 | 
         
            +
            from scepter.modules.transform.io import pillow_convert
         
     | 
| 12 | 
         
            +
            from scepter.modules.utils.config import Config
         
     | 
| 13 | 
         
            +
            from scepter.modules.utils.file_system import FS
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            if os.path.exists('__init__.py'):
         
     | 
| 16 | 
         
            +
                package_name = 'scepter_ext'
         
     | 
| 17 | 
         
            +
                spec = importlib.util.spec_from_file_location(package_name, '__init__.py')
         
     | 
| 18 | 
         
            +
                package = importlib.util.module_from_spec(spec)
         
     | 
| 19 | 
         
            +
                sys.modules[package_name] = package
         
     | 
| 20 | 
         
            +
                spec.loader.exec_module(package)
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
            from examples.examples import fft_examples as all_examples
         
     | 
| 23 | 
         
            +
            from inference.registry import INFERENCES
         
     | 
| 24 | 
         
            +
            fs_list = [
         
     | 
| 25 | 
         
            +
                Config(cfg_dict={"NAME": "HuggingfaceFs", "TEMP_DIR": "./cache"}, load=False),
         
     | 
| 26 | 
         
            +
                Config(cfg_dict={"NAME": "ModelscopeFs", "TEMP_DIR": "./cache"}, load=False),
         
     | 
| 27 | 
         
            +
                Config(cfg_dict={"NAME": "HttpFs", "TEMP_DIR": "./cache"}, load=False),
         
     | 
| 28 | 
         
            +
                Config(cfg_dict={"NAME": "LocalFs", "TEMP_DIR": "./cache"}, load=False),
         
     | 
| 29 | 
         
            +
            ]
         
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
            for one_fs in fs_list:
         
     | 
| 32 | 
         
            +
                FS.init_fs_client(one_fs)
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
            def run_one_case(pipe,
         
     | 
| 36 | 
         
            +
                            input_image = None,
         
     | 
| 37 | 
         
            +
                            input_mask = None,
         
     | 
| 38 | 
         
            +
                            input_reference_image = None,
         
     | 
| 39 | 
         
            +
                            save_path = "examples/output/example.png",
         
     | 
| 40 | 
         
            +
                            instruction = "",
         
     | 
| 41 | 
         
            +
                            output_h = 1024,
         
     | 
| 42 | 
         
            +
                            output_w = 1024,
         
     | 
| 43 | 
         
            +
                            seed = -1,
         
     | 
| 44 | 
         
            +
                            sample_steps = None,
         
     | 
| 45 | 
         
            +
                            guide_scale = None,
         
     | 
| 46 | 
         
            +
                            repainting_scale = None,
         
     | 
| 47 | 
         
            +
                            use_change=True,
         
     | 
| 48 | 
         
            +
                            keep_pixels=True,
         
     | 
| 49 | 
         
            +
                            keep_pixels_rate=0.8,
         
     | 
| 50 | 
         
            +
                            **kwargs):
         
     | 
| 51 | 
         
            +
                if input_image is not None:
         
     | 
| 52 | 
         
            +
                    input_image = Image.open(io.BytesIO(FS.get_object(input_image)))
         
     | 
| 53 | 
         
            +
                    input_image = pillow_convert(input_image, "RGB")
         
     | 
| 54 | 
         
            +
                if input_mask is not None:
         
     | 
| 55 | 
         
            +
                    input_mask = Image.open(io.BytesIO(FS.get_object(input_mask)))
         
     | 
| 56 | 
         
            +
                    input_mask = pillow_convert(input_mask, "L")
         
     | 
| 57 | 
         
            +
                if input_reference_image is not None:
         
     | 
| 58 | 
         
            +
                    input_reference_image = Image.open(io.BytesIO(FS.get_object(input_reference_image)))
         
     | 
| 59 | 
         
            +
                    input_reference_image = pillow_convert(input_reference_image, "RGB")
         
     | 
| 60 | 
         
            +
                print(repainting_scale)
         
     | 
| 61 | 
         
            +
                image, _, _, _, seed = pipe(
         
     | 
| 62 | 
         
            +
                    reference_image=input_reference_image,
         
     | 
| 63 | 
         
            +
                    edit_image=input_image,
         
     | 
| 64 | 
         
            +
                    edit_mask=input_mask,
         
     | 
| 65 | 
         
            +
                    prompt=instruction,
         
     | 
| 66 | 
         
            +
                    output_height=output_h,
         
     | 
| 67 | 
         
            +
                    output_width=output_w,
         
     | 
| 68 | 
         
            +
                    sampler='flow_euler',
         
     | 
| 69 | 
         
            +
                    sample_steps=sample_steps or pipe.input.get("sample_steps", 28),
         
     | 
| 70 | 
         
            +
                    guide_scale=guide_scale or pipe.input.get("guide_scale", 50),
         
     | 
| 71 | 
         
            +
                    seed=seed,
         
     | 
| 72 | 
         
            +
                    repainting_scale=repainting_scale,
         
     | 
| 73 | 
         
            +
                    use_change=use_change,
         
     | 
| 74 | 
         
            +
                    keep_pixels=keep_pixels,
         
     | 
| 75 | 
         
            +
                    keep_pixels_rate=keep_pixels_rate
         
     | 
| 76 | 
         
            +
                )
         
     | 
| 77 | 
         
            +
                with FS.put_to(save_path) as local_path:
         
     | 
| 78 | 
         
            +
                    image.save(local_path)
         
     | 
| 79 | 
         
            +
                return local_path, seed
         
     | 
| 80 | 
         
            +
             
     | 
| 81 | 
         
            +
             
     | 
| 82 | 
         
            +
            def run():
         
     | 
| 83 | 
         
            +
                parser = argparse.ArgumentParser(description='Argparser for Scepter:\n')
         
     | 
| 84 | 
         
            +
                parser.add_argument('--instruction',
         
     | 
| 85 | 
         
            +
                                    dest='instruction',
         
     | 
| 86 | 
         
            +
                                    help='The instruction for editing or generating!',
         
     | 
| 87 | 
         
            +
                                    default="")
         
     | 
| 88 | 
         
            +
                parser.add_argument('--output_h',
         
     | 
| 89 | 
         
            +
                                    dest='output_h',
         
     | 
| 90 | 
         
            +
                                    help='The height of output image for generation tasks!',
         
     | 
| 91 | 
         
            +
                                    type=int,
         
     | 
| 92 | 
         
            +
                                    default=1024)
         
     | 
| 93 | 
         
            +
                parser.add_argument('--output_w',
         
     | 
| 94 | 
         
            +
                                    dest='output_w',
         
     | 
| 95 | 
         
            +
                                    help='The width of output image for generation tasks!',
         
     | 
| 96 | 
         
            +
                                    type=int,
         
     | 
| 97 | 
         
            +
                                    default=1024)
         
     | 
| 98 | 
         
            +
                parser.add_argument('--input_reference_image',
         
     | 
| 99 | 
         
            +
                                    dest='input_reference_image',
         
     | 
| 100 | 
         
            +
                                    help='The input reference image!',
         
     | 
| 101 | 
         
            +
                                    default=None
         
     | 
| 102 | 
         
            +
                                    )
         
     | 
| 103 | 
         
            +
                parser.add_argument('--input_image',
         
     | 
| 104 | 
         
            +
                                    dest='input_image',
         
     | 
| 105 | 
         
            +
                                    help='The input image!',
         
     | 
| 106 | 
         
            +
                                    default=None
         
     | 
| 107 | 
         
            +
                                    )
         
     | 
| 108 | 
         
            +
                parser.add_argument('--input_mask',
         
     | 
| 109 | 
         
            +
                                    dest='input_mask',
         
     | 
| 110 | 
         
            +
                                    help='The input mask!',
         
     | 
| 111 | 
         
            +
                                    default=None
         
     | 
| 112 | 
         
            +
                                    )
         
     | 
| 113 | 
         
            +
                parser.add_argument('--save_path',
         
     | 
| 114 | 
         
            +
                                    dest='save_path',
         
     | 
| 115 | 
         
            +
                                    help='The save path for output image!',
         
     | 
| 116 | 
         
            +
                                    default='examples/output_images/output.png'
         
     | 
| 117 | 
         
            +
                                    )
         
     | 
| 118 | 
         
            +
                parser.add_argument('--seed',
         
     | 
| 119 | 
         
            +
                                    dest='seed',
         
     | 
| 120 | 
         
            +
                                    help='The seed for generation!',
         
     | 
| 121 | 
         
            +
                                    type=int,
         
     | 
| 122 | 
         
            +
                                    default=-1)
         
     | 
| 123 | 
         
            +
             
     | 
| 124 | 
         
            +
                parser.add_argument('--step',
         
     | 
| 125 | 
         
            +
                                    dest='step',
         
     | 
| 126 | 
         
            +
                                    help='The sample step for generation!',
         
     | 
| 127 | 
         
            +
                                    type=int,
         
     | 
| 128 | 
         
            +
                                    default=None)
         
     | 
| 129 | 
         
            +
             
     | 
| 130 | 
         
            +
                parser.add_argument('--guide_scale',
         
     | 
| 131 | 
         
            +
                                    dest='guide_scale',
         
     | 
| 132 | 
         
            +
                                    help='The guide scale for generation!',
         
     | 
| 133 | 
         
            +
                                    type=int,
         
     | 
| 134 | 
         
            +
                                    default=None)
         
     | 
| 135 | 
         
            +
             
     | 
| 136 | 
         
            +
                parser.add_argument('--repainting_scale',
         
     | 
| 137 | 
         
            +
                                    dest='repainting_scale',
         
     | 
| 138 | 
         
            +
                                    help='The repainting scale for content filling generation!',
         
     | 
| 139 | 
         
            +
                                    type=int,
         
     | 
| 140 | 
         
            +
                                    default=None)
         
     | 
| 141 | 
         
            +
             
     | 
| 142 | 
         
            +
                cfg = Config(load=True, parser_ins=parser)
         
     | 
| 143 | 
         
            +
                model_cfg = Config(load=True, cfg_file="config/ace_plus_fft.yaml")
         
     | 
| 144 | 
         
            +
                pipe = INFERENCES.build(model_cfg)
         
     | 
| 145 | 
         
            +
             
     | 
| 146 | 
         
            +
             
     | 
| 147 | 
         
            +
                if cfg.args.instruction == "" and cfg.args.input_image is None and cfg.args.input_reference_image is None:
         
     | 
| 148 | 
         
            +
                    params = {
         
     | 
| 149 | 
         
            +
                        "output_h": cfg.args.output_h,
         
     | 
| 150 | 
         
            +
                        "output_w": cfg.args.output_w,
         
     | 
| 151 | 
         
            +
                        "sample_steps": cfg.args.step,
         
     | 
| 152 | 
         
            +
                        "guide_scale": cfg.args.guide_scale
         
     | 
| 153 | 
         
            +
                    }
         
     | 
| 154 | 
         
            +
                    # run examples
         
     | 
| 155 | 
         
            +
             
     | 
| 156 | 
         
            +
                    for example in all_examples:
         
     | 
| 157 | 
         
            +
                        example.update(params)
         
     | 
| 158 | 
         
            +
                        local_path, seed = run_one_case(pipe, **example)
         
     | 
| 159 | 
         
            +
             
     | 
| 160 | 
         
            +
                else:
         
     | 
| 161 | 
         
            +
                    params = {
         
     | 
| 162 | 
         
            +
                        "input_image": cfg.args.input_image,
         
     | 
| 163 | 
         
            +
                        "input_mask": cfg.args.input_mask,
         
     | 
| 164 | 
         
            +
                        "input_reference_image": cfg.args.input_reference_image,
         
     | 
| 165 | 
         
            +
                        "save_path": cfg.args.save_path,
         
     | 
| 166 | 
         
            +
                        "instruction": cfg.args.instruction,
         
     | 
| 167 | 
         
            +
                        "output_h": cfg.args.output_h,
         
     | 
| 168 | 
         
            +
                        "output_w": cfg.args.output_w,
         
     | 
| 169 | 
         
            +
                        "sample_steps": cfg.args.step,
         
     | 
| 170 | 
         
            +
                        "guide_scale": cfg.args.guide_scale,
         
     | 
| 171 | 
         
            +
                        "repainting_scale": cfg.args.repainting_scale,
         
     | 
| 172 | 
         
            +
                    }
         
     | 
| 173 | 
         
            +
                    local_path, seed = run_one_case(pipe, **params)
         
     | 
| 174 | 
         
            +
                    print(local_path, seed)
         
     | 
| 175 | 
         
            +
             
     | 
| 176 | 
         
            +
            if __name__ == '__main__':
         
     | 
| 177 | 
         
            +
                run()
         
     | 
| 178 | 
         
            +
             
     | 
    	
        infer_lora.py
    ADDED
    
    | 
         @@ -0,0 +1,228 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # -*- coding: utf-8 -*-
         
     | 
| 2 | 
         
            +
            # Copyright (c) Alibaba, Inc. and its affiliates.
         
     | 
| 3 | 
         
            +
            import argparse
         
     | 
| 4 | 
         
            +
            import glob
         
     | 
| 5 | 
         
            +
            import io
         
     | 
| 6 | 
         
            +
            import os
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            from PIL import Image
         
     | 
| 9 | 
         
            +
            from scepter.modules.transform.io import pillow_convert
         
     | 
| 10 | 
         
            +
            from scepter.modules.utils.config import Config
         
     | 
| 11 | 
         
            +
            from scepter.modules.utils.file_system import FS
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
            from examples.examples import all_examples
         
     | 
| 14 | 
         
            +
            from inference.ace_plus_diffusers import ACEPlusDiffuserInference
         
     | 
| 15 | 
         
            +
            inference_dict = {
         
     | 
| 16 | 
         
            +
                "ACE_DIFFUSER_PLUS": ACEPlusDiffuserInference
         
     | 
| 17 | 
         
            +
            }
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
            fs_list = [
         
     | 
| 20 | 
         
            +
                Config(cfg_dict={"NAME": "HuggingfaceFs", "TEMP_DIR": "./cache"}, load=False),
         
     | 
| 21 | 
         
            +
                Config(cfg_dict={"NAME": "ModelscopeFs", "TEMP_DIR": "./cache"}, load=False),
         
     | 
| 22 | 
         
            +
                Config(cfg_dict={"NAME": "HttpFs", "TEMP_DIR": "./cache"}, load=False),
         
     | 
| 23 | 
         
            +
                Config(cfg_dict={"NAME": "LocalFs", "TEMP_DIR": "./cache"}, load=False),
         
     | 
| 24 | 
         
            +
            ]
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
            for one_fs in fs_list:
         
     | 
| 27 | 
         
            +
                FS.init_fs_client(one_fs)
         
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
            def run_one_case(pipe,
         
     | 
| 31 | 
         
            +
                            input_image = None,
         
     | 
| 32 | 
         
            +
                            input_mask = None,
         
     | 
| 33 | 
         
            +
                            input_reference_image = None,
         
     | 
| 34 | 
         
            +
                            save_path = "examples/output/example.png",
         
     | 
| 35 | 
         
            +
                            instruction = "",
         
     | 
| 36 | 
         
            +
                            output_h = 1024,
         
     | 
| 37 | 
         
            +
                            output_w = 1024,
         
     | 
| 38 | 
         
            +
                            seed = -1,
         
     | 
| 39 | 
         
            +
                            sample_steps = None,
         
     | 
| 40 | 
         
            +
                            guide_scale = None,
         
     | 
| 41 | 
         
            +
                            repainting_scale = None,
         
     | 
| 42 | 
         
            +
                            model_path = None,
         
     | 
| 43 | 
         
            +
                            **kwargs):
         
     | 
| 44 | 
         
            +
                if input_image is not None:
         
     | 
| 45 | 
         
            +
                    input_image = Image.open(io.BytesIO(FS.get_object(input_image)))
         
     | 
| 46 | 
         
            +
                    input_image = pillow_convert(input_image, "RGB")
         
     | 
| 47 | 
         
            +
                if input_mask is not None:
         
     | 
| 48 | 
         
            +
                    input_mask = Image.open(io.BytesIO(FS.get_object(input_mask)))
         
     | 
| 49 | 
         
            +
                    input_mask = pillow_convert(input_mask, "L")
         
     | 
| 50 | 
         
            +
                if input_reference_image is not None:
         
     | 
| 51 | 
         
            +
                    input_reference_image = Image.open(io.BytesIO(FS.get_object(input_reference_image)))
         
     | 
| 52 | 
         
            +
                    input_reference_image = pillow_convert(input_reference_image, "RGB")
         
     | 
| 53 | 
         
            +
             
     | 
| 54 | 
         
            +
                image, seed = pipe(
         
     | 
| 55 | 
         
            +
                    reference_image=input_reference_image,
         
     | 
| 56 | 
         
            +
                    edit_image=input_image,
         
     | 
| 57 | 
         
            +
                    edit_mask=input_mask,
         
     | 
| 58 | 
         
            +
                    prompt=instruction,
         
     | 
| 59 | 
         
            +
                    output_height=output_h,
         
     | 
| 60 | 
         
            +
                    output_width=output_w,
         
     | 
| 61 | 
         
            +
                    sampler='flow_euler',
         
     | 
| 62 | 
         
            +
                    sample_steps=sample_steps or pipe.input.get("sample_steps", 28),
         
     | 
| 63 | 
         
            +
                    guide_scale=guide_scale or pipe.input.get("guide_scale", 50),
         
     | 
| 64 | 
         
            +
                    seed=seed,
         
     | 
| 65 | 
         
            +
                    repainting_scale=repainting_scale or pipe.input.get("repainting_scale", 1.0),
         
     | 
| 66 | 
         
            +
                    lora_path = model_path
         
     | 
| 67 | 
         
            +
                )
         
     | 
| 68 | 
         
            +
                with FS.put_to(save_path) as local_path:
         
     | 
| 69 | 
         
            +
                    image.save(local_path)
         
     | 
| 70 | 
         
            +
                return local_path, seed
         
     | 
| 71 | 
         
            +
             
     | 
| 72 | 
         
            +
             
     | 
| 73 | 
         
            +
            def run():
         
     | 
| 74 | 
         
            +
                parser = argparse.ArgumentParser(description='Argparser for Scepter:\n')
         
     | 
| 75 | 
         
            +
                parser.add_argument('--instruction',
         
     | 
| 76 | 
         
            +
                                    dest='instruction',
         
     | 
| 77 | 
         
            +
                                    help='The instruction for editing or generating!',
         
     | 
| 78 | 
         
            +
                                    default="")
         
     | 
| 79 | 
         
            +
                parser.add_argument('--output_h',
         
     | 
| 80 | 
         
            +
                                    dest='output_h',
         
     | 
| 81 | 
         
            +
                                    help='The height of output image for generation tasks!',
         
     | 
| 82 | 
         
            +
                                    type=int,
         
     | 
| 83 | 
         
            +
                                    default=1024)
         
     | 
| 84 | 
         
            +
                parser.add_argument('--output_w',
         
     | 
| 85 | 
         
            +
                                    dest='output_w',
         
     | 
| 86 | 
         
            +
                                    help='The width of output image for generation tasks!',
         
     | 
| 87 | 
         
            +
                                    type=int,
         
     | 
| 88 | 
         
            +
                                    default=1024)
         
     | 
| 89 | 
         
            +
                parser.add_argument('--input_reference_image',
         
     | 
| 90 | 
         
            +
                                    dest='input_reference_image',
         
     | 
| 91 | 
         
            +
                                    help='The input reference image!',
         
     | 
| 92 | 
         
            +
                                    default=None
         
     | 
| 93 | 
         
            +
                                    )
         
     | 
| 94 | 
         
            +
                parser.add_argument('--input_image',
         
     | 
| 95 | 
         
            +
                                    dest='input_image',
         
     | 
| 96 | 
         
            +
                                    help='The input image!',
         
     | 
| 97 | 
         
            +
                                    default=None
         
     | 
| 98 | 
         
            +
                                    )
         
     | 
| 99 | 
         
            +
                parser.add_argument('--input_mask',
         
     | 
| 100 | 
         
            +
                                    dest='input_mask',
         
     | 
| 101 | 
         
            +
                                    help='The input mask!',
         
     | 
| 102 | 
         
            +
                                    default=None
         
     | 
| 103 | 
         
            +
                                    )
         
     | 
| 104 | 
         
            +
                parser.add_argument('--save_path',
         
     | 
| 105 | 
         
            +
                                    dest='save_path',
         
     | 
| 106 | 
         
            +
                                    help='The save path for output image!',
         
     | 
| 107 | 
         
            +
                                    default='examples/output_images/output.png'
         
     | 
| 108 | 
         
            +
                                    )
         
     | 
| 109 | 
         
            +
                parser.add_argument('--seed',
         
     | 
| 110 | 
         
            +
                                    dest='seed',
         
     | 
| 111 | 
         
            +
                                    help='The seed for generation!',
         
     | 
| 112 | 
         
            +
                                    type=int,
         
     | 
| 113 | 
         
            +
                                    default=-1)
         
     | 
| 114 | 
         
            +
             
     | 
| 115 | 
         
            +
                parser.add_argument('--step',
         
     | 
| 116 | 
         
            +
                                    dest='step',
         
     | 
| 117 | 
         
            +
                                    help='The sample step for generation!',
         
     | 
| 118 | 
         
            +
                                    type=int,
         
     | 
| 119 | 
         
            +
                                    default=None)
         
     | 
| 120 | 
         
            +
             
     | 
| 121 | 
         
            +
                parser.add_argument('--guide_scale',
         
     | 
| 122 | 
         
            +
                                    dest='guide_scale',
         
     | 
| 123 | 
         
            +
                                    help='The guide scale for generation!',
         
     | 
| 124 | 
         
            +
                                    type=int,
         
     | 
| 125 | 
         
            +
                                    default=None)
         
     | 
| 126 | 
         
            +
             
     | 
| 127 | 
         
            +
                parser.add_argument('--repainting_scale',
         
     | 
| 128 | 
         
            +
                                    dest='repainting_scale',
         
     | 
| 129 | 
         
            +
                                    help='The repainting scale for content filling generation!',
         
     | 
| 130 | 
         
            +
                                    type=int,
         
     | 
| 131 | 
         
            +
                                    default=None)
         
     | 
| 132 | 
         
            +
             
     | 
| 133 | 
         
            +
                parser.add_argument('--task_type',
         
     | 
| 134 | 
         
            +
                                    dest='task_type',
         
     | 
| 135 | 
         
            +
                                    choices=['portrait', 'subject', 'local_editing'],
         
     | 
| 136 | 
         
            +
                                    help="Choose the task type.",
         
     | 
| 137 | 
         
            +
                                    default='')
         
     | 
| 138 | 
         
            +
             
     | 
| 139 | 
         
            +
                parser.add_argument('--task_model',
         
     | 
| 140 | 
         
            +
                                    dest='task_model',
         
     | 
| 141 | 
         
            +
                                    help='The models list for different tasks!',
         
     | 
| 142 | 
         
            +
                                    default="./models/model_zoo.yaml")
         
     | 
| 143 | 
         
            +
             
     | 
| 144 | 
         
            +
             
     | 
| 145 | 
         
            +
                parser.add_argument('--infer_type',
         
     | 
| 146 | 
         
            +
                                    dest='infer_type',
         
     | 
| 147 | 
         
            +
                                    choices=['diffusers'],
         
     | 
| 148 | 
         
            +
                                    default='diffusers',
         
     | 
| 149 | 
         
            +
                                    help="Choose the inference scripts. 'native' refers to using the official implementation of ace++, "
         
     | 
| 150 | 
         
            +
                                         "while 'diffusers' refers to using the adaptation for diffusers")
         
     | 
| 151 | 
         
            +
             
     | 
| 152 | 
         
            +
                parser.add_argument('--cfg_folder',
         
     | 
| 153 | 
         
            +
                                    dest='cfg_folder',
         
     | 
| 154 | 
         
            +
                                    help='The inference config!',
         
     | 
| 155 | 
         
            +
                                    default="./config")
         
     | 
| 156 | 
         
            +
             
     | 
| 157 | 
         
            +
                cfg = Config(load=True, parser_ins=parser)
         
     | 
| 158 | 
         
            +
             
     | 
| 159 | 
         
            +
                model_yamls = glob.glob(os.path.join(cfg.args.cfg_folder, '*.yaml'))
         
     | 
| 160 | 
         
            +
                model_choices = dict()
         
     | 
| 161 | 
         
            +
                for i in model_yamls:
         
     | 
| 162 | 
         
            +
                    model_cfg = Config(load=True, cfg_file=i)
         
     | 
| 163 | 
         
            +
                    model_name = model_cfg.NAME
         
     | 
| 164 | 
         
            +
                    model_choices[model_name] = model_cfg
         
     | 
| 165 | 
         
            +
             
     | 
| 166 | 
         
            +
                if cfg.args.infer_type == "native":
         
     | 
| 167 | 
         
            +
                    infer_name = "ace_plus_native_infer"
         
     | 
| 168 | 
         
            +
                elif cfg.args.infer_type == "diffusers":
         
     | 
| 169 | 
         
            +
                    infer_name = "ace_plus_diffuser_infer"
         
     | 
| 170 | 
         
            +
                else:
         
     | 
| 171 | 
         
            +
                    raise ValueError("infer_type should be native or diffusers")
         
     | 
| 172 | 
         
            +
             
     | 
| 173 | 
         
            +
                assert infer_name in model_choices
         
     | 
| 174 | 
         
            +
             
     | 
| 175 | 
         
            +
                # choose different model
         
     | 
| 176 | 
         
            +
                task_model_cfg = Config(load=True, cfg_file=cfg.args.task_model)
         
     | 
| 177 | 
         
            +
             
     | 
| 178 | 
         
            +
                task_model_dict = {}
         
     | 
| 179 | 
         
            +
                for task_name, task_model in task_model_cfg.MODEL.items():
         
     | 
| 180 | 
         
            +
                    task_model_dict[task_name] = task_model
         
     | 
| 181 | 
         
            +
             
     | 
| 182 | 
         
            +
             
     | 
| 183 | 
         
            +
                # choose the inference scripts.
         
     | 
| 184 | 
         
            +
                pipe_cfg = model_choices[infer_name]
         
     | 
| 185 | 
         
            +
                infer_name = pipe_cfg.get("INFERENCE_TYPE", "ACE_PLUS")
         
     | 
| 186 | 
         
            +
                pipe = inference_dict[infer_name]()
         
     | 
| 187 | 
         
            +
                pipe.init_from_cfg(pipe_cfg)
         
     | 
| 188 | 
         
            +
             
     | 
| 189 | 
         
            +
                if cfg.args.instruction == "" and cfg.args.input_image is None and cfg.args.input_reference_image is None:
         
     | 
| 190 | 
         
            +
                    params = {
         
     | 
| 191 | 
         
            +
                        "output_h": cfg.args.output_h,
         
     | 
| 192 | 
         
            +
                        "output_w": cfg.args.output_w,
         
     | 
| 193 | 
         
            +
                        "sample_steps": cfg.args.step,
         
     | 
| 194 | 
         
            +
                        "guide_scale": cfg.args.guide_scale
         
     | 
| 195 | 
         
            +
                    }
         
     | 
| 196 | 
         
            +
                    # run examples
         
     | 
| 197 | 
         
            +
             
     | 
| 198 | 
         
            +
                    for example in all_examples:
         
     | 
| 199 | 
         
            +
                        example["model_path"] = FS.get_from(task_model_dict[example["task_type"].upper()]["MODEL_PATH"])
         
     | 
| 200 | 
         
            +
                        example.update(params)
         
     | 
| 201 | 
         
            +
                        if example["edit_type"] == "repainting":
         
     | 
| 202 | 
         
            +
                            example["repainting_scale"] = 1.0
         
     | 
| 203 | 
         
            +
                        else:
         
     | 
| 204 | 
         
            +
                            example["repainting_scale"] = task_model_dict[example["task_type"].upper()].get("REPAINTING_SCALE", 1.0)
         
     | 
| 205 | 
         
            +
                        print(example)
         
     | 
| 206 | 
         
            +
                        local_path, seed = run_one_case(pipe, **example)
         
     | 
| 207 | 
         
            +
             
     | 
| 208 | 
         
            +
                else:
         
     | 
| 209 | 
         
            +
                    assert cfg.args.task_type.upper() in task_model_cfg
         
     | 
| 210 | 
         
            +
                    params = {
         
     | 
| 211 | 
         
            +
                        "input_image": cfg.args.input_image,
         
     | 
| 212 | 
         
            +
                        "input_mask": cfg.args.input_mask,
         
     | 
| 213 | 
         
            +
                        "input_reference_image": cfg.args.input_reference_image,
         
     | 
| 214 | 
         
            +
                        "save_path": cfg.args.save_path,
         
     | 
| 215 | 
         
            +
                        "instruction": cfg.args.instruction,
         
     | 
| 216 | 
         
            +
                        "output_h": cfg.args.output_h,
         
     | 
| 217 | 
         
            +
                        "output_w": cfg.args.output_w,
         
     | 
| 218 | 
         
            +
                        "sample_steps": cfg.args.step,
         
     | 
| 219 | 
         
            +
                        "guide_scale": cfg.args.guide_scale,
         
     | 
| 220 | 
         
            +
                        "repainting_scale": cfg.args.repainting_scale,
         
     | 
| 221 | 
         
            +
                        "model_path": FS.get_from(task_model_dict[cfg.args.task_type.upper()]["MODEL_PATH"])
         
     | 
| 222 | 
         
            +
                    }
         
     | 
| 223 | 
         
            +
                    local_path, seed = run_one_case(pipe, **params)
         
     | 
| 224 | 
         
            +
                    print(local_path, seed)
         
     | 
| 225 | 
         
            +
             
     | 
| 226 | 
         
            +
            if __name__ == '__main__':
         
     | 
| 227 | 
         
            +
                run()
         
     | 
| 228 | 
         
            +
             
     | 
    	
        inference/__init__.py
    CHANGED
    
    | 
         @@ -0,0 +1,2 @@ 
     | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from .ace_plus_diffusers import ACEPlusDiffuserInference
         
     | 
| 2 | 
         
            +
            from .ace_plus_inference import ACEInference
         
     | 
    	
        inference/ace_plus_diffusers.py
    CHANGED
    
    | 
         @@ -12,7 +12,6 @@ from scepter.modules.utils.logger import get_logger 
     | 
|
| 12 | 
         
             
            from transformers import T5TokenizerFast
         
     | 
| 13 | 
         
             
            from .utils import ACEPlusImageProcessor
         
     | 
| 14 | 
         | 
| 15 | 
         
            -
             
     | 
| 16 | 
         
             
            class ACEPlusDiffuserInference():
         
     | 
| 17 | 
         
             
                def __init__(self, logger=None):
         
     | 
| 18 | 
         
             
                    if logger is None:
         
     | 
| 
         @@ -39,7 +38,6 @@ class ACEPlusDiffuserInference(): 
     | 
|
| 39 | 
         
             
                    self.pipe.tokenizer_2 = tokenizer_2
         
     | 
| 40 | 
         
             
                    self.load_default(cfg.DEFAULT_PARAS)
         
     | 
| 41 | 
         | 
| 42 | 
         
            -
             
     | 
| 43 | 
         
             
                def prepare_input(self,
         
     | 
| 44 | 
         
             
                                  image,
         
     | 
| 45 | 
         
             
                                  mask,
         
     | 
| 
         @@ -88,7 +86,11 @@ class ACEPlusDiffuserInference(): 
     | 
|
| 88 | 
         
             
                    if isinstance(prompt, str):
         
     | 
| 89 | 
         
             
                        prompt = [prompt]
         
     | 
| 90 | 
         
             
                    seed = seed if seed >= 0 else random.randint(0, 2 ** 32 - 1)
         
     | 
| 91 | 
         
            -
                     
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 92 | 
         
             
                    h, w = image.shape[1:]
         
     | 
| 93 | 
         
             
                    generator = torch.Generator("cpu").manual_seed(seed)
         
     | 
| 94 | 
         
             
                    masked_image_latents = self.prepare_input(image, mask,
         
     | 
| 
         @@ -98,6 +100,8 @@ class ACEPlusDiffuserInference(): 
     | 
|
| 98 | 
         
             
                        with FS.get_from(lora_path) as local_path:
         
     | 
| 99 | 
         
             
                            self.pipe.load_lora_weights(local_path)
         
     | 
| 100 | 
         | 
| 
         | 
|
| 
         | 
|
| 101 | 
         
             
                    image = self.pipe(
         
     | 
| 102 | 
         
             
                        prompt=prompt,
         
     | 
| 103 | 
         
             
                        masked_image_latents=masked_image_latents,
         
     | 
| 
         | 
|
| 12 | 
         
             
            from transformers import T5TokenizerFast
         
     | 
| 13 | 
         
             
            from .utils import ACEPlusImageProcessor
         
     | 
| 14 | 
         | 
| 
         | 
|
| 15 | 
         
             
            class ACEPlusDiffuserInference():
         
     | 
| 16 | 
         
             
                def __init__(self, logger=None):
         
     | 
| 17 | 
         
             
                    if logger is None:
         
     | 
| 
         | 
|
| 38 | 
         
             
                    self.pipe.tokenizer_2 = tokenizer_2
         
     | 
| 39 | 
         
             
                    self.load_default(cfg.DEFAULT_PARAS)
         
     | 
| 40 | 
         | 
| 
         | 
|
| 41 | 
         
             
                def prepare_input(self,
         
     | 
| 42 | 
         
             
                                  image,
         
     | 
| 43 | 
         
             
                                  mask,
         
     | 
| 
         | 
|
| 86 | 
         
             
                    if isinstance(prompt, str):
         
     | 
| 87 | 
         
             
                        prompt = [prompt]
         
     | 
| 88 | 
         
             
                    seed = seed if seed >= 0 else random.randint(0, 2 ** 32 - 1)
         
     | 
| 89 | 
         
            +
                    # edit_image, edit_mask, change_image, content_image, out_h, out_w, slice_w
         
     | 
| 90 | 
         
            +
                    image, mask, _, _, out_h, out_w, slice_w = self.image_processor.preprocess(reference_image, edit_image, edit_mask,
         
     | 
| 91 | 
         
            +
                                                                                         width = output_width,
         
     | 
| 92 | 
         
            +
                                                                                         height = output_height,
         
     | 
| 93 | 
         
            +
                                                                                         repainting_scale = repainting_scale)
         
     | 
| 94 | 
         
             
                    h, w = image.shape[1:]
         
     | 
| 95 | 
         
             
                    generator = torch.Generator("cpu").manual_seed(seed)
         
     | 
| 96 | 
         
             
                    masked_image_latents = self.prepare_input(image, mask,
         
     | 
| 
         | 
|
| 100 | 
         
             
                        with FS.get_from(lora_path) as local_path:
         
     | 
| 101 | 
         
             
                            self.pipe.load_lora_weights(local_path)
         
     | 
| 102 | 
         | 
| 103 | 
         
            +
             
     | 
| 104 | 
         
            +
             
     | 
| 105 | 
         
             
                    image = self.pipe(
         
     | 
| 106 | 
         
             
                        prompt=prompt,
         
     | 
| 107 | 
         
             
                        masked_image_latents=masked_image_latents,
         
     | 
    	
        inference/ace_plus_inference.py
    ADDED
    
    | 
         @@ -0,0 +1,83 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # -*- coding: utf-8 -*-
         
     | 
| 2 | 
         
            +
            # Copyright (c) Alibaba, Inc. and its affiliates.
         
     | 
| 3 | 
         
            +
            import random
         
     | 
| 4 | 
         
            +
            from collections import OrderedDict
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            import torch, numpy as np
         
     | 
| 7 | 
         
            +
            from PIL import Image
         
     | 
| 8 | 
         
            +
            from scepter.modules.model.registry import MODELS
         
     | 
| 9 | 
         
            +
            from scepter.modules.utils.config import Config
         
     | 
| 10 | 
         
            +
            from scepter.modules.utils.distribute import we
         
     | 
| 11 | 
         
            +
            from .registry import BaseInference, INFERENCES
         
     | 
| 12 | 
         
            +
            from .utils import ACEPlusImageProcessor
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
            @INFERENCES.register_class()
         
     | 
| 15 | 
         
            +
            class ACEInference(BaseInference):
         
     | 
| 16 | 
         
            +
                '''
         
     | 
| 17 | 
         
            +
                    reuse the ldm code
         
     | 
| 18 | 
         
            +
                '''
         
     | 
| 19 | 
         
            +
                def __init__(self, cfg, logger=None):
         
     | 
| 20 | 
         
            +
                    super().__init__(cfg, logger)
         
     | 
| 21 | 
         
            +
                    self.pipe = MODELS.build(cfg.MODEL, logger=self.logger).eval().to(we.device_id)
         
     | 
| 22 | 
         
            +
                    self.image_processor = ACEPlusImageProcessor(max_seq_len=cfg.MAX_SEQ_LEN)
         
     | 
| 23 | 
         
            +
                    self.input = {k.lower(): dict(v).get('DEFAULT', None) if isinstance(v, (dict, OrderedDict, Config)) else v for
         
     | 
| 24 | 
         
            +
                                  k, v in cfg.SAMPLE_ARGS.items()}
         
     | 
| 25 | 
         
            +
                    self.dtype = getattr(torch, cfg.get("DTYPE", "bfloat16"))
         
     | 
| 26 | 
         
            +
                @torch.no_grad()
         
     | 
| 27 | 
         
            +
                def __call__(self,
         
     | 
| 28 | 
         
            +
                             reference_image=None,
         
     | 
| 29 | 
         
            +
                             edit_image=None,
         
     | 
| 30 | 
         
            +
                             edit_mask=None,
         
     | 
| 31 | 
         
            +
                             prompt='',
         
     | 
| 32 | 
         
            +
                             edit_type=None,
         
     | 
| 33 | 
         
            +
                             output_height=1024,
         
     | 
| 34 | 
         
            +
                             output_width=1024,
         
     | 
| 35 | 
         
            +
                             sampler='flow_euler',
         
     | 
| 36 | 
         
            +
                             sample_steps=28,
         
     | 
| 37 | 
         
            +
                             guide_scale=50,
         
     | 
| 38 | 
         
            +
                             lora_path=None,
         
     | 
| 39 | 
         
            +
                             seed=-1,
         
     | 
| 40 | 
         
            +
                             repainting_scale=0,
         
     | 
| 41 | 
         
            +
                             use_change=False,
         
     | 
| 42 | 
         
            +
                             keep_pixels=False,
         
     | 
| 43 | 
         
            +
                             keep_pixels_rate=0.8,
         
     | 
| 44 | 
         
            +
                             **kwargs):
         
     | 
| 45 | 
         
            +
                    # convert the input info to the input of ldm.
         
     | 
| 46 | 
         
            +
                    if isinstance(prompt, str):
         
     | 
| 47 | 
         
            +
                        prompt = [prompt]
         
     | 
| 48 | 
         
            +
                    seed = seed if seed >= 0 else random.randint(0, 2 ** 24 - 1)
         
     | 
| 49 | 
         
            +
                    image, mask, change_image, content_image, out_h, out_w, slice_w = self.image_processor.preprocess(reference_image, edit_image, edit_mask,
         
     | 
| 50 | 
         
            +
                                                                                         height=output_height, width=output_width,
         
     | 
| 51 | 
         
            +
                                                                                         repainting_scale=repainting_scale,
         
     | 
| 52 | 
         
            +
                                                                                         keep_pixels=keep_pixels,
         
     | 
| 53 | 
         
            +
                                                                                         keep_pixels_rate=keep_pixels_rate,
         
     | 
| 54 | 
         
            +
                                                                                         use_change = use_change)
         
     | 
| 55 | 
         
            +
                    change_image = [None] if change_image is None else [change_image.to(we.device_id)]
         
     | 
| 56 | 
         
            +
                    image, mask = [image.to(we.device_id)], [mask.to(we.device_id)]
         
     | 
| 57 | 
         
            +
             
     | 
| 58 | 
         
            +
                    (src_image_list, src_mask_list, modify_image_list,
         
     | 
| 59 | 
         
            +
                     edit_id, prompt) = [image], [mask], [change_image], [[0]], [prompt]
         
     | 
| 60 | 
         
            +
             
     | 
| 61 | 
         
            +
                    with torch.amp.autocast(enabled=True, dtype=self.dtype, device_type='cuda'):
         
     | 
| 62 | 
         
            +
                        out_image = self.pipe(
         
     | 
| 63 | 
         
            +
                            src_image_list=src_image_list,
         
     | 
| 64 | 
         
            +
                            modify_image_list= modify_image_list,
         
     | 
| 65 | 
         
            +
                            src_mask_list=src_mask_list,
         
     | 
| 66 | 
         
            +
                            edit_id=edit_id,
         
     | 
| 67 | 
         
            +
                            image=image,
         
     | 
| 68 | 
         
            +
                            image_mask=mask,
         
     | 
| 69 | 
         
            +
                            prompt=prompt,
         
     | 
| 70 | 
         
            +
                            sampler='flow_euler',
         
     | 
| 71 | 
         
            +
                            sample_steps=sample_steps,
         
     | 
| 72 | 
         
            +
                            seed=seed,
         
     | 
| 73 | 
         
            +
                            guide_scale=guide_scale,
         
     | 
| 74 | 
         
            +
                            show_process=True,
         
     | 
| 75 | 
         
            +
                        )
         
     | 
| 76 | 
         
            +
                    imgs = [x_i['reconstruct_image'].float().permute(1, 2, 0).cpu().numpy()
         
     | 
| 77 | 
         
            +
                        for x_i in out_image
         
     | 
| 78 | 
         
            +
                    ]
         
     | 
| 79 | 
         
            +
                    imgs = [Image.fromarray((img * 255).astype(np.uint8)) for img in imgs]
         
     | 
| 80 | 
         
            +
                    edit_image = Image.fromarray((torch.clamp(image[0] / 2 + 0.5, min=0.0, max=1.0)*255).float().permute(1, 2, 0).cpu().numpy().astype(np.uint8))
         
     | 
| 81 | 
         
            +
                    change_image = Image.fromarray((torch.clamp(change_image[0] / 2 + 0.5, min=0.0, max=1.0)*255).float().permute(1, 2, 0).cpu().numpy().astype(np.uint8))
         
     | 
| 82 | 
         
            +
                    mask = Image.fromarray((mask[0] * 255).squeeze(0).cpu().numpy().astype(np.uint8))
         
     | 
| 83 | 
         
            +
                    return self.image_processor.postprocess(imgs[0], slice_w, out_w, out_h), edit_image, change_image, mask, seed
         
     | 
    	
        inference/registry.py
    ADDED
    
    | 
         @@ -0,0 +1,228 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # -*- coding: utf-8 -*-
         
     | 
| 2 | 
         
            +
            # Copyright (c) Alibaba, Inc. and its affiliates.
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            import torch
         
     | 
| 5 | 
         
            +
            from PIL.Image import Image
         
     | 
| 6 | 
         
            +
            from collections import OrderedDict
         
     | 
| 7 | 
         
            +
            from scepter.modules.utils.distribute import we
         
     | 
| 8 | 
         
            +
            from scepter.modules.utils.config import Config
         
     | 
| 9 | 
         
            +
            from scepter.modules.utils.logger import get_logger
         
     | 
| 10 | 
         
            +
            from scepter.studio.utils.env import get_available_memory
         
     | 
| 11 | 
         
            +
            from scepter.modules.model.registry import MODELS, BACKBONES, EMBEDDERS
         
     | 
| 12 | 
         
            +
            from scepter.modules.utils.registry import Registry, build_from_config
         
     | 
| 13 | 
         
            +
            def get_model(model_tuple):
         
     | 
| 14 | 
         
            +
                assert 'model' in model_tuple
         
     | 
| 15 | 
         
            +
                return model_tuple['model']
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
            class BaseInference():
         
     | 
| 18 | 
         
            +
                '''
         
     | 
| 19 | 
         
            +
                    support to load the components dynamicly.
         
     | 
| 20 | 
         
            +
                    create and load model when run this model at the first time.
         
     | 
| 21 | 
         
            +
                '''
         
     | 
| 22 | 
         
            +
                def __init__(self, cfg, logger=None):
         
     | 
| 23 | 
         
            +
                    if logger is None:
         
     | 
| 24 | 
         
            +
                        logger = get_logger(name='scepter')
         
     | 
| 25 | 
         
            +
                    self.logger = logger
         
     | 
| 26 | 
         
            +
                    self.name = cfg.NAME
         
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
                def init_from_modules(self, modules):
         
     | 
| 29 | 
         
            +
                    for k, v in modules.items():
         
     | 
| 30 | 
         
            +
                        self.__setattr__(k, v)
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
                def infer_model(self, cfg, module_paras=None):
         
     | 
| 33 | 
         
            +
                    module = {
         
     | 
| 34 | 
         
            +
                        'model': None,
         
     | 
| 35 | 
         
            +
                        'cfg': cfg,
         
     | 
| 36 | 
         
            +
                        'device': 'offline',
         
     | 
| 37 | 
         
            +
                        'name': cfg.NAME,
         
     | 
| 38 | 
         
            +
                        'function_info': {},
         
     | 
| 39 | 
         
            +
                        'paras': {}
         
     | 
| 40 | 
         
            +
                    }
         
     | 
| 41 | 
         
            +
                    if module_paras is None:
         
     | 
| 42 | 
         
            +
                        return module
         
     | 
| 43 | 
         
            +
                    function_info = {}
         
     | 
| 44 | 
         
            +
                    paras = {
         
     | 
| 45 | 
         
            +
                        k.lower(): v
         
     | 
| 46 | 
         
            +
                        for k, v in module_paras.get('PARAS', {}).items()
         
     | 
| 47 | 
         
            +
                    }
         
     | 
| 48 | 
         
            +
                    for function in module_paras.get('FUNCTION', []):
         
     | 
| 49 | 
         
            +
                        input_dict = {}
         
     | 
| 50 | 
         
            +
                        for inp in function.get('INPUT', []):
         
     | 
| 51 | 
         
            +
                            if inp.lower() in self.input:
         
     | 
| 52 | 
         
            +
                                input_dict[inp.lower()] = self.input[inp.lower()]
         
     | 
| 53 | 
         
            +
                        function_info[function.NAME] = {
         
     | 
| 54 | 
         
            +
                            'dtype': function.get('DTYPE', 'float32'),
         
     | 
| 55 | 
         
            +
                            'input': input_dict
         
     | 
| 56 | 
         
            +
                        }
         
     | 
| 57 | 
         
            +
                    module['paras'] = paras
         
     | 
| 58 | 
         
            +
                    module['function_info'] = function_info
         
     | 
| 59 | 
         
            +
                    return module
         
     | 
| 60 | 
         
            +
             
     | 
| 61 | 
         
            +
                def init_from_ckpt(self, path, model, ignore_keys=list()):
         
     | 
| 62 | 
         
            +
                    if path.endswith('safetensors'):
         
     | 
| 63 | 
         
            +
                        from safetensors.torch import load_file as load_safetensors
         
     | 
| 64 | 
         
            +
                        sd = load_safetensors(path)
         
     | 
| 65 | 
         
            +
                    else:
         
     | 
| 66 | 
         
            +
                        sd = torch.load(path, map_location='cpu', weights_only=True)
         
     | 
| 67 | 
         
            +
             
     | 
| 68 | 
         
            +
                    new_sd = OrderedDict()
         
     | 
| 69 | 
         
            +
                    for k, v in sd.items():
         
     | 
| 70 | 
         
            +
                        ignored = False
         
     | 
| 71 | 
         
            +
                        for ik in ignore_keys:
         
     | 
| 72 | 
         
            +
                            if ik in k:
         
     | 
| 73 | 
         
            +
                                if we.rank == 0:
         
     | 
| 74 | 
         
            +
                                    self.logger.info(
         
     | 
| 75 | 
         
            +
                                        'Ignore key {} from state_dict.'.format(k))
         
     | 
| 76 | 
         
            +
                                ignored = True
         
     | 
| 77 | 
         
            +
                                break
         
     | 
| 78 | 
         
            +
                        if not ignored:
         
     | 
| 79 | 
         
            +
                            new_sd[k] = v
         
     | 
| 80 | 
         
            +
             
     | 
| 81 | 
         
            +
                    missing, unexpected = model.load_state_dict(new_sd, strict=False)
         
     | 
| 82 | 
         
            +
                    if we.rank == 0:
         
     | 
| 83 | 
         
            +
                        self.logger.info(
         
     | 
| 84 | 
         
            +
                            f'Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys'
         
     | 
| 85 | 
         
            +
                        )
         
     | 
| 86 | 
         
            +
                        if len(missing) > 0:
         
     | 
| 87 | 
         
            +
                            self.logger.info(f'Missing Keys:\n {missing}')
         
     | 
| 88 | 
         
            +
                        if len(unexpected) > 0:
         
     | 
| 89 | 
         
            +
                            self.logger.info(f'\nUnexpected Keys:\n {unexpected}')
         
     | 
| 90 | 
         
            +
             
     | 
| 91 | 
         
            +
                def load(self, module):
         
     | 
| 92 | 
         
            +
                    if module['device'] == 'offline':
         
     | 
| 93 | 
         
            +
                        from scepter.modules.utils.import_utils import LazyImportModule
         
     | 
| 94 | 
         
            +
                        if (LazyImportModule.get_module_type(('MODELS', module['cfg'].NAME)) or
         
     | 
| 95 | 
         
            +
                                module['cfg'].NAME in MODELS.class_map):
         
     | 
| 96 | 
         
            +
                            model = MODELS.build(module['cfg'], logger=self.logger).eval()
         
     | 
| 97 | 
         
            +
                        elif (LazyImportModule.get_module_type(('BACKBONES', module['cfg'].NAME)) or
         
     | 
| 98 | 
         
            +
                                module['cfg'].NAME in BACKBONES.class_map):
         
     | 
| 99 | 
         
            +
                            model = BACKBONES.build(module['cfg'],
         
     | 
| 100 | 
         
            +
                                                    logger=self.logger).eval()
         
     | 
| 101 | 
         
            +
                        elif (LazyImportModule.get_module_type(('EMBEDDERS', module['cfg'].NAME)) or
         
     | 
| 102 | 
         
            +
                                module['cfg'].NAME in EMBEDDERS.class_map):
         
     | 
| 103 | 
         
            +
                            model = EMBEDDERS.build(module['cfg'],
         
     | 
| 104 | 
         
            +
                                                    logger=self.logger).eval()
         
     | 
| 105 | 
         
            +
                        else:
         
     | 
| 106 | 
         
            +
                            raise NotImplementedError
         
     | 
| 107 | 
         
            +
                        if 'DTYPE' in module['cfg'] and module['cfg']['DTYPE'] is not None:
         
     | 
| 108 | 
         
            +
                            model = model.to(getattr(torch, module['cfg'].DTYPE))
         
     | 
| 109 | 
         
            +
                        if module['cfg'].get('RELOAD_MODEL', None):
         
     | 
| 110 | 
         
            +
                            self.init_from_ckpt(module['cfg'].RELOAD_MODEL, model)
         
     | 
| 111 | 
         
            +
                        module['model'] = model
         
     | 
| 112 | 
         
            +
                        module['device'] = 'cpu'
         
     | 
| 113 | 
         
            +
                    if module['device'] == 'cpu':
         
     | 
| 114 | 
         
            +
                        module['device'] = we.device_id
         
     | 
| 115 | 
         
            +
                        module['model'] = module['model'].to(we.device_id)
         
     | 
| 116 | 
         
            +
                    return module
         
     | 
| 117 | 
         
            +
             
     | 
| 118 | 
         
            +
                def unload(self, module):
         
     | 
| 119 | 
         
            +
                    if module is None:
         
     | 
| 120 | 
         
            +
                        return module
         
     | 
| 121 | 
         
            +
                    mem = get_available_memory()
         
     | 
| 122 | 
         
            +
                    free_mem = int(mem['available'] / (1024**2))
         
     | 
| 123 | 
         
            +
                    total_mem = int(mem['total'] / (1024**2))
         
     | 
| 124 | 
         
            +
                    if free_mem < 0.5 * total_mem:
         
     | 
| 125 | 
         
            +
                        if module['model'] is not None:
         
     | 
| 126 | 
         
            +
                            module['model'] = module['model'].to('cpu')
         
     | 
| 127 | 
         
            +
                            del module['model']
         
     | 
| 128 | 
         
            +
                        module['model'] = None
         
     | 
| 129 | 
         
            +
                        module['device'] = 'offline'
         
     | 
| 130 | 
         
            +
                        print('delete module')
         
     | 
| 131 | 
         
            +
                    else:
         
     | 
| 132 | 
         
            +
                        if module['model'] is not None:
         
     | 
| 133 | 
         
            +
                            module['model'] = module['model'].to('cpu')
         
     | 
| 134 | 
         
            +
                            module['device'] = 'cpu'
         
     | 
| 135 | 
         
            +
                        else:
         
     | 
| 136 | 
         
            +
                            module['device'] = 'offline'
         
     | 
| 137 | 
         
            +
                    if torch.cuda.is_available():
         
     | 
| 138 | 
         
            +
                        torch.cuda.empty_cache()
         
     | 
| 139 | 
         
            +
                        torch.cuda.ipc_collect()
         
     | 
| 140 | 
         
            +
                    return module
         
     | 
| 141 | 
         
            +
             
     | 
| 142 | 
         
            +
                def dynamic_load(self, module=None, name=''):
         
     | 
| 143 | 
         
            +
                    self.logger.info('Loading {} model'.format(name))
         
     | 
| 144 | 
         
            +
                    if name == 'all':
         
     | 
| 145 | 
         
            +
                        for subname in self.loaded_model_name:
         
     | 
| 146 | 
         
            +
                            self.loaded_model[subname] = self.dynamic_load(
         
     | 
| 147 | 
         
            +
                                getattr(self, subname), subname)
         
     | 
| 148 | 
         
            +
                    elif name in self.loaded_model_name:
         
     | 
| 149 | 
         
            +
                        if name in self.loaded_model:
         
     | 
| 150 | 
         
            +
                            if module['cfg'] != self.loaded_model[name]['cfg']:
         
     | 
| 151 | 
         
            +
                                self.unload(self.loaded_model[name])
         
     | 
| 152 | 
         
            +
                                module = self.load(module)
         
     | 
| 153 | 
         
            +
                                self.loaded_model[name] = module
         
     | 
| 154 | 
         
            +
                                return module
         
     | 
| 155 | 
         
            +
                            elif module['device'] == 'cpu' or module['device'] == 'offline':
         
     | 
| 156 | 
         
            +
                                module = self.load(module)
         
     | 
| 157 | 
         
            +
                                return module
         
     | 
| 158 | 
         
            +
                            else:
         
     | 
| 159 | 
         
            +
                                return module
         
     | 
| 160 | 
         
            +
                        else:
         
     | 
| 161 | 
         
            +
                            module = self.load(module)
         
     | 
| 162 | 
         
            +
                            self.loaded_model[name] = module
         
     | 
| 163 | 
         
            +
                            return module
         
     | 
| 164 | 
         
            +
                    else:
         
     | 
| 165 | 
         
            +
                        return self.load(module)
         
     | 
| 166 | 
         
            +
             
     | 
| 167 | 
         
            +
                def dynamic_unload(self, module=None, name='', skip_loaded=False):
         
     | 
| 168 | 
         
            +
                    self.logger.info('Unloading {} model'.format(name))
         
     | 
| 169 | 
         
            +
                    if name == 'all':
         
     | 
| 170 | 
         
            +
                        for name, module in self.loaded_model.items():
         
     | 
| 171 | 
         
            +
                            module = self.unload(self.loaded_model[name])
         
     | 
| 172 | 
         
            +
                            self.loaded_model[name] = module
         
     | 
| 173 | 
         
            +
                    elif name in self.loaded_model_name:
         
     | 
| 174 | 
         
            +
                        if name in self.loaded_model:
         
     | 
| 175 | 
         
            +
                            if not skip_loaded:
         
     | 
| 176 | 
         
            +
                                module = self.unload(self.loaded_model[name])
         
     | 
| 177 | 
         
            +
                                self.loaded_model[name] = module
         
     | 
| 178 | 
         
            +
                        else:
         
     | 
| 179 | 
         
            +
                            self.unload(module)
         
     | 
| 180 | 
         
            +
                    else:
         
     | 
| 181 | 
         
            +
                        self.unload(module)
         
     | 
| 182 | 
         
            +
             
     | 
| 183 | 
         
            +
                def load_default(self, cfg):
         
     | 
| 184 | 
         
            +
                    module_paras = {}
         
     | 
| 185 | 
         
            +
                    if cfg is not None:
         
     | 
| 186 | 
         
            +
                        self.paras = cfg.PARAS
         
     | 
| 187 | 
         
            +
                        self.input_cfg = {k.lower(): v for k, v in cfg.INPUT.items()}
         
     | 
| 188 | 
         
            +
                        self.input = {k.lower(): dict(v).get('DEFAULT', None) if isinstance(v, (dict, OrderedDict, Config)) else v for k, v in cfg.INPUT.items()}
         
     | 
| 189 | 
         
            +
                        self.output = {k.lower(): v for k, v in cfg.OUTPUT.items()}
         
     | 
| 190 | 
         
            +
                        module_paras = cfg.MODULES_PARAS
         
     | 
| 191 | 
         
            +
                    return module_paras
         
     | 
| 192 | 
         
            +
             
     | 
| 193 | 
         
            +
                def load_image(self, image, num_samples=1):
         
     | 
| 194 | 
         
            +
                    if isinstance(image, torch.Tensor):
         
     | 
| 195 | 
         
            +
                        pass
         
     | 
| 196 | 
         
            +
                    elif isinstance(image, Image):
         
     | 
| 197 | 
         
            +
                        pass
         
     | 
| 198 | 
         
            +
                    elif isinstance(image, Image):
         
     | 
| 199 | 
         
            +
                        pass
         
     | 
| 200 | 
         
            +
             
     | 
| 201 | 
         
            +
                def get_function_info(self, module, function_name=None):
         
     | 
| 202 | 
         
            +
                    all_function = module['function_info']
         
     | 
| 203 | 
         
            +
                    if function_name in all_function:
         
     | 
| 204 | 
         
            +
                        return function_name, all_function[function_name]['dtype']
         
     | 
| 205 | 
         
            +
                    if function_name is None and len(all_function) == 1:
         
     | 
| 206 | 
         
            +
                        for k, v in all_function.items():
         
     | 
| 207 | 
         
            +
                            return k, v['dtype']
         
     | 
| 208 | 
         
            +
             
     | 
| 209 | 
         
            +
                @torch.no_grad()
         
     | 
| 210 | 
         
            +
                def __call__(self,
         
     | 
| 211 | 
         
            +
                             input,
         
     | 
| 212 | 
         
            +
                             **kwargs):
         
     | 
| 213 | 
         
            +
                    return
         
     | 
| 214 | 
         
            +
             
     | 
| 215 | 
         
            +
            def build_inference(cfg, registry, logger=None, *args, **kwargs):
         
     | 
| 216 | 
         
            +
                """ After build model, load pretrained model if exists key `pretrain`.
         
     | 
| 217 | 
         
            +
             
     | 
| 218 | 
         
            +
                pretrain (str, dict): Describes how to load pretrained model.
         
     | 
| 219 | 
         
            +
                    str, treat pretrain as model path;
         
     | 
| 220 | 
         
            +
                    dict: should contains key `path`, and other parameters token by function load_pretrained();
         
     | 
| 221 | 
         
            +
                """
         
     | 
| 222 | 
         
            +
                if not isinstance(cfg, Config):
         
     | 
| 223 | 
         
            +
                    raise TypeError(f'Config must be type dict, got {type(cfg)}')
         
     | 
| 224 | 
         
            +
                model = build_from_config(cfg, registry, logger=logger, *args, **kwargs)
         
     | 
| 225 | 
         
            +
                return model
         
     | 
| 226 | 
         
            +
             
     | 
| 227 | 
         
            +
            # reigister cls for diffusion.
         
     | 
| 228 | 
         
            +
            INFERENCES = Registry('INFERENCE', build_func=build_inference)
         
     | 
    	
        inference/utils.py
    CHANGED
    
    | 
         @@ -49,7 +49,10 @@ class ACEPlusImageProcessor(): 
     | 
|
| 49 | 
         
             
                               edit_mask=None,
         
     | 
| 50 | 
         
             
                               height=1024,
         
     | 
| 51 | 
         
             
                               width=1024,
         
     | 
| 52 | 
         
            -
                               repainting_scale = 1.0 
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 53 | 
         
             
                    reference_image = self.image_check(reference_image)
         
     | 
| 54 | 
         
             
                    edit_image = self.image_check(edit_image)
         
     | 
| 55 | 
         
             
                    # for reference generation
         
     | 
| 
         @@ -57,8 +60,12 @@ class ACEPlusImageProcessor(): 
     | 
|
| 57 | 
         
             
                        edit_image = torch.zeros([3, height, width])
         
     | 
| 58 | 
         
             
                        edit_mask = torch.ones([1, height, width])
         
     | 
| 59 | 
         
             
                    else:
         
     | 
| 60 | 
         
            -
                        edit_mask  
     | 
| 61 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 62 | 
         
             
                        edit_mask = edit_mask.astype(
         
     | 
| 63 | 
         
             
                            np.float32) if np.any(edit_mask) else np.ones_like(edit_mask).astype(
         
     | 
| 64 | 
         
             
                            np.float32)
         
     | 
| 
         @@ -71,12 +78,27 @@ class ACEPlusImageProcessor(): 
     | 
|
| 71 | 
         | 
| 72 | 
         
             
                    assert edit_mask is not None
         
     | 
| 73 | 
         
             
                    if reference_image is not None:
         
     | 
| 74 | 
         
            -
                    # align height with edit_image
         
     | 
| 75 | 
         
             
                        _, H, W = reference_image.shape
         
     | 
| 76 | 
         
             
                        _, eH, eW = edit_image.shape
         
     | 
| 77 | 
         
            -
                         
     | 
| 78 | 
         
            -
             
     | 
| 79 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 80 | 
         
             
                        edit_image = torch.cat([reference_image, edit_image], dim=-1)
         
     | 
| 81 | 
         
             
                        edit_mask = torch.cat([torch.zeros([1, reference_image.shape[1], reference_image.shape[2]]), edit_mask], dim=-1)
         
     | 
| 82 | 
         
             
                        slice_w = reference_image.shape[-1]
         
     | 
| 
         @@ -89,16 +111,21 @@ class ACEPlusImageProcessor(): 
     | 
|
| 89 | 
         
             
                    rW = int(W * scale) // self.d * self.d
         
     | 
| 90 | 
         
             
                    slice_w = int(slice_w * scale) // self.d * self.d
         
     | 
| 91 | 
         | 
| 92 | 
         
            -
                    edit_image = T.Resize((rH, rW), interpolation=T.InterpolationMode. 
     | 
| 93 | 
         
             
                    edit_mask = T.Resize((rH, rW), interpolation=T.InterpolationMode.NEAREST_EXACT, antialias=True)(edit_mask)
         
     | 
| 94 | 
         
            -
             
     | 
| 95 | 
         
            -
                     
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 96 | 
         | 
| 97 | 
         | 
| 98 | 
         
             
                def postprocess(self, image, slice_w, out_w, out_h):
         
     | 
| 99 | 
         
             
                    w, h = image.size
         
     | 
| 100 | 
         
             
                    if slice_w > 0:
         
     | 
| 101 | 
         
            -
                        output_image = image.crop((slice_w +  
     | 
| 102 | 
         
             
                        output_image = output_image.resize((out_w, out_h))
         
     | 
| 103 | 
         
             
                    else:
         
     | 
| 104 | 
         
             
                        output_image = image
         
     | 
| 
         | 
|
| 49 | 
         
             
                               edit_mask=None,
         
     | 
| 50 | 
         
             
                               height=1024,
         
     | 
| 51 | 
         
             
                               width=1024,
         
     | 
| 52 | 
         
            +
                               repainting_scale = 1.0,
         
     | 
| 53 | 
         
            +
                               keep_pixels = False,
         
     | 
| 54 | 
         
            +
                               keep_pixels_rate = 0.8,
         
     | 
| 55 | 
         
            +
                               use_change = False):
         
     | 
| 56 | 
         
             
                    reference_image = self.image_check(reference_image)
         
     | 
| 57 | 
         
             
                    edit_image = self.image_check(edit_image)
         
     | 
| 58 | 
         
             
                    # for reference generation
         
     | 
| 
         | 
|
| 60 | 
         
             
                        edit_image = torch.zeros([3, height, width])
         
     | 
| 61 | 
         
             
                        edit_mask = torch.ones([1, height, width])
         
     | 
| 62 | 
         
             
                    else:
         
     | 
| 63 | 
         
            +
                        if edit_mask is None:
         
     | 
| 64 | 
         
            +
                            _, eH, eW = edit_image.shape
         
     | 
| 65 | 
         
            +
                            edit_mask = np.ones((eH, eW))
         
     | 
| 66 | 
         
            +
                        else:
         
     | 
| 67 | 
         
            +
                            edit_mask = np.asarray(edit_mask)
         
     | 
| 68 | 
         
            +
                            edit_mask = np.where(edit_mask > 128, 1, 0)
         
     | 
| 69 | 
         
             
                        edit_mask = edit_mask.astype(
         
     | 
| 70 | 
         
             
                            np.float32) if np.any(edit_mask) else np.ones_like(edit_mask).astype(
         
     | 
| 71 | 
         
             
                            np.float32)
         
     | 
| 
         | 
|
| 78 | 
         | 
| 79 | 
         
             
                    assert edit_mask is not None
         
     | 
| 80 | 
         
             
                    if reference_image is not None:
         
     | 
| 
         | 
|
| 81 | 
         
             
                        _, H, W = reference_image.shape
         
     | 
| 82 | 
         
             
                        _, eH, eW = edit_image.shape
         
     | 
| 83 | 
         
            +
                        if not keep_pixels:
         
     | 
| 84 | 
         
            +
                            # align height with edit_image
         
     | 
| 85 | 
         
            +
                            scale = eH / H
         
     | 
| 86 | 
         
            +
                            tH, tW = eH, int(W * scale)
         
     | 
| 87 | 
         
            +
                            reference_image = T.Resize((tH, tW), interpolation=T.InterpolationMode.BILINEAR, antialias=True)(
         
     | 
| 88 | 
         
            +
                                reference_image)
         
     | 
| 89 | 
         
            +
                        else:
         
     | 
| 90 | 
         
            +
                            # padding
         
     | 
| 91 | 
         
            +
                            if H >= keep_pixels_rate * eH:
         
     | 
| 92 | 
         
            +
                                tH = int(eH * keep_pixels_rate)
         
     | 
| 93 | 
         
            +
                                scale = tH/H
         
     | 
| 94 | 
         
            +
                                tW = int(W * scale)
         
     | 
| 95 | 
         
            +
                                reference_image = T.Resize((tH, tW), interpolation=T.InterpolationMode.BILINEAR, antialias=True)(
         
     | 
| 96 | 
         
            +
                                    reference_image)
         
     | 
| 97 | 
         
            +
                            rH, rW = reference_image.shape[-2:]
         
     | 
| 98 | 
         
            +
                            delta_w = 0
         
     | 
| 99 | 
         
            +
                            delta_h = eH - rH
         
     | 
| 100 | 
         
            +
                            padding = (delta_w // 2, delta_h // 2, delta_w - (delta_w // 2), delta_h - (delta_h // 2))
         
     | 
| 101 | 
         
            +
                            reference_image = T.Pad(padding, fill=0, padding_mode="constant")(reference_image)
         
     | 
| 102 | 
         
             
                        edit_image = torch.cat([reference_image, edit_image], dim=-1)
         
     | 
| 103 | 
         
             
                        edit_mask = torch.cat([torch.zeros([1, reference_image.shape[1], reference_image.shape[2]]), edit_mask], dim=-1)
         
     | 
| 104 | 
         
             
                        slice_w = reference_image.shape[-1]
         
     | 
| 
         | 
|
| 111 | 
         
             
                    rW = int(W * scale) // self.d * self.d
         
     | 
| 112 | 
         
             
                    slice_w = int(slice_w * scale) // self.d * self.d
         
     | 
| 113 | 
         | 
| 114 | 
         
            +
                    edit_image = T.Resize((rH, rW), interpolation=T.InterpolationMode.NEAREST_EXACT, antialias=True)(edit_image)
         
     | 
| 115 | 
         
             
                    edit_mask = T.Resize((rH, rW), interpolation=T.InterpolationMode.NEAREST_EXACT, antialias=True)(edit_mask)
         
     | 
| 116 | 
         
            +
                    content_image = edit_image
         
     | 
| 117 | 
         
            +
                    if use_change:
         
     | 
| 118 | 
         
            +
                        change_image = edit_image * edit_mask
         
     | 
| 119 | 
         
            +
                        edit_image = edit_image * (1 - edit_mask)
         
     | 
| 120 | 
         
            +
                    else:
         
     | 
| 121 | 
         
            +
                        change_image = None
         
     | 
| 122 | 
         
            +
                    return edit_image, edit_mask, change_image, content_image, out_h, out_w, slice_w
         
     | 
| 123 | 
         | 
| 124 | 
         | 
| 125 | 
         
             
                def postprocess(self, image, slice_w, out_w, out_h):
         
     | 
| 126 | 
         
             
                    w, h = image.size
         
     | 
| 127 | 
         
             
                    if slice_w > 0:
         
     | 
| 128 | 
         
            +
                        output_image = image.crop((slice_w + 30, 0, w, h))
         
     | 
| 129 | 
         
             
                        output_image = output_image.resize((out_w, out_h))
         
     | 
| 130 | 
         
             
                    else:
         
     | 
| 131 | 
         
             
                        output_image = image
         
     |