File size: 794 Bytes
e402ae6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 |
#!/usr/bin/env python3
from diffusers import FlaxStableDiffusionXLPipeline
import numpy as np
import jax.numpy as jnp
import jax
path = "hf-internal-testing/tiny-stable-diffusion-xl-pipe"
pipe, params = FlaxStableDiffusionXLPipeline.from_pretrained(path)
prompt = "An astronaut riding a green horse on Mars"
steps = 3
batch_size, height, width, ch = 1, 32, 32, 4
num_elems = batch_size * height * width * ch
rng = jax.random.PRNGKey(0)
latents = (jnp.arange(num_elems) / num_elems)[:, None, None, None].reshape(batch_size, ch, width, height)
print("latents", np.abs(np.asarray(latents)).sum())
prompt_embeds = pipe.prepare_inputs(prompt)
image = pipe(prompt_embeds, params, rng, latents=latents, num_inference_steps=3, output_type="np").images[0]
print(np.abs(np.asarray(image)).sum())
|