SettW commited on
Commit
10307dd
·
verified ·
1 Parent(s): 967e6ad

Create loosecontrol.py

Browse files
Files changed (1) hide show
  1. loosecontrol.py +135 -0
loosecontrol.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers import (
2
+ ControlNetModel,
3
+ StableDiffusionControlNetPipeline,
4
+ UniPCMultistepScheduler,
5
+ )
6
+ import torch
7
+ import PIL
8
+ import PIL.Image
9
+ from diffusers.loaders import UNet2DConditionLoadersMixin
10
+ from typing import Dict
11
+ from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor
12
+ import functools
13
+ from cross_frame_attention import CrossFrameAttnProcessor
14
+
15
+ TEXT_ENCODER_NAME = "text_encoder"
16
+ UNET_NAME = "unet"
17
+ NEGATIVE_PROMPT = "blurry, text, caption, lowquality, lowresolution, low res, grainy, ugly"
18
+
19
+ def attach_loaders_mixin(model):
20
+ # hacky way to make ControlNet work with LoRA. This may not be required in future versions of diffusers.
21
+ model.text_encoder_name = TEXT_ENCODER_NAME
22
+ model.unet_name = UNET_NAME
23
+ r"""
24
+ Attach the [`UNet2DConditionLoadersMixin`] to a model. This will add the
25
+ all the methods from the mixin 'UNet2DConditionLoadersMixin' to the model.
26
+ """
27
+ # mixin_instance = UNet2DConditionLoadersMixin()
28
+ for attr_name, attr_value in vars(UNet2DConditionLoadersMixin).items():
29
+ # print(attr_name)
30
+ if callable(attr_value):
31
+ # setattr(model, attr_name, functools.partialmethod(attr_value, model).__get__(model, model.__class__))
32
+ setattr(model, attr_name, functools.partial(attr_value, model))
33
+ return model
34
+
35
+ def set_attn_processor(module, processor, _remove_lora=False):
36
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
37
+ if hasattr(module, "set_processor"):
38
+ if not isinstance(processor, dict):
39
+ module.set_processor(processor, _remove_lora=_remove_lora)
40
+ else:
41
+ module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
42
+
43
+ for sub_name, child in module.named_children():
44
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
45
+
46
+ for name, module in module.named_children():
47
+ fn_recursive_attn_processor(name, module, processor)
48
+
49
+
50
+
51
+ class ControlNetX(ControlNetModel, UNet2DConditionLoadersMixin):
52
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
53
+ # This may not be required in future versions of diffusers.
54
+ @property
55
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
56
+ r"""
57
+ Returns:
58
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
59
+ indexed by its weight name.
60
+ """
61
+ # set recursively
62
+ processors = {}
63
+
64
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
65
+ if hasattr(module, "get_processor"):
66
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
67
+
68
+ for sub_name, child in module.named_children():
69
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
70
+
71
+ return processors
72
+
73
+ for name, module in self.named_children():
74
+ fn_recursive_add_processors(name, module, processors)
75
+
76
+ return processors
77
+
78
+ class ControlNetPipeline:
79
+ def __init__(self, checkpoint="lllyasviel/control_v11f1p_sd15_depth", sd_checkpoint="runwayml/stable-diffusion-v1-5") -> None:
80
+ controlnet = ControlNetX.from_pretrained(checkpoint)
81
+ self.pipe = StableDiffusionControlNetPipeline.from_pretrained(
82
+ sd_checkpoint, controlnet=controlnet, requires_safety_checker=False, safety_checker=None,
83
+ torch_dtype=torch.float16)
84
+ self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config)
85
+
86
+ @torch.no_grad()
87
+ def __call__(self,
88
+ prompt: str="",
89
+ height=512,
90
+ width=512,
91
+ control_image=None,
92
+ controlnet_conditioning_scale=1.0,
93
+ num_inference_steps: int=20,
94
+ **kwargs) -> PIL.Image.Image:
95
+
96
+ out = self.pipe(prompt, control_image,
97
+ height=height, width=width,
98
+ num_inference_steps=num_inference_steps,
99
+ controlnet_conditioning_scale=controlnet_conditioning_scale,
100
+ **kwargs).images
101
+
102
+ return out[0] if len(out) == 1 else out
103
+
104
+ def to(self, *args, **kwargs):
105
+ self.pipe.to(*args, **kwargs)
106
+ return self
107
+
108
+
109
+ class LooseControlNet(ControlNetPipeline):
110
+ def __init__(self, loose_control_weights="shariqfarooq/loose-control-3dbox", cn_checkpoint="lllyasviel/control_v11f1p_sd15_depth", sd_checkpoint="runwayml/stable-diffusion-v1-5") -> None:
111
+ super().__init__(cn_checkpoint, sd_checkpoint)
112
+ self.pipe.controlnet = attach_loaders_mixin(self.pipe.controlnet)
113
+ self.pipe.controlnet.load_attn_procs(loose_control_weights)
114
+
115
+ def set_normal_attention(self):
116
+ self.pipe.unet.set_attn_processor(AttnProcessor())
117
+
118
+ def set_cf_attention(self, _remove_lora=False):
119
+ for upblocks in self.pipe.unet.up_blocks[-2:]:
120
+ set_attn_processor(upblocks, CrossFrameAttnProcessor(), _remove_lora=_remove_lora)
121
+
122
+ def edit(self, depth, depth_edit, prompt, prompt_edit=None, seed=42, seed_edit=None, negative_prompt=NEGATIVE_PROMPT, controlnet_conditioning_scale=1.0, num_inference_steps=20, **kwargs):
123
+ if prompt_edit is None:
124
+ prompt_edit = prompt
125
+
126
+ if seed_edit is None:
127
+ seed_edit = seed
128
+
129
+ seed = int(seed)
130
+ seed_edit = int(seed_edit)
131
+ control_image = [depth, depth_edit]
132
+ prompt = [prompt, prompt_edit]
133
+ generator = [torch.Generator().manual_seed(seed), torch.Generator().manual_seed(seed_edit)]
134
+ gen = self.pipe(prompt, control_image=control_image, controlnet_conditioning_scale=controlnet_conditioning_scale, generator=generator, num_inference_steps=num_inference_steps, negative_prompt=negative_prompt, **kwargs)[-1]
135
+ return gen