yuvalalaluf commited on
Commit
b058ece
Β·
1 Parent(s): d41f254

add run.py

Browse files
Files changed (1) hide show
  1. run.py +65 -0
run.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from typing import List
3
+
4
+ import numpy as np
5
+ import pyrallis
6
+ import torch
7
+ from PIL import Image
8
+ from diffusers.training_utils import set_seed
9
+
10
+ sys.path.append(".")
11
+ sys.path.append("..")
12
+
13
+ from appearance_transfer_model import AppearanceTransferModel
14
+ from config import RunConfig, Range
15
+ from utils import latent_utils
16
+ from utils.latent_utils import load_latents_or_invert_images
17
+
18
+
19
+ @pyrallis.wrap()
20
+ def main(cfg: RunConfig):
21
+ run(cfg)
22
+
23
+
24
+ def run(cfg: RunConfig) -> List[Image.Image]:
25
+ pyrallis.dump(cfg, open(cfg.output_path / 'config.yaml', 'w'))
26
+ set_seed(cfg.seed)
27
+ model = AppearanceTransferModel(cfg)
28
+ latents_app, latents_struct, noise_app, noise_struct = load_latents_or_invert_images(model=model, cfg=cfg)
29
+ model.set_latents(latents_app, latents_struct)
30
+ model.set_noise(noise_app, noise_struct)
31
+ print("Running appearance transfer...")
32
+ images = run_appearance_transfer(model=model, cfg=cfg)
33
+ print("Done.")
34
+ return images
35
+
36
+
37
+ def run_appearance_transfer(model: AppearanceTransferModel, cfg: RunConfig) -> List[Image.Image]:
38
+ init_latents, init_zs = latent_utils.get_init_latents_and_noises(model=model, cfg=cfg)
39
+ model.pipe.scheduler.set_timesteps(cfg.num_timesteps)
40
+ model.enable_edit = True # Activate our cross-image attention layers
41
+ start_step = min(cfg.cross_attn_32_range.start, cfg.cross_attn_64_range.start)
42
+ end_step = max(cfg.cross_attn_32_range.end, cfg.cross_attn_64_range.end)
43
+ images = model.pipe(
44
+ prompt=[cfg.prompt] * 3,
45
+ latents=init_latents,
46
+ guidance_scale=1.0,
47
+ num_inference_steps=cfg.num_timesteps,
48
+ swap_guidance_scale=cfg.swap_guidance_scale,
49
+ callback=model.get_adain_callback(),
50
+ eta=1,
51
+ zs=init_zs,
52
+ generator=torch.Generator('cuda').manual_seed(cfg.seed),
53
+ cross_image_attention_range=Range(start=start_step, end=end_step),
54
+ ).images
55
+ # Save images
56
+ images[0].save(cfg.output_path / f"out_transfer---seed_{cfg.seed}.png")
57
+ images[1].save(cfg.output_path / f"out_style---seed_{cfg.seed}.png")
58
+ images[2].save(cfg.output_path / f"out_struct---seed_{cfg.seed}.png")
59
+ joined_images = np.concatenate(images[::-1], axis=1)
60
+ Image.fromarray(joined_images).save(cfg.output_path / f"out_joined---seed_{cfg.seed}.png")
61
+ return images
62
+
63
+
64
+ if __name__ == '__main__':
65
+ main()