Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			Zero
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			Zero
	| # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates | |
| # Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import torch | |
| import yaml, os | |
| from PIL import Image | |
| from diffusers.pipelines import FluxPipeline | |
| from typing import List, Union, Optional, Dict, Any, Callable | |
| from src.flux.transformer import tranformer_forward | |
| from src.flux.condition import Condition | |
| from diffusers.pipelines.flux.pipeline_flux import ( | |
| FluxPipelineOutput, | |
| calculate_shift, | |
| retrieve_timesteps, | |
| np, | |
| ) | |
| from src.flux.pipeline_tools import ( | |
| encode_prompt_with_clip_t5, tokenize_t5_prompt, clear_attn_maps, encode_vae_images | |
| ) | |
| from src.flux.pipeline_tools import CustomFluxPipeline, load_modulation_adapter, decode_vae_images, \ | |
| save_attention_maps, gather_attn_maps, clear_attn_maps, load_dit_lora, quantization | |
| from src.utils.data_utils import pad_to_square, pad_to_target, pil2tensor, get_closest_ratio, get_aspect_ratios | |
| from src.utils.modulation_utils import get_word_index, unpad_input_ids | |
| def get_config(config_path: str = None): | |
| config_path = config_path or os.environ.get("XFL_CONFIG") | |
| if not config_path: | |
| return {} | |
| with open(config_path, "r") as f: | |
| config = yaml.safe_load(f) | |
| return config | |
| def prepare_params( | |
| prompt: Union[str, List[str]] = None, | |
| prompt_2: Optional[Union[str, List[str]]] = None, | |
| height: Optional[int] = 512, | |
| width: Optional[int] = 512, | |
| num_inference_steps: int = 28, | |
| timesteps: List[int] = None, | |
| guidance_scale: float = 3.5, | |
| num_images_per_prompt: Optional[int] = 1, | |
| generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, | |
| latents: Optional[torch.FloatTensor] = None, | |
| prompt_embeds: Optional[torch.FloatTensor] = None, | |
| pooled_prompt_embeds: Optional[torch.FloatTensor] = None, | |
| output_type: Optional[str] = "pil", | |
| return_dict: bool = True, | |
| joint_attention_kwargs: Optional[Dict[str, Any]] = None, | |
| callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, | |
| callback_on_step_end_tensor_inputs: List[str] = ["latents"], | |
| max_sequence_length: int = 512, | |
| verbose: bool = False, | |
| **kwargs: dict, | |
| ): | |
| return ( | |
| prompt, | |
| prompt_2, | |
| height, | |
| width, | |
| num_inference_steps, | |
| timesteps, | |
| guidance_scale, | |
| num_images_per_prompt, | |
| generator, | |
| latents, | |
| prompt_embeds, | |
| pooled_prompt_embeds, | |
| output_type, | |
| return_dict, | |
| joint_attention_kwargs, | |
| callback_on_step_end, | |
| callback_on_step_end_tensor_inputs, | |
| max_sequence_length, | |
| verbose, | |
| ) | |
| def seed_everything(seed: int = 42): | |
| torch.backends.cudnn.deterministic = True | |
| torch.manual_seed(seed) | |
| np.random.seed(seed) | |
| def generate( | |
| pipeline: FluxPipeline, | |
| vae_conditions: List[Condition] = None, | |
| config_path: str = None, | |
| model_config: Optional[Dict[str, Any]] = {}, | |
| vae_condition_scale: float = 1.0, | |
| default_lora: bool = False, | |
| condition_pad_to: str = "square", | |
| condition_size: int = 512, | |
| text_cond_mask: Optional[torch.FloatTensor] = None, | |
| delta_emb: Optional[torch.FloatTensor] = None, | |
| delta_emb_pblock: Optional[torch.FloatTensor] = None, | |
| delta_emb_mask: Optional[torch.FloatTensor] = None, | |
| delta_start_ends = None, | |
| condition_latents = None, | |
| condition_ids = None, | |
| mod_adapter = None, | |
| store_attn_map: bool = False, | |
| vae_skip_iter: str = None, | |
| control_weight_lambda: str = None, | |
| double_attention: bool = False, | |
| single_attention: bool = False, | |
| ip_scale: str = None, | |
| use_latent_sblora_control: bool = False, | |
| latent_sblora_scale: str = None, | |
| use_condition_sblora_control: bool = False, | |
| condition_sblora_scale: str = None, | |
| idips = None, | |
| **params: dict, | |
| ): | |
| model_config = model_config or get_config(config_path).get("model", {}) | |
| vae_skip_iter = model_config.get("vae_skip_iter", vae_skip_iter) | |
| double_attention = model_config.get("double_attention", double_attention) | |
| single_attention = model_config.get("single_attention", single_attention) | |
| control_weight_lambda = model_config.get("control_weight_lambda", control_weight_lambda) | |
| ip_scale = model_config.get("ip_scale", ip_scale) | |
| use_latent_sblora_control = model_config.get("use_latent_sblora_control", use_latent_sblora_control) | |
| use_condition_sblora_control = model_config.get("use_condition_sblora_control", use_condition_sblora_control) | |
| latent_sblora_scale = model_config.get("latent_sblora_scale", latent_sblora_scale) | |
| condition_sblora_scale = model_config.get("condition_sblora_scale", condition_sblora_scale) | |
| model_config["use_attention_double"] = False | |
| model_config["use_attention_single"] = False | |
| use_attention = False | |
| if idips is not None: | |
| if control_weight_lambda != "no": | |
| parts = control_weight_lambda.split(',') | |
| new_parts = [] | |
| for part in parts: | |
| if ':' in part: | |
| left, right = part.split(':') | |
| values = right.split('/') | |
| # 保存整体值 | |
| global_value = values[0] | |
| id_value = values[1] | |
| ip_value = values[2] | |
| new_values = [global_value] | |
| for is_id in idips: | |
| if is_id: | |
| new_values.append(id_value) | |
| else: | |
| new_values.append(ip_value) | |
| new_part = f"{left}:{('/'.join(new_values))}" | |
| new_parts.append(new_part) | |
| else: | |
| new_parts.append(part) | |
| control_weight_lambda = ','.join(new_parts) | |
| if vae_condition_scale != 1: | |
| for name, module in pipeline.transformer.named_modules(): | |
| if not name.endswith(".attn"): | |
| continue | |
| module.c_factor = torch.ones(1, 1) * vae_condition_scale | |
| self = pipeline | |
| ( | |
| prompt, | |
| prompt_2, | |
| height, | |
| width, | |
| num_inference_steps, | |
| timesteps, | |
| guidance_scale, | |
| num_images_per_prompt, | |
| generator, | |
| latents, | |
| prompt_embeds, | |
| pooled_prompt_embeds, | |
| output_type, | |
| return_dict, | |
| joint_attention_kwargs, | |
| callback_on_step_end, | |
| callback_on_step_end_tensor_inputs, | |
| max_sequence_length, | |
| verbose, | |
| ) = prepare_params(**params) | |
| height = height or self.default_sample_size * self.vae_scale_factor | |
| width = width or self.default_sample_size * self.vae_scale_factor | |
| # 1. Check inputs. Raise error if not correct | |
| self.check_inputs( | |
| prompt, | |
| prompt_2, | |
| height, | |
| width, | |
| prompt_embeds=prompt_embeds, | |
| pooled_prompt_embeds=pooled_prompt_embeds, | |
| callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, | |
| max_sequence_length=max_sequence_length, | |
| ) | |
| self._guidance_scale = guidance_scale | |
| self._joint_attention_kwargs = joint_attention_kwargs | |
| self._interrupt = False | |
| # 2. Define call parameters | |
| if prompt is not None and isinstance(prompt, str): | |
| batch_size = 1 | |
| elif prompt is not None and isinstance(prompt, list): | |
| batch_size = len(prompt) | |
| else: | |
| batch_size = prompt_embeds.shape[0] | |
| device = self._execution_device | |
| lora_scale = ( | |
| self.joint_attention_kwargs.get("scale", None) | |
| if self.joint_attention_kwargs is not None | |
| else None | |
| ) | |
| ( | |
| t5_prompt_embeds, | |
| pooled_prompt_embeds, | |
| text_ids, | |
| ) = encode_prompt_with_clip_t5( | |
| self=self, | |
| prompt="" if self.text_encoder_2 is None else prompt, | |
| prompt_2=None, | |
| prompt_embeds=prompt_embeds, | |
| pooled_prompt_embeds=pooled_prompt_embeds, | |
| device=device, | |
| num_images_per_prompt=num_images_per_prompt, | |
| max_sequence_length=max_sequence_length, | |
| lora_scale=lora_scale, | |
| ) | |
| # 4. Prepare latent variables | |
| num_channels_latents = self.transformer.config.in_channels // 4 | |
| latents, latent_image_ids = self.prepare_latents( | |
| batch_size * num_images_per_prompt, | |
| num_channels_latents, | |
| height, | |
| width, | |
| pooled_prompt_embeds.dtype, | |
| device, | |
| generator, | |
| latents, | |
| ) | |
| latent_height = height // 16 | |
| # 5. Prepare timesteps | |
| sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) | |
| image_seq_len = latents.shape[1] | |
| mu = calculate_shift( | |
| image_seq_len, | |
| self.scheduler.config.base_image_seq_len, | |
| self.scheduler.config.max_image_seq_len, | |
| self.scheduler.config.base_shift, | |
| self.scheduler.config.max_shift, | |
| ) | |
| timesteps, num_inference_steps = retrieve_timesteps( | |
| self.scheduler, | |
| num_inference_steps, | |
| device, | |
| timesteps, | |
| sigmas, | |
| mu=mu, | |
| ) | |
| num_warmup_steps = max( | |
| len(timesteps) - num_inference_steps * self.scheduler.order, 0 | |
| ) | |
| self._num_timesteps = len(timesteps) | |
| attn_map = None | |
| # 6. Denoising loop | |
| with self.progress_bar(total=num_inference_steps) as progress_bar: | |
| totalsteps = timesteps[0] | |
| if control_weight_lambda is not None: | |
| print("control_weight_lambda", control_weight_lambda) | |
| control_weight_lambda_schedule = [] | |
| for scale_str in control_weight_lambda.split(','): | |
| time_region, scale = scale_str.split(':') | |
| start, end = time_region.split('-') | |
| scales = [float(s) for s in scale.split('/')] | |
| control_weight_lambda_schedule.append([(1-float(start))*totalsteps, (1-float(end))*totalsteps, scales]) | |
| if ip_scale is not None: | |
| print("ip_scale", ip_scale) | |
| ip_scale_schedule = [] | |
| for scale_str in ip_scale.split(','): | |
| time_region, scale = scale_str.split(':') | |
| start, end = time_region.split('-') | |
| ip_scale_schedule.append([(1-float(start))*totalsteps, (1-float(end))*totalsteps, float(scale)]) | |
| if use_latent_sblora_control: | |
| if latent_sblora_scale is not None: | |
| print("latent_sblora_scale", latent_sblora_scale) | |
| latent_sblora_scale_schedule = [] | |
| for scale_str in latent_sblora_scale.split(','): | |
| time_region, scale = scale_str.split(':') | |
| start, end = time_region.split('-') | |
| latent_sblora_scale_schedule.append([(1-float(start))*totalsteps, (1-float(end))*totalsteps, float(scale)]) | |
| if use_condition_sblora_control: | |
| if condition_sblora_scale is not None: | |
| print("condition_sblora_scale", condition_sblora_scale) | |
| condition_sblora_scale_schedule = [] | |
| for scale_str in condition_sblora_scale.split(','): | |
| time_region, scale = scale_str.split(':') | |
| start, end = time_region.split('-') | |
| condition_sblora_scale_schedule.append([(1-float(start))*totalsteps, (1-float(end))*totalsteps, float(scale)]) | |
| if vae_skip_iter is not None: | |
| print("vae_skip_iter", vae_skip_iter) | |
| vae_skip_iter_schedule = [] | |
| for scale_str in vae_skip_iter.split(','): | |
| time_region, scale = scale_str.split(':') | |
| start, end = time_region.split('-') | |
| vae_skip_iter_schedule.append([(1-float(start))*totalsteps, (1-float(end))*totalsteps, float(scale)]) | |
| if control_weight_lambda is not None and attn_map is None: | |
| batch_size = latents.shape[0] | |
| latent_width = latents.shape[1]//latent_height | |
| attn_map = torch.ones(batch_size, latent_height, latent_width, 128, device=latents.device, dtype=torch.bfloat16) | |
| print("contol_weight_only", attn_map.shape) | |
| self.scheduler.set_begin_index(0) | |
| self.scheduler._init_step_index(0) | |
| for i, t in enumerate(timesteps): | |
| if control_weight_lambda is not None: | |
| cur_control_weight_lambda = [] | |
| for start, end, scale in control_weight_lambda_schedule: | |
| if t <= start and t >= end: | |
| cur_control_weight_lambda = scale | |
| break | |
| print(f"timestep:{t}, cur_control_weight_lambda:{cur_control_weight_lambda}") | |
| if cur_control_weight_lambda: | |
| model_config["use_attention_single"] = True | |
| use_attention = True | |
| model_config["use_atten_lambda"] = cur_control_weight_lambda | |
| else: | |
| model_config["use_attention_single"] = False | |
| use_attention = False | |
| if self.interrupt: | |
| continue | |
| if isinstance(delta_emb, list): | |
| cur_delta_emb = delta_emb[i] | |
| cur_delta_emb_pblock = delta_emb_pblock[i] | |
| cur_delta_emb_mask = delta_emb_mask[i] | |
| else: | |
| cur_delta_emb = delta_emb | |
| cur_delta_emb_pblock = delta_emb_pblock | |
| cur_delta_emb_mask = delta_emb_mask | |
| # broadcast to batch dimension in a way that's compatible with ONNX/Core ML | |
| timestep = t.expand(latents.shape[0]).to(latents.dtype) / 1000 | |
| prompt_embeds = t5_prompt_embeds | |
| text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=prompt_embeds.dtype) | |
| # handle guidance | |
| if self.transformer.config.guidance_embeds: | |
| guidance = torch.tensor([guidance_scale], device=device) | |
| guidance = guidance.expand(latents.shape[0]) | |
| else: | |
| guidance = None | |
| self.transformer.enable_lora() | |
| lora_weight = 1 | |
| if ip_scale is not None: | |
| lora_weight = 0 | |
| for start, end, scale in ip_scale_schedule: | |
| if t <= start and t >= end: | |
| lora_weight = scale | |
| break | |
| if lora_weight != 1: print(f"timestep:{t}, lora_weights:{lora_weight}") | |
| latent_sblora_weight = None | |
| if use_latent_sblora_control: | |
| if latent_sblora_scale is not None: | |
| latent_sblora_weight = 0 | |
| for start, end, scale in latent_sblora_scale_schedule: | |
| if t <= start and t >= end: | |
| latent_sblora_weight = scale | |
| break | |
| if latent_sblora_weight != 1: print(f"timestep:{t}, latent_sblora_weight:{latent_sblora_weight}") | |
| condition_sblora_weight = None | |
| if use_condition_sblora_control: | |
| if condition_sblora_scale is not None: | |
| condition_sblora_weight = 0 | |
| for start, end, scale in condition_sblora_scale_schedule: | |
| if t <= start and t >= end: | |
| condition_sblora_weight = scale | |
| break | |
| if condition_sblora_weight !=1: print(f"timestep:{t}, condition_sblora_weight:{condition_sblora_weight}") | |
| vae_skip_iter_t = False | |
| if vae_skip_iter is not None: | |
| for start, end, scale in vae_skip_iter_schedule: | |
| if t <= start and t >= end: | |
| vae_skip_iter_t = bool(scale) | |
| break | |
| if vae_skip_iter_t: | |
| print(f"timestep:{t}, skip vae:{vae_skip_iter_t}") | |
| noise_pred = tranformer_forward( | |
| self.transformer, | |
| model_config=model_config, | |
| # Inputs of the condition (new feature) | |
| text_cond_mask=text_cond_mask, | |
| delta_emb=cur_delta_emb, | |
| delta_emb_pblock=cur_delta_emb_pblock, | |
| delta_emb_mask=cur_delta_emb_mask, | |
| delta_start_ends=delta_start_ends, | |
| condition_latents=None if vae_skip_iter_t else condition_latents, | |
| condition_ids=None if vae_skip_iter_t else condition_ids, | |
| condition_type_ids=None, | |
| # Inputs to the original transformer | |
| hidden_states=latents, | |
| # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing) | |
| timestep=timestep, | |
| guidance=guidance, | |
| pooled_projections=pooled_prompt_embeds, | |
| encoder_hidden_states=prompt_embeds, | |
| txt_ids=text_ids, | |
| img_ids=latent_image_ids, | |
| joint_attention_kwargs={'scale': lora_weight, "latent_sblora_weight": latent_sblora_weight, "condition_sblora_weight": condition_sblora_weight}, | |
| store_attn_map=use_attention, | |
| last_attn_map=attn_map if cur_control_weight_lambda else None, | |
| use_text_mod=model_config["modulation"]["use_text_mod"], | |
| use_img_mod=model_config["modulation"]["use_img_mod"], | |
| mod_adapter=mod_adapter, | |
| latent_height=latent_height, | |
| return_dict=False, | |
| )[0] | |
| if use_attention: | |
| attn_maps, _ = gather_attn_maps(self.transformer, clear=True) | |
| # compute the previous noisy sample x_t -> x_t-1 | |
| latents_dtype = latents.dtype | |
| latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] | |
| if latents.dtype != latents_dtype: | |
| if torch.backends.mps.is_available(): | |
| # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 | |
| latents = latents.to(latents_dtype) | |
| if callback_on_step_end is not None: | |
| callback_kwargs = {} | |
| for k in callback_on_step_end_tensor_inputs: | |
| callback_kwargs[k] = locals()[k] | |
| callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) | |
| latents = callback_outputs.pop("latents", latents) | |
| prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) | |
| # call the callback, if provided | |
| if i == len(timesteps) - 1 or ( | |
| (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 | |
| ): | |
| progress_bar.update() | |
| if output_type == "latent": | |
| image = latents | |
| else: | |
| latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) | |
| latents = ( | |
| latents / self.vae.config.scaling_factor | |
| ) + self.vae.config.shift_factor | |
| image = self.vae.decode(latents, return_dict=False)[0] | |
| image = self.image_processor.postprocess(image, output_type=output_type) | |
| # Offload all models | |
| self.maybe_free_model_hooks() | |
| self.transformer.enable_lora() | |
| if vae_condition_scale != 1: | |
| for name, module in pipeline.transformer.named_modules(): | |
| if not name.endswith(".attn"): | |
| continue | |
| del module.c_factor | |
| if not return_dict: | |
| return (image,) | |
| return FluxPipelineOutput(images=image) | |
| def generate_from_test_sample( | |
| test_sample, pipe, config, | |
| num_images=1, | |
| vae_skip_iter: str = None, | |
| target_height: int = None, | |
| target_width: int = None, | |
| seed: int = 42, | |
| control_weight_lambda: str = None, | |
| double_attention: bool = False, | |
| single_attention: bool = False, | |
| ip_scale: str = None, | |
| use_latent_sblora_control: bool = False, | |
| latent_sblora_scale: str = None, | |
| use_condition_sblora_control: bool = False, | |
| condition_sblora_scale: str = None, | |
| use_idip = False, | |
| **kargs | |
| ): | |
| target_size = config["train"]["dataset"]["val_target_size"] | |
| condition_size = config["train"]["dataset"].get("val_condition_size", target_size//2) | |
| condition_pad_to = config["train"]["dataset"]["condition_pad_to"] | |
| pos_offset_type = config["model"].get("pos_offset_type", "width") | |
| seed = config["model"].get("seed", seed) | |
| device = pipe._execution_device | |
| condition_imgs = test_sample['input_images'] | |
| position_delta = test_sample['position_delta'] | |
| prompt = test_sample['prompt'] | |
| original_image = test_sample.get('original_image', None) | |
| condition_type = test_sample.get('condition_type', "subject") | |
| modulation_input = test_sample.get('modulation', None) | |
| delta_start_ends = None | |
| condition_latents = condition_ids = None | |
| text_cond_mask = None | |
| delta_embs = None | |
| delta_embs_pblock = None | |
| delta_embs_mask = None | |
| try: | |
| max_length = config["model"]["modulation"]["max_text_len"] | |
| except Exception as e: | |
| print(e) | |
| max_length = 512 | |
| if modulation_input is None or len(modulation_input) == 0: | |
| delta_emb = delta_emb_pblock = delta_emb_mask = None | |
| else: | |
| dtype = torch.bfloat16 | |
| batch_size = 1 | |
| N = config["model"]["modulation"].get("per_block_adapter_single_blocks", 0) + 19 | |
| guidance = torch.tensor([3.5]).to(device).expand(batch_size) | |
| out_dim = config["model"]["modulation"]["out_dim"] | |
| tar_text_inputs = tokenize_t5_prompt(pipe, prompt, max_length) | |
| tar_padding_mask = tar_text_inputs.attention_mask.to(device).bool() | |
| tar_tokens = tar_text_inputs.input_ids.to(device) | |
| if config["model"]["modulation"]["eos_exclude"]: | |
| tar_padding_mask[tar_tokens == 1] = False | |
| def get_start_end_by_pompt_matching(src_prompts, tar_prompts): | |
| text_cond_mask = torch.zeros(batch_size, max_length, device=device, dtype=torch.bool) | |
| tar_prompt_input_ids = tokenize_t5_prompt(pipe, tar_prompts, max_length).input_ids | |
| src_prompt_count = 1 | |
| start_ends = [] | |
| for i, (src_prompt, tar_prompt, tar_prompt_tokens) in enumerate(zip(src_prompts, tar_prompts, tar_prompt_input_ids)): | |
| try: | |
| tar_start, tar_end = get_word_index(pipe, tar_prompt, tar_prompt_tokens, src_prompt, src_prompt_count, max_length, verbose=False) | |
| start_ends.append([tar_start, tar_end]) | |
| text_cond_mask[i, tar_start:tar_end] = True | |
| except Exception as e: | |
| print(e) | |
| return start_ends, text_cond_mask | |
| def encode_mod_image(pil_images): | |
| if config["model"]["modulation"]["use_dit"]: | |
| raise NotImplementedError() | |
| else: | |
| pil_images = [pad_to_square(img).resize((224, 224)) for img in pil_images] | |
| if config["model"]["modulation"]["use_vae"]: | |
| raise NotImplementedError() | |
| else: | |
| clip_pixel_values = pipe.clip_processor( | |
| text=None, images=pil_images, do_resize=False, do_center_crop=False, return_tensors="pt", | |
| ).pixel_values.to(dtype=dtype, device=device) | |
| clip_outputs = pipe.clip_model(clip_pixel_values, output_hidden_states=True, interpolate_pos_encoding=True, return_dict=True) | |
| return clip_outputs | |
| def rgba_to_white_background(input_path, background=(255,255,255)): | |
| with Image.open(input_path).convert("RGBA") as img: | |
| img_np = np.array(img) | |
| alpha = img_np[:, :, 3] / 255.0 # 归一化Alpha通道[3](@ref) | |
| rgb = img_np[:, :, :3].astype(float) # 提取RGB通道 | |
| background_np = np.full_like(rgb, background, dtype=float) # 根据参数生成背景[7](@ref) | |
| # 混合计算:前景色*alpha + 背景色*(1-alpha) | |
| result_np = rgb * alpha[..., np.newaxis] + \ | |
| background_np * (1 - alpha[..., np.newaxis]) | |
| result = Image.fromarray(result_np.astype(np.uint8), "RGB") | |
| return result | |
| def get_mod_emb(modulation_input, timestep): | |
| delta_emb = torch.zeros((batch_size, max_length, out_dim), dtype=dtype, device=device) | |
| delta_emb_pblock = torch.zeros((batch_size, max_length, N, out_dim), dtype=dtype, device=device) | |
| delta_emb_mask = torch.zeros((batch_size, max_length), dtype=torch.bool, device=device) | |
| delta_start_ends = None | |
| condition_latents = condition_ids = None | |
| text_cond_mask = None | |
| if modulation_input[0]["type"] == "adapter": | |
| num_inputs = len(modulation_input[0]["src_inputs"]) | |
| src_prompts = [x["caption"] for x in modulation_input[0]["src_inputs"]] | |
| src_text_inputs = tokenize_t5_prompt(pipe, src_prompts, max_length) | |
| src_input_ids = unpad_input_ids(src_text_inputs.input_ids, src_text_inputs.attention_mask) | |
| tar_input_ids = unpad_input_ids(tar_text_inputs.input_ids, tar_text_inputs.attention_mask) | |
| src_prompt_embeds = pipe._get_t5_prompt_embeds(prompt=src_prompts, max_sequence_length=max_length, device=device) # (M, 512, 4096) | |
| pil_images = [rgba_to_white_background(x["image_path"]) for x in modulation_input[0]["src_inputs"]] | |
| src_ds_scales = [x.get("downsample_scale", 1.0) for x in modulation_input[0]["src_inputs"]] | |
| resized_pil_images = [] | |
| for img, ds_scale in zip(pil_images, src_ds_scales): | |
| img = pad_to_square(img) | |
| if ds_scale < 1.0: | |
| assert ds_scale > 0 | |
| img = img.resize((int(224 * ds_scale), int(224 * ds_scale))).resize((224, 224)) | |
| resized_pil_images.append(img) | |
| pil_images = resized_pil_images | |
| img_encoded = encode_mod_image(pil_images) | |
| delta_start_ends = [] | |
| text_cond_mask = torch.zeros(num_inputs, max_length, device=device, dtype=torch.bool) | |
| if config["model"]["modulation"]["pass_vae"]: | |
| pil_images = [pad_to_square(img).resize((condition_size, condition_size)) for img in pil_images] | |
| with torch.no_grad(): | |
| batch_tensor = torch.stack([pil2tensor(x) for x in pil_images]) | |
| x_0, img_ids = encode_vae_images(pipe, batch_tensor) # (N, 256, 64) | |
| condition_latents = x_0.clone().detach().reshape(1, -1, 64) # (1, N256, 64) | |
| condition_ids = img_ids.clone().detach() | |
| condition_ids = condition_ids.unsqueeze(0).repeat_interleave(num_inputs, dim=0) # (N, 256, 3) | |
| for i in range(num_inputs): | |
| condition_ids[i, :, 1] += 0 if pos_offset_type == "width" else -(batch_tensor.shape[-1]//16) * (i + 1) | |
| condition_ids[i, :, 2] += -(batch_tensor.shape[-1]//16) * (i + 1) | |
| condition_ids = condition_ids.reshape(-1, 3) # (N256, 3) | |
| if config["model"]["modulation"]["use_dit"]: | |
| raise NotImplementedError() | |
| else: | |
| src_delta_embs = [] # [(512, 3072)] | |
| src_delta_emb_pblock = [] | |
| for i in range(num_inputs): | |
| if isinstance(img_encoded, dict): | |
| _src_clip_outputs = {} | |
| for key in img_encoded: | |
| if torch.is_tensor(img_encoded[key]): | |
| _src_clip_outputs[key] = img_encoded[key][i:i+1] | |
| else: | |
| _src_clip_outputs[key] = [x[i:i+1] for x in img_encoded[key]] | |
| _img_encoded = _src_clip_outputs | |
| else: | |
| _img_encoded = img_encoded[i:i+1] | |
| x1, x2 = pipe.modulation_adapters[0](timestep, src_prompt_embeds[i:i+1], _img_encoded) | |
| src_delta_embs.append(x1[0]) # (512, 3072) | |
| src_delta_emb_pblock.append(x2[0]) # (512, N, 3072) | |
| for input_args in modulation_input[0]["use_words"]: | |
| src_word_count = 1 | |
| if len(input_args) == 3: | |
| src_input_index, src_word, tar_word = input_args | |
| tar_word_count = 1 | |
| else: | |
| src_input_index, src_word, tar_word, tar_word_count = input_args[:4] | |
| src_prompt = src_prompts[src_input_index] | |
| tar_prompt = prompt | |
| src_start, src_end = get_word_index(pipe, src_prompt, src_input_ids[src_input_index], src_word, src_word_count, max_length, verbose=False) | |
| tar_start, tar_end = get_word_index(pipe, tar_prompt, tar_input_ids[0], tar_word, tar_word_count, max_length, verbose=False) | |
| if delta_emb is not None: | |
| delta_emb[:, tar_start:tar_end] = src_delta_embs[src_input_index][src_start:src_end] # (B, 512, 3072) | |
| if delta_emb_pblock is not None: | |
| delta_emb_pblock[:, tar_start:tar_end] = src_delta_emb_pblock[src_input_index][src_start:src_end] # (B, 512, N, 3072) | |
| delta_emb_mask[:, tar_start:tar_end] = True | |
| text_cond_mask[src_input_index, tar_start:tar_end] = True | |
| delta_start_ends.append([0, src_input_index, src_start, src_end, tar_start, tar_end]) | |
| text_cond_mask = text_cond_mask.transpose(0, 1).unsqueeze(0) | |
| else: | |
| raise NotImplementedError() | |
| return delta_emb, delta_emb_pblock, delta_emb_mask, \ | |
| text_cond_mask, delta_start_ends, condition_latents, condition_ids | |
| num_inference_steps = 28 # FIXME: harcoded here | |
| num_channels_latents = pipe.transformer.config.in_channels // 4 | |
| # set timesteps | |
| sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) | |
| mu = calculate_shift( | |
| num_channels_latents, | |
| pipe.scheduler.config.base_image_seq_len, | |
| pipe.scheduler.config.max_image_seq_len, | |
| pipe.scheduler.config.base_shift, | |
| pipe.scheduler.config.max_shift, | |
| ) | |
| timesteps, num_inference_steps = retrieve_timesteps( | |
| pipe.scheduler, | |
| num_inference_steps, | |
| device, | |
| None, | |
| sigmas, | |
| mu=mu, | |
| ) | |
| if modulation_input is not None: | |
| delta_embs = [] | |
| delta_embs_pblock = [] | |
| delta_embs_mask = [] | |
| for i, t in enumerate(timesteps): | |
| t = t.expand(1).to(torch.bfloat16) / 1000 | |
| ( | |
| delta_emb, delta_emb_pblock, delta_emb_mask, | |
| text_cond_mask, delta_start_ends, | |
| condition_latents, condition_ids | |
| ) = get_mod_emb(modulation_input, t) | |
| delta_embs.append(delta_emb) | |
| delta_embs_pblock.append(delta_emb_pblock) | |
| delta_embs_mask.append(delta_emb_mask) | |
| if original_image is not None: | |
| raise NotImplementedError() | |
| (target_height, target_width), closest_ratio = get_closest_ratio(original_image.height, original_image.width, train_aspect_ratios) | |
| elif modulation_input is None or len(modulation_input) == 0: | |
| delta_emb = delta_emb_pblock = delta_emb_mask = None | |
| else: | |
| for i, t in enumerate(timesteps): | |
| t = t.expand(1).to(torch.bfloat16) / 1000 | |
| ( | |
| delta_emb, delta_emb_pblock, delta_emb_mask, | |
| text_cond_mask, delta_start_ends, | |
| condition_latents, condition_ids | |
| ) = get_mod_emb(modulation_input, t) | |
| delta_embs.append(delta_emb) | |
| delta_embs_pblock.append(delta_emb_pblock) | |
| delta_embs_mask.append(delta_emb_mask) | |
| if target_height is None or target_width is None: | |
| target_height = target_width = target_size | |
| if condition_pad_to == "square": | |
| condition_imgs = [pad_to_square(x) for x in condition_imgs] | |
| elif condition_pad_to == "target": | |
| condition_imgs = [pad_to_target(x, (target_size, target_size)) for x in condition_imgs] | |
| condition_imgs = [x.resize((condition_size, condition_size)).convert("RGB") for x in condition_imgs] | |
| # TODO: fix position_delta | |
| conditions = [ | |
| Condition( | |
| condition_type=condition_type, | |
| condition=x, | |
| position_delta=position_delta, | |
| ) for x in condition_imgs | |
| ] | |
| # vlm_images = condition_imgs if config["model"]["use_vlm"] else [] | |
| use_perblock_adapter = False | |
| try: | |
| if config["model"]["modulation"]["use_perblock_adapter"]: | |
| use_perblock_adapter = True | |
| except Exception as e: | |
| pass | |
| results = [] | |
| for i in range(num_images): | |
| clear_attn_maps(pipe.transformer) | |
| generator = torch.Generator(device=device) | |
| generator.manual_seed(seed + i) | |
| if modulation_input is None or len(modulation_input) == 0: | |
| idips = None | |
| else: | |
| idips = ["human" in p["image_path"] for p in modulation_input[0]["src_inputs"]] | |
| if len(modulation_input[0]["use_words"][0])==5: | |
| print("use idips in use_words") | |
| idips = [x[-1] for x in modulation_input[0]["use_words"]] | |
| result_img = generate( | |
| pipe, | |
| prompt=prompt, | |
| max_sequence_length=max_length, | |
| vae_conditions=conditions, | |
| generator=generator, | |
| model_config=config["model"], | |
| height=target_height, | |
| width=target_width, | |
| condition_pad_to=condition_pad_to, | |
| condition_size=condition_size, | |
| text_cond_mask=text_cond_mask, | |
| delta_emb=delta_embs, | |
| delta_emb_pblock=delta_embs_pblock if use_perblock_adapter else None, | |
| delta_emb_mask=delta_embs_mask, | |
| delta_start_ends=delta_start_ends, | |
| condition_latents=condition_latents, | |
| condition_ids=condition_ids, | |
| mod_adapter=pipe.modulation_adapters[0] if config["model"]["modulation"]["use_dit"] else None, | |
| vae_skip_iter=vae_skip_iter, | |
| control_weight_lambda=control_weight_lambda, | |
| double_attention=double_attention, | |
| single_attention=single_attention, | |
| ip_scale=ip_scale, | |
| use_latent_sblora_control=use_latent_sblora_control, | |
| latent_sblora_scale=latent_sblora_scale, | |
| use_condition_sblora_control=use_condition_sblora_control, | |
| condition_sblora_scale=condition_sblora_scale, | |
| idips=idips if use_idip else None, | |
| **kargs, | |
| ).images[0] | |
| final_image = result_img | |
| results.append(final_image) | |
| if num_images == 1: | |
| return results[0] | |
| return results | 
