#!/usr/bin/env python3 from diffusers import FlaxStableDiffusionXLPipeline import numpy as np import jax.numpy as jnp import jax path = "hf-internal-testing/tiny-stable-diffusion-xl-pipe" pipe, params = FlaxStableDiffusionXLPipeline.from_pretrained(path) prompt = "An astronaut riding a green horse on Mars" steps = 3 batch_size, height, width, ch = 1, 32, 32, 4 num_elems = batch_size * height * width * ch rng = jax.random.PRNGKey(0) latents = (jnp.arange(num_elems) / num_elems)[:, None, None, None].reshape(batch_size, ch, width, height) print("latents", np.abs(np.asarray(latents)).sum()) prompt_embeds = pipe.prepare_inputs(prompt) image = pipe(prompt_embeds, params, rng, latents=latents, num_inference_steps=3, output_type="np").images[0] print(np.abs(np.asarray(image)).sum())