Alex-23 commited on
Commit
1df7f10
·
1 Parent(s): a7e58f3

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -0
app.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers import DDPMPipeline
2
+ image_pipe = DDPMPipeline.from_pretrained("google/ddpm-celebahq-256")
3
+ image_pipe.to("cuda")
4
+ images = image_pipe().images
5
+ image_pipe
6
+ from diffusers import UNet2DModel
7
+ repo_id = "google/ddpm-church-256"
8
+ model = UNet2DModel.from_pretrained(repo_id)
9
+ model
10
+ model.config
11
+ model_random = UNet2DModel(**model.config)
12
+ model_random.save_pretrained("my_model")
13
+ model_random = UNet2DModel.from_pretrained("my_model")
14
+ import torch
15
+ torch.manual_seed(0)
16
+ noisy_sample = torch.randn(
17
+ 1, model.config.in_channels, model.config.sample_size, model.config.sample_size
18
+ )
19
+ noisy_sample.shape
20
+ with torch.no_grad():
21
+ noisy_residual = model(sample=noisy_sample, timestep=2).sample
22
+ noisy_residual.shape
23
+ from diffusers import DDPMScheduler
24
+ scheduler = DDPMScheduler.from_config(repo_id)
25
+ scheduler.config
26
+ scheduler.save_config("my_scheduler")
27
+ new_scheduler = DDPMScheduler.from_config("my_scheduler")
28
+ less_noisy_sample = scheduler.step(
29
+ model_output=noisy_residual, timestep=2, sample=noisy_sample
30
+ ).prev_sample
31
+ less_noisy_sample.shape
32
+ import PIL.Image
33
+ import numpy as np
34
+ def display_sample(sample, i):
35
+ image_processed = sample.cpu().permute(0, 2, 3, 1)
36
+ image_processed = (image_processed + 1.0) * 127.5
37
+ image_processed = image_processed.numpy().astype(np.uint8)
38
+ image_pil = PIL.Image.fromarray(image_processed[0])
39
+ display(f"Image at step {i}")
40
+ display(image_pil)
41
+ model.to("cuda")
42
+ noisy_sample = noisy_sample.to("cuda")
43
+ import tqdm
44
+ sample = noisy_sample
45
+ for i, t in enumerate(tqdm.tqdm(scheduler.timesteps)):
46
+ with torch.no_grad():
47
+ residual = model(sample, t).sample
48
+ sample = scheduler.step(residual, t, sample).prev_sample
49
+ if (i + 1) % 50 == 0:
50
+ display_sample(sample, i + 1)
51
+ from diffusers import DDIMScheduler
52
+ scheduler = DDIMScheduler.from_config(repo_id)
53
+ scheduler.set_timesteps(num_inference_steps=50)
54
+ import tqdm
55
+ sample = noisy_sample
56
+ for i, t in enumerate(tqdm.tqdm(scheduler.timesteps)):
57
+ with torch.no_grad():
58
+ residual = model(sample, t).sample
59
+ sample = scheduler.step(residual, t, sample).prev_sample
60
+ if (i + 1) % 10 == 0:
61
+ display_sample(sample, i + 1)