|
|
|
from diffusers import FlaxStableDiffusionXLPipeline |
|
import numpy as np |
|
import jax.numpy as jnp |
|
import jax |
|
|
|
path = "hf-internal-testing/tiny-stable-diffusion-xl-pipe" |
|
|
|
pipe, params = FlaxStableDiffusionXLPipeline.from_pretrained(path) |
|
|
|
prompt = "An astronaut riding a green horse on Mars" |
|
steps = 3 |
|
|
|
batch_size, height, width, ch = 1, 32, 32, 4 |
|
num_elems = batch_size * height * width * ch |
|
rng = jax.random.PRNGKey(0) |
|
latents = (jnp.arange(num_elems) / num_elems)[:, None, None, None].reshape(batch_size, ch, width, height) |
|
|
|
print("latents", np.abs(np.asarray(latents)).sum()) |
|
|
|
prompt_embeds = pipe.prepare_inputs(prompt) |
|
|
|
image = pipe(prompt_embeds, params, rng, latents=latents, num_inference_steps=3, output_type="np").images[0] |
|
|
|
print(np.abs(np.asarray(image)).sum()) |
|
|