Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	
		erwold
		
	commited on
		
		
					Commit 
							
							·
						
						76678b6
	
1
								Parent(s):
							
							bb47725
								
Initial Commit
Browse files
    	
        app.py
    CHANGED
    
    | 
         @@ -13,7 +13,6 @@ import sys 
     | 
|
| 13 | 
         | 
| 14 | 
         
             
            from qwen2_vl.modeling_qwen2_vl import Qwen2VLSimplifiedModel
         
     | 
| 15 | 
         
             
            from huggingface_hub import snapshot_download
         
     | 
| 16 | 
         
            -
            import spaces
         
     | 
| 17 | 
         | 
| 18 | 
         
             
            # 设置日志
         
     | 
| 19 | 
         
             
            logging.basicConfig(
         
     | 
| 
         @@ -78,42 +77,53 @@ class FluxInterface: 
     | 
|
| 78 | 
         
             
                        return
         
     | 
| 79 | 
         | 
| 80 | 
         
             
                    logger.info("Starting model loading...")
         
     | 
| 81 | 
         
            -
                    # 3. 显式设置 PyTorch 缓存分配器的行为
         
     | 
| 82 | 
         
            -
                    torch.cuda.set_per_process_memory_fraction(0.95)  # 允许使用95%的显存
         
     | 
| 83 | 
         
            -
                    torch.cuda.max_memory_allocated = lambda *args, **kwargs: 0  # 忽略已分配内存的限制
         
     | 
| 84 | 
         | 
| 85 | 
         
            -
                    #  
     | 
| 86 | 
         
             
                    tokenizer = CLIPTokenizer.from_pretrained(os.path.join(MODEL_CACHE_DIR, "flux/tokenizer"))
         
     | 
| 87 | 
         
            -
                    text_encoder = CLIPTextModel.from_pretrained( 
     | 
| 88 | 
         
            -
             
     | 
| 89 | 
         
            -
                     
     | 
| 90 | 
         | 
| 91 | 
         
            -
                     
     | 
| 92 | 
         
            -
             
     | 
| 93 | 
         
            -
                     
     | 
| 94 | 
         
            -
                    scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(os.path.join(MODEL_CACHE_DIR, "flux/scheduler"), shift=1)
         
     | 
| 95 | 
         | 
| 96 | 
         
            -
                     
     | 
| 97 | 
         
            -
             
     | 
| 98 | 
         | 
| 99 | 
         
            -
                    #  
     | 
| 100 | 
         
            -
                     
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 101 | 
         
             
                    connector_path = os.path.join(MODEL_CACHE_DIR, "qwen2-vl/connector.pt")
         
     | 
| 102 | 
         
             
                    connector_state = torch.load(connector_path, map_location='cpu')
         
     | 
| 103 | 
         
            -
                    connector_state = {k: v.to(self.dtype) for k, v in connector_state.items()}
         
     | 
| 104 | 
         
             
                    connector.load_state_dict(connector_state)
         
     | 
| 105 | 
         
            -
                     
     | 
| 106 | 
         
            -
             
     | 
| 107 | 
         
            -
                    # 加载 T5 embedder
         
     | 
| 108 | 
         
            -
                    self.t5_context_embedder = nn.Linear(4096, 3072).to(self.dtype).to(self.device)
         
     | 
| 109 | 
         
             
                    t5_embedder_path = os.path.join(MODEL_CACHE_DIR, "qwen2-vl/t5_embedder.pt")
         
     | 
| 110 | 
         
             
                    t5_embedder_state = torch.load(t5_embedder_path, map_location='cpu')
         
     | 
| 111 | 
         
            -
                    t5_embedder_state = {k: v.to(self.dtype) for k, v in t5_embedder_state.items()}
         
     | 
| 112 | 
         
             
                    self.t5_context_embedder.load_state_dict(t5_embedder_state)
         
     | 
| 113 | 
         
            -
                    self.t5_context_embedder = self.t5_context_embedder.to(self.device)
         
     | 
| 114 | 
         | 
| 115 | 
         
            -
                    #  
     | 
| 116 | 
         
            -
                    for model in [text_encoder, text_encoder_two, vae, transformer, qwen2vl,  
     | 
| 
         | 
|
| 117 | 
         
             
                        model.requires_grad_(False)
         
     | 
| 118 | 
         
             
                        model.eval()
         
     | 
| 119 | 
         | 
| 
         @@ -133,9 +143,9 @@ class FluxInterface: 
     | 
|
| 133 | 
         | 
| 134 | 
         
             
                    # Initialize processor and pipeline
         
     | 
| 135 | 
         
             
                    self.qwen2vl_processor = AutoProcessor.from_pretrained(
         
     | 
| 136 | 
         
            -
                        self.MODEL_ID, 
     | 
| 137 | 
         
             
                        subfolder="qwen2-vl",
         
     | 
| 138 | 
         
            -
                        min_pixels=256*28*28, 
     | 
| 139 | 
         
             
                        max_pixels=256*28*28
         
     | 
| 140 | 
         
             
                    )
         
     | 
| 141 | 
         | 
| 
         @@ -145,7 +155,61 @@ class FluxInterface: 
     | 
|
| 145 | 
         
             
                        vae=vae,
         
     | 
| 146 | 
         
             
                        text_encoder=text_encoder,
         
     | 
| 147 | 
         
             
                        tokenizer=tokenizer,
         
     | 
| 148 | 
         
            -
                    ) 
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 149 | 
         | 
| 150 | 
         
             
                def resize_image(self, img, max_pixels=1050000):
         
     | 
| 151 | 
         
             
                    if not isinstance(img, Image.Image):
         
     | 
| 
         @@ -163,28 +227,7 @@ class FluxInterface: 
     | 
|
| 163 | 
         
             
                        img = img.resize((new_width, new_height), Image.LANCZOS)
         
     | 
| 164 | 
         | 
| 165 | 
         
             
                    return img
         
     | 
| 166 | 
         
            -
             
     | 
| 167 | 
         
            -
                # [Previous methods remain unchanged...]
         
     | 
| 168 | 
         
            -
                def process_image(self, image):
         
     | 
| 169 | 
         
            -
                    message = [
         
     | 
| 170 | 
         
            -
                        {
         
     | 
| 171 | 
         
            -
                            "role": "user",
         
     | 
| 172 | 
         
            -
                            "content": [
         
     | 
| 173 | 
         
            -
                                {"type": "image", "image": image},
         
     | 
| 174 | 
         
            -
                                {"type": "text", "text": "Describe this image."},
         
     | 
| 175 | 
         
            -
                            ]
         
     | 
| 176 | 
         
            -
                        }
         
     | 
| 177 | 
         
            -
                    ]
         
     | 
| 178 | 
         
            -
                    text = self.qwen2vl_processor.apply_chat_template(message, tokenize=False, add_generation_prompt=True)
         
     | 
| 179 | 
         
            -
             
     | 
| 180 | 
         
            -
                    with torch.no_grad():
         
     | 
| 181 | 
         
            -
                        inputs = self.qwen2vl_processor(text=[text], images=[image], padding=True, return_tensors="pt").to(self.device)
         
     | 
| 182 | 
         
            -
                        output_hidden_state, image_token_mask, image_grid_thw = self.models['qwen2vl'](**inputs)
         
     | 
| 183 | 
         
            -
                        image_hidden_state = output_hidden_state[image_token_mask].view(1, -1, output_hidden_state.size(-1))
         
     | 
| 184 | 
         
            -
                        image_hidden_state = self.models['connector'](image_hidden_state)
         
     | 
| 185 | 
         
            -
             
     | 
| 186 | 
         
            -
                    return image_hidden_state, image_grid_thw
         
     | 
| 187 | 
         
            -
                
         
     | 
| 188 | 
         
             
                def compute_t5_text_embeddings(self, prompt):
         
     | 
| 189 | 
         
             
                    """Compute T5 embeddings for text prompt"""
         
     | 
| 190 | 
         
             
                    if prompt == "":
         
     | 
| 
         @@ -222,50 +265,39 @@ class FluxInterface: 
     | 
|
| 222 | 
         | 
| 223 | 
         
             
                    return pooled_prompt_embeds
         
     | 
| 224 | 
         | 
| 225 | 
         
            -
                 
     | 
| 226 | 
         
            -
             
     | 
| 227 | 
         
             
                    try:
         
     | 
| 228 | 
         
            -
                        logger.info(f"Starting generation with prompt: {prompt} 
     | 
| 229 | 
         | 
| 230 | 
         
             
                        if input_image is None:
         
     | 
| 231 | 
         
             
                            raise ValueError("No input image provided")
         
     | 
| 232 | 
         | 
| 233 | 
         
             
                        if seed is not None:
         
     | 
| 234 | 
         
             
                            torch.manual_seed(seed)
         
     | 
| 235 | 
         
            -
             
     | 
| 236 | 
         
            -
             
     | 
| 237 | 
         
            -
                        self. 
     | 
| 238 | 
         
            -
                        logger.info("Models loaded successfully")
         
     | 
| 239 | 
         | 
| 240 | 
         
            -
                        #  
     | 
| 241 | 
         
            -
                         
     | 
| 242 | 
         
            -
             
     | 
| 243 | 
         
            -
                        width, height = ASPECT_RATIOS[aspect_ratio]
         
     | 
| 244 | 
         
            -
                        logger.info(f"Using dimensions: {width}x{height}")
         
     | 
| 245 | 
         | 
| 246 | 
         
            -
                        #  
     | 
| 247 | 
         
            -
                         
     | 
| 248 | 
         
            -
             
     | 
| 249 | 
         
            -
                            logger.info(f"Input image resized to: {input_image.size}")
         
     | 
| 250 | 
         
            -
                            qwen2_hidden_state, image_grid_thw = self.process_image(input_image)
         
     | 
| 251 | 
         
            -
                            logger.info("Input image processed successfully")
         
     | 
| 252 | 
         
            -
                        except Exception as e:
         
     | 
| 253 | 
         
            -
                            raise RuntimeError(f"Error processing input image: {str(e)}")
         
     | 
| 254 | 
         | 
| 255 | 
         
            -
                         
     | 
| 256 | 
         
            -
             
     | 
| 257 | 
         
            -
             
     | 
| 258 | 
         
            -
                            
         
     | 
| 259 | 
         
            -
                            # Get T5 embeddings if prompt is provided
         
     | 
| 260 | 
         
            -
                            t5_prompt_embeds = self.compute_t5_text_embeddings(prompt)
         
     | 
| 261 | 
         
            -
                            logger.info("T5 prompt embeddings computed")
         
     | 
| 262 | 
         
            -
                        except Exception as e:
         
     | 
| 263 | 
         
            -
                            raise RuntimeError(f"Error computing embeddings: {str(e)}")
         
     | 
| 264 | 
         | 
| 265 | 
         
            -
                        #  
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 266 | 
         
             
                        try:
         
     | 
| 267 | 
         
             
                            output_images = self.pipeline(
         
     | 
| 268 | 
         
            -
                                prompt_embeds=qwen2_hidden_state.repeat(num_images, 1, 1),
         
     | 
| 269 | 
         
             
                                pooled_prompt_embeds=pooled_prompt_embeds,
         
     | 
| 270 | 
         
             
                                t5_prompt_embeds=t5_prompt_embeds.repeat(num_images, 1, 1) if t5_prompt_embeds is not None else None,
         
     | 
| 271 | 
         
             
                                num_inference_steps=num_inference_steps,
         
     | 
| 
         @@ -274,10 +306,16 @@ class FluxInterface: 
     | 
|
| 274 | 
         
             
                                width=width,
         
     | 
| 275 | 
         
             
                            ).images
         
     | 
| 276 | 
         | 
| 277 | 
         
            -
                             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 278 | 
         
             
                            return output_images
         
     | 
| 
         | 
|
| 279 | 
         
             
                        except Exception as e:
         
     | 
| 280 | 
         
             
                            raise RuntimeError(f"Error generating images: {str(e)}")
         
     | 
| 
         | 
|
| 281 | 
         
             
                    except Exception as e:
         
     | 
| 282 | 
         
             
                        logger.error(f"Error during generation: {str(e)}")
         
     | 
| 283 | 
         
             
                        raise gr.Error(f"Generation failed: {str(e)}")
         
     | 
| 
         | 
|
| 13 | 
         | 
| 14 | 
         
             
            from qwen2_vl.modeling_qwen2_vl import Qwen2VLSimplifiedModel
         
     | 
| 15 | 
         
             
            from huggingface_hub import snapshot_download
         
     | 
| 
         | 
|
| 16 | 
         | 
| 17 | 
         
             
            # 设置日志
         
     | 
| 18 | 
         
             
            logging.basicConfig(
         
     | 
| 
         | 
|
| 77 | 
         
             
                        return
         
     | 
| 78 | 
         | 
| 79 | 
         
             
                    logger.info("Starting model loading...")
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 80 | 
         | 
| 81 | 
         
            +
                    # 1. 首先加载较小的模型到GPU
         
     | 
| 82 | 
         
             
                    tokenizer = CLIPTokenizer.from_pretrained(os.path.join(MODEL_CACHE_DIR, "flux/tokenizer"))
         
     | 
| 83 | 
         
            +
                    text_encoder = CLIPTextModel.from_pretrained(
         
     | 
| 84 | 
         
            +
                        os.path.join(MODEL_CACHE_DIR, "flux/text_encoder")
         
     | 
| 85 | 
         
            +
                    ).to(self.dtype).to(self.device)
         
     | 
| 86 | 
         | 
| 87 | 
         
            +
                    text_encoder_two = T5EncoderModel.from_pretrained(
         
     | 
| 88 | 
         
            +
                        os.path.join(MODEL_CACHE_DIR, "flux/text_encoder_2")
         
     | 
| 89 | 
         
            +
                    ).to(self.dtype).to(self.device)
         
     | 
| 
         | 
|
| 90 | 
         | 
| 91 | 
         
            +
                    tokenizer_two = T5TokenizerFast.from_pretrained(
         
     | 
| 92 | 
         
            +
                        os.path.join(MODEL_CACHE_DIR, "flux/tokenizer_2"))
         
     | 
| 93 | 
         | 
| 94 | 
         
            +
                    # 2. 将大模型初始加载到CPU
         
     | 
| 95 | 
         
            +
                    vae = AutoencoderKL.from_pretrained(
         
     | 
| 96 | 
         
            +
                        os.path.join(MODEL_CACHE_DIR, "flux/vae")
         
     | 
| 97 | 
         
            +
                    ).to(torch.float32).cpu()
         
     | 
| 98 | 
         
            +
                    
         
     | 
| 99 | 
         
            +
                    transformer = FluxTransformer2DModel.from_pretrained(
         
     | 
| 100 | 
         
            +
                        os.path.join(MODEL_CACHE_DIR, "flux/transformer")
         
     | 
| 101 | 
         
            +
                    ).to(torch.float32).cpu()
         
     | 
| 102 | 
         
            +
                    
         
     | 
| 103 | 
         
            +
                    scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
         
     | 
| 104 | 
         
            +
                        os.path.join(MODEL_CACHE_DIR, "flux/scheduler"), 
         
     | 
| 105 | 
         
            +
                        shift=1
         
     | 
| 106 | 
         
            +
                    )
         
     | 
| 107 | 
         
            +
                    
         
     | 
| 108 | 
         
            +
                    # 3. Qwen2VL初始加载到CPU
         
     | 
| 109 | 
         
            +
                    qwen2vl = Qwen2VLSimplifiedModel.from_pretrained(
         
     | 
| 110 | 
         
            +
                        os.path.join(MODEL_CACHE_DIR, "qwen2-vl")
         
     | 
| 111 | 
         
            +
                    ).to(torch.float32).cpu()
         
     | 
| 112 | 
         
            +
                    
         
     | 
| 113 | 
         
            +
                    # 4. 加载connector和embedder到CPU
         
     | 
| 114 | 
         
            +
                    connector = Qwen2Connector().to(torch.float32).cpu()
         
     | 
| 115 | 
         
             
                    connector_path = os.path.join(MODEL_CACHE_DIR, "qwen2-vl/connector.pt")
         
     | 
| 116 | 
         
             
                    connector_state = torch.load(connector_path, map_location='cpu')
         
     | 
| 
         | 
|
| 117 | 
         
             
                    connector.load_state_dict(connector_state)
         
     | 
| 118 | 
         
            +
                    
         
     | 
| 119 | 
         
            +
                    self.t5_context_embedder = nn.Linear(4096, 3072).to(torch.float32).cpu()
         
     | 
| 
         | 
|
| 
         | 
|
| 120 | 
         
             
                    t5_embedder_path = os.path.join(MODEL_CACHE_DIR, "qwen2-vl/t5_embedder.pt")
         
     | 
| 121 | 
         
             
                    t5_embedder_state = torch.load(t5_embedder_path, map_location='cpu')
         
     | 
| 
         | 
|
| 122 | 
         
             
                    self.t5_context_embedder.load_state_dict(t5_embedder_state)
         
     | 
| 
         | 
|
| 123 | 
         | 
| 124 | 
         
            +
                    # 5. 设置所有模型为eval模式
         
     | 
| 125 | 
         
            +
                    for model in [text_encoder, text_encoder_two, vae, transformer, qwen2vl, 
         
     | 
| 126 | 
         
            +
                                 connector, self.t5_context_embedder]:
         
     | 
| 127 | 
         
             
                        model.requires_grad_(False)
         
     | 
| 128 | 
         
             
                        model.eval()
         
     | 
| 129 | 
         | 
| 
         | 
|
| 143 | 
         | 
| 144 | 
         
             
                    # Initialize processor and pipeline
         
     | 
| 145 | 
         
             
                    self.qwen2vl_processor = AutoProcessor.from_pretrained(
         
     | 
| 146 | 
         
            +
                        self.MODEL_ID,
         
     | 
| 147 | 
         
             
                        subfolder="qwen2-vl",
         
     | 
| 148 | 
         
            +
                        min_pixels=256*28*28,
         
     | 
| 149 | 
         
             
                        max_pixels=256*28*28
         
     | 
| 150 | 
         
             
                    )
         
     | 
| 151 | 
         | 
| 
         | 
|
| 155 | 
         
             
                        vae=vae,
         
     | 
| 156 | 
         
             
                        text_encoder=text_encoder,
         
     | 
| 157 | 
         
             
                        tokenizer=tokenizer,
         
     | 
| 158 | 
         
            +
                    )
         
     | 
| 159 | 
         
            +
             
     | 
| 160 | 
         
            +
                def move_to_device(self, model, device):
         
     | 
| 161 | 
         
            +
                    """Helper function to move model to specified device"""
         
     | 
| 162 | 
         
            +
                    if hasattr(model, 'to'):
         
     | 
| 163 | 
         
            +
                        return model.to(device)
         
     | 
| 164 | 
         
            +
                    return model
         
     | 
| 165 | 
         
            +
             
     | 
| 166 | 
         
            +
                def process_image(self, image):
         
     | 
| 167 | 
         
            +
                    """Process image with Qwen2VL model"""
         
     | 
| 168 | 
         
            +
                    try:
         
     | 
| 169 | 
         
            +
                        # 1. 将Qwen2VL相关模型移到GPU
         
     | 
| 170 | 
         
            +
                        self.models['qwen2vl'] = self.move_to_device(self.models['qwen2vl'], self.device)
         
     | 
| 171 | 
         
            +
                        self.models['connector'] = self.move_to_device(self.models['connector'], self.device)
         
     | 
| 172 | 
         
            +
                        
         
     | 
| 173 | 
         
            +
                        message = [
         
     | 
| 174 | 
         
            +
                            {
         
     | 
| 175 | 
         
            +
                                "role": "user",
         
     | 
| 176 | 
         
            +
                                "content": [
         
     | 
| 177 | 
         
            +
                                    {"type": "image", "image": image},
         
     | 
| 178 | 
         
            +
                                    {"type": "text", "text": "Describe this image."},
         
     | 
| 179 | 
         
            +
                                ]
         
     | 
| 180 | 
         
            +
                            }
         
     | 
| 181 | 
         
            +
                        ]
         
     | 
| 182 | 
         
            +
                        text = self.qwen2vl_processor.apply_chat_template(
         
     | 
| 183 | 
         
            +
                            message, 
         
     | 
| 184 | 
         
            +
                            tokenize=False, 
         
     | 
| 185 | 
         
            +
                            add_generation_prompt=True
         
     | 
| 186 | 
         
            +
                        )
         
     | 
| 187 | 
         
            +
             
     | 
| 188 | 
         
            +
                        with torch.no_grad():
         
     | 
| 189 | 
         
            +
                            inputs = self.qwen2vl_processor(
         
     | 
| 190 | 
         
            +
                                text=[text], 
         
     | 
| 191 | 
         
            +
                                images=[image], 
         
     | 
| 192 | 
         
            +
                                padding=True, 
         
     | 
| 193 | 
         
            +
                                return_tensors="pt"
         
     | 
| 194 | 
         
            +
                            ).to(self.device)
         
     | 
| 195 | 
         
            +
                            
         
     | 
| 196 | 
         
            +
                            output_hidden_state, image_token_mask, image_grid_thw = self.models['qwen2vl'](**inputs)
         
     | 
| 197 | 
         
            +
                            image_hidden_state = output_hidden_state[image_token_mask].view(1, -1, output_hidden_state.size(-1))
         
     | 
| 198 | 
         
            +
                            image_hidden_state = self.models['connector'](image_hidden_state)
         
     | 
| 199 | 
         
            +
                            
         
     | 
| 200 | 
         
            +
                            # 保存结果到CPU
         
     | 
| 201 | 
         
            +
                            result = (image_hidden_state.cpu(), image_grid_thw)
         
     | 
| 202 | 
         
            +
                        
         
     | 
| 203 | 
         
            +
                        # 2. 将Qwen2VL相关模型移回CPU以释放显存
         
     | 
| 204 | 
         
            +
                        self.models['qwen2vl'] = self.move_to_device(self.models['qwen2vl'], 'cpu')
         
     | 
| 205 | 
         
            +
                        self.models['connector'] = self.move_to_device(self.models['connector'], 'cpu')
         
     | 
| 206 | 
         
            +
                        torch.cuda.empty_cache()
         
     | 
| 207 | 
         
            +
                        
         
     | 
| 208 | 
         
            +
                        return result
         
     | 
| 209 | 
         
            +
                        
         
     | 
| 210 | 
         
            +
                    except Exception as e:
         
     | 
| 211 | 
         
            +
                        logger.error(f"Error in process_image: {str(e)}")
         
     | 
| 212 | 
         
            +
                        raise
         
     | 
| 213 | 
         | 
| 214 | 
         
             
                def resize_image(self, img, max_pixels=1050000):
         
     | 
| 215 | 
         
             
                    if not isinstance(img, Image.Image):
         
     | 
| 
         | 
|
| 227 | 
         
             
                        img = img.resize((new_width, new_height), Image.LANCZOS)
         
     | 
| 228 | 
         | 
| 229 | 
         
             
                    return img
         
     | 
| 230 | 
         
            +
             
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 231 | 
         
             
                def compute_t5_text_embeddings(self, prompt):
         
     | 
| 232 | 
         
             
                    """Compute T5 embeddings for text prompt"""
         
     | 
| 233 | 
         
             
                    if prompt == "":
         
     | 
| 
         | 
|
| 265 | 
         | 
| 266 | 
         
             
                    return pooled_prompt_embeds
         
     | 
| 267 | 
         | 
| 268 | 
         
            +
                def generate(self, input_image, prompt="", guidance_scale=3.5, 
         
     | 
| 269 | 
         
            +
                            num_inference_steps=28, num_images=2, seed=None, aspect_ratio="1:1"):
         
     | 
| 270 | 
         
             
                    try:
         
     | 
| 271 | 
         
            +
                        logger.info(f"Starting generation with prompt: {prompt}")
         
     | 
| 272 | 
         | 
| 273 | 
         
             
                        if input_image is None:
         
     | 
| 274 | 
         
             
                            raise ValueError("No input image provided")
         
     | 
| 275 | 
         | 
| 276 | 
         
             
                        if seed is not None:
         
     | 
| 277 | 
         
             
                            torch.manual_seed(seed)
         
     | 
| 278 | 
         
            +
                             
         
     | 
| 279 | 
         
            +
                        # 1. 使用Qwen2VL处理图像
         
     | 
| 280 | 
         
            +
                        qwen2_hidden_state, image_grid_thw = self.process_image(input_image)
         
     | 
| 
         | 
|
| 281 | 
         | 
| 282 | 
         
            +
                        # 2. 计算文本嵌入
         
     | 
| 283 | 
         
            +
                        pooled_prompt_embeds = self.compute_text_embeddings("")
         
     | 
| 284 | 
         
            +
                        t5_prompt_embeds = self.compute_t5_text_embeddings(prompt)
         
     | 
| 
         | 
|
| 
         | 
|
| 285 | 
         | 
| 286 | 
         
            +
                        # 3. 将Transformer和VAE移到GPU
         
     | 
| 287 | 
         
            +
                        self.models['transformer'] = self.move_to_device(self.models['transformer'], self.device)
         
     | 
| 288 | 
         
            +
                        self.models['vae'] = self.move_to_device(self.models['vae'], self.device)
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 289 | 
         | 
| 290 | 
         
            +
                        # 更新pipeline中的模型
         
     | 
| 291 | 
         
            +
                        self.pipeline.transformer = self.models['transformer']
         
     | 
| 292 | 
         
            +
                        self.pipeline.vae = self.models['vae']
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 293 | 
         | 
| 294 | 
         
            +
                        # 获取维度
         
     | 
| 295 | 
         
            +
                        width, height = ASPECT_RATIOS[aspect_ratio]
         
     | 
| 296 | 
         
            +
                        
         
     | 
| 297 | 
         
            +
                        # 4. 生成图像
         
     | 
| 298 | 
         
             
                        try:
         
     | 
| 299 | 
         
             
                            output_images = self.pipeline(
         
     | 
| 300 | 
         
            +
                                prompt_embeds=qwen2_hidden_state.to(self.device).repeat(num_images, 1, 1),
         
     | 
| 301 | 
         
             
                                pooled_prompt_embeds=pooled_prompt_embeds,
         
     | 
| 302 | 
         
             
                                t5_prompt_embeds=t5_prompt_embeds.repeat(num_images, 1, 1) if t5_prompt_embeds is not None else None,
         
     | 
| 303 | 
         
             
                                num_inference_steps=num_inference_steps,
         
     | 
| 
         | 
|
| 306 | 
         
             
                                width=width,
         
     | 
| 307 | 
         
             
                            ).images
         
     | 
| 308 | 
         | 
| 309 | 
         
            +
                            # 5. 将Transformer和VAE移回CPU
         
     | 
| 310 | 
         
            +
                            self.models['transformer'] = self.move_to_device(self.models['transformer'], 'cpu')
         
     | 
| 311 | 
         
            +
                            self.models['vae'] = self.move_to_device(self.models['vae'], 'cpu')
         
     | 
| 312 | 
         
            +
                            torch.cuda.empty_cache()
         
     | 
| 313 | 
         
            +
                            
         
     | 
| 314 | 
         
             
                            return output_images
         
     | 
| 315 | 
         
            +
                            
         
     | 
| 316 | 
         
             
                        except Exception as e:
         
     | 
| 317 | 
         
             
                            raise RuntimeError(f"Error generating images: {str(e)}")
         
     | 
| 318 | 
         
            +
                            
         
     | 
| 319 | 
         
             
                    except Exception as e:
         
     | 
| 320 | 
         
             
                        logger.error(f"Error during generation: {str(e)}")
         
     | 
| 321 | 
         
             
                        raise gr.Error(f"Generation failed: {str(e)}")
         
     |