Spaces:
Running
on
Zero
Running
on
Zero
from diffusers.pipelines import FluxPipeline | |
from diffusers.utils import logging | |
from diffusers.pipelines.flux.pipeline_flux import logger | |
from torch import Tensor | |
def encode_images(pipeline: FluxPipeline, images: Tensor): | |
images = pipeline.image_processor.preprocess(images) | |
images = images.to(pipeline.device).to(pipeline.dtype) | |
images = pipeline.vae.encode(images).latent_dist.sample() | |
images = ( | |
images - pipeline.vae.config.shift_factor | |
) * pipeline.vae.config.scaling_factor | |
images_tokens = pipeline._pack_latents(images, *images.shape) | |
images_ids = pipeline._prepare_latent_image_ids( | |
images.shape[0], | |
images.shape[2], | |
images.shape[3], | |
pipeline.device, | |
pipeline.dtype, | |
) | |
if images_tokens.shape[1] != images_ids.shape[0]: | |
images_ids = pipeline._prepare_latent_image_ids( | |
images.shape[0], | |
images.shape[2] // 2, | |
images.shape[3] // 2, | |
pipeline.device, | |
pipeline.dtype, | |
) | |
return images_tokens, images_ids | |
def prepare_text_input(pipeline: FluxPipeline, prompts, max_sequence_length=512): | |
# Turn off warnings (CLIP overflow) | |
logger.setLevel(logging.ERROR) | |
( | |
prompt_embeds, | |
pooled_prompt_embeds, | |
text_ids, | |
) = pipeline.encode_prompt( | |
prompt=prompts, | |
prompt_2=None, | |
prompt_embeds=None, | |
pooled_prompt_embeds=None, | |
device=pipeline.device, | |
num_images_per_prompt=1, | |
max_sequence_length=max_sequence_length, | |
lora_scale=None, | |
) | |
# Turn on warnings | |
logger.setLevel(logging.WARNING) | |
return prompt_embeds, pooled_prompt_embeds, text_ids | |