import logging from safetensors.torch import load_file from animatediff import get_dir from animatediff.utils.lora_diffusers import (LoRANetwork, create_network_from_weights) logger = logging.getLogger(__name__) data_dir = get_dir("data") def merge_safetensors_lora(text_encoder, unet, lora_path, alpha=0.75, is_animatediff=True): def dump(loaded): for a in loaded: logger.info(f"{a} {loaded[a].shape}") sd = load_file(lora_path) if False: dump(sd) print(f"create LoRA network") lora_network: LoRANetwork = create_network_from_weights(text_encoder, unet, sd, multiplier=alpha, is_animatediff=is_animatediff) print(f"load LoRA network weights") lora_network.load_state_dict(sd, False) lora_network.merge_to(alpha) def load_lora_map(pipe, lora_map_config, video_length, is_sdxl=False): new_map = {} for item in lora_map_config: lora_path = data_dir.joinpath(item) if type(lora_map_config[item]) in (float,int): te_en = [pipe.text_encoder, pipe.text_encoder_2] if is_sdxl else pipe.text_encoder merge_safetensors_lora(te_en, pipe.unet, lora_path, lora_map_config[item], not is_sdxl) else: new_map[lora_path] = lora_map_config[item] lora_map = LoraMap(pipe, new_map, video_length, is_sdxl) pipe.lora_map = lora_map if lora_map.is_valid else None def load_lcm_lora(pipe, lcm_map, is_sdxl=False, is_merge=False): if is_sdxl: lora_path = data_dir.joinpath("models/lcm_lora/sdxl/pytorch_lora_weights.safetensors") else: lora_path = data_dir.joinpath("models/lcm_lora/sd15/pytorch_lora_weights.safetensors") logger.info(f"{lora_path=}") if is_merge: te_en = [pipe.text_encoder, pipe.text_encoder_2] if is_sdxl else pipe.text_encoder merge_safetensors_lora(te_en, pipe.unet, lora_path, 1.0, not is_sdxl) pipe.lcm = None return lcm = LcmLora(pipe, is_sdxl, lora_path, lcm_map) pipe.lcm = lcm if lcm.is_valid else None class LcmLora: def __init__( self, pipe, is_sdxl, lora_path, lcm_map ): self.is_valid = False sd = load_file(lora_path) if not sd: return te_en = [pipe.text_encoder, pipe.text_encoder_2] if is_sdxl else pipe.text_encoder lora_network: LoRANetwork = create_network_from_weights(te_en, pipe.unet, sd, multiplier=1.0, is_animatediff=not is_sdxl) lora_network.load_state_dict(sd, False) lora_network.apply_to(1.0) self.network = lora_network self.is_valid = True self.start_scale = lcm_map["start_scale"] self.end_scale = lcm_map["end_scale"] self.gradient_start = lcm_map["gradient_start"] self.gradient_end = lcm_map["gradient_end"] def to( self, device, dtype, ): self.network.to(device=device, dtype=dtype) def apply( self, step, total_steps, ): step += 1 progress = step / total_steps if progress < self.gradient_start: scale = self.start_scale elif progress > self.gradient_end: scale = self.end_scale else: if (self.gradient_end - self.gradient_start) < 1e-4: progress = 0 else: progress = (progress - self.gradient_start) / (self.gradient_end - self.gradient_start) scale = (self.end_scale - self.start_scale) * progress scale += self.start_scale self.network.active( scale ) def unapply( self, ): self.network.deactive( ) class LoraMap: def __init__( self, pipe, lora_map, video_length, is_sdxl, ): self.networks = [] def create_schedule(scales, length): scales = { int(i):scales[i] for i in scales } keys = sorted(scales.keys()) if len(keys) == 1: return { i:scales[keys[0]] for i in range(length) } keys = keys + [keys[0]] schedule={} def calc(rate,start_v,end_v): return start_v + (rate * rate)*(end_v - start_v) for key_prev,key_next in zip(keys[:-1],keys[1:]): v1 = scales[key_prev] v2 = scales[key_next] if key_prev > key_next: key_next += length for i in range(key_prev,key_next): dist = i-key_prev if i >= length: i -= length schedule[i] = calc( dist/(key_next-key_prev), v1, v2 ) return schedule for lora_path in lora_map: sd = load_file(lora_path) if not sd: continue te_en = [pipe.text_encoder, pipe.text_encoder_2] if is_sdxl else pipe.text_encoder lora_network: LoRANetwork = create_network_from_weights(te_en, pipe.unet, sd, multiplier=0.75, is_animatediff=not is_sdxl) lora_network.load_state_dict(sd, False) lora_network.apply_to(0.75) self.networks.append( { "network":lora_network, "region":lora_map[lora_path]["region"], "schedule": create_schedule(lora_map[lora_path]["scale"], video_length ) } ) def region_convert(i): if i == "background": return 0 else: return int(i) + 1 for net in self.networks: net["region"] = [ region_convert(i) for i in net["region"] ] # for n in self.networks: # logger.info(f"{n['region']=}") # logger.info(f"{n['schedule']=}") if self.networks: self.is_valid = True else: self.is_valid = False def to( self, device, dtype, ): for net in self.networks: net["network"].to(device=device, dtype=dtype) def apply( self, cond_index, cond_nums, frame_no, ): ''' neg 0 (bg) neg 1 neg 2 pos 0 (bg) pos 1 pos 2 ''' region_index = cond_index if cond_index < cond_nums//2 else cond_index - cond_nums//2 # logger.info(f"{cond_index=}") # logger.info(f"{cond_nums=}") # logger.info(f"{region_index=}") for i,net in enumerate(self.networks): if region_index in net["region"]: scale = net["schedule"][frame_no] if scale > 0: net["network"].active( scale ) # logger.info(f"{i=} active {scale=}") else: net["network"].deactive( ) # logger.info(f"{i=} DEactive") else: net["network"].deactive( ) # logger.info(f"{i=} DEactive") def unapply( self, ): for net in self.networks: net["network"].deactive( )