Learner commited on
Commit
0ee1469
·
1 Parent(s): 12b6864

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1 -10
app.py CHANGED
@@ -32,15 +32,6 @@ from flax.training.common_utils import shard
32
  def create_key(seed=0):
33
  return jax.random.PRNGKey(seed)
34
 
35
-
36
- def image_grid(imgs, rows, cols):
37
- w, h = imgs[0].size
38
- grid = Image.new("RGB", size=(cols * w, rows * h))
39
- for i, img in enumerate(imgs):
40
- grid.paste(img, box=(i % cols * w, i // cols * h))
41
- return grid
42
-
43
-
44
  # load control net and stable diffusion v1-5
45
  controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
46
  "learner/jax-diffuser-event", from_flax=True, dtype=jnp.float32
@@ -62,7 +53,7 @@ def infer(prompts, negative_prompts, image):
62
  num_samples = 1 # jax.device_count()
63
  rng = create_key(0)
64
  rng = jax.random.split(rng, jax.device_count())
65
- battlemap_image = load_image(image)
66
 
67
  prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples)
68
  negative_prompt_ids = pipe.prepare_text_inputs([negative_prompts] * num_samples)
 
32
  def create_key(seed=0):
33
  return jax.random.PRNGKey(seed)
34
 
 
 
 
 
 
 
 
 
 
35
  # load control net and stable diffusion v1-5
36
  controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
37
  "learner/jax-diffuser-event", from_flax=True, dtype=jnp.float32
 
53
  num_samples = 1 # jax.device_count()
54
  rng = create_key(0)
55
  rng = jax.random.split(rng, jax.device_count())
56
+ battlemap_image = Image.open(image)
57
 
58
  prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples)
59
  negative_prompt_ids = pipe.prepare_text_inputs([negative_prompts] * num_samples)