File size: 794 Bytes
e402ae6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
#!/usr/bin/env python3
from diffusers import FlaxStableDiffusionXLPipeline
import numpy as np
import jax.numpy as jnp
import jax

path = "hf-internal-testing/tiny-stable-diffusion-xl-pipe"

pipe, params = FlaxStableDiffusionXLPipeline.from_pretrained(path)

prompt = "An astronaut riding a green horse on Mars"
steps = 3

batch_size, height, width, ch = 1, 32, 32, 4
num_elems = batch_size * height * width * ch
rng = jax.random.PRNGKey(0)
latents = (jnp.arange(num_elems) / num_elems)[:, None, None, None].reshape(batch_size, ch, width, height)

print("latents", np.abs(np.asarray(latents)).sum())

prompt_embeds = pipe.prepare_inputs(prompt)

image = pipe(prompt_embeds, params, rng, latents=latents, num_inference_steps=3, output_type="np").images[0]

print(np.abs(np.asarray(image)).sum())