Spaces:
Running
on
Zero
Running
on
Zero
File size: 1,712 Bytes
fb6a167 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 |
from diffusers.pipelines import FluxPipeline
from diffusers.utils import logging
from diffusers.pipelines.flux.pipeline_flux import logger
from torch import Tensor
def encode_images(pipeline: FluxPipeline, images: Tensor):
images = pipeline.image_processor.preprocess(images)
images = images.to(pipeline.device).to(pipeline.dtype)
images = pipeline.vae.encode(images).latent_dist.sample()
images = (
images - pipeline.vae.config.shift_factor
) * pipeline.vae.config.scaling_factor
images_tokens = pipeline._pack_latents(images, *images.shape)
images_ids = pipeline._prepare_latent_image_ids(
images.shape[0],
images.shape[2],
images.shape[3],
pipeline.device,
pipeline.dtype,
)
if images_tokens.shape[1] != images_ids.shape[0]:
images_ids = pipeline._prepare_latent_image_ids(
images.shape[0],
images.shape[2] // 2,
images.shape[3] // 2,
pipeline.device,
pipeline.dtype,
)
return images_tokens, images_ids
def prepare_text_input(pipeline: FluxPipeline, prompts, max_sequence_length=512):
# Turn off warnings (CLIP overflow)
logger.setLevel(logging.ERROR)
(
prompt_embeds,
pooled_prompt_embeds,
text_ids,
) = pipeline.encode_prompt(
prompt=prompts,
prompt_2=None,
prompt_embeds=None,
pooled_prompt_embeds=None,
device=pipeline.device,
num_images_per_prompt=1,
max_sequence_length=max_sequence_length,
lora_scale=None,
)
# Turn on warnings
logger.setLevel(logging.WARNING)
return prompt_embeds, pooled_prompt_embeds, text_ids
|