|
|
|
import hf_image_uploader as hiu |
|
import torch |
|
from diffusers import WuerstchenDecoderPipeline, WuerstchenPriorPipeline |
|
from diffusers.pipelines.wuerstchen import DEFAULT_STAGE_C_TIMESTEPS |
|
|
|
device = "cuda" |
|
dtype = torch.float16 |
|
num_images_per_prompt = 2 |
|
|
|
prior_pipeline = WuerstchenPriorPipeline.from_pretrained( |
|
"warp-ai/wuerstchen-prior", torch_dtype=dtype |
|
).to(device) |
|
decoder_pipeline = WuerstchenDecoderPipeline.from_pretrained( |
|
"warp-ai/wuerstchen", torch_dtype=dtype |
|
).to(device) |
|
|
|
caption = "Anthropomorphic cat dressed as a fire fighter" |
|
negative_prompt = "" |
|
|
|
prior_pipeline.prior = torch.compile(prior_pipeline.prior, mode="reduce-overhead", fullgraph=True) |
|
decoder_pipeline.decoder = torch.compile(decoder_pipeline.decoder, mode="reduce-overhead", fullgraph=True) |
|
|
|
prior_output = prior_pipeline( |
|
prompt=caption, |
|
height=1024, |
|
width=1536, |
|
timesteps=DEFAULT_STAGE_C_TIMESTEPS, |
|
negative_prompt=negative_prompt, |
|
guidance_scale=4.0, |
|
num_images_per_prompt=num_images_per_prompt, |
|
) |
|
images = decoder_pipeline( |
|
image_embeddings=prior_output.image_embeddings, |
|
prompt=caption, |
|
negative_prompt=negative_prompt, |
|
guidance_scale=0.0, |
|
output_type="pil", |
|
).images |
|
|
|
for image in images: |
|
hiu.upload(image, repo_id="patrickvonplaten/images") |
|
|