fffiloni commited on
Commit
ca441ab
·
verified ·
1 Parent(s): 052f125

Create gradio_app.py

Browse files
Files changed (1) hide show
  1. gradio_app.py +316 -0
gradio_app.py ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import yaml
2
+ import tempfile
3
+ import gradio as gr
4
+ import os
5
+ import torch
6
+ import imageio
7
+ import argparse
8
+ from types import MethodType
9
+ import safetensors.torch as sf
10
+ import torch.nn.functional as F
11
+ from omegaconf import OmegaConf
12
+ from transformers import CLIPTextModel, CLIPTokenizer
13
+ from diffusers import MotionAdapter, EulerAncestralDiscreteScheduler, AutoencoderKL
14
+ from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepScheduler
15
+ from diffusers.models.attention_processor import AttnProcessor2_0
16
+ from torch.hub import download_url_to_file
17
+
18
+ from src.ic_light import BGSource
19
+ from src.animatediff_pipe import AnimateDiffVideoToVideoPipeline
20
+ from src.ic_light_pipe import StableDiffusionImg2ImgPipeline
21
+ from utils.tools import read_video,
22
+
23
+ from huggingface_hub import snapshot_download, hf_hub_download
24
+
25
+ huggingface_hub.hf_hub_download(
26
+ repo_id='lllyasviel/ic-light',
27
+ filename='iclight_sd15_fc.safetensors',
28
+ local_dir='./models'
29
+ )
30
+
31
+ snapshot_download(
32
+ repo_id="stablediffusionapi/realistic-vision-v51",
33
+ local_dir="./models/stablediffusionapi/realistic-vision-v51"
34
+ )
35
+
36
+ snapshot_download(
37
+ repo_id="guoyww/animatediff-motion-adapter-v1-5-3",
38
+ local_dir="./models/guoyww/animatediff-motion-adapter-v1-5-3"
39
+ )
40
+
41
+ def main(args):
42
+
43
+ config = OmegaConf.load(args.config)
44
+ device = torch.device('cuda')
45
+ adopted_dtype = torch.float16
46
+ set_all_seed(42)
47
+
48
+ ## vdm model
49
+ adapter = MotionAdapter.from_pretrained(args.motion_adapter_model)
50
+
51
+ ## pipeline
52
+ pipe = AnimateDiffVideoToVideoPipeline.from_pretrained(args.sd_model, motion_adapter=adapter)
53
+ eul_scheduler = EulerAncestralDiscreteScheduler.from_pretrained(
54
+ args.sd_model,
55
+ subfolder="scheduler",
56
+ beta_schedule="linear",
57
+ )
58
+
59
+ pipe.scheduler = eul_scheduler
60
+ pipe.enable_vae_slicing()
61
+ pipe = pipe.to(device=device, dtype=adopted_dtype)
62
+ pipe.vae.requires_grad_(False)
63
+ pipe.unet.requires_grad_(False)
64
+
65
+ ## ic-light model
66
+ tokenizer = CLIPTokenizer.from_pretrained(args.sd_model, subfolder="tokenizer")
67
+ text_encoder = CLIPTextModel.from_pretrained(args.sd_model, subfolder="text_encoder")
68
+ vae = AutoencoderKL.from_pretrained(args.sd_model, subfolder="vae")
69
+ unet = UNet2DConditionModel.from_pretrained(args.sd_model, subfolder="unet")
70
+ with torch.no_grad():
71
+ new_conv_in = torch.nn.Conv2d(8, unet.conv_in.out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding)
72
+ new_conv_in.weight.zero_() #torch.Size([320, 8, 3, 3])
73
+ new_conv_in.weight[:, :4, :, :].copy_(unet.conv_in.weight)
74
+ new_conv_in.bias = unet.conv_in.bias
75
+ unet.conv_in = new_conv_in
76
+ unet_original_forward = unet.forward
77
+
78
+ def hooked_unet_forward(sample, timestep, encoder_hidden_states, **kwargs):
79
+
80
+ c_concat = kwargs['cross_attention_kwargs']['concat_conds'].to(sample)
81
+ c_concat = torch.cat([c_concat] * (sample.shape[0] // c_concat.shape[0]), dim=0)
82
+ new_sample = torch.cat([sample, c_concat], dim=1)
83
+ kwargs['cross_attention_kwargs'] = {}
84
+ return unet_original_forward(new_sample, timestep, encoder_hidden_states, **kwargs)
85
+ unet.forward = hooked_unet_forward
86
+
87
+ ## ic-light model loader
88
+ if not os.path.exists(args.ic_light_model):
89
+ download_url_to_file(url='https://huggingface.co/lllyasviel/ic-light/resolve/main/iclight_sd15_fc.safetensors',
90
+ dst=args.ic_light_model)
91
+
92
+ sd_offset = sf.load_file(args.ic_light_model)
93
+ sd_origin = unet.state_dict()
94
+ sd_merged = {k: sd_origin[k] + sd_offset[k] for k in sd_origin.keys()}
95
+ unet.load_state_dict(sd_merged, strict=True)
96
+ del sd_offset, sd_origin, sd_merged
97
+ text_encoder = text_encoder.to(device=device, dtype=adopted_dtype)
98
+ vae = vae.to(device=device, dtype=adopted_dtype)
99
+ unet = unet.to(device=device, dtype=adopted_dtype)
100
+ unet.set_attn_processor(AttnProcessor2_0())
101
+ vae.set_attn_processor(AttnProcessor2_0())
102
+
103
+ # Consistent light attention
104
+ @torch.inference_mode()
105
+ def custom_forward_CLA(self,
106
+ hidden_states,
107
+ gamma=config.get("gamma", 0.5),
108
+ encoder_hidden_states=None,
109
+ attention_mask=None,
110
+ cross_attention_kwargs=None
111
+ ):
112
+
113
+ batch_size, sequence_length, channel = hidden_states.shape
114
+
115
+ residual = hidden_states
116
+ input_ndim = hidden_states.ndim
117
+ if input_ndim == 4:
118
+ batch_size, channel, height, width = hidden_states.shape
119
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
120
+
121
+ if attention_mask is not None:
122
+ if attention_mask.shape[-1] != query.shape[1]:
123
+ target_length = query.shape[1]
124
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
125
+ attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
126
+ if self.group_norm is not None:
127
+ hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
128
+ if encoder_hidden_states is None:
129
+ encoder_hidden_states = hidden_states
130
+
131
+ query = self.to_q(hidden_states)
132
+ key = self.to_k(encoder_hidden_states)
133
+ value = self.to_v(encoder_hidden_states)
134
+ inner_dim = key.shape[-1]
135
+ head_dim = inner_dim // self.heads
136
+ query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
137
+ key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
138
+ value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
139
+
140
+ hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False)
141
+ shape = query.shape
142
+
143
+ # addition key and value
144
+ mean_key = key.reshape(2,-1,shape[1],shape[2],shape[3]).mean(dim=1,keepdim=True)
145
+ mean_value = value.reshape(2,-1,shape[1],shape[2],shape[3]).mean(dim=1,keepdim=True)
146
+ mean_key = mean_key.expand(-1,shape[0]//2,-1,-1,-1).reshape(shape[0],shape[1],shape[2],shape[3])
147
+ mean_value = mean_value.expand(-1,shape[0]//2,-1,-1,-1).reshape(shape[0],shape[1],shape[2],shape[3])
148
+ add_hidden_state = F.scaled_dot_product_attention(query, mean_key, mean_value, attn_mask=None, dropout_p=0.0, is_causal=False)
149
+
150
+ # mix
151
+ hidden_states = (1-gamma)*hidden_states + gamma*add_hidden_state
152
+
153
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.heads * head_dim)
154
+ hidden_states = hidden_states.to(query.dtype)
155
+ hidden_states = self.to_out[0](hidden_states)
156
+ hidden_states = self.to_out[1](hidden_states)
157
+
158
+ if input_ndim == 4:
159
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
160
+
161
+ if self.residual_connection:
162
+ hidden_states = hidden_states + residual
163
+
164
+ hidden_states = hidden_states / self.rescale_output_factor
165
+ return hidden_states
166
+
167
+ ### attention
168
+ @torch.inference_mode()
169
+ def prep_unet_self_attention(unet):
170
+ for name, module in unet.named_modules():
171
+ module_name = type(module).__name__
172
+
173
+ name_split_list = name.split(".")
174
+ cond_1 = name_split_list[0] in "up_blocks"
175
+ cond_2 = name_split_list[-1] in ('attn1')
176
+
177
+ if "Attention" in module_name and cond_1 and cond_2:
178
+ cond_3 = name_split_list[1]
179
+ if cond_3 not in "3":
180
+ module.forward = MethodType(custom_forward_CLA, module)
181
+
182
+ return unet
183
+
184
+ ## consistency light attention
185
+ unet = prep_unet_self_attention(unet)
186
+
187
+ ## ic-light-scheduler
188
+ ic_light_scheduler = DPMSolverMultistepScheduler(
189
+ num_train_timesteps=1000,
190
+ beta_start=0.00085,
191
+ beta_end=0.012,
192
+ algorithm_type="sde-dpmsolver++",
193
+ use_karras_sigmas=True,
194
+ steps_offset=1
195
+ )
196
+ ic_light_pipe = StableDiffusionImg2ImgPipeline(
197
+ vae=vae,
198
+ text_encoder=text_encoder,
199
+ tokenizer=tokenizer,
200
+ unet=unet,
201
+ scheduler=ic_light_scheduler,
202
+ safety_checker=None,
203
+ requires_safety_checker=False,
204
+ feature_extractor=None,
205
+ image_encoder=None
206
+ )
207
+ ic_light_pipe = ic_light_pipe.to(device)
208
+
209
+ ############################# params ######################################
210
+ strength = config.get("strength", 0.5)
211
+ num_step = config.get("num_step", 25)
212
+ text_guide_scale = config.get("text_guide_scale", 2)
213
+ seed = config.get("seed")
214
+ image_width = config.get("width", 512)
215
+ image_height = config.get("height", 512)
216
+ n_prompt = config.get("n_prompt", "")
217
+ relight_prompt = config.get("relight_prompt", "")
218
+ video_path = config.get("video_path", "")
219
+ bg_source = BGSource[config.get("bg_source")]
220
+ save_path = config.get("save_path")
221
+
222
+ ############################## infer #####################################
223
+ generator = torch.manual_seed(seed)
224
+ video_name = os.path.basename(video_path)
225
+ video_list, video_name = read_video(video_path, image_width, image_height)
226
+
227
+ print("################## begin ##################")
228
+ with torch.no_grad():
229
+ num_inference_steps = int(round(num_step / strength))
230
+
231
+ output = pipe(
232
+ ic_light_pipe=ic_light_pipe,
233
+ relight_prompt=relight_prompt,
234
+ bg_source=bg_source,
235
+ video=video_list,
236
+ prompt=relight_prompt,
237
+ strength=strength,
238
+ negative_prompt=n_prompt,
239
+ guidance_scale=text_guide_scale,
240
+ num_inference_steps=num_inference_steps,
241
+ height=image_height,
242
+ width=image_width,
243
+ generator=generator,
244
+ )
245
+
246
+ frames = output.frames[0]
247
+ results_path = f"{save_path}/relight_{video_name}"
248
+ imageio.mimwrite(results_path, frames, fps=8)
249
+ print(f"relight with bg generation! prompt:{relight_prompt}, light:{bg_source.value}, save in {results_path}.")
250
+
251
+ def infer(n_prompt, relight_prompt, video_path, bg_source, save_path,
252
+ width, height, strength, gamma, num_step, text_guide_scale, seed):
253
+
254
+ config_data = {
255
+ "n_prompt": n_prompt,
256
+ "relight_prompt": relight_prompt,
257
+ "video_path": video_path,
258
+ "bg_source": bg_source,
259
+ "save_path": save_path,
260
+ "width": width,
261
+ "height": height,
262
+ "strength": strength,
263
+ "gamma": gamma,
264
+ "num_step": num_step,
265
+ "text_guide_scale": text_guide_scale,
266
+ "seed": seed
267
+ }
268
+
269
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".yaml")
270
+ with open(temp_file.name, 'w') as file:
271
+ yaml.dump(config_data, file, default_flow_style=False)
272
+
273
+ config_path = temp_file.name
274
+
275
+ class Args:
276
+ def __init__(self):
277
+ self.sd_model = "./models/stablediffusionapi/realistic-vision-v51"
278
+ self.motion_adapter_model = "./models/guoyww/animatediff-motion-adapter-v1-5-3"
279
+ self.ic_light_model = "./models/iclight_sd15_fc.safetensors"
280
+ self.config = config_path
281
+
282
+ args = Args()
283
+ main(args)
284
+
285
+ video_name = os.path.basename(video_path)
286
+ results_path = f"{save_path}/relight_{video_name}"
287
+ os.remove(config_path)
288
+
289
+ return results_path
290
+
291
+ with gr.Blocks() as demo:
292
+ with gr.Row():
293
+ n_prompt = gr.Textbox(label="Negative Prompt")
294
+ relight_prompt = gr.Textbox(label="Relight Prompt")
295
+ with gr.Row():
296
+ video_path = gr.Textbox(label="Video Path")
297
+ bg_source = gr.Dropdown(["NONE", "LEFT", "RIGHT", "BOTTOM", "TOP"], label="Background Source")
298
+ with gr.Row():
299
+ save_path = gr.Textbox(label="Save Path")
300
+ width = gr.Number(label="Width", value=512)
301
+ height = gr.Number(label="Height", value=512)
302
+ with gr.Row():
303
+ strength = gr.Slider(minimum=0.0, maximum=1.0, label="Strength", value=0.5)
304
+ gamma = gr.Slider(minimum=0.0, maximum=1.0, label="Gamma", value=0.5)
305
+ with gr.Row():
306
+ num_step = gr.Number(label="Number of Steps", value=25)
307
+ text_guide_scale = gr.Number(label="Text Guide Scale", value=2)
308
+ seed = gr.Number(label="Seed", value=2060)
309
+
310
+ output = gr.Textbox(label="Results Path")
311
+ submit = gr.Button("Run")
312
+ submit.click(infer, inputs=[n_prompt, relight_prompt, video_path, bg_source, save_path,
313
+ width, height, strength, gamma, num_step, text_guide_scale, seed],
314
+ outputs=output)
315
+
316
+ demo.launch()