awacke1 commited on
Commit
cf7cf37
·
1 Parent(s): 6451eba

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +178 -0
app.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ os.system("git clone --recursive https://github.com/JD-P/cloob-latent-diffusion")
4
+ os.system("cd cloob-latent-diffusion;pip install omegaconf pillow pytorch-lightning einops wandb ftfy regex ./CLIP")
5
+
6
+ import argparse
7
+ from functools import partial
8
+ from pathlib import Path
9
+ import sys
10
+
11
+ sys.path.append('./cloob-latent-diffusion')
12
+ sys.path.append('./cloob-latent-diffusion/cloob-training')
13
+ sys.path.append('./cloob-latent-diffusion/latent-diffusion')
14
+ sys.path.append('./cloob-latent-diffusion/taming-transformers')
15
+ sys.path.append('./cloob-latent-diffusion/v-diffusion-pytorch')
16
+
17
+ from omegaconf import OmegaConf
18
+ from PIL import Image
19
+
20
+ import torch
21
+ from torch import nn
22
+ from torch.nn import functional as F
23
+ from torchvision import transforms
24
+ from torchvision.transforms import functional as TF
25
+ from tqdm import trange
26
+ from CLIP import clip
27
+ from cloob_training import model_pt, pretrained
28
+
29
+ import ldm.models.autoencoder
30
+ from diffusion import sampling, utils
31
+
32
+ import train_latent_diffusion as train
33
+ from huggingface_hub import hf_hub_url, cached_download
34
+
35
+ import random
36
+
37
+ # Download the model files
38
+ checkpoint = cached_download(hf_hub_url("huggan/distill-ccld-wa", filename="model_student.ckpt"))
39
+ ae_model_path = cached_download(hf_hub_url("huggan/ccld_wa", filename="ae_model.ckpt"))
40
+ ae_config_path = cached_download(hf_hub_url("huggan/ccld_wa", filename="ae_model.yaml"))
41
+
42
+ # Define a few utility functions
43
+
44
+ def parse_prompt(prompt, default_weight=3.):
45
+ if prompt.startswith('http://') or prompt.startswith('https://'):
46
+ vals = prompt.rsplit(':', 2)
47
+ vals = [vals[0] + ':' + vals[1], *vals[2:]]
48
+ else:
49
+ vals = prompt.rsplit(':', 1)
50
+ vals = vals + ['', default_weight][len(vals):]
51
+ return vals[0], float(vals[1])
52
+
53
+
54
+ def resize_and_center_crop(image, size):
55
+ fac = max(size[0] / image.size[0], size[1] / image.size[1])
56
+ image = image.resize((int(fac * image.size[0]), int(fac * image.size[1])), Image.LANCZOS)
57
+ return TF.center_crop(image, size[::-1])
58
+
59
+
60
+ # Load the models
61
+ device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
62
+ print('Using device:', device)
63
+ print('loading models')
64
+
65
+ # autoencoder
66
+ ae_config = OmegaConf.load(ae_config_path)
67
+ ae_model = ldm.models.autoencoder.AutoencoderKL(**ae_config.model.params)
68
+ ae_model.eval().requires_grad_(False).to(device)
69
+ ae_model.load_state_dict(torch.load(ae_model_path))
70
+ n_ch, side_y, side_x = 4, 32, 32
71
+
72
+ # diffusion model
73
+ model = train.DiffusionModel(192, [1,1,2,2], autoencoder_scale=torch.tensor(4.3084))
74
+ model.load_state_dict(torch.load(checkpoint, map_location='cpu'))
75
+ model = model.to(device).eval().requires_grad_(False)
76
+
77
+ # CLOOB
78
+ cloob_config = pretrained.get_config('cloob_laion_400m_vit_b_16_16_epochs')
79
+ cloob = model_pt.get_pt_model(cloob_config)
80
+ checkpoint = pretrained.download_checkpoint(cloob_config)
81
+ cloob.load_state_dict(model_pt.get_pt_params(cloob_config, checkpoint))
82
+ cloob.eval().requires_grad_(False).to(device)
83
+
84
+
85
+ # The key function: returns a list of n PIL images
86
+ def generate(n=1, prompts=['a red circle'], images=[], seed=42, steps=15,
87
+ method='plms', eta=None):
88
+ zero_embed = torch.zeros([1, cloob.config['d_embed']], device=device)
89
+ target_embeds, weights = [zero_embed], []
90
+
91
+ for prompt in prompts:
92
+ txt, weight = parse_prompt(prompt)
93
+ target_embeds.append(cloob.text_encoder(cloob.tokenize(txt).to(device)).float())
94
+ weights.append(weight)
95
+
96
+ for prompt in images:
97
+ path, weight = parse_prompt(prompt)
98
+ img = Image.open(utils.fetch(path)).convert('RGB')
99
+ clip_size = cloob.config['image_encoder']['image_size']
100
+ img = resize_and_center_crop(img, (clip_size, clip_size))
101
+ batch = TF.to_tensor(img)[None].to(device)
102
+ embed = F.normalize(cloob.image_encoder(cloob.normalize(batch)).float(), dim=-1)
103
+ target_embeds.append(embed)
104
+ weights.append(weight)
105
+
106
+ weights = torch.tensor([1 - sum(weights), *weights], device=device)
107
+
108
+ torch.manual_seed(seed)
109
+
110
+ def cfg_model_fn(x, t):
111
+ n = x.shape[0]
112
+ n_conds = len(target_embeds)
113
+ x_in = x.repeat([n_conds, 1, 1, 1])
114
+ t_in = t.repeat([n_conds])
115
+ clip_embed_in = torch.cat([*target_embeds]).repeat_interleave(n, 0)
116
+ vs = model(x_in, t_in, clip_embed_in).view([n_conds, n, *x.shape[1:]])
117
+ v = vs.mul(weights[:, None, None, None, None]).sum(0)
118
+ return v
119
+
120
+ def run(x, steps):
121
+ if method == 'ddpm':
122
+ return sampling.sample(cfg_model_fn, x, steps, 1., {})
123
+ if method == 'ddim':
124
+ return sampling.sample(cfg_model_fn, x, steps, eta, {})
125
+ if method == 'prk':
126
+ return sampling.prk_sample(cfg_model_fn, x, steps, {})
127
+ if method == 'plms':
128
+ return sampling.plms_sample(cfg_model_fn, x, steps, {})
129
+ if method == 'pie':
130
+ return sampling.pie_sample(cfg_model_fn, x, steps, {})
131
+ if method == 'plms2':
132
+ return sampling.plms2_sample(cfg_model_fn, x, steps, {})
133
+ assert False
134
+
135
+ batch_size = n
136
+ x = torch.randn([n, n_ch, side_y, side_x], device=device)
137
+ t = torch.linspace(1, 0, steps + 1, device=device)[:-1]
138
+ steps = utils.get_spliced_ddpm_cosine_schedule(t)
139
+ pil_ims = []
140
+ for i in trange(0, n, batch_size):
141
+ cur_batch_size = min(n - i, batch_size)
142
+ out_latents = run(x[i:i+cur_batch_size], steps)
143
+ outs = ae_model.decode(out_latents * torch.tensor(2.55).to(device))
144
+ for j, out in enumerate(outs):
145
+ pil_ims.append(utils.to_pil_image(out))
146
+
147
+ return pil_ims
148
+
149
+
150
+ import gradio as gr
151
+
152
+ def gen_ims(prompt, im_prompt=None, seed=None, n_steps=10, method='plms'):
153
+ if seed == None :
154
+ seed = random.randint(0, 10000)
155
+ print( prompt, im_prompt, seed, n_steps)
156
+ prompts = [prompt]
157
+ im_prompts = []
158
+ if im_prompt != None:
159
+ im_prompts = [im_prompt]
160
+ pil_ims = generate(n=1, prompts=prompts, images=im_prompts, seed=seed, steps=n_steps, method=method)
161
+ return pil_ims[0]
162
+
163
+ iface = gr.Interface(fn=gen_ims,
164
+ inputs=[#gr.inputs.Slider(minimum=1, maximum=1, step=1, default=1,label="Number of images"),
165
+ #gr.inputs.Slider(minimum=0, maximum=200, step=1, label='Random seed', default=0),
166
+ gr.inputs.Textbox(label="Text prompt"),
167
+ gr.inputs.Image(optional=True, label="Image prompt", type='filepath'),
168
+ #gr.inputs.Slider(minimum=10, maximum=35, step=1, default=15,label="Number of steps")
169
+ ],
170
+ outputs=[gr.outputs.Image(type="pil", label="Generated Image")],
171
+ examples=[["Angels entertained unawares, oil painting"],["Pegasus in style of H.R. Giger"], ['Dragon oil painting'], ["Lighthouse reflections at sunrise"], ["Archangel sculpture"]],
172
+ title='Art from Text or Image:',
173
+ description="Sourced from [WikiArt](https://huggingface.co/datasets/huggan/wikiart) dataset",
174
+ article = 'Distilled version of a cloob-conditioned latent diffusion model [model card](https://huggingface.co/huggan/distill-ccld-wa)'
175
+
176
+ )
177
+
178
+ iface.launch(enable_queue=True) # , debug=True for colab debugging