Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			Zero
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			Zero
	| import torch | |
| import math | |
| from diffusers_helper.k_diffusion.uni_pc_fm import sample_unipc | |
| from diffusers_helper.k_diffusion.wrapper import fm_wrapper | |
| from diffusers_helper.utils import repeat_to_batch_size | |
| def flux_time_shift(t, mu=1.15, sigma=1.0): | |
| return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) | |
| def calculate_flux_mu(context_length, x1=256, y1=0.5, x2=4096, y2=1.15, exp_max=7.0): | |
| k = (y2 - y1) / (x2 - x1) | |
| b = y1 - k * x1 | |
| mu = k * context_length + b | |
| mu = min(mu, math.log(exp_max)) | |
| return mu | |
| def get_flux_sigmas_from_mu(n, mu): | |
| sigmas = torch.linspace(1, 0, steps=n + 1) | |
| sigmas = flux_time_shift(sigmas, mu=mu) | |
| return sigmas | |
| def sample_hunyuan( | |
| transformer, | |
| sampler='unipc', | |
| initial_latent=None, | |
| concat_latent=None, | |
| strength=1.0, | |
| width=512, | |
| height=512, | |
| frames=16, | |
| real_guidance_scale=1.0, | |
| distilled_guidance_scale=6.0, | |
| guidance_rescale=0.0, | |
| shift=None, | |
| num_inference_steps=25, | |
| batch_size=None, | |
| generator=None, | |
| prompt_embeds=None, | |
| prompt_embeds_mask=None, | |
| prompt_poolers=None, | |
| negative_prompt_embeds=None, | |
| negative_prompt_embeds_mask=None, | |
| negative_prompt_poolers=None, | |
| dtype=torch.bfloat16, | |
| device=None, | |
| negative_kwargs=None, | |
| callback=None, | |
| **kwargs, | |
| ): | |
| device = device or transformer.device | |
| if batch_size is None: | |
| batch_size = int(prompt_embeds.shape[0]) | |
| latents = torch.randn((batch_size, 16, (frames + 3) // 4, height // 8, width // 8), generator=generator, device=generator.device).to(device=device, dtype=torch.float32) | |
| B, C, T, H, W = latents.shape | |
| seq_length = T * H * W // 4 | |
| if shift is None: | |
| mu = calculate_flux_mu(seq_length, exp_max=7.0) | |
| else: | |
| mu = math.log(shift) | |
| sigmas = get_flux_sigmas_from_mu(num_inference_steps, mu).to(device) | |
| k_model = fm_wrapper(transformer) | |
| if initial_latent is not None: | |
| sigmas = sigmas * strength | |
| first_sigma = sigmas[0].to(device=device, dtype=torch.float32) | |
| initial_latent = initial_latent.to(device=device, dtype=torch.float32) | |
| latents = initial_latent.float() * (1.0 - first_sigma) + latents.float() * first_sigma | |
| if concat_latent is not None: | |
| concat_latent = concat_latent.to(latents) | |
| distilled_guidance = torch.tensor([distilled_guidance_scale * 1000.0] * batch_size).to(device=device, dtype=dtype) | |
| prompt_embeds = repeat_to_batch_size(prompt_embeds, batch_size) | |
| prompt_embeds_mask = repeat_to_batch_size(prompt_embeds_mask, batch_size) | |
| prompt_poolers = repeat_to_batch_size(prompt_poolers, batch_size) | |
| negative_prompt_embeds = repeat_to_batch_size(negative_prompt_embeds, batch_size) | |
| negative_prompt_embeds_mask = repeat_to_batch_size(negative_prompt_embeds_mask, batch_size) | |
| negative_prompt_poolers = repeat_to_batch_size(negative_prompt_poolers, batch_size) | |
| concat_latent = repeat_to_batch_size(concat_latent, batch_size) | |
| sampler_kwargs = dict( | |
| dtype=dtype, | |
| cfg_scale=real_guidance_scale, | |
| cfg_rescale=guidance_rescale, | |
| concat_latent=concat_latent, | |
| positive=dict( | |
| pooled_projections=prompt_poolers, | |
| encoder_hidden_states=prompt_embeds, | |
| encoder_attention_mask=prompt_embeds_mask, | |
| guidance=distilled_guidance, | |
| **kwargs, | |
| ), | |
| negative=dict( | |
| pooled_projections=negative_prompt_poolers, | |
| encoder_hidden_states=negative_prompt_embeds, | |
| encoder_attention_mask=negative_prompt_embeds_mask, | |
| guidance=distilled_guidance, | |
| **(kwargs if negative_kwargs is None else {**kwargs, **negative_kwargs}), | |
| ) | |
| ) | |
| if sampler == 'unipc': | |
| results = sample_unipc(k_model, latents, sigmas, extra_args=sampler_kwargs, disable=False, callback=callback) | |
| else: | |
| raise NotImplementedError(f'Sampler {sampler} is not supported.') | |
| return results | |
 
			
