Spaces:
Running
on
Zero
Running
on
Zero
Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +4 -35
- .gitignore +15 -0
- .gradio/certificate.pem +31 -0
- .vscode/settings.json +13 -0
- Compiler.py +106 -0
- HomeImage.png +3 -0
- LICENSE +674 -0
- README.md +139 -7
- _internal/ESRGAN/put_esrgan_and_other_upscale_models_here +0 -0
- _internal/checkpoints/put_checkpoints_here +0 -0
- _internal/clip/sd1_clip_config.json +25 -0
- _internal/embeddings/put_embeddings_or_textual_inversion_concepts_here +0 -0
- _internal/loras/put_loras_here +0 -0
- _internal/sd1_tokenizer/special_tokens_map.json +24 -0
- _internal/sd1_tokenizer/tokenizer_config.json +34 -0
- _internal/sd1_tokenizer/vocab.json +0 -0
- _internal/yolos/put_yolo_and_seg_files_here +0 -0
- app.py +171 -0
- modules/Attention/Attention.py +191 -0
- modules/Attention/AttentionMethods.py +120 -0
- modules/AutoDetailer/AD_util.py +245 -0
- modules/AutoDetailer/ADetailer.py +952 -0
- modules/AutoDetailer/SAM.py +300 -0
- modules/AutoDetailer/SEGS.py +95 -0
- modules/AutoDetailer/bbox.py +203 -0
- modules/AutoDetailer/mask_util.py +80 -0
- modules/AutoDetailer/tensor_util.py +253 -0
- modules/AutoEncoders/ResBlock.py +406 -0
- modules/AutoEncoders/VariationalAE.py +824 -0
- modules/AutoEncoders/taesd.py +310 -0
- modules/BlackForest/Flux.py +853 -0
- modules/Device/Device.py +1602 -0
- modules/FileManaging/Downloader.py +116 -0
- modules/FileManaging/ImageSaver.py +126 -0
- modules/FileManaging/Loader.py +138 -0
- modules/Model/LoRas.py +193 -0
- modules/Model/ModelBase.py +363 -0
- modules/Model/ModelPatcher.py +779 -0
- modules/NeuralNetwork/transformer.py +443 -0
- modules/NeuralNetwork/unet.py +1132 -0
- modules/Quantize/Quantizer.py +1012 -0
- modules/SD15/SD15.py +81 -0
- modules/SD15/SDClip.py +403 -0
- modules/SD15/SDToken.py +450 -0
- modules/StableFast/StableFast.py +274 -0
- modules/UltimateSDUpscale/RDRB.py +471 -0
- modules/UltimateSDUpscale/USDU_upscaler.py +182 -0
- modules/UltimateSDUpscale/USDU_util.py +173 -0
- modules/UltimateSDUpscale/UltimateSDUpscale.py +1019 -0
- modules/UltimateSDUpscale/image_util.py +265 -0
.gitattributes
CHANGED
@@ -1,35 +1,4 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
1 |
+
# Auto detect text files and perform LF normalization
|
2 |
+
* text=auto
|
3 |
+
HomeImage.png filter=lfs diff=lfs merge=lfs -text
|
4 |
+
stable_fast-1.0.5+torch222cu121-cp310-cp310-manylinux2014_x86_64.whl filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.gitignore
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
*.pyc
|
3 |
+
*.pth
|
4 |
+
*.pt
|
5 |
+
*.safetensors
|
6 |
+
*.gguf
|
7 |
+
*.png
|
8 |
+
/.idea
|
9 |
+
/htmlcov
|
10 |
+
.coverage
|
11 |
+
.toml
|
12 |
+
__pycache__
|
13 |
+
.venv
|
14 |
+
!HomeImage.png
|
15 |
+
*.txt
|
.gradio/certificate.pem
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
-----BEGIN CERTIFICATE-----
|
2 |
+
MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
|
3 |
+
TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
|
4 |
+
cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
|
5 |
+
WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
|
6 |
+
ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
|
7 |
+
MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
|
8 |
+
h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
|
9 |
+
0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
|
10 |
+
A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
|
11 |
+
T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
|
12 |
+
B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
|
13 |
+
B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
|
14 |
+
KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
|
15 |
+
OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
|
16 |
+
jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
|
17 |
+
qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
|
18 |
+
rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
|
19 |
+
HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
|
20 |
+
hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
|
21 |
+
ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
|
22 |
+
3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
|
23 |
+
NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
|
24 |
+
ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
|
25 |
+
TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
|
26 |
+
jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
|
27 |
+
oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
|
28 |
+
4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
|
29 |
+
mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
|
30 |
+
emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
|
31 |
+
-----END CERTIFICATE-----
|
.vscode/settings.json
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"python.testing.unittestArgs": [
|
3 |
+
"-v",
|
4 |
+
"-s",
|
5 |
+
".",
|
6 |
+
"-p",
|
7 |
+
"*test.py"
|
8 |
+
],
|
9 |
+
"python.testing.pytestEnabled": false,
|
10 |
+
"python.testing.unittestEnabled": true,
|
11 |
+
"python.analysis.autoImportCompletions": true,
|
12 |
+
"python.analysis.typeCheckingMode": "off"
|
13 |
+
}
|
Compiler.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
|
4 |
+
files_ordered = [
|
5 |
+
"./modules/Utilities/util.py",
|
6 |
+
"./modules/sample/sampling_util.py",
|
7 |
+
"./modules/Device/Device.py",
|
8 |
+
"./modules/cond/cond_util.py",
|
9 |
+
"./modules/cond/cond.py",
|
10 |
+
"./modules/sample/ksampler_util.py",
|
11 |
+
"./modules/cond/cast.py",
|
12 |
+
"./modules/Attention/AttentionMethods.py",
|
13 |
+
"./modules/AutoEncoders/taesd.py",
|
14 |
+
"./modules/cond/cond.py",
|
15 |
+
"./modules/cond/Activation.py",
|
16 |
+
"./modules/Attention/Attention.py",
|
17 |
+
"./modules/sample/samplers.py",
|
18 |
+
"./modules/sample/CFG.py",
|
19 |
+
"./modules/NeuralNetwork/transformer.py",
|
20 |
+
"./modules/sample/sampling.py",
|
21 |
+
"./modules/clip/CLIPTextModel.py",
|
22 |
+
"./modules/AutoEncoders/ResBlock.py",
|
23 |
+
"./modules/AutoDetailer/mask_util.py",
|
24 |
+
"./modules/NeuralNetwork/unet.py",
|
25 |
+
"./modules/SD15/SDClip.py",
|
26 |
+
"./modules/SD15/SDToken.py",
|
27 |
+
"./modules/UltimateSDUpscale/USDU_util.py",
|
28 |
+
"./modules/StableFast/SF_util.py",
|
29 |
+
"./modules/Utilities/Latent.py",
|
30 |
+
"./modules/AutoDetailer/SEGS.py",
|
31 |
+
"./modules/AutoDetailer/tensor_util.py",
|
32 |
+
"./modules/AutoDetailer/AD_util.py",
|
33 |
+
"./modules/clip/FluxClip.py",
|
34 |
+
"./modules/Model/ModelPatcher.py",
|
35 |
+
"./modules/Model/ModelBase.py",
|
36 |
+
"./modules/UltimateSDUpscale/image_util.py",
|
37 |
+
"./modules/UltimateSDUpscale/RDRB.py",
|
38 |
+
"./modules/StableFast/ModuleFactory.py",
|
39 |
+
"./modules/AutoDetailer/bbox.py",
|
40 |
+
"./modules/AutoEncoders/VariationalAE.py",
|
41 |
+
"./modules/clip/Clip.py",
|
42 |
+
"./modules/Model/LoRas.py",
|
43 |
+
"./modules/BlackForest/Flux.py",
|
44 |
+
"./modules/UltimateSDUpscale/USDU_upscaler.py",
|
45 |
+
"./modules/StableFast/ModuleTracing.py",
|
46 |
+
"./modules/hidiffusion/utils.py",
|
47 |
+
"./modules/FileManaging/Downloader.py",
|
48 |
+
"./modules/AutoDetailer/SAM.py",
|
49 |
+
"./modules/AutoDetailer/ADetailer.py",
|
50 |
+
"./modules/Quantize/Quantizer.py",
|
51 |
+
"./modules/FileManaging/Loader.py",
|
52 |
+
"./modules/SD15/SD15.py",
|
53 |
+
"./modules/UltimateSDUpscale/UltimateSDUpscale.py",
|
54 |
+
"./modules/StableFast/StableFast.py",
|
55 |
+
"./modules/hidiffusion/msw_msa_attention.py",
|
56 |
+
"./modules/FileManaging/ImageSaver.py",
|
57 |
+
"./modules/Utilities/Enhancer.py",
|
58 |
+
"./modules/Utilities/upscale.py",
|
59 |
+
"./modules/user/pipeline.py",
|
60 |
+
]
|
61 |
+
|
62 |
+
def get_file_patterns():
|
63 |
+
patterns = []
|
64 |
+
seen = set()
|
65 |
+
for path in files_ordered:
|
66 |
+
filename = os.path.basename(path)
|
67 |
+
name = os.path.splitext(filename)[0]
|
68 |
+
if name not in seen:
|
69 |
+
# Pattern 1: matches module name when not in brackets or after a dot
|
70 |
+
pattern1 = rf'(?<![a-zA-Z0-9_\.])({name}\.)(?![)\]])'
|
71 |
+
# Pattern 2: matches module name inside brackets while preserving them
|
72 |
+
pattern2 = rf'(\[|\()({name}\.)([^\]\)]+?)(\]|\))'
|
73 |
+
pattern3 = 'cond_util\.'
|
74 |
+
patterns.extend([
|
75 |
+
(pattern1, ''), # Remove module name and dot outside brackets
|
76 |
+
(pattern2, r'\1\3\4'), # Keep brackets, remove only module name
|
77 |
+
(pattern3, '')
|
78 |
+
])
|
79 |
+
seen.add(name)
|
80 |
+
return patterns
|
81 |
+
|
82 |
+
def remove_file_names(line):
|
83 |
+
patterns = get_file_patterns()
|
84 |
+
result = line
|
85 |
+
for pattern, replacement in patterns:
|
86 |
+
result = re.sub(pattern, replacement, result)
|
87 |
+
return result
|
88 |
+
|
89 |
+
try:
|
90 |
+
with open("./compiled.py", "w") as output_file:
|
91 |
+
for file_path in files_ordered:
|
92 |
+
try:
|
93 |
+
with open(file_path, "r") as input_file:
|
94 |
+
for line in input_file:
|
95 |
+
if not line.lstrip().startswith("from modules."):
|
96 |
+
# Apply the file name removal before writing
|
97 |
+
modified_line = remove_file_names(line)
|
98 |
+
output_file.write(modified_line)
|
99 |
+
output_file.write("\n\n")
|
100 |
+
print(f"Processed: {file_path}")
|
101 |
+
except FileNotFoundError:
|
102 |
+
print(f"Error: Could not find file {file_path}")
|
103 |
+
except Exception as e:
|
104 |
+
print(f"Error processing {file_path}: {str(e)}")
|
105 |
+
except Exception as e:
|
106 |
+
print(f"Error creating compiled.py: {str(e)}")
|
HomeImage.png
ADDED
![]() |
Git LFS Details
|
LICENSE
ADDED
@@ -0,0 +1,674 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
GNU GENERAL PUBLIC LICENSE
|
2 |
+
Version 3, 29 June 2007
|
3 |
+
|
4 |
+
Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
|
5 |
+
Everyone is permitted to copy and distribute verbatim copies
|
6 |
+
of this license document, but changing it is not allowed.
|
7 |
+
|
8 |
+
Preamble
|
9 |
+
|
10 |
+
The GNU General Public License is a free, copyleft license for
|
11 |
+
software and other kinds of works.
|
12 |
+
|
13 |
+
The licenses for most software and other practical works are designed
|
14 |
+
to take away your freedom to share and change the works. By contrast,
|
15 |
+
the GNU General Public License is intended to guarantee your freedom to
|
16 |
+
share and change all versions of a program--to make sure it remains free
|
17 |
+
software for all its users. We, the Free Software Foundation, use the
|
18 |
+
GNU General Public License for most of our software; it applies also to
|
19 |
+
any other work released this way by its authors. You can apply it to
|
20 |
+
your programs, too.
|
21 |
+
|
22 |
+
When we speak of free software, we are referring to freedom, not
|
23 |
+
price. Our General Public Licenses are designed to make sure that you
|
24 |
+
have the freedom to distribute copies of free software (and charge for
|
25 |
+
them if you wish), that you receive source code or can get it if you
|
26 |
+
want it, that you can change the software or use pieces of it in new
|
27 |
+
free programs, and that you know you can do these things.
|
28 |
+
|
29 |
+
To protect your rights, we need to prevent others from denying you
|
30 |
+
these rights or asking you to surrender the rights. Therefore, you have
|
31 |
+
certain responsibilities if you distribute copies of the software, or if
|
32 |
+
you modify it: responsibilities to respect the freedom of others.
|
33 |
+
|
34 |
+
For example, if you distribute copies of such a program, whether
|
35 |
+
gratis or for a fee, you must pass on to the recipients the same
|
36 |
+
freedoms that you received. You must make sure that they, too, receive
|
37 |
+
or can get the source code. And you must show them these terms so they
|
38 |
+
know their rights.
|
39 |
+
|
40 |
+
Developers that use the GNU GPL protect your rights with two steps:
|
41 |
+
(1) assert copyright on the software, and (2) offer you this License
|
42 |
+
giving you legal permission to copy, distribute and/or modify it.
|
43 |
+
|
44 |
+
For the developers' and authors' protection, the GPL clearly explains
|
45 |
+
that there is no warranty for this free software. For both users' and
|
46 |
+
authors' sake, the GPL requires that modified versions be marked as
|
47 |
+
changed, so that their problems will not be attributed erroneously to
|
48 |
+
authors of previous versions.
|
49 |
+
|
50 |
+
Some devices are designed to deny users access to install or run
|
51 |
+
modified versions of the software inside them, although the manufacturer
|
52 |
+
can do so. This is fundamentally incompatible with the aim of
|
53 |
+
protecting users' freedom to change the software. The systematic
|
54 |
+
pattern of such abuse occurs in the area of products for individuals to
|
55 |
+
use, which is precisely where it is most unacceptable. Therefore, we
|
56 |
+
have designed this version of the GPL to prohibit the practice for those
|
57 |
+
products. If such problems arise substantially in other domains, we
|
58 |
+
stand ready to extend this provision to those domains in future versions
|
59 |
+
of the GPL, as needed to protect the freedom of users.
|
60 |
+
|
61 |
+
Finally, every program is threatened constantly by software patents.
|
62 |
+
States should not allow patents to restrict development and use of
|
63 |
+
software on general-purpose computers, but in those that do, we wish to
|
64 |
+
avoid the special danger that patents applied to a free program could
|
65 |
+
make it effectively proprietary. To prevent this, the GPL assures that
|
66 |
+
patents cannot be used to render the program non-free.
|
67 |
+
|
68 |
+
The precise terms and conditions for copying, distribution and
|
69 |
+
modification follow.
|
70 |
+
|
71 |
+
TERMS AND CONDITIONS
|
72 |
+
|
73 |
+
0. Definitions.
|
74 |
+
|
75 |
+
"This License" refers to version 3 of the GNU General Public License.
|
76 |
+
|
77 |
+
"Copyright" also means copyright-like laws that apply to other kinds of
|
78 |
+
works, such as semiconductor masks.
|
79 |
+
|
80 |
+
"The Program" refers to any copyrightable work licensed under this
|
81 |
+
License. Each licensee is addressed as "you". "Licensees" and
|
82 |
+
"recipients" may be individuals or organizations.
|
83 |
+
|
84 |
+
To "modify" a work means to copy from or adapt all or part of the work
|
85 |
+
in a fashion requiring copyright permission, other than the making of an
|
86 |
+
exact copy. The resulting work is called a "modified version" of the
|
87 |
+
earlier work or a work "based on" the earlier work.
|
88 |
+
|
89 |
+
A "covered work" means either the unmodified Program or a work based
|
90 |
+
on the Program.
|
91 |
+
|
92 |
+
To "propagate" a work means to do anything with it that, without
|
93 |
+
permission, would make you directly or secondarily liable for
|
94 |
+
infringement under applicable copyright law, except executing it on a
|
95 |
+
computer or modifying a private copy. Propagation includes copying,
|
96 |
+
distribution (with or without modification), making available to the
|
97 |
+
public, and in some countries other activities as well.
|
98 |
+
|
99 |
+
To "convey" a work means any kind of propagation that enables other
|
100 |
+
parties to make or receive copies. Mere interaction with a user through
|
101 |
+
a computer network, with no transfer of a copy, is not conveying.
|
102 |
+
|
103 |
+
An interactive user interface displays "Appropriate Legal Notices"
|
104 |
+
to the extent that it includes a convenient and prominently visible
|
105 |
+
feature that (1) displays an appropriate copyright notice, and (2)
|
106 |
+
tells the user that there is no warranty for the work (except to the
|
107 |
+
extent that warranties are provided), that licensees may convey the
|
108 |
+
work under this License, and how to view a copy of this License. If
|
109 |
+
the interface presents a list of user commands or options, such as a
|
110 |
+
menu, a prominent item in the list meets this criterion.
|
111 |
+
|
112 |
+
1. Source Code.
|
113 |
+
|
114 |
+
The "source code" for a work means the preferred form of the work
|
115 |
+
for making modifications to it. "Object code" means any non-source
|
116 |
+
form of a work.
|
117 |
+
|
118 |
+
A "Standard Interface" means an interface that either is an official
|
119 |
+
standard defined by a recognized standards body, or, in the case of
|
120 |
+
interfaces specified for a particular programming language, one that
|
121 |
+
is widely used among developers working in that language.
|
122 |
+
|
123 |
+
The "System Libraries" of an executable work include anything, other
|
124 |
+
than the work as a whole, that (a) is included in the normal form of
|
125 |
+
packaging a Major Component, but which is not part of that Major
|
126 |
+
Component, and (b) serves only to enable use of the work with that
|
127 |
+
Major Component, or to implement a Standard Interface for which an
|
128 |
+
implementation is available to the public in source code form. A
|
129 |
+
"Major Component", in this context, means a major essential component
|
130 |
+
(kernel, window system, and so on) of the specific operating system
|
131 |
+
(if any) on which the executable work runs, or a compiler used to
|
132 |
+
produce the work, or an object code interpreter used to run it.
|
133 |
+
|
134 |
+
The "Corresponding Source" for a work in object code form means all
|
135 |
+
the source code needed to generate, install, and (for an executable
|
136 |
+
work) run the object code and to modify the work, including scripts to
|
137 |
+
control those activities. However, it does not include the work's
|
138 |
+
System Libraries, or general-purpose tools or generally available free
|
139 |
+
programs which are used unmodified in performing those activities but
|
140 |
+
which are not part of the work. For example, Corresponding Source
|
141 |
+
includes interface definition files associated with source files for
|
142 |
+
the work, and the source code for shared libraries and dynamically
|
143 |
+
linked subprograms that the work is specifically designed to require,
|
144 |
+
such as by intimate data communication or control flow between those
|
145 |
+
subprograms and other parts of the work.
|
146 |
+
|
147 |
+
The Corresponding Source need not include anything that users
|
148 |
+
can regenerate automatically from other parts of the Corresponding
|
149 |
+
Source.
|
150 |
+
|
151 |
+
The Corresponding Source for a work in source code form is that
|
152 |
+
same work.
|
153 |
+
|
154 |
+
2. Basic Permissions.
|
155 |
+
|
156 |
+
All rights granted under this License are granted for the term of
|
157 |
+
copyright on the Program, and are irrevocable provided the stated
|
158 |
+
conditions are met. This License explicitly affirms your unlimited
|
159 |
+
permission to run the unmodified Program. The output from running a
|
160 |
+
covered work is covered by this License only if the output, given its
|
161 |
+
content, constitutes a covered work. This License acknowledges your
|
162 |
+
rights of fair use or other equivalent, as provided by copyright law.
|
163 |
+
|
164 |
+
You may make, run and propagate covered works that you do not
|
165 |
+
convey, without conditions so long as your license otherwise remains
|
166 |
+
in force. You may convey covered works to others for the sole purpose
|
167 |
+
of having them make modifications exclusively for you, or provide you
|
168 |
+
with facilities for running those works, provided that you comply with
|
169 |
+
the terms of this License in conveying all material for which you do
|
170 |
+
not control copyright. Those thus making or running the covered works
|
171 |
+
for you must do so exclusively on your behalf, under your direction
|
172 |
+
and control, on terms that prohibit them from making any copies of
|
173 |
+
your copyrighted material outside their relationship with you.
|
174 |
+
|
175 |
+
Conveying under any other circumstances is permitted solely under
|
176 |
+
the conditions stated below. Sublicensing is not allowed; section 10
|
177 |
+
makes it unnecessary.
|
178 |
+
|
179 |
+
3. Protecting Users' Legal Rights From Anti-Circumvention Law.
|
180 |
+
|
181 |
+
No covered work shall be deemed part of an effective technological
|
182 |
+
measure under any applicable law fulfilling obligations under article
|
183 |
+
11 of the WIPO copyright treaty adopted on 20 December 1996, or
|
184 |
+
similar laws prohibiting or restricting circumvention of such
|
185 |
+
measures.
|
186 |
+
|
187 |
+
When you convey a covered work, you waive any legal power to forbid
|
188 |
+
circumvention of technological measures to the extent such circumvention
|
189 |
+
is effected by exercising rights under this License with respect to
|
190 |
+
the covered work, and you disclaim any intention to limit operation or
|
191 |
+
modification of the work as a means of enforcing, against the work's
|
192 |
+
users, your or third parties' legal rights to forbid circumvention of
|
193 |
+
technological measures.
|
194 |
+
|
195 |
+
4. Conveying Verbatim Copies.
|
196 |
+
|
197 |
+
You may convey verbatim copies of the Program's source code as you
|
198 |
+
receive it, in any medium, provided that you conspicuously and
|
199 |
+
appropriately publish on each copy an appropriate copyright notice;
|
200 |
+
keep intact all notices stating that this License and any
|
201 |
+
non-permissive terms added in accord with section 7 apply to the code;
|
202 |
+
keep intact all notices of the absence of any warranty; and give all
|
203 |
+
recipients a copy of this License along with the Program.
|
204 |
+
|
205 |
+
You may charge any price or no price for each copy that you convey,
|
206 |
+
and you may offer support or warranty protection for a fee.
|
207 |
+
|
208 |
+
5. Conveying Modified Source Versions.
|
209 |
+
|
210 |
+
You may convey a work based on the Program, or the modifications to
|
211 |
+
produce it from the Program, in the form of source code under the
|
212 |
+
terms of section 4, provided that you also meet all of these conditions:
|
213 |
+
|
214 |
+
a) The work must carry prominent notices stating that you modified
|
215 |
+
it, and giving a relevant date.
|
216 |
+
|
217 |
+
b) The work must carry prominent notices stating that it is
|
218 |
+
released under this License and any conditions added under section
|
219 |
+
7. This requirement modifies the requirement in section 4 to
|
220 |
+
"keep intact all notices".
|
221 |
+
|
222 |
+
c) You must license the entire work, as a whole, under this
|
223 |
+
License to anyone who comes into possession of a copy. This
|
224 |
+
License will therefore apply, along with any applicable section 7
|
225 |
+
additional terms, to the whole of the work, and all its parts,
|
226 |
+
regardless of how they are packaged. This License gives no
|
227 |
+
permission to license the work in any other way, but it does not
|
228 |
+
invalidate such permission if you have separately received it.
|
229 |
+
|
230 |
+
d) If the work has interactive user interfaces, each must display
|
231 |
+
Appropriate Legal Notices; however, if the Program has interactive
|
232 |
+
interfaces that do not display Appropriate Legal Notices, your
|
233 |
+
work need not make them do so.
|
234 |
+
|
235 |
+
A compilation of a covered work with other separate and independent
|
236 |
+
works, which are not by their nature extensions of the covered work,
|
237 |
+
and which are not combined with it such as to form a larger program,
|
238 |
+
in or on a volume of a storage or distribution medium, is called an
|
239 |
+
"aggregate" if the compilation and its resulting copyright are not
|
240 |
+
used to limit the access or legal rights of the compilation's users
|
241 |
+
beyond what the individual works permit. Inclusion of a covered work
|
242 |
+
in an aggregate does not cause this License to apply to the other
|
243 |
+
parts of the aggregate.
|
244 |
+
|
245 |
+
6. Conveying Non-Source Forms.
|
246 |
+
|
247 |
+
You may convey a covered work in object code form under the terms
|
248 |
+
of sections 4 and 5, provided that you also convey the
|
249 |
+
machine-readable Corresponding Source under the terms of this License,
|
250 |
+
in one of these ways:
|
251 |
+
|
252 |
+
a) Convey the object code in, or embodied in, a physical product
|
253 |
+
(including a physical distribution medium), accompanied by the
|
254 |
+
Corresponding Source fixed on a durable physical medium
|
255 |
+
customarily used for software interchange.
|
256 |
+
|
257 |
+
b) Convey the object code in, or embodied in, a physical product
|
258 |
+
(including a physical distribution medium), accompanied by a
|
259 |
+
written offer, valid for at least three years and valid for as
|
260 |
+
long as you offer spare parts or customer support for that product
|
261 |
+
model, to give anyone who possesses the object code either (1) a
|
262 |
+
copy of the Corresponding Source for all the software in the
|
263 |
+
product that is covered by this License, on a durable physical
|
264 |
+
medium customarily used for software interchange, for a price no
|
265 |
+
more than your reasonable cost of physically performing this
|
266 |
+
conveying of source, or (2) access to copy the
|
267 |
+
Corresponding Source from a network server at no charge.
|
268 |
+
|
269 |
+
c) Convey individual copies of the object code with a copy of the
|
270 |
+
written offer to provide the Corresponding Source. This
|
271 |
+
alternative is allowed only occasionally and noncommercially, and
|
272 |
+
only if you received the object code with such an offer, in accord
|
273 |
+
with subsection 6b.
|
274 |
+
|
275 |
+
d) Convey the object code by offering access from a designated
|
276 |
+
place (gratis or for a charge), and offer equivalent access to the
|
277 |
+
Corresponding Source in the same way through the same place at no
|
278 |
+
further charge. You need not require recipients to copy the
|
279 |
+
Corresponding Source along with the object code. If the place to
|
280 |
+
copy the object code is a network server, the Corresponding Source
|
281 |
+
may be on a different server (operated by you or a third party)
|
282 |
+
that supports equivalent copying facilities, provided you maintain
|
283 |
+
clear directions next to the object code saying where to find the
|
284 |
+
Corresponding Source. Regardless of what server hosts the
|
285 |
+
Corresponding Source, you remain obligated to ensure that it is
|
286 |
+
available for as long as needed to satisfy these requirements.
|
287 |
+
|
288 |
+
e) Convey the object code using peer-to-peer transmission, provided
|
289 |
+
you inform other peers where the object code and Corresponding
|
290 |
+
Source of the work are being offered to the general public at no
|
291 |
+
charge under subsection 6d.
|
292 |
+
|
293 |
+
A separable portion of the object code, whose source code is excluded
|
294 |
+
from the Corresponding Source as a System Library, need not be
|
295 |
+
included in conveying the object code work.
|
296 |
+
|
297 |
+
A "User Product" is either (1) a "consumer product", which means any
|
298 |
+
tangible personal property which is normally used for personal, family,
|
299 |
+
or household purposes, or (2) anything designed or sold for incorporation
|
300 |
+
into a dwelling. In determining whether a product is a consumer product,
|
301 |
+
doubtful cases shall be resolved in favor of coverage. For a particular
|
302 |
+
product received by a particular user, "normally used" refers to a
|
303 |
+
typical or common use of that class of product, regardless of the status
|
304 |
+
of the particular user or of the way in which the particular user
|
305 |
+
actually uses, or expects or is expected to use, the product. A product
|
306 |
+
is a consumer product regardless of whether the product has substantial
|
307 |
+
commercial, industrial or non-consumer uses, unless such uses represent
|
308 |
+
the only significant mode of use of the product.
|
309 |
+
|
310 |
+
"Installation Information" for a User Product means any methods,
|
311 |
+
procedures, authorization keys, or other information required to install
|
312 |
+
and execute modified versions of a covered work in that User Product from
|
313 |
+
a modified version of its Corresponding Source. The information must
|
314 |
+
suffice to ensure that the continued functioning of the modified object
|
315 |
+
code is in no case prevented or interfered with solely because
|
316 |
+
modification has been made.
|
317 |
+
|
318 |
+
If you convey an object code work under this section in, or with, or
|
319 |
+
specifically for use in, a User Product, and the conveying occurs as
|
320 |
+
part of a transaction in which the right of possession and use of the
|
321 |
+
User Product is transferred to the recipient in perpetuity or for a
|
322 |
+
fixed term (regardless of how the transaction is characterized), the
|
323 |
+
Corresponding Source conveyed under this section must be accompanied
|
324 |
+
by the Installation Information. But this requirement does not apply
|
325 |
+
if neither you nor any third party retains the ability to install
|
326 |
+
modified object code on the User Product (for example, the work has
|
327 |
+
been installed in ROM).
|
328 |
+
|
329 |
+
The requirement to provide Installation Information does not include a
|
330 |
+
requirement to continue to provide support service, warranty, or updates
|
331 |
+
for a work that has been modified or installed by the recipient, or for
|
332 |
+
the User Product in which it has been modified or installed. Access to a
|
333 |
+
network may be denied when the modification itself materially and
|
334 |
+
adversely affects the operation of the network or violates the rules and
|
335 |
+
protocols for communication across the network.
|
336 |
+
|
337 |
+
Corresponding Source conveyed, and Installation Information provided,
|
338 |
+
in accord with this section must be in a format that is publicly
|
339 |
+
documented (and with an implementation available to the public in
|
340 |
+
source code form), and must require no special password or key for
|
341 |
+
unpacking, reading or copying.
|
342 |
+
|
343 |
+
7. Additional Terms.
|
344 |
+
|
345 |
+
"Additional permissions" are terms that supplement the terms of this
|
346 |
+
License by making exceptions from one or more of its conditions.
|
347 |
+
Additional permissions that are applicable to the entire Program shall
|
348 |
+
be treated as though they were included in this License, to the extent
|
349 |
+
that they are valid under applicable law. If additional permissions
|
350 |
+
apply only to part of the Program, that part may be used separately
|
351 |
+
under those permissions, but the entire Program remains governed by
|
352 |
+
this License without regard to the additional permissions.
|
353 |
+
|
354 |
+
When you convey a copy of a covered work, you may at your option
|
355 |
+
remove any additional permissions from that copy, or from any part of
|
356 |
+
it. (Additional permissions may be written to require their own
|
357 |
+
removal in certain cases when you modify the work.) You may place
|
358 |
+
additional permissions on material, added by you to a covered work,
|
359 |
+
for which you have or can give appropriate copyright permission.
|
360 |
+
|
361 |
+
Notwithstanding any other provision of this License, for material you
|
362 |
+
add to a covered work, you may (if authorized by the copyright holders of
|
363 |
+
that material) supplement the terms of this License with terms:
|
364 |
+
|
365 |
+
a) Disclaiming warranty or limiting liability differently from the
|
366 |
+
terms of sections 15 and 16 of this License; or
|
367 |
+
|
368 |
+
b) Requiring preservation of specified reasonable legal notices or
|
369 |
+
author attributions in that material or in the Appropriate Legal
|
370 |
+
Notices displayed by works containing it; or
|
371 |
+
|
372 |
+
c) Prohibiting misrepresentation of the origin of that material, or
|
373 |
+
requiring that modified versions of such material be marked in
|
374 |
+
reasonable ways as different from the original version; or
|
375 |
+
|
376 |
+
d) Limiting the use for publicity purposes of names of licensors or
|
377 |
+
authors of the material; or
|
378 |
+
|
379 |
+
e) Declining to grant rights under trademark law for use of some
|
380 |
+
trade names, trademarks, or service marks; or
|
381 |
+
|
382 |
+
f) Requiring indemnification of licensors and authors of that
|
383 |
+
material by anyone who conveys the material (or modified versions of
|
384 |
+
it) with contractual assumptions of liability to the recipient, for
|
385 |
+
any liability that these contractual assumptions directly impose on
|
386 |
+
those licensors and authors.
|
387 |
+
|
388 |
+
All other non-permissive additional terms are considered "further
|
389 |
+
restrictions" within the meaning of section 10. If the Program as you
|
390 |
+
received it, or any part of it, contains a notice stating that it is
|
391 |
+
governed by this License along with a term that is a further
|
392 |
+
restriction, you may remove that term. If a license document contains
|
393 |
+
a further restriction but permits relicensing or conveying under this
|
394 |
+
License, you may add to a covered work material governed by the terms
|
395 |
+
of that license document, provided that the further restriction does
|
396 |
+
not survive such relicensing or conveying.
|
397 |
+
|
398 |
+
If you add terms to a covered work in accord with this section, you
|
399 |
+
must place, in the relevant source files, a statement of the
|
400 |
+
additional terms that apply to those files, or a notice indicating
|
401 |
+
where to find the applicable terms.
|
402 |
+
|
403 |
+
Additional terms, permissive or non-permissive, may be stated in the
|
404 |
+
form of a separately written license, or stated as exceptions;
|
405 |
+
the above requirements apply either way.
|
406 |
+
|
407 |
+
8. Termination.
|
408 |
+
|
409 |
+
You may not propagate or modify a covered work except as expressly
|
410 |
+
provided under this License. Any attempt otherwise to propagate or
|
411 |
+
modify it is void, and will automatically terminate your rights under
|
412 |
+
this License (including any patent licenses granted under the third
|
413 |
+
paragraph of section 11).
|
414 |
+
|
415 |
+
However, if you cease all violation of this License, then your
|
416 |
+
license from a particular copyright holder is reinstated (a)
|
417 |
+
provisionally, unless and until the copyright holder explicitly and
|
418 |
+
finally terminates your license, and (b) permanently, if the copyright
|
419 |
+
holder fails to notify you of the violation by some reasonable means
|
420 |
+
prior to 60 days after the cessation.
|
421 |
+
|
422 |
+
Moreover, your license from a particular copyright holder is
|
423 |
+
reinstated permanently if the copyright holder notifies you of the
|
424 |
+
violation by some reasonable means, this is the first time you have
|
425 |
+
received notice of violation of this License (for any work) from that
|
426 |
+
copyright holder, and you cure the violation prior to 30 days after
|
427 |
+
your receipt of the notice.
|
428 |
+
|
429 |
+
Termination of your rights under this section does not terminate the
|
430 |
+
licenses of parties who have received copies or rights from you under
|
431 |
+
this License. If your rights have been terminated and not permanently
|
432 |
+
reinstated, you do not qualify to receive new licenses for the same
|
433 |
+
material under section 10.
|
434 |
+
|
435 |
+
9. Acceptance Not Required for Having Copies.
|
436 |
+
|
437 |
+
You are not required to accept this License in order to receive or
|
438 |
+
run a copy of the Program. Ancillary propagation of a covered work
|
439 |
+
occurring solely as a consequence of using peer-to-peer transmission
|
440 |
+
to receive a copy likewise does not require acceptance. However,
|
441 |
+
nothing other than this License grants you permission to propagate or
|
442 |
+
modify any covered work. These actions infringe copyright if you do
|
443 |
+
not accept this License. Therefore, by modifying or propagating a
|
444 |
+
covered work, you indicate your acceptance of this License to do so.
|
445 |
+
|
446 |
+
10. Automatic Licensing of Downstream Recipients.
|
447 |
+
|
448 |
+
Each time you convey a covered work, the recipient automatically
|
449 |
+
receives a license from the original licensors, to run, modify and
|
450 |
+
propagate that work, subject to this License. You are not responsible
|
451 |
+
for enforcing compliance by third parties with this License.
|
452 |
+
|
453 |
+
An "entity transaction" is a transaction transferring control of an
|
454 |
+
organization, or substantially all assets of one, or subdividing an
|
455 |
+
organization, or merging organizations. If propagation of a covered
|
456 |
+
work results from an entity transaction, each party to that
|
457 |
+
transaction who receives a copy of the work also receives whatever
|
458 |
+
licenses to the work the party's predecessor in interest had or could
|
459 |
+
give under the previous paragraph, plus a right to possession of the
|
460 |
+
Corresponding Source of the work from the predecessor in interest, if
|
461 |
+
the predecessor has it or can get it with reasonable efforts.
|
462 |
+
|
463 |
+
You may not impose any further restrictions on the exercise of the
|
464 |
+
rights granted or affirmed under this License. For example, you may
|
465 |
+
not impose a license fee, royalty, or other charge for exercise of
|
466 |
+
rights granted under this License, and you may not initiate litigation
|
467 |
+
(including a cross-claim or counterclaim in a lawsuit) alleging that
|
468 |
+
any patent claim is infringed by making, using, selling, offering for
|
469 |
+
sale, or importing the Program or any portion of it.
|
470 |
+
|
471 |
+
11. Patents.
|
472 |
+
|
473 |
+
A "contributor" is a copyright holder who authorizes use under this
|
474 |
+
License of the Program or a work on which the Program is based. The
|
475 |
+
work thus licensed is called the contributor's "contributor version".
|
476 |
+
|
477 |
+
A contributor's "essential patent claims" are all patent claims
|
478 |
+
owned or controlled by the contributor, whether already acquired or
|
479 |
+
hereafter acquired, that would be infringed by some manner, permitted
|
480 |
+
by this License, of making, using, or selling its contributor version,
|
481 |
+
but do not include claims that would be infringed only as a
|
482 |
+
consequence of further modification of the contributor version. For
|
483 |
+
purposes of this definition, "control" includes the right to grant
|
484 |
+
patent sublicenses in a manner consistent with the requirements of
|
485 |
+
this License.
|
486 |
+
|
487 |
+
Each contributor grants you a non-exclusive, worldwide, royalty-free
|
488 |
+
patent license under the contributor's essential patent claims, to
|
489 |
+
make, use, sell, offer for sale, import and otherwise run, modify and
|
490 |
+
propagate the contents of its contributor version.
|
491 |
+
|
492 |
+
In the following three paragraphs, a "patent license" is any express
|
493 |
+
agreement or commitment, however denominated, not to enforce a patent
|
494 |
+
(such as an express permission to practice a patent or covenant not to
|
495 |
+
sue for patent infringement). To "grant" such a patent license to a
|
496 |
+
party means to make such an agreement or commitment not to enforce a
|
497 |
+
patent against the party.
|
498 |
+
|
499 |
+
If you convey a covered work, knowingly relying on a patent license,
|
500 |
+
and the Corresponding Source of the work is not available for anyone
|
501 |
+
to copy, free of charge and under the terms of this License, through a
|
502 |
+
publicly available network server or other readily accessible means,
|
503 |
+
then you must either (1) cause the Corresponding Source to be so
|
504 |
+
available, or (2) arrange to deprive yourself of the benefit of the
|
505 |
+
patent license for this particular work, or (3) arrange, in a manner
|
506 |
+
consistent with the requirements of this License, to extend the patent
|
507 |
+
license to downstream recipients. "Knowingly relying" means you have
|
508 |
+
actual knowledge that, but for the patent license, your conveying the
|
509 |
+
covered work in a country, or your recipient's use of the covered work
|
510 |
+
in a country, would infringe one or more identifiable patents in that
|
511 |
+
country that you have reason to believe are valid.
|
512 |
+
|
513 |
+
If, pursuant to or in connection with a single transaction or
|
514 |
+
arrangement, you convey, or propagate by procuring conveyance of, a
|
515 |
+
covered work, and grant a patent license to some of the parties
|
516 |
+
receiving the covered work authorizing them to use, propagate, modify
|
517 |
+
or convey a specific copy of the covered work, then the patent license
|
518 |
+
you grant is automatically extended to all recipients of the covered
|
519 |
+
work and works based on it.
|
520 |
+
|
521 |
+
A patent license is "discriminatory" if it does not include within
|
522 |
+
the scope of its coverage, prohibits the exercise of, or is
|
523 |
+
conditioned on the non-exercise of one or more of the rights that are
|
524 |
+
specifically granted under this License. You may not convey a covered
|
525 |
+
work if you are a party to an arrangement with a third party that is
|
526 |
+
in the business of distributing software, under which you make payment
|
527 |
+
to the third party based on the extent of your activity of conveying
|
528 |
+
the work, and under which the third party grants, to any of the
|
529 |
+
parties who would receive the covered work from you, a discriminatory
|
530 |
+
patent license (a) in connection with copies of the covered work
|
531 |
+
conveyed by you (or copies made from those copies), or (b) primarily
|
532 |
+
for and in connection with specific products or compilations that
|
533 |
+
contain the covered work, unless you entered into that arrangement,
|
534 |
+
or that patent license was granted, prior to 28 March 2007.
|
535 |
+
|
536 |
+
Nothing in this License shall be construed as excluding or limiting
|
537 |
+
any implied license or other defenses to infringement that may
|
538 |
+
otherwise be available to you under applicable patent law.
|
539 |
+
|
540 |
+
12. No Surrender of Others' Freedom.
|
541 |
+
|
542 |
+
If conditions are imposed on you (whether by court order, agreement or
|
543 |
+
otherwise) that contradict the conditions of this License, they do not
|
544 |
+
excuse you from the conditions of this License. If you cannot convey a
|
545 |
+
covered work so as to satisfy simultaneously your obligations under this
|
546 |
+
License and any other pertinent obligations, then as a consequence you may
|
547 |
+
not convey it at all. For example, if you agree to terms that obligate you
|
548 |
+
to collect a royalty for further conveying from those to whom you convey
|
549 |
+
the Program, the only way you could satisfy both those terms and this
|
550 |
+
License would be to refrain entirely from conveying the Program.
|
551 |
+
|
552 |
+
13. Use with the GNU Affero General Public License.
|
553 |
+
|
554 |
+
Notwithstanding any other provision of this License, you have
|
555 |
+
permission to link or combine any covered work with a work licensed
|
556 |
+
under version 3 of the GNU Affero General Public License into a single
|
557 |
+
combined work, and to convey the resulting work. The terms of this
|
558 |
+
License will continue to apply to the part which is the covered work,
|
559 |
+
but the special requirements of the GNU Affero General Public License,
|
560 |
+
section 13, concerning interaction through a network will apply to the
|
561 |
+
combination as such.
|
562 |
+
|
563 |
+
14. Revised Versions of this License.
|
564 |
+
|
565 |
+
The Free Software Foundation may publish revised and/or new versions of
|
566 |
+
the GNU General Public License from time to time. Such new versions will
|
567 |
+
be similar in spirit to the present version, but may differ in detail to
|
568 |
+
address new problems or concerns.
|
569 |
+
|
570 |
+
Each version is given a distinguishing version number. If the
|
571 |
+
Program specifies that a certain numbered version of the GNU General
|
572 |
+
Public License "or any later version" applies to it, you have the
|
573 |
+
option of following the terms and conditions either of that numbered
|
574 |
+
version or of any later version published by the Free Software
|
575 |
+
Foundation. If the Program does not specify a version number of the
|
576 |
+
GNU General Public License, you may choose any version ever published
|
577 |
+
by the Free Software Foundation.
|
578 |
+
|
579 |
+
If the Program specifies that a proxy can decide which future
|
580 |
+
versions of the GNU General Public License can be used, that proxy's
|
581 |
+
public statement of acceptance of a version permanently authorizes you
|
582 |
+
to choose that version for the Program.
|
583 |
+
|
584 |
+
Later license versions may give you additional or different
|
585 |
+
permissions. However, no additional obligations are imposed on any
|
586 |
+
author or copyright holder as a result of your choosing to follow a
|
587 |
+
later version.
|
588 |
+
|
589 |
+
15. Disclaimer of Warranty.
|
590 |
+
|
591 |
+
THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
|
592 |
+
APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
|
593 |
+
HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
|
594 |
+
OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
|
595 |
+
THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
596 |
+
PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
|
597 |
+
IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
|
598 |
+
ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
|
599 |
+
|
600 |
+
16. Limitation of Liability.
|
601 |
+
|
602 |
+
IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
|
603 |
+
WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
|
604 |
+
THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
|
605 |
+
GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
|
606 |
+
USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
|
607 |
+
DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
|
608 |
+
PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
|
609 |
+
EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
|
610 |
+
SUCH DAMAGES.
|
611 |
+
|
612 |
+
17. Interpretation of Sections 15 and 16.
|
613 |
+
|
614 |
+
If the disclaimer of warranty and limitation of liability provided
|
615 |
+
above cannot be given local legal effect according to their terms,
|
616 |
+
reviewing courts shall apply local law that most closely approximates
|
617 |
+
an absolute waiver of all civil liability in connection with the
|
618 |
+
Program, unless a warranty or assumption of liability accompanies a
|
619 |
+
copy of the Program in return for a fee.
|
620 |
+
|
621 |
+
END OF TERMS AND CONDITIONS
|
622 |
+
|
623 |
+
How to Apply These Terms to Your New Programs
|
624 |
+
|
625 |
+
If you develop a new program, and you want it to be of the greatest
|
626 |
+
possible use to the public, the best way to achieve this is to make it
|
627 |
+
free software which everyone can redistribute and change under these terms.
|
628 |
+
|
629 |
+
To do so, attach the following notices to the program. It is safest
|
630 |
+
to attach them to the start of each source file to most effectively
|
631 |
+
state the exclusion of warranty; and each file should have at least
|
632 |
+
the "copyright" line and a pointer to where the full notice is found.
|
633 |
+
|
634 |
+
<one line to give the program's name and a brief idea of what it does.>
|
635 |
+
Copyright (C) <year> <name of author>
|
636 |
+
|
637 |
+
This program is free software: you can redistribute it and/or modify
|
638 |
+
it under the terms of the GNU General Public License as published by
|
639 |
+
the Free Software Foundation, either version 3 of the License, or
|
640 |
+
(at your option) any later version.
|
641 |
+
|
642 |
+
This program is distributed in the hope that it will be useful,
|
643 |
+
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
644 |
+
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
645 |
+
GNU General Public License for more details.
|
646 |
+
|
647 |
+
You should have received a copy of the GNU General Public License
|
648 |
+
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
649 |
+
|
650 |
+
Also add information on how to contact you by electronic and paper mail.
|
651 |
+
|
652 |
+
If the program does terminal interaction, make it output a short
|
653 |
+
notice like this when it starts in an interactive mode:
|
654 |
+
|
655 |
+
<program> Copyright (C) <year> <name of author>
|
656 |
+
This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
|
657 |
+
This is free software, and you are welcome to redistribute it
|
658 |
+
under certain conditions; type `show c' for details.
|
659 |
+
|
660 |
+
The hypothetical commands `show w' and `show c' should show the appropriate
|
661 |
+
parts of the General Public License. Of course, your program's commands
|
662 |
+
might be different; for a GUI interface, you would use an "about box".
|
663 |
+
|
664 |
+
You should also get your employer (if you work as a programmer) or school,
|
665 |
+
if any, to sign a "copyright disclaimer" for the program, if necessary.
|
666 |
+
For more information on this, and how to apply and follow the GNU GPL, see
|
667 |
+
<https://www.gnu.org/licenses/>.
|
668 |
+
|
669 |
+
The GNU General Public License does not permit incorporating your program
|
670 |
+
into proprietary programs. If your program is a subroutine library, you
|
671 |
+
may consider it more useful to permit linking proprietary applications with
|
672 |
+
the library. If this is what you want to do, use the GNU Lesser General
|
673 |
+
Public License instead of this License. But first, please read
|
674 |
+
<https://www.gnu.org/licenses/why-not-lgpl.html>.
|
README.md
CHANGED
@@ -1,12 +1,144 @@
|
|
1 |
---
|
2 |
-
title: LightDiffusion
|
3 |
-
|
4 |
-
colorFrom: gray
|
5 |
-
colorTo: blue
|
6 |
sdk: gradio
|
7 |
sdk_version: 5.14.0
|
8 |
-
|
9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
---
|
11 |
|
12 |
-
|
|
|
1 |
---
|
2 |
+
title: LightDiffusion-Next
|
3 |
+
app_file: app.py
|
|
|
|
|
4 |
sdk: gradio
|
5 |
sdk_version: 5.14.0
|
6 |
+
---
|
7 |
+
<div align="center">
|
8 |
+
|
9 |
+
# Say hi to LightDiffusion-Next 👋
|
10 |
+
|
11 |
+
|
12 |
+
**LightDiffusion-Next** is the fastest AI-powered image generation GUI/CLI, combining speed, precision, and flexibility in one cohesive tool.
|
13 |
+
</br>
|
14 |
+
</br>
|
15 |
+
<a href="https://github.com/LightDiffusion/LightDiffusion-Next">
|
16 |
+
<img src="./HomeImage.png" alt="Logo">
|
17 |
+
|
18 |
+
</a>
|
19 |
+
</br>
|
20 |
+
</div>
|
21 |
+
|
22 |
+
|
23 |
+
As a refactored and improved version of the original [LightDiffusion repository](https://github.com/Aatrick/LightDiffusion), this project enhances usability, maintainability, and functionality while introducing a host of new features to streamline your creative workflows.
|
24 |
+
|
25 |
+
## Motivation:
|
26 |
+
|
27 |
+
**LightDiffusion** was originally meant to be made in Rust, but due to the lack of support for the Rust language in the AI community, it was made in Python with the goal of being the simplest and fastest AI image generation tool.
|
28 |
+
|
29 |
+
That's when the first version of LightDiffusion was born which only counted [3000 lines of code](https://github.com/LightDiffusion/LightDiffusion-original), only using Pytorch. With time, the [project](https://github.com/Aatrick/LightDiffusion) grew and became more complex, and the need for a refactor was evident. This is where **LightDiffusion-Next** comes in, with a more modular and maintainable codebase, and a plethora of new features and optimizations.
|
30 |
+
|
31 |
+
📚 Learn more in the [official documentation](https://aatrick.github.io/LightDiffusion/).
|
32 |
+
|
33 |
+
---
|
34 |
+
|
35 |
+
## 🌟 Highlights
|
36 |
+
|
37 |
+

|
38 |
+
|
39 |
+
**LightDiffusion-Next** offers a powerful suite of tools to cater to creators at every level. At its core, it supports **Text-to-Image** (Txt2Img) and **Image-to-Image** (Img2Img) generation, offering a variety of upscale methods and samplers, to make it easier to create stunning images with minimal effort.
|
40 |
+
|
41 |
+
Advanced users can take advantage of features like **attention syntax**, **Hires-Fix** or **ADetailer**. These tools provide better quality and flexibility for generating complex and high-resolution outputs.
|
42 |
+
|
43 |
+
**LightDiffusion-Next** is fine-tuned for **performance**. Features such as **Xformers** acceleration, **BFloat16** precision support, **WaveSpeed** dynamic caching, and **Stable-Fast** model compilation (which offers up to a 70% speed boost) ensure smooth and efficient operation, even on demanding workloads.
|
44 |
+
|
45 |
+
---
|
46 |
+
|
47 |
+
## ✨ Feature Showcase
|
48 |
+
|
49 |
+
Here’s what makes LightDiffusion-Next stand out:
|
50 |
+
|
51 |
+
- **Speed and Efficiency**:
|
52 |
+
Enjoy industry-leading performance with built-in Xformers, Pytorch, Wavespeed and Stable-Fast optimizations, achieving up to 30% faster speeds compared to the rest of the AI image generation backends in SD1.5 and up to 2x for Flux.
|
53 |
+
|
54 |
+
- **Automatic Detailing**:
|
55 |
+
Effortlessly enhance faces and body details with AI-driven tools based on the [Impact Pack](https://github.com/ltdrdata/ComfyUI-Impact-Pack).
|
56 |
+
|
57 |
+
- **State Preservation**:
|
58 |
+
Save and resume your progress with saved states, ensuring seamless transitions between sessions.
|
59 |
+
|
60 |
+
- **Advanced GUI and CLI**:
|
61 |
+
Work through a user-friendly graphical interface or leverage the streamlined pipeline for CLI-based workflows.
|
62 |
+
|
63 |
+
- **Integration-Ready**:
|
64 |
+
Collaborate and create directly in Discord with [Boubou](https://github.com/Aatrick/Boubou), or preview images dynamically with the optional **TAESD preview mode**.
|
65 |
+
|
66 |
+
- **Image Previewing**:
|
67 |
+
Get a real-time preview of your generated images with TAESD, allowing for user-friendly and interactive workflows.
|
68 |
+
|
69 |
+
- **Image Upscaling**:
|
70 |
+
Enhance your images with advanced upscaling options like UltimateSDUpscaling, ensuring high-quality results every time.
|
71 |
+
|
72 |
+
- **Prompt Refinement**:
|
73 |
+
Use the Ollama-powered automatic prompt enhancer to refine your prompts and generate more accurate and detailed outputs.
|
74 |
+
|
75 |
+
- **LoRa and Textual Inversion Embeddings**:
|
76 |
+
Leverage LoRa and textual inversion embeddings for highly customized and nuanced results, adding a new dimension to your creative process.
|
77 |
+
|
78 |
+
- **Low-End Device Support**:
|
79 |
+
Run LightDiffusion-Next on low-end devices with as little as 2GB of VRAM or even no GPU, ensuring accessibility for all users.
|
80 |
+
|
81 |
+
---
|
82 |
+
|
83 |
+
## ⚡ Performance Benchmarks
|
84 |
+
|
85 |
+
**LightDiffusion-Next** dominates in performance:
|
86 |
+
|
87 |
+
| **Tool** | **Speed (it/s)** |
|
88 |
+
|------------------------------------|------------------|
|
89 |
+
| **LightDiffusion with Stable-Fast** | 2.8 |
|
90 |
+
| **LightDiffusion** | 1.8 |
|
91 |
+
| **ComfyUI** | 1.4 |
|
92 |
+
| **SDForge** | 1.3 |
|
93 |
+
| **SDWebUI** | 0.9 |
|
94 |
+
|
95 |
+
(All benchmarks are based on a 1024x1024 resolution with a batch size of 1 using BFloat16 precision without tweaking installations. Made with a 3060 mobile GPU using SD1.5.)
|
96 |
+
|
97 |
+
With its unmatched speed and efficiency, LightDiffusion-Next sets the benchmark for AI image generation tools.
|
98 |
+
|
99 |
+
---
|
100 |
+
|
101 |
+
## 🛠 Installation
|
102 |
+
|
103 |
+
### Quick Start
|
104 |
+
|
105 |
+
1. Download a release or clone this repository.
|
106 |
+
2. Run `run.bat` in a terminal.
|
107 |
+
3. Start creating!
|
108 |
+
|
109 |
+
### Command-Line Pipeline
|
110 |
+
|
111 |
+
For a GUI-free experience, use the pipeline:
|
112 |
+
```bash
|
113 |
+
pipeline.bat <prompt> <width> <height> <num_images> <batch_size>
|
114 |
+
```
|
115 |
+
Use `pipeline.bat -h` for more options.
|
116 |
+
|
117 |
+
---
|
118 |
+
|
119 |
+
### Advanced Setup
|
120 |
+
|
121 |
+
- **Install from Source**:
|
122 |
+
Install dependencies via:
|
123 |
+
```bash
|
124 |
+
pip install -r requirements.txt
|
125 |
+
```
|
126 |
+
Add your SD1/1.5 safetensors model to the `checkpoints` directory, then launch the application.
|
127 |
+
|
128 |
+
- **⚡Stable-Fast Optimization**:
|
129 |
+
Follow [this guide](https://github.com/chengzeyi/stable-fast?tab=readme-ov-file#installation) to enable Stable-Fast mode for optimal performance.
|
130 |
+
|
131 |
+
- **🦙 Prompt Enhancer**:
|
132 |
+
Refine your prompts with Ollama:
|
133 |
+
```bash
|
134 |
+
pip install ollama
|
135 |
+
ollama run deepseek-r1
|
136 |
+
```
|
137 |
+
See the [Ollama guide](https://github.com/ollama/ollama?tab=readme-ov-file) for details.
|
138 |
+
|
139 |
+
- **🤖 Discord Integration**:
|
140 |
+
Set up the Discord bot by following the [Boubou installation guide](https://github.com/Aatrick/Boubou).
|
141 |
+
|
142 |
---
|
143 |
|
144 |
+
🎨 Enjoy exploring the powerful features of LightDiffusion-Next!
|
_internal/ESRGAN/put_esrgan_and_other_upscale_models_here
ADDED
File without changes
|
_internal/checkpoints/put_checkpoints_here
ADDED
File without changes
|
_internal/clip/sd1_clip_config.json
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "openai/clip-vit-large-patch14",
|
3 |
+
"architectures": [
|
4 |
+
"CLIPTextModel"
|
5 |
+
],
|
6 |
+
"attention_dropout": 0.0,
|
7 |
+
"bos_token_id": 0,
|
8 |
+
"dropout": 0.0,
|
9 |
+
"eos_token_id": 2,
|
10 |
+
"hidden_act": "quick_gelu",
|
11 |
+
"hidden_size": 768,
|
12 |
+
"initializer_factor": 1.0,
|
13 |
+
"initializer_range": 0.02,
|
14 |
+
"intermediate_size": 3072,
|
15 |
+
"layer_norm_eps": 1e-05,
|
16 |
+
"max_position_embeddings": 77,
|
17 |
+
"model_type": "clip_text_model",
|
18 |
+
"num_attention_heads": 12,
|
19 |
+
"num_hidden_layers": 12,
|
20 |
+
"pad_token_id": 1,
|
21 |
+
"projection_dim": 768,
|
22 |
+
"torch_dtype": "float32",
|
23 |
+
"transformers_version": "4.24.0",
|
24 |
+
"vocab_size": 49408
|
25 |
+
}
|
_internal/embeddings/put_embeddings_or_textual_inversion_concepts_here
ADDED
File without changes
|
_internal/loras/put_loras_here
ADDED
File without changes
|
_internal/sd1_tokenizer/special_tokens_map.json
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"bos_token": {
|
3 |
+
"content": "<|startoftext|>",
|
4 |
+
"lstrip": false,
|
5 |
+
"normalized": true,
|
6 |
+
"rstrip": false,
|
7 |
+
"single_word": false
|
8 |
+
},
|
9 |
+
"eos_token": {
|
10 |
+
"content": "<|endoftext|>",
|
11 |
+
"lstrip": false,
|
12 |
+
"normalized": true,
|
13 |
+
"rstrip": false,
|
14 |
+
"single_word": false
|
15 |
+
},
|
16 |
+
"pad_token": "<|endoftext|>",
|
17 |
+
"unk_token": {
|
18 |
+
"content": "<|endoftext|>",
|
19 |
+
"lstrip": false,
|
20 |
+
"normalized": true,
|
21 |
+
"rstrip": false,
|
22 |
+
"single_word": false
|
23 |
+
}
|
24 |
+
}
|
_internal/sd1_tokenizer/tokenizer_config.json
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"add_prefix_space": false,
|
3 |
+
"bos_token": {
|
4 |
+
"__type": "AddedToken",
|
5 |
+
"content": "<|startoftext|>",
|
6 |
+
"lstrip": false,
|
7 |
+
"normalized": true,
|
8 |
+
"rstrip": false,
|
9 |
+
"single_word": false
|
10 |
+
},
|
11 |
+
"do_lower_case": true,
|
12 |
+
"eos_token": {
|
13 |
+
"__type": "AddedToken",
|
14 |
+
"content": "<|endoftext|>",
|
15 |
+
"lstrip": false,
|
16 |
+
"normalized": true,
|
17 |
+
"rstrip": false,
|
18 |
+
"single_word": false
|
19 |
+
},
|
20 |
+
"errors": "replace",
|
21 |
+
"model_max_length": 77,
|
22 |
+
"name_or_path": "openai/clip-vit-large-patch14",
|
23 |
+
"pad_token": "<|endoftext|>",
|
24 |
+
"special_tokens_map_file": "./special_tokens_map.json",
|
25 |
+
"tokenizer_class": "CLIPTokenizer",
|
26 |
+
"unk_token": {
|
27 |
+
"__type": "AddedToken",
|
28 |
+
"content": "<|endoftext|>",
|
29 |
+
"lstrip": false,
|
30 |
+
"normalized": true,
|
31 |
+
"rstrip": false,
|
32 |
+
"single_word": false
|
33 |
+
}
|
34 |
+
}
|
_internal/sd1_tokenizer/vocab.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
_internal/yolos/put_yolo_and_seg_files_here
ADDED
File without changes
|
app.py
ADDED
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import glob
|
2 |
+
import gradio as gr
|
3 |
+
import sys
|
4 |
+
import os
|
5 |
+
from PIL import Image
|
6 |
+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))
|
7 |
+
|
8 |
+
from modules.user.pipeline import pipeline
|
9 |
+
import torch
|
10 |
+
|
11 |
+
def load_generated_images():
|
12 |
+
"""Load generated images with given prefix from disk"""
|
13 |
+
image_files = glob.glob("./_internal/output/*")
|
14 |
+
|
15 |
+
# If there are no image files, return
|
16 |
+
if not image_files:
|
17 |
+
return []
|
18 |
+
|
19 |
+
# Sort files by modification time in descending order
|
20 |
+
image_files.sort(key=os.path.getmtime, reverse=True)
|
21 |
+
|
22 |
+
# Get most recent timestamp
|
23 |
+
latest_time = os.path.getmtime(image_files[0])
|
24 |
+
|
25 |
+
# Get all images from same batch (within 1 second of most recent)
|
26 |
+
batch_images = []
|
27 |
+
for file in image_files:
|
28 |
+
if abs(os.path.getmtime(file) - latest_time) < 1.0:
|
29 |
+
try:
|
30 |
+
img = Image.open(file)
|
31 |
+
batch_images.append(img)
|
32 |
+
except:
|
33 |
+
continue
|
34 |
+
|
35 |
+
if not batch_images:
|
36 |
+
return []
|
37 |
+
return batch_images
|
38 |
+
|
39 |
+
def generate_images(
|
40 |
+
prompt: str,
|
41 |
+
width: int = 512,
|
42 |
+
height: int = 512,
|
43 |
+
num_images: int = 1,
|
44 |
+
batch_size: int = 1,
|
45 |
+
hires_fix: bool = False,
|
46 |
+
adetailer: bool = False,
|
47 |
+
enhance_prompt: bool = False,
|
48 |
+
img2img_enabled: bool = False,
|
49 |
+
img2img_image: str = None,
|
50 |
+
stable_fast: bool = False,
|
51 |
+
reuse_seed: bool = False,
|
52 |
+
flux_enabled: bool = False,
|
53 |
+
prio_speed: bool = False,
|
54 |
+
progress=gr.Progress()
|
55 |
+
):
|
56 |
+
"""Generate images using the LightDiffusion pipeline"""
|
57 |
+
try:
|
58 |
+
if img2img_enabled and img2img_image is not None:
|
59 |
+
# Save uploaded image temporarily and pass path to pipeline
|
60 |
+
img2img_image.save("temp_img2img.png")
|
61 |
+
prompt = "temp_img2img.png"
|
62 |
+
|
63 |
+
# Run pipeline and capture saved images
|
64 |
+
with torch.inference_mode():
|
65 |
+
images = pipeline(
|
66 |
+
prompt=prompt,
|
67 |
+
w=width,
|
68 |
+
h=height,
|
69 |
+
number=num_images,
|
70 |
+
batch=batch_size,
|
71 |
+
hires_fix=hires_fix,
|
72 |
+
adetailer=adetailer,
|
73 |
+
enhance_prompt=enhance_prompt,
|
74 |
+
img2img=img2img_enabled,
|
75 |
+
stable_fast=stable_fast,
|
76 |
+
reuse_seed=reuse_seed,
|
77 |
+
flux_enabled=flux_enabled,
|
78 |
+
prio_speed=prio_speed
|
79 |
+
)
|
80 |
+
|
81 |
+
return load_generated_images()
|
82 |
+
|
83 |
+
except Exception as e:
|
84 |
+
import traceback
|
85 |
+
print(traceback.format_exc())
|
86 |
+
return [Image.new('RGB', (512, 512), color='black')]
|
87 |
+
|
88 |
+
# Create Gradio interface
|
89 |
+
with gr.Blocks(title="LightDiffusion Web UI") as demo:
|
90 |
+
gr.Markdown("# LightDiffusion Web UI")
|
91 |
+
gr.Markdown("Generate AI images using LightDiffusion")
|
92 |
+
|
93 |
+
with gr.Row():
|
94 |
+
with gr.Column():
|
95 |
+
# Input components
|
96 |
+
prompt = gr.Textbox(label="Prompt", placeholder="Enter your prompt here...")
|
97 |
+
|
98 |
+
with gr.Row():
|
99 |
+
width = gr.Slider(minimum=64, maximum=2048, value=512, step=64, label="Width")
|
100 |
+
height = gr.Slider(minimum=64, maximum=2048, value=512, step=64, label="Height")
|
101 |
+
|
102 |
+
with gr.Row():
|
103 |
+
num_images = gr.Slider(minimum=1, maximum=10, value=1, step=1, label="Number of Images")
|
104 |
+
batch_size = gr.Slider(minimum=1, maximum=4, value=1, step=1, label="Batch Size")
|
105 |
+
|
106 |
+
with gr.Row():
|
107 |
+
hires_fix = gr.Checkbox(label="HiRes Fix")
|
108 |
+
adetailer = gr.Checkbox(label="Auto Face/Body Enhancement")
|
109 |
+
enhance_prompt = gr.Checkbox(label="Enhance Prompt")
|
110 |
+
stable_fast = gr.Checkbox(label="Stable Fast Mode")
|
111 |
+
|
112 |
+
with gr.Row():
|
113 |
+
reuse_seed = gr.Checkbox(label="Reuse Seed")
|
114 |
+
flux_enabled = gr.Checkbox(label="Flux Mode")
|
115 |
+
prio_speed = gr.Checkbox(label="Prioritize Speed")
|
116 |
+
|
117 |
+
with gr.Row():
|
118 |
+
img2img_enabled = gr.Checkbox(label="Image to Image Mode")
|
119 |
+
img2img_image = gr.Image(label="Input Image for img2img", visible=False)
|
120 |
+
|
121 |
+
# Make input image visible only when img2img is enabled
|
122 |
+
img2img_enabled.change(
|
123 |
+
fn=lambda x: gr.update(visible=x),
|
124 |
+
inputs=[img2img_enabled],
|
125 |
+
outputs=[img2img_image]
|
126 |
+
)
|
127 |
+
|
128 |
+
generate_btn = gr.Button("Generate")
|
129 |
+
|
130 |
+
# Output gallery
|
131 |
+
gallery = gr.Gallery(
|
132 |
+
label="Generated Images",
|
133 |
+
show_label=True,
|
134 |
+
elem_id="gallery",
|
135 |
+
columns=[2],
|
136 |
+
rows=[2],
|
137 |
+
object_fit="contain",
|
138 |
+
height="auto"
|
139 |
+
)
|
140 |
+
|
141 |
+
# Connect generate button to pipeline
|
142 |
+
generate_btn.click(
|
143 |
+
fn=generate_images,
|
144 |
+
inputs=[
|
145 |
+
prompt,
|
146 |
+
width,
|
147 |
+
height,
|
148 |
+
num_images,
|
149 |
+
batch_size,
|
150 |
+
hires_fix,
|
151 |
+
adetailer,
|
152 |
+
enhance_prompt,
|
153 |
+
img2img_enabled,
|
154 |
+
img2img_image,
|
155 |
+
stable_fast,
|
156 |
+
reuse_seed,
|
157 |
+
flux_enabled,
|
158 |
+
prio_speed
|
159 |
+
],
|
160 |
+
outputs=gallery
|
161 |
+
)
|
162 |
+
|
163 |
+
# For local testing
|
164 |
+
if __name__ == "__main__":
|
165 |
+
demo.launch(
|
166 |
+
server_name="0.0.0.0",
|
167 |
+
server_port=8000,
|
168 |
+
auth=None,
|
169 |
+
share=True,
|
170 |
+
debug=True
|
171 |
+
)
|
modules/Attention/Attention.py
ADDED
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import logging
|
4 |
+
|
5 |
+
from modules.Utilities import util
|
6 |
+
from modules.Attention import AttentionMethods
|
7 |
+
from modules.Device import Device
|
8 |
+
from modules.cond import cast
|
9 |
+
|
10 |
+
|
11 |
+
def Normalize(
|
12 |
+
in_channels: int, dtype: torch.dtype = None, device: torch.device = None
|
13 |
+
) -> torch.nn.GroupNorm:
|
14 |
+
"""#### Normalize the input channels.
|
15 |
+
|
16 |
+
#### Args:
|
17 |
+
- `in_channels` (int): The input channels.
|
18 |
+
- `dtype` (torch.dtype, optional): The data type. Defaults to `None`.
|
19 |
+
- `device` (torch.device, optional): The device. Defaults to `None`.
|
20 |
+
|
21 |
+
#### Returns:
|
22 |
+
- `torch.nn.GroupNorm`: The normalized input channels
|
23 |
+
"""
|
24 |
+
return torch.nn.GroupNorm(
|
25 |
+
num_groups=32,
|
26 |
+
num_channels=in_channels,
|
27 |
+
eps=1e-6,
|
28 |
+
affine=True,
|
29 |
+
dtype=dtype,
|
30 |
+
device=device,
|
31 |
+
)
|
32 |
+
|
33 |
+
|
34 |
+
if Device.xformers_enabled():
|
35 |
+
logging.info("Using xformers cross attention")
|
36 |
+
optimized_attention = AttentionMethods.attention_xformers
|
37 |
+
else:
|
38 |
+
logging.info("Using pytorch cross attention")
|
39 |
+
optimized_attention = AttentionMethods.attention_pytorch
|
40 |
+
|
41 |
+
optimized_attention_masked = optimized_attention
|
42 |
+
|
43 |
+
|
44 |
+
def optimized_attention_for_device() -> AttentionMethods.attention_pytorch:
|
45 |
+
"""#### Get the optimized attention for a device.
|
46 |
+
|
47 |
+
#### Returns:
|
48 |
+
- `function`: The optimized attention function.
|
49 |
+
"""
|
50 |
+
return AttentionMethods.attention_pytorch
|
51 |
+
|
52 |
+
|
53 |
+
class CrossAttention(nn.Module):
|
54 |
+
"""#### Cross attention module, which applies attention across the query and context.
|
55 |
+
|
56 |
+
#### Args:
|
57 |
+
- `query_dim` (int): The query dimension.
|
58 |
+
- `context_dim` (int, optional): The context dimension. Defaults to `None`.
|
59 |
+
- `heads` (int, optional): The number of heads. Defaults to `8`.
|
60 |
+
- `dim_head` (int, optional): The head dimension. Defaults to `64`.
|
61 |
+
- `dropout` (float, optional): The dropout rate. Defaults to `0.0`.
|
62 |
+
- `dtype` (torch.dtype, optional): The data type. Defaults to `None`.
|
63 |
+
- `device` (torch.device, optional): The device. Defaults to `None`.
|
64 |
+
- `operations` (cast.disable_weight_init, optional): The operations. Defaults to `cast.disable_weight_init`.
|
65 |
+
"""
|
66 |
+
|
67 |
+
def __init__(
|
68 |
+
self,
|
69 |
+
query_dim: int,
|
70 |
+
context_dim: int = None,
|
71 |
+
heads: int = 8,
|
72 |
+
dim_head: int = 64,
|
73 |
+
dropout: float = 0.0,
|
74 |
+
dtype: torch.dtype = None,
|
75 |
+
device: torch.device = None,
|
76 |
+
operations: cast.disable_weight_init = cast.disable_weight_init,
|
77 |
+
):
|
78 |
+
super().__init__()
|
79 |
+
inner_dim = dim_head * heads
|
80 |
+
context_dim = util.default(context_dim, query_dim)
|
81 |
+
|
82 |
+
self.heads = heads
|
83 |
+
self.dim_head = dim_head
|
84 |
+
|
85 |
+
self.to_q = operations.Linear(
|
86 |
+
query_dim, inner_dim, bias=False, dtype=dtype, device=device
|
87 |
+
)
|
88 |
+
self.to_k = operations.Linear(
|
89 |
+
context_dim, inner_dim, bias=False, dtype=dtype, device=device
|
90 |
+
)
|
91 |
+
self.to_v = operations.Linear(
|
92 |
+
context_dim, inner_dim, bias=False, dtype=dtype, device=device
|
93 |
+
)
|
94 |
+
|
95 |
+
self.to_out = nn.Sequential(
|
96 |
+
operations.Linear(inner_dim, query_dim, dtype=dtype, device=device),
|
97 |
+
nn.Dropout(dropout),
|
98 |
+
)
|
99 |
+
|
100 |
+
def forward(
|
101 |
+
self,
|
102 |
+
x: torch.Tensor,
|
103 |
+
context: torch.Tensor = None,
|
104 |
+
value: torch.Tensor = None,
|
105 |
+
mask: torch.Tensor = None,
|
106 |
+
) -> torch.Tensor:
|
107 |
+
"""#### Forward pass of the cross attention module.
|
108 |
+
|
109 |
+
#### Args:
|
110 |
+
- `x` (torch.Tensor): The input tensor.
|
111 |
+
- `context` (torch.Tensor, optional): The context tensor. Defaults to `None`.
|
112 |
+
- `value` (torch.Tensor, optional): The value tensor. Defaults to `None`.
|
113 |
+
- `mask` (torch.Tensor, optional): The mask tensor. Defaults to `None`.
|
114 |
+
|
115 |
+
#### Returns:
|
116 |
+
- `torch.Tensor`: The output tensor.
|
117 |
+
"""
|
118 |
+
q = self.to_q(x)
|
119 |
+
context = util.default(context, x)
|
120 |
+
k = self.to_k(context)
|
121 |
+
v = self.to_v(context)
|
122 |
+
|
123 |
+
out = optimized_attention(q, k, v, self.heads)
|
124 |
+
return self.to_out(out)
|
125 |
+
|
126 |
+
|
127 |
+
class AttnBlock(nn.Module):
|
128 |
+
"""#### Attention block, which applies attention to the input tensor.
|
129 |
+
|
130 |
+
#### Args:
|
131 |
+
- `in_channels` (int): The input channels.
|
132 |
+
"""
|
133 |
+
|
134 |
+
def __init__(self, in_channels: int):
|
135 |
+
super().__init__()
|
136 |
+
self.in_channels = in_channels
|
137 |
+
|
138 |
+
self.norm = Normalize(in_channels)
|
139 |
+
self.q = cast.disable_weight_init.Conv2d(
|
140 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
141 |
+
)
|
142 |
+
self.k = cast.disable_weight_init.Conv2d(
|
143 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
144 |
+
)
|
145 |
+
self.v = cast.disable_weight_init.Conv2d(
|
146 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
147 |
+
)
|
148 |
+
self.proj_out = cast.disable_weight_init.Conv2d(
|
149 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
150 |
+
)
|
151 |
+
|
152 |
+
if Device.xformers_enabled_vae():
|
153 |
+
logging.info("Using xformers attention in VAE")
|
154 |
+
self.optimized_attention = AttentionMethods.xformers_attention
|
155 |
+
else:
|
156 |
+
logging.info("Using pytorch attention in VAE")
|
157 |
+
self.optimized_attention = AttentionMethods.pytorch_attention
|
158 |
+
|
159 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
160 |
+
"""#### Forward pass of the attention block.
|
161 |
+
|
162 |
+
#### Args:
|
163 |
+
- `x` (torch.Tensor): The input tensor.
|
164 |
+
|
165 |
+
#### Returns:
|
166 |
+
- `torch.Tensor`: The output tensor.
|
167 |
+
"""
|
168 |
+
h_ = x
|
169 |
+
h_ = self.norm(h_)
|
170 |
+
q = self.q(h_)
|
171 |
+
k = self.k(h_)
|
172 |
+
v = self.v(h_)
|
173 |
+
|
174 |
+
h_ = self.optimized_attention(q, k, v)
|
175 |
+
|
176 |
+
h_ = self.proj_out(h_)
|
177 |
+
|
178 |
+
return x + h_
|
179 |
+
|
180 |
+
|
181 |
+
def make_attn(in_channels: int, attn_type: str = "vanilla") -> AttnBlock:
|
182 |
+
"""#### Make an attention block.
|
183 |
+
|
184 |
+
#### Args:
|
185 |
+
- `in_channels` (int): The input channels.
|
186 |
+
- `attn_type` (str, optional): The attention type. Defaults to "vanilla".
|
187 |
+
|
188 |
+
#### Returns:
|
189 |
+
- `AttnBlock`: A class instance of the attention block.
|
190 |
+
"""
|
191 |
+
return AttnBlock(in_channels)
|
modules/Attention/AttentionMethods.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
try :
|
2 |
+
import xformers
|
3 |
+
except ImportError:
|
4 |
+
pass
|
5 |
+
import torch
|
6 |
+
|
7 |
+
|
8 |
+
def attention_xformers(
|
9 |
+
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, heads: int, mask=None
|
10 |
+
) -> torch.Tensor:
|
11 |
+
"""#### Make an attention call using xformers. Fastest attention implementation.
|
12 |
+
|
13 |
+
#### Args:
|
14 |
+
- `q` (torch.Tensor): The query tensor.
|
15 |
+
- `k` (torch.Tensor): The key tensor, must have the same shape as `q`.
|
16 |
+
- `v` (torch.Tensor): The value tensor, must have the same shape as `q`.
|
17 |
+
- `heads` (int): The number of heads, must be a divisor of the hidden dimension.
|
18 |
+
- `mask` (torch.Tensor, optional): The mask tensor. Defaults to `None`.
|
19 |
+
|
20 |
+
#### Returns:
|
21 |
+
- `torch.Tensor`: The output tensor.
|
22 |
+
"""
|
23 |
+
b, _, dim_head = q.shape
|
24 |
+
dim_head //= heads
|
25 |
+
|
26 |
+
q, k, v = map(
|
27 |
+
lambda t: t.unsqueeze(3)
|
28 |
+
.reshape(b, -1, heads, dim_head)
|
29 |
+
.permute(0, 2, 1, 3)
|
30 |
+
.reshape(b * heads, -1, dim_head)
|
31 |
+
.contiguous(),
|
32 |
+
(q, k, v),
|
33 |
+
)
|
34 |
+
|
35 |
+
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask)
|
36 |
+
|
37 |
+
out = (
|
38 |
+
out.unsqueeze(0)
|
39 |
+
.reshape(b, heads, -1, dim_head)
|
40 |
+
.permute(0, 2, 1, 3)
|
41 |
+
.reshape(b, -1, heads * dim_head)
|
42 |
+
)
|
43 |
+
return out
|
44 |
+
|
45 |
+
|
46 |
+
def attention_pytorch(
|
47 |
+
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, heads: int, mask=None
|
48 |
+
) -> torch.Tensor:
|
49 |
+
"""#### Make an attention call using PyTorch.
|
50 |
+
|
51 |
+
#### Args:
|
52 |
+
- `q` (torch.Tensor): The query tensor.
|
53 |
+
- `k` (torch.Tensor): The key tensor, must have the same shape as `q.
|
54 |
+
- `v` (torch.Tensor): The value tensor, must have the same shape as `q.
|
55 |
+
- `heads` (int): The number of heads, must be a divisor of the hidden dimension.
|
56 |
+
- `mask` (torch.Tensor, optional): The mask tensor. Defaults to `None`.
|
57 |
+
|
58 |
+
#### Returns:
|
59 |
+
- `torch.Tensor`: The output tensor.
|
60 |
+
"""
|
61 |
+
b, _, dim_head = q.shape
|
62 |
+
dim_head //= heads
|
63 |
+
q, k, v = map(
|
64 |
+
lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2),
|
65 |
+
(q, k, v),
|
66 |
+
)
|
67 |
+
|
68 |
+
out = torch.nn.functional.scaled_dot_product_attention(
|
69 |
+
q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False
|
70 |
+
)
|
71 |
+
out = out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
72 |
+
return out
|
73 |
+
|
74 |
+
|
75 |
+
def xformers_attention(
|
76 |
+
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor
|
77 |
+
) -> torch.Tensor:
|
78 |
+
"""#### Compute attention using xformers.
|
79 |
+
|
80 |
+
#### Args:
|
81 |
+
- `q` (torch.Tensor): The query tensor.
|
82 |
+
- `k` (torch.Tensor): The key tensor, must have the same shape as `q`.
|
83 |
+
- `v` (torch.Tensor): The value tensor, must have the same shape as `q`.
|
84 |
+
|
85 |
+
Returns:
|
86 |
+
- `torch.Tensor`: The output tensor.
|
87 |
+
"""
|
88 |
+
B, C, H, W = q.shape
|
89 |
+
q, k, v = map(
|
90 |
+
lambda t: t.view(B, C, -1).transpose(1, 2).contiguous(),
|
91 |
+
(q, k, v),
|
92 |
+
)
|
93 |
+
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None)
|
94 |
+
out = out.transpose(1, 2).reshape(B, C, H, W)
|
95 |
+
return out
|
96 |
+
|
97 |
+
|
98 |
+
def pytorch_attention(
|
99 |
+
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor
|
100 |
+
) -> torch.Tensor:
|
101 |
+
"""#### Compute attention using PyTorch.
|
102 |
+
|
103 |
+
#### Args:
|
104 |
+
- `q` (torch.Tensor): The query tensor.
|
105 |
+
- `k` (torch.Tensor): The key tensor, must have the same shape as `q.
|
106 |
+
- `v` (torch.Tensor): The value tensor, must have the same shape as `q.
|
107 |
+
|
108 |
+
#### Returns:
|
109 |
+
- `torch.Tensor`: The output tensor.
|
110 |
+
"""
|
111 |
+
B, C, H, W = q.shape
|
112 |
+
q, k, v = map(
|
113 |
+
lambda t: t.view(B, 1, C, -1).transpose(2, 3).contiguous(),
|
114 |
+
(q, k, v),
|
115 |
+
)
|
116 |
+
out = torch.nn.functional.scaled_dot_product_attention(
|
117 |
+
q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False
|
118 |
+
)
|
119 |
+
out = out.transpose(2, 3).reshape(B, C, H, W)
|
120 |
+
return out
|
modules/AutoDetailer/AD_util.py
ADDED
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
from ultralytics import YOLO
|
6 |
+
from PIL import Image
|
7 |
+
|
8 |
+
orig_torch_load = torch.load
|
9 |
+
|
10 |
+
# importing YOLO breaking original torch.load capabilities
|
11 |
+
torch.load = orig_torch_load
|
12 |
+
|
13 |
+
|
14 |
+
def load_yolo(model_path: str) -> YOLO:
|
15 |
+
"""#### Load YOLO model.
|
16 |
+
|
17 |
+
#### Args:
|
18 |
+
- `model_path` (str): The path to the YOLO model.
|
19 |
+
|
20 |
+
#### Returns:
|
21 |
+
- `YOLO`: The YOLO model initialized with the specified model path.
|
22 |
+
"""
|
23 |
+
try:
|
24 |
+
return YOLO(model_path)
|
25 |
+
except ModuleNotFoundError:
|
26 |
+
print("please download yolo model")
|
27 |
+
|
28 |
+
|
29 |
+
def inference_bbox(
|
30 |
+
model: YOLO,
|
31 |
+
image: Image.Image,
|
32 |
+
confidence: float = 0.3,
|
33 |
+
device: str = "",
|
34 |
+
) -> List:
|
35 |
+
"""#### Perform inference on an image and return bounding boxes.
|
36 |
+
|
37 |
+
#### Args:
|
38 |
+
- `model` (YOLO): The YOLO model.
|
39 |
+
- `image` (Image.Image): The image to perform inference on.
|
40 |
+
- `confidence` (float): The confidence threshold for the bounding boxes.
|
41 |
+
- `device` (str): The device to run the model on.
|
42 |
+
|
43 |
+
#### Returns:
|
44 |
+
- `List[List[str, List[int], np.ndarray, float]]`: The list of bounding boxes.
|
45 |
+
"""
|
46 |
+
pred = model(image, conf=confidence, device=device)
|
47 |
+
|
48 |
+
bboxes = pred[0].boxes.xyxy.cpu().numpy()
|
49 |
+
cv2_image = np.array(image)
|
50 |
+
cv2_image = cv2_image[:, :, ::-1].copy() # Convert RGB to BGR for cv2 processing
|
51 |
+
cv2_gray = cv2.cvtColor(cv2_image, cv2.COLOR_BGR2GRAY)
|
52 |
+
|
53 |
+
segms = []
|
54 |
+
for x0, y0, x1, y1 in bboxes:
|
55 |
+
cv2_mask = np.zeros(cv2_gray.shape, np.uint8)
|
56 |
+
cv2.rectangle(cv2_mask, (int(x0), int(y0)), (int(x1), int(y1)), 255, -1)
|
57 |
+
cv2_mask_bool = cv2_mask.astype(bool)
|
58 |
+
segms.append(cv2_mask_bool)
|
59 |
+
|
60 |
+
results = [[], [], [], []]
|
61 |
+
for i in range(len(bboxes)):
|
62 |
+
results[0].append(pred[0].names[int(pred[0].boxes[i].cls.item())])
|
63 |
+
results[1].append(bboxes[i])
|
64 |
+
results[2].append(segms[i])
|
65 |
+
results[3].append(pred[0].boxes[i].conf.cpu().numpy())
|
66 |
+
|
67 |
+
return results
|
68 |
+
|
69 |
+
|
70 |
+
def create_segmasks(results: List) -> List:
|
71 |
+
"""#### Create segmentation masks from the results of the inference.
|
72 |
+
|
73 |
+
#### Args:
|
74 |
+
- `results` (List[List[str, List[int], np.ndarray, float]]): The results of the inference.
|
75 |
+
|
76 |
+
#### Returns:
|
77 |
+
- `List[List[int], np.ndarray, float]`: The list of segmentation masks.
|
78 |
+
"""
|
79 |
+
bboxs = results[1]
|
80 |
+
segms = results[2]
|
81 |
+
confidence = results[3]
|
82 |
+
|
83 |
+
results = []
|
84 |
+
for i in range(len(segms)):
|
85 |
+
item = (bboxs[i], segms[i].astype(np.float32), confidence[i])
|
86 |
+
results.append(item)
|
87 |
+
return results
|
88 |
+
|
89 |
+
|
90 |
+
def dilate_masks(segmasks: List, dilation_factor: int, iter: int = 1) -> List:
|
91 |
+
"""#### Dilate the segmentation masks.
|
92 |
+
|
93 |
+
#### Args:
|
94 |
+
- `segmasks` (List[List[int], np.ndarray, float]): The segmentation masks.
|
95 |
+
- `dilation_factor` (int): The dilation factor.
|
96 |
+
- `iter` (int): The number of iterations.
|
97 |
+
|
98 |
+
#### Returns:
|
99 |
+
- `List[List[int], np.ndarray, float]`: The dilated segmentation masks.
|
100 |
+
"""
|
101 |
+
dilated_masks = []
|
102 |
+
kernel = np.ones((abs(dilation_factor), abs(dilation_factor)), np.uint8)
|
103 |
+
|
104 |
+
for i in range(len(segmasks)):
|
105 |
+
cv2_mask = segmasks[i][1]
|
106 |
+
|
107 |
+
dilated_mask = cv2.dilate(cv2_mask, kernel, iter)
|
108 |
+
|
109 |
+
item = (segmasks[i][0], dilated_mask, segmasks[i][2])
|
110 |
+
dilated_masks.append(item)
|
111 |
+
|
112 |
+
return dilated_masks
|
113 |
+
|
114 |
+
|
115 |
+
def normalize_region(limit: int, startp: int, size: int) -> List:
|
116 |
+
"""#### Normalize the region.
|
117 |
+
|
118 |
+
#### Args:
|
119 |
+
- `limit` (int): The limit.
|
120 |
+
- `startp` (int): The start point.
|
121 |
+
- `size` (int): The size.
|
122 |
+
|
123 |
+
#### Returns:
|
124 |
+
- `List[int]`: The normalized start and end points.
|
125 |
+
"""
|
126 |
+
if startp < 0:
|
127 |
+
new_endp = min(limit, size)
|
128 |
+
new_startp = 0
|
129 |
+
elif startp + size > limit:
|
130 |
+
new_startp = max(0, limit - size)
|
131 |
+
new_endp = limit
|
132 |
+
else:
|
133 |
+
new_startp = startp
|
134 |
+
new_endp = min(limit, startp + size)
|
135 |
+
|
136 |
+
return int(new_startp), int(new_endp)
|
137 |
+
|
138 |
+
|
139 |
+
def make_crop_region(w: int, h: int, bbox: List, crop_factor: float) -> List:
|
140 |
+
"""#### Make the crop region.
|
141 |
+
|
142 |
+
#### Args:
|
143 |
+
- `w` (int): The width.
|
144 |
+
- `h` (int): The height.
|
145 |
+
- `bbox` (List[int]): The bounding box.
|
146 |
+
- `crop_factor` (float): The crop factor.
|
147 |
+
|
148 |
+
#### Returns:
|
149 |
+
- `List[x1: int, y1: int, x2: int, y2: int]`: The crop region.
|
150 |
+
"""
|
151 |
+
x1 = bbox[0]
|
152 |
+
y1 = bbox[1]
|
153 |
+
x2 = bbox[2]
|
154 |
+
y2 = bbox[3]
|
155 |
+
|
156 |
+
bbox_w = x2 - x1
|
157 |
+
bbox_h = y2 - y1
|
158 |
+
|
159 |
+
crop_w = bbox_w * crop_factor
|
160 |
+
crop_h = bbox_h * crop_factor
|
161 |
+
|
162 |
+
kernel_x = x1 + bbox_w / 2
|
163 |
+
kernel_y = y1 + bbox_h / 2
|
164 |
+
|
165 |
+
new_x1 = int(kernel_x - crop_w / 2)
|
166 |
+
new_y1 = int(kernel_y - crop_h / 2)
|
167 |
+
|
168 |
+
# make sure position in (w,h)
|
169 |
+
new_x1, new_x2 = normalize_region(w, new_x1, crop_w)
|
170 |
+
new_y1, new_y2 = normalize_region(h, new_y1, crop_h)
|
171 |
+
|
172 |
+
return [new_x1, new_y1, new_x2, new_y2]
|
173 |
+
|
174 |
+
|
175 |
+
def crop_ndarray2(npimg: np.ndarray, crop_region: List) -> np.ndarray:
|
176 |
+
"""#### Crop the ndarray in 2 dimensions.
|
177 |
+
|
178 |
+
#### Args:
|
179 |
+
- `npimg` (np.ndarray): The ndarray to crop.
|
180 |
+
- `crop_region` (List[int]): The crop region.
|
181 |
+
|
182 |
+
#### Returns:
|
183 |
+
- `np.ndarray`: The cropped ndarray.
|
184 |
+
"""
|
185 |
+
x1 = crop_region[0]
|
186 |
+
y1 = crop_region[1]
|
187 |
+
x2 = crop_region[2]
|
188 |
+
y2 = crop_region[3]
|
189 |
+
|
190 |
+
cropped = npimg[y1:y2, x1:x2]
|
191 |
+
|
192 |
+
return cropped
|
193 |
+
|
194 |
+
|
195 |
+
def crop_ndarray4(npimg: np.ndarray, crop_region: List) -> np.ndarray:
|
196 |
+
"""#### Crop the ndarray in 4 dimensions.
|
197 |
+
|
198 |
+
#### Args:
|
199 |
+
- `npimg` (np.ndarray): The ndarray to crop.
|
200 |
+
- `crop_region` (List[int]): The crop region.
|
201 |
+
|
202 |
+
#### Returns:
|
203 |
+
- `np.ndarray`: The cropped ndarray.
|
204 |
+
"""
|
205 |
+
x1 = crop_region[0]
|
206 |
+
y1 = crop_region[1]
|
207 |
+
x2 = crop_region[2]
|
208 |
+
y2 = crop_region[3]
|
209 |
+
|
210 |
+
cropped = npimg[:, y1:y2, x1:x2, :]
|
211 |
+
|
212 |
+
return cropped
|
213 |
+
|
214 |
+
|
215 |
+
def crop_image(image: Image.Image, crop_region: List) -> Image.Image:
|
216 |
+
"""#### Crop the image.
|
217 |
+
|
218 |
+
#### Args:
|
219 |
+
- `image` (Image.Image): The image to crop.
|
220 |
+
- `crop_region` (List[int]): The crop region.
|
221 |
+
|
222 |
+
#### Returns:
|
223 |
+
- `Image.Image`: The cropped image.
|
224 |
+
"""
|
225 |
+
return crop_ndarray4(image, crop_region)
|
226 |
+
|
227 |
+
|
228 |
+
def segs_scale_match(segs: List[np.ndarray], target_shape: List) -> List:
|
229 |
+
"""#### Match the scale of the segmentation masks.
|
230 |
+
|
231 |
+
#### Args:
|
232 |
+
- `segs` (List[np.ndarray]): The segmentation masks.
|
233 |
+
- `target_shape` (List[int]): The target shape.
|
234 |
+
|
235 |
+
#### Returns:
|
236 |
+
- `List[np.ndarray]`: The matched segmentation masks.
|
237 |
+
"""
|
238 |
+
h = segs[0][0]
|
239 |
+
w = segs[0][1]
|
240 |
+
|
241 |
+
th = target_shape[1]
|
242 |
+
tw = target_shape[2]
|
243 |
+
|
244 |
+
if (h == th and w == tw) or h == 0 or w == 0:
|
245 |
+
return segs
|
modules/AutoDetailer/ADetailer.py
ADDED
@@ -0,0 +1,952 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
from typing import Any, Dict, Optional, Tuple
|
4 |
+
|
5 |
+
from modules.AutoDetailer import AD_util, bbox, tensor_util
|
6 |
+
from modules.AutoDetailer import SEGS
|
7 |
+
from modules.Utilities import util
|
8 |
+
from modules.AutoEncoders import VariationalAE
|
9 |
+
from modules.Device import Device
|
10 |
+
from modules.sample import ksampler_util, samplers, sampling, sampling_util
|
11 |
+
|
12 |
+
# FIXME: Improve slow inference times
|
13 |
+
|
14 |
+
|
15 |
+
class DifferentialDiffusion:
|
16 |
+
"""#### Class for applying differential diffusion to a model."""
|
17 |
+
|
18 |
+
def apply(self, model: torch.nn.Module) -> Tuple[torch.nn.Module]:
|
19 |
+
"""#### Apply differential diffusion to a model.
|
20 |
+
|
21 |
+
#### Args:
|
22 |
+
- `model` (torch.nn.Module): The input model.
|
23 |
+
|
24 |
+
#### Returns:
|
25 |
+
- `Tuple[torch.nn.Module]`: The modified model.
|
26 |
+
"""
|
27 |
+
model = model.clone()
|
28 |
+
model.set_model_denoise_mask_function(self.forward)
|
29 |
+
return (model,)
|
30 |
+
|
31 |
+
def forward(
|
32 |
+
self,
|
33 |
+
sigma: torch.Tensor,
|
34 |
+
denoise_mask: torch.Tensor,
|
35 |
+
extra_options: Dict[str, Any],
|
36 |
+
) -> torch.Tensor:
|
37 |
+
"""#### Forward function for differential diffusion.
|
38 |
+
|
39 |
+
#### Args:
|
40 |
+
- `sigma` (torch.Tensor): The sigma tensor.
|
41 |
+
- `denoise_mask` (torch.Tensor): The denoise mask tensor.
|
42 |
+
- `extra_options` (Dict[str, Any]): Additional options.
|
43 |
+
|
44 |
+
#### Returns:
|
45 |
+
- `torch.Tensor`: The processed denoise mask tensor.
|
46 |
+
"""
|
47 |
+
model = extra_options["model"]
|
48 |
+
step_sigmas = extra_options["sigmas"]
|
49 |
+
sigma_to = model.inner_model.model_sampling.sigma_min
|
50 |
+
sigma_from = step_sigmas[0]
|
51 |
+
|
52 |
+
ts_from = model.inner_model.model_sampling.timestep(sigma_from)
|
53 |
+
ts_to = model.inner_model.model_sampling.timestep(sigma_to)
|
54 |
+
current_ts = model.inner_model.model_sampling.timestep(sigma[0])
|
55 |
+
|
56 |
+
threshold = (current_ts - ts_to) / (ts_from - ts_to)
|
57 |
+
|
58 |
+
return (denoise_mask >= threshold).to(denoise_mask.dtype)
|
59 |
+
|
60 |
+
|
61 |
+
def to_latent_image(pixels: torch.Tensor, vae: VariationalAE.VAE) -> torch.Tensor:
|
62 |
+
"""#### Convert pixels to a latent image using a VAE.
|
63 |
+
|
64 |
+
#### Args:
|
65 |
+
- `pixels` (torch.Tensor): The input pixel tensor.
|
66 |
+
- `vae` (VariationalAE.VAE): The VAE model.
|
67 |
+
|
68 |
+
#### Returns:
|
69 |
+
- `torch.Tensor`: The latent image tensor.
|
70 |
+
"""
|
71 |
+
pixels.shape[1]
|
72 |
+
pixels.shape[2]
|
73 |
+
return VariationalAE.VAEEncode().encode(vae, pixels)[0]
|
74 |
+
|
75 |
+
|
76 |
+
def calculate_sigmas2(
|
77 |
+
model: torch.nn.Module, sampler: str, scheduler: str, steps: int
|
78 |
+
) -> torch.Tensor:
|
79 |
+
"""#### Calculate sigmas for a model.
|
80 |
+
|
81 |
+
#### Args:
|
82 |
+
- `model` (torch.nn.Module): The input model.
|
83 |
+
- `sampler` (str): The sampler name.
|
84 |
+
- `scheduler` (str): The scheduler name.
|
85 |
+
- `steps` (int): The number of steps.
|
86 |
+
|
87 |
+
#### Returns:
|
88 |
+
- `torch.Tensor`: The calculated sigmas.
|
89 |
+
"""
|
90 |
+
return ksampler_util.calculate_sigmas(
|
91 |
+
model.get_model_object("model_sampling"), scheduler, steps
|
92 |
+
)
|
93 |
+
|
94 |
+
|
95 |
+
def get_noise_sampler(
|
96 |
+
x: torch.Tensor, cpu: bool, total_sigmas: torch.Tensor, **kwargs
|
97 |
+
) -> Optional[sampling_util.BrownianTreeNoiseSampler]:
|
98 |
+
"""#### Get a noise sampler.
|
99 |
+
|
100 |
+
#### Args:
|
101 |
+
- `x` (torch.Tensor): The input tensor.
|
102 |
+
- `cpu` (bool): Whether to use CPU.
|
103 |
+
- `total_sigmas` (torch.Tensor): The total sigmas tensor.
|
104 |
+
- `kwargs` (dict): Additional arguments.
|
105 |
+
|
106 |
+
#### Returns:
|
107 |
+
- `Optional[sampling_util.BrownianTreeNoiseSampler]`: The noise sampler.
|
108 |
+
"""
|
109 |
+
if "extra_args" in kwargs and "seed" in kwargs["extra_args"]:
|
110 |
+
sigma_min, sigma_max = total_sigmas[total_sigmas > 0].min(), total_sigmas.max()
|
111 |
+
seed = kwargs["extra_args"].get("seed", None)
|
112 |
+
return sampling_util.BrownianTreeNoiseSampler(
|
113 |
+
x, sigma_min, sigma_max, seed=seed, cpu=cpu
|
114 |
+
)
|
115 |
+
return None
|
116 |
+
|
117 |
+
|
118 |
+
def ksampler2(
|
119 |
+
sampler_name: str,
|
120 |
+
total_sigmas: torch.Tensor,
|
121 |
+
extra_options: Dict[str, Any] = {},
|
122 |
+
inpaint_options: Dict[str, Any] = {},
|
123 |
+
pipeline: bool = False,
|
124 |
+
) -> sampling.KSAMPLER:
|
125 |
+
"""#### Get a ksampler.
|
126 |
+
|
127 |
+
#### Args:
|
128 |
+
- `sampler_name` (str): The sampler name.
|
129 |
+
- `total_sigmas` (torch.Tensor): The total sigmas tensor.
|
130 |
+
- `extra_options` (Dict[str, Any], optional): Additional options. Defaults to {}.
|
131 |
+
- `inpaint_options` (Dict[str, Any], optional): Inpaint options. Defaults to {}.
|
132 |
+
- `pipeline` (bool, optional): Whether to use pipeline. Defaults to False.
|
133 |
+
|
134 |
+
#### Returns:
|
135 |
+
- `sampling.KSAMPLER`: The ksampler.
|
136 |
+
"""
|
137 |
+
if sampler_name == "dpmpp_2m_sde":
|
138 |
+
|
139 |
+
def sample_dpmpp_sde(model, x, sigmas, pipeline, **kwargs):
|
140 |
+
noise_sampler = get_noise_sampler(x, True, total_sigmas, **kwargs)
|
141 |
+
if noise_sampler is not None:
|
142 |
+
kwargs["noise_sampler"] = noise_sampler
|
143 |
+
|
144 |
+
return samplers.sample_dpmpp_2m_sde(
|
145 |
+
model, x, sigmas, pipeline=pipeline, **kwargs
|
146 |
+
)
|
147 |
+
|
148 |
+
sampler_function = sample_dpmpp_sde
|
149 |
+
|
150 |
+
else:
|
151 |
+
return sampling.sampler_object(sampler_name, pipeline=pipeline)
|
152 |
+
|
153 |
+
return sampling.KSAMPLER(sampler_function, extra_options, inpaint_options)
|
154 |
+
|
155 |
+
|
156 |
+
class Noise_RandomNoise:
|
157 |
+
"""#### Class for generating random noise."""
|
158 |
+
|
159 |
+
def __init__(self, seed: int):
|
160 |
+
"""#### Initialize the Noise_RandomNoise class.
|
161 |
+
|
162 |
+
#### Args:
|
163 |
+
- `seed` (int): The seed for random noise.
|
164 |
+
"""
|
165 |
+
self.seed = seed
|
166 |
+
|
167 |
+
def generate_noise(self, input_latent: Dict[str, torch.Tensor]) -> torch.Tensor:
|
168 |
+
"""#### Generate random noise.
|
169 |
+
|
170 |
+
#### Args:
|
171 |
+
- `input_latent` (Dict[str, torch.Tensor]): The input latent tensor.
|
172 |
+
|
173 |
+
#### Returns:
|
174 |
+
- `torch.Tensor`: The generated noise tensor.
|
175 |
+
"""
|
176 |
+
latent_image = input_latent["samples"]
|
177 |
+
batch_inds = (
|
178 |
+
input_latent["batch_index"] if "batch_index" in input_latent else None
|
179 |
+
)
|
180 |
+
return ksampler_util.prepare_noise(latent_image, self.seed, batch_inds)
|
181 |
+
|
182 |
+
|
183 |
+
def sample_with_custom_noise(
|
184 |
+
model: torch.nn.Module,
|
185 |
+
add_noise: bool,
|
186 |
+
noise_seed: int,
|
187 |
+
cfg: int,
|
188 |
+
positive: Any,
|
189 |
+
negative: Any,
|
190 |
+
sampler: Any,
|
191 |
+
sigmas: torch.Tensor,
|
192 |
+
latent_image: Dict[str, torch.Tensor],
|
193 |
+
noise: Optional[torch.Tensor] = None,
|
194 |
+
callback: Optional[callable] = None,
|
195 |
+
pipeline: bool = False,
|
196 |
+
) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]:
|
197 |
+
"""#### Sample with custom noise.
|
198 |
+
|
199 |
+
#### Args:
|
200 |
+
- `model` (torch.nn.Module): The input model.
|
201 |
+
- `add_noise` (bool): Whether to add noise.
|
202 |
+
- `noise_seed` (int): The noise seed.
|
203 |
+
- `cfg` (int): Classifier-Free Guidance Scale
|
204 |
+
- `positive` (Any): The positive prompt.
|
205 |
+
- `negative` (Any): The negative prompt.
|
206 |
+
- `sampler` (Any): The sampler.
|
207 |
+
- `sigmas` (torch.Tensor): The sigmas tensor.
|
208 |
+
- `latent_image` (Dict[str, torch.Tensor]): The latent image tensor.
|
209 |
+
- `noise` (Optional[torch.Tensor], optional): The noise tensor. Defaults to None.
|
210 |
+
- `callback` (Optional[callable], optional): The callback function. Defaults to None.
|
211 |
+
- `pipeline` (bool, optional): Whether to use pipeline. Defaults to False.
|
212 |
+
|
213 |
+
#### Returns:
|
214 |
+
- `Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]`: The sampled and denoised tensors.
|
215 |
+
"""
|
216 |
+
latent = latent_image
|
217 |
+
latent_image = latent["samples"]
|
218 |
+
|
219 |
+
out = latent.copy()
|
220 |
+
out["samples"] = latent_image
|
221 |
+
|
222 |
+
if noise is None:
|
223 |
+
noise = Noise_RandomNoise(noise_seed).generate_noise(out)
|
224 |
+
|
225 |
+
noise_mask = None
|
226 |
+
if "noise_mask" in latent:
|
227 |
+
noise_mask = latent["noise_mask"]
|
228 |
+
|
229 |
+
disable_pbar = not util.PROGRESS_BAR_ENABLED
|
230 |
+
|
231 |
+
device = Device.get_torch_device()
|
232 |
+
|
233 |
+
noise = noise.to(device)
|
234 |
+
latent_image = latent_image.to(device)
|
235 |
+
if noise_mask is not None:
|
236 |
+
noise_mask = noise_mask.to(device)
|
237 |
+
|
238 |
+
samples = sampling.sample_custom(
|
239 |
+
model,
|
240 |
+
noise,
|
241 |
+
cfg,
|
242 |
+
sampler,
|
243 |
+
sigmas,
|
244 |
+
positive,
|
245 |
+
negative,
|
246 |
+
latent_image,
|
247 |
+
noise_mask=noise_mask,
|
248 |
+
disable_pbar=disable_pbar,
|
249 |
+
seed=noise_seed,
|
250 |
+
pipeline=pipeline,
|
251 |
+
)
|
252 |
+
|
253 |
+
samples = samples.to(Device.intermediate_device())
|
254 |
+
|
255 |
+
out["samples"] = samples
|
256 |
+
out_denoised = out
|
257 |
+
return out, out_denoised
|
258 |
+
|
259 |
+
|
260 |
+
def separated_sample(
|
261 |
+
model: torch.nn.Module,
|
262 |
+
add_noise: bool,
|
263 |
+
seed: int,
|
264 |
+
steps: int,
|
265 |
+
cfg: int,
|
266 |
+
sampler_name: str,
|
267 |
+
scheduler: str,
|
268 |
+
positive: Any,
|
269 |
+
negative: Any,
|
270 |
+
latent_image: Dict[str, torch.Tensor],
|
271 |
+
start_at_step: Optional[int],
|
272 |
+
end_at_step: Optional[int],
|
273 |
+
return_with_leftover_noise: bool,
|
274 |
+
sigma_ratio: float = 1.0,
|
275 |
+
sampler_opt: Optional[Dict[str, Any]] = None,
|
276 |
+
noise: Optional[torch.Tensor] = None,
|
277 |
+
callback: Optional[callable] = None,
|
278 |
+
scheduler_func: Optional[callable] = None,
|
279 |
+
pipeline: bool = False,
|
280 |
+
) -> Dict[str, torch.Tensor]:
|
281 |
+
"""#### Perform separated sampling.
|
282 |
+
|
283 |
+
#### Args:
|
284 |
+
- `model` (torch.nn.Module): The input model.
|
285 |
+
- `add_noise` (bool): Whether to add noise.
|
286 |
+
- `seed` (int): The seed for random noise.
|
287 |
+
- `steps` (int): The number of steps.
|
288 |
+
- `cfg` (int): Classifier-Free Guidance Scale
|
289 |
+
- `sampler_name` (str): The sampler name.
|
290 |
+
- `scheduler` (str): The scheduler name.
|
291 |
+
- `positive` (Any): The positive prompt.
|
292 |
+
- `negative` (Any): The negative prompt.
|
293 |
+
- `latent_image` (Dict[str, torch.Tensor]): The latent image tensor.
|
294 |
+
- `start_at_step` (Optional[int]): The step to start at.
|
295 |
+
- `end_at_step` (Optional[int]): The step to end at.
|
296 |
+
- `return_with_leftover_noise` (bool): Whether to return with leftover noise.
|
297 |
+
- `sigma_ratio` (float, optional): The sigma ratio. Defaults to 1.0.
|
298 |
+
- `sampler_opt` (Optional[Dict[str, Any]], optional): The sampler options. Defaults to None.
|
299 |
+
- `noise` (Optional[torch.Tensor], optional): The noise tensor. Defaults to None.
|
300 |
+
- `callback` (Optional[callable], optional): The callback function. Defaults to None.
|
301 |
+
- `scheduler_func` (Optional[callable], optional): The scheduler function. Defaults to None.
|
302 |
+
- `pipeline` (bool, optional): Whether to use pipeline. Defaults to False.
|
303 |
+
|
304 |
+
#### Returns:
|
305 |
+
- `Dict[str, torch.Tensor]`: The sampled tensor.
|
306 |
+
"""
|
307 |
+
total_sigmas = calculate_sigmas2(model, sampler_name, scheduler, steps)
|
308 |
+
|
309 |
+
sigmas = total_sigmas
|
310 |
+
|
311 |
+
if start_at_step is not None:
|
312 |
+
sigmas = sigmas[start_at_step:] * sigma_ratio
|
313 |
+
|
314 |
+
impact_sampler = ksampler2(sampler_name, total_sigmas, pipeline=pipeline)
|
315 |
+
|
316 |
+
res = sample_with_custom_noise(
|
317 |
+
model,
|
318 |
+
add_noise,
|
319 |
+
seed,
|
320 |
+
cfg,
|
321 |
+
positive,
|
322 |
+
negative,
|
323 |
+
impact_sampler,
|
324 |
+
sigmas,
|
325 |
+
latent_image,
|
326 |
+
noise=noise,
|
327 |
+
callback=callback,
|
328 |
+
pipeline=pipeline,
|
329 |
+
)
|
330 |
+
|
331 |
+
return res[1]
|
332 |
+
|
333 |
+
|
334 |
+
def ksampler_wrapper(
|
335 |
+
model: torch.nn.Module,
|
336 |
+
seed: int,
|
337 |
+
steps: int,
|
338 |
+
cfg: int,
|
339 |
+
sampler_name: str,
|
340 |
+
scheduler: str,
|
341 |
+
positive: Any,
|
342 |
+
negative: Any,
|
343 |
+
latent_image: Dict[str, torch.Tensor],
|
344 |
+
denoise: float,
|
345 |
+
refiner_ratio: Optional[float] = None,
|
346 |
+
refiner_model: Optional[torch.nn.Module] = None,
|
347 |
+
refiner_clip: Optional[Any] = None,
|
348 |
+
refiner_positive: Optional[Any] = None,
|
349 |
+
refiner_negative: Optional[Any] = None,
|
350 |
+
sigma_factor: float = 1.0,
|
351 |
+
noise: Optional[torch.Tensor] = None,
|
352 |
+
scheduler_func: Optional[callable] = None,
|
353 |
+
pipeline: bool = False,
|
354 |
+
) -> Dict[str, torch.Tensor]:
|
355 |
+
"""#### Wrapper for ksampler.
|
356 |
+
|
357 |
+
#### Args:
|
358 |
+
- `model` (torch.nn.Module): The input model.
|
359 |
+
- `seed` (int): The seed for random noise.
|
360 |
+
- `steps` (int): The number of steps.
|
361 |
+
- `cfg` (int): Classifier-Free Guidance Scale
|
362 |
+
- `sampler_name` (str): The sampler name.
|
363 |
+
- `scheduler` (str): The scheduler name.
|
364 |
+
- `positive` (Any): The positive prompt.
|
365 |
+
- `negative` (Any): The negative prompt.
|
366 |
+
- `latent_image` (Dict[str, torch.Tensor]): The latent image tensor.
|
367 |
+
- `denoise` (float): The denoise factor.
|
368 |
+
- `refiner_ratio` (Optional[float], optional): The refiner ratio. Defaults to None.
|
369 |
+
- `refiner_model` (Optional[torch.nn.Module], optional): The refiner model. Defaults to None.
|
370 |
+
- `refiner_clip` (Optional[Any], optional): The refiner clip. Defaults to None.
|
371 |
+
- `refiner_positive` (Optional[Any], optional): The refiner positive prompt. Defaults to None.
|
372 |
+
- `refiner_negative` (Optional[Any], optional): The refiner negative prompt. Defaults to None.
|
373 |
+
- `sigma_factor` (float, optional): The sigma factor. Defaults to 1.0.
|
374 |
+
- `noise` (Optional[torch.Tensor], optional): The noise tensor. Defaults to None.
|
375 |
+
- `scheduler_func` (Optional[callable], optional): The scheduler function. Defaults to None.
|
376 |
+
- `pipeline` (bool, optional): Whether to use pipeline. Defaults to False.
|
377 |
+
|
378 |
+
#### Returns:
|
379 |
+
- `Dict[str, torch.Tensor]`: The refined latent tensor.
|
380 |
+
"""
|
381 |
+
advanced_steps = math.floor(steps / denoise)
|
382 |
+
start_at_step = advanced_steps - steps
|
383 |
+
end_at_step = start_at_step + steps
|
384 |
+
refined_latent = separated_sample(
|
385 |
+
model,
|
386 |
+
True,
|
387 |
+
seed,
|
388 |
+
advanced_steps,
|
389 |
+
cfg,
|
390 |
+
sampler_name,
|
391 |
+
scheduler,
|
392 |
+
positive,
|
393 |
+
negative,
|
394 |
+
latent_image,
|
395 |
+
start_at_step,
|
396 |
+
end_at_step,
|
397 |
+
False,
|
398 |
+
sigma_ratio=sigma_factor,
|
399 |
+
noise=noise,
|
400 |
+
scheduler_func=scheduler_func,
|
401 |
+
pipeline=pipeline,
|
402 |
+
)
|
403 |
+
|
404 |
+
return refined_latent
|
405 |
+
|
406 |
+
|
407 |
+
def enhance_detail(
|
408 |
+
image: torch.Tensor,
|
409 |
+
model: torch.nn.Module,
|
410 |
+
clip: Any,
|
411 |
+
vae: VariationalAE.VAE,
|
412 |
+
guide_size: int,
|
413 |
+
guide_size_for_bbox: bool,
|
414 |
+
max_size: int,
|
415 |
+
bbox: Tuple[int, int, int, int],
|
416 |
+
seed: int,
|
417 |
+
steps: int,
|
418 |
+
cfg: int,
|
419 |
+
sampler_name: str,
|
420 |
+
scheduler: str,
|
421 |
+
positive: Any,
|
422 |
+
negative: Any,
|
423 |
+
denoise: float,
|
424 |
+
noise_mask: Optional[torch.Tensor],
|
425 |
+
force_inpaint: bool,
|
426 |
+
wildcard_opt: Optional[Any] = None,
|
427 |
+
wildcard_opt_concat_mode: Optional[Any] = None,
|
428 |
+
detailer_hook: Optional[callable] = None,
|
429 |
+
refiner_ratio: Optional[float] = None,
|
430 |
+
refiner_model: Optional[torch.nn.Module] = None,
|
431 |
+
refiner_clip: Optional[Any] = None,
|
432 |
+
refiner_positive: Optional[Any] = None,
|
433 |
+
refiner_negative: Optional[Any] = None,
|
434 |
+
control_net_wrapper: Optional[Any] = None,
|
435 |
+
cycle: int = 1,
|
436 |
+
inpaint_model: bool = False,
|
437 |
+
noise_mask_feather: int = 0,
|
438 |
+
scheduler_func: Optional[callable] = None,
|
439 |
+
pipeline: bool = False,
|
440 |
+
) -> Tuple[torch.Tensor, Optional[Any]]:
|
441 |
+
"""#### Enhance detail of an image.
|
442 |
+
|
443 |
+
#### Args:
|
444 |
+
- `image` (torch.Tensor): The input image tensor.
|
445 |
+
- `model` (torch.nn.Module): The model.
|
446 |
+
- `clip` (Any): The clip model.
|
447 |
+
- `vae` (VariationalAE.VAE): The VAE model.
|
448 |
+
- `guide_size` (int): The guide size.
|
449 |
+
- `guide_size_for_bbox` (bool): Whether to use guide size for bbox.
|
450 |
+
- `max_size` (int): The maximum size.
|
451 |
+
- `bbox` (Tuple[int, int, int, int]): The bounding box.
|
452 |
+
- `seed` (int): The seed for random noise.
|
453 |
+
- `steps` (int): The number of steps.
|
454 |
+
- `cfg` (int): Classifier-Free Guidance Scale
|
455 |
+
- `sampler_name` (str): The sampler name.
|
456 |
+
- `scheduler` (str): The scheduler name.
|
457 |
+
- `positive` (Any): The positive prompt.
|
458 |
+
- `negative` (Any): The negative prompt.
|
459 |
+
- `denoise` (float): The denoise factor.
|
460 |
+
- `noise_mask` (Optional[torch.Tensor]): The noise mask tensor.
|
461 |
+
- `force_inpaint` (bool): Whether to force inpaint.
|
462 |
+
- `wildcard_opt` (Optional[Any], optional): The wildcard options. Defaults to None.
|
463 |
+
- `wildcard_opt_concat_mode` (Optional[Any], optional): The wildcard concat mode. Defaults to None.
|
464 |
+
- `detailer_hook` (Optional[callable], optional): The detailer hook. Defaults to None.
|
465 |
+
- `refiner_ratio` (Optional[float], optional): The refiner ratio. Defaults to None.
|
466 |
+
- `refiner_model` (Optional[torch.nn.Module], optional): The refiner model. Defaults to None.
|
467 |
+
- `refiner_clip` (Optional[Any], optional): The refiner clip. Defaults to None.
|
468 |
+
- `refiner_positive` (Optional[Any], optional): The refiner positive prompt. Defaults to None.
|
469 |
+
- `refiner_negative` (Optional[Any], optional): The refiner negative prompt. Defaults to None.
|
470 |
+
- `control_net_wrapper` (Optional[Any], optional): The control net wrapper. Defaults to None.
|
471 |
+
- `cycle` (int, optional): The number of cycles. Defaults to 1.
|
472 |
+
- `inpaint_model` (bool, optional): Whether to use inpaint model. Defaults to False.
|
473 |
+
- `noise_mask_feather` (int, optional): The noise mask feather. Defaults to 0.
|
474 |
+
- `scheduler_func` (Optional[callable], optional): The scheduler function. Defaults to None.
|
475 |
+
- `pipeline` (bool, optional): Whether to use pipeline. Defaults to False.
|
476 |
+
|
477 |
+
#### Returns:
|
478 |
+
- `Tuple[torch.Tensor, Optional[Any]]`: The refined image tensor and optional cnet_pils.
|
479 |
+
"""
|
480 |
+
if noise_mask is not None:
|
481 |
+
noise_mask = tensor_util.tensor_gaussian_blur_mask(
|
482 |
+
noise_mask, noise_mask_feather
|
483 |
+
)
|
484 |
+
noise_mask = noise_mask.squeeze(3)
|
485 |
+
|
486 |
+
h = image.shape[1]
|
487 |
+
w = image.shape[2]
|
488 |
+
|
489 |
+
bbox_h = bbox[3] - bbox[1]
|
490 |
+
bbox_w = bbox[2] - bbox[0]
|
491 |
+
|
492 |
+
# for cropped_size
|
493 |
+
upscale = guide_size / min(w, h)
|
494 |
+
|
495 |
+
new_w = int(w * upscale)
|
496 |
+
new_h = int(h * upscale)
|
497 |
+
|
498 |
+
if new_w > max_size or new_h > max_size:
|
499 |
+
upscale *= max_size / max(new_w, new_h)
|
500 |
+
new_w = int(w * upscale)
|
501 |
+
new_h = int(h * upscale)
|
502 |
+
|
503 |
+
if upscale <= 1.0 or new_w == 0 or new_h == 0:
|
504 |
+
print("Detailer: force inpaint")
|
505 |
+
upscale = 1.0
|
506 |
+
new_w = w
|
507 |
+
new_h = h
|
508 |
+
|
509 |
+
print(
|
510 |
+
f"Detailer: segment upscale for ({bbox_w, bbox_h}) | crop region {w, h} x {upscale} -> {new_w, new_h}"
|
511 |
+
)
|
512 |
+
|
513 |
+
# upscale
|
514 |
+
upscaled_image = tensor_util.tensor_resize(image, new_w, new_h)
|
515 |
+
|
516 |
+
cnet_pils = None
|
517 |
+
|
518 |
+
# prepare mask
|
519 |
+
latent_image = to_latent_image(upscaled_image, vae)
|
520 |
+
if noise_mask is not None:
|
521 |
+
latent_image["noise_mask"] = noise_mask
|
522 |
+
|
523 |
+
refined_latent = latent_image
|
524 |
+
|
525 |
+
# ksampler
|
526 |
+
for i in range(0, cycle):
|
527 |
+
(
|
528 |
+
model2,
|
529 |
+
seed2,
|
530 |
+
steps2,
|
531 |
+
cfg2,
|
532 |
+
sampler_name2,
|
533 |
+
scheduler2,
|
534 |
+
positive2,
|
535 |
+
negative2,
|
536 |
+
_upscaled_latent2,
|
537 |
+
denoise2,
|
538 |
+
) = (
|
539 |
+
model,
|
540 |
+
seed + i,
|
541 |
+
steps,
|
542 |
+
cfg,
|
543 |
+
sampler_name,
|
544 |
+
scheduler,
|
545 |
+
positive,
|
546 |
+
negative,
|
547 |
+
latent_image,
|
548 |
+
denoise,
|
549 |
+
)
|
550 |
+
noise = None
|
551 |
+
|
552 |
+
refined_latent = ksampler_wrapper(
|
553 |
+
model2,
|
554 |
+
seed2,
|
555 |
+
steps2,
|
556 |
+
cfg2,
|
557 |
+
sampler_name2,
|
558 |
+
scheduler2,
|
559 |
+
positive2,
|
560 |
+
negative2,
|
561 |
+
refined_latent,
|
562 |
+
denoise2,
|
563 |
+
refiner_ratio,
|
564 |
+
refiner_model,
|
565 |
+
refiner_clip,
|
566 |
+
refiner_positive,
|
567 |
+
refiner_negative,
|
568 |
+
noise=noise,
|
569 |
+
scheduler_func=scheduler_func,
|
570 |
+
pipeline=pipeline,
|
571 |
+
)
|
572 |
+
|
573 |
+
# non-latent downscale - latent downscale cause bad quality
|
574 |
+
try:
|
575 |
+
# try to decode image normally
|
576 |
+
refined_image = vae.decode(refined_latent["samples"])
|
577 |
+
except Exception:
|
578 |
+
# usually an out-of-memory exception from the decode, so try a tiled approach
|
579 |
+
refined_image = vae.decode_tiled(
|
580 |
+
refined_latent["samples"],
|
581 |
+
tile_x=64,
|
582 |
+
tile_y=64,
|
583 |
+
)
|
584 |
+
|
585 |
+
# downscale
|
586 |
+
refined_image = tensor_util.tensor_resize(refined_image, w, h)
|
587 |
+
|
588 |
+
# prevent mixing of device
|
589 |
+
refined_image = refined_image.cpu()
|
590 |
+
|
591 |
+
# don't convert to latent - latent break image
|
592 |
+
# preserving pil is much better
|
593 |
+
return refined_image, cnet_pils
|
594 |
+
|
595 |
+
|
596 |
+
class DetailerForEach:
|
597 |
+
"""#### Class for detailing each segment of an image."""
|
598 |
+
|
599 |
+
@staticmethod
|
600 |
+
def do_detail(
|
601 |
+
image: torch.Tensor,
|
602 |
+
segs: Tuple[torch.Tensor, Any],
|
603 |
+
model: torch.nn.Module,
|
604 |
+
clip: Any,
|
605 |
+
vae: VariationalAE.VAE,
|
606 |
+
guide_size: int,
|
607 |
+
guide_size_for_bbox: bool,
|
608 |
+
max_size: int,
|
609 |
+
seed: int,
|
610 |
+
steps: int,
|
611 |
+
cfg: int,
|
612 |
+
sampler_name: str,
|
613 |
+
scheduler: str,
|
614 |
+
positive: Any,
|
615 |
+
negative: Any,
|
616 |
+
denoise: float,
|
617 |
+
feather: int,
|
618 |
+
noise_mask: Optional[torch.Tensor],
|
619 |
+
force_inpaint: bool,
|
620 |
+
wildcard_opt: Optional[Any] = None,
|
621 |
+
detailer_hook: Optional[callable] = None,
|
622 |
+
refiner_ratio: Optional[float] = None,
|
623 |
+
refiner_model: Optional[torch.nn.Module] = None,
|
624 |
+
refiner_clip: Optional[Any] = None,
|
625 |
+
refiner_positive: Optional[Any] = None,
|
626 |
+
refiner_negative: Optional[Any] = None,
|
627 |
+
cycle: int = 1,
|
628 |
+
inpaint_model: bool = False,
|
629 |
+
noise_mask_feather: int = 0,
|
630 |
+
scheduler_func_opt: Optional[callable] = None,
|
631 |
+
pipeline: bool = False,
|
632 |
+
) -> Tuple[torch.Tensor, list, list, list, list, Tuple[torch.Tensor, list]]:
|
633 |
+
"""#### Perform detailing on each segment of an image.
|
634 |
+
|
635 |
+
#### Args:
|
636 |
+
- `image` (torch.Tensor): The input image tensor.
|
637 |
+
- `segs` (Tuple[torch.Tensor, Any]): The segments.
|
638 |
+
- `model` (torch.nn.Module): The model.
|
639 |
+
- `clip` (Any): The clip model.
|
640 |
+
- `vae` (VariationalAE.VAE): The VAE model.
|
641 |
+
- `guide_size` (int): The guide size.
|
642 |
+
- `guide_size_for_bbox` (bool): Whether to use guide size for bbox.
|
643 |
+
- `max_size` (int): The maximum size.
|
644 |
+
- `seed` (int): The seed for random noise.
|
645 |
+
- `steps` (int): The number of steps.
|
646 |
+
- `cfg` (int): Classifier-Free Guidance Scale.
|
647 |
+
- `sampler_name` (str): The sampler name.
|
648 |
+
- `scheduler` (str): The scheduler name.
|
649 |
+
- `positive` (Any): The positive prompt.
|
650 |
+
- `negative` (Any): The negative prompt.
|
651 |
+
- `denoise` (float): The denoise factor.
|
652 |
+
- `feather` (int): The feather value.
|
653 |
+
- `noise_mask` (Optional[torch.Tensor]): The noise mask tensor.
|
654 |
+
- `force_inpaint` (bool): Whether to force inpaint.
|
655 |
+
- `wildcard_opt` (Optional[Any], optional): The wildcard options. Defaults to None.
|
656 |
+
- `detailer_hook` (Optional[callable], optional): The detailer hook. Defaults to None.
|
657 |
+
- `refiner_ratio` (Optional[float], optional): The refiner ratio. Defaults to None.
|
658 |
+
- `refiner_model` (Optional[torch.nn.Module], optional): The refiner model. Defaults to None.
|
659 |
+
- `refiner_clip` (Optional[Any], optional): The refiner clip. Defaults to None.
|
660 |
+
- `refiner_positive` (Optional[Any], optional): The refiner positive prompt. Defaults to None.
|
661 |
+
- `refiner_negative` (Optional[Any], optional): The refiner negative prompt. Defaults to None.
|
662 |
+
- `cycle` (int, optional): The number of cycles. Defaults to 1.
|
663 |
+
- `inpaint_model` (bool, optional): Whether to use inpaint model. Defaults to False.
|
664 |
+
- `noise_mask_feather` (int, optional): The noise mask feather. Defaults to 0.
|
665 |
+
- `scheduler_func_opt` (Optional[callable], optional): The scheduler function. Defaults to None.
|
666 |
+
- `pipeline` (bool, optional): Whether to use pipeline. Defaults to False.
|
667 |
+
|
668 |
+
#### Returns:
|
669 |
+
- `Tuple[torch.Tensor, list, list, list, list, Tuple[torch.Tensor, list]]`: The detailed image tensor, cropped list, enhanced list, enhanced alpha list, cnet PIL list, and new segments.
|
670 |
+
"""
|
671 |
+
image = image.clone()
|
672 |
+
enhanced_alpha_list = []
|
673 |
+
enhanced_list = []
|
674 |
+
cropped_list = []
|
675 |
+
cnet_pil_list = []
|
676 |
+
|
677 |
+
segs = AD_util.segs_scale_match(segs, image.shape)
|
678 |
+
new_segs = []
|
679 |
+
|
680 |
+
wildcard_concat_mode = None
|
681 |
+
wmode, wildcard_chooser = bbox.process_wildcard_for_segs(wildcard_opt)
|
682 |
+
|
683 |
+
ordered_segs = segs[1]
|
684 |
+
|
685 |
+
if (
|
686 |
+
noise_mask_feather > 0
|
687 |
+
and "denoise_mask_function" not in model.model_options
|
688 |
+
):
|
689 |
+
model = DifferentialDiffusion().apply(model)[0]
|
690 |
+
|
691 |
+
for i, seg in enumerate(ordered_segs):
|
692 |
+
cropped_image = AD_util.crop_ndarray4(
|
693 |
+
image.cpu().numpy(), seg.crop_region
|
694 |
+
) # Never use seg.cropped_image to handle overlapping area
|
695 |
+
cropped_image = tensor_util.to_tensor(cropped_image)
|
696 |
+
mask = tensor_util.to_tensor(seg.cropped_mask)
|
697 |
+
mask = tensor_util.tensor_gaussian_blur_mask(mask, feather)
|
698 |
+
|
699 |
+
is_mask_all_zeros = (seg.cropped_mask == 0).all().item()
|
700 |
+
if is_mask_all_zeros:
|
701 |
+
print("Detailer: segment skip [empty mask]")
|
702 |
+
continue
|
703 |
+
|
704 |
+
cropped_mask = seg.cropped_mask
|
705 |
+
|
706 |
+
seg_seed, wildcard_item = wildcard_chooser.get(seg)
|
707 |
+
|
708 |
+
seg_seed = seed + i if seg_seed is None else seg_seed
|
709 |
+
|
710 |
+
cropped_positive = [
|
711 |
+
[
|
712 |
+
condition,
|
713 |
+
{
|
714 |
+
k: (
|
715 |
+
crop_condition_mask(v, image, seg.crop_region)
|
716 |
+
if k == "mask"
|
717 |
+
else v
|
718 |
+
)
|
719 |
+
for k, v in details.items()
|
720 |
+
},
|
721 |
+
]
|
722 |
+
for condition, details in positive
|
723 |
+
]
|
724 |
+
|
725 |
+
cropped_negative = [
|
726 |
+
[
|
727 |
+
condition,
|
728 |
+
{
|
729 |
+
k: (
|
730 |
+
crop_condition_mask(v, image, seg.crop_region)
|
731 |
+
if k == "mask"
|
732 |
+
else v
|
733 |
+
)
|
734 |
+
for k, v in details.items()
|
735 |
+
},
|
736 |
+
]
|
737 |
+
for condition, details in negative
|
738 |
+
]
|
739 |
+
|
740 |
+
orig_cropped_image = cropped_image.clone()
|
741 |
+
enhanced_image, cnet_pils = enhance_detail(
|
742 |
+
cropped_image,
|
743 |
+
model,
|
744 |
+
clip,
|
745 |
+
vae,
|
746 |
+
guide_size,
|
747 |
+
guide_size_for_bbox,
|
748 |
+
max_size,
|
749 |
+
seg.bbox,
|
750 |
+
seg_seed,
|
751 |
+
steps,
|
752 |
+
cfg,
|
753 |
+
sampler_name,
|
754 |
+
scheduler,
|
755 |
+
cropped_positive,
|
756 |
+
cropped_negative,
|
757 |
+
denoise,
|
758 |
+
cropped_mask,
|
759 |
+
force_inpaint,
|
760 |
+
wildcard_opt=wildcard_item,
|
761 |
+
wildcard_opt_concat_mode=wildcard_concat_mode,
|
762 |
+
detailer_hook=detailer_hook,
|
763 |
+
refiner_ratio=refiner_ratio,
|
764 |
+
refiner_model=refiner_model,
|
765 |
+
refiner_clip=refiner_clip,
|
766 |
+
refiner_positive=refiner_positive,
|
767 |
+
refiner_negative=refiner_negative,
|
768 |
+
control_net_wrapper=seg.control_net_wrapper,
|
769 |
+
cycle=cycle,
|
770 |
+
inpaint_model=inpaint_model,
|
771 |
+
noise_mask_feather=noise_mask_feather,
|
772 |
+
scheduler_func=scheduler_func_opt,
|
773 |
+
pipeline=pipeline,
|
774 |
+
)
|
775 |
+
|
776 |
+
if enhanced_image is not None:
|
777 |
+
# don't latent composite-> converting to latent caused poor quality
|
778 |
+
# use image paste
|
779 |
+
image = image.cpu()
|
780 |
+
enhanced_image = enhanced_image.cpu()
|
781 |
+
tensor_util.tensor_paste(
|
782 |
+
image,
|
783 |
+
enhanced_image,
|
784 |
+
(seg.crop_region[0], seg.crop_region[1]),
|
785 |
+
mask,
|
786 |
+
) # this code affecting to `cropped_image`.
|
787 |
+
enhanced_list.append(enhanced_image)
|
788 |
+
|
789 |
+
# Convert enhanced_pil_alpha to RGBA mode
|
790 |
+
enhanced_image_alpha = tensor_util.tensor_convert_rgba(enhanced_image)
|
791 |
+
new_seg_image = (
|
792 |
+
enhanced_image.numpy()
|
793 |
+
) # alpha should not be applied to seg_image
|
794 |
+
# Apply the mask
|
795 |
+
mask = tensor_util.tensor_resize(
|
796 |
+
mask, *tensor_util.tensor_get_size(enhanced_image)
|
797 |
+
)
|
798 |
+
tensor_util.tensor_putalpha(enhanced_image_alpha, mask)
|
799 |
+
enhanced_alpha_list.append(enhanced_image_alpha)
|
800 |
+
|
801 |
+
cropped_list.append(orig_cropped_image) # NOTE: Don't use `cropped_image`
|
802 |
+
|
803 |
+
new_seg = SEGS.SEG(
|
804 |
+
new_seg_image,
|
805 |
+
seg.cropped_mask,
|
806 |
+
seg.confidence,
|
807 |
+
seg.crop_region,
|
808 |
+
seg.bbox,
|
809 |
+
seg.label,
|
810 |
+
seg.control_net_wrapper,
|
811 |
+
)
|
812 |
+
new_segs.append(new_seg)
|
813 |
+
|
814 |
+
image_tensor = tensor_util.tensor_convert_rgb(image)
|
815 |
+
|
816 |
+
cropped_list.sort(key=lambda x: x.shape, reverse=True)
|
817 |
+
enhanced_list.sort(key=lambda x: x.shape, reverse=True)
|
818 |
+
enhanced_alpha_list.sort(key=lambda x: x.shape, reverse=True)
|
819 |
+
|
820 |
+
return (
|
821 |
+
image_tensor,
|
822 |
+
cropped_list,
|
823 |
+
enhanced_list,
|
824 |
+
enhanced_alpha_list,
|
825 |
+
cnet_pil_list,
|
826 |
+
(segs[0], new_segs),
|
827 |
+
)
|
828 |
+
|
829 |
+
|
830 |
+
def empty_pil_tensor(w: int = 64, h: int = 64) -> torch.Tensor:
|
831 |
+
"""#### Create an empty PIL tensor.
|
832 |
+
|
833 |
+
#### Args:
|
834 |
+
- `w` (int, optional): The width of the tensor. Defaults to 64.
|
835 |
+
- `h` (int, optional): The height of the tensor. Defaults to 64.
|
836 |
+
|
837 |
+
#### Returns:
|
838 |
+
- `torch.Tensor`: The empty tensor.
|
839 |
+
"""
|
840 |
+
return torch.zeros((1, h, w, 3), dtype=torch.float32)
|
841 |
+
|
842 |
+
|
843 |
+
class DetailerForEachTest(DetailerForEach):
|
844 |
+
"""#### Test class for DetailerForEach."""
|
845 |
+
|
846 |
+
def doit(
|
847 |
+
self,
|
848 |
+
image: torch.Tensor,
|
849 |
+
segs: Any,
|
850 |
+
model: torch.nn.Module,
|
851 |
+
clip: Any,
|
852 |
+
vae: VariationalAE.VAE,
|
853 |
+
guide_size: int,
|
854 |
+
guide_size_for: bool,
|
855 |
+
max_size: int,
|
856 |
+
seed: int,
|
857 |
+
steps: int,
|
858 |
+
cfg: Any,
|
859 |
+
sampler_name: str,
|
860 |
+
scheduler: str,
|
861 |
+
positive: Any,
|
862 |
+
negative: Any,
|
863 |
+
denoise: float,
|
864 |
+
feather: int,
|
865 |
+
noise_mask: Optional[torch.Tensor],
|
866 |
+
force_inpaint: bool,
|
867 |
+
wildcard: Optional[Any],
|
868 |
+
detailer_hook: Optional[callable] = None,
|
869 |
+
cycle: int = 1,
|
870 |
+
inpaint_model: bool = False,
|
871 |
+
noise_mask_feather: int = 0,
|
872 |
+
scheduler_func_opt: Optional[callable] = None,
|
873 |
+
pipeline: bool = False,
|
874 |
+
) -> Tuple[torch.Tensor, list, list, list, list]:
|
875 |
+
"""#### Perform detail enhancement for testing.
|
876 |
+
|
877 |
+
#### Args:
|
878 |
+
- `image` (torch.Tensor): The input image tensor.
|
879 |
+
- `segs` (Any): The segments.
|
880 |
+
- `model` (torch.nn.Module): The model.
|
881 |
+
- `clip` (Any): The clip model.
|
882 |
+
- `vae` (VariationalAE.VAE): The VAE model.
|
883 |
+
- `guide_size` (int): The guide size.
|
884 |
+
- `guide_size_for` (bool): Whether to use guide size for.
|
885 |
+
- `max_size` (int): The maximum size.
|
886 |
+
- `seed` (int): The seed for random noise.
|
887 |
+
- `steps` (int): The number of steps.
|
888 |
+
- `cfg` (Any): The configuration.
|
889 |
+
- `sampler_name` (str): The sampler name.
|
890 |
+
- `scheduler` (str): The scheduler name.
|
891 |
+
- `positive` (Any): The positive prompt.
|
892 |
+
- `negative` (Any): The negative prompt.
|
893 |
+
- `denoise` (float): The denoise factor.
|
894 |
+
- `feather` (int): The feather value.
|
895 |
+
- `noise_mask` (Optional[torch.Tensor]): The noise mask tensor.
|
896 |
+
- `force_inpaint` (bool): Whether to force inpaint.
|
897 |
+
- `wildcard` (Optional[Any]): The wildcard options.
|
898 |
+
- `detailer_hook` (Optional[callable], optional): The detailer hook. Defaults to None.
|
899 |
+
- `cycle` (int, optional): The number of cycles. Defaults to 1.
|
900 |
+
- `inpaint_model` (bool, optional): Whether to use inpaint model. Defaults to False.
|
901 |
+
- `noise_mask_feather` (int, optional): The noise mask feather. Defaults to 0.
|
902 |
+
- `scheduler_func_opt` (Optional[callable], optional): The scheduler function. Defaults to None.
|
903 |
+
- `pipeline` (bool, optional): Whether to use pipeline. Defaults to False.
|
904 |
+
|
905 |
+
#### Returns:
|
906 |
+
- `Tuple[torch.Tensor, list, list, list, list]`: The enhanced image tensor, cropped list, cropped enhanced list, cropped enhanced alpha list, and cnet PIL list.
|
907 |
+
"""
|
908 |
+
(
|
909 |
+
enhanced_img,
|
910 |
+
cropped,
|
911 |
+
cropped_enhanced,
|
912 |
+
cropped_enhanced_alpha,
|
913 |
+
cnet_pil_list,
|
914 |
+
new_segs,
|
915 |
+
) = DetailerForEach.do_detail(
|
916 |
+
image,
|
917 |
+
segs,
|
918 |
+
model,
|
919 |
+
clip,
|
920 |
+
vae,
|
921 |
+
guide_size,
|
922 |
+
guide_size_for,
|
923 |
+
max_size,
|
924 |
+
seed,
|
925 |
+
steps,
|
926 |
+
cfg,
|
927 |
+
sampler_name,
|
928 |
+
scheduler,
|
929 |
+
positive,
|
930 |
+
negative,
|
931 |
+
denoise,
|
932 |
+
feather,
|
933 |
+
noise_mask,
|
934 |
+
force_inpaint,
|
935 |
+
wildcard,
|
936 |
+
detailer_hook,
|
937 |
+
cycle=cycle,
|
938 |
+
inpaint_model=inpaint_model,
|
939 |
+
noise_mask_feather=noise_mask_feather,
|
940 |
+
scheduler_func_opt=scheduler_func_opt,
|
941 |
+
pipeline=pipeline,
|
942 |
+
)
|
943 |
+
|
944 |
+
cnet_pil_list = [empty_pil_tensor()]
|
945 |
+
|
946 |
+
return (
|
947 |
+
enhanced_img,
|
948 |
+
cropped,
|
949 |
+
cropped_enhanced,
|
950 |
+
cropped_enhanced_alpha,
|
951 |
+
cnet_pil_list,
|
952 |
+
)
|
modules/AutoDetailer/SAM.py
ADDED
@@ -0,0 +1,300 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
from segment_anything import SamPredictor, sam_model_registry
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from modules.AutoDetailer import mask_util
|
7 |
+
from modules.Device import Device
|
8 |
+
|
9 |
+
|
10 |
+
def sam_predict(
|
11 |
+
predictor: SamPredictor, points: list, plabs: list, bbox: list, threshold: float
|
12 |
+
) -> list:
|
13 |
+
"""#### Predict masks using SAM.
|
14 |
+
|
15 |
+
#### Args:
|
16 |
+
- `predictor` (SamPredictor): The SAM predictor.
|
17 |
+
- `points` (list): List of points.
|
18 |
+
- `plabs` (list): List of point labels.
|
19 |
+
- `bbox` (list): Bounding box.
|
20 |
+
- `threshold` (float): Threshold for mask selection.
|
21 |
+
|
22 |
+
#### Returns:
|
23 |
+
- `list`: List of predicted masks.
|
24 |
+
"""
|
25 |
+
point_coords = None if not points else np.array(points)
|
26 |
+
point_labels = None if not plabs else np.array(plabs)
|
27 |
+
|
28 |
+
box = np.array([bbox]) if bbox is not None else None
|
29 |
+
|
30 |
+
cur_masks, scores, _ = predictor.predict(
|
31 |
+
point_coords=point_coords, point_labels=point_labels, box=box
|
32 |
+
)
|
33 |
+
|
34 |
+
total_masks = []
|
35 |
+
|
36 |
+
selected = False
|
37 |
+
max_score = 0
|
38 |
+
max_mask = None
|
39 |
+
for idx in range(len(scores)):
|
40 |
+
if scores[idx] > max_score:
|
41 |
+
max_score = scores[idx]
|
42 |
+
max_mask = cur_masks[idx]
|
43 |
+
|
44 |
+
if scores[idx] >= threshold:
|
45 |
+
selected = True
|
46 |
+
total_masks.append(cur_masks[idx])
|
47 |
+
else:
|
48 |
+
pass
|
49 |
+
|
50 |
+
if not selected and max_mask is not None:
|
51 |
+
total_masks.append(max_mask)
|
52 |
+
|
53 |
+
return total_masks
|
54 |
+
|
55 |
+
|
56 |
+
def is_same_device(a: torch.device, b: torch.device) -> bool:
|
57 |
+
"""#### Check if two devices are the same.
|
58 |
+
|
59 |
+
#### Args:
|
60 |
+
- `a` (torch.device): The first device.
|
61 |
+
- `b` (torch.device): The second device.
|
62 |
+
|
63 |
+
#### Returns:
|
64 |
+
- `bool`: Whether the devices are the same.
|
65 |
+
"""
|
66 |
+
a_device = torch.device(a) if isinstance(a, str) else a
|
67 |
+
b_device = torch.device(b) if isinstance(b, str) else b
|
68 |
+
return a_device.type == b_device.type and a_device.index == b_device.index
|
69 |
+
|
70 |
+
|
71 |
+
class SafeToGPU:
|
72 |
+
"""#### Class to safely move objects to GPU."""
|
73 |
+
|
74 |
+
def __init__(self, size: int):
|
75 |
+
self.size = size
|
76 |
+
|
77 |
+
def to_device(self, obj: torch.nn.Module, device: torch.device) -> None:
|
78 |
+
"""#### Move an object to a device.
|
79 |
+
|
80 |
+
#### Args:
|
81 |
+
- `obj` (torch.nn.Module): The object to move.
|
82 |
+
- `device` (torch.device): The target device.
|
83 |
+
"""
|
84 |
+
if is_same_device(device, "cpu"):
|
85 |
+
obj.to(device)
|
86 |
+
else:
|
87 |
+
if is_same_device(obj.device, "cpu"): # cpu to gpu
|
88 |
+
Device.free_memory(self.size * 1.3, device)
|
89 |
+
if Device.get_free_memory(device) > self.size * 1.3:
|
90 |
+
try:
|
91 |
+
obj.to(device)
|
92 |
+
except:
|
93 |
+
print(
|
94 |
+
f"WARN: The model is not moved to the '{device}' due to insufficient memory. [1]"
|
95 |
+
)
|
96 |
+
else:
|
97 |
+
print(
|
98 |
+
f"WARN: The model is not moved to the '{device}' due to insufficient memory. [2]"
|
99 |
+
)
|
100 |
+
|
101 |
+
|
102 |
+
class SAMWrapper:
|
103 |
+
"""#### Wrapper class for SAM model."""
|
104 |
+
|
105 |
+
def __init__(
|
106 |
+
self, model: torch.nn.Module, is_auto_mode: bool, safe_to_gpu: SafeToGPU = None
|
107 |
+
):
|
108 |
+
self.model = model
|
109 |
+
self.safe_to_gpu = safe_to_gpu if safe_to_gpu is not None else SafeToGPU()
|
110 |
+
self.is_auto_mode = is_auto_mode
|
111 |
+
|
112 |
+
def prepare_device(self) -> None:
|
113 |
+
"""#### Prepare the device for the model."""
|
114 |
+
if self.is_auto_mode:
|
115 |
+
device = Device.get_torch_device()
|
116 |
+
self.safe_to_gpu.to_device(self.model, device=device)
|
117 |
+
|
118 |
+
def release_device(self) -> None:
|
119 |
+
"""#### Release the device from the model."""
|
120 |
+
if self.is_auto_mode:
|
121 |
+
self.model.to(device="cpu")
|
122 |
+
|
123 |
+
def predict(
|
124 |
+
self, image: np.ndarray, points: list, plabs: list, bbox: list, threshold: float
|
125 |
+
) -> list:
|
126 |
+
"""#### Predict masks using the SAM model.
|
127 |
+
|
128 |
+
#### Args:
|
129 |
+
- `image` (np.ndarray): The input image.
|
130 |
+
- `points` (list): List of points.
|
131 |
+
- `plabs` (list): List of point labels.
|
132 |
+
- `bbox` (list): Bounding box.
|
133 |
+
- `threshold` (float): Threshold for mask selection.
|
134 |
+
|
135 |
+
#### Returns:
|
136 |
+
- `list`: List of predicted masks.
|
137 |
+
"""
|
138 |
+
predictor = SamPredictor(self.model)
|
139 |
+
predictor.set_image(image, "RGB")
|
140 |
+
|
141 |
+
return sam_predict(predictor, points, plabs, bbox, threshold)
|
142 |
+
|
143 |
+
|
144 |
+
class SAMLoader:
|
145 |
+
"""#### Class to load SAM models."""
|
146 |
+
|
147 |
+
def load_model(self, model_name: str, device_mode: str = "auto") -> tuple:
|
148 |
+
"""#### Load a SAM model.
|
149 |
+
|
150 |
+
#### Args:
|
151 |
+
- `model_name` (str): The name of the model.
|
152 |
+
- `device_mode` (str, optional): The device mode. Defaults to "auto".
|
153 |
+
|
154 |
+
#### Returns:
|
155 |
+
- `tuple`: The loaded SAM model.
|
156 |
+
"""
|
157 |
+
modelname = "./_internal/yolos/" + model_name
|
158 |
+
|
159 |
+
if "vit_h" in model_name:
|
160 |
+
model_kind = "vit_h"
|
161 |
+
elif "vit_l" in model_name:
|
162 |
+
model_kind = "vit_l"
|
163 |
+
else:
|
164 |
+
model_kind = "vit_b"
|
165 |
+
|
166 |
+
sam = sam_model_registry[model_kind](checkpoint=modelname)
|
167 |
+
size = os.path.getsize(modelname)
|
168 |
+
safe_to = SafeToGPU(size)
|
169 |
+
|
170 |
+
# Unless user explicitly wants to use CPU, we use GPU
|
171 |
+
device = Device.get_torch_device() if device_mode == "Prefer GPU" else "CPU"
|
172 |
+
|
173 |
+
if device_mode == "Prefer GPU":
|
174 |
+
safe_to.to_device(sam, device)
|
175 |
+
|
176 |
+
is_auto_mode = device_mode == "AUTO"
|
177 |
+
|
178 |
+
sam_obj = SAMWrapper(sam, is_auto_mode=is_auto_mode, safe_to_gpu=safe_to)
|
179 |
+
sam.sam_wrapper = sam_obj
|
180 |
+
|
181 |
+
print(f"Loads SAM model: {modelname} (device:{device_mode})")
|
182 |
+
return (sam,)
|
183 |
+
|
184 |
+
|
185 |
+
def make_sam_mask(
|
186 |
+
sam: SAMWrapper,
|
187 |
+
segs: tuple,
|
188 |
+
image: torch.Tensor,
|
189 |
+
detection_hint: bool,
|
190 |
+
dilation: int,
|
191 |
+
threshold: float,
|
192 |
+
bbox_expansion: int,
|
193 |
+
mask_hint_threshold: float,
|
194 |
+
mask_hint_use_negative: bool,
|
195 |
+
) -> torch.Tensor:
|
196 |
+
"""#### Create a SAM mask.
|
197 |
+
|
198 |
+
#### Args:
|
199 |
+
- `sam` (SAMWrapper): The SAM wrapper.
|
200 |
+
- `segs` (tuple): Segmentation information.
|
201 |
+
- `image` (torch.Tensor): The input image.
|
202 |
+
- `detection_hint` (bool): Whether to use detection hint.
|
203 |
+
- `dilation` (int): Dilation value.
|
204 |
+
- `threshold` (float): Threshold for mask selection.
|
205 |
+
- `bbox_expansion` (int): Bounding box expansion value.
|
206 |
+
- `mask_hint_threshold` (float): Mask hint threshold.
|
207 |
+
- `mask_hint_use_negative` (bool): Whether to use negative mask hint.
|
208 |
+
|
209 |
+
#### Returns:
|
210 |
+
- `torch.Tensor`: The created SAM mask.
|
211 |
+
"""
|
212 |
+
sam_obj = sam.sam_wrapper
|
213 |
+
sam_obj.prepare_device()
|
214 |
+
|
215 |
+
try:
|
216 |
+
image = np.clip(255.0 * image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8)
|
217 |
+
|
218 |
+
total_masks = []
|
219 |
+
# seg_shape = segs[0]
|
220 |
+
segs = segs[1]
|
221 |
+
for i in range(len(segs)):
|
222 |
+
bbox = segs[i].bbox
|
223 |
+
center = mask_util.center_of_bbox(bbox)
|
224 |
+
x1 = max(bbox[0] - bbox_expansion, 0)
|
225 |
+
y1 = max(bbox[1] - bbox_expansion, 0)
|
226 |
+
x2 = min(bbox[2] + bbox_expansion, image.shape[1])
|
227 |
+
y2 = min(bbox[3] + bbox_expansion, image.shape[0])
|
228 |
+
dilated_bbox = [x1, y1, x2, y2]
|
229 |
+
points = []
|
230 |
+
plabs = []
|
231 |
+
points.append(center)
|
232 |
+
plabs = [1] # 1 = foreground point, 0 = background point
|
233 |
+
detected_masks = sam_obj.predict(
|
234 |
+
image, points, plabs, dilated_bbox, threshold
|
235 |
+
)
|
236 |
+
total_masks += detected_masks
|
237 |
+
|
238 |
+
# merge every collected masks
|
239 |
+
mask = mask_util.combine_masks2(total_masks)
|
240 |
+
|
241 |
+
finally:
|
242 |
+
sam_obj.release_device()
|
243 |
+
|
244 |
+
if mask is not None:
|
245 |
+
mask = mask.float()
|
246 |
+
mask = mask_util.dilate_mask(mask.cpu().numpy(), dilation)
|
247 |
+
mask = torch.from_numpy(mask)
|
248 |
+
|
249 |
+
mask = mask_util.make_3d_mask(mask)
|
250 |
+
return mask
|
251 |
+
else:
|
252 |
+
return None
|
253 |
+
|
254 |
+
|
255 |
+
class SAMDetectorCombined:
|
256 |
+
"""#### Class to combine SAM detection."""
|
257 |
+
|
258 |
+
def doit(
|
259 |
+
self,
|
260 |
+
sam_model: SAMWrapper,
|
261 |
+
segs: tuple,
|
262 |
+
image: torch.Tensor,
|
263 |
+
detection_hint: bool,
|
264 |
+
dilation: int,
|
265 |
+
threshold: float,
|
266 |
+
bbox_expansion: int,
|
267 |
+
mask_hint_threshold: float,
|
268 |
+
mask_hint_use_negative: bool,
|
269 |
+
) -> tuple:
|
270 |
+
"""#### Combine SAM detection.
|
271 |
+
|
272 |
+
#### Args:
|
273 |
+
- `sam_model` (SAMWrapper): The SAM wrapper.
|
274 |
+
- `segs` (tuple): Segmentation information.
|
275 |
+
- `image` (torch.Tensor): The input image.
|
276 |
+
- `detection_hint` (bool): Whether to use detection hint.
|
277 |
+
- `dilation` (int): Dilation value.
|
278 |
+
- `threshold` (float): Threshold for mask selection.
|
279 |
+
- `bbox_expansion` (int): Bounding box expansion value.
|
280 |
+
- `mask_hint_threshold` (float): Mask hint threshold.
|
281 |
+
- `mask_hint_use_negative` (bool): Whether to use negative mask hint.
|
282 |
+
|
283 |
+
#### Returns:
|
284 |
+
- `tuple`: The combined SAM detection result.
|
285 |
+
"""
|
286 |
+
sam = make_sam_mask(
|
287 |
+
sam_model,
|
288 |
+
segs,
|
289 |
+
image,
|
290 |
+
detection_hint,
|
291 |
+
dilation,
|
292 |
+
threshold,
|
293 |
+
bbox_expansion,
|
294 |
+
mask_hint_threshold,
|
295 |
+
mask_hint_use_negative,
|
296 |
+
)
|
297 |
+
if sam is not None:
|
298 |
+
return (sam,)
|
299 |
+
else:
|
300 |
+
return None
|
modules/AutoDetailer/SEGS.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import namedtuple
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from modules.AutoDetailer import mask_util
|
5 |
+
|
6 |
+
SEG = namedtuple(
|
7 |
+
"SEG",
|
8 |
+
[
|
9 |
+
"cropped_image",
|
10 |
+
"cropped_mask",
|
11 |
+
"confidence",
|
12 |
+
"crop_region",
|
13 |
+
"bbox",
|
14 |
+
"label",
|
15 |
+
"control_net_wrapper",
|
16 |
+
],
|
17 |
+
defaults=[None],
|
18 |
+
)
|
19 |
+
|
20 |
+
|
21 |
+
def segs_bitwise_and_mask(segs: tuple, mask: torch.Tensor) -> tuple:
|
22 |
+
"""#### Apply bitwise AND operation between segmentation masks and a given mask.
|
23 |
+
|
24 |
+
#### Args:
|
25 |
+
- `segs` (tuple): A tuple containing segmentation information.
|
26 |
+
- `mask` (torch.Tensor): The mask tensor.
|
27 |
+
|
28 |
+
#### Returns:
|
29 |
+
- `tuple`: A tuple containing the original segmentation and the updated items.
|
30 |
+
"""
|
31 |
+
mask = mask_util.make_2d_mask(mask)
|
32 |
+
items = []
|
33 |
+
|
34 |
+
mask = (mask.cpu().numpy() * 255).astype(np.uint8)
|
35 |
+
|
36 |
+
for seg in segs[1]:
|
37 |
+
cropped_mask = (seg.cropped_mask * 255).astype(np.uint8)
|
38 |
+
crop_region = seg.crop_region
|
39 |
+
|
40 |
+
cropped_mask2 = mask[
|
41 |
+
crop_region[1] : crop_region[3], crop_region[0] : crop_region[2]
|
42 |
+
]
|
43 |
+
|
44 |
+
new_mask = np.bitwise_and(cropped_mask.astype(np.uint8), cropped_mask2)
|
45 |
+
new_mask = new_mask.astype(np.float32) / 255.0
|
46 |
+
|
47 |
+
item = SEG(
|
48 |
+
seg.cropped_image,
|
49 |
+
new_mask,
|
50 |
+
seg.confidence,
|
51 |
+
seg.crop_region,
|
52 |
+
seg.bbox,
|
53 |
+
seg.label,
|
54 |
+
None,
|
55 |
+
)
|
56 |
+
items.append(item)
|
57 |
+
|
58 |
+
return segs[0], items
|
59 |
+
|
60 |
+
|
61 |
+
class SegsBitwiseAndMask:
|
62 |
+
"""#### Class to apply bitwise AND operation between segmentation masks and a given mask."""
|
63 |
+
|
64 |
+
def doit(self, segs: tuple, mask: torch.Tensor) -> tuple:
|
65 |
+
"""#### Apply bitwise AND operation between segmentation masks and a given mask.
|
66 |
+
|
67 |
+
#### Args:
|
68 |
+
- `segs` (tuple): A tuple containing segmentation information.
|
69 |
+
- `mask` (torch.Tensor): The mask tensor.
|
70 |
+
|
71 |
+
#### Returns:
|
72 |
+
- `tuple`: A tuple containing the original segmentation and the updated items.
|
73 |
+
"""
|
74 |
+
return (segs_bitwise_and_mask(segs, mask),)
|
75 |
+
|
76 |
+
|
77 |
+
class SEGSLabelFilter:
|
78 |
+
"""#### Class to filter segmentation labels."""
|
79 |
+
|
80 |
+
@staticmethod
|
81 |
+
def filter(segs: tuple, labels: list) -> tuple:
|
82 |
+
"""#### Filter segmentation labels.
|
83 |
+
|
84 |
+
#### Args:
|
85 |
+
- `segs` (tuple): A tuple containing segmentation information.
|
86 |
+
- `labels` (list): A list of labels to filter.
|
87 |
+
|
88 |
+
#### Returns:
|
89 |
+
- `tuple`: A tuple containing the original segmentation and an empty list.
|
90 |
+
"""
|
91 |
+
labels = set([label.strip() for label in labels])
|
92 |
+
return (
|
93 |
+
segs,
|
94 |
+
(segs[0], []),
|
95 |
+
)
|
modules/AutoDetailer/bbox.py
ADDED
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from ultralytics import YOLO
|
3 |
+
from modules.AutoDetailer import SEGS, AD_util, tensor_util
|
4 |
+
from typing import List, Tuple, Optional
|
5 |
+
|
6 |
+
|
7 |
+
class UltraBBoxDetector:
|
8 |
+
"""#### Class to detect bounding boxes using a YOLO model."""
|
9 |
+
|
10 |
+
bbox_model: Optional[YOLO] = None
|
11 |
+
|
12 |
+
def __init__(self, bbox_model: YOLO):
|
13 |
+
"""#### Initialize the UltraBBoxDetector with a YOLO model.
|
14 |
+
|
15 |
+
#### Args:
|
16 |
+
- `bbox_model` (YOLO): The YOLO model to use for detection.
|
17 |
+
"""
|
18 |
+
self.bbox_model = bbox_model
|
19 |
+
|
20 |
+
def detect(
|
21 |
+
self,
|
22 |
+
image: torch.Tensor,
|
23 |
+
threshold: float,
|
24 |
+
dilation: int,
|
25 |
+
crop_factor: float,
|
26 |
+
drop_size: int = 1,
|
27 |
+
detailer_hook: Optional[callable] = None,
|
28 |
+
) -> Tuple[Tuple[int, int], List[SEGS.SEG]]:
|
29 |
+
"""#### Detect bounding boxes in an image.
|
30 |
+
|
31 |
+
#### Args:
|
32 |
+
- `image` (torch.Tensor): The input image tensor.
|
33 |
+
- `threshold` (float): The detection threshold.
|
34 |
+
- `dilation` (int): The dilation factor for masks.
|
35 |
+
- `crop_factor` (float): The crop factor for bounding boxes.
|
36 |
+
- `drop_size` (int, optional): The minimum size of bounding boxes to keep. Defaults to 1.
|
37 |
+
- `detailer_hook` (callable, optional): A hook function for additional processing. Defaults to None.
|
38 |
+
|
39 |
+
#### Returns:
|
40 |
+
- `Tuple[Tuple[int, int], List[SEGS.SEG]]`: The shape of the image and a list of detected segments.
|
41 |
+
"""
|
42 |
+
drop_size = max(drop_size, 1)
|
43 |
+
detected_results = AD_util.inference_bbox(
|
44 |
+
self.bbox_model, tensor_util.tensor2pil(image), threshold
|
45 |
+
)
|
46 |
+
segmasks = AD_util.create_segmasks(detected_results)
|
47 |
+
|
48 |
+
if dilation > 0:
|
49 |
+
segmasks = AD_util.dilate_masks(segmasks, dilation)
|
50 |
+
|
51 |
+
items = []
|
52 |
+
h = image.shape[1]
|
53 |
+
w = image.shape[2]
|
54 |
+
|
55 |
+
for x, label in zip(segmasks, detected_results[0]):
|
56 |
+
item_bbox = x[0]
|
57 |
+
item_mask = x[1]
|
58 |
+
|
59 |
+
y1, x1, y2, x2 = item_bbox
|
60 |
+
|
61 |
+
if (
|
62 |
+
x2 - x1 > drop_size and y2 - y1 > drop_size
|
63 |
+
): # minimum dimension must be (2,2) to avoid squeeze issue
|
64 |
+
crop_region = AD_util.make_crop_region(w, h, item_bbox, crop_factor)
|
65 |
+
|
66 |
+
cropped_image = AD_util.crop_image(image, crop_region)
|
67 |
+
cropped_mask = AD_util.crop_ndarray2(item_mask, crop_region)
|
68 |
+
confidence = x[2]
|
69 |
+
|
70 |
+
item = SEGS.SEG(
|
71 |
+
cropped_image,
|
72 |
+
cropped_mask,
|
73 |
+
confidence,
|
74 |
+
crop_region,
|
75 |
+
item_bbox,
|
76 |
+
label,
|
77 |
+
None,
|
78 |
+
)
|
79 |
+
|
80 |
+
items.append(item)
|
81 |
+
|
82 |
+
shape = image.shape[1], image.shape[2]
|
83 |
+
segs = shape, items
|
84 |
+
|
85 |
+
return segs
|
86 |
+
|
87 |
+
|
88 |
+
class UltraSegmDetector:
|
89 |
+
"""#### Class to detect segments using a YOLO model."""
|
90 |
+
|
91 |
+
bbox_model: Optional[YOLO] = None
|
92 |
+
|
93 |
+
def __init__(self, bbox_model: YOLO):
|
94 |
+
"""#### Initialize the UltraSegmDetector with a YOLO model.
|
95 |
+
|
96 |
+
#### Args:
|
97 |
+
- `bbox_model` (YOLO): The YOLO model to use for detection.
|
98 |
+
"""
|
99 |
+
self.bbox_model = bbox_model
|
100 |
+
|
101 |
+
|
102 |
+
class NO_SEGM_DETECTOR:
|
103 |
+
"""#### Placeholder class for no segment detector."""
|
104 |
+
|
105 |
+
pass
|
106 |
+
|
107 |
+
|
108 |
+
class UltralyticsDetectorProvider:
|
109 |
+
"""#### Class to provide YOLO models for detection."""
|
110 |
+
|
111 |
+
def doit(self, model_name: str) -> Tuple[UltraBBoxDetector, UltraSegmDetector]:
|
112 |
+
"""#### Load a YOLO model and return detectors.
|
113 |
+
|
114 |
+
#### Args:
|
115 |
+
- `model_name` (str): The name of the YOLO model to load.
|
116 |
+
|
117 |
+
#### Returns:
|
118 |
+
- `Tuple[UltraBBoxDetector, UltraSegmDetector]`: The bounding box and segment detectors.
|
119 |
+
"""
|
120 |
+
model = AD_util.load_yolo("./_internal/yolos/" + model_name)
|
121 |
+
return UltraBBoxDetector(model), UltraSegmDetector(model)
|
122 |
+
|
123 |
+
|
124 |
+
class BboxDetectorForEach:
|
125 |
+
"""#### Class to detect bounding boxes for each segment."""
|
126 |
+
|
127 |
+
def doit(
|
128 |
+
self,
|
129 |
+
bbox_detector: UltraBBoxDetector,
|
130 |
+
image: torch.Tensor,
|
131 |
+
threshold: float,
|
132 |
+
dilation: int,
|
133 |
+
crop_factor: float,
|
134 |
+
drop_size: int,
|
135 |
+
labels: Optional[str] = None,
|
136 |
+
detailer_hook: Optional[callable] = None,
|
137 |
+
) -> Tuple[Tuple[int, int], List[SEGS.SEG]]:
|
138 |
+
"""#### Detect bounding boxes for each segment in an image.
|
139 |
+
|
140 |
+
#### Args:
|
141 |
+
- `bbox_detector` (UltraBBoxDetector): The bounding box detector.
|
142 |
+
- `image` (torch.Tensor): The input image tensor.
|
143 |
+
- `threshold` (float): The detection threshold.
|
144 |
+
- `dilation` (int): The dilation factor for masks.
|
145 |
+
- `crop_factor` (float): The crop factor for bounding boxes.
|
146 |
+
- `drop_size` (int): The minimum size of bounding boxes to keep.
|
147 |
+
- `labels` (str, optional): The labels to filter. Defaults to None.
|
148 |
+
- `detailer_hook` (callable, optional): A hook function for additional processing. Defaults to None.
|
149 |
+
|
150 |
+
#### Returns:
|
151 |
+
- `Tuple[Tuple[int, int], List[SEGS.SEG]]`: The shape of the image and a list of detected segments.
|
152 |
+
"""
|
153 |
+
segs = bbox_detector.detect(
|
154 |
+
image, threshold, dilation, crop_factor, drop_size, detailer_hook
|
155 |
+
)
|
156 |
+
|
157 |
+
if labels is not None and labels != "":
|
158 |
+
labels = labels.split(",")
|
159 |
+
if len(labels) > 0:
|
160 |
+
segs, _ = SEGS.SEGSLabelFilter.filter(segs, labels)
|
161 |
+
|
162 |
+
return segs
|
163 |
+
|
164 |
+
|
165 |
+
class WildcardChooser:
|
166 |
+
"""#### Class to choose wildcards for segments."""
|
167 |
+
|
168 |
+
def __init__(self, items: List[Tuple[None, str]], randomize_when_exhaust: bool):
|
169 |
+
"""#### Initialize the WildcardChooser.
|
170 |
+
|
171 |
+
#### Args:
|
172 |
+
- `items` (List[Tuple[None, str]]): The list of items to choose from.
|
173 |
+
- `randomize_when_exhaust` (bool): Whether to randomize when the list is exhausted.
|
174 |
+
"""
|
175 |
+
self.i = 0
|
176 |
+
self.items = items
|
177 |
+
self.randomize_when_exhaust = randomize_when_exhaust
|
178 |
+
|
179 |
+
def get(self, seg: SEGS.SEG) -> Tuple[None, str]:
|
180 |
+
"""#### Get the next item from the list.
|
181 |
+
|
182 |
+
#### Args:
|
183 |
+
- `seg` (SEGS.SEG): The segment.
|
184 |
+
|
185 |
+
#### Returns:
|
186 |
+
- `Tuple[None, str]`: The next item from the list.
|
187 |
+
"""
|
188 |
+
item = self.items[self.i]
|
189 |
+
self.i += 1
|
190 |
+
|
191 |
+
return item
|
192 |
+
|
193 |
+
|
194 |
+
def process_wildcard_for_segs(wildcard: str) -> Tuple[None, WildcardChooser]:
|
195 |
+
"""#### Process a wildcard for segments.
|
196 |
+
|
197 |
+
#### Args:
|
198 |
+
- `wildcard` (str): The wildcard.
|
199 |
+
|
200 |
+
#### Returns:
|
201 |
+
- `Tuple[None, WildcardChooser]`: The processed wildcard and a WildcardChooser.
|
202 |
+
"""
|
203 |
+
return None, WildcardChooser([(None, wildcard)], False)
|
modules/AutoDetailer/mask_util.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
|
4 |
+
|
5 |
+
def center_of_bbox(bbox: list) -> tuple[float, float]:
|
6 |
+
"""#### Calculate the center of a bounding box.
|
7 |
+
|
8 |
+
#### Args:
|
9 |
+
- `bbox` (list): The bounding box coordinates [x1, y1, x2, y2].
|
10 |
+
|
11 |
+
#### Returns:
|
12 |
+
- `tuple[float, float]`: The center coordinates (x, y).
|
13 |
+
"""
|
14 |
+
w, h = bbox[2] - bbox[0], bbox[3] - bbox[1]
|
15 |
+
return bbox[0] + w / 2, bbox[1] + h / 2
|
16 |
+
|
17 |
+
|
18 |
+
def make_2d_mask(mask: torch.Tensor) -> torch.Tensor:
|
19 |
+
"""#### Convert a mask to 2D.
|
20 |
+
|
21 |
+
#### Args:
|
22 |
+
- `mask` (torch.Tensor): The input mask tensor.
|
23 |
+
|
24 |
+
#### Returns:
|
25 |
+
- `torch.Tensor`: The 2D mask tensor.
|
26 |
+
"""
|
27 |
+
if len(mask.shape) == 4:
|
28 |
+
return mask.squeeze(0).squeeze(0)
|
29 |
+
elif len(mask.shape) == 3:
|
30 |
+
return mask.squeeze(0)
|
31 |
+
return mask
|
32 |
+
|
33 |
+
|
34 |
+
def combine_masks2(masks: list) -> torch.Tensor | None:
|
35 |
+
"""#### Combine multiple masks into one.
|
36 |
+
|
37 |
+
#### Args:
|
38 |
+
- `masks` (list): A list of mask tensors.
|
39 |
+
|
40 |
+
#### Returns:
|
41 |
+
- `torch.Tensor | None`: The combined mask tensor or None if no masks are provided.
|
42 |
+
"""
|
43 |
+
try:
|
44 |
+
mask = torch.from_numpy(np.array(masks[0]).astype(np.uint8))
|
45 |
+
except:
|
46 |
+
print("No Human Detected")
|
47 |
+
return None
|
48 |
+
return mask
|
49 |
+
|
50 |
+
|
51 |
+
def dilate_mask(
|
52 |
+
mask: torch.Tensor, dilation_factor: int, iter: int = 1
|
53 |
+
) -> torch.Tensor:
|
54 |
+
"""#### Dilate a mask.
|
55 |
+
|
56 |
+
#### Args:
|
57 |
+
- `mask` (torch.Tensor): The input mask tensor.
|
58 |
+
- `dilation_factor` (int): The dilation factor.
|
59 |
+
- `iter` (int, optional): The number of iterations. Defaults to 1.
|
60 |
+
|
61 |
+
#### Returns:
|
62 |
+
- `torch.Tensor`: The dilated mask tensor.
|
63 |
+
"""
|
64 |
+
return make_2d_mask(mask)
|
65 |
+
|
66 |
+
|
67 |
+
def make_3d_mask(mask: torch.Tensor) -> torch.Tensor:
|
68 |
+
"""#### Convert a mask to 3D.
|
69 |
+
|
70 |
+
#### Args:
|
71 |
+
- `mask` (torch.Tensor): The input mask tensor.
|
72 |
+
|
73 |
+
#### Returns:
|
74 |
+
- `torch.Tensor`: The 3D mask tensor.
|
75 |
+
"""
|
76 |
+
if len(mask.shape) == 4:
|
77 |
+
return mask.squeeze(0)
|
78 |
+
elif len(mask.shape) == 2:
|
79 |
+
return mask.unsqueeze(0)
|
80 |
+
return mask
|
modules/AutoDetailer/tensor_util.py
ADDED
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
from PIL import Image
|
4 |
+
import torchvision
|
5 |
+
|
6 |
+
from modules.Device import Device
|
7 |
+
|
8 |
+
|
9 |
+
def _tensor_check_image(image: torch.Tensor) -> None:
|
10 |
+
"""#### Check if the input is a valid tensor image.
|
11 |
+
|
12 |
+
#### Args:
|
13 |
+
- `image` (torch.Tensor): The input tensor image.
|
14 |
+
"""
|
15 |
+
return
|
16 |
+
|
17 |
+
|
18 |
+
def tensor2pil(image: torch.Tensor) -> Image.Image:
|
19 |
+
"""#### Convert a tensor to a PIL image.
|
20 |
+
|
21 |
+
#### Args:
|
22 |
+
- `image` (torch.Tensor): The input tensor.
|
23 |
+
|
24 |
+
#### Returns:
|
25 |
+
- `Image.Image`: The converted PIL image.
|
26 |
+
"""
|
27 |
+
_tensor_check_image(image)
|
28 |
+
return Image.fromarray(
|
29 |
+
np.clip(255.0 * image.cpu().numpy().squeeze(0), 0, 255).astype(np.uint8)
|
30 |
+
)
|
31 |
+
|
32 |
+
|
33 |
+
def general_tensor_resize(image: torch.Tensor, w: int, h: int) -> torch.Tensor:
|
34 |
+
"""#### Resize a tensor image using bilinear interpolation.
|
35 |
+
|
36 |
+
#### Args:
|
37 |
+
- `image` (torch.Tensor): The input tensor image.
|
38 |
+
- `w` (int): The target width.
|
39 |
+
- `h` (int): The target height.
|
40 |
+
|
41 |
+
#### Returns:
|
42 |
+
- `torch.Tensor`: The resized tensor image.
|
43 |
+
"""
|
44 |
+
_tensor_check_image(image)
|
45 |
+
image = image.permute(0, 3, 1, 2)
|
46 |
+
image = torch.nn.functional.interpolate(image, size=(h, w), mode="bilinear")
|
47 |
+
image = image.permute(0, 2, 3, 1)
|
48 |
+
return image
|
49 |
+
|
50 |
+
|
51 |
+
def pil2tensor(image: Image.Image) -> torch.Tensor:
|
52 |
+
"""#### Convert a PIL image to a tensor.
|
53 |
+
|
54 |
+
#### Args:
|
55 |
+
- `image` (Image.Image): The input PIL image.
|
56 |
+
|
57 |
+
#### Returns:
|
58 |
+
- `torch.Tensor`: The converted tensor.
|
59 |
+
"""
|
60 |
+
return torch.from_numpy(np.array(image).astype(np.float32) / 255.0).unsqueeze(0)
|
61 |
+
|
62 |
+
|
63 |
+
class TensorBatchBuilder:
|
64 |
+
"""#### Class for building a batch of tensors."""
|
65 |
+
|
66 |
+
def __init__(self):
|
67 |
+
self.tensor: torch.Tensor | None = None
|
68 |
+
|
69 |
+
def concat(self, new_tensor: torch.Tensor) -> None:
|
70 |
+
"""#### Concatenate a new tensor to the batch.
|
71 |
+
|
72 |
+
#### Args:
|
73 |
+
- `new_tensor` (torch.Tensor): The new tensor to concatenate.
|
74 |
+
"""
|
75 |
+
self.tensor = new_tensor
|
76 |
+
|
77 |
+
|
78 |
+
LANCZOS = Image.Resampling.LANCZOS if hasattr(Image, "Resampling") else Image.LANCZOS
|
79 |
+
|
80 |
+
|
81 |
+
def tensor_resize(image: torch.Tensor, w: int, h: int) -> torch.Tensor:
|
82 |
+
"""#### Resize a tensor image.
|
83 |
+
|
84 |
+
#### Args:
|
85 |
+
- `image` (torch.Tensor): The input tensor image.
|
86 |
+
- `w` (int): The target width.
|
87 |
+
- `h` (int): The target height.
|
88 |
+
|
89 |
+
#### Returns:
|
90 |
+
- `torch.Tensor`: The resized tensor image.
|
91 |
+
"""
|
92 |
+
_tensor_check_image(image)
|
93 |
+
if image.shape[3] >= 3:
|
94 |
+
scaled_images = TensorBatchBuilder()
|
95 |
+
for single_image in image:
|
96 |
+
single_image = single_image.unsqueeze(0)
|
97 |
+
single_pil = tensor2pil(single_image)
|
98 |
+
scaled_pil = single_pil.resize((w, h), resample=LANCZOS)
|
99 |
+
|
100 |
+
single_image = pil2tensor(scaled_pil)
|
101 |
+
scaled_images.concat(single_image)
|
102 |
+
|
103 |
+
return scaled_images.tensor
|
104 |
+
else:
|
105 |
+
return general_tensor_resize(image, w, h)
|
106 |
+
|
107 |
+
|
108 |
+
def tensor_paste(
|
109 |
+
image1: torch.Tensor,
|
110 |
+
image2: torch.Tensor,
|
111 |
+
left_top: tuple[int, int],
|
112 |
+
mask: torch.Tensor,
|
113 |
+
) -> None:
|
114 |
+
"""#### Paste one tensor image onto another using a mask.
|
115 |
+
|
116 |
+
#### Args:
|
117 |
+
- `image1` (torch.Tensor): The base tensor image.
|
118 |
+
- `image2` (torch.Tensor): The tensor image to paste.
|
119 |
+
- `left_top` (tuple[int, int]): The top-left corner where the image2 will be pasted.
|
120 |
+
- `mask` (torch.Tensor): The mask tensor.
|
121 |
+
"""
|
122 |
+
_tensor_check_image(image1)
|
123 |
+
_tensor_check_image(image2)
|
124 |
+
_tensor_check_mask(mask)
|
125 |
+
|
126 |
+
x, y = left_top
|
127 |
+
_, h1, w1, _ = image1.shape
|
128 |
+
_, h2, w2, _ = image2.shape
|
129 |
+
|
130 |
+
# calculate image patch size
|
131 |
+
w = min(w1, x + w2) - x
|
132 |
+
h = min(h1, y + h2) - y
|
133 |
+
|
134 |
+
mask = mask[:, :h, :w, :]
|
135 |
+
image1[:, y : y + h, x : x + w, :] = (1 - mask) * image1[
|
136 |
+
:, y : y + h, x : x + w, :
|
137 |
+
] + mask * image2[:, :h, :w, :]
|
138 |
+
return
|
139 |
+
|
140 |
+
|
141 |
+
def tensor_convert_rgba(image: torch.Tensor, prefer_copy: bool = True) -> torch.Tensor:
|
142 |
+
"""#### Convert a tensor image to RGBA format.
|
143 |
+
|
144 |
+
#### Args:
|
145 |
+
- `image` (torch.Tensor): The input tensor image.
|
146 |
+
- `prefer_copy` (bool, optional): Whether to prefer copying the tensor. Defaults to True.
|
147 |
+
|
148 |
+
#### Returns:
|
149 |
+
- `torch.Tensor`: The converted RGBA tensor image.
|
150 |
+
"""
|
151 |
+
_tensor_check_image(image)
|
152 |
+
alpha = torch.ones((*image.shape[:-1], 1))
|
153 |
+
return torch.cat((image, alpha), axis=-1)
|
154 |
+
|
155 |
+
|
156 |
+
def tensor_convert_rgb(image: torch.Tensor, prefer_copy: bool = True) -> torch.Tensor:
|
157 |
+
"""#### Convert a tensor image to RGB format.
|
158 |
+
|
159 |
+
#### Args:
|
160 |
+
- `image` (torch.Tensor): The input tensor image.
|
161 |
+
- `prefer_copy` (bool, optional): Whether to prefer copying the tensor. Defaults to True.
|
162 |
+
|
163 |
+
#### Returns:
|
164 |
+
- `torch.Tensor`: The converted RGB tensor image.
|
165 |
+
"""
|
166 |
+
_tensor_check_image(image)
|
167 |
+
return image
|
168 |
+
|
169 |
+
|
170 |
+
def tensor_get_size(image: torch.Tensor) -> tuple[int, int]:
|
171 |
+
"""#### Get the size of a tensor image.
|
172 |
+
|
173 |
+
#### Args:
|
174 |
+
- `image` (torch.Tensor): The input tensor image.
|
175 |
+
|
176 |
+
#### Returns:
|
177 |
+
- `tuple[int, int]`: The width and height of the tensor image.
|
178 |
+
"""
|
179 |
+
_tensor_check_image(image)
|
180 |
+
_, h, w, _ = image.shape
|
181 |
+
return (w, h)
|
182 |
+
|
183 |
+
|
184 |
+
def tensor_putalpha(image: torch.Tensor, mask: torch.Tensor) -> None:
|
185 |
+
"""#### Add an alpha channel to a tensor image using a mask.
|
186 |
+
|
187 |
+
#### Args:
|
188 |
+
- `image` (torch.Tensor): The input tensor image.
|
189 |
+
- `mask` (torch.Tensor): The mask tensor.
|
190 |
+
"""
|
191 |
+
_tensor_check_image(image)
|
192 |
+
_tensor_check_mask(mask)
|
193 |
+
image[..., -1] = mask[..., 0]
|
194 |
+
|
195 |
+
|
196 |
+
def _tensor_check_mask(mask: torch.Tensor) -> None:
|
197 |
+
"""#### Check if the input is a valid tensor mask.
|
198 |
+
|
199 |
+
#### Args:
|
200 |
+
- `mask` (torch.Tensor): The input tensor mask.
|
201 |
+
"""
|
202 |
+
return
|
203 |
+
|
204 |
+
|
205 |
+
def tensor_gaussian_blur_mask(
|
206 |
+
mask: torch.Tensor | np.ndarray, kernel_size: int, sigma: float = 10.0
|
207 |
+
) -> torch.Tensor:
|
208 |
+
"""#### Apply Gaussian blur to a tensor mask.
|
209 |
+
|
210 |
+
#### Args:
|
211 |
+
- `mask` (torch.Tensor | np.ndarray): The input tensor mask.
|
212 |
+
- `kernel_size` (int): The size of the Gaussian kernel.
|
213 |
+
- `sigma` (float, optional): The standard deviation of the Gaussian kernel. Defaults to 10.0.
|
214 |
+
|
215 |
+
#### Returns:
|
216 |
+
- `torch.Tensor`: The blurred tensor mask.
|
217 |
+
"""
|
218 |
+
if isinstance(mask, np.ndarray):
|
219 |
+
mask = torch.from_numpy(mask)
|
220 |
+
|
221 |
+
if mask.ndim == 2:
|
222 |
+
mask = mask[None, ..., None]
|
223 |
+
|
224 |
+
_tensor_check_mask(mask)
|
225 |
+
|
226 |
+
kernel_size = kernel_size * 2 + 1
|
227 |
+
|
228 |
+
prev_device = mask.device
|
229 |
+
device = Device.get_torch_device()
|
230 |
+
mask.to(device)
|
231 |
+
|
232 |
+
# apply gaussian blur
|
233 |
+
mask = mask[:, None, ..., 0]
|
234 |
+
blurred_mask = torchvision.transforms.GaussianBlur(
|
235 |
+
kernel_size=kernel_size, sigma=sigma
|
236 |
+
)(mask)
|
237 |
+
blurred_mask = blurred_mask[:, 0, ..., None]
|
238 |
+
|
239 |
+
blurred_mask.to(prev_device)
|
240 |
+
|
241 |
+
return blurred_mask
|
242 |
+
|
243 |
+
|
244 |
+
def to_tensor(image: np.ndarray) -> torch.Tensor:
|
245 |
+
"""#### Convert a numpy array to a tensor.
|
246 |
+
|
247 |
+
#### Args:
|
248 |
+
- `image` (np.ndarray): The input numpy array.
|
249 |
+
|
250 |
+
#### Returns:
|
251 |
+
- `torch.Tensor`: The converted tensor.
|
252 |
+
"""
|
253 |
+
return torch.from_numpy(image)
|
modules/AutoEncoders/ResBlock.py
ADDED
@@ -0,0 +1,406 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import abstractmethod
|
2 |
+
from typing import Optional, Any, Dict
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from modules.NeuralNetwork import transformer
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
|
9 |
+
from modules.Attention import Attention
|
10 |
+
from modules.cond import cast
|
11 |
+
from modules.sample import sampling_util
|
12 |
+
|
13 |
+
|
14 |
+
oai_ops = cast.disable_weight_init
|
15 |
+
|
16 |
+
|
17 |
+
class TimestepBlock1(nn.Module):
|
18 |
+
"""#### Abstract class representing a timestep block."""
|
19 |
+
|
20 |
+
@abstractmethod
|
21 |
+
def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
|
22 |
+
"""#### Forward pass for the timestep block.
|
23 |
+
|
24 |
+
#### Args:
|
25 |
+
- `x` (torch.Tensor): The input tensor.
|
26 |
+
- `emb` (torch.Tensor): The embedding tensor.
|
27 |
+
|
28 |
+
#### Returns:
|
29 |
+
- `torch.Tensor`: The output tensor.
|
30 |
+
"""
|
31 |
+
pass
|
32 |
+
|
33 |
+
|
34 |
+
def forward_timestep_embed1(
|
35 |
+
ts: nn.ModuleList,
|
36 |
+
x: torch.Tensor,
|
37 |
+
emb: torch.Tensor,
|
38 |
+
context: Optional[torch.Tensor] = None,
|
39 |
+
transformer_options: Optional[Dict[str, Any]] = {},
|
40 |
+
output_shape: Optional[torch.Size] = None,
|
41 |
+
time_context: Optional[torch.Tensor] = None,
|
42 |
+
num_video_frames: Optional[int] = None,
|
43 |
+
image_only_indicator: Optional[bool] = None,
|
44 |
+
) -> torch.Tensor:
|
45 |
+
"""#### Forward pass for timestep embedding.
|
46 |
+
|
47 |
+
#### Args:
|
48 |
+
- `ts` (nn.ModuleList): The list of timestep blocks.
|
49 |
+
- `x` (torch.Tensor): The input tensor.
|
50 |
+
- `emb` (torch.Tensor): The embedding tensor.
|
51 |
+
- `context` (torch.Tensor, optional): The context tensor. Defaults to None.
|
52 |
+
- `transformer_options` (dict, optional): The transformer options. Defaults to {}.
|
53 |
+
- `output_shape` (torch.Size, optional): The output shape. Defaults to None.
|
54 |
+
- `time_context` (torch.Tensor, optional): The time context tensor. Defaults to None.
|
55 |
+
- `num_video_frames` (int, optional): The number of video frames. Defaults to None.
|
56 |
+
- `image_only_indicator` (bool, optional): The image only indicator. Defaults to None.
|
57 |
+
|
58 |
+
#### Returns:
|
59 |
+
- `torch.Tensor`: The output tensor.
|
60 |
+
"""
|
61 |
+
for layer in ts:
|
62 |
+
if isinstance(layer, TimestepBlock1):
|
63 |
+
x = layer(x, emb)
|
64 |
+
elif isinstance(layer, transformer.SpatialTransformer):
|
65 |
+
x = layer(x, context, transformer_options)
|
66 |
+
if "transformer_index" in transformer_options:
|
67 |
+
transformer_options["transformer_index"] += 1
|
68 |
+
elif isinstance(layer, Upsample1):
|
69 |
+
x = layer(x, output_shape=output_shape)
|
70 |
+
else:
|
71 |
+
x = layer(x)
|
72 |
+
return x
|
73 |
+
|
74 |
+
|
75 |
+
class Upsample1(nn.Module):
|
76 |
+
"""#### Class representing an upsample layer."""
|
77 |
+
|
78 |
+
def __init__(
|
79 |
+
self,
|
80 |
+
channels: int,
|
81 |
+
use_conv: bool,
|
82 |
+
dims: int = 2,
|
83 |
+
out_channels: Optional[int] = None,
|
84 |
+
padding: int = 1,
|
85 |
+
dtype: Optional[torch.dtype] = None,
|
86 |
+
device: Optional[torch.device] = None,
|
87 |
+
operations: Any = oai_ops,
|
88 |
+
):
|
89 |
+
"""#### Initialize the upsample layer.
|
90 |
+
|
91 |
+
#### Args:
|
92 |
+
- `channels` (int): The number of input channels.
|
93 |
+
- `use_conv` (bool): Whether to use convolution.
|
94 |
+
- `dims` (int, optional): The number of dimensions. Defaults to 2.
|
95 |
+
- `out_channels` (int, optional): The number of output channels. Defaults to None.
|
96 |
+
- `padding` (int, optional): The padding size. Defaults to 1.
|
97 |
+
- `dtype` (torch.dtype, optional): The data type. Defaults to None.
|
98 |
+
- `device` (torch.device, optional): The device. Defaults to None.
|
99 |
+
- `operations` (any, optional): The operations. Defaults to oai_ops.
|
100 |
+
"""
|
101 |
+
super().__init__()
|
102 |
+
self.channels = channels
|
103 |
+
self.out_channels = out_channels or channels
|
104 |
+
self.use_conv = use_conv
|
105 |
+
self.dims = dims
|
106 |
+
if use_conv:
|
107 |
+
self.conv = operations.conv_nd(
|
108 |
+
dims,
|
109 |
+
self.channels,
|
110 |
+
self.out_channels,
|
111 |
+
3,
|
112 |
+
padding=padding,
|
113 |
+
dtype=dtype,
|
114 |
+
device=device,
|
115 |
+
)
|
116 |
+
|
117 |
+
def forward(
|
118 |
+
self, x: torch.Tensor, output_shape: Optional[torch.Size] = None
|
119 |
+
) -> torch.Tensor:
|
120 |
+
"""#### Forward pass for the upsample layer.
|
121 |
+
|
122 |
+
#### Args:
|
123 |
+
- `x` (torch.Tensor): The input tensor.
|
124 |
+
- `output_shape` (torch.Size, optional): The output shape. Defaults to None.
|
125 |
+
|
126 |
+
#### Returns:
|
127 |
+
- `torch.Tensor`: The output tensor.
|
128 |
+
"""
|
129 |
+
assert x.shape[1] == self.channels
|
130 |
+
shape = [x.shape[2] * 2, x.shape[3] * 2]
|
131 |
+
if output_shape is not None:
|
132 |
+
shape[0] = output_shape[2]
|
133 |
+
shape[1] = output_shape[3]
|
134 |
+
|
135 |
+
x = F.interpolate(x, size=shape, mode="nearest")
|
136 |
+
if self.use_conv:
|
137 |
+
x = self.conv(x)
|
138 |
+
return x
|
139 |
+
|
140 |
+
|
141 |
+
class Downsample1(nn.Module):
|
142 |
+
"""#### Class representing a downsample layer."""
|
143 |
+
|
144 |
+
def __init__(
|
145 |
+
self,
|
146 |
+
channels: int,
|
147 |
+
use_conv: bool,
|
148 |
+
dims: int = 2,
|
149 |
+
out_channels: Optional[int] = None,
|
150 |
+
padding: int = 1,
|
151 |
+
dtype: Optional[torch.dtype] = None,
|
152 |
+
device: Optional[torch.device] = None,
|
153 |
+
operations: Any = oai_ops,
|
154 |
+
):
|
155 |
+
"""#### Initialize the downsample layer.
|
156 |
+
|
157 |
+
#### Args:
|
158 |
+
- `channels` (int): The number of input channels.
|
159 |
+
- `use_conv` (bool): Whether to use convolution.
|
160 |
+
- `dims` (int, optional): The number of dimensions. Defaults to 2.
|
161 |
+
- `out_channels` (int, optional): The number of output channels. Defaults to None.
|
162 |
+
- `padding` (int, optional): The padding size. Defaults to 1.
|
163 |
+
- `dtype` (torch.dtype, optional): The data type. Defaults to None.
|
164 |
+
- `device` (torch.device, optional): The device. Defaults to None.
|
165 |
+
- `operations` (any, optional): The operations. Defaults to oai_ops.
|
166 |
+
"""
|
167 |
+
super().__init__()
|
168 |
+
self.channels = channels
|
169 |
+
self.out_channels = out_channels or channels
|
170 |
+
self.use_conv = use_conv
|
171 |
+
self.dims = dims
|
172 |
+
stride = 2 if dims != 3 else (1, 2, 2)
|
173 |
+
self.op = operations.conv_nd(
|
174 |
+
dims,
|
175 |
+
self.channels,
|
176 |
+
self.out_channels,
|
177 |
+
3,
|
178 |
+
stride=stride,
|
179 |
+
padding=padding,
|
180 |
+
dtype=dtype,
|
181 |
+
device=device,
|
182 |
+
)
|
183 |
+
|
184 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
185 |
+
"""#### Forward pass for the downsample layer.
|
186 |
+
|
187 |
+
#### Args:
|
188 |
+
- `x` (torch.Tensor): The input tensor.
|
189 |
+
|
190 |
+
#### Returns:
|
191 |
+
- `torch.Tensor`: The output tensor.
|
192 |
+
"""
|
193 |
+
assert x.shape[1] == self.channels
|
194 |
+
return self.op(x)
|
195 |
+
|
196 |
+
|
197 |
+
class ResBlock1(TimestepBlock1):
|
198 |
+
"""#### Class representing a residual block layer."""
|
199 |
+
|
200 |
+
def __init__(
|
201 |
+
self,
|
202 |
+
channels: int,
|
203 |
+
emb_channels: int,
|
204 |
+
dropout: float,
|
205 |
+
out_channels: Optional[int] = None,
|
206 |
+
use_conv: bool = False,
|
207 |
+
use_scale_shift_norm: bool = False,
|
208 |
+
dims: int = 2,
|
209 |
+
use_checkpoint: bool = False,
|
210 |
+
up: bool = False,
|
211 |
+
down: bool = False,
|
212 |
+
kernel_size: int = 3,
|
213 |
+
exchange_temb_dims: bool = False,
|
214 |
+
skip_t_emb: bool = False,
|
215 |
+
dtype: Optional[torch.dtype] = None,
|
216 |
+
device: Optional[torch.device] = None,
|
217 |
+
operations: Any = oai_ops,
|
218 |
+
):
|
219 |
+
"""#### Initialize the residual block layer.
|
220 |
+
|
221 |
+
#### Args:
|
222 |
+
- `channels` (int): The number of input channels.
|
223 |
+
- `emb_channels` (int): The number of embedding channels.
|
224 |
+
- `dropout` (float): The dropout rate.
|
225 |
+
- `out_channels` (int, optional): The number of output channels. Defaults to None.
|
226 |
+
- `use_conv` (bool, optional): Whether to use convolution. Defaults to False.
|
227 |
+
- `use_scale_shift_norm` (bool, optional): Whether to use scale shift normalization. Defaults to False.
|
228 |
+
- `dims` (int, optional): The number of dimensions. Defaults to 2.
|
229 |
+
- `use_checkpoint` (bool, optional): Whether to use checkpointing. Defaults to False.
|
230 |
+
- `up` (bool, optional): Whether to use upsampling. Defaults to False.
|
231 |
+
- `down` (bool, optional): Whether to use downsampling. Defaults to False.
|
232 |
+
- `kernel_size` (int, optional): The kernel size. Defaults to 3.
|
233 |
+
- `exchange_temb_dims` (bool, optional): Whether to exchange embedding dimensions. Defaults to False.
|
234 |
+
- `skip_t_emb` (bool, optional): Whether to skip embedding. Defaults to False.
|
235 |
+
- `dtype` (torch.dtype, optional): The data type. Defaults to None.
|
236 |
+
- `device` (torch.device, optional): The device. Defaults to None.
|
237 |
+
- `operations` (any, optional): The operations. Defaults to oai_ops.
|
238 |
+
"""
|
239 |
+
super().__init__()
|
240 |
+
self.channels = channels
|
241 |
+
self.emb_channels = emb_channels
|
242 |
+
self.dropout = dropout
|
243 |
+
self.out_channels = out_channels or channels
|
244 |
+
self.use_conv = use_conv
|
245 |
+
self.use_checkpoint = use_checkpoint
|
246 |
+
self.use_scale_shift_norm = use_scale_shift_norm
|
247 |
+
self.exchange_temb_dims = exchange_temb_dims
|
248 |
+
|
249 |
+
padding = kernel_size // 2
|
250 |
+
|
251 |
+
self.in_layers = nn.Sequential(
|
252 |
+
operations.GroupNorm(32, channels, dtype=dtype, device=device),
|
253 |
+
nn.SiLU(),
|
254 |
+
operations.conv_nd(
|
255 |
+
dims,
|
256 |
+
channels,
|
257 |
+
self.out_channels,
|
258 |
+
kernel_size,
|
259 |
+
padding=padding,
|
260 |
+
dtype=dtype,
|
261 |
+
device=device,
|
262 |
+
),
|
263 |
+
)
|
264 |
+
|
265 |
+
self.updown = up or down
|
266 |
+
|
267 |
+
self.h_upd = self.x_upd = nn.Identity()
|
268 |
+
|
269 |
+
self.skip_t_emb = skip_t_emb
|
270 |
+
self.emb_layers = nn.Sequential(
|
271 |
+
nn.SiLU(),
|
272 |
+
operations.Linear(
|
273 |
+
emb_channels,
|
274 |
+
(2 * self.out_channels if use_scale_shift_norm else self.out_channels),
|
275 |
+
dtype=dtype,
|
276 |
+
device=device,
|
277 |
+
),
|
278 |
+
)
|
279 |
+
self.out_layers = nn.Sequential(
|
280 |
+
operations.GroupNorm(32, self.out_channels, dtype=dtype, device=device),
|
281 |
+
nn.SiLU(),
|
282 |
+
nn.Dropout(p=dropout),
|
283 |
+
operations.conv_nd(
|
284 |
+
dims,
|
285 |
+
self.out_channels,
|
286 |
+
self.out_channels,
|
287 |
+
kernel_size,
|
288 |
+
padding=padding,
|
289 |
+
dtype=dtype,
|
290 |
+
device=device,
|
291 |
+
),
|
292 |
+
)
|
293 |
+
|
294 |
+
if self.out_channels == channels:
|
295 |
+
self.skip_connection = nn.Identity()
|
296 |
+
else:
|
297 |
+
self.skip_connection = operations.conv_nd(
|
298 |
+
dims, channels, self.out_channels, 1, dtype=dtype, device=device
|
299 |
+
)
|
300 |
+
|
301 |
+
def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
|
302 |
+
"""#### Forward pass for the residual block layer.
|
303 |
+
|
304 |
+
#### Args:
|
305 |
+
- `x` (torch.Tensor): The input tensor.
|
306 |
+
- `emb` (torch.Tensor): The embedding tensor.
|
307 |
+
|
308 |
+
#### Returns:
|
309 |
+
- `torch.Tensor`: The output tensor.
|
310 |
+
"""
|
311 |
+
return sampling_util.checkpoint(
|
312 |
+
self._forward, (x, emb), self.parameters(), self.use_checkpoint
|
313 |
+
)
|
314 |
+
|
315 |
+
def _forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
|
316 |
+
"""#### Internal forward pass for the residual block layer.
|
317 |
+
|
318 |
+
#### Args:
|
319 |
+
- `x` (torch.Tensor): The input tensor.
|
320 |
+
- `emb` (torch.Tensor): The embedding tensor.
|
321 |
+
|
322 |
+
#### Returns:
|
323 |
+
- `torch.Tensor`: The output tensor.
|
324 |
+
"""
|
325 |
+
h = self.in_layers(x)
|
326 |
+
|
327 |
+
emb_out = None
|
328 |
+
if not self.skip_t_emb:
|
329 |
+
emb_out = self.emb_layers(emb).type(h.dtype)
|
330 |
+
while len(emb_out.shape) < len(h.shape):
|
331 |
+
emb_out = emb_out[..., None]
|
332 |
+
if emb_out is not None:
|
333 |
+
h = h + emb_out
|
334 |
+
h = self.out_layers(h)
|
335 |
+
return self.skip_connection(x) + h
|
336 |
+
|
337 |
+
|
338 |
+
ops = cast.disable_weight_init
|
339 |
+
|
340 |
+
|
341 |
+
class ResnetBlock(nn.Module):
|
342 |
+
"""#### Class representing a ResNet block layer."""
|
343 |
+
|
344 |
+
def __init__(
|
345 |
+
self,
|
346 |
+
*,
|
347 |
+
in_channels: int,
|
348 |
+
out_channels: Optional[int] = None,
|
349 |
+
conv_shortcut: bool = False,
|
350 |
+
dropout: float,
|
351 |
+
temb_channels: int = 512,
|
352 |
+
):
|
353 |
+
"""#### Initialize the ResNet block layer.
|
354 |
+
|
355 |
+
#### Args:
|
356 |
+
- `in_channels` (int): The number of input channels.
|
357 |
+
- `out_channels` (int, optional): The number of output channels. Defaults to None.
|
358 |
+
- `conv_shortcut` (bool, optional): Whether to use convolution shortcut. Defaults to False.
|
359 |
+
- `dropout` (float): The dropout rate.
|
360 |
+
- `temb_channels` (int, optional): The number of embedding channels. Defaults to 512.
|
361 |
+
"""
|
362 |
+
super().__init__()
|
363 |
+
self.in_channels = in_channels
|
364 |
+
out_channels = in_channels if out_channels is None else out_channels
|
365 |
+
self.out_channels = out_channels
|
366 |
+
self.use_conv_shortcut = conv_shortcut
|
367 |
+
|
368 |
+
self.swish = torch.nn.SiLU(inplace=True)
|
369 |
+
self.norm1 = Attention.Normalize(in_channels)
|
370 |
+
self.conv1 = ops.Conv2d(
|
371 |
+
in_channels, out_channels, kernel_size=3, stride=1, padding=1
|
372 |
+
)
|
373 |
+
self.norm2 = Attention.Normalize(out_channels)
|
374 |
+
self.dropout = torch.nn.Dropout(dropout, inplace=True)
|
375 |
+
self.conv2 = ops.Conv2d(
|
376 |
+
out_channels, out_channels, kernel_size=3, stride=1, padding=1
|
377 |
+
)
|
378 |
+
if self.in_channels != self.out_channels:
|
379 |
+
self.nin_shortcut = ops.Conv2d(
|
380 |
+
in_channels, out_channels, kernel_size=1, stride=1, padding=0
|
381 |
+
)
|
382 |
+
|
383 |
+
def forward(self, x: torch.Tensor, temb: torch.Tensor) -> torch.Tensor:
|
384 |
+
"""#### Forward pass for the ResNet block layer.
|
385 |
+
|
386 |
+
#### Args:
|
387 |
+
- `x` (torch.Tensor): The input tensor.
|
388 |
+
- `temb` (torch.Tensor): The embedding tensor.
|
389 |
+
|
390 |
+
#### Returns:
|
391 |
+
- `torch.Tensor`: The output tensor.
|
392 |
+
"""
|
393 |
+
h = x
|
394 |
+
h = self.norm1(h)
|
395 |
+
h = self.swish(h)
|
396 |
+
h = self.conv1(h)
|
397 |
+
|
398 |
+
h = self.norm2(h)
|
399 |
+
h = self.swish(h)
|
400 |
+
h = self.dropout(h)
|
401 |
+
h = self.conv2(h)
|
402 |
+
|
403 |
+
if self.in_channels != self.out_channels:
|
404 |
+
x = self.nin_shortcut(x)
|
405 |
+
|
406 |
+
return x + h
|
modules/AutoEncoders/VariationalAE.py
ADDED
@@ -0,0 +1,824 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from typing import Dict, Optional, Tuple, Union
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
from modules.Model import ModelPatcher
|
6 |
+
import torch.nn as nn
|
7 |
+
|
8 |
+
from modules.Attention import Attention
|
9 |
+
from modules.AutoEncoders import ResBlock
|
10 |
+
from modules.Device import Device
|
11 |
+
from modules.Utilities import util
|
12 |
+
from modules.cond import cast
|
13 |
+
|
14 |
+
|
15 |
+
class DiagonalGaussianDistribution(object):
|
16 |
+
"""#### Represents a diagonal Gaussian distribution parameterized by mean and log-variance.
|
17 |
+
|
18 |
+
#### Attributes:
|
19 |
+
- `parameters` (torch.Tensor): The concatenated mean and log-variance of the distribution.
|
20 |
+
- `mean` (torch.Tensor): The mean of the distribution.
|
21 |
+
- `logvar` (torch.Tensor): The log-variance of the distribution, clamped between -30.0 and 20.0.
|
22 |
+
- `std` (torch.Tensor): The standard deviation of the distribution, computed as exp(0.5 * logvar).
|
23 |
+
- `var` (torch.Tensor): The variance of the distribution, computed as exp(logvar).
|
24 |
+
- `deterministic` (bool): If True, the distribution is deterministic.
|
25 |
+
|
26 |
+
#### Methods:
|
27 |
+
- `sample() -> torch.Tensor`:
|
28 |
+
Samples from the distribution using the reparameterization trick.
|
29 |
+
- `kl(other: DiagonalGaussianDistribution = None) -> torch.Tensor`:
|
30 |
+
Computes the Kullback-Leibler divergence between this distribution and a standard normal distribution.
|
31 |
+
If `other` is provided, computes the KL divergence between this distribution and `other`.
|
32 |
+
"""
|
33 |
+
|
34 |
+
def __init__(self, parameters: torch.Tensor, deterministic: bool = False):
|
35 |
+
self.parameters = parameters
|
36 |
+
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
|
37 |
+
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
|
38 |
+
self.deterministic = deterministic
|
39 |
+
self.std = torch.exp(0.5 * self.logvar)
|
40 |
+
self.var = torch.exp(self.logvar)
|
41 |
+
|
42 |
+
def sample(self) -> torch.Tensor:
|
43 |
+
"""#### Samples from the distribution using the reparameterization trick.
|
44 |
+
|
45 |
+
#### Returns:
|
46 |
+
- `torch.Tensor`: A sample from the distribution.
|
47 |
+
"""
|
48 |
+
x = self.mean + self.std * torch.randn(self.mean.shape).to(
|
49 |
+
device=self.parameters.device
|
50 |
+
)
|
51 |
+
return x
|
52 |
+
|
53 |
+
def kl(self, other: "DiagonalGaussianDistribution" = None) -> torch.Tensor:
|
54 |
+
"""#### Computes the Kullback-Leibler divergence between this distribution and a standard normal distribution.
|
55 |
+
|
56 |
+
If `other` is provided, computes the KL divergence between this distribution and `other`.
|
57 |
+
|
58 |
+
#### Args:
|
59 |
+
- `other` (DiagonalGaussianDistribution, optional): Another distribution to compute the KL divergence with.
|
60 |
+
|
61 |
+
#### Returns:
|
62 |
+
- `torch.Tensor`: The KL divergence.
|
63 |
+
"""
|
64 |
+
return 0.5 * torch.sum(
|
65 |
+
torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
|
66 |
+
dim=[1, 2, 3],
|
67 |
+
)
|
68 |
+
|
69 |
+
|
70 |
+
class DiagonalGaussianRegularizer(torch.nn.Module):
|
71 |
+
"""#### Regularizer for diagonal Gaussian distributions."""
|
72 |
+
|
73 |
+
def __init__(self, sample: bool = True):
|
74 |
+
"""#### Initialize the regularizer.
|
75 |
+
|
76 |
+
#### Args:
|
77 |
+
- `sample` (bool, optional): Whether to sample from the distribution. Defaults to True.
|
78 |
+
"""
|
79 |
+
super().__init__()
|
80 |
+
self.sample = sample
|
81 |
+
|
82 |
+
def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
|
83 |
+
"""#### Forward pass for the regularizer.
|
84 |
+
|
85 |
+
#### Args:
|
86 |
+
- `z` (torch.Tensor): The input tensor.
|
87 |
+
|
88 |
+
#### Returns:
|
89 |
+
- `Tuple[torch.Tensor, dict]`: The regularized tensor and a log dictionary.
|
90 |
+
"""
|
91 |
+
log = dict()
|
92 |
+
posterior = DiagonalGaussianDistribution(z)
|
93 |
+
if self.sample:
|
94 |
+
z = posterior.sample()
|
95 |
+
else:
|
96 |
+
z = posterior.mode()
|
97 |
+
kl_loss = posterior.kl()
|
98 |
+
kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
|
99 |
+
log["kl_loss"] = kl_loss
|
100 |
+
return z, log
|
101 |
+
|
102 |
+
|
103 |
+
class AutoencodingEngine(nn.Module):
|
104 |
+
"""#### Class representing an autoencoding engine."""
|
105 |
+
|
106 |
+
def __init__(self, encoder: nn.Module, decoder: nn.Module, regularizer: nn.Module, flux: bool = False):
|
107 |
+
"""#### Initialize the autoencoding engine.
|
108 |
+
|
109 |
+
#### Args:
|
110 |
+
- `encoder` (nn.Module): The encoder module.
|
111 |
+
- `decoder` (nn.Module): The decoder module.
|
112 |
+
- `regularizer` (nn.Module): The regularizer module.
|
113 |
+
"""
|
114 |
+
super().__init__()
|
115 |
+
self.encoder = encoder
|
116 |
+
self.decoder = decoder
|
117 |
+
self.regularization = regularizer
|
118 |
+
if not flux:
|
119 |
+
self.post_quant_conv = cast.disable_weight_init.Conv2d(4, 4, 1)
|
120 |
+
self.quant_conv = cast.disable_weight_init.Conv2d(8, 8, 1)
|
121 |
+
|
122 |
+
def get_last_layer(self):
|
123 |
+
"""#### Get the last layer of the decoder.
|
124 |
+
|
125 |
+
Returns:
|
126 |
+
- `nn.Module`: The last layer of the decoder.
|
127 |
+
"""
|
128 |
+
return self.decoder.get_last_layer()
|
129 |
+
|
130 |
+
def decode(self, z: torch.Tensor, flux:bool = False, **kwargs) -> torch.Tensor:
|
131 |
+
"""#### Decode the latent tensor.
|
132 |
+
|
133 |
+
#### Args:
|
134 |
+
- `z` (torch.Tensor): The latent tensor.
|
135 |
+
- `decoder_kwargs` (dict): Additional arguments for the decoder.
|
136 |
+
|
137 |
+
#### Returns:
|
138 |
+
- `torch.Tensor`: The decoded tensor.
|
139 |
+
"""
|
140 |
+
if flux:
|
141 |
+
x = self.decoder(z, **kwargs)
|
142 |
+
return x
|
143 |
+
dec = self.post_quant_conv(z)
|
144 |
+
dec = self.decoder(dec, **kwargs)
|
145 |
+
return dec
|
146 |
+
|
147 |
+
|
148 |
+
def encode(
|
149 |
+
self,
|
150 |
+
x: torch.Tensor,
|
151 |
+
return_reg_log: bool = False,
|
152 |
+
unregularized: bool = False,
|
153 |
+
flux: bool = False,
|
154 |
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
|
155 |
+
"""#### Encode the input tensor.
|
156 |
+
|
157 |
+
#### Args:
|
158 |
+
- `x` (torch.Tensor): The input tensor.
|
159 |
+
- `return_reg_log` (bool, optional): Whether to return the regularization log. Defaults to False.
|
160 |
+
|
161 |
+
#### Returns:
|
162 |
+
- `Union[torch.Tensor, Tuple[torch.Tensor, dict]]`: The encoded tensor and optionally the regularization log.
|
163 |
+
"""
|
164 |
+
z = self.encoder(x)
|
165 |
+
if not flux:
|
166 |
+
z = self.quant_conv(z)
|
167 |
+
if unregularized:
|
168 |
+
return z, dict()
|
169 |
+
z, reg_log = self.regularization(z)
|
170 |
+
if return_reg_log:
|
171 |
+
return z, reg_log
|
172 |
+
return z
|
173 |
+
|
174 |
+
ops = cast.disable_weight_init
|
175 |
+
|
176 |
+
if Device.xformers_enabled_vae():
|
177 |
+
pass
|
178 |
+
|
179 |
+
|
180 |
+
def nonlinearity(x: torch.Tensor) -> torch.Tensor:
|
181 |
+
"""#### Apply the swish nonlinearity.
|
182 |
+
|
183 |
+
#### Args:
|
184 |
+
- `x` (torch.Tensor): The input tensor.
|
185 |
+
|
186 |
+
#### Returns:
|
187 |
+
- `torch.Tensor`: The output tensor.
|
188 |
+
"""
|
189 |
+
return x * torch.sigmoid(x)
|
190 |
+
|
191 |
+
|
192 |
+
class Upsample(nn.Module):
|
193 |
+
"""#### Class representing an upsample layer."""
|
194 |
+
|
195 |
+
def __init__(self, in_channels: int, with_conv: bool):
|
196 |
+
"""#### Initialize the upsample layer.
|
197 |
+
|
198 |
+
#### Args:
|
199 |
+
- `in_channels` (int): The number of input channels.
|
200 |
+
- `with_conv` (bool): Whether to use convolution.
|
201 |
+
"""
|
202 |
+
super().__init__()
|
203 |
+
self.with_conv = with_conv
|
204 |
+
if self.with_conv:
|
205 |
+
self.conv = ops.Conv2d(
|
206 |
+
in_channels, in_channels, kernel_size=3, stride=1, padding=1
|
207 |
+
)
|
208 |
+
|
209 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
210 |
+
"""#### Forward pass for the upsample layer.
|
211 |
+
|
212 |
+
#### Args:
|
213 |
+
- `x` (torch.Tensor): The input tensor.
|
214 |
+
|
215 |
+
#### Returns:
|
216 |
+
- `torch.Tensor`: The output tensor.
|
217 |
+
"""
|
218 |
+
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
219 |
+
if self.with_conv:
|
220 |
+
x = self.conv(x)
|
221 |
+
return x
|
222 |
+
|
223 |
+
|
224 |
+
class Downsample(nn.Module):
|
225 |
+
"""#### Class representing a downsample layer."""
|
226 |
+
|
227 |
+
def __init__(self, in_channels: int, with_conv: bool):
|
228 |
+
"""#### Initialize the downsample layer.
|
229 |
+
|
230 |
+
#### Args:
|
231 |
+
- `in_channels` (int): The number of input channels.
|
232 |
+
- `with_conv` (bool): Whether to use convolution.
|
233 |
+
"""
|
234 |
+
super().__init__()
|
235 |
+
self.with_conv = with_conv
|
236 |
+
if self.with_conv:
|
237 |
+
# no asymmetric padding in torch conv, must do it ourselves
|
238 |
+
self.conv = ops.Conv2d(
|
239 |
+
in_channels, in_channels, kernel_size=3, stride=2, padding=0
|
240 |
+
)
|
241 |
+
|
242 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
243 |
+
"""#### Forward pass for the downsample layer.
|
244 |
+
|
245 |
+
#### Args:
|
246 |
+
- `x` (torch.Tensor): The input tensor.
|
247 |
+
|
248 |
+
#### Returns:
|
249 |
+
- `torch.Tensor`: The output tensor.
|
250 |
+
"""
|
251 |
+
pad = (0, 1, 0, 1)
|
252 |
+
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
253 |
+
x = self.conv(x)
|
254 |
+
return x
|
255 |
+
|
256 |
+
|
257 |
+
class Encoder(nn.Module):
|
258 |
+
"""#### Class representing an encoder."""
|
259 |
+
|
260 |
+
def __init__(
|
261 |
+
self,
|
262 |
+
*,
|
263 |
+
ch: int,
|
264 |
+
out_ch: int,
|
265 |
+
ch_mult: Tuple[int, ...] = (1, 2, 4, 8),
|
266 |
+
num_res_blocks: int,
|
267 |
+
attn_resolutions: Tuple[int, ...],
|
268 |
+
dropout: float = 0.0,
|
269 |
+
resamp_with_conv: bool = True,
|
270 |
+
in_channels: int,
|
271 |
+
resolution: int,
|
272 |
+
z_channels: int,
|
273 |
+
double_z: bool = True,
|
274 |
+
use_linear_attn: bool = False,
|
275 |
+
attn_type: str = "vanilla",
|
276 |
+
**ignore_kwargs,
|
277 |
+
):
|
278 |
+
"""#### Initialize the encoder.
|
279 |
+
|
280 |
+
#### Args:
|
281 |
+
- `ch` (int): The base number of channels.
|
282 |
+
- `out_ch` (int): The number of output channels.
|
283 |
+
- `ch_mult` (Tuple[int, ...], optional): Channel multiplier at each resolution. Defaults to (1, 2, 4, 8).
|
284 |
+
- `num_res_blocks` (int): The number of residual blocks.
|
285 |
+
- `attn_resolutions` (Tuple[int, ...]): The resolutions at which to apply attention.
|
286 |
+
- `dropout` (float, optional): The dropout rate. Defaults to 0.0.
|
287 |
+
- `resamp_with_conv` (bool, optional): Whether to use convolution for resampling. Defaults to True.
|
288 |
+
- `in_channels` (int): The number of input channels.
|
289 |
+
- `resolution` (int): The resolution of the input.
|
290 |
+
- `z_channels` (int): The number of latent channels.
|
291 |
+
- `double_z` (bool, optional): Whether to double the latent channels. Defaults to True.
|
292 |
+
- `use_linear_attn` (bool, optional): Whether to use linear attention. Defaults to False.
|
293 |
+
- `attn_type` (str, optional): The type of attention. Defaults to "vanilla".
|
294 |
+
"""
|
295 |
+
super().__init__()
|
296 |
+
if use_linear_attn:
|
297 |
+
attn_type = "linear"
|
298 |
+
self.ch = ch
|
299 |
+
self.temb_ch = 0
|
300 |
+
self.num_resolutions = len(ch_mult)
|
301 |
+
self.num_res_blocks = num_res_blocks
|
302 |
+
self.resolution = resolution
|
303 |
+
self.in_channels = in_channels
|
304 |
+
|
305 |
+
# downsampling
|
306 |
+
self.conv_in = ops.Conv2d(
|
307 |
+
in_channels, self.ch, kernel_size=3, stride=1, padding=1
|
308 |
+
)
|
309 |
+
|
310 |
+
curr_res = resolution
|
311 |
+
in_ch_mult = (1,) + tuple(ch_mult)
|
312 |
+
self.in_ch_mult = in_ch_mult
|
313 |
+
self.down = nn.ModuleList()
|
314 |
+
for i_level in range(self.num_resolutions):
|
315 |
+
block = nn.ModuleList()
|
316 |
+
attn = nn.ModuleList()
|
317 |
+
block_in = ch * in_ch_mult[i_level]
|
318 |
+
block_out = ch * ch_mult[i_level]
|
319 |
+
for i_block in range(self.num_res_blocks):
|
320 |
+
block.append(
|
321 |
+
ResBlock.ResnetBlock(
|
322 |
+
in_channels=block_in,
|
323 |
+
out_channels=block_out,
|
324 |
+
temb_channels=self.temb_ch,
|
325 |
+
dropout=dropout,
|
326 |
+
)
|
327 |
+
)
|
328 |
+
block_in = block_out
|
329 |
+
down = nn.Module()
|
330 |
+
down.block = block
|
331 |
+
down.attn = attn
|
332 |
+
if i_level != self.num_resolutions - 1:
|
333 |
+
down.downsample = Downsample(block_in, resamp_with_conv)
|
334 |
+
curr_res = curr_res // 2
|
335 |
+
self.down.append(down)
|
336 |
+
|
337 |
+
# middle
|
338 |
+
self.mid = nn.Module()
|
339 |
+
self.mid.block_1 = ResBlock.ResnetBlock(
|
340 |
+
in_channels=block_in,
|
341 |
+
out_channels=block_in,
|
342 |
+
temb_channels=self.temb_ch,
|
343 |
+
dropout=dropout,
|
344 |
+
)
|
345 |
+
self.mid.attn_1 = Attention.make_attn(block_in, attn_type=attn_type)
|
346 |
+
self.mid.block_2 = ResBlock.ResnetBlock(
|
347 |
+
in_channels=block_in,
|
348 |
+
out_channels=block_in,
|
349 |
+
temb_channels=self.temb_ch,
|
350 |
+
dropout=dropout,
|
351 |
+
)
|
352 |
+
|
353 |
+
# end
|
354 |
+
self.norm_out = Attention.Normalize(block_in)
|
355 |
+
self.conv_out = ops.Conv2d(
|
356 |
+
block_in,
|
357 |
+
2 * z_channels if double_z else z_channels,
|
358 |
+
kernel_size=3,
|
359 |
+
stride=1,
|
360 |
+
padding=1,
|
361 |
+
)
|
362 |
+
self._device = torch.device("cpu")
|
363 |
+
self._dtype = torch.float32
|
364 |
+
|
365 |
+
def to(self, device=None, dtype=None):
|
366 |
+
"""#### Move the encoder to a device and data type.
|
367 |
+
|
368 |
+
#### Args:
|
369 |
+
- `device` (torch.device, optional): The device to move to. Defaults to None.
|
370 |
+
- `dtype` (torch.dtype, optional): The data type to move to. Defaults to None.
|
371 |
+
|
372 |
+
#### Returns:
|
373 |
+
- `nn.Module`: The encoder.
|
374 |
+
"""
|
375 |
+
if device is not None:
|
376 |
+
self._device = device
|
377 |
+
if dtype is not None:
|
378 |
+
self._dtype = dtype
|
379 |
+
return super().to(device=device, dtype=dtype)
|
380 |
+
|
381 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
382 |
+
"""#### Forward pass for the encoder.
|
383 |
+
|
384 |
+
#### Args:
|
385 |
+
- `x` (torch.Tensor): The input tensor.
|
386 |
+
|
387 |
+
#### Returns:
|
388 |
+
- `torch.Tensor`: The encoded tensor.
|
389 |
+
"""
|
390 |
+
if x.device != self._device or x.dtype != self._dtype:
|
391 |
+
self.to(device=x.device, dtype=x.dtype)
|
392 |
+
# timestep embedding
|
393 |
+
temb = None
|
394 |
+
# downsampling
|
395 |
+
h = self.conv_in(x)
|
396 |
+
for i_level in range(self.num_resolutions):
|
397 |
+
for i_block in range(self.num_res_blocks):
|
398 |
+
h = self.down[i_level].block[i_block](h, temb)
|
399 |
+
if len(self.down[i_level].attn) > 0:
|
400 |
+
h = self.down[i_level].attn[i_block](h)
|
401 |
+
if i_level != self.num_resolutions - 1:
|
402 |
+
h = self.down[i_level].downsample(h)
|
403 |
+
|
404 |
+
# middle
|
405 |
+
h = self.mid.block_1(h, temb)
|
406 |
+
h = self.mid.attn_1(h)
|
407 |
+
h = self.mid.block_2(h, temb)
|
408 |
+
|
409 |
+
# end
|
410 |
+
h = self.norm_out(h)
|
411 |
+
h = nonlinearity(h)
|
412 |
+
h = self.conv_out(h)
|
413 |
+
return h
|
414 |
+
|
415 |
+
|
416 |
+
class Decoder(nn.Module):
|
417 |
+
"""#### Class representing a decoder."""
|
418 |
+
|
419 |
+
def __init__(
|
420 |
+
self,
|
421 |
+
*,
|
422 |
+
ch: int,
|
423 |
+
out_ch: int,
|
424 |
+
ch_mult: Tuple[int, ...] = (1, 2, 4, 8),
|
425 |
+
num_res_blocks: int,
|
426 |
+
attn_resolutions: Tuple[int, ...],
|
427 |
+
dropout: float = 0.0,
|
428 |
+
resamp_with_conv: bool = True,
|
429 |
+
in_channels: int,
|
430 |
+
resolution: int,
|
431 |
+
z_channels: int,
|
432 |
+
give_pre_end: bool = False,
|
433 |
+
tanh_out: bool = False,
|
434 |
+
use_linear_attn: bool = False,
|
435 |
+
conv_out_op: nn.Module = ops.Conv2d,
|
436 |
+
resnet_op: nn.Module = ResBlock.ResnetBlock,
|
437 |
+
attn_op: nn.Module = Attention.AttnBlock,
|
438 |
+
**ignorekwargs,
|
439 |
+
):
|
440 |
+
"""#### Initialize the decoder.
|
441 |
+
|
442 |
+
#### Args:
|
443 |
+
- `ch` (int): The base number of channels.
|
444 |
+
- `out_ch` (int): The number of output channels.
|
445 |
+
- `ch_mult` (Tuple[int, ...], optional): Channel multiplier at each resolution. Defaults to (1, 2, 4, 8).
|
446 |
+
- `num_res_blocks` (int): The number of residual blocks.
|
447 |
+
- `attn_resolutions` (Tuple[int, ...]): The resolutions at which to apply attention.
|
448 |
+
- `dropout` (float, optional): The dropout rate. Defaults to 0.0.
|
449 |
+
- `resamp_with_conv` (bool, optional): Whether to use convolution for resampling. Defaults to True.
|
450 |
+
- `in_channels` (int): The number of input channels.
|
451 |
+
- `resolution` (int): The resolution of the input.
|
452 |
+
- `z_channels` (int): The number of latent channels.
|
453 |
+
- `give_pre_end` (bool, optional): Whether to give pre-end. Defaults to False.
|
454 |
+
- `tanh_out` (bool, optional): Whether to use tanh activation at the output. Defaults to False.
|
455 |
+
- `use_linear_attn` (bool, optional): Whether to use linear attention. Defaults to False.
|
456 |
+
- `conv_out_op` (nn.Module, optional): The convolution output operation. Defaults to ops.Conv2d.
|
457 |
+
- `resnet_op` (nn.Module, optional): The residual block operation. Defaults to ResBlock.ResnetBlock.
|
458 |
+
- `attn_op` (nn.Module, optional): The attention block operation. Defaults to Attention.AttnBlock.
|
459 |
+
"""
|
460 |
+
super().__init__()
|
461 |
+
self.ch = ch
|
462 |
+
self.temb_ch = 0
|
463 |
+
self.num_resolutions = len(ch_mult)
|
464 |
+
self.num_res_blocks = num_res_blocks
|
465 |
+
self.resolution = resolution
|
466 |
+
self.in_channels = in_channels
|
467 |
+
self.give_pre_end = give_pre_end
|
468 |
+
self.tanh_out = tanh_out
|
469 |
+
|
470 |
+
# compute in_ch_mult, block_in and curr_res at lowest res
|
471 |
+
(1,) + tuple(ch_mult)
|
472 |
+
block_in = ch * ch_mult[self.num_resolutions - 1]
|
473 |
+
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
474 |
+
self.z_shape = (1, z_channels, curr_res, curr_res)
|
475 |
+
logging.debug(
|
476 |
+
"Working with z of shape {} = {} dimensions.".format(
|
477 |
+
self.z_shape, np.prod(self.z_shape)
|
478 |
+
)
|
479 |
+
)
|
480 |
+
|
481 |
+
# z to block_in
|
482 |
+
self.conv_in = ops.Conv2d(
|
483 |
+
z_channels, block_in, kernel_size=3, stride=1, padding=1
|
484 |
+
)
|
485 |
+
|
486 |
+
# middle
|
487 |
+
self.mid = nn.Module()
|
488 |
+
self.mid.block_1 = resnet_op(
|
489 |
+
in_channels=block_in,
|
490 |
+
out_channels=block_in,
|
491 |
+
temb_channels=self.temb_ch,
|
492 |
+
dropout=dropout,
|
493 |
+
)
|
494 |
+
self.mid.attn_1 = attn_op(block_in)
|
495 |
+
self.mid.block_2 = resnet_op(
|
496 |
+
in_channels=block_in,
|
497 |
+
out_channels=block_in,
|
498 |
+
temb_channels=self.temb_ch,
|
499 |
+
dropout=dropout,
|
500 |
+
)
|
501 |
+
|
502 |
+
# upsampling
|
503 |
+
self.up = nn.ModuleList()
|
504 |
+
for i_level in reversed(range(self.num_resolutions)):
|
505 |
+
block = nn.ModuleList()
|
506 |
+
attn = nn.ModuleList()
|
507 |
+
block_out = ch * ch_mult[i_level]
|
508 |
+
for i_block in range(self.num_res_blocks + 1):
|
509 |
+
block.append(
|
510 |
+
resnet_op(
|
511 |
+
in_channels=block_in,
|
512 |
+
out_channels=block_out,
|
513 |
+
temb_channels=self.temb_ch,
|
514 |
+
dropout=dropout,
|
515 |
+
)
|
516 |
+
)
|
517 |
+
block_in = block_out
|
518 |
+
up = nn.Module()
|
519 |
+
up.block = block
|
520 |
+
up.attn = attn
|
521 |
+
if i_level != 0:
|
522 |
+
up.upsample = Upsample(block_in, resamp_with_conv)
|
523 |
+
curr_res = curr_res * 2
|
524 |
+
self.up.insert(0, up) # prepend to get consistent order
|
525 |
+
|
526 |
+
# end
|
527 |
+
self.norm_out = Attention.Normalize(block_in)
|
528 |
+
self.conv_out = conv_out_op(
|
529 |
+
block_in, out_ch, kernel_size=3, stride=1, padding=1
|
530 |
+
)
|
531 |
+
|
532 |
+
def forward(self, z: torch.Tensor, **kwargs) -> torch.Tensor:
|
533 |
+
"""#### Forward pass for the decoder.
|
534 |
+
|
535 |
+
#### Args:
|
536 |
+
- `z` (torch.Tensor): The input tensor.
|
537 |
+
- `**kwargs`: Additional arguments.
|
538 |
+
|
539 |
+
#### Returns:
|
540 |
+
- `torch.Tensor`: The output tensor.
|
541 |
+
|
542 |
+
"""
|
543 |
+
# assert z.shape[1:] == self.z_shape[1:]
|
544 |
+
self.last_z_shape = z.shape
|
545 |
+
|
546 |
+
# timestep embedding
|
547 |
+
temb = None
|
548 |
+
|
549 |
+
# z to block_in
|
550 |
+
h = self.conv_in(z)
|
551 |
+
|
552 |
+
# middle
|
553 |
+
h = self.mid.block_1(h, temb, **kwargs)
|
554 |
+
h = self.mid.attn_1(h, **kwargs)
|
555 |
+
h = self.mid.block_2(h, temb, **kwargs)
|
556 |
+
|
557 |
+
# upsampling
|
558 |
+
for i_level in reversed(range(self.num_resolutions)):
|
559 |
+
for i_block in range(self.num_res_blocks + 1):
|
560 |
+
h = self.up[i_level].block[i_block](h, temb, **kwargs)
|
561 |
+
if i_level != 0:
|
562 |
+
h = self.up[i_level].upsample(h)
|
563 |
+
|
564 |
+
h = self.norm_out(h)
|
565 |
+
h = nonlinearity(h)
|
566 |
+
h = self.conv_out(h, **kwargs)
|
567 |
+
return h
|
568 |
+
|
569 |
+
|
570 |
+
class VAE:
|
571 |
+
"""#### Class representing a Variational Autoencoder (VAE)."""
|
572 |
+
|
573 |
+
def __init__(
|
574 |
+
self,
|
575 |
+
sd: Optional[dict] = None,
|
576 |
+
device: Optional[torch.device] = None,
|
577 |
+
config: Optional[dict] = None,
|
578 |
+
dtype: Optional[torch.dtype] = None,
|
579 |
+
flux: Optional[bool] = False,
|
580 |
+
):
|
581 |
+
"""#### Initialize the VAE.
|
582 |
+
|
583 |
+
#### Args:
|
584 |
+
- `sd` (dict, optional): The state dictionary. Defaults to None.
|
585 |
+
- `device` (torch.device, optional): The device to use. Defaults to None.
|
586 |
+
- `config` (dict, optional): The configuration dictionary. Defaults to None.
|
587 |
+
- `dtype` (torch.dtype, optional): The data type. Defaults to None.
|
588 |
+
"""
|
589 |
+
self.memory_used_encode = lambda shape, dtype: (
|
590 |
+
1767 * shape[2] * shape[3]
|
591 |
+
) * Device.dtype_size(
|
592 |
+
dtype
|
593 |
+
) # These are for AutoencoderKL and need tweaking (should be lower)
|
594 |
+
self.memory_used_decode = lambda shape, dtype: (
|
595 |
+
2178 * shape[2] * shape[3] * 64
|
596 |
+
) * Device.dtype_size(dtype)
|
597 |
+
self.downscale_ratio = 8
|
598 |
+
self.upscale_ratio = 8
|
599 |
+
self.latent_channels = 4
|
600 |
+
self.output_channels = 3
|
601 |
+
self.process_input = lambda image: image * 2.0 - 1.0
|
602 |
+
self.process_output = lambda image: torch.clamp(
|
603 |
+
(image + 1.0) / 2.0, min=0.0, max=1.0
|
604 |
+
)
|
605 |
+
self.working_dtypes = [torch.bfloat16, torch.float32]
|
606 |
+
|
607 |
+
if config is None:
|
608 |
+
if "decoder.conv_in.weight" in sd:
|
609 |
+
# default SD1.x/SD2.x VAE parameters
|
610 |
+
ddconfig = {
|
611 |
+
"double_z": True,
|
612 |
+
"z_channels": 4,
|
613 |
+
"resolution": 256,
|
614 |
+
"in_channels": 3,
|
615 |
+
"out_ch": 3,
|
616 |
+
"ch": 128,
|
617 |
+
"ch_mult": [1, 2, 4, 4],
|
618 |
+
"num_res_blocks": 2,
|
619 |
+
"attn_resolutions": [],
|
620 |
+
"dropout": 0.0,
|
621 |
+
}
|
622 |
+
|
623 |
+
if (
|
624 |
+
"encoder.down.2.downsample.conv.weight" not in sd
|
625 |
+
and "decoder.up.3.upsample.conv.weight" not in sd
|
626 |
+
): # Stable diffusion x4 upscaler VAE
|
627 |
+
ddconfig["ch_mult"] = [1, 2, 4]
|
628 |
+
self.downscale_ratio = 4
|
629 |
+
self.upscale_ratio = 4
|
630 |
+
|
631 |
+
self.latent_channels = ddconfig["z_channels"] = sd[
|
632 |
+
"decoder.conv_in.weight"
|
633 |
+
].shape[1]
|
634 |
+
# Initialize model
|
635 |
+
self.first_stage_model = AutoencodingEngine(
|
636 |
+
Encoder(**ddconfig),
|
637 |
+
Decoder(**ddconfig),
|
638 |
+
DiagonalGaussianRegularizer(),
|
639 |
+
flux=flux
|
640 |
+
)
|
641 |
+
else:
|
642 |
+
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
|
643 |
+
self.first_stage_model = None
|
644 |
+
return
|
645 |
+
|
646 |
+
self.first_stage_model = self.first_stage_model.eval()
|
647 |
+
|
648 |
+
m, u = self.first_stage_model.load_state_dict(sd, strict=False)
|
649 |
+
if len(m) > 0:
|
650 |
+
logging.warning("Missing VAE keys {}".format(m))
|
651 |
+
|
652 |
+
if len(u) > 0:
|
653 |
+
logging.debug("Leftover VAE keys {}".format(u))
|
654 |
+
|
655 |
+
if device is None:
|
656 |
+
device = Device.vae_device()
|
657 |
+
self.device = device
|
658 |
+
offload_device = Device.vae_offload_device()
|
659 |
+
if dtype is None:
|
660 |
+
dtype = Device.vae_dtype()
|
661 |
+
self.vae_dtype = dtype
|
662 |
+
self.first_stage_model.to(self.vae_dtype)
|
663 |
+
self.output_device = Device.intermediate_device()
|
664 |
+
|
665 |
+
self.patcher = ModelPatcher.ModelPatcher(
|
666 |
+
self.first_stage_model,
|
667 |
+
load_device=self.device,
|
668 |
+
offload_device=offload_device,
|
669 |
+
)
|
670 |
+
logging.debug(
|
671 |
+
"VAE load device: {}, offload device: {}, dtype: {}".format(
|
672 |
+
self.device, offload_device, self.vae_dtype
|
673 |
+
)
|
674 |
+
)
|
675 |
+
|
676 |
+
|
677 |
+
def vae_encode_crop_pixels(self, pixels: torch.Tensor) -> torch.Tensor:
|
678 |
+
"""#### Crop the input pixels to be compatible with the VAE.
|
679 |
+
|
680 |
+
#### Args:
|
681 |
+
- `pixels` (torch.Tensor): The input pixel tensor.
|
682 |
+
|
683 |
+
#### Returns:
|
684 |
+
- `torch.Tensor`: The cropped pixel tensor.
|
685 |
+
"""
|
686 |
+
(pixels.shape[1] // self.downscale_ratio) * self.downscale_ratio
|
687 |
+
(pixels.shape[2] // self.downscale_ratio) * self.downscale_ratio
|
688 |
+
return pixels
|
689 |
+
|
690 |
+
def decode(self, samples_in: torch.Tensor, flux:bool = False) -> torch.Tensor:
|
691 |
+
"""#### Decode the latent samples to pixel samples.
|
692 |
+
|
693 |
+
#### Args:
|
694 |
+
- `samples_in` (torch.Tensor): The input latent samples.
|
695 |
+
|
696 |
+
#### Returns:
|
697 |
+
- `torch.Tensor`: The decoded pixel samples.
|
698 |
+
"""
|
699 |
+
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
|
700 |
+
Device.load_models_gpu([self.patcher], memory_required=memory_used)
|
701 |
+
free_memory = Device.get_free_memory(self.device)
|
702 |
+
batch_number = int(free_memory / memory_used)
|
703 |
+
batch_number = max(1, batch_number)
|
704 |
+
|
705 |
+
pixel_samples = torch.empty(
|
706 |
+
(
|
707 |
+
samples_in.shape[0],
|
708 |
+
3,
|
709 |
+
round(samples_in.shape[2] * self.upscale_ratio),
|
710 |
+
round(samples_in.shape[3] * self.upscale_ratio),
|
711 |
+
),
|
712 |
+
device=self.output_device,
|
713 |
+
)
|
714 |
+
for x in range(0, samples_in.shape[0], batch_number):
|
715 |
+
samples = (
|
716 |
+
samples_in[x : x + batch_number].to(self.vae_dtype).to(self.device)
|
717 |
+
)
|
718 |
+
pixel_samples[x : x + batch_number] = self.process_output(
|
719 |
+
self.first_stage_model.decode(samples, flux=flux).to(self.output_device).float()
|
720 |
+
)
|
721 |
+
pixel_samples = pixel_samples.to(self.output_device).movedim(1, -1)
|
722 |
+
return pixel_samples
|
723 |
+
|
724 |
+
|
725 |
+
def encode(self, pixel_samples: torch.Tensor, flux:bool = False) -> torch.Tensor:
|
726 |
+
"""#### Encode the pixel samples to latent samples.
|
727 |
+
|
728 |
+
#### Args:
|
729 |
+
- `pixel_samples` (torch.Tensor): The input pixel samples.
|
730 |
+
|
731 |
+
#### Returns:
|
732 |
+
- `torch.Tensor`: The encoded latent samples.
|
733 |
+
"""
|
734 |
+
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
|
735 |
+
pixel_samples = pixel_samples.movedim(-1, 1)
|
736 |
+
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
|
737 |
+
Device.load_models_gpu([self.patcher], memory_required=memory_used)
|
738 |
+
free_memory = Device.get_free_memory(self.device)
|
739 |
+
batch_number = int(free_memory / memory_used)
|
740 |
+
batch_number = max(1, batch_number)
|
741 |
+
samples = torch.empty(
|
742 |
+
(
|
743 |
+
pixel_samples.shape[0],
|
744 |
+
self.latent_channels,
|
745 |
+
round(pixel_samples.shape[2] // self.downscale_ratio),
|
746 |
+
round(pixel_samples.shape[3] // self.downscale_ratio),
|
747 |
+
),
|
748 |
+
device=self.output_device,
|
749 |
+
)
|
750 |
+
for x in range(0, pixel_samples.shape[0], batch_number):
|
751 |
+
pixels_in = (
|
752 |
+
self.process_input(pixel_samples[x : x + batch_number])
|
753 |
+
.to(self.vae_dtype)
|
754 |
+
.to(self.device)
|
755 |
+
)
|
756 |
+
samples[x : x + batch_number] = (
|
757 |
+
self.first_stage_model.encode(pixels_in, flux=flux).to(self.output_device).float()
|
758 |
+
)
|
759 |
+
|
760 |
+
return samples
|
761 |
+
|
762 |
+
def get_sd(self):
|
763 |
+
"""#### Get the state dictionary.
|
764 |
+
|
765 |
+
#### Returns:
|
766 |
+
- `dict`: The state dictionary.
|
767 |
+
"""
|
768 |
+
return self.first_stage_model.state_dict()
|
769 |
+
|
770 |
+
|
771 |
+
class VAEDecode:
|
772 |
+
"""#### Class for decoding VAE samples."""
|
773 |
+
|
774 |
+
def decode(self, vae: VAE, samples: dict, flux:bool = False) -> Tuple[torch.Tensor]:
|
775 |
+
"""#### Decode the VAE samples.
|
776 |
+
|
777 |
+
#### Args:
|
778 |
+
- `vae` (VAE): The VAE instance.
|
779 |
+
- `samples` (dict): The samples dictionary.
|
780 |
+
|
781 |
+
#### Returns:
|
782 |
+
- `Tuple[torch.Tensor]`: The decoded samples.
|
783 |
+
"""
|
784 |
+
return (vae.decode(samples["samples"], flux=flux),)
|
785 |
+
|
786 |
+
|
787 |
+
class VAEEncode:
|
788 |
+
"""#### Class for encoding VAE samples."""
|
789 |
+
|
790 |
+
def encode(self, vae: VAE, pixels: torch.Tensor, flux:bool = False) -> Tuple[dict]:
|
791 |
+
"""#### Encode the VAE samples.
|
792 |
+
|
793 |
+
#### Args:
|
794 |
+
- `vae` (VAE): The VAE instance.
|
795 |
+
- `pixels` (torch.Tensor): The input pixel tensor.
|
796 |
+
|
797 |
+
#### Returns:
|
798 |
+
- `Tuple[dict]`: The encoded samples dictionary.
|
799 |
+
"""
|
800 |
+
t = vae.encode(pixels[:, :, :, :3], flux=flux)
|
801 |
+
return ({"samples": t},)
|
802 |
+
|
803 |
+
|
804 |
+
class VAELoader:
|
805 |
+
"""#### Class for loading VAEs."""
|
806 |
+
# TODO: scale factor?
|
807 |
+
def load_vae(self, vae_name):
|
808 |
+
"""#### Load the VAE.
|
809 |
+
|
810 |
+
#### Args:
|
811 |
+
- `vae_name`: The name of the VAE.
|
812 |
+
|
813 |
+
#### Returns:
|
814 |
+
- `Tuple[VAE]`: The VAE instance.
|
815 |
+
"""
|
816 |
+
if vae_name in ["taesd", "taesdxl", "taesd3", "taef1"]:
|
817 |
+
sd = self.load_taesd(vae_name)
|
818 |
+
else:
|
819 |
+
vae_path = "./_internal/vae/" + vae_name
|
820 |
+
sd = util.load_torch_file(vae_path)
|
821 |
+
vae = VAE(sd=sd)
|
822 |
+
return (vae,)
|
823 |
+
|
824 |
+
|
modules/AutoEncoders/taesd.py
ADDED
@@ -0,0 +1,310 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Tiny AutoEncoder for Stable Diffusion
|
3 |
+
(DNN for encoding / decoding SD's latent space)
|
4 |
+
"""
|
5 |
+
|
6 |
+
# TODO: Check if multiprocessing is possible for this module
|
7 |
+
from PIL import Image
|
8 |
+
import numpy as np
|
9 |
+
from sympy import im
|
10 |
+
import torch
|
11 |
+
from modules.Utilities import util
|
12 |
+
import torch.nn as nn
|
13 |
+
|
14 |
+
from modules.cond import cast
|
15 |
+
from modules.user import app_instance
|
16 |
+
|
17 |
+
|
18 |
+
def conv(n_in: int, n_out: int, **kwargs) -> cast.disable_weight_init.Conv2d:
|
19 |
+
"""#### Create a convolutional layer.
|
20 |
+
|
21 |
+
#### Args:
|
22 |
+
- `n_in` (int): The number of input channels.
|
23 |
+
- `n_out` (int): The number of output channels.
|
24 |
+
|
25 |
+
#### Returns:
|
26 |
+
- `torch.nn.Module`: The convolutional layer.
|
27 |
+
"""
|
28 |
+
return cast.disable_weight_init.Conv2d(n_in, n_out, 3, padding=1, **kwargs)
|
29 |
+
|
30 |
+
|
31 |
+
class Clamp(nn.Module):
|
32 |
+
"""#### Class representing a clamping layer."""
|
33 |
+
|
34 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
35 |
+
"""#### Forward pass of the clamping layer.
|
36 |
+
|
37 |
+
#### Args:
|
38 |
+
- `x` (torch.Tensor): The input tensor.
|
39 |
+
|
40 |
+
#### Returns:
|
41 |
+
- `torch.Tensor`: The clamped tensor.
|
42 |
+
"""
|
43 |
+
return torch.tanh(x / 3) * 3
|
44 |
+
|
45 |
+
|
46 |
+
class Block(nn.Module):
|
47 |
+
"""#### Class representing a block layer."""
|
48 |
+
|
49 |
+
def __init__(self, n_in: int, n_out: int):
|
50 |
+
"""#### Initialize the block layer.
|
51 |
+
|
52 |
+
#### Args:
|
53 |
+
- `n_in` (int): The number of input channels.
|
54 |
+
- `n_out` (int): The number of output channels.
|
55 |
+
|
56 |
+
#### Returns:
|
57 |
+
- `Block`: The block layer.
|
58 |
+
"""
|
59 |
+
super().__init__()
|
60 |
+
self.conv = nn.Sequential(
|
61 |
+
conv(n_in, n_out),
|
62 |
+
nn.ReLU(),
|
63 |
+
conv(n_out, n_out),
|
64 |
+
nn.ReLU(),
|
65 |
+
conv(n_out, n_out),
|
66 |
+
)
|
67 |
+
self.skip = (
|
68 |
+
cast.disable_weight_init.Conv2d(n_in, n_out, 1, bias=False)
|
69 |
+
if n_in != n_out
|
70 |
+
else nn.Identity()
|
71 |
+
)
|
72 |
+
self.fuse = nn.ReLU()
|
73 |
+
|
74 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
75 |
+
return self.fuse(self.conv(x) + self.skip(x))
|
76 |
+
|
77 |
+
|
78 |
+
def Encoder2(latent_channels: int = 4) -> nn.Sequential:
|
79 |
+
"""#### Create an encoder.
|
80 |
+
|
81 |
+
#### Args:
|
82 |
+
- `latent_channels` (int, optional): The number of latent channels. Defaults to 4.
|
83 |
+
|
84 |
+
#### Returns:
|
85 |
+
- `torch.nn.Module`: The encoder.
|
86 |
+
"""
|
87 |
+
return nn.Sequential(
|
88 |
+
conv(3, 64),
|
89 |
+
Block(64, 64),
|
90 |
+
conv(64, 64, stride=2, bias=False),
|
91 |
+
Block(64, 64),
|
92 |
+
Block(64, 64),
|
93 |
+
Block(64, 64),
|
94 |
+
conv(64, 64, stride=2, bias=False),
|
95 |
+
Block(64, 64),
|
96 |
+
Block(64, 64),
|
97 |
+
Block(64, 64),
|
98 |
+
conv(64, 64, stride=2, bias=False),
|
99 |
+
Block(64, 64),
|
100 |
+
Block(64, 64),
|
101 |
+
Block(64, 64),
|
102 |
+
conv(64, latent_channels),
|
103 |
+
)
|
104 |
+
|
105 |
+
|
106 |
+
def Decoder2(latent_channels: int = 4) -> nn.Sequential:
|
107 |
+
"""#### Create a decoder.
|
108 |
+
|
109 |
+
#### Args:
|
110 |
+
- `latent_channels` (int, optional): The number of latent channels. Defaults to 4.
|
111 |
+
|
112 |
+
#### Returns:
|
113 |
+
- `torch.nn.Module`: The decoder.
|
114 |
+
"""
|
115 |
+
return nn.Sequential(
|
116 |
+
Clamp(),
|
117 |
+
conv(latent_channels, 64),
|
118 |
+
nn.ReLU(),
|
119 |
+
Block(64, 64),
|
120 |
+
Block(64, 64),
|
121 |
+
Block(64, 64),
|
122 |
+
nn.Upsample(scale_factor=2),
|
123 |
+
conv(64, 64, bias=False),
|
124 |
+
Block(64, 64),
|
125 |
+
Block(64, 64),
|
126 |
+
Block(64, 64),
|
127 |
+
nn.Upsample(scale_factor=2),
|
128 |
+
conv(64, 64, bias=False),
|
129 |
+
Block(64, 64),
|
130 |
+
Block(64, 64),
|
131 |
+
Block(64, 64),
|
132 |
+
nn.Upsample(scale_factor=2),
|
133 |
+
conv(64, 64, bias=False),
|
134 |
+
Block(64, 64),
|
135 |
+
conv(64, 3),
|
136 |
+
)
|
137 |
+
|
138 |
+
|
139 |
+
class TAESD(nn.Module):
|
140 |
+
"""#### Class representing a Tiny AutoEncoder for Stable Diffusion.
|
141 |
+
|
142 |
+
#### Attributes:
|
143 |
+
- `latent_magnitude` (float): Magnitude of the latent space.
|
144 |
+
- `latent_shift` (float): Shift value for the latent space.
|
145 |
+
- `vae_shift` (torch.nn.Parameter): Shift parameter for the VAE.
|
146 |
+
- `vae_scale` (torch.nn.Parameter): Scale parameter for the VAE.
|
147 |
+
- `taesd_encoder` (Encoder2): Encoder network for the TAESD.
|
148 |
+
- `taesd_decoder` (Decoder2): Decoder network for the TAESD.
|
149 |
+
|
150 |
+
#### Args:
|
151 |
+
- `encoder_path` (str, optional): Path to the encoder model file. Defaults to None.
|
152 |
+
- `decoder_path` (str, optional): Path to the decoder model file. Defaults to "./_internal/vae_approx/taesd_decoder.safetensors".
|
153 |
+
- `latent_channels` (int, optional): Number of channels in the latent space. Defaults to 4.
|
154 |
+
|
155 |
+
#### Methods:
|
156 |
+
- `scale_latents(x)`:
|
157 |
+
Scales raw latents to the range [0, 1].
|
158 |
+
- `unscale_latents(x)`:
|
159 |
+
Unscales latents from the range [0, 1] to raw latents.
|
160 |
+
- `decode(x)`:
|
161 |
+
Decodes the given latent representation to the original space.
|
162 |
+
- `encode(x)`:
|
163 |
+
Encodes the given input to the latent space.
|
164 |
+
"""
|
165 |
+
|
166 |
+
latent_magnitude = 3
|
167 |
+
latent_shift = 0.5
|
168 |
+
|
169 |
+
def __init__(
|
170 |
+
self,
|
171 |
+
encoder_path: str = None,
|
172 |
+
decoder_path: str = None,
|
173 |
+
latent_channels: int = 4,
|
174 |
+
):
|
175 |
+
"""#### Initialize the TAESD model.
|
176 |
+
|
177 |
+
#### Args:
|
178 |
+
- `encoder_path` (str, optional): Path to the encoder model file. Defaults to None.
|
179 |
+
- `decoder_path` (str, optional): Path to the decoder model file. Defaults to "./_internal/vae_approx/taesd_decoder.safetensors".
|
180 |
+
- `latent_channels` (int, optional): Number of channels in the latent space. Defaults to 4.
|
181 |
+
"""
|
182 |
+
super().__init__()
|
183 |
+
self.vae_shift = torch.nn.Parameter(torch.tensor(0.0))
|
184 |
+
self.vae_scale = torch.nn.Parameter(torch.tensor(1.0))
|
185 |
+
self.taesd_encoder = Encoder2(latent_channels)
|
186 |
+
self.taesd_decoder = Decoder2(latent_channels)
|
187 |
+
decoder_path = (
|
188 |
+
"./_internal/vae_approx/taesd_decoder.safetensors"
|
189 |
+
if decoder_path is None
|
190 |
+
else decoder_path
|
191 |
+
)
|
192 |
+
if encoder_path is not None:
|
193 |
+
self.taesd_encoder.load_state_dict(
|
194 |
+
util.load_torch_file(encoder_path, safe_load=True)
|
195 |
+
)
|
196 |
+
if decoder_path is not None:
|
197 |
+
self.taesd_decoder.load_state_dict(
|
198 |
+
util.load_torch_file(decoder_path, safe_load=True)
|
199 |
+
)
|
200 |
+
|
201 |
+
@staticmethod
|
202 |
+
def scale_latents(x: torch.Tensor) -> torch.Tensor:
|
203 |
+
"""#### Scales raw latents to the range [0, 1].
|
204 |
+
|
205 |
+
#### Args:
|
206 |
+
- `x` (torch.Tensor): The raw latents.
|
207 |
+
|
208 |
+
#### Returns:
|
209 |
+
- `torch.Tensor`: The scaled latents.
|
210 |
+
"""
|
211 |
+
return x.div(2 * TAESD.latent_magnitude).add(TAESD.latent_shift).clamp(0, 1)
|
212 |
+
|
213 |
+
@staticmethod
|
214 |
+
def unscale_latents(x: torch.Tensor) -> torch.Tensor:
|
215 |
+
"""#### Unscales latents from the range [0, 1] to raw latents.
|
216 |
+
|
217 |
+
#### Args:
|
218 |
+
- `x` (torch.Tensor): The scaled latents.
|
219 |
+
|
220 |
+
#### Returns:
|
221 |
+
- `torch.Tensor`: The raw latents.
|
222 |
+
"""
|
223 |
+
return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude)
|
224 |
+
|
225 |
+
def decode(self, x: torch.Tensor) -> torch.Tensor:
|
226 |
+
"""#### Decodes the given latent representation to the original space.
|
227 |
+
|
228 |
+
#### Args:
|
229 |
+
- `x` (torch.Tensor): The latent representation.
|
230 |
+
|
231 |
+
#### Returns:
|
232 |
+
- `torch.Tensor`: The decoded representation.
|
233 |
+
"""
|
234 |
+
device = next(self.taesd_decoder.parameters()).device
|
235 |
+
x = x.to(device)
|
236 |
+
x_sample = self.taesd_decoder((x - self.vae_shift) * self.vae_scale)
|
237 |
+
x_sample = x_sample.sub(0.5).mul(2)
|
238 |
+
return x_sample
|
239 |
+
|
240 |
+
def encode(self, x: torch.Tensor) -> torch.Tensor:
|
241 |
+
"""#### Encodes the given input to the latent space.
|
242 |
+
|
243 |
+
#### Args:
|
244 |
+
- `x` (torch.Tensor): The input.
|
245 |
+
|
246 |
+
#### Returns:
|
247 |
+
- `torch.Tensor`: The latent representation.
|
248 |
+
"""
|
249 |
+
device = next(self.taesd_encoder.parameters()).device
|
250 |
+
x = x.to(device)
|
251 |
+
return (self.taesd_encoder(x * 0.5 + 0.5) / self.vae_scale) + self.vae_shift
|
252 |
+
|
253 |
+
|
254 |
+
def taesd_preview(x: torch.Tensor, flux: bool = False):
|
255 |
+
"""#### Preview the batched latent tensors as images.
|
256 |
+
|
257 |
+
#### Args:
|
258 |
+
- `x` (torch.Tensor): Input latent tensor with shape [B,C,H,W]
|
259 |
+
- `flux` (bool, optional): Whether using flux model (for channel ordering). Defaults to False.
|
260 |
+
"""
|
261 |
+
if app_instance.app.previewer_var.get() is True:
|
262 |
+
taesd_instance = TAESD()
|
263 |
+
|
264 |
+
# Handle channel dimension
|
265 |
+
if x.shape[1] != 4:
|
266 |
+
desired_channels = 4
|
267 |
+
current_channels = x.shape[1]
|
268 |
+
|
269 |
+
if current_channels > desired_channels:
|
270 |
+
x = x[:, :desired_channels, :, :]
|
271 |
+
else:
|
272 |
+
padding = torch.zeros(x.shape[0], desired_channels - current_channels,
|
273 |
+
x.shape[2], x.shape[3], device=x.device)
|
274 |
+
x = torch.cat([x, padding], dim=1)
|
275 |
+
|
276 |
+
# Process entire batch at once
|
277 |
+
decoded_batch = taesd_instance.decode(x)
|
278 |
+
|
279 |
+
images = []
|
280 |
+
|
281 |
+
# Convert each image in batch
|
282 |
+
for decoded in decoded_batch:
|
283 |
+
# Handle channel dimension
|
284 |
+
if decoded.shape[0] == 1:
|
285 |
+
decoded = decoded.repeat(3, 1, 1)
|
286 |
+
|
287 |
+
# Apply different normalization for flux vs standard mode
|
288 |
+
if flux:
|
289 |
+
# For flux: Assume BGR ordering and different normalization
|
290 |
+
decoded = decoded[[2,1,0], :, :] # BGR -> RGB
|
291 |
+
# Adjust normalization for flux model range
|
292 |
+
decoded = decoded.clamp(-1, 1)
|
293 |
+
decoded = (decoded + 1.0) * 0.5 # Scale from [-1,1] to [0,1]
|
294 |
+
else:
|
295 |
+
# Standard normalization
|
296 |
+
decoded = (decoded + 1.0) / 2.0
|
297 |
+
|
298 |
+
# Convert to numpy and uint8
|
299 |
+
image_np = (decoded.cpu().detach().numpy() * 255.0)
|
300 |
+
image_np = np.transpose(image_np, (1, 2, 0))
|
301 |
+
image_np = np.clip(image_np, 0, 255).astype(np.uint8)
|
302 |
+
|
303 |
+
# Create PIL Image
|
304 |
+
img = Image.fromarray(image_np, mode='RGB')
|
305 |
+
images.append(img)
|
306 |
+
|
307 |
+
# Update display with all images
|
308 |
+
app_instance.app.update_image(images)
|
309 |
+
else:
|
310 |
+
pass
|
modules/BlackForest/Flux.py
ADDED
@@ -0,0 +1,853 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Original code can be found on: https://github.com/black-forest-labs/flux
|
2 |
+
|
3 |
+
|
4 |
+
from dataclasses import dataclass
|
5 |
+
from einops import rearrange, repeat
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
|
9 |
+
from modules.Attention import Attention
|
10 |
+
from modules.Device import Device
|
11 |
+
from modules.Model import ModelBase
|
12 |
+
from modules.Utilities import Latent
|
13 |
+
from modules.cond import cast, cond
|
14 |
+
from modules.sample import sampling, sampling_util
|
15 |
+
|
16 |
+
|
17 |
+
# Define the attention mechanism
|
18 |
+
def attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, pe: torch.Tensor) -> torch.Tensor:
|
19 |
+
"""#### Compute the attention mechanism.
|
20 |
+
|
21 |
+
#### Args:
|
22 |
+
- `q` (Tensor): The query tensor.
|
23 |
+
- `k` (Tensor): The key tensor.
|
24 |
+
- `v` (Tensor): The value tensor.
|
25 |
+
- `pe` (Tensor): The positional encoding tensor.
|
26 |
+
|
27 |
+
#### Returns:
|
28 |
+
- `Tensor`: The attention tensor.
|
29 |
+
"""
|
30 |
+
q, k = apply_rope(q, k, pe)
|
31 |
+
heads = q.shape[1]
|
32 |
+
x = Attention.optimized_attention(q, k, v, heads, skip_reshape=True)
|
33 |
+
return x
|
34 |
+
|
35 |
+
# Define the rotary positional encoding (RoPE)
|
36 |
+
def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
|
37 |
+
"""#### Compute the rotary positional encoding.
|
38 |
+
|
39 |
+
#### Args:
|
40 |
+
- `pos` (Tensor): The position tensor.
|
41 |
+
- `dim` (int): The dimension of the tensor.
|
42 |
+
- `theta` (int): The theta value for scaling.
|
43 |
+
|
44 |
+
#### Returns:
|
45 |
+
- `Tensor`: The rotary positional encoding tensor.
|
46 |
+
"""
|
47 |
+
assert dim % 2 == 0
|
48 |
+
if Device.is_device_mps(pos.device) or Device.is_intel_xpu():
|
49 |
+
device = torch.device("cpu")
|
50 |
+
else:
|
51 |
+
device = pos.device
|
52 |
+
|
53 |
+
scale = torch.linspace(
|
54 |
+
0, (dim - 2) / dim, steps=dim // 2, dtype=torch.float64, device=device
|
55 |
+
)
|
56 |
+
omega = 1.0 / (theta**scale)
|
57 |
+
out = torch.einsum(
|
58 |
+
"...n,d->...nd", pos.to(dtype=torch.float32, device=device), omega
|
59 |
+
)
|
60 |
+
out = torch.stack(
|
61 |
+
[torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1
|
62 |
+
)
|
63 |
+
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
|
64 |
+
return out.to(dtype=torch.float32, device=pos.device)
|
65 |
+
|
66 |
+
# Apply the rotary positional encoding to the query and key tensors
|
67 |
+
def apply_rope(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor) -> tuple:
|
68 |
+
"""#### Apply the rotary positional encoding to the query and key tensors.
|
69 |
+
|
70 |
+
#### Args:
|
71 |
+
- `xq` (Tensor): The query tensor.
|
72 |
+
- `xk` (Tensor): The key tensor.
|
73 |
+
- `freqs_cis` (Tensor): The frequency tensor.
|
74 |
+
|
75 |
+
#### Returns:
|
76 |
+
- `tuple`: The modified query and key tensors.
|
77 |
+
"""
|
78 |
+
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
|
79 |
+
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
|
80 |
+
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
81 |
+
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
82 |
+
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
83 |
+
|
84 |
+
# Define the embedding class
|
85 |
+
class EmbedND(nn.Module):
|
86 |
+
def __init__(self, dim: int, theta: int, axes_dim: list):
|
87 |
+
"""#### Initialize the EmbedND class.
|
88 |
+
|
89 |
+
#### Args:
|
90 |
+
- `dim` (int): The dimension of the tensor.
|
91 |
+
- `theta` (int): The theta value for scaling.
|
92 |
+
- `axes_dim` (list): The list of axis dimensions.
|
93 |
+
"""
|
94 |
+
super().__init__()
|
95 |
+
self.dim = dim
|
96 |
+
self.theta = theta
|
97 |
+
self.axes_dim = axes_dim
|
98 |
+
|
99 |
+
def forward(self, ids: torch.Tensor) -> torch.Tensor:
|
100 |
+
"""#### Forward pass for the EmbedND class.
|
101 |
+
|
102 |
+
#### Args:
|
103 |
+
- `ids` (Tensor): The input tensor.
|
104 |
+
|
105 |
+
#### Returns:
|
106 |
+
- `Tensor`: The embedded tensor.
|
107 |
+
"""
|
108 |
+
n_axes = ids.shape[-1]
|
109 |
+
emb = torch.cat(
|
110 |
+
[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
|
111 |
+
dim=-3,
|
112 |
+
)
|
113 |
+
return emb.unsqueeze(1)
|
114 |
+
|
115 |
+
# Define the MLP embedder class
|
116 |
+
class MLPEmbedder(nn.Module):
|
117 |
+
def __init__(self, in_dim: int, hidden_dim: int, dtype=None, device=None, operations=None):
|
118 |
+
"""#### Initialize the MLPEmbedder class.
|
119 |
+
|
120 |
+
#### Args:
|
121 |
+
- `in_dim` (int): The input dimension.
|
122 |
+
- `hidden_dim` (int): The hidden dimension.
|
123 |
+
- `dtype` (optional): The data type.
|
124 |
+
- `device` (optional): The device.
|
125 |
+
- `operations` (optional): The operations module.
|
126 |
+
"""
|
127 |
+
super().__init__()
|
128 |
+
self.in_layer = operations.Linear(
|
129 |
+
in_dim, hidden_dim, bias=True, dtype=dtype, device=device
|
130 |
+
)
|
131 |
+
self.silu = nn.SiLU()
|
132 |
+
self.out_layer = operations.Linear(
|
133 |
+
hidden_dim, hidden_dim, bias=True, dtype=dtype, device=device
|
134 |
+
)
|
135 |
+
|
136 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
137 |
+
"""#### Forward pass for the MLPEmbedder class.
|
138 |
+
|
139 |
+
#### Args:
|
140 |
+
- `x` (Tensor): The input tensor.
|
141 |
+
|
142 |
+
#### Returns:
|
143 |
+
- `Tensor`: The output tensor.
|
144 |
+
"""
|
145 |
+
return self.out_layer(self.silu(self.in_layer(x)))
|
146 |
+
|
147 |
+
# Define the RMS normalization class
|
148 |
+
class RMSNorm(nn.Module):
|
149 |
+
def __init__(self, dim: int, dtype=None, device=None, operations=None):
|
150 |
+
"""#### Initialize the RMSNorm class.
|
151 |
+
|
152 |
+
#### Args:
|
153 |
+
- `dim` (int): The dimension of the tensor.
|
154 |
+
- `dtype` (optional): The data type.
|
155 |
+
- `device` (optional): The device.
|
156 |
+
- `operations` (optional): The operations module.
|
157 |
+
"""
|
158 |
+
super().__init__()
|
159 |
+
self.scale = nn.Parameter(torch.empty((dim), dtype=dtype, device=device))
|
160 |
+
|
161 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
162 |
+
"""#### Forward pass for the RMSNorm class.
|
163 |
+
|
164 |
+
#### Args:
|
165 |
+
- `x` (Tensor): The input tensor.
|
166 |
+
|
167 |
+
#### Returns:
|
168 |
+
- `Tensor`: The normalized tensor.
|
169 |
+
"""
|
170 |
+
return rms_norm(x, self.scale, 1e-6)
|
171 |
+
|
172 |
+
# Define the query-key normalization class
|
173 |
+
class QKNorm(nn.Module):
|
174 |
+
def __init__(self, dim: int, dtype=None, device=None, operations=None):
|
175 |
+
"""#### Initialize the QKNorm class.
|
176 |
+
|
177 |
+
#### Args:
|
178 |
+
- `dim` (int): The dimension of the tensor.
|
179 |
+
- `dtype` (optional): The data type.
|
180 |
+
- `device` (optional): The device.
|
181 |
+
- `operations` (optional): The operations module.
|
182 |
+
"""
|
183 |
+
super().__init__()
|
184 |
+
self.query_norm = RMSNorm(dim, dtype=dtype, device=device, operations=operations)
|
185 |
+
self.key_norm = RMSNorm(dim, dtype=dtype, device=device, operations=operations)
|
186 |
+
|
187 |
+
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> tuple:
|
188 |
+
"""#### Forward pass for the QKNorm class.
|
189 |
+
|
190 |
+
#### Args:
|
191 |
+
- `q` (Tensor): The query tensor.
|
192 |
+
- `k` (Tensor): The key tensor.
|
193 |
+
- `v` (Tensor): The value tensor.
|
194 |
+
|
195 |
+
#### Returns:
|
196 |
+
- `tuple`: The normalized query and key tensors.
|
197 |
+
"""
|
198 |
+
q = self.query_norm(q)
|
199 |
+
k = self.key_norm(k)
|
200 |
+
return q.to(v), k.to(v)
|
201 |
+
|
202 |
+
# Define the self-attention class
|
203 |
+
class SelfAttention(nn.Module):
|
204 |
+
def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False, dtype=None, device=None, operations=None):
|
205 |
+
"""#### Initialize the SelfAttention class.
|
206 |
+
|
207 |
+
#### Args:
|
208 |
+
- `dim` (int): The dimension of the tensor.
|
209 |
+
- `num_heads` (int, optional): The number of attention heads. Defaults to 8.
|
210 |
+
- `qkv_bias` (bool, optional): Whether to use bias in QKV projection. Defaults to False.
|
211 |
+
- `dtype` (optional): The data type.
|
212 |
+
- `device` (optional): The device.
|
213 |
+
- `operations` (optional): The operations module.
|
214 |
+
"""
|
215 |
+
super().__init__()
|
216 |
+
self.num_heads = num_heads
|
217 |
+
head_dim = dim // num_heads
|
218 |
+
|
219 |
+
self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
|
220 |
+
self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations)
|
221 |
+
self.proj = operations.Linear(dim, dim, dtype=dtype, device=device)
|
222 |
+
|
223 |
+
# Define the modulation output dataclass
|
224 |
+
@dataclass
|
225 |
+
class ModulationOut:
|
226 |
+
shift: torch.Tensor
|
227 |
+
scale: torch.Tensor
|
228 |
+
gate: torch.Tensor
|
229 |
+
|
230 |
+
# Define the modulation class
|
231 |
+
class Modulation(nn.Module):
|
232 |
+
def __init__(self, dim: int, double: bool, dtype=None, device=None, operations=None):
|
233 |
+
"""#### Initialize the Modulation class.
|
234 |
+
|
235 |
+
#### Args:
|
236 |
+
- `dim` (int): The dimension of the tensor.
|
237 |
+
- `double` (bool): Whether to use double modulation.
|
238 |
+
- `dtype` (optional): The data type.
|
239 |
+
- `device` (optional): The device.
|
240 |
+
- `operations` (optional): The operations module.
|
241 |
+
"""
|
242 |
+
super().__init__()
|
243 |
+
self.is_double = double
|
244 |
+
self.multiplier = 6 if double else 3
|
245 |
+
self.lin = operations.Linear(dim, self.multiplier * dim, bias=True, dtype=dtype, device=device)
|
246 |
+
|
247 |
+
def forward(self, vec: torch.Tensor) -> tuple:
|
248 |
+
"""#### Forward pass for the Modulation class.
|
249 |
+
|
250 |
+
#### Args:
|
251 |
+
- `vec` (Tensor): The input tensor.
|
252 |
+
|
253 |
+
#### Returns:
|
254 |
+
- `tuple`: The modulation output.
|
255 |
+
"""
|
256 |
+
out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
|
257 |
+
return (ModulationOut(*out[:3]), ModulationOut(*out[3:]) if self.is_double else None)
|
258 |
+
|
259 |
+
# Define the double stream block class
|
260 |
+
class DoubleStreamBlock(nn.Module):
|
261 |
+
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, dtype=None, device=None, operations=None):
|
262 |
+
"""#### Initialize the DoubleStreamBlock class.
|
263 |
+
|
264 |
+
#### Args:
|
265 |
+
- `hidden_size` (int): The hidden size.
|
266 |
+
- `num_heads` (int): The number of attention heads.
|
267 |
+
- `mlp_ratio` (float): The MLP ratio.
|
268 |
+
- `qkv_bias` (bool, optional): Whether to use bias in QKV projection. Defaults to False.
|
269 |
+
- `dtype` (optional): The data type.
|
270 |
+
- `device` (optional): The device.
|
271 |
+
- `operations` (optional): The operations module.
|
272 |
+
"""
|
273 |
+
super().__init__()
|
274 |
+
|
275 |
+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
276 |
+
self.num_heads = num_heads
|
277 |
+
self.hidden_size = hidden_size
|
278 |
+
self.img_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations)
|
279 |
+
self.img_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
280 |
+
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
|
281 |
+
self.img_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
282 |
+
self.img_mlp = nn.Sequential(
|
283 |
+
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
|
284 |
+
nn.GELU(approximate="tanh"),
|
285 |
+
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
286 |
+
)
|
287 |
+
|
288 |
+
self.txt_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations)
|
289 |
+
self.txt_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
290 |
+
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
|
291 |
+
self.txt_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
292 |
+
self.txt_mlp = nn.Sequential(
|
293 |
+
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
|
294 |
+
nn.GELU(approximate="tanh"),
|
295 |
+
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
296 |
+
)
|
297 |
+
|
298 |
+
def forward(self, img: torch.Tensor, txt: torch.Tensor, vec: torch.Tensor, pe: torch.Tensor) -> tuple:
|
299 |
+
"""#### Forward pass for the DoubleStreamBlock class.
|
300 |
+
|
301 |
+
#### Args:
|
302 |
+
- `img` (Tensor): The image tensor.
|
303 |
+
- `txt` (Tensor): The text tensor.
|
304 |
+
- `vec` (Tensor): The vector tensor.
|
305 |
+
- `pe` (Tensor): The positional encoding tensor.
|
306 |
+
|
307 |
+
#### Returns:
|
308 |
+
- `tuple`: The modified image and text tensors.
|
309 |
+
"""
|
310 |
+
img_mod1, img_mod2 = self.img_mod(vec)
|
311 |
+
txt_mod1, txt_mod2 = self.txt_mod(vec)
|
312 |
+
|
313 |
+
# prepare image for attention
|
314 |
+
img_modulated = self.img_norm1(img)
|
315 |
+
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
|
316 |
+
img_qkv = self.img_attn.qkv(img_modulated)
|
317 |
+
img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
318 |
+
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
|
319 |
+
|
320 |
+
# prepare txt for attention
|
321 |
+
txt_modulated = self.txt_norm1(txt)
|
322 |
+
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
|
323 |
+
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
324 |
+
txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
325 |
+
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
326 |
+
|
327 |
+
# run actual attention
|
328 |
+
attn = attention(
|
329 |
+
torch.cat((txt_q, img_q), dim=2),
|
330 |
+
torch.cat((txt_k, img_k), dim=2),
|
331 |
+
torch.cat((txt_v, img_v), dim=2),
|
332 |
+
pe=pe,
|
333 |
+
)
|
334 |
+
|
335 |
+
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
|
336 |
+
|
337 |
+
# calculate the img bloks
|
338 |
+
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
|
339 |
+
img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
|
340 |
+
|
341 |
+
# calculate the txt bloks
|
342 |
+
txt += txt_mod1.gate * self.txt_attn.proj(txt_attn)
|
343 |
+
txt += txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
|
344 |
+
|
345 |
+
if txt.dtype == torch.float16:
|
346 |
+
txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504)
|
347 |
+
|
348 |
+
return img, txt
|
349 |
+
|
350 |
+
# Define the single stream block class
|
351 |
+
class SingleStreamBlock(nn.Module):
|
352 |
+
"""
|
353 |
+
A DiT block with parallel linear layers as described in
|
354 |
+
https://arxiv.org/abs/2302.05442 and adapted modulation interface.
|
355 |
+
"""
|
356 |
+
|
357 |
+
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float = 4.0, qk_scale: float = None, dtype=None, device=None, operations=None):
|
358 |
+
"""#### Initialize the SingleStreamBlock class.
|
359 |
+
|
360 |
+
#### Args:
|
361 |
+
- `hidden_size` (int): The hidden size.
|
362 |
+
- `num_heads` (int): The number of attention heads.
|
363 |
+
- `mlp_ratio` (float, optional): The MLP ratio. Defaults to 4.0.
|
364 |
+
- `qk_scale` (float, optional): The QK scale. Defaults to None.
|
365 |
+
- `dtype` (optional): The data type.
|
366 |
+
- `device` (optional): The device.
|
367 |
+
- `operations` (optional): The operations module.
|
368 |
+
"""
|
369 |
+
super().__init__()
|
370 |
+
self.hidden_dim = hidden_size
|
371 |
+
self.num_heads = num_heads
|
372 |
+
head_dim = hidden_size // num_heads
|
373 |
+
self.scale = qk_scale or head_dim**-0.5
|
374 |
+
|
375 |
+
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
376 |
+
# qkv and mlp_in
|
377 |
+
self.linear1 = operations.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim, dtype=dtype, device=device)
|
378 |
+
# proj and mlp_out
|
379 |
+
self.linear2 = operations.Linear(hidden_size + self.mlp_hidden_dim, hidden_size, dtype=dtype, device=device)
|
380 |
+
|
381 |
+
self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations)
|
382 |
+
|
383 |
+
self.hidden_size = hidden_size
|
384 |
+
self.pre_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
385 |
+
|
386 |
+
self.mlp_act = nn.GELU(approximate="tanh")
|
387 |
+
self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations)
|
388 |
+
|
389 |
+
def forward(self, x: torch.Tensor, vec: torch.Tensor, pe: torch.Tensor) -> torch.Tensor:
|
390 |
+
"""#### Forward pass for the SingleStreamBlock class.
|
391 |
+
|
392 |
+
#### Args:
|
393 |
+
- `x` (Tensor): The input tensor.
|
394 |
+
- `vec` (Tensor): The vector tensor.
|
395 |
+
- `pe` (Tensor): The positional encoding tensor.
|
396 |
+
|
397 |
+
#### Returns:
|
398 |
+
- `Tensor`: The modified tensor.
|
399 |
+
"""
|
400 |
+
mod, _ = self.modulation(vec)
|
401 |
+
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
|
402 |
+
qkv, mlp = torch.split(
|
403 |
+
self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1
|
404 |
+
)
|
405 |
+
|
406 |
+
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(
|
407 |
+
2, 0, 3, 1, 4
|
408 |
+
)
|
409 |
+
q, k = self.norm(q, k, v)
|
410 |
+
|
411 |
+
# compute attention
|
412 |
+
attn = attention(q, k, v, pe=pe)
|
413 |
+
# compute activation in mlp stream, cat again and run second linear layer
|
414 |
+
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
415 |
+
x += mod.gate * output
|
416 |
+
if x.dtype == torch.float16:
|
417 |
+
x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
|
418 |
+
return x
|
419 |
+
|
420 |
+
class LastLayer(nn.Module):
|
421 |
+
def __init__(
|
422 |
+
self,
|
423 |
+
hidden_size: int,
|
424 |
+
patch_size: int,
|
425 |
+
out_channels: int,
|
426 |
+
dtype=None,
|
427 |
+
device=None,
|
428 |
+
operations=None,
|
429 |
+
):
|
430 |
+
"""#### Initialize the LastLayer class.
|
431 |
+
|
432 |
+
#### Args:
|
433 |
+
- `hidden_size` (int): The hidden size.
|
434 |
+
- `patch_size` (int): The patch size.
|
435 |
+
- `out_channels` (int): The number of output channels.
|
436 |
+
- `dtype` (optional): The data type.
|
437 |
+
- `device` (optional): The device.
|
438 |
+
- `operations` (optional): The operations module.
|
439 |
+
"""
|
440 |
+
super().__init__()
|
441 |
+
self.norm_final = operations.LayerNorm(
|
442 |
+
hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device
|
443 |
+
)
|
444 |
+
self.linear = operations.Linear(
|
445 |
+
hidden_size,
|
446 |
+
patch_size * patch_size * out_channels,
|
447 |
+
bias=True,
|
448 |
+
dtype=dtype,
|
449 |
+
device=device,
|
450 |
+
)
|
451 |
+
self.adaLN_modulation = nn.Sequential(
|
452 |
+
nn.SiLU(),
|
453 |
+
operations.Linear(
|
454 |
+
hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device
|
455 |
+
),
|
456 |
+
)
|
457 |
+
|
458 |
+
def forward(self, x: torch.Tensor, vec: torch.Tensor) -> torch.Tensor:
|
459 |
+
"""#### Forward pass for the LastLayer class.
|
460 |
+
|
461 |
+
#### Args:
|
462 |
+
- `x` (torch.Tensor): The input tensor.
|
463 |
+
- `vec` (torch.Tensor): The vector tensor.
|
464 |
+
|
465 |
+
#### Returns:
|
466 |
+
- `torch.Tensor`: The output tensor.
|
467 |
+
"""
|
468 |
+
shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
|
469 |
+
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
|
470 |
+
x = self.linear(x)
|
471 |
+
return x
|
472 |
+
|
473 |
+
|
474 |
+
def pad_to_patch_size(img: torch.Tensor, patch_size: tuple = (2, 2), padding_mode: str = "circular") -> torch.Tensor:
|
475 |
+
"""#### Pad the image to the specified patch size.
|
476 |
+
|
477 |
+
#### Args:
|
478 |
+
- `img` (torch.Tensor): The input image tensor.
|
479 |
+
- `patch_size` (tuple, optional): The patch size. Defaults to (2, 2).
|
480 |
+
- `padding_mode` (str, optional): The padding mode. Defaults to "circular".
|
481 |
+
|
482 |
+
#### Returns:
|
483 |
+
- `torch.Tensor`: The padded image tensor.
|
484 |
+
"""
|
485 |
+
if (
|
486 |
+
padding_mode == "circular"
|
487 |
+
and torch.jit.is_tracing()
|
488 |
+
or torch.jit.is_scripting()
|
489 |
+
):
|
490 |
+
padding_mode = "reflect"
|
491 |
+
pad_h = (patch_size[0] - img.shape[-2] % patch_size[0]) % patch_size[0]
|
492 |
+
pad_w = (patch_size[1] - img.shape[-1] % patch_size[1]) % patch_size[1]
|
493 |
+
return torch.nn.functional.pad(img, (0, pad_w, 0, pad_h), mode=padding_mode)
|
494 |
+
|
495 |
+
|
496 |
+
try:
|
497 |
+
rms_norm_torch = torch.nn.functional.rms_norm
|
498 |
+
except Exception:
|
499 |
+
rms_norm_torch = None
|
500 |
+
|
501 |
+
|
502 |
+
def rms_norm(x: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6) -> torch.Tensor:
|
503 |
+
"""#### Apply RMS normalization to the input tensor.
|
504 |
+
|
505 |
+
#### Args:
|
506 |
+
- `x` (torch.Tensor): The input tensor.
|
507 |
+
- `weight` (torch.Tensor): The weight tensor.
|
508 |
+
- `eps` (float, optional): The epsilon value for numerical stability. Defaults to 1e-6.
|
509 |
+
|
510 |
+
#### Returns:
|
511 |
+
- `torch.Tensor`: The normalized tensor.
|
512 |
+
"""
|
513 |
+
if rms_norm_torch is not None and not (
|
514 |
+
torch.jit.is_tracing() or torch.jit.is_scripting()
|
515 |
+
):
|
516 |
+
return rms_norm_torch(
|
517 |
+
x,
|
518 |
+
weight.shape,
|
519 |
+
weight=cast.cast_to(weight, dtype=x.dtype, device=x.device),
|
520 |
+
eps=eps,
|
521 |
+
)
|
522 |
+
else:
|
523 |
+
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps)
|
524 |
+
return (x * rrms) * cast.cast_to(weight, dtype=x.dtype, device=x.device)
|
525 |
+
|
526 |
+
|
527 |
+
@dataclass
|
528 |
+
class FluxParams:
|
529 |
+
in_channels: int
|
530 |
+
vec_in_dim: int
|
531 |
+
context_in_dim: int
|
532 |
+
hidden_size: int
|
533 |
+
mlp_ratio: float
|
534 |
+
num_heads: int
|
535 |
+
depth: int
|
536 |
+
depth_single_blocks: int
|
537 |
+
axes_dim: list
|
538 |
+
theta: int
|
539 |
+
qkv_bias: bool
|
540 |
+
guidance_embed: bool
|
541 |
+
|
542 |
+
|
543 |
+
class Flux3(nn.Module):
|
544 |
+
"""
|
545 |
+
Transformer model for flow matching on sequences.
|
546 |
+
"""
|
547 |
+
|
548 |
+
def __init__(
|
549 |
+
self,
|
550 |
+
image_model=None,
|
551 |
+
final_layer: bool = True,
|
552 |
+
dtype=None,
|
553 |
+
device=None,
|
554 |
+
operations=None,
|
555 |
+
**kwargs,
|
556 |
+
):
|
557 |
+
"""#### Initialize the Flux3 class.
|
558 |
+
|
559 |
+
#### Args:
|
560 |
+
- `image_model` (optional): The image model.
|
561 |
+
- `final_layer` (bool, optional): Whether to include the final layer. Defaults to True.
|
562 |
+
- `dtype` (optional): The data type.
|
563 |
+
- `device` (optional): The device.
|
564 |
+
- `operations` (optional): The operations module.
|
565 |
+
- `**kwargs`: Additional keyword arguments.
|
566 |
+
"""
|
567 |
+
super().__init__()
|
568 |
+
self.dtype = dtype
|
569 |
+
params = FluxParams(**kwargs)
|
570 |
+
self.params = params
|
571 |
+
self.in_channels = params.in_channels * 2 * 2
|
572 |
+
self.out_channels = self.in_channels
|
573 |
+
if params.hidden_size % params.num_heads != 0:
|
574 |
+
raise ValueError(
|
575 |
+
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
|
576 |
+
)
|
577 |
+
pe_dim = params.hidden_size // params.num_heads
|
578 |
+
if sum(params.axes_dim) != pe_dim:
|
579 |
+
raise ValueError(
|
580 |
+
f"Got {params.axes_dim} but expected positional dim {pe_dim}"
|
581 |
+
)
|
582 |
+
self.hidden_size = params.hidden_size
|
583 |
+
self.num_heads = params.num_heads
|
584 |
+
self.pe_embedder = EmbedND(
|
585 |
+
dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim
|
586 |
+
)
|
587 |
+
self.img_in = operations.Linear(
|
588 |
+
self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device
|
589 |
+
)
|
590 |
+
self.time_in = MLPEmbedder(
|
591 |
+
in_dim=256,
|
592 |
+
hidden_dim=self.hidden_size,
|
593 |
+
dtype=dtype,
|
594 |
+
device=device,
|
595 |
+
operations=operations,
|
596 |
+
)
|
597 |
+
self.vector_in = MLPEmbedder(
|
598 |
+
params.vec_in_dim,
|
599 |
+
self.hidden_size,
|
600 |
+
dtype=dtype,
|
601 |
+
device=device,
|
602 |
+
operations=operations,
|
603 |
+
)
|
604 |
+
self.guidance_in = (
|
605 |
+
MLPEmbedder(
|
606 |
+
in_dim=256,
|
607 |
+
hidden_dim=self.hidden_size,
|
608 |
+
dtype=dtype,
|
609 |
+
device=device,
|
610 |
+
operations=operations,
|
611 |
+
)
|
612 |
+
if params.guidance_embed
|
613 |
+
else nn.Identity()
|
614 |
+
)
|
615 |
+
self.txt_in = operations.Linear(
|
616 |
+
params.context_in_dim, self.hidden_size, dtype=dtype, device=device
|
617 |
+
)
|
618 |
+
|
619 |
+
self.double_blocks = nn.ModuleList(
|
620 |
+
[
|
621 |
+
DoubleStreamBlock(
|
622 |
+
self.hidden_size,
|
623 |
+
self.num_heads,
|
624 |
+
mlp_ratio=params.mlp_ratio,
|
625 |
+
qkv_bias=params.qkv_bias,
|
626 |
+
dtype=dtype,
|
627 |
+
device=device,
|
628 |
+
operations=operations,
|
629 |
+
)
|
630 |
+
for _ in range(params.depth)
|
631 |
+
]
|
632 |
+
)
|
633 |
+
|
634 |
+
self.single_blocks = nn.ModuleList(
|
635 |
+
[
|
636 |
+
SingleStreamBlock(
|
637 |
+
self.hidden_size,
|
638 |
+
self.num_heads,
|
639 |
+
mlp_ratio=params.mlp_ratio,
|
640 |
+
dtype=dtype,
|
641 |
+
device=device,
|
642 |
+
operations=operations,
|
643 |
+
)
|
644 |
+
for _ in range(params.depth_single_blocks)
|
645 |
+
]
|
646 |
+
)
|
647 |
+
|
648 |
+
if final_layer:
|
649 |
+
self.final_layer = LastLayer(
|
650 |
+
self.hidden_size,
|
651 |
+
1,
|
652 |
+
self.out_channels,
|
653 |
+
dtype=dtype,
|
654 |
+
device=device,
|
655 |
+
operations=operations,
|
656 |
+
)
|
657 |
+
|
658 |
+
def forward_orig(
|
659 |
+
self,
|
660 |
+
img: torch.Tensor,
|
661 |
+
img_ids: torch.Tensor,
|
662 |
+
txt: torch.Tensor,
|
663 |
+
txt_ids: torch.Tensor,
|
664 |
+
timesteps: torch.Tensor,
|
665 |
+
y: torch.Tensor,
|
666 |
+
guidance: torch.Tensor = None,
|
667 |
+
control=None,
|
668 |
+
) -> torch.Tensor:
|
669 |
+
"""#### Original forward pass for the Flux3 class.
|
670 |
+
|
671 |
+
#### Args:
|
672 |
+
- `img` (torch.Tensor): The image tensor.
|
673 |
+
- `img_ids` (torch.Tensor): The image IDs tensor.
|
674 |
+
- `txt` (torch.Tensor): The text tensor.
|
675 |
+
- `txt_ids` (torch.Tensor): The text IDs tensor.
|
676 |
+
- `timesteps` (torch.Tensor): The timesteps tensor.
|
677 |
+
- `y` (torch.Tensor): The vector tensor.
|
678 |
+
- `guidance` (torch.Tensor, optional): The guidance tensor. Defaults to None.
|
679 |
+
- `control` (optional): The control tensor. Defaults to None.
|
680 |
+
|
681 |
+
#### Returns:
|
682 |
+
- `torch.Tensor`: The output tensor.
|
683 |
+
"""
|
684 |
+
if img.ndim != 3 or txt.ndim != 3:
|
685 |
+
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
686 |
+
|
687 |
+
# running on sequences img
|
688 |
+
img = self.img_in(img)
|
689 |
+
vec = self.time_in(sampling_util.timestep_embedding_flux(timesteps, 256).to(img.dtype))
|
690 |
+
if self.params.guidance_embed:
|
691 |
+
if guidance is None:
|
692 |
+
raise ValueError(
|
693 |
+
"Didn't get guidance strength for guidance distilled model."
|
694 |
+
)
|
695 |
+
vec = vec + self.guidance_in(
|
696 |
+
sampling_util.timestep_embedding_flux(guidance, 256).to(img.dtype)
|
697 |
+
)
|
698 |
+
|
699 |
+
vec = vec + self.vector_in(y)
|
700 |
+
txt = self.txt_in(txt)
|
701 |
+
|
702 |
+
ids = torch.cat((txt_ids, img_ids), dim=1)
|
703 |
+
pe = self.pe_embedder(ids)
|
704 |
+
|
705 |
+
for i, block in enumerate(self.double_blocks):
|
706 |
+
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
|
707 |
+
|
708 |
+
if control is not None: # Controlnet
|
709 |
+
control_i = control.get("input")
|
710 |
+
if i < len(control_i):
|
711 |
+
add = control_i[i]
|
712 |
+
if add is not None:
|
713 |
+
img += add
|
714 |
+
|
715 |
+
img = torch.cat((txt, img), 1)
|
716 |
+
|
717 |
+
for i, block in enumerate(self.single_blocks):
|
718 |
+
img = block(img, vec=vec, pe=pe)
|
719 |
+
|
720 |
+
if control is not None: # Controlnet
|
721 |
+
control_o = control.get("output")
|
722 |
+
if i < len(control_o):
|
723 |
+
add = control_o[i]
|
724 |
+
if add is not None:
|
725 |
+
img[:, txt.shape[1] :, ...] += add
|
726 |
+
|
727 |
+
img = img[:, txt.shape[1] :, ...]
|
728 |
+
|
729 |
+
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
730 |
+
return img
|
731 |
+
|
732 |
+
def forward(self, x: torch.Tensor, timestep: torch.Tensor, context: torch.Tensor, y: torch.Tensor, guidance: torch.Tensor, control=None, **kwargs) -> torch.Tensor:
|
733 |
+
"""#### Forward pass for the Flux3 class.
|
734 |
+
|
735 |
+
#### Args:
|
736 |
+
- `x` (torch.Tensor): The input tensor.
|
737 |
+
- `timestep` (torch.Tensor): The timestep tensor.
|
738 |
+
- `context` (torch.Tensor): The context tensor.
|
739 |
+
- `y` (torch.Tensor): The vector tensor.
|
740 |
+
- `guidance` (torch.Tensor): The guidance tensor.
|
741 |
+
- `control` (optional): The control tensor. Defaults to None.
|
742 |
+
- `**kwargs`: Additional keyword arguments.
|
743 |
+
|
744 |
+
#### Returns:
|
745 |
+
- `torch.Tensor`: The output tensor.
|
746 |
+
"""
|
747 |
+
bs, c, h, w = x.shape
|
748 |
+
patch_size = 2
|
749 |
+
x = pad_to_patch_size(x, (patch_size, patch_size))
|
750 |
+
|
751 |
+
img = rearrange(
|
752 |
+
x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size
|
753 |
+
)
|
754 |
+
|
755 |
+
h_len = (h + (patch_size // 2)) // patch_size
|
756 |
+
w_len = (w + (patch_size // 2)) // patch_size
|
757 |
+
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
758 |
+
img_ids[..., 1] = (
|
759 |
+
img_ids[..., 1]
|
760 |
+
+ torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype)[
|
761 |
+
:, None
|
762 |
+
]
|
763 |
+
)
|
764 |
+
img_ids[..., 2] = (
|
765 |
+
img_ids[..., 2]
|
766 |
+
+ torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype)[
|
767 |
+
None, :
|
768 |
+
]
|
769 |
+
)
|
770 |
+
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
771 |
+
|
772 |
+
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
773 |
+
out = self.forward_orig(
|
774 |
+
img, img_ids, context, txt_ids, timestep, y, guidance, control
|
775 |
+
)
|
776 |
+
return rearrange(
|
777 |
+
out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2
|
778 |
+
)[:, :, :h, :w]
|
779 |
+
|
780 |
+
|
781 |
+
class Flux2(ModelBase.BaseModel):
|
782 |
+
def __init__(self, model_config: dict, model_type=sampling.ModelType.FLUX, device=None):
|
783 |
+
"""#### Initialize the Flux2 class.
|
784 |
+
|
785 |
+
#### Args:
|
786 |
+
- `model_config` (dict): The model configuration.
|
787 |
+
- `model_type` (sampling.ModelType, optional): The model type. Defaults to sampling.ModelType.FLUX.
|
788 |
+
- `device` (optional): The device.
|
789 |
+
"""
|
790 |
+
super().__init__(model_config, model_type, device=device, unet_model=Flux3, flux=True)
|
791 |
+
|
792 |
+
def encode_adm(self, **kwargs) -> torch.Tensor:
|
793 |
+
"""#### Encode the ADM.
|
794 |
+
|
795 |
+
#### Args:
|
796 |
+
- `**kwargs`: Additional keyword arguments.
|
797 |
+
|
798 |
+
#### Returns:
|
799 |
+
- `torch.Tensor`: The encoded ADM tensor.
|
800 |
+
"""
|
801 |
+
return kwargs["pooled_output"]
|
802 |
+
|
803 |
+
def extra_conds(self, **kwargs) -> dict:
|
804 |
+
"""#### Get extra conditions.
|
805 |
+
|
806 |
+
#### Args:
|
807 |
+
- `**kwargs`: Additional keyword arguments.
|
808 |
+
|
809 |
+
#### Returns:
|
810 |
+
- `dict`: The extra conditions.
|
811 |
+
"""
|
812 |
+
out = super().extra_conds(**kwargs)
|
813 |
+
cross_attn = kwargs.get("cross_attn", None)
|
814 |
+
if cross_attn is not None:
|
815 |
+
out["c_crossattn"] = cond.CONDRegular(cross_attn)
|
816 |
+
out["guidance"] = cond.CONDRegular(torch.FloatTensor([kwargs.get("guidance", 3.5)]))
|
817 |
+
return out
|
818 |
+
|
819 |
+
|
820 |
+
class Flux(ModelBase.BASE):
|
821 |
+
unet_config = {
|
822 |
+
"image_model": "flux",
|
823 |
+
"guidance_embed": True,
|
824 |
+
}
|
825 |
+
|
826 |
+
sampling_settings = {}
|
827 |
+
|
828 |
+
unet_extra_config = {}
|
829 |
+
latent_format = Latent.Flux1
|
830 |
+
|
831 |
+
memory_usage_factor = 2.8
|
832 |
+
|
833 |
+
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
834 |
+
|
835 |
+
vae_key_prefix = ["vae."]
|
836 |
+
text_encoder_key_prefix = ["text_encoders."]
|
837 |
+
|
838 |
+
def get_model(self, state_dict: dict, prefix: str = "", device=None) -> Flux2:
|
839 |
+
"""#### Get the model.
|
840 |
+
|
841 |
+
#### Args:
|
842 |
+
- `state_dict` (dict): The state dictionary.
|
843 |
+
- `prefix` (str, optional): The prefix. Defaults to "".
|
844 |
+
- `device` (optional): The device.
|
845 |
+
|
846 |
+
#### Returns:
|
847 |
+
- `Flux2`: The Flux2 model.
|
848 |
+
"""
|
849 |
+
out = Flux2(self, device=device)
|
850 |
+
return out
|
851 |
+
|
852 |
+
|
853 |
+
models = [Flux]
|
modules/Device/Device.py
ADDED
@@ -0,0 +1,1602 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import platform
|
3 |
+
import sys
|
4 |
+
from enum import Enum
|
5 |
+
from typing import Tuple, Union
|
6 |
+
import packaging.version
|
7 |
+
|
8 |
+
import psutil
|
9 |
+
import torch
|
10 |
+
|
11 |
+
if packaging.version.parse(torch.__version__) >= packaging.version.parse("1.12.0"):
|
12 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
13 |
+
|
14 |
+
|
15 |
+
class VRAMState(Enum):
|
16 |
+
"""#### Enum for VRAM states.
|
17 |
+
"""
|
18 |
+
DISABLED = 0 # No vram present: no need to move _internal to vram
|
19 |
+
NO_VRAM = 1 # Very low vram: enable all the options to save vram
|
20 |
+
LOW_VRAM = 2
|
21 |
+
NORMAL_VRAM = 3
|
22 |
+
HIGH_VRAM = 4
|
23 |
+
SHARED = 5 # No dedicated vram: memory shared between CPU and GPU but _internal still need to be moved between both.
|
24 |
+
|
25 |
+
|
26 |
+
class CPUState(Enum):
|
27 |
+
"""#### Enum for CPU states.
|
28 |
+
"""
|
29 |
+
GPU = 0
|
30 |
+
CPU = 1
|
31 |
+
MPS = 2
|
32 |
+
|
33 |
+
|
34 |
+
# Determine VRAM State
|
35 |
+
vram_state = VRAMState.NORMAL_VRAM
|
36 |
+
set_vram_to = VRAMState.NORMAL_VRAM
|
37 |
+
cpu_state = CPUState.GPU
|
38 |
+
|
39 |
+
total_vram = 0
|
40 |
+
|
41 |
+
lowvram_available = True
|
42 |
+
xpu_available = False
|
43 |
+
|
44 |
+
directml_enabled = False
|
45 |
+
try:
|
46 |
+
if torch.xpu.is_available():
|
47 |
+
xpu_available = True
|
48 |
+
except:
|
49 |
+
pass
|
50 |
+
|
51 |
+
try:
|
52 |
+
if torch.backends.mps.is_available():
|
53 |
+
cpu_state = CPUState.MPS
|
54 |
+
import torch.mps
|
55 |
+
except:
|
56 |
+
pass
|
57 |
+
|
58 |
+
|
59 |
+
def is_intel_xpu() -> bool:
|
60 |
+
"""#### Check if Intel XPU is available.
|
61 |
+
|
62 |
+
#### Returns:
|
63 |
+
- `bool`: Whether Intel XPU is available.
|
64 |
+
"""
|
65 |
+
global cpu_state
|
66 |
+
global xpu_available
|
67 |
+
if cpu_state == CPUState.GPU:
|
68 |
+
if xpu_available:
|
69 |
+
return True
|
70 |
+
return False
|
71 |
+
|
72 |
+
|
73 |
+
def get_torch_device() -> torch.device:
|
74 |
+
"""#### Get the torch device.
|
75 |
+
|
76 |
+
#### Returns:
|
77 |
+
- `torch.device`: The torch device.
|
78 |
+
"""
|
79 |
+
global directml_enabled
|
80 |
+
global cpu_state
|
81 |
+
if directml_enabled:
|
82 |
+
global directml_device
|
83 |
+
return directml_device
|
84 |
+
if cpu_state == CPUState.MPS:
|
85 |
+
return torch.device("mps")
|
86 |
+
if cpu_state == CPUState.CPU:
|
87 |
+
return torch.device("cpu")
|
88 |
+
else:
|
89 |
+
if is_intel_xpu():
|
90 |
+
return torch.device("xpu", torch.xpu.current_device())
|
91 |
+
else:
|
92 |
+
return torch.device(torch.cuda.current_device())
|
93 |
+
|
94 |
+
|
95 |
+
def get_total_memory(dev: torch.device = None, torch_total_too: bool = False) -> int:
|
96 |
+
"""#### Get the total memory.
|
97 |
+
|
98 |
+
#### Args:
|
99 |
+
- `dev` (torch.device, optional): The device. Defaults to None.
|
100 |
+
- `torch_total_too` (bool, optional): Whether to get the total memory in PyTorch. Defaults to False.
|
101 |
+
|
102 |
+
#### Returns:
|
103 |
+
- `int`: The total memory.
|
104 |
+
"""
|
105 |
+
global directml_enabled
|
106 |
+
if dev is None:
|
107 |
+
dev = get_torch_device()
|
108 |
+
|
109 |
+
if hasattr(dev, "type") and (dev.type == "cpu" or dev.type == "mps"):
|
110 |
+
mem_total = psutil.virtual_memory().total
|
111 |
+
mem_total_torch = mem_total
|
112 |
+
else:
|
113 |
+
if directml_enabled:
|
114 |
+
mem_total = 1024 * 1024 * 1024
|
115 |
+
mem_total_torch = mem_total
|
116 |
+
elif is_intel_xpu():
|
117 |
+
stats = torch.xpu.memory_stats(dev)
|
118 |
+
mem_reserved = stats["reserved_bytes.all.current"]
|
119 |
+
mem_total_torch = mem_reserved
|
120 |
+
mem_total = torch.xpu.get_device_properties(dev).total_memory
|
121 |
+
else:
|
122 |
+
stats = torch.cuda.memory_stats(dev)
|
123 |
+
mem_reserved = stats["reserved_bytes.all.current"]
|
124 |
+
_, mem_total_cuda = torch.cuda.mem_get_info(dev)
|
125 |
+
mem_total_torch = mem_reserved
|
126 |
+
mem_total = mem_total_cuda
|
127 |
+
|
128 |
+
if torch_total_too:
|
129 |
+
return (mem_total, mem_total_torch)
|
130 |
+
else:
|
131 |
+
return mem_total
|
132 |
+
|
133 |
+
|
134 |
+
total_vram = get_total_memory(get_torch_device()) / (1024 * 1024)
|
135 |
+
total_ram = psutil.virtual_memory().total / (1024 * 1024)
|
136 |
+
logging.info(
|
137 |
+
"Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram)
|
138 |
+
)
|
139 |
+
try:
|
140 |
+
OOM_EXCEPTION = torch.cuda.OutOfMemoryError
|
141 |
+
except:
|
142 |
+
OOM_EXCEPTION = Exception
|
143 |
+
|
144 |
+
XFORMERS_VERSION = ""
|
145 |
+
XFORMERS_ENABLED_VAE = True
|
146 |
+
try:
|
147 |
+
import xformers
|
148 |
+
import xformers.ops
|
149 |
+
|
150 |
+
XFORMERS_IS_AVAILABLE = True
|
151 |
+
try:
|
152 |
+
XFORMERS_IS_AVAILABLE = xformers._has_cpp_library
|
153 |
+
except:
|
154 |
+
pass
|
155 |
+
try:
|
156 |
+
XFORMERS_VERSION = xformers.version.__version__
|
157 |
+
logging.info("xformers version: {}".format(XFORMERS_VERSION))
|
158 |
+
if XFORMERS_VERSION.startswith("0.0.18"):
|
159 |
+
logging.warning(
|
160 |
+
"\nWARNING: This version of xformers has a major bug where you will get black images when generating high resolution images."
|
161 |
+
)
|
162 |
+
logging.warning(
|
163 |
+
"Please downgrade or upgrade xformers to a different version.\n"
|
164 |
+
)
|
165 |
+
XFORMERS_ENABLED_VAE = False
|
166 |
+
except:
|
167 |
+
pass
|
168 |
+
except:
|
169 |
+
XFORMERS_IS_AVAILABLE = False
|
170 |
+
|
171 |
+
|
172 |
+
def is_nvidia() -> bool:
|
173 |
+
"""#### Checks if user has an Nvidia GPU
|
174 |
+
|
175 |
+
#### Returns
|
176 |
+
- `bool`: Whether the GPU is Nvidia
|
177 |
+
"""
|
178 |
+
global cpu_state
|
179 |
+
if cpu_state == CPUState.GPU:
|
180 |
+
if torch.version.cuda:
|
181 |
+
return True
|
182 |
+
return False
|
183 |
+
|
184 |
+
|
185 |
+
ENABLE_PYTORCH_ATTENTION = False
|
186 |
+
|
187 |
+
VAE_DTYPE = torch.float32
|
188 |
+
|
189 |
+
try:
|
190 |
+
if is_nvidia():
|
191 |
+
torch_version = torch.version.__version__
|
192 |
+
if int(torch_version[0]) >= 2:
|
193 |
+
if ENABLE_PYTORCH_ATTENTION is False:
|
194 |
+
ENABLE_PYTORCH_ATTENTION = True
|
195 |
+
if (
|
196 |
+
torch.cuda.is_bf16_supported()
|
197 |
+
and torch.cuda.get_device_properties(torch.cuda.current_device()).major
|
198 |
+
>= 8
|
199 |
+
):
|
200 |
+
VAE_DTYPE = torch.bfloat16
|
201 |
+
except:
|
202 |
+
pass
|
203 |
+
|
204 |
+
if is_intel_xpu():
|
205 |
+
VAE_DTYPE = torch.bfloat16
|
206 |
+
|
207 |
+
if ENABLE_PYTORCH_ATTENTION:
|
208 |
+
torch.backends.cuda.enable_math_sdp(True)
|
209 |
+
torch.backends.cuda.enable_flash_sdp(True)
|
210 |
+
torch.backends.cuda.enable_mem_efficient_sdp(True)
|
211 |
+
|
212 |
+
|
213 |
+
FORCE_FP32 = False
|
214 |
+
FORCE_FP16 = False
|
215 |
+
|
216 |
+
if lowvram_available:
|
217 |
+
if set_vram_to in (VRAMState.LOW_VRAM, VRAMState.NO_VRAM):
|
218 |
+
vram_state = set_vram_to
|
219 |
+
|
220 |
+
if cpu_state != CPUState.GPU:
|
221 |
+
vram_state = VRAMState.DISABLED
|
222 |
+
|
223 |
+
if cpu_state == CPUState.MPS:
|
224 |
+
vram_state = VRAMState.SHARED
|
225 |
+
|
226 |
+
logging.info(f"Set vram state to: {vram_state.name}")
|
227 |
+
|
228 |
+
DISABLE_SMART_MEMORY = False
|
229 |
+
|
230 |
+
if DISABLE_SMART_MEMORY:
|
231 |
+
logging.info("Disabling smart memory management")
|
232 |
+
|
233 |
+
|
234 |
+
def get_torch_device_name(device: torch.device) -> str:
|
235 |
+
"""#### Get the name of the torch compatible device
|
236 |
+
|
237 |
+
#### Args:
|
238 |
+
- `device` (torch.device): the device
|
239 |
+
|
240 |
+
#### Returns:
|
241 |
+
- `str`: the name of the device
|
242 |
+
"""
|
243 |
+
if hasattr(device, "type"):
|
244 |
+
if device.type == "cuda":
|
245 |
+
try:
|
246 |
+
allocator_backend = torch.cuda.get_allocator_backend()
|
247 |
+
except:
|
248 |
+
allocator_backend = ""
|
249 |
+
return "{} {} : {}".format(
|
250 |
+
device, torch.cuda.get_device_name(device), allocator_backend
|
251 |
+
)
|
252 |
+
else:
|
253 |
+
return "{}".format(device.type)
|
254 |
+
elif is_intel_xpu():
|
255 |
+
return "{} {}".format(device, torch.xpu.get_device_name(device))
|
256 |
+
else:
|
257 |
+
return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device))
|
258 |
+
|
259 |
+
|
260 |
+
try:
|
261 |
+
logging.info("Device: {}".format(get_torch_device_name(get_torch_device())))
|
262 |
+
except:
|
263 |
+
logging.warning("Could not pick default device.")
|
264 |
+
|
265 |
+
logging.info("VAE dtype: {}".format(VAE_DTYPE))
|
266 |
+
|
267 |
+
current_loaded_models = []
|
268 |
+
|
269 |
+
|
270 |
+
def module_size(module: torch.nn.Module) -> int:
|
271 |
+
"""#### Get the size of a module
|
272 |
+
|
273 |
+
#### Args:
|
274 |
+
- `module` (torch.nn.Module): The module
|
275 |
+
|
276 |
+
#### Returns:
|
277 |
+
- `int`: The size of the module
|
278 |
+
"""
|
279 |
+
module_mem = 0
|
280 |
+
sd = module.state_dict()
|
281 |
+
for k in sd:
|
282 |
+
t = sd[k]
|
283 |
+
module_mem += t.nelement() * t.element_size()
|
284 |
+
return module_mem
|
285 |
+
|
286 |
+
|
287 |
+
class LoadedModel:
|
288 |
+
"""#### Class to load a model
|
289 |
+
"""
|
290 |
+
def __init__(self, model: torch.nn.Module):
|
291 |
+
"""#### Initialize the class
|
292 |
+
|
293 |
+
#### Args:
|
294 |
+
- `model`: The model
|
295 |
+
"""
|
296 |
+
self.model = model
|
297 |
+
self.device = model.load_device
|
298 |
+
self.weights_loaded = False
|
299 |
+
self.real_model = None
|
300 |
+
|
301 |
+
def model_memory(self):
|
302 |
+
"""#### Get the model memory
|
303 |
+
|
304 |
+
#### Returns:
|
305 |
+
- `int`: The model memory
|
306 |
+
"""
|
307 |
+
return self.model.model_size()
|
308 |
+
|
309 |
+
|
310 |
+
def model_offloaded_memory(self):
|
311 |
+
"""#### Get the offloaded model memory
|
312 |
+
|
313 |
+
#### Returns:
|
314 |
+
- `int`: The offloaded model memory
|
315 |
+
"""
|
316 |
+
return self.model.model_size() - self.model.loaded_size()
|
317 |
+
|
318 |
+
def model_memory_required(self, device: torch.device) -> int:
|
319 |
+
"""#### Get the required model memory
|
320 |
+
|
321 |
+
#### Args:
|
322 |
+
- `device`: The device
|
323 |
+
|
324 |
+
#### Returns:
|
325 |
+
- `int`: The required model memory
|
326 |
+
"""
|
327 |
+
if hasattr(self.model, 'current_loaded_device') and device == self.model.current_loaded_device():
|
328 |
+
return self.model_offloaded_memory()
|
329 |
+
else:
|
330 |
+
return self.model_memory()
|
331 |
+
|
332 |
+
def model_load(self, lowvram_model_memory: int = 0, force_patch_weights: bool = False) -> torch.nn.Module:
|
333 |
+
"""#### Load the model
|
334 |
+
|
335 |
+
#### Args:
|
336 |
+
- `lowvram_model_memory` (int, optional): The low VRAM model memory. Defaults to 0.
|
337 |
+
- `force_patch_weights` (bool, optional): Whether to force patch the weights. Defaults to False.
|
338 |
+
|
339 |
+
#### Returns:
|
340 |
+
- `torch.nn.Module`: The real model
|
341 |
+
"""
|
342 |
+
patch_model_to = self.device
|
343 |
+
|
344 |
+
self.model.model_patches_to(self.device)
|
345 |
+
self.model.model_patches_to(self.model.model_dtype())
|
346 |
+
|
347 |
+
load_weights = not self.weights_loaded
|
348 |
+
|
349 |
+
try:
|
350 |
+
if hasattr(self.model, "patch_model_lowvram") and lowvram_model_memory > 0 and load_weights:
|
351 |
+
self.real_model = self.model.patch_model_lowvram(
|
352 |
+
device_to=patch_model_to,
|
353 |
+
lowvram_model_memory=lowvram_model_memory,
|
354 |
+
force_patch_weights=force_patch_weights,
|
355 |
+
)
|
356 |
+
else:
|
357 |
+
self.real_model = self.model.patch_model(
|
358 |
+
device_to=patch_model_to, patch_weights=load_weights
|
359 |
+
)
|
360 |
+
except Exception as e:
|
361 |
+
self.model.unpatch_model(self.model.offload_device)
|
362 |
+
self.model_unload()
|
363 |
+
raise e
|
364 |
+
self.weights_loaded = True
|
365 |
+
return self.real_model
|
366 |
+
|
367 |
+
def model_load_flux(self, lowvram_model_memory: int = 0, force_patch_weights: bool = False) -> torch.nn.Module:
|
368 |
+
"""#### Load the model
|
369 |
+
|
370 |
+
#### Args:
|
371 |
+
- `lowvram_model_memory` (int, optional): The low VRAM model memory. Defaults to 0.
|
372 |
+
- `force_patch_weights` (bool, optional): Whether to force patch the weights. Defaults to False.
|
373 |
+
|
374 |
+
#### Returns:
|
375 |
+
- `torch.nn.Module`: The real model
|
376 |
+
"""
|
377 |
+
patch_model_to = self.device
|
378 |
+
|
379 |
+
self.model.model_patches_to(self.device)
|
380 |
+
self.model.model_patches_to(self.model.model_dtype())
|
381 |
+
|
382 |
+
load_weights = not self.weights_loaded
|
383 |
+
|
384 |
+
if self.model.loaded_size() > 0:
|
385 |
+
use_more_vram = lowvram_model_memory
|
386 |
+
if use_more_vram == 0:
|
387 |
+
use_more_vram = 1e32
|
388 |
+
self.model_use_more_vram(use_more_vram)
|
389 |
+
else:
|
390 |
+
try:
|
391 |
+
self.real_model = self.model.patch_model_flux(
|
392 |
+
device_to=patch_model_to,
|
393 |
+
lowvram_model_memory=lowvram_model_memory,
|
394 |
+
load_weights=load_weights,
|
395 |
+
force_patch_weights=force_patch_weights,
|
396 |
+
)
|
397 |
+
except Exception as e:
|
398 |
+
self.model.unpatch_model(self.model.offload_device)
|
399 |
+
self.model_unload()
|
400 |
+
raise e
|
401 |
+
|
402 |
+
if (
|
403 |
+
is_intel_xpu()
|
404 |
+
and "ipex" in globals()
|
405 |
+
and self.real_model is not None
|
406 |
+
):
|
407 |
+
import ipex
|
408 |
+
with torch.no_grad():
|
409 |
+
self.real_model = ipex.optimize(
|
410 |
+
self.real_model.eval(),
|
411 |
+
inplace=True,
|
412 |
+
graph_mode=True,
|
413 |
+
concat_linear=True,
|
414 |
+
)
|
415 |
+
|
416 |
+
self.weights_loaded = True
|
417 |
+
return self.real_model
|
418 |
+
|
419 |
+
def should_reload_model(self, force_patch_weights: bool = False) -> bool:
|
420 |
+
"""#### Checks if the model should be reloaded
|
421 |
+
|
422 |
+
#### Args:
|
423 |
+
- `force_patch_weights` (bool, optional): If model reloading should be enforced. Defaults to False.
|
424 |
+
|
425 |
+
#### Returns:
|
426 |
+
- `bool`: Whether the model should be reloaded
|
427 |
+
"""
|
428 |
+
if force_patch_weights and self.model.lowvram_patch_counter > 0:
|
429 |
+
return True
|
430 |
+
return False
|
431 |
+
|
432 |
+
def model_unload(self, unpatch_weights: bool = True) -> None:
|
433 |
+
"""#### Unloads the patched model
|
434 |
+
|
435 |
+
#### Args:
|
436 |
+
- `unpatch_weights` (bool, optional): Whether the weights should be unpatched. Defaults to True.
|
437 |
+
"""
|
438 |
+
self.model.unpatch_model(
|
439 |
+
self.model.offload_device, unpatch_weights=unpatch_weights
|
440 |
+
)
|
441 |
+
self.model.model_patches_to(self.model.offload_device)
|
442 |
+
self.weights_loaded = self.weights_loaded and not unpatch_weights
|
443 |
+
self.real_model = None
|
444 |
+
|
445 |
+
def model_use_more_vram(self, extra_memory: int) -> int:
|
446 |
+
"""#### Use more VRAM
|
447 |
+
|
448 |
+
#### Args:
|
449 |
+
- `extra_memory`: The extra memory
|
450 |
+
"""
|
451 |
+
return self.model.partially_load(self.device, extra_memory)
|
452 |
+
|
453 |
+
def __eq__(self, other: torch.nn.Module) -> bool:
|
454 |
+
"""#### Verify if the model is equal to another
|
455 |
+
|
456 |
+
#### Args:
|
457 |
+
- `other` (torch.nn.Module): the other model
|
458 |
+
|
459 |
+
#### Returns:
|
460 |
+
- `bool`: Whether the two models are equal
|
461 |
+
"""
|
462 |
+
return self.model is other.model
|
463 |
+
|
464 |
+
|
465 |
+
def minimum_inference_memory() -> int:
|
466 |
+
"""#### The minimum memory requirement for inference, equals to 1024^3
|
467 |
+
|
468 |
+
#### Returns:
|
469 |
+
- `int`: the memory requirement
|
470 |
+
"""
|
471 |
+
return 1024 * 1024 * 1024
|
472 |
+
|
473 |
+
|
474 |
+
def unload_model_clones(model: torch.nn.Module, unload_weights_only:bool = True, force_unload: bool = True) -> bool:
|
475 |
+
"""#### Unloads the model clones
|
476 |
+
|
477 |
+
#### Args:
|
478 |
+
- `model` (torch.nn.Module): The model
|
479 |
+
- `unload_weights_only` (bool, optional): Whether to unload only the weights. Defaults to True.
|
480 |
+
- `force_unload` (bool, optional): Whether to force the unload. Defaults to True.
|
481 |
+
|
482 |
+
#### Returns:
|
483 |
+
- `bool`: Whether the model was unloaded
|
484 |
+
"""
|
485 |
+
to_unload = []
|
486 |
+
for i in range(len(current_loaded_models)):
|
487 |
+
if model.is_clone(current_loaded_models[i].model):
|
488 |
+
to_unload = [i] + to_unload
|
489 |
+
|
490 |
+
if len(to_unload) == 0:
|
491 |
+
return True
|
492 |
+
|
493 |
+
same_weights = 0
|
494 |
+
|
495 |
+
if same_weights == len(to_unload):
|
496 |
+
unload_weight = False
|
497 |
+
else:
|
498 |
+
unload_weight = True
|
499 |
+
|
500 |
+
if not force_unload:
|
501 |
+
if unload_weights_only and unload_weight is False:
|
502 |
+
return None
|
503 |
+
|
504 |
+
for i in to_unload:
|
505 |
+
logging.debug("unload clone {} {}".format(i, unload_weight))
|
506 |
+
current_loaded_models.pop(i).model_unload(unpatch_weights=unload_weight)
|
507 |
+
|
508 |
+
return unload_weight
|
509 |
+
|
510 |
+
|
511 |
+
def free_memory(memory_required: int, device: torch.device, keep_loaded: list = []) -> None:
|
512 |
+
"""#### Free memory
|
513 |
+
|
514 |
+
#### Args:
|
515 |
+
- `memory_required` (int): The required memory
|
516 |
+
- `device` (torch.device): The device
|
517 |
+
- `keep_loaded` (list, optional): The list of loaded models to keep. Defaults to [].
|
518 |
+
"""
|
519 |
+
unloaded_model = []
|
520 |
+
can_unload = []
|
521 |
+
|
522 |
+
for i in range(len(current_loaded_models) - 1, -1, -1):
|
523 |
+
shift_model = current_loaded_models[i]
|
524 |
+
if shift_model.device == device:
|
525 |
+
if shift_model not in keep_loaded:
|
526 |
+
can_unload.append(
|
527 |
+
(sys.getrefcount(shift_model.model), shift_model.model_memory(), i)
|
528 |
+
)
|
529 |
+
|
530 |
+
for x in sorted(can_unload):
|
531 |
+
i = x[-1]
|
532 |
+
if not DISABLE_SMART_MEMORY:
|
533 |
+
if get_free_memory(device) > memory_required:
|
534 |
+
break
|
535 |
+
current_loaded_models[i].model_unload()
|
536 |
+
unloaded_model.append(i)
|
537 |
+
|
538 |
+
for i in sorted(unloaded_model, reverse=True):
|
539 |
+
current_loaded_models.pop(i)
|
540 |
+
|
541 |
+
if len(unloaded_model) > 0:
|
542 |
+
soft_empty_cache()
|
543 |
+
else:
|
544 |
+
if vram_state != VRAMState.HIGH_VRAM:
|
545 |
+
mem_free_total, mem_free_torch = get_free_memory(
|
546 |
+
device, torch_free_too=True
|
547 |
+
)
|
548 |
+
if mem_free_torch > mem_free_total * 0.25:
|
549 |
+
soft_empty_cache()
|
550 |
+
|
551 |
+
def use_more_memory(extra_memory: int, loaded_models: list, device: torch.device) -> None:
|
552 |
+
"""#### Use more memory
|
553 |
+
|
554 |
+
#### Args:
|
555 |
+
- `extra_memory` (int): The extra memory
|
556 |
+
- `loaded_models` (list): The loaded models
|
557 |
+
- `device` (torch.device): The device
|
558 |
+
"""
|
559 |
+
for m in loaded_models:
|
560 |
+
if m.device == device:
|
561 |
+
extra_memory -= m.model_use_more_vram(extra_memory)
|
562 |
+
if extra_memory <= 0:
|
563 |
+
break
|
564 |
+
|
565 |
+
WINDOWS = any(platform.win32_ver())
|
566 |
+
|
567 |
+
EXTRA_RESERVED_VRAM = 400 * 1024 * 1024
|
568 |
+
if WINDOWS:
|
569 |
+
EXTRA_RESERVED_VRAM = (
|
570 |
+
600 * 1024 * 1024
|
571 |
+
) # Windows is higher because of the shared vram issue
|
572 |
+
|
573 |
+
def extra_reserved_memory() -> int:
|
574 |
+
"""#### Extra reserved memory
|
575 |
+
|
576 |
+
#### Returns:
|
577 |
+
- `int`: The extra reserved memory
|
578 |
+
"""
|
579 |
+
return EXTRA_RESERVED_VRAM
|
580 |
+
|
581 |
+
def offloaded_memory(loaded_models: list, device: torch.device) -> int:
|
582 |
+
"""#### Offloaded memory
|
583 |
+
|
584 |
+
#### Args:
|
585 |
+
- `loaded_models` (list): The loaded models
|
586 |
+
- `device` (torch.device): The device
|
587 |
+
|
588 |
+
#### Returns:
|
589 |
+
- `int`: The offloaded memory
|
590 |
+
"""
|
591 |
+
offloaded_mem = 0
|
592 |
+
for m in loaded_models:
|
593 |
+
if m.device == device:
|
594 |
+
offloaded_mem += m.model_offloaded_memory()
|
595 |
+
return offloaded_mem
|
596 |
+
|
597 |
+
def load_models_gpu(models: list, memory_required: int = 0, force_patch_weights: bool = False, minimum_memory_required=None, force_full_load=False, flux_enabled: bool = False) -> None:
|
598 |
+
"""#### Load models on the GPU
|
599 |
+
|
600 |
+
#### Args:
|
601 |
+
- `models`(list): The models
|
602 |
+
- `memory_required` (int, optional): The required memory. Defaults to 0.
|
603 |
+
- `force_patch_weights` (bool, optional): Whether to force patch the weights. Defaults to False.
|
604 |
+
- `minimum_memory_required` (int, optional): The minimum memory required. Defaults to None.
|
605 |
+
- `force_full_load` (bool, optional
|
606 |
+
- `flux_enabled` (bool, optional): Whether flux is enabled. Defaults to False.
|
607 |
+
"""
|
608 |
+
global vram_state
|
609 |
+
if not flux_enabled:
|
610 |
+
|
611 |
+
inference_memory = minimum_inference_memory()
|
612 |
+
extra_mem = max(inference_memory, memory_required)
|
613 |
+
|
614 |
+
models = set(models)
|
615 |
+
|
616 |
+
models_to_load = []
|
617 |
+
models_already_loaded = []
|
618 |
+
for x in models:
|
619 |
+
loaded_model = LoadedModel(x)
|
620 |
+
loaded = None
|
621 |
+
|
622 |
+
try:
|
623 |
+
loaded_model_index = current_loaded_models.index(loaded_model)
|
624 |
+
except:
|
625 |
+
loaded_model_index = None
|
626 |
+
|
627 |
+
if loaded_model_index is not None:
|
628 |
+
loaded = current_loaded_models[loaded_model_index]
|
629 |
+
if loaded.should_reload_model(force_patch_weights=force_patch_weights):
|
630 |
+
current_loaded_models.pop(loaded_model_index).model_unload(
|
631 |
+
unpatch_weights=True
|
632 |
+
)
|
633 |
+
loaded = None
|
634 |
+
else:
|
635 |
+
models_already_loaded.append(loaded)
|
636 |
+
|
637 |
+
if loaded is None:
|
638 |
+
if hasattr(x, "model"):
|
639 |
+
logging.info(f"Requested to load {x.model.__class__.__name__}")
|
640 |
+
models_to_load.append(loaded_model)
|
641 |
+
|
642 |
+
if len(models_to_load) == 0:
|
643 |
+
devs = set(map(lambda a: a.device, models_already_loaded))
|
644 |
+
for d in devs:
|
645 |
+
if d != torch.device("cpu"):
|
646 |
+
free_memory(extra_mem, d, models_already_loaded)
|
647 |
+
return
|
648 |
+
|
649 |
+
logging.info(
|
650 |
+
f"Loading {len(models_to_load)} new model{'s' if len(models_to_load) > 1 else ''}"
|
651 |
+
)
|
652 |
+
|
653 |
+
total_memory_required = {}
|
654 |
+
for loaded_model in models_to_load:
|
655 |
+
if (
|
656 |
+
unload_model_clones(
|
657 |
+
loaded_model.model, unload_weights_only=True, force_unload=False
|
658 |
+
)
|
659 |
+
is True
|
660 |
+
): # unload clones where the weights are different
|
661 |
+
total_memory_required[loaded_model.device] = total_memory_required.get(
|
662 |
+
loaded_model.device, 0
|
663 |
+
) + loaded_model.model_memory_required(loaded_model.device)
|
664 |
+
|
665 |
+
for device in total_memory_required:
|
666 |
+
if device != torch.device("cpu"):
|
667 |
+
free_memory(
|
668 |
+
total_memory_required[device] * 1.3 + extra_mem,
|
669 |
+
device,
|
670 |
+
models_already_loaded,
|
671 |
+
)
|
672 |
+
|
673 |
+
for loaded_model in models_to_load:
|
674 |
+
weights_unloaded = unload_model_clones(
|
675 |
+
loaded_model.model, unload_weights_only=False, force_unload=False
|
676 |
+
) # unload the rest of the clones where the weights can stay loaded
|
677 |
+
if weights_unloaded is not None:
|
678 |
+
loaded_model.weights_loaded = not weights_unloaded
|
679 |
+
|
680 |
+
for loaded_model in models_to_load:
|
681 |
+
model = loaded_model.model
|
682 |
+
torch_dev = model.load_device
|
683 |
+
if is_device_cpu(torch_dev):
|
684 |
+
vram_set_state = VRAMState.DISABLED
|
685 |
+
else:
|
686 |
+
vram_set_state = vram_state
|
687 |
+
lowvram_model_memory = 0
|
688 |
+
if lowvram_available and (
|
689 |
+
vram_set_state == VRAMState.LOW_VRAM
|
690 |
+
or vram_set_state == VRAMState.NORMAL_VRAM
|
691 |
+
):
|
692 |
+
model_size = loaded_model.model_memory_required(torch_dev)
|
693 |
+
current_free_mem = get_free_memory(torch_dev)
|
694 |
+
lowvram_model_memory = int(
|
695 |
+
max(64 * (1024 * 1024), (current_free_mem - 1024 * (1024 * 1024)) / 1.3)
|
696 |
+
)
|
697 |
+
if model_size > (
|
698 |
+
current_free_mem - inference_memory
|
699 |
+
): # only switch to lowvram if really necessary
|
700 |
+
vram_set_state = VRAMState.LOW_VRAM
|
701 |
+
else:
|
702 |
+
lowvram_model_memory = 0
|
703 |
+
|
704 |
+
if vram_set_state == VRAMState.NO_VRAM:
|
705 |
+
lowvram_model_memory = 64 * 1024 * 1024
|
706 |
+
|
707 |
+
loaded_model.model_load(
|
708 |
+
lowvram_model_memory, force_patch_weights=force_patch_weights
|
709 |
+
)
|
710 |
+
current_loaded_models.insert(0, loaded_model)
|
711 |
+
return
|
712 |
+
else:
|
713 |
+
inference_memory = minimum_inference_memory()
|
714 |
+
extra_mem = max(inference_memory, memory_required + extra_reserved_memory())
|
715 |
+
if minimum_memory_required is None:
|
716 |
+
minimum_memory_required = extra_mem
|
717 |
+
else:
|
718 |
+
minimum_memory_required = max(
|
719 |
+
inference_memory, minimum_memory_required + extra_reserved_memory()
|
720 |
+
)
|
721 |
+
|
722 |
+
models = set(models)
|
723 |
+
|
724 |
+
models_to_load = []
|
725 |
+
models_already_loaded = []
|
726 |
+
for x in models:
|
727 |
+
loaded_model = LoadedModel(x)
|
728 |
+
loaded = None
|
729 |
+
|
730 |
+
try:
|
731 |
+
loaded_model_index = current_loaded_models.index(loaded_model)
|
732 |
+
except:
|
733 |
+
loaded_model_index = None
|
734 |
+
|
735 |
+
if loaded_model_index is not None:
|
736 |
+
loaded = current_loaded_models[loaded_model_index]
|
737 |
+
if loaded.should_reload_model(
|
738 |
+
force_patch_weights=force_patch_weights
|
739 |
+
): # TODO: cleanup this model reload logic
|
740 |
+
current_loaded_models.pop(loaded_model_index).model_unload(
|
741 |
+
unpatch_weights=True
|
742 |
+
)
|
743 |
+
loaded = None
|
744 |
+
else:
|
745 |
+
loaded.currently_used = True
|
746 |
+
models_already_loaded.append(loaded)
|
747 |
+
|
748 |
+
if loaded is None:
|
749 |
+
if hasattr(x, "model"):
|
750 |
+
logging.info(f"Requested to load {x.model.__class__.__name__}")
|
751 |
+
models_to_load.append(loaded_model)
|
752 |
+
|
753 |
+
if len(models_to_load) == 0:
|
754 |
+
devs = set(map(lambda a: a.device, models_already_loaded))
|
755 |
+
for d in devs:
|
756 |
+
if d != torch.device("cpu"):
|
757 |
+
free_memory(
|
758 |
+
extra_mem + offloaded_memory(models_already_loaded, d),
|
759 |
+
d,
|
760 |
+
models_already_loaded,
|
761 |
+
)
|
762 |
+
free_mem = get_free_memory(d)
|
763 |
+
if free_mem < minimum_memory_required:
|
764 |
+
logging.info(
|
765 |
+
"Unloading models for lowram load."
|
766 |
+
) # TODO: partial model unloading when this case happens, also handle the opposite case where models can be unlowvramed.
|
767 |
+
models_to_load = free_memory(minimum_memory_required, d)
|
768 |
+
logging.info("{} models unloaded.".format(len(models_to_load)))
|
769 |
+
else:
|
770 |
+
use_more_memory(
|
771 |
+
free_mem - minimum_memory_required, models_already_loaded, d
|
772 |
+
)
|
773 |
+
if len(models_to_load) == 0:
|
774 |
+
return
|
775 |
+
|
776 |
+
logging.info(
|
777 |
+
f"Loading {len(models_to_load)} new model{'s' if len(models_to_load) > 1 else ''}"
|
778 |
+
)
|
779 |
+
|
780 |
+
total_memory_required = {}
|
781 |
+
for loaded_model in models_to_load:
|
782 |
+
unload_model_clones(
|
783 |
+
loaded_model.model, unload_weights_only=True, force_unload=False
|
784 |
+
) # unload clones where the weights are different
|
785 |
+
total_memory_required[loaded_model.device] = total_memory_required.get(
|
786 |
+
loaded_model.device, 0
|
787 |
+
) + loaded_model.model_memory_required(loaded_model.device)
|
788 |
+
|
789 |
+
for loaded_model in models_already_loaded:
|
790 |
+
total_memory_required[loaded_model.device] = total_memory_required.get(
|
791 |
+
loaded_model.device, 0
|
792 |
+
) + loaded_model.model_memory_required(loaded_model.device)
|
793 |
+
|
794 |
+
for loaded_model in models_to_load:
|
795 |
+
weights_unloaded = unload_model_clones(
|
796 |
+
loaded_model.model, unload_weights_only=False, force_unload=False
|
797 |
+
) # unload the rest of the clones where the weights can stay loaded
|
798 |
+
if weights_unloaded is not None:
|
799 |
+
loaded_model.weights_loaded = not weights_unloaded
|
800 |
+
|
801 |
+
for device in total_memory_required:
|
802 |
+
if device != torch.device("cpu"):
|
803 |
+
free_memory(
|
804 |
+
total_memory_required[device] * 1.1 + extra_mem,
|
805 |
+
device,
|
806 |
+
models_already_loaded,
|
807 |
+
)
|
808 |
+
|
809 |
+
for loaded_model in models_to_load:
|
810 |
+
model = loaded_model.model
|
811 |
+
torch_dev = model.load_device
|
812 |
+
if is_device_cpu(torch_dev):
|
813 |
+
vram_set_state = VRAMState.DISABLED
|
814 |
+
else:
|
815 |
+
vram_set_state = vram_state
|
816 |
+
lowvram_model_memory = 0
|
817 |
+
if (
|
818 |
+
lowvram_available
|
819 |
+
and (
|
820 |
+
vram_set_state == VRAMState.LOW_VRAM
|
821 |
+
or vram_set_state == VRAMState.NORMAL_VRAM
|
822 |
+
)
|
823 |
+
and not force_full_load
|
824 |
+
):
|
825 |
+
model_size = loaded_model.model_memory_required(torch_dev)
|
826 |
+
current_free_mem = get_free_memory(torch_dev)
|
827 |
+
lowvram_model_memory = max(
|
828 |
+
64 * (1024 * 1024),
|
829 |
+
(current_free_mem - minimum_memory_required),
|
830 |
+
min(
|
831 |
+
current_free_mem * 0.4,
|
832 |
+
current_free_mem - minimum_inference_memory(),
|
833 |
+
),
|
834 |
+
)
|
835 |
+
if (
|
836 |
+
model_size <= lowvram_model_memory
|
837 |
+
): # only switch to lowvram if really necessary
|
838 |
+
lowvram_model_memory = 0
|
839 |
+
|
840 |
+
if vram_set_state == VRAMState.NO_VRAM:
|
841 |
+
lowvram_model_memory = 64 * 1024 * 1024
|
842 |
+
|
843 |
+
loaded_model.model_load_flux(
|
844 |
+
lowvram_model_memory, force_patch_weights=force_patch_weights
|
845 |
+
)
|
846 |
+
current_loaded_models.insert(0, loaded_model)
|
847 |
+
|
848 |
+
devs = set(map(lambda a: a.device, models_already_loaded))
|
849 |
+
for d in devs:
|
850 |
+
if d != torch.device("cpu"):
|
851 |
+
free_mem = get_free_memory(d)
|
852 |
+
if free_mem > minimum_memory_required:
|
853 |
+
use_more_memory(
|
854 |
+
free_mem - minimum_memory_required, models_already_loaded, d
|
855 |
+
)
|
856 |
+
return
|
857 |
+
|
858 |
+
def load_model_gpu(model: torch.nn.Module, flux_enabled:bool = False) -> None:
|
859 |
+
"""#### Load a model on the GPU
|
860 |
+
|
861 |
+
#### Args:
|
862 |
+
- `model` (torch.nn.Module): The model
|
863 |
+
- `flux_enable` (bool, optional): Whether flux is enabled. Defaults to False.
|
864 |
+
"""
|
865 |
+
return load_models_gpu([model], flux_enabled=flux_enabled)
|
866 |
+
|
867 |
+
|
868 |
+
def cleanup_models(keep_clone_weights_loaded:bool = False):
|
869 |
+
"""#### Cleanup the models
|
870 |
+
|
871 |
+
#### Args:
|
872 |
+
- `keep_clone_weights_loaded` (bool, optional): Whether to keep the clone weights loaded. Defaults to False.
|
873 |
+
"""
|
874 |
+
to_delete = []
|
875 |
+
for i in range(len(current_loaded_models)):
|
876 |
+
if sys.getrefcount(current_loaded_models[i].model) <= 2:
|
877 |
+
if not keep_clone_weights_loaded:
|
878 |
+
to_delete = [i] + to_delete
|
879 |
+
elif (
|
880 |
+
sys.getrefcount(current_loaded_models[i].real_model) <= 3
|
881 |
+
): # references from .real_model + the .model
|
882 |
+
to_delete = [i] + to_delete
|
883 |
+
|
884 |
+
for i in to_delete:
|
885 |
+
x = current_loaded_models.pop(i)
|
886 |
+
x.model_unload()
|
887 |
+
del x
|
888 |
+
|
889 |
+
|
890 |
+
def dtype_size(dtype: torch.dtype) -> int:
|
891 |
+
"""#### Get the size of a dtype
|
892 |
+
|
893 |
+
#### Args:
|
894 |
+
- `dtype` (torch.dtype): The dtype
|
895 |
+
|
896 |
+
#### Returns:
|
897 |
+
- `int`: The size of the dtype
|
898 |
+
"""
|
899 |
+
dtype_size = 4
|
900 |
+
if dtype == torch.float16 or dtype == torch.bfloat16:
|
901 |
+
dtype_size = 2
|
902 |
+
elif dtype == torch.float32:
|
903 |
+
dtype_size = 4
|
904 |
+
else:
|
905 |
+
try:
|
906 |
+
dtype_size = dtype.itemsize
|
907 |
+
except: # Old pytorch doesn't have .itemsize
|
908 |
+
pass
|
909 |
+
return dtype_size
|
910 |
+
|
911 |
+
|
912 |
+
def unet_offload_device() -> torch.device:
|
913 |
+
"""#### Get the offload device for UNet
|
914 |
+
|
915 |
+
#### Returns:
|
916 |
+
- `torch.device`: The offload device
|
917 |
+
"""
|
918 |
+
if vram_state == VRAMState.HIGH_VRAM:
|
919 |
+
return get_torch_device()
|
920 |
+
else:
|
921 |
+
return torch.device("cpu")
|
922 |
+
|
923 |
+
|
924 |
+
def unet_inital_load_device(parameters, dtype) -> torch.device:
|
925 |
+
"""#### Get the initial load device for UNet
|
926 |
+
|
927 |
+
#### Args:
|
928 |
+
- `parameters` (int): The parameters
|
929 |
+
- `dtype` (torch.dtype): The dtype
|
930 |
+
|
931 |
+
#### Returns:
|
932 |
+
- `torch.device`: The initial load device
|
933 |
+
"""
|
934 |
+
torch_dev = get_torch_device()
|
935 |
+
if vram_state == VRAMState.HIGH_VRAM:
|
936 |
+
return torch_dev
|
937 |
+
|
938 |
+
cpu_dev = torch.device("cpu")
|
939 |
+
if DISABLE_SMART_MEMORY:
|
940 |
+
return cpu_dev
|
941 |
+
|
942 |
+
model_size = dtype_size(dtype) * parameters
|
943 |
+
|
944 |
+
mem_dev = get_free_memory(torch_dev)
|
945 |
+
mem_cpu = get_free_memory(cpu_dev)
|
946 |
+
if mem_dev > mem_cpu and model_size < mem_dev:
|
947 |
+
return torch_dev
|
948 |
+
else:
|
949 |
+
return cpu_dev
|
950 |
+
|
951 |
+
|
952 |
+
def unet_dtype(
|
953 |
+
device: torch.dtype = None,
|
954 |
+
model_params: int = 0,
|
955 |
+
supported_dtypes: list = [torch.float16, torch.bfloat16, torch.float32],
|
956 |
+
) -> torch.dtype:
|
957 |
+
"""#### Get the dtype for UNet
|
958 |
+
|
959 |
+
#### Args:
|
960 |
+
- `device` (torch.dtype, optional): The device. Defaults to None.
|
961 |
+
- `model_params` (int, optional): The model parameters. Defaults to 0.
|
962 |
+
- `supported_dtypes` (list, optional): The supported dtypes. Defaults to [torch.float16, torch.bfloat16, torch.float32].
|
963 |
+
|
964 |
+
#### Returns:
|
965 |
+
- `torch.dtype`: The dtype
|
966 |
+
"""
|
967 |
+
if should_use_fp16(device=device, model_params=model_params, manual_cast=True):
|
968 |
+
if torch.float16 in supported_dtypes:
|
969 |
+
return torch.float16
|
970 |
+
if should_use_bf16(device, model_params=model_params, manual_cast=True):
|
971 |
+
if torch.bfloat16 in supported_dtypes:
|
972 |
+
return torch.bfloat16
|
973 |
+
return torch.float32
|
974 |
+
|
975 |
+
|
976 |
+
# None means no manual cast
|
977 |
+
def unet_manual_cast(
|
978 |
+
weight_dtype: torch.dtype,
|
979 |
+
inference_device: torch.device,
|
980 |
+
supported_dtypes: list = [torch.float16, torch.bfloat16, torch.float32],
|
981 |
+
) -> torch.dtype:
|
982 |
+
"""#### Manual cast for UNet
|
983 |
+
|
984 |
+
#### Args:
|
985 |
+
- `weight_dtype` (torch.dtype): The dtype of the weights
|
986 |
+
- `inference_device` (torch.device): The device used for inference
|
987 |
+
- `supported_dtypes` (list, optional): The supported dtypes. Defaults to [torch.float16, torch.bfloat16, torch.float32].
|
988 |
+
|
989 |
+
#### Returns:
|
990 |
+
- `torch.dtype`: The dtype
|
991 |
+
"""
|
992 |
+
if weight_dtype == torch.float32:
|
993 |
+
return None
|
994 |
+
|
995 |
+
fp16_supported = should_use_fp16(inference_device, prioritize_performance=False)
|
996 |
+
if fp16_supported and weight_dtype == torch.float16:
|
997 |
+
return None
|
998 |
+
|
999 |
+
bf16_supported = should_use_bf16(inference_device)
|
1000 |
+
if bf16_supported and weight_dtype == torch.bfloat16:
|
1001 |
+
return None
|
1002 |
+
|
1003 |
+
if fp16_supported and torch.float16 in supported_dtypes:
|
1004 |
+
return torch.float16
|
1005 |
+
|
1006 |
+
elif bf16_supported and torch.bfloat16 in supported_dtypes:
|
1007 |
+
return torch.bfloat16
|
1008 |
+
else:
|
1009 |
+
return torch.float32
|
1010 |
+
|
1011 |
+
|
1012 |
+
def text_encoder_offload_device() -> torch.device:
|
1013 |
+
"""#### Get the offload device for the text encoder
|
1014 |
+
|
1015 |
+
#### Returns:
|
1016 |
+
- `torch.device`: The offload device
|
1017 |
+
"""
|
1018 |
+
return torch.device("cpu")
|
1019 |
+
|
1020 |
+
|
1021 |
+
def text_encoder_device() -> torch.device:
|
1022 |
+
"""#### Get the device for the text encoder
|
1023 |
+
|
1024 |
+
#### Returns:
|
1025 |
+
- `torch.device`: The device
|
1026 |
+
"""
|
1027 |
+
if vram_state == VRAMState.HIGH_VRAM or vram_state == VRAMState.NORMAL_VRAM:
|
1028 |
+
if should_use_fp16(prioritize_performance=False):
|
1029 |
+
return get_torch_device()
|
1030 |
+
else:
|
1031 |
+
return torch.device("cpu")
|
1032 |
+
else:
|
1033 |
+
return torch.device("cpu")
|
1034 |
+
|
1035 |
+
def text_encoder_initial_device(load_device: torch.device, offload_device: torch.device, model_size: int = 0) -> torch.device:
|
1036 |
+
"""#### Get the initial device for the text encoder
|
1037 |
+
|
1038 |
+
#### Args:
|
1039 |
+
- `load_device` (torch.device): The load device
|
1040 |
+
- `offload_device` (torch.device): The offload device
|
1041 |
+
- `model_size` (int, optional): The model size. Defaults to 0.
|
1042 |
+
|
1043 |
+
#### Returns:
|
1044 |
+
- `torch.device`: The initial device
|
1045 |
+
"""
|
1046 |
+
if load_device == offload_device or model_size <= 1024 * 1024 * 1024:
|
1047 |
+
return offload_device
|
1048 |
+
|
1049 |
+
if is_device_mps(load_device):
|
1050 |
+
return offload_device
|
1051 |
+
|
1052 |
+
mem_l = get_free_memory(load_device)
|
1053 |
+
mem_o = get_free_memory(offload_device)
|
1054 |
+
if mem_l > (mem_o * 0.5) and model_size * 1.2 < mem_l:
|
1055 |
+
return load_device
|
1056 |
+
else:
|
1057 |
+
return offload_device
|
1058 |
+
|
1059 |
+
|
1060 |
+
def text_encoder_dtype(device: torch.device = None) -> torch.dtype:
|
1061 |
+
"""#### Get the dtype for the text encoder
|
1062 |
+
|
1063 |
+
#### Args:
|
1064 |
+
- `device` (torch.device, optional): The device used by the text encoder. Defaults to None.
|
1065 |
+
|
1066 |
+
Returns:
|
1067 |
+
torch.dtype: The dtype
|
1068 |
+
"""
|
1069 |
+
if is_device_cpu(device):
|
1070 |
+
return torch.float16
|
1071 |
+
|
1072 |
+
return torch.float16
|
1073 |
+
|
1074 |
+
|
1075 |
+
def intermediate_device() -> torch.device:
|
1076 |
+
"""#### Get the intermediate device
|
1077 |
+
|
1078 |
+
#### Returns:
|
1079 |
+
- `torch.device`: The intermediate device
|
1080 |
+
"""
|
1081 |
+
return torch.device("cpu")
|
1082 |
+
|
1083 |
+
|
1084 |
+
def vae_device() -> torch.device:
|
1085 |
+
"""#### Get the VAE device
|
1086 |
+
|
1087 |
+
#### Returns:
|
1088 |
+
- `torch.device`: The VAE device
|
1089 |
+
"""
|
1090 |
+
return get_torch_device()
|
1091 |
+
|
1092 |
+
|
1093 |
+
def vae_offload_device() -> torch.device:
|
1094 |
+
"""#### Get the offload device for VAE
|
1095 |
+
|
1096 |
+
#### Returns:
|
1097 |
+
- `torch.device`: The offload device
|
1098 |
+
"""
|
1099 |
+
return torch.device("cpu")
|
1100 |
+
|
1101 |
+
|
1102 |
+
def vae_dtype():
|
1103 |
+
"""#### Get the dtype for VAE
|
1104 |
+
|
1105 |
+
#### Returns:
|
1106 |
+
- `torch.dtype`: The dtype
|
1107 |
+
"""
|
1108 |
+
global VAE_DTYPE
|
1109 |
+
return VAE_DTYPE
|
1110 |
+
|
1111 |
+
|
1112 |
+
def get_autocast_device(dev: torch.device) -> str:
|
1113 |
+
"""#### Get the autocast device
|
1114 |
+
|
1115 |
+
#### Args:
|
1116 |
+
- `dev` (torch.device): The device
|
1117 |
+
|
1118 |
+
#### Returns:
|
1119 |
+
- `str`: The autocast device type
|
1120 |
+
"""
|
1121 |
+
if hasattr(dev, "type"):
|
1122 |
+
return dev.type
|
1123 |
+
return "cuda"
|
1124 |
+
|
1125 |
+
|
1126 |
+
def supports_dtype(device: torch.device, dtype: torch.dtype) -> bool:
|
1127 |
+
"""#### Check if the device supports the dtype
|
1128 |
+
|
1129 |
+
#### Args:
|
1130 |
+
- `device` (torch.device): The device to check
|
1131 |
+
- `dtype` (torch.dtype): The dtype to check support
|
1132 |
+
|
1133 |
+
#### Returns:
|
1134 |
+
- `bool`: Whether the dtype is supported by the device
|
1135 |
+
"""
|
1136 |
+
if dtype == torch.float32:
|
1137 |
+
return True
|
1138 |
+
if is_device_cpu(device):
|
1139 |
+
return False
|
1140 |
+
if dtype == torch.float16:
|
1141 |
+
return True
|
1142 |
+
if dtype == torch.bfloat16:
|
1143 |
+
return True
|
1144 |
+
return False
|
1145 |
+
|
1146 |
+
|
1147 |
+
def device_supports_non_blocking(device: torch.device) -> bool:
|
1148 |
+
"""#### Check if the device supports non-blocking
|
1149 |
+
|
1150 |
+
#### Args:
|
1151 |
+
- `device` (torch.device): The device to check
|
1152 |
+
|
1153 |
+
#### Returns:
|
1154 |
+
- `bool`: Whether the device supports non-blocking
|
1155 |
+
"""
|
1156 |
+
if is_device_mps(device):
|
1157 |
+
return False # pytorch bug? mps doesn't support non blocking
|
1158 |
+
return True
|
1159 |
+
|
1160 |
+
def supports_cast(device: torch.device, dtype: torch.dtype): # TODO
|
1161 |
+
"""#### Check if the device supports casting
|
1162 |
+
|
1163 |
+
#### Args:
|
1164 |
+
- `device`: The device
|
1165 |
+
- `dtype`: The dtype
|
1166 |
+
|
1167 |
+
#### Returns:
|
1168 |
+
- `bool`: Whether the device supports casting
|
1169 |
+
"""
|
1170 |
+
if dtype == torch.float32:
|
1171 |
+
return True
|
1172 |
+
if dtype == torch.float16:
|
1173 |
+
return True
|
1174 |
+
if directml_enabled:
|
1175 |
+
return False
|
1176 |
+
if dtype == torch.bfloat16:
|
1177 |
+
return True
|
1178 |
+
if is_device_mps(device):
|
1179 |
+
return False
|
1180 |
+
if dtype == torch.float8_e4m3fn:
|
1181 |
+
return True
|
1182 |
+
if dtype == torch.float8_e5m2:
|
1183 |
+
return True
|
1184 |
+
return False
|
1185 |
+
|
1186 |
+
def cast_to_device(tensor: torch.Tensor, device: torch.device, dtype: torch.dtype, copy: bool = False) -> torch.Tensor:
|
1187 |
+
"""#### Cast a tensor to a device
|
1188 |
+
|
1189 |
+
#### Args:
|
1190 |
+
- `tensor` (torch.Tensor): The tensor to cast
|
1191 |
+
- `device` (torch.device): The device to cast the tensor to
|
1192 |
+
- `dtype` (torch.dtype): The dtype precision to cast to
|
1193 |
+
- `copy` (bool, optional): Whether to copy the tensor. Defaults to False.
|
1194 |
+
|
1195 |
+
#### Returns:
|
1196 |
+
- `torch.Tensor`: The tensor cast to the device
|
1197 |
+
"""
|
1198 |
+
device_supports_cast = False
|
1199 |
+
if tensor.dtype == torch.float32 or tensor.dtype == torch.float16:
|
1200 |
+
device_supports_cast = True
|
1201 |
+
elif tensor.dtype == torch.bfloat16:
|
1202 |
+
if hasattr(device, "type") and device.type.startswith("cuda"):
|
1203 |
+
device_supports_cast = True
|
1204 |
+
elif is_intel_xpu():
|
1205 |
+
device_supports_cast = True
|
1206 |
+
|
1207 |
+
non_blocking = device_supports_non_blocking(device)
|
1208 |
+
|
1209 |
+
if device_supports_cast:
|
1210 |
+
if copy:
|
1211 |
+
if tensor.device == device:
|
1212 |
+
return tensor.to(dtype, copy=copy, non_blocking=non_blocking)
|
1213 |
+
return tensor.to(device, copy=copy, non_blocking=non_blocking).to(
|
1214 |
+
dtype, non_blocking=non_blocking
|
1215 |
+
)
|
1216 |
+
else:
|
1217 |
+
return tensor.to(device, non_blocking=non_blocking).to(
|
1218 |
+
dtype, non_blocking=non_blocking
|
1219 |
+
)
|
1220 |
+
else:
|
1221 |
+
return tensor.to(device, dtype, copy=copy, non_blocking=non_blocking)
|
1222 |
+
|
1223 |
+
def pick_weight_dtype(dtype: torch.dtype, fallback_dtype: torch.dtype, device: torch.device) -> torch.dtype:
|
1224 |
+
"""#### Pick the weight dtype
|
1225 |
+
|
1226 |
+
#### Args:
|
1227 |
+
- `dtype`: The dtype
|
1228 |
+
- `fallback_dtype`: The fallback dtype
|
1229 |
+
- `device`: The device
|
1230 |
+
|
1231 |
+
#### Returns:
|
1232 |
+
- `torch.dtype`: The weight dtype
|
1233 |
+
"""
|
1234 |
+
if dtype is None:
|
1235 |
+
dtype = fallback_dtype
|
1236 |
+
elif dtype_size(dtype) > dtype_size(fallback_dtype):
|
1237 |
+
dtype = fallback_dtype
|
1238 |
+
|
1239 |
+
if not supports_cast(device, dtype):
|
1240 |
+
dtype = fallback_dtype
|
1241 |
+
|
1242 |
+
return dtype
|
1243 |
+
|
1244 |
+
def xformers_enabled() -> bool:
|
1245 |
+
"""#### Check if xformers is enabled
|
1246 |
+
|
1247 |
+
#### Returns:
|
1248 |
+
- `bool`: Whether xformers is enabled
|
1249 |
+
"""
|
1250 |
+
global directml_enabled
|
1251 |
+
global cpu_state
|
1252 |
+
if cpu_state != CPUState.GPU:
|
1253 |
+
return False
|
1254 |
+
if is_intel_xpu():
|
1255 |
+
return False
|
1256 |
+
if directml_enabled:
|
1257 |
+
return False
|
1258 |
+
return XFORMERS_IS_AVAILABLE
|
1259 |
+
|
1260 |
+
|
1261 |
+
def xformers_enabled_vae() -> bool:
|
1262 |
+
"""#### Check if xformers is enabled for VAE
|
1263 |
+
|
1264 |
+
#### Returns:
|
1265 |
+
- `bool`: Whether xformers is enabled for VAE
|
1266 |
+
"""
|
1267 |
+
enabled = xformers_enabled()
|
1268 |
+
if not enabled:
|
1269 |
+
return False
|
1270 |
+
|
1271 |
+
return XFORMERS_ENABLED_VAE
|
1272 |
+
|
1273 |
+
|
1274 |
+
def pytorch_attention_enabled() -> bool:
|
1275 |
+
"""#### Check if PyTorch attention is enabled
|
1276 |
+
|
1277 |
+
#### Returns:
|
1278 |
+
- `bool`: Whether PyTorch attention is enabled
|
1279 |
+
"""
|
1280 |
+
global ENABLE_PYTORCH_ATTENTION
|
1281 |
+
return ENABLE_PYTORCH_ATTENTION
|
1282 |
+
|
1283 |
+
def pytorch_attention_flash_attention() -> bool:
|
1284 |
+
"""#### Check if PyTorch flash attention is enabled and supported.
|
1285 |
+
|
1286 |
+
#### Returns:
|
1287 |
+
- `bool`: True if PyTorch flash attention is enabled and supported, False otherwise.
|
1288 |
+
"""
|
1289 |
+
global ENABLE_PYTORCH_ATTENTION
|
1290 |
+
if ENABLE_PYTORCH_ATTENTION:
|
1291 |
+
if is_nvidia(): # pytorch flash attention only works on Nvidia
|
1292 |
+
return True
|
1293 |
+
return False
|
1294 |
+
|
1295 |
+
|
1296 |
+
def get_free_memory(dev: torch.device = None, torch_free_too: bool = False) -> Union[int, Tuple[int, int]]:
|
1297 |
+
"""#### Get the free memory available on the device.
|
1298 |
+
|
1299 |
+
#### Args:
|
1300 |
+
- `dev` (torch.device, optional): The device to check memory for. Defaults to None.
|
1301 |
+
- `torch_free_too` (bool, optional): Whether to return both total and torch free memory. Defaults to False.
|
1302 |
+
|
1303 |
+
#### Returns:
|
1304 |
+
- `int` or `Tuple[int, int]`: The free memory available. If `torch_free_too` is True, returns a tuple of total and torch free memory.
|
1305 |
+
"""
|
1306 |
+
global directml_enabled
|
1307 |
+
if dev is None:
|
1308 |
+
dev = get_torch_device()
|
1309 |
+
|
1310 |
+
if hasattr(dev, "type") and (dev.type == "cpu" or dev.type == "mps"):
|
1311 |
+
mem_free_total = psutil.virtual_memory().available
|
1312 |
+
mem_free_torch = mem_free_total
|
1313 |
+
else:
|
1314 |
+
if directml_enabled:
|
1315 |
+
mem_free_total = 1024 * 1024 * 1024
|
1316 |
+
mem_free_torch = mem_free_total
|
1317 |
+
elif is_intel_xpu():
|
1318 |
+
stats = torch.xpu.memory_stats(dev)
|
1319 |
+
mem_active = stats["active_bytes.all.current"]
|
1320 |
+
mem_reserved = stats["reserved_bytes.all.current"]
|
1321 |
+
mem_free_torch = mem_reserved - mem_active
|
1322 |
+
mem_free_xpu = (
|
1323 |
+
torch.xpu.get_device_properties(dev).total_memory - mem_reserved
|
1324 |
+
)
|
1325 |
+
mem_free_total = mem_free_xpu + mem_free_torch
|
1326 |
+
else:
|
1327 |
+
stats = torch.cuda.memory_stats(dev)
|
1328 |
+
mem_active = stats["active_bytes.all.current"]
|
1329 |
+
mem_reserved = stats["reserved_bytes.all.current"]
|
1330 |
+
mem_free_cuda, _ = torch.cuda.mem_get_info(dev)
|
1331 |
+
mem_free_torch = mem_reserved - mem_active
|
1332 |
+
mem_free_total = mem_free_cuda + mem_free_torch
|
1333 |
+
|
1334 |
+
if torch_free_too:
|
1335 |
+
return (mem_free_total, mem_free_torch)
|
1336 |
+
else:
|
1337 |
+
return mem_free_total
|
1338 |
+
|
1339 |
+
|
1340 |
+
def cpu_mode() -> bool:
|
1341 |
+
"""#### Check if the current mode is CPU.
|
1342 |
+
|
1343 |
+
#### Returns:
|
1344 |
+
- `bool`: True if the current mode is CPU, False otherwise.
|
1345 |
+
"""
|
1346 |
+
global cpu_state
|
1347 |
+
return cpu_state == CPUState.CPU
|
1348 |
+
|
1349 |
+
|
1350 |
+
def mps_mode() -> bool:
|
1351 |
+
"""#### Check if the current mode is MPS.
|
1352 |
+
|
1353 |
+
#### Returns:
|
1354 |
+
- `bool`: True if the current mode is MPS, False otherwise.
|
1355 |
+
"""
|
1356 |
+
global cpu_state
|
1357 |
+
return cpu_state == CPUState.MPS
|
1358 |
+
|
1359 |
+
|
1360 |
+
def is_device_type(device: torch.device, type: str) -> bool:
|
1361 |
+
"""#### Check if the device is of a specific type.
|
1362 |
+
|
1363 |
+
#### Args:
|
1364 |
+
- `device` (torch.device): The device to check.
|
1365 |
+
- `type` (str): The type to check for.
|
1366 |
+
|
1367 |
+
#### Returns:
|
1368 |
+
- `bool`: True if the device is of the specified type, False otherwise.
|
1369 |
+
"""
|
1370 |
+
if hasattr(device, "type"):
|
1371 |
+
if device.type == type:
|
1372 |
+
return True
|
1373 |
+
return False
|
1374 |
+
|
1375 |
+
|
1376 |
+
def is_device_cpu(device: torch.device) -> bool:
|
1377 |
+
"""#### Check if the device is a CPU.
|
1378 |
+
|
1379 |
+
#### Args:
|
1380 |
+
- `device` (torch.device): The device to check.
|
1381 |
+
|
1382 |
+
#### Returns:
|
1383 |
+
- `bool`: True if the device is a CPU, False otherwise.
|
1384 |
+
"""
|
1385 |
+
return is_device_type(device, "cpu")
|
1386 |
+
|
1387 |
+
|
1388 |
+
def is_device_mps(device: torch.device) -> bool:
|
1389 |
+
"""#### Check if the device is an MPS.
|
1390 |
+
|
1391 |
+
#### Args:
|
1392 |
+
- `device` (torch.device): The device to check.
|
1393 |
+
|
1394 |
+
#### Returns:
|
1395 |
+
- `bool`: True if the device is an MPS, False otherwise.
|
1396 |
+
"""
|
1397 |
+
return is_device_type(device, "mps")
|
1398 |
+
|
1399 |
+
|
1400 |
+
def is_device_cuda(device: torch.device) -> bool:
|
1401 |
+
"""#### Check if the device is a CUDA device.
|
1402 |
+
|
1403 |
+
#### Args:
|
1404 |
+
- `device` (torch.device): The device to check.
|
1405 |
+
|
1406 |
+
#### Returns:
|
1407 |
+
- `bool`: True if the device is a CUDA device, False otherwise.
|
1408 |
+
"""
|
1409 |
+
return is_device_type(device, "cuda")
|
1410 |
+
|
1411 |
+
|
1412 |
+
def should_use_fp16(
|
1413 |
+
device: torch.device = None, model_params: int = 0, prioritize_performance: bool = True, manual_cast: bool = False
|
1414 |
+
) -> bool:
|
1415 |
+
"""#### Determine if FP16 should be used.
|
1416 |
+
|
1417 |
+
#### Args:
|
1418 |
+
- `device` (torch.device, optional): The device to check. Defaults to None.
|
1419 |
+
- `model_params` (int, optional): The number of model parameters. Defaults to 0.
|
1420 |
+
- `prioritize_performance` (bool, optional): Whether to prioritize performance. Defaults to True.
|
1421 |
+
- `manual_cast` (bool, optional): Whether to manually cast. Defaults to False.
|
1422 |
+
|
1423 |
+
#### Returns:
|
1424 |
+
- `bool`: True if FP16 should be used, False otherwise.
|
1425 |
+
"""
|
1426 |
+
global directml_enabled
|
1427 |
+
|
1428 |
+
if device is not None:
|
1429 |
+
if is_device_cpu(device):
|
1430 |
+
return False
|
1431 |
+
|
1432 |
+
if FORCE_FP16:
|
1433 |
+
return True
|
1434 |
+
|
1435 |
+
if device is not None:
|
1436 |
+
if is_device_mps(device):
|
1437 |
+
return True
|
1438 |
+
|
1439 |
+
if FORCE_FP32:
|
1440 |
+
return False
|
1441 |
+
|
1442 |
+
if directml_enabled:
|
1443 |
+
return False
|
1444 |
+
|
1445 |
+
if mps_mode():
|
1446 |
+
return True
|
1447 |
+
|
1448 |
+
if cpu_mode():
|
1449 |
+
return False
|
1450 |
+
|
1451 |
+
if is_intel_xpu():
|
1452 |
+
return True
|
1453 |
+
|
1454 |
+
if torch.version.hip:
|
1455 |
+
return True
|
1456 |
+
|
1457 |
+
props = torch.cuda.get_device_properties("cuda")
|
1458 |
+
if props.major >= 8:
|
1459 |
+
return True
|
1460 |
+
|
1461 |
+
if props.major < 6:
|
1462 |
+
return False
|
1463 |
+
|
1464 |
+
fp16_works = False
|
1465 |
+
nvidia_10_series = [
|
1466 |
+
"1080",
|
1467 |
+
"1070",
|
1468 |
+
"titan x",
|
1469 |
+
"p3000",
|
1470 |
+
"p3200",
|
1471 |
+
"p4000",
|
1472 |
+
"p4200",
|
1473 |
+
"p5000",
|
1474 |
+
"p5200",
|
1475 |
+
"p6000",
|
1476 |
+
"1060",
|
1477 |
+
"1050",
|
1478 |
+
"p40",
|
1479 |
+
"p100",
|
1480 |
+
"p6",
|
1481 |
+
"p4",
|
1482 |
+
]
|
1483 |
+
for x in nvidia_10_series:
|
1484 |
+
if x in props.name.lower():
|
1485 |
+
fp16_works = True
|
1486 |
+
|
1487 |
+
if fp16_works or manual_cast:
|
1488 |
+
free_model_memory = get_free_memory() * 0.9 - minimum_inference_memory()
|
1489 |
+
if (not prioritize_performance) or model_params * 4 > free_model_memory:
|
1490 |
+
return True
|
1491 |
+
|
1492 |
+
if props.major < 7:
|
1493 |
+
return False
|
1494 |
+
|
1495 |
+
nvidia_16_series = [
|
1496 |
+
"1660",
|
1497 |
+
"1650",
|
1498 |
+
"1630",
|
1499 |
+
"T500",
|
1500 |
+
"T550",
|
1501 |
+
"T600",
|
1502 |
+
"MX550",
|
1503 |
+
"MX450",
|
1504 |
+
"CMP 30HX",
|
1505 |
+
"T2000",
|
1506 |
+
"T1000",
|
1507 |
+
"T1200",
|
1508 |
+
]
|
1509 |
+
for x in nvidia_16_series:
|
1510 |
+
if x in props.name:
|
1511 |
+
return False
|
1512 |
+
|
1513 |
+
return True
|
1514 |
+
|
1515 |
+
|
1516 |
+
def should_use_bf16(
|
1517 |
+
device: torch.device = None, model_params: int = 0, prioritize_performance: bool = True, manual_cast: bool = False
|
1518 |
+
) -> bool:
|
1519 |
+
"""#### Determine if BF16 should be used.
|
1520 |
+
|
1521 |
+
#### Args:
|
1522 |
+
- `device` (torch.device, optional): The device to check. Defaults to None.
|
1523 |
+
- `model_params` (int, optional): The number of model parameters. Defaults to 0.
|
1524 |
+
- `prioritize_performance` (bool, optional): Whether to prioritize performance. Defaults to True.
|
1525 |
+
- `manual_cast` (bool, optional): Whether to manually cast. Defaults to False.
|
1526 |
+
|
1527 |
+
#### Returns:
|
1528 |
+
- `bool`: True if BF16 should be used, False otherwise.
|
1529 |
+
"""
|
1530 |
+
if device is not None:
|
1531 |
+
if is_device_cpu(device):
|
1532 |
+
return False
|
1533 |
+
|
1534 |
+
if device is not None:
|
1535 |
+
if is_device_mps(device):
|
1536 |
+
return False
|
1537 |
+
|
1538 |
+
if FORCE_FP32:
|
1539 |
+
return False
|
1540 |
+
|
1541 |
+
if directml_enabled:
|
1542 |
+
return False
|
1543 |
+
|
1544 |
+
if cpu_mode() or mps_mode():
|
1545 |
+
return False
|
1546 |
+
|
1547 |
+
if is_intel_xpu():
|
1548 |
+
return True
|
1549 |
+
|
1550 |
+
if device is None:
|
1551 |
+
device = torch.device("cuda")
|
1552 |
+
|
1553 |
+
props = torch.cuda.get_device_properties(device)
|
1554 |
+
if props.major >= 8:
|
1555 |
+
return True
|
1556 |
+
|
1557 |
+
bf16_works = torch.cuda.is_bf16_supported()
|
1558 |
+
|
1559 |
+
if bf16_works or manual_cast:
|
1560 |
+
free_model_memory = get_free_memory() * 0.9 - minimum_inference_memory()
|
1561 |
+
if (not prioritize_performance) or model_params * 4 > free_model_memory:
|
1562 |
+
return True
|
1563 |
+
|
1564 |
+
return False
|
1565 |
+
|
1566 |
+
|
1567 |
+
def soft_empty_cache(force: bool = False) -> None:
|
1568 |
+
"""#### Softly empty the cache.
|
1569 |
+
|
1570 |
+
#### Args:
|
1571 |
+
- `force` (bool, optional): Whether to force emptying the cache. Defaults to False.
|
1572 |
+
"""
|
1573 |
+
global cpu_state
|
1574 |
+
if cpu_state == CPUState.MPS:
|
1575 |
+
torch.mps.empty_cache()
|
1576 |
+
elif is_intel_xpu():
|
1577 |
+
torch.xpu.empty_cache()
|
1578 |
+
elif torch.cuda.is_available():
|
1579 |
+
if (
|
1580 |
+
force or is_nvidia()
|
1581 |
+
): # This seems to make things worse on ROCm so I only do it for cuda
|
1582 |
+
torch.cuda.empty_cache()
|
1583 |
+
torch.cuda.ipc_collect()
|
1584 |
+
|
1585 |
+
|
1586 |
+
def unload_all_models() -> None:
|
1587 |
+
"""#### Unload all models."""
|
1588 |
+
free_memory(1e30, get_torch_device())
|
1589 |
+
|
1590 |
+
|
1591 |
+
def resolve_lowvram_weight(weight: torch.Tensor, model: torch.nn.Module, key: str) -> torch.Tensor:
|
1592 |
+
"""#### Resolve low VRAM weight.
|
1593 |
+
|
1594 |
+
#### Args:
|
1595 |
+
- `weight` (torch.Tensor): The weight tensor.
|
1596 |
+
- `model` (torch.nn.Module): The model.
|
1597 |
+
- `key` (str): The key.
|
1598 |
+
|
1599 |
+
#### Returns:
|
1600 |
+
- `torch.Tensor`: The resolved weight tensor.
|
1601 |
+
"""
|
1602 |
+
return weight
|
modules/FileManaging/Downloader.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import glob
|
2 |
+
from huggingface_hub import hf_hub_download
|
3 |
+
|
4 |
+
|
5 |
+
def CheckAndDownload():
|
6 |
+
"""#### Check and download all the necessary safetensors and checkpoints models"""
|
7 |
+
if glob.glob("./_internal/checkpoints/*.safetensors") == []:
|
8 |
+
|
9 |
+
hf_hub_download(
|
10 |
+
repo_id="Meina/MeinaMix",
|
11 |
+
filename="Meina V10 - baked VAE.safetensors",
|
12 |
+
local_dir="./_internal/checkpoints/",
|
13 |
+
)
|
14 |
+
hf_hub_download(
|
15 |
+
repo_id="Lykon/DreamShaper",
|
16 |
+
filename="DreamShaper_8_pruned.safetensors",
|
17 |
+
local_dir="./_internal/checkpoints/",
|
18 |
+
)
|
19 |
+
if glob.glob("./_internal/yolos/*.pt") == []:
|
20 |
+
|
21 |
+
hf_hub_download(
|
22 |
+
repo_id="Bingsu/adetailer",
|
23 |
+
filename="hand_yolov9c.pt",
|
24 |
+
local_dir="./_internal/yolos/",
|
25 |
+
)
|
26 |
+
hf_hub_download(
|
27 |
+
repo_id="Bingsu/adetailer",
|
28 |
+
filename="face_yolov9c.pt",
|
29 |
+
local_dir="./_internal/yolos/",
|
30 |
+
)
|
31 |
+
hf_hub_download(
|
32 |
+
repo_id="Bingsu/adetailer",
|
33 |
+
filename="person_yolov8m-seg.pt",
|
34 |
+
local_dir="./_internal/yolos/",
|
35 |
+
)
|
36 |
+
hf_hub_download(
|
37 |
+
repo_id="segments-arnaud/sam_vit_b",
|
38 |
+
filename="sam_vit_b_01ec64.pth",
|
39 |
+
local_dir="./_internal/yolos/",
|
40 |
+
)
|
41 |
+
if glob.glob("./_internal/ESRGAN/*.pth") == []:
|
42 |
+
|
43 |
+
hf_hub_download(
|
44 |
+
repo_id="lllyasviel/Annotators",
|
45 |
+
filename="RealESRGAN_x4plus.pth",
|
46 |
+
local_dir="./_internal/ESRGAN/",
|
47 |
+
)
|
48 |
+
if glob.glob("./_internal/loras/*.safetensors") == []:
|
49 |
+
|
50 |
+
hf_hub_download(
|
51 |
+
repo_id="EvilEngine/add_detail",
|
52 |
+
filename="add_detail.safetensors",
|
53 |
+
local_dir="./_internal/loras/",
|
54 |
+
)
|
55 |
+
if glob.glob("./_internal/embeddings/*.pt") == []:
|
56 |
+
|
57 |
+
hf_hub_download(
|
58 |
+
repo_id="EvilEngine/badhandv4",
|
59 |
+
filename="badhandv4.pt",
|
60 |
+
local_dir="./_internal/embeddings/",
|
61 |
+
)
|
62 |
+
# hf_hub_download(
|
63 |
+
# repo_id="segments-arnaud/sam_vit_b",
|
64 |
+
# filename="EasyNegative.safetensors",
|
65 |
+
# local_dir="./_internal/embeddings/",
|
66 |
+
# )
|
67 |
+
if glob.glob("./_internal/vae_approx/*.pth") == []:
|
68 |
+
|
69 |
+
hf_hub_download(
|
70 |
+
repo_id="madebyollin/taesd",
|
71 |
+
filename="taesd_decoder.safetensors",
|
72 |
+
local_dir="./_internal/vae_approx/",
|
73 |
+
)
|
74 |
+
|
75 |
+
def CheckAndDownloadFlux():
|
76 |
+
"""#### Check and download all the necessary safetensors and checkpoints models for FLUX"""
|
77 |
+
if glob.glob("./_internal/embeddings/*.pt") == []:
|
78 |
+
hf_hub_download(
|
79 |
+
repo_id="EvilEngine/badhandv4",
|
80 |
+
filename="badhandv4.pt",
|
81 |
+
local_dir="./_internal/embeddings",
|
82 |
+
)
|
83 |
+
if glob.glob("./_internal/unet/*.gguf") == []:
|
84 |
+
|
85 |
+
hf_hub_download(
|
86 |
+
repo_id="city96/FLUX.1-dev-gguf",
|
87 |
+
filename="flux1-dev-Q8_0.gguf",
|
88 |
+
local_dir="./_internal/unet",
|
89 |
+
)
|
90 |
+
if glob.glob("./_internal/clip/*.gguf") == []:
|
91 |
+
|
92 |
+
hf_hub_download(
|
93 |
+
repo_id="city96/t5-v1_1-xxl-encoder-gguf",
|
94 |
+
filename="t5-v1_1-xxl-encoder-Q8_0.gguf",
|
95 |
+
local_dir="./_internal/clip",
|
96 |
+
)
|
97 |
+
hf_hub_download(
|
98 |
+
repo_id="comfyanonymous/flux_text_encoders",
|
99 |
+
filename="clip_l.safetensors",
|
100 |
+
local_dir="./_internal/clip",
|
101 |
+
)
|
102 |
+
if glob.glob("./_internal/vae/*.safetensors") == []:
|
103 |
+
|
104 |
+
hf_hub_download(
|
105 |
+
repo_id="black-forest-labs/FLUX.1-schnell",
|
106 |
+
filename="ae.safetensors",
|
107 |
+
local_dir="./_internal/vae",
|
108 |
+
)
|
109 |
+
|
110 |
+
if glob.glob("./_internal/vae_approx/*.pth") == []:
|
111 |
+
|
112 |
+
hf_hub_download(
|
113 |
+
repo_id="madebyollin/taef1",
|
114 |
+
filename="diffusion_pytorch_model.safetensors",
|
115 |
+
local_dir="./_internal/vae_approx/",
|
116 |
+
)
|
modules/FileManaging/ImageSaver.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
from PIL import Image
|
4 |
+
|
5 |
+
output_directory = "./_internal/output"
|
6 |
+
|
7 |
+
|
8 |
+
def get_output_directory() -> str:
|
9 |
+
"""#### Get the output directory.
|
10 |
+
|
11 |
+
#### Returns:
|
12 |
+
- `str`: The output directory.
|
13 |
+
"""
|
14 |
+
global output_directory
|
15 |
+
return output_directory
|
16 |
+
|
17 |
+
|
18 |
+
def get_save_image_path(
|
19 |
+
filename_prefix: str, output_dir: str, image_width: int = 0, image_height: int = 0
|
20 |
+
) -> tuple:
|
21 |
+
"""#### Get the save image path.
|
22 |
+
|
23 |
+
#### Args:
|
24 |
+
- `filename_prefix` (str): The filename prefix.
|
25 |
+
- `output_dir` (str): The output directory.
|
26 |
+
- `image_width` (int, optional): The image width. Defaults to 0.
|
27 |
+
- `image_height` (int, optional): The image height. Defaults to 0.
|
28 |
+
|
29 |
+
#### Returns:
|
30 |
+
- `tuple`: The full output folder, filename, counter, subfolder, and filename prefix.
|
31 |
+
"""
|
32 |
+
|
33 |
+
def map_filename(filename: str) -> tuple:
|
34 |
+
prefix_len = len(os.path.basename(filename_prefix))
|
35 |
+
prefix = filename[: prefix_len + 1]
|
36 |
+
try:
|
37 |
+
digits = int(filename[prefix_len + 1 :].split("_")[0])
|
38 |
+
except:
|
39 |
+
digits = 0
|
40 |
+
return (digits, prefix)
|
41 |
+
|
42 |
+
def compute_vars(input: str, image_width: int, image_height: int) -> str:
|
43 |
+
input = input.replace("%width%", str(image_width))
|
44 |
+
input = input.replace("%height%", str(image_height))
|
45 |
+
return input
|
46 |
+
|
47 |
+
filename_prefix = compute_vars(filename_prefix, image_width, image_height)
|
48 |
+
|
49 |
+
subfolder = os.path.dirname(os.path.normpath(filename_prefix))
|
50 |
+
filename = os.path.basename(os.path.normpath(filename_prefix))
|
51 |
+
|
52 |
+
full_output_folder = os.path.join(output_dir, subfolder)
|
53 |
+
try:
|
54 |
+
counter = (
|
55 |
+
max(
|
56 |
+
filter(
|
57 |
+
lambda a: a[1][:-1] == filename and a[1][-1] == "_",
|
58 |
+
map(map_filename, os.listdir(full_output_folder)),
|
59 |
+
)
|
60 |
+
)[0]
|
61 |
+
+ 1
|
62 |
+
)
|
63 |
+
except ValueError:
|
64 |
+
counter = 1
|
65 |
+
except FileNotFoundError:
|
66 |
+
os.makedirs(full_output_folder, exist_ok=True)
|
67 |
+
counter = 1
|
68 |
+
return full_output_folder, filename, counter, subfolder, filename_prefix
|
69 |
+
|
70 |
+
|
71 |
+
MAX_RESOLUTION = 16384
|
72 |
+
|
73 |
+
|
74 |
+
class SaveImage:
|
75 |
+
"""#### Class for saving images."""
|
76 |
+
|
77 |
+
def __init__(self):
|
78 |
+
"""#### Initialize the SaveImage class."""
|
79 |
+
self.output_dir = get_output_directory()
|
80 |
+
self.type = "output"
|
81 |
+
self.prefix_append = ""
|
82 |
+
self.compress_level = 4
|
83 |
+
|
84 |
+
def save_images(
|
85 |
+
self,
|
86 |
+
images: list,
|
87 |
+
filename_prefix: str = "LD",
|
88 |
+
prompt: str = None,
|
89 |
+
extra_pnginfo: dict = None,
|
90 |
+
) -> dict:
|
91 |
+
"""#### Save images to the output directory.
|
92 |
+
|
93 |
+
#### Args:
|
94 |
+
- `images` (list): The list of images.
|
95 |
+
- `filename_prefix` (str, optional): The filename prefix. Defaults to "LD".
|
96 |
+
- `prompt` (str, optional): The prompt. Defaults to None.
|
97 |
+
- `extra_pnginfo` (dict, optional): Additional PNG info. Defaults to None.
|
98 |
+
|
99 |
+
#### Returns:
|
100 |
+
- `dict`: The saved images information.
|
101 |
+
"""
|
102 |
+
filename_prefix += self.prefix_append
|
103 |
+
full_output_folder, filename, counter, subfolder, filename_prefix = (
|
104 |
+
get_save_image_path(
|
105 |
+
filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0]
|
106 |
+
)
|
107 |
+
)
|
108 |
+
results = list()
|
109 |
+
for batch_number, image in enumerate(images):
|
110 |
+
i = 255.0 * image.cpu().numpy()
|
111 |
+
img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
|
112 |
+
metadata = None
|
113 |
+
|
114 |
+
filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
|
115 |
+
file = f"{filename_with_batch_num}_{counter:05}_.png"
|
116 |
+
img.save(
|
117 |
+
os.path.join(full_output_folder, file),
|
118 |
+
pnginfo=metadata,
|
119 |
+
compress_level=self.compress_level,
|
120 |
+
)
|
121 |
+
results.append(
|
122 |
+
{"filename": file, "subfolder": subfolder, "type": self.type}
|
123 |
+
)
|
124 |
+
counter += 1
|
125 |
+
|
126 |
+
return {"ui": {"images": results}}
|
modules/FileManaging/Loader.py
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import torch
|
3 |
+
from modules.Utilities import util
|
4 |
+
from modules.AutoEncoders import VariationalAE
|
5 |
+
from modules.Device import Device
|
6 |
+
from modules.Model import ModelPatcher
|
7 |
+
from modules.NeuralNetwork import unet
|
8 |
+
from modules.clip import Clip
|
9 |
+
|
10 |
+
|
11 |
+
def load_checkpoint_guess_config(
|
12 |
+
ckpt_path: str,
|
13 |
+
output_vae: bool = True,
|
14 |
+
output_clip: bool = True,
|
15 |
+
output_clipvision: bool = False,
|
16 |
+
embedding_directory: str = None,
|
17 |
+
output_model: bool = True,
|
18 |
+
) -> tuple:
|
19 |
+
"""#### Load a checkpoint and guess the configuration.
|
20 |
+
|
21 |
+
#### Args:
|
22 |
+
- `ckpt_path` (str): The path to the checkpoint file.
|
23 |
+
- `output_vae` (bool, optional): Whether to output the VAE. Defaults to True.
|
24 |
+
- `output_clip` (bool, optional): Whether to output the CLIP. Defaults to True.
|
25 |
+
- `output_clipvision` (bool, optional): Whether to output the CLIP vision. Defaults to False.
|
26 |
+
- `embedding_directory` (str, optional): The embedding directory. Defaults to None.
|
27 |
+
- `output_model` (bool, optional): Whether to output the model. Defaults to True.
|
28 |
+
|
29 |
+
#### Returns:
|
30 |
+
- `tuple`: The model patcher, CLIP, VAE, and CLIP vision.
|
31 |
+
"""
|
32 |
+
sd = util.load_torch_file(ckpt_path)
|
33 |
+
sd.keys()
|
34 |
+
clip = None
|
35 |
+
clipvision = None
|
36 |
+
vae = None
|
37 |
+
model = None
|
38 |
+
model_patcher = None
|
39 |
+
clip_target = None
|
40 |
+
|
41 |
+
parameters = util.calculate_parameters(sd, "model.diffusion_model.")
|
42 |
+
load_device = Device.get_torch_device()
|
43 |
+
|
44 |
+
model_config = unet.model_config_from_unet(sd, "model.diffusion_model.")
|
45 |
+
unet_dtype = unet.unet_dtype1(
|
46 |
+
model_params=parameters,
|
47 |
+
supported_dtypes=model_config.supported_inference_dtypes,
|
48 |
+
)
|
49 |
+
manual_cast_dtype = Device.unet_manual_cast(
|
50 |
+
unet_dtype, load_device, model_config.supported_inference_dtypes
|
51 |
+
)
|
52 |
+
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
|
53 |
+
|
54 |
+
if output_model:
|
55 |
+
inital_load_device = Device.unet_inital_load_device(parameters, unet_dtype)
|
56 |
+
Device.unet_offload_device()
|
57 |
+
model = model_config.get_model(
|
58 |
+
sd, "model.diffusion_model.", device=inital_load_device
|
59 |
+
)
|
60 |
+
model.load_model_weights(sd, "model.diffusion_model.")
|
61 |
+
|
62 |
+
if output_vae:
|
63 |
+
vae_sd = util.state_dict_prefix_replace(
|
64 |
+
sd, {k: "" for k in model_config.vae_key_prefix}, filter_keys=True
|
65 |
+
)
|
66 |
+
vae_sd = model_config.process_vae_state_dict(vae_sd)
|
67 |
+
vae = VariationalAE.VAE(sd=vae_sd)
|
68 |
+
|
69 |
+
if output_clip:
|
70 |
+
clip_target = model_config.clip_target()
|
71 |
+
if clip_target is not None:
|
72 |
+
clip_sd = model_config.process_clip_state_dict(sd)
|
73 |
+
if len(clip_sd) > 0:
|
74 |
+
clip = Clip.CLIP(clip_target, embedding_directory=embedding_directory)
|
75 |
+
m, u = clip.load_sd(clip_sd, full_model=True)
|
76 |
+
if len(m) > 0:
|
77 |
+
m_filter = list(
|
78 |
+
filter(
|
79 |
+
lambda a: ".logit_scale" not in a
|
80 |
+
and ".transformer.text_projection.weight" not in a,
|
81 |
+
m,
|
82 |
+
)
|
83 |
+
)
|
84 |
+
if len(m_filter) > 0:
|
85 |
+
logging.warning("clip missing: {}".format(m))
|
86 |
+
else:
|
87 |
+
logging.debug("clip missing: {}".format(m))
|
88 |
+
|
89 |
+
if len(u) > 0:
|
90 |
+
logging.debug("clip unexpected {}:".format(u))
|
91 |
+
else:
|
92 |
+
logging.warning(
|
93 |
+
"no CLIP/text encoder weights in checkpoint, the text encoder model will not be loaded."
|
94 |
+
)
|
95 |
+
|
96 |
+
left_over = sd.keys()
|
97 |
+
if len(left_over) > 0:
|
98 |
+
logging.debug("left over keys: {}".format(left_over))
|
99 |
+
|
100 |
+
if output_model:
|
101 |
+
model_patcher = ModelPatcher.ModelPatcher(
|
102 |
+
model,
|
103 |
+
load_device=load_device,
|
104 |
+
offload_device=Device.unet_offload_device(),
|
105 |
+
current_device=inital_load_device,
|
106 |
+
)
|
107 |
+
if inital_load_device != torch.device("cpu"):
|
108 |
+
logging.info("loaded straight to GPU")
|
109 |
+
Device.load_model_gpu(model_patcher)
|
110 |
+
|
111 |
+
return (model_patcher, clip, vae, clipvision)
|
112 |
+
|
113 |
+
|
114 |
+
class CheckpointLoaderSimple:
|
115 |
+
"""#### Class for loading checkpoints."""
|
116 |
+
|
117 |
+
def load_checkpoint(
|
118 |
+
self, ckpt_name: str, output_vae: bool = True, output_clip: bool = True
|
119 |
+
) -> tuple:
|
120 |
+
"""#### Load a checkpoint.
|
121 |
+
|
122 |
+
#### Args:
|
123 |
+
- `ckpt_name` (str): The name of the checkpoint.
|
124 |
+
- `output_vae` (bool, optional): Whether to output the VAE. Defaults to True.
|
125 |
+
- `output_clip` (bool, optional): Whether to output the CLIP. Defaults to True.
|
126 |
+
|
127 |
+
#### Returns:
|
128 |
+
- `tuple`: The model patcher, CLIP, and VAE.
|
129 |
+
"""
|
130 |
+
ckpt_path = f"{ckpt_name}"
|
131 |
+
out = load_checkpoint_guess_config(
|
132 |
+
ckpt_path,
|
133 |
+
output_vae=output_vae,
|
134 |
+
output_clip=output_clip,
|
135 |
+
embedding_directory="./_internal/embeddings/",
|
136 |
+
)
|
137 |
+
print("loading", ckpt_path)
|
138 |
+
return out[:3]
|
modules/Model/LoRas.py
ADDED
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from modules.Utilities import util
|
3 |
+
from modules.NeuralNetwork import unet
|
4 |
+
|
5 |
+
LORA_CLIP_MAP = {
|
6 |
+
"mlp.fc1": "mlp_fc1",
|
7 |
+
"mlp.fc2": "mlp_fc2",
|
8 |
+
"self_attn.k_proj": "self_attn_k_proj",
|
9 |
+
"self_attn.q_proj": "self_attn_q_proj",
|
10 |
+
"self_attn.v_proj": "self_attn_v_proj",
|
11 |
+
"self_attn.out_proj": "self_attn_out_proj",
|
12 |
+
}
|
13 |
+
|
14 |
+
|
15 |
+
def load_lora(lora: dict, to_load: dict) -> dict:
|
16 |
+
"""#### Load a LoRA model.
|
17 |
+
|
18 |
+
#### Args:
|
19 |
+
- `lora` (dict): The LoRA model state dictionary.
|
20 |
+
- `to_load` (dict): The keys to load from the LoRA model.
|
21 |
+
|
22 |
+
#### Returns:
|
23 |
+
- `dict`: The loaded LoRA model.
|
24 |
+
"""
|
25 |
+
patch_dict = {}
|
26 |
+
loaded_keys = set()
|
27 |
+
for x in to_load:
|
28 |
+
alpha_name = "{}.alpha".format(x)
|
29 |
+
alpha = None
|
30 |
+
if alpha_name in lora.keys():
|
31 |
+
alpha = lora[alpha_name].item()
|
32 |
+
loaded_keys.add(alpha_name)
|
33 |
+
|
34 |
+
"{}.dora_scale".format(x)
|
35 |
+
dora_scale = None
|
36 |
+
|
37 |
+
regular_lora = "{}.lora_up.weight".format(x)
|
38 |
+
"{}_lora.up.weight".format(x)
|
39 |
+
"{}.lora_linear_layer.up.weight".format(x)
|
40 |
+
A_name = None
|
41 |
+
|
42 |
+
if regular_lora in lora.keys():
|
43 |
+
A_name = regular_lora
|
44 |
+
B_name = "{}.lora_down.weight".format(x)
|
45 |
+
"{}.lora_mid.weight".format(x)
|
46 |
+
|
47 |
+
if A_name is not None:
|
48 |
+
mid = None
|
49 |
+
patch_dict[to_load[x]] = (
|
50 |
+
"lora",
|
51 |
+
(lora[A_name], lora[B_name], alpha, mid, dora_scale),
|
52 |
+
)
|
53 |
+
loaded_keys.add(A_name)
|
54 |
+
loaded_keys.add(B_name)
|
55 |
+
return patch_dict
|
56 |
+
|
57 |
+
|
58 |
+
def model_lora_keys_clip(model: torch.nn.Module, key_map: dict = {}) -> dict:
|
59 |
+
"""#### Get the keys for a LoRA model's CLIP component.
|
60 |
+
|
61 |
+
#### Args:
|
62 |
+
- `model` (torch.nn.Module): The LoRA model.
|
63 |
+
- `key_map` (dict, optional): The key map. Defaults to {}.
|
64 |
+
|
65 |
+
#### Returns:
|
66 |
+
- `dict`: The keys for the CLIP component.
|
67 |
+
"""
|
68 |
+
sdk = model.state_dict().keys()
|
69 |
+
|
70 |
+
text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}"
|
71 |
+
for b in range(32):
|
72 |
+
for c in LORA_CLIP_MAP:
|
73 |
+
k = "clip_l.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
|
74 |
+
if k in sdk:
|
75 |
+
lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c])
|
76 |
+
key_map[lora_key] = k
|
77 |
+
lora_key = "lora_te1_text_model_encoder_layers_{}_{}".format(
|
78 |
+
b, LORA_CLIP_MAP[c]
|
79 |
+
) # SDXL base
|
80 |
+
key_map[lora_key] = k
|
81 |
+
lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format(
|
82 |
+
b, c
|
83 |
+
) # diffusers lora
|
84 |
+
key_map[lora_key] = k
|
85 |
+
return key_map
|
86 |
+
|
87 |
+
|
88 |
+
def model_lora_keys_unet(model: torch.nn.Module, key_map: dict = {}) -> dict:
|
89 |
+
"""#### Get the keys for a LoRA model's UNet component.
|
90 |
+
|
91 |
+
#### Args:
|
92 |
+
- `model` (torch.nn.Module): The LoRA model.
|
93 |
+
- `key_map` (dict, optional): The key map. Defaults to {}.
|
94 |
+
|
95 |
+
#### Returns:
|
96 |
+
- `dict`: The keys for the UNet component.
|
97 |
+
"""
|
98 |
+
sdk = model.state_dict().keys()
|
99 |
+
|
100 |
+
for k in sdk:
|
101 |
+
if k.startswith("diffusion_model.") and k.endswith(".weight"):
|
102 |
+
key_lora = k[len("diffusion_model.") : -len(".weight")].replace(".", "_")
|
103 |
+
key_map["lora_unet_{}".format(key_lora)] = k
|
104 |
+
key_map["lora_prior_unet_{}".format(key_lora)] = k # cascade lora:
|
105 |
+
|
106 |
+
diffusers_keys = unet.unet_to_diffusers(model.model_config.unet_config)
|
107 |
+
for k in diffusers_keys:
|
108 |
+
if k.endswith(".weight"):
|
109 |
+
unet_key = "diffusion_model.{}".format(diffusers_keys[k])
|
110 |
+
key_lora = k[: -len(".weight")].replace(".", "_")
|
111 |
+
key_map["lora_unet_{}".format(key_lora)] = unet_key
|
112 |
+
|
113 |
+
diffusers_lora_prefix = ["", "unet."]
|
114 |
+
for p in diffusers_lora_prefix:
|
115 |
+
diffusers_lora_key = "{}{}".format(
|
116 |
+
p, k[: -len(".weight")].replace(".to_", ".processor.to_")
|
117 |
+
)
|
118 |
+
if diffusers_lora_key.endswith(".to_out.0"):
|
119 |
+
diffusers_lora_key = diffusers_lora_key[:-2]
|
120 |
+
key_map[diffusers_lora_key] = unet_key
|
121 |
+
return key_map
|
122 |
+
|
123 |
+
|
124 |
+
def load_lora_for_models(
|
125 |
+
model: object, clip: object, lora: dict, strength_model: float, strength_clip: float
|
126 |
+
) -> tuple:
|
127 |
+
"""#### Load a LoRA model for the given models.
|
128 |
+
|
129 |
+
#### Args:
|
130 |
+
- `model` (object): The model.
|
131 |
+
- `clip` (object): The CLIP model.
|
132 |
+
- `lora` (dict): The LoRA model state dictionary.
|
133 |
+
- `strength_model` (float): The strength of the model.
|
134 |
+
- `strength_clip` (float): The strength of the CLIP model.
|
135 |
+
|
136 |
+
#### Returns:
|
137 |
+
- `tuple`: The new model patcher and CLIP model.
|
138 |
+
"""
|
139 |
+
key_map = {}
|
140 |
+
if model is not None:
|
141 |
+
key_map = model_lora_keys_unet(model.model, key_map)
|
142 |
+
if clip is not None:
|
143 |
+
key_map = model_lora_keys_clip(clip.cond_stage_model, key_map)
|
144 |
+
|
145 |
+
loaded = load_lora(lora, key_map)
|
146 |
+
new_modelpatcher = model.clone()
|
147 |
+
k = new_modelpatcher.add_patches(loaded, strength_model)
|
148 |
+
|
149 |
+
new_clip = clip.clone()
|
150 |
+
k1 = new_clip.add_patches(loaded, strength_clip)
|
151 |
+
k = set(k)
|
152 |
+
k1 = set(k1)
|
153 |
+
|
154 |
+
return (new_modelpatcher, new_clip)
|
155 |
+
|
156 |
+
|
157 |
+
class LoraLoader:
|
158 |
+
"""#### Class for loading LoRA models."""
|
159 |
+
|
160 |
+
def __init__(self):
|
161 |
+
"""#### Initialize the LoraLoader class."""
|
162 |
+
self.loaded_lora = None
|
163 |
+
|
164 |
+
def load_lora(
|
165 |
+
self,
|
166 |
+
model: object,
|
167 |
+
clip: object,
|
168 |
+
lora_name: str,
|
169 |
+
strength_model: float,
|
170 |
+
strength_clip: float,
|
171 |
+
) -> tuple:
|
172 |
+
"""#### Load a LoRA model.
|
173 |
+
|
174 |
+
#### Args:
|
175 |
+
- `model` (object): The model.
|
176 |
+
- `clip` (object): The CLIP model.
|
177 |
+
- `lora_name` (str): The name of the LoRA model.
|
178 |
+
- `strength_model` (float): The strength of the model.
|
179 |
+
- `strength_clip` (float): The strength of the CLIP model.
|
180 |
+
|
181 |
+
#### Returns:
|
182 |
+
- `tuple`: The new model patcher and CLIP model.
|
183 |
+
"""
|
184 |
+
lora_path = util.get_full_path("loras", lora_name)
|
185 |
+
lora = None
|
186 |
+
if lora is None:
|
187 |
+
lora = util.load_torch_file(lora_path, safe_load=True)
|
188 |
+
self.loaded_lora = (lora_path, lora)
|
189 |
+
|
190 |
+
model_lora, clip_lora = load_lora_for_models(
|
191 |
+
model, clip, lora, strength_model, strength_clip
|
192 |
+
)
|
193 |
+
return (model_lora, clip_lora)
|
modules/Model/ModelBase.py
ADDED
@@ -0,0 +1,363 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import math
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from modules.Utilities import Latent
|
6 |
+
from modules.Device import Device
|
7 |
+
from modules.NeuralNetwork import unet
|
8 |
+
from modules.cond import cast, cond
|
9 |
+
from modules.sample import sampling
|
10 |
+
|
11 |
+
|
12 |
+
class BaseModel(torch.nn.Module):
|
13 |
+
"""#### Base class for models."""
|
14 |
+
|
15 |
+
def __init__(
|
16 |
+
self,
|
17 |
+
model_config: object,
|
18 |
+
model_type: sampling.ModelType = sampling.ModelType.EPS,
|
19 |
+
device: torch.device = None,
|
20 |
+
unet_model: object = unet.UNetModel1,
|
21 |
+
flux: bool = False,
|
22 |
+
):
|
23 |
+
"""#### Initialize the BaseModel class.
|
24 |
+
|
25 |
+
#### Args:
|
26 |
+
- `model_config` (object): The model configuration.
|
27 |
+
- `model_type` (sampling.ModelType, optional): The model type. Defaults to sampling.ModelType.EPS.
|
28 |
+
- `device` (torch.device, optional): The device to use. Defaults to None.
|
29 |
+
- `unet_model` (object, optional): The UNet model. Defaults to unet.UNetModel1.
|
30 |
+
"""
|
31 |
+
super().__init__()
|
32 |
+
|
33 |
+
unet_config = model_config.unet_config
|
34 |
+
self.latent_format = model_config.latent_format
|
35 |
+
self.model_config = model_config
|
36 |
+
self.manual_cast_dtype = model_config.manual_cast_dtype
|
37 |
+
self.device = device
|
38 |
+
if flux:
|
39 |
+
if not unet_config.get("disable_unet_model_creation", False):
|
40 |
+
operations = model_config.custom_operations
|
41 |
+
self.diffusion_model = unet_model(
|
42 |
+
**unet_config, device=device, operations=operations
|
43 |
+
)
|
44 |
+
logging.info(
|
45 |
+
"model weight dtype {}, manual cast: {}".format(
|
46 |
+
self.get_dtype(), self.manual_cast_dtype
|
47 |
+
)
|
48 |
+
)
|
49 |
+
else:
|
50 |
+
if not unet_config.get("disable_unet_model_creation", False):
|
51 |
+
if self.manual_cast_dtype is not None:
|
52 |
+
operations = cast.manual_cast
|
53 |
+
else:
|
54 |
+
operations = cast.disable_weight_init
|
55 |
+
self.diffusion_model = unet_model(
|
56 |
+
**unet_config, device=device, operations=operations
|
57 |
+
)
|
58 |
+
self.model_type = model_type
|
59 |
+
self.model_sampling = sampling.model_sampling(model_config, model_type, flux=flux)
|
60 |
+
|
61 |
+
self.adm_channels = unet_config.get("adm_in_channels", None)
|
62 |
+
if self.adm_channels is None:
|
63 |
+
self.adm_channels = 0
|
64 |
+
|
65 |
+
self.concat_keys = ()
|
66 |
+
logging.info("model_type {}".format(model_type.name))
|
67 |
+
logging.debug("adm {}".format(self.adm_channels))
|
68 |
+
self.memory_usage_factor = model_config.memory_usage_factor if flux else 2.0
|
69 |
+
|
70 |
+
def apply_model(
|
71 |
+
self,
|
72 |
+
x: torch.Tensor,
|
73 |
+
t: torch.Tensor,
|
74 |
+
c_concat: torch.Tensor = None,
|
75 |
+
c_crossattn: torch.Tensor = None,
|
76 |
+
control: torch.Tensor = None,
|
77 |
+
transformer_options: dict = {},
|
78 |
+
**kwargs,
|
79 |
+
) -> torch.Tensor:
|
80 |
+
"""#### Apply the model to the input tensor.
|
81 |
+
|
82 |
+
#### Args:
|
83 |
+
- `x` (torch.Tensor): The input tensor.
|
84 |
+
- `t` (torch.Tensor): The timestep tensor.
|
85 |
+
- `c_concat` (torch.Tensor, optional): The concatenated condition tensor. Defaults to None.
|
86 |
+
- `c_crossattn` (torch.Tensor, optional): The cross-attention condition tensor. Defaults to None.
|
87 |
+
- `control` (torch.Tensor, optional): The control tensor. Defaults to None.
|
88 |
+
- `transformer_options` (dict, optional): The transformer options. Defaults to {}.
|
89 |
+
- `**kwargs`: Additional keyword arguments.
|
90 |
+
|
91 |
+
#### Returns:
|
92 |
+
- `torch.Tensor`: The output tensor.
|
93 |
+
"""
|
94 |
+
sigma = t
|
95 |
+
xc = self.model_sampling.calculate_input(sigma, x)
|
96 |
+
if c_concat is not None:
|
97 |
+
xc = torch.cat([xc] + [c_concat], dim=1)
|
98 |
+
|
99 |
+
context = c_crossattn
|
100 |
+
dtype = self.get_dtype()
|
101 |
+
|
102 |
+
if self.manual_cast_dtype is not None:
|
103 |
+
dtype = self.manual_cast_dtype
|
104 |
+
|
105 |
+
xc = xc.to(dtype)
|
106 |
+
t = self.model_sampling.timestep(t).float()
|
107 |
+
context = context.to(dtype)
|
108 |
+
extra_conds = {}
|
109 |
+
for o in kwargs:
|
110 |
+
extra = kwargs[o]
|
111 |
+
if hasattr(extra, "dtype"):
|
112 |
+
if extra.dtype != torch.int and extra.dtype != torch.long:
|
113 |
+
extra = extra.to(dtype)
|
114 |
+
extra_conds[o] = extra
|
115 |
+
|
116 |
+
model_output = self.diffusion_model(
|
117 |
+
xc,
|
118 |
+
t,
|
119 |
+
context=context,
|
120 |
+
control=control,
|
121 |
+
transformer_options=transformer_options,
|
122 |
+
**extra_conds,
|
123 |
+
).float()
|
124 |
+
return self.model_sampling.calculate_denoised(sigma, model_output, x)
|
125 |
+
|
126 |
+
def get_dtype(self) -> torch.dtype:
|
127 |
+
"""#### Get the data type of the model.
|
128 |
+
|
129 |
+
#### Returns:
|
130 |
+
- `torch.dtype`: The data type.
|
131 |
+
"""
|
132 |
+
return self.diffusion_model.dtype
|
133 |
+
|
134 |
+
def encode_adm(self, **kwargs) -> None:
|
135 |
+
"""#### Encode the ADM.
|
136 |
+
|
137 |
+
#### Args:
|
138 |
+
- `**kwargs`: Additional keyword arguments.
|
139 |
+
|
140 |
+
#### Returns:
|
141 |
+
- `None`: The encoded ADM.
|
142 |
+
"""
|
143 |
+
return None
|
144 |
+
|
145 |
+
def extra_conds(self, **kwargs) -> dict:
|
146 |
+
"""#### Get the extra conditions.
|
147 |
+
|
148 |
+
#### Args:
|
149 |
+
- `**kwargs`: Additional keyword arguments.
|
150 |
+
|
151 |
+
#### Returns:
|
152 |
+
- `dict`: The extra conditions.
|
153 |
+
"""
|
154 |
+
out = {}
|
155 |
+
adm = self.encode_adm(**kwargs)
|
156 |
+
if adm is not None:
|
157 |
+
out["y"] = cond.CONDRegular(adm)
|
158 |
+
|
159 |
+
cross_attn = kwargs.get("cross_attn", None)
|
160 |
+
if cross_attn is not None:
|
161 |
+
out["c_crossattn"] = cond.CONDCrossAttn(cross_attn)
|
162 |
+
|
163 |
+
cross_attn_cnet = kwargs.get("cross_attn_controlnet", None)
|
164 |
+
if cross_attn_cnet is not None:
|
165 |
+
out["crossattn_controlnet"] = cond.CONDCrossAttn(cross_attn_cnet)
|
166 |
+
|
167 |
+
return out
|
168 |
+
|
169 |
+
def load_model_weights(self, sd: dict, unet_prefix: str = "") -> "BaseModel":
|
170 |
+
"""#### Load the model weights.
|
171 |
+
|
172 |
+
#### Args:
|
173 |
+
- `sd` (dict): The state dictionary.
|
174 |
+
- `unet_prefix` (str, optional): The UNet prefix. Defaults to "".
|
175 |
+
|
176 |
+
#### Returns:
|
177 |
+
- `BaseModel`: The model with loaded weights.
|
178 |
+
"""
|
179 |
+
to_load = {}
|
180 |
+
keys = list(sd.keys())
|
181 |
+
for k in keys:
|
182 |
+
if k.startswith(unet_prefix):
|
183 |
+
to_load[k[len(unet_prefix) :]] = sd.pop(k)
|
184 |
+
|
185 |
+
to_load = self.model_config.process_unet_state_dict(to_load)
|
186 |
+
m, u = self.diffusion_model.load_state_dict(to_load, strict=False)
|
187 |
+
if len(m) > 0:
|
188 |
+
logging.warning("unet missing: {}".format(m))
|
189 |
+
|
190 |
+
if len(u) > 0:
|
191 |
+
logging.warning("unet unexpected: {}".format(u))
|
192 |
+
del to_load
|
193 |
+
return self
|
194 |
+
|
195 |
+
def process_latent_in(self, latent: torch.Tensor) -> torch.Tensor:
|
196 |
+
"""#### Process the latent input.
|
197 |
+
|
198 |
+
#### Args:
|
199 |
+
- `latent` (torch.Tensor): The latent tensor.
|
200 |
+
|
201 |
+
#### Returns:
|
202 |
+
- `torch.Tensor`: The processed latent tensor.
|
203 |
+
"""
|
204 |
+
return self.latent_format.process_in(latent)
|
205 |
+
|
206 |
+
def process_latent_out(self, latent: torch.Tensor) -> torch.Tensor:
|
207 |
+
"""#### Process the latent output.
|
208 |
+
|
209 |
+
#### Args:
|
210 |
+
- `latent` (torch.Tensor): The latent tensor.
|
211 |
+
|
212 |
+
#### Returns:
|
213 |
+
- `torch.Tensor`: The processed latent tensor.
|
214 |
+
"""
|
215 |
+
return self.latent_format.process_out(latent)
|
216 |
+
|
217 |
+
def memory_required(self, input_shape: tuple) -> float:
|
218 |
+
"""#### Calculate the memory required for the model.
|
219 |
+
|
220 |
+
#### Args:
|
221 |
+
- `input_shape` (tuple): The input shape.
|
222 |
+
|
223 |
+
#### Returns:
|
224 |
+
- `float`: The memory required.
|
225 |
+
"""
|
226 |
+
dtype = self.get_dtype()
|
227 |
+
if self.manual_cast_dtype is not None:
|
228 |
+
dtype = self.manual_cast_dtype
|
229 |
+
# TODO: this needs to be tweaked
|
230 |
+
area = input_shape[0] * math.prod(input_shape[2:])
|
231 |
+
return (area * Device.dtype_size(dtype) * 0.01 * self.memory_usage_factor) * (
|
232 |
+
1024 * 1024
|
233 |
+
)
|
234 |
+
|
235 |
+
|
236 |
+
class BASE:
|
237 |
+
"""#### Base class for model configurations."""
|
238 |
+
|
239 |
+
unet_config = {}
|
240 |
+
unet_extra_config = {
|
241 |
+
"num_heads": -1,
|
242 |
+
"num_head_channels": 64,
|
243 |
+
}
|
244 |
+
|
245 |
+
required_keys = {}
|
246 |
+
|
247 |
+
clip_prefix = []
|
248 |
+
clip_vision_prefix = None
|
249 |
+
noise_aug_config = None
|
250 |
+
sampling_settings = {}
|
251 |
+
latent_format = Latent.LatentFormat
|
252 |
+
vae_key_prefix = ["first_stage_model."]
|
253 |
+
text_encoder_key_prefix = ["cond_stage_model."]
|
254 |
+
supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
255 |
+
|
256 |
+
memory_usage_factor = 2.0
|
257 |
+
|
258 |
+
manual_cast_dtype = None
|
259 |
+
custom_operations = None
|
260 |
+
|
261 |
+
@classmethod
|
262 |
+
def matches(cls, unet_config: dict, state_dict: dict = None) -> bool:
|
263 |
+
"""#### Check if the UNet configuration matches.
|
264 |
+
|
265 |
+
#### Args:
|
266 |
+
- `unet_config` (dict): The UNet configuration.
|
267 |
+
- `state_dict` (dict, optional): The state dictionary. Defaults to None.
|
268 |
+
|
269 |
+
#### Returns:
|
270 |
+
- `bool`: Whether the configuration matches.
|
271 |
+
"""
|
272 |
+
for k in cls.unet_config:
|
273 |
+
if k not in unet_config or cls.unet_config[k] != unet_config[k]:
|
274 |
+
return False
|
275 |
+
if state_dict is not None:
|
276 |
+
for k in cls.required_keys:
|
277 |
+
if k not in state_dict:
|
278 |
+
return False
|
279 |
+
return True
|
280 |
+
|
281 |
+
def model_type(self, state_dict: dict, prefix: str = "") -> sampling.ModelType:
|
282 |
+
"""#### Get the model type.
|
283 |
+
|
284 |
+
#### Args:
|
285 |
+
- `state_dict` (dict): The state dictionary.
|
286 |
+
- `prefix` (str, optional): The prefix. Defaults to "".
|
287 |
+
|
288 |
+
#### Returns:
|
289 |
+
- `sampling.ModelType`: The model type.
|
290 |
+
"""
|
291 |
+
return sampling.ModelType.EPS
|
292 |
+
|
293 |
+
def inpaint_model(self) -> bool:
|
294 |
+
"""#### Check if the model is an inpaint model.
|
295 |
+
|
296 |
+
#### Returns:
|
297 |
+
- `bool`: Whether the model is an inpaint model.
|
298 |
+
"""
|
299 |
+
return self.unet_config["in_channels"] > 4
|
300 |
+
|
301 |
+
def __init__(self, unet_config: dict):
|
302 |
+
"""#### Initialize the BASE class.
|
303 |
+
|
304 |
+
#### Args:
|
305 |
+
- `unet_config` (dict): The UNet configuration.
|
306 |
+
"""
|
307 |
+
self.unet_config = unet_config.copy()
|
308 |
+
self.sampling_settings = self.sampling_settings.copy()
|
309 |
+
self.latent_format = self.latent_format()
|
310 |
+
for x in self.unet_extra_config:
|
311 |
+
self.unet_config[x] = self.unet_extra_config[x]
|
312 |
+
|
313 |
+
def get_model(
|
314 |
+
self, state_dict: dict, prefix: str = "", device: torch.device = None
|
315 |
+
) -> BaseModel:
|
316 |
+
"""#### Get the model.
|
317 |
+
|
318 |
+
#### Args:
|
319 |
+
- `state_dict` (dict): The state dictionary.
|
320 |
+
- `prefix` (str, optional): The prefix. Defaults to "".
|
321 |
+
- `device` (torch.device, optional): The device to use. Defaults to None.
|
322 |
+
|
323 |
+
#### Returns:
|
324 |
+
- `BaseModel`: The model.
|
325 |
+
"""
|
326 |
+
out = BaseModel(
|
327 |
+
self, model_type=self.model_type(state_dict, prefix), device=device
|
328 |
+
)
|
329 |
+
return out
|
330 |
+
|
331 |
+
def process_unet_state_dict(self, state_dict: dict) -> dict:
|
332 |
+
"""#### Process the UNet state dictionary.
|
333 |
+
|
334 |
+
#### Args:
|
335 |
+
- `state_dict` (dict): The state dictionary.
|
336 |
+
|
337 |
+
#### Returns:
|
338 |
+
- `dict`: The processed state dictionary.
|
339 |
+
"""
|
340 |
+
return state_dict
|
341 |
+
|
342 |
+
def process_vae_state_dict(self, state_dict: dict) -> dict:
|
343 |
+
"""#### Process the VAE state dictionary.
|
344 |
+
|
345 |
+
#### Args:
|
346 |
+
- `state_dict` (dict): The state dictionary.
|
347 |
+
|
348 |
+
#### Returns:
|
349 |
+
- `dict`: The processed state dictionary.
|
350 |
+
"""
|
351 |
+
return state_dict
|
352 |
+
|
353 |
+
def set_inference_dtype(
|
354 |
+
self, dtype: torch.dtype, manual_cast_dtype: torch.dtype
|
355 |
+
) -> None:
|
356 |
+
"""#### Set the inference data type.
|
357 |
+
|
358 |
+
#### Args:
|
359 |
+
- `dtype` (torch.dtype): The data type.
|
360 |
+
- `manual_cast_dtype` (torch.dtype): The manual cast data type.
|
361 |
+
"""
|
362 |
+
self.unet_config["dtype"] = dtype
|
363 |
+
self.manual_cast_dtype = manual_cast_dtype
|
modules/Model/ModelPatcher.py
ADDED
@@ -0,0 +1,779 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import logging
|
3 |
+
import uuid
|
4 |
+
|
5 |
+
import torch
|
6 |
+
|
7 |
+
from modules.NeuralNetwork import unet
|
8 |
+
from modules.Utilities import util
|
9 |
+
from modules.Device import Device
|
10 |
+
|
11 |
+
def wipe_lowvram_weight(m):
|
12 |
+
if hasattr(m, "prev_comfy_cast_weights"):
|
13 |
+
m.comfy_cast_weights = m.prev_comfy_cast_weights
|
14 |
+
del m.prev_comfy_cast_weights
|
15 |
+
m.weight_function = None
|
16 |
+
m.bias_function = None
|
17 |
+
|
18 |
+
class ModelPatcher:
|
19 |
+
def __init__(
|
20 |
+
self,
|
21 |
+
model: torch.nn.Module,
|
22 |
+
load_device: torch.device,
|
23 |
+
offload_device: torch.device,
|
24 |
+
size: int = 0,
|
25 |
+
current_device: torch.device = None,
|
26 |
+
weight_inplace_update: bool = False,
|
27 |
+
):
|
28 |
+
"""#### Initialize the ModelPatcher class.
|
29 |
+
|
30 |
+
#### Args:
|
31 |
+
- `model` (torch.nn.Module): The model.
|
32 |
+
- `load_device` (torch.device): The device to load the model on.
|
33 |
+
- `offload_device` (torch.device): The device to offload the model to.
|
34 |
+
- `size` (int, optional): The size of the model. Defaults to 0.
|
35 |
+
- `current_device` (torch.device, optional): The current device. Defaults to None.
|
36 |
+
- `weight_inplace_update` (bool, optional): Whether to update weights in place. Defaults to False.
|
37 |
+
"""
|
38 |
+
self.size = size
|
39 |
+
self.model = model
|
40 |
+
self.patches = {}
|
41 |
+
self.backup = {}
|
42 |
+
self.object_patches = {}
|
43 |
+
self.object_patches_backup = {}
|
44 |
+
self.model_options = {"transformer_options": {}}
|
45 |
+
self.model_size()
|
46 |
+
self.load_device = load_device
|
47 |
+
self.offload_device = offload_device
|
48 |
+
if current_device is None:
|
49 |
+
self.current_device = self.offload_device
|
50 |
+
else:
|
51 |
+
self.current_device = current_device
|
52 |
+
|
53 |
+
self.weight_inplace_update = weight_inplace_update
|
54 |
+
self.model_lowvram = False
|
55 |
+
self.lowvram_patch_counter = 0
|
56 |
+
self.patches_uuid = uuid.uuid4()
|
57 |
+
|
58 |
+
if not hasattr(self.model, "model_loaded_weight_memory"):
|
59 |
+
self.model.model_loaded_weight_memory = 0
|
60 |
+
|
61 |
+
if not hasattr(self.model, "model_lowvram"):
|
62 |
+
self.model.model_lowvram = False
|
63 |
+
|
64 |
+
if not hasattr(self.model, "lowvram_patch_counter"):
|
65 |
+
self.model.lowvram_patch_counter = 0
|
66 |
+
|
67 |
+
def loaded_size(self) -> int:
|
68 |
+
"""#### Get the loaded size
|
69 |
+
|
70 |
+
#### Returns:
|
71 |
+
- `int`: The loaded size
|
72 |
+
"""
|
73 |
+
return self.model.model_loaded_weight_memory
|
74 |
+
|
75 |
+
def model_size(self) -> int:
|
76 |
+
"""#### Get the size of the model.
|
77 |
+
|
78 |
+
#### Returns:
|
79 |
+
- `int`: The size of the model.
|
80 |
+
"""
|
81 |
+
if self.size > 0:
|
82 |
+
return self.size
|
83 |
+
model_sd = self.model.state_dict()
|
84 |
+
self.size = Device.module_size(self.model)
|
85 |
+
self.model_keys = set(model_sd.keys())
|
86 |
+
return self.size
|
87 |
+
|
88 |
+
def clone(self) -> "ModelPatcher":
|
89 |
+
"""#### Clone the ModelPatcher object.
|
90 |
+
|
91 |
+
#### Returns:
|
92 |
+
- `ModelPatcher`: The cloned ModelPatcher object.
|
93 |
+
"""
|
94 |
+
n = ModelPatcher(
|
95 |
+
self.model,
|
96 |
+
self.load_device,
|
97 |
+
self.offload_device,
|
98 |
+
self.size,
|
99 |
+
self.current_device,
|
100 |
+
weight_inplace_update=self.weight_inplace_update,
|
101 |
+
)
|
102 |
+
n.patches = {}
|
103 |
+
for k in self.patches:
|
104 |
+
n.patches[k] = self.patches[k][:]
|
105 |
+
n.patches_uuid = self.patches_uuid
|
106 |
+
|
107 |
+
n.object_patches = self.object_patches.copy()
|
108 |
+
n.model_options = copy.deepcopy(self.model_options)
|
109 |
+
n.model_keys = self.model_keys
|
110 |
+
n.backup = self.backup
|
111 |
+
n.object_patches_backup = self.object_patches_backup
|
112 |
+
return n
|
113 |
+
|
114 |
+
def is_clone(self, other: object) -> bool:
|
115 |
+
"""#### Check if the object is a clone.
|
116 |
+
|
117 |
+
#### Args:
|
118 |
+
- `other` (object): The other object.
|
119 |
+
|
120 |
+
#### Returns:
|
121 |
+
- `bool`: Whether the object is a clone.
|
122 |
+
"""
|
123 |
+
if hasattr(other, "model") and self.model is other.model:
|
124 |
+
return True
|
125 |
+
return False
|
126 |
+
|
127 |
+
def memory_required(self, input_shape: tuple) -> float:
|
128 |
+
"""#### Calculate the memory required for the model.
|
129 |
+
|
130 |
+
#### Args:
|
131 |
+
- `input_shape` (tuple): The input shape.
|
132 |
+
|
133 |
+
#### Returns:
|
134 |
+
- `float`: The memory required.
|
135 |
+
"""
|
136 |
+
return self.model.memory_required(input_shape=input_shape)
|
137 |
+
|
138 |
+
def set_model_unet_function_wrapper(self, unet_wrapper_function: callable) -> None:
|
139 |
+
"""#### Set the UNet function wrapper for the model.
|
140 |
+
|
141 |
+
#### Args:
|
142 |
+
- `unet_wrapper_function` (callable): The UNet function wrapper.
|
143 |
+
"""
|
144 |
+
self.model_options["model_function_wrapper"] = unet_wrapper_function
|
145 |
+
|
146 |
+
def set_model_denoise_mask_function(self, denoise_mask_function: callable) -> None:
|
147 |
+
"""#### Set the denoise mask function for the model.
|
148 |
+
|
149 |
+
#### Args:
|
150 |
+
- `denoise_mask_function` (callable): The denoise mask function.
|
151 |
+
"""
|
152 |
+
self.model_options["denoise_mask_function"] = denoise_mask_function
|
153 |
+
|
154 |
+
def get_model_object(self, name: str) -> object:
|
155 |
+
"""#### Get an object from the model.
|
156 |
+
|
157 |
+
#### Args:
|
158 |
+
- `name` (str): The name of the object.
|
159 |
+
|
160 |
+
#### Returns:
|
161 |
+
- `object`: The object.
|
162 |
+
"""
|
163 |
+
return util.get_attr(self.model, name)
|
164 |
+
|
165 |
+
def model_patches_to(self, device: torch.device) -> None:
|
166 |
+
"""#### Move model patches to a device.
|
167 |
+
|
168 |
+
#### Args:
|
169 |
+
- `device` (torch.device): The device.
|
170 |
+
"""
|
171 |
+
self.model_options["transformer_options"]
|
172 |
+
if "model_function_wrapper" in self.model_options:
|
173 |
+
wrap_func = self.model_options["model_function_wrapper"]
|
174 |
+
if hasattr(wrap_func, "to"):
|
175 |
+
self.model_options["model_function_wrapper"] = wrap_func.to(device)
|
176 |
+
|
177 |
+
def model_dtype(self) -> torch.dtype:
|
178 |
+
"""#### Get the data type of the model.
|
179 |
+
|
180 |
+
#### Returns:
|
181 |
+
- `torch.dtype`: The data type.
|
182 |
+
"""
|
183 |
+
if hasattr(self.model, "get_dtype"):
|
184 |
+
return self.model.get_dtype()
|
185 |
+
|
186 |
+
def add_patches(
|
187 |
+
self, patches: dict, strength_patch: float = 1.0, strength_model: float = 1.0
|
188 |
+
) -> list:
|
189 |
+
"""#### Add patches to the model.
|
190 |
+
|
191 |
+
#### Args:
|
192 |
+
- `patches` (dict): The patches to add.
|
193 |
+
- `strength_patch` (float, optional): The strength of the patches. Defaults to 1.0.
|
194 |
+
- `strength_model` (float, optional): The strength of the model. Defaults to 1.0.
|
195 |
+
|
196 |
+
#### Returns:
|
197 |
+
- `list`: The list of patched keys.
|
198 |
+
"""
|
199 |
+
p = set()
|
200 |
+
for k in patches:
|
201 |
+
if k in self.model_keys:
|
202 |
+
p.add(k)
|
203 |
+
current_patches = self.patches.get(k, [])
|
204 |
+
current_patches.append((strength_patch, patches[k], strength_model))
|
205 |
+
self.patches[k] = current_patches
|
206 |
+
|
207 |
+
self.patches_uuid = uuid.uuid4()
|
208 |
+
return list(p)
|
209 |
+
|
210 |
+
def set_model_patch(self, patch: list, name: str):
|
211 |
+
"""#### Set a patch for the model.
|
212 |
+
|
213 |
+
#### Args:
|
214 |
+
- `patch` (list): The patch.
|
215 |
+
- `name` (str): The name of the patch.
|
216 |
+
"""
|
217 |
+
to = self.model_options["transformer_options"]
|
218 |
+
if "patches" not in to:
|
219 |
+
to["patches"] = {}
|
220 |
+
to["patches"][name] = to["patches"].get(name, []) + [patch]
|
221 |
+
|
222 |
+
def set_model_attn1_patch(self, patch: list):
|
223 |
+
"""#### Set the attention 1 patch for the model.
|
224 |
+
|
225 |
+
#### Args:
|
226 |
+
- `patch` (list): The patch.
|
227 |
+
"""
|
228 |
+
self.set_model_patch(patch, "attn1_patch")
|
229 |
+
|
230 |
+
def set_model_attn2_patch(self, patch: list):
|
231 |
+
"""#### Set the attention 2 patch for the model.
|
232 |
+
|
233 |
+
#### Args:
|
234 |
+
- `patch` (list): The patch.
|
235 |
+
"""
|
236 |
+
self.set_model_patch(patch, "attn2_patch")
|
237 |
+
|
238 |
+
def set_model_attn1_output_patch(self, patch: list):
|
239 |
+
"""#### Set the attention 1 output patch for the model.
|
240 |
+
|
241 |
+
#### Args:
|
242 |
+
- `patch` (list): The patch.
|
243 |
+
"""
|
244 |
+
self.set_model_patch(patch, "attn1_output_patch")
|
245 |
+
|
246 |
+
def set_model_attn2_output_patch(self, patch: list):
|
247 |
+
"""#### Set the attention 2 output patch for the model.
|
248 |
+
|
249 |
+
#### Args:
|
250 |
+
- `patch` (list): The patch.
|
251 |
+
"""
|
252 |
+
self.set_model_patch(patch, "attn2_output_patch")
|
253 |
+
|
254 |
+
def model_state_dict(self, filter_prefix: str = None) -> dict:
|
255 |
+
"""#### Get the state dictionary of the model.
|
256 |
+
|
257 |
+
#### Args:
|
258 |
+
- `filter_prefix` (str, optional): The prefix to filter. Defaults to None.
|
259 |
+
|
260 |
+
#### Returns:
|
261 |
+
- `dict`: The state dictionary.
|
262 |
+
"""
|
263 |
+
sd = self.model.state_dict()
|
264 |
+
list(sd.keys())
|
265 |
+
return sd
|
266 |
+
|
267 |
+
def patch_weight_to_device(self, key: str, device_to: torch.device = None) -> None:
|
268 |
+
"""#### Patch the weight of a key to a device.
|
269 |
+
|
270 |
+
#### Args:
|
271 |
+
- `key` (str): The key.
|
272 |
+
- `device_to` (torch.device, optional): The device to patch to. Defaults to None.
|
273 |
+
"""
|
274 |
+
if key not in self.patches:
|
275 |
+
return
|
276 |
+
|
277 |
+
weight = util.get_attr(self.model, key)
|
278 |
+
|
279 |
+
inplace_update = self.weight_inplace_update
|
280 |
+
|
281 |
+
if key not in self.backup:
|
282 |
+
self.backup[key] = weight.to(
|
283 |
+
device=self.offload_device, copy=inplace_update
|
284 |
+
)
|
285 |
+
|
286 |
+
if device_to is not None:
|
287 |
+
temp_weight = Device.cast_to_device(
|
288 |
+
weight, device_to, torch.float32, copy=True
|
289 |
+
)
|
290 |
+
else:
|
291 |
+
temp_weight = weight.to(torch.float32, copy=True)
|
292 |
+
out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(
|
293 |
+
weight.dtype
|
294 |
+
)
|
295 |
+
if inplace_update:
|
296 |
+
util.copy_to_param(self.model, key, out_weight)
|
297 |
+
else:
|
298 |
+
util.set_attr_param(self.model, key, out_weight)
|
299 |
+
|
300 |
+
def load(
|
301 |
+
self,
|
302 |
+
device_to: torch.device = None,
|
303 |
+
lowvram_model_memory: int = 0,
|
304 |
+
force_patch_weights: bool = False,
|
305 |
+
full_load: bool = False,
|
306 |
+
):
|
307 |
+
"""#### Load the model.
|
308 |
+
|
309 |
+
#### Args:
|
310 |
+
- `device_to` (torch.device, optional): The device to load to. Defaults to None.
|
311 |
+
- `lowvram_model_memory` (int, optional): The low VRAM model memory. Defaults to 0.
|
312 |
+
- `force_patch_weights` (bool, optional): Whether to force patch weights. Defaults to False.
|
313 |
+
- `full_load` (bool, optional): Whether to fully load the model. Defaults to False.
|
314 |
+
"""
|
315 |
+
mem_counter = 0
|
316 |
+
patch_counter = 0
|
317 |
+
lowvram_counter = 0
|
318 |
+
loading = []
|
319 |
+
for n, m in self.model.named_modules():
|
320 |
+
if hasattr(m, "comfy_cast_weights") or hasattr(m, "weight"):
|
321 |
+
loading.append((Device.module_size(m), n, m))
|
322 |
+
|
323 |
+
load_completely = []
|
324 |
+
loading.sort(reverse=True)
|
325 |
+
for x in loading:
|
326 |
+
n = x[1]
|
327 |
+
m = x[2]
|
328 |
+
module_mem = x[0]
|
329 |
+
|
330 |
+
lowvram_weight = False
|
331 |
+
|
332 |
+
if not full_load and hasattr(m, "comfy_cast_weights"):
|
333 |
+
if mem_counter + module_mem >= lowvram_model_memory:
|
334 |
+
lowvram_weight = True
|
335 |
+
lowvram_counter += 1
|
336 |
+
if hasattr(m, "prev_comfy_cast_weights"): # Already lowvramed
|
337 |
+
continue
|
338 |
+
|
339 |
+
weight_key = "{}.weight".format(n)
|
340 |
+
bias_key = "{}.bias".format(n)
|
341 |
+
|
342 |
+
if lowvram_weight:
|
343 |
+
if weight_key in self.patches:
|
344 |
+
if force_patch_weights:
|
345 |
+
self.patch_weight_to_device(weight_key)
|
346 |
+
if bias_key in self.patches:
|
347 |
+
if force_patch_weights:
|
348 |
+
self.patch_weight_to_device(bias_key)
|
349 |
+
|
350 |
+
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
351 |
+
m.comfy_cast_weights = True
|
352 |
+
else:
|
353 |
+
if hasattr(m, "comfy_cast_weights"):
|
354 |
+
if m.comfy_cast_weights:
|
355 |
+
wipe_lowvram_weight(m)
|
356 |
+
|
357 |
+
if hasattr(m, "weight"):
|
358 |
+
mem_counter += module_mem
|
359 |
+
load_completely.append((module_mem, n, m))
|
360 |
+
|
361 |
+
load_completely.sort(reverse=True)
|
362 |
+
for x in load_completely:
|
363 |
+
n = x[1]
|
364 |
+
m = x[2]
|
365 |
+
weight_key = "{}.weight".format(n)
|
366 |
+
bias_key = "{}.bias".format(n)
|
367 |
+
if hasattr(m, "comfy_patched_weights"):
|
368 |
+
if m.comfy_patched_weights is True:
|
369 |
+
continue
|
370 |
+
|
371 |
+
self.patch_weight_to_device(weight_key, device_to=device_to)
|
372 |
+
self.patch_weight_to_device(bias_key, device_to=device_to)
|
373 |
+
logging.debug("lowvram: loaded module regularly {} {}".format(n, m))
|
374 |
+
m.comfy_patched_weights = True
|
375 |
+
|
376 |
+
for x in load_completely:
|
377 |
+
x[2].to(device_to)
|
378 |
+
|
379 |
+
if lowvram_counter > 0:
|
380 |
+
logging.info(
|
381 |
+
"loaded partially {} {} {}".format(
|
382 |
+
lowvram_model_memory / (1024 * 1024),
|
383 |
+
mem_counter / (1024 * 1024),
|
384 |
+
patch_counter,
|
385 |
+
)
|
386 |
+
)
|
387 |
+
self.model.model_lowvram = True
|
388 |
+
else:
|
389 |
+
logging.info(
|
390 |
+
"loaded completely {} {} {}".format(
|
391 |
+
lowvram_model_memory / (1024 * 1024),
|
392 |
+
mem_counter / (1024 * 1024),
|
393 |
+
full_load,
|
394 |
+
)
|
395 |
+
)
|
396 |
+
self.model.model_lowvram = False
|
397 |
+
if full_load:
|
398 |
+
self.model.to(device_to)
|
399 |
+
mem_counter = self.model_size()
|
400 |
+
|
401 |
+
|
402 |
+
self.model.lowvram_patch_counter += patch_counter
|
403 |
+
self.model.device = device_to
|
404 |
+
self.model.model_loaded_weight_memory = mem_counter
|
405 |
+
|
406 |
+
def patch_model_flux(
|
407 |
+
self,
|
408 |
+
device_to: torch.device = None,
|
409 |
+
lowvram_model_memory: int =0,
|
410 |
+
load_weights: bool = True,
|
411 |
+
force_patch_weights: bool = False,
|
412 |
+
):
|
413 |
+
"""#### Patch the model.
|
414 |
+
|
415 |
+
#### Args:
|
416 |
+
- `device_to` (torch.device, optional): The device to patch to. Defaults to None.
|
417 |
+
- `lowvram_model_memory` (int, optional): The low VRAM model memory. Defaults to 0.
|
418 |
+
- `load_weights` (bool, optional): Whether to load weights. Defaults to True.
|
419 |
+
- `force_patch_weights` (bool, optional): Whether to force patch weights. Defaults to False.
|
420 |
+
|
421 |
+
#### Returns:
|
422 |
+
- `torch.nn.Module`: The patched model.
|
423 |
+
"""
|
424 |
+
for k in self.object_patches:
|
425 |
+
old = util.set_attr(self.model, k, self.object_patches[k])
|
426 |
+
if k not in self.object_patches_backup:
|
427 |
+
self.object_patches_backup[k] = old
|
428 |
+
|
429 |
+
if lowvram_model_memory == 0:
|
430 |
+
full_load = True
|
431 |
+
else:
|
432 |
+
full_load = False
|
433 |
+
|
434 |
+
if load_weights:
|
435 |
+
self.load(
|
436 |
+
device_to,
|
437 |
+
lowvram_model_memory=lowvram_model_memory,
|
438 |
+
force_patch_weights=force_patch_weights,
|
439 |
+
full_load=full_load,
|
440 |
+
)
|
441 |
+
return self.model
|
442 |
+
|
443 |
+
def patch_model_lowvram_flux(
|
444 |
+
self,
|
445 |
+
device_to: torch.device = None,
|
446 |
+
lowvram_model_memory: int = 0,
|
447 |
+
force_patch_weights: bool = False,
|
448 |
+
) -> torch.nn.Module:
|
449 |
+
"""#### Patch the model for low VRAM.
|
450 |
+
|
451 |
+
#### Args:
|
452 |
+
- `device_to` (torch.device, optional): The device to patch to. Defaults to None.
|
453 |
+
- `lowvram_model_memory` (int, optional): The low VRAM model memory. Defaults to 0.
|
454 |
+
- `force_patch_weights` (bool, optional): Whether to force patch weights. Defaults to False.
|
455 |
+
|
456 |
+
#### Returns:
|
457 |
+
- `torch.nn.Module`: The patched model.
|
458 |
+
"""
|
459 |
+
self.patch_model(device_to)
|
460 |
+
|
461 |
+
logging.info(
|
462 |
+
"loading in lowvram mode {}".format(lowvram_model_memory / (1024 * 1024))
|
463 |
+
)
|
464 |
+
|
465 |
+
class LowVramPatch:
|
466 |
+
def __init__(self, key: str, model_patcher: "ModelPatcher"):
|
467 |
+
self.key = key
|
468 |
+
self.model_patcher = model_patcher
|
469 |
+
|
470 |
+
def __call__(self, weight: torch.Tensor) -> torch.Tensor:
|
471 |
+
return self.model_patcher.calculate_weight(
|
472 |
+
self.model_patcher.patches[self.key], weight, self.key
|
473 |
+
)
|
474 |
+
|
475 |
+
mem_counter = 0
|
476 |
+
patch_counter = 0
|
477 |
+
for n, m in self.model.named_modules():
|
478 |
+
lowvram_weight = False
|
479 |
+
if hasattr(m, "comfy_cast_weights"):
|
480 |
+
module_mem = Device.module_size(m)
|
481 |
+
if mem_counter + module_mem >= lowvram_model_memory:
|
482 |
+
lowvram_weight = True
|
483 |
+
|
484 |
+
weight_key = "{}.weight".format(n)
|
485 |
+
bias_key = "{}.bias".format(n)
|
486 |
+
|
487 |
+
if lowvram_weight:
|
488 |
+
if weight_key in self.patches:
|
489 |
+
if force_patch_weights:
|
490 |
+
self.patch_weight_to_device(weight_key)
|
491 |
+
else:
|
492 |
+
m.weight_function = LowVramPatch(weight_key, self)
|
493 |
+
patch_counter += 1
|
494 |
+
if bias_key in self.patches:
|
495 |
+
if force_patch_weights:
|
496 |
+
self.patch_weight_to_device(bias_key)
|
497 |
+
else:
|
498 |
+
m.bias_function = LowVramPatch(bias_key, self)
|
499 |
+
patch_counter += 1
|
500 |
+
|
501 |
+
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
502 |
+
m.comfy_cast_weights = True
|
503 |
+
else:
|
504 |
+
if hasattr(m, "weight"):
|
505 |
+
self.patch_weight_to_device(weight_key, device_to)
|
506 |
+
self.patch_weight_to_device(bias_key, device_to)
|
507 |
+
m.to(device_to)
|
508 |
+
mem_counter += Device.module_size(m)
|
509 |
+
logging.debug("lowvram: loaded module regularly {}".format(m))
|
510 |
+
|
511 |
+
self.model_lowvram = True
|
512 |
+
self.lowvram_patch_counter = patch_counter
|
513 |
+
return self.model
|
514 |
+
|
515 |
+
def patch_model(
|
516 |
+
self, device_to: torch.device = None, patch_weights: bool = True
|
517 |
+
) -> torch.nn.Module:
|
518 |
+
"""#### Patch the model.
|
519 |
+
|
520 |
+
#### Args:
|
521 |
+
- `device_to` (torch.device, optional): The device to patch to. Defaults to None.
|
522 |
+
- `patch_weights` (bool, optional): Whether to patch weights. Defaults to True.
|
523 |
+
|
524 |
+
#### Returns:
|
525 |
+
- `torch.nn.Module`: The patched model.
|
526 |
+
"""
|
527 |
+
for k in self.object_patches:
|
528 |
+
old = util.set_attr(self.model, k, self.object_patches[k])
|
529 |
+
if k not in self.object_patches_backup:
|
530 |
+
self.object_patches_backup[k] = old
|
531 |
+
|
532 |
+
if patch_weights:
|
533 |
+
model_sd = self.model_state_dict()
|
534 |
+
for key in self.patches:
|
535 |
+
if key not in model_sd:
|
536 |
+
logging.warning(
|
537 |
+
"could not patch. key doesn't exist in model: {}".format(key)
|
538 |
+
)
|
539 |
+
continue
|
540 |
+
|
541 |
+
self.patch_weight_to_device(key, device_to)
|
542 |
+
|
543 |
+
if device_to is not None:
|
544 |
+
self.model.to(device_to)
|
545 |
+
self.current_device = device_to
|
546 |
+
|
547 |
+
return self.model
|
548 |
+
|
549 |
+
def patch_model_lowvram(
|
550 |
+
self,
|
551 |
+
device_to: torch.device = None,
|
552 |
+
lowvram_model_memory: int = 0,
|
553 |
+
force_patch_weights: bool = False,
|
554 |
+
) -> torch.nn.Module:
|
555 |
+
"""#### Patch the model for low VRAM.
|
556 |
+
|
557 |
+
#### Args:
|
558 |
+
- `device_to` (torch.device, optional): The device to patch to. Defaults to None.
|
559 |
+
- `lowvram_model_memory` (int, optional): The low VRAM model memory. Defaults to 0.
|
560 |
+
- `force_patch_weights` (bool, optional): Whether to force patch weights. Defaults to False.
|
561 |
+
|
562 |
+
#### Returns:
|
563 |
+
- `torch.nn.Module`: The patched model.
|
564 |
+
"""
|
565 |
+
self.patch_model(device_to, patch_weights=False)
|
566 |
+
|
567 |
+
logging.info(
|
568 |
+
"loading in lowvram mode {}".format(lowvram_model_memory / (1024 * 1024))
|
569 |
+
)
|
570 |
+
|
571 |
+
class LowVramPatch:
|
572 |
+
def __init__(self, key: str, model_patcher: "ModelPatcher"):
|
573 |
+
self.key = key
|
574 |
+
self.model_patcher = model_patcher
|
575 |
+
|
576 |
+
def __call__(self, weight: torch.Tensor) -> torch.Tensor:
|
577 |
+
return self.model_patcher.calculate_weight(
|
578 |
+
self.model_patcher.patches[self.key], weight, self.key
|
579 |
+
)
|
580 |
+
|
581 |
+
mem_counter = 0
|
582 |
+
patch_counter = 0
|
583 |
+
for n, m in self.model.named_modules():
|
584 |
+
lowvram_weight = False
|
585 |
+
if hasattr(m, "comfy_cast_weights"):
|
586 |
+
module_mem = Device.module_size(m)
|
587 |
+
if mem_counter + module_mem >= lowvram_model_memory:
|
588 |
+
lowvram_weight = True
|
589 |
+
|
590 |
+
weight_key = "{}.weight".format(n)
|
591 |
+
bias_key = "{}.bias".format(n)
|
592 |
+
|
593 |
+
if lowvram_weight:
|
594 |
+
if weight_key in self.patches:
|
595 |
+
if force_patch_weights:
|
596 |
+
self.patch_weight_to_device(weight_key)
|
597 |
+
else:
|
598 |
+
m.weight_function = LowVramPatch(weight_key, self)
|
599 |
+
patch_counter += 1
|
600 |
+
if bias_key in self.patches:
|
601 |
+
if force_patch_weights:
|
602 |
+
self.patch_weight_to_device(bias_key)
|
603 |
+
else:
|
604 |
+
m.bias_function = LowVramPatch(bias_key, self)
|
605 |
+
patch_counter += 1
|
606 |
+
|
607 |
+
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
608 |
+
m.comfy_cast_weights = True
|
609 |
+
else:
|
610 |
+
if hasattr(m, "weight"):
|
611 |
+
self.patch_weight_to_device(weight_key, device_to)
|
612 |
+
self.patch_weight_to_device(bias_key, device_to)
|
613 |
+
m.to(device_to)
|
614 |
+
mem_counter += Device.module_size(m)
|
615 |
+
logging.debug("lowvram: loaded module regularly {}".format(m))
|
616 |
+
|
617 |
+
self.model_lowvram = True
|
618 |
+
self.lowvram_patch_counter = patch_counter
|
619 |
+
return self.model
|
620 |
+
|
621 |
+
def calculate_weight(
|
622 |
+
self, patches: list, weight: torch.Tensor, key: str
|
623 |
+
) -> torch.Tensor:
|
624 |
+
"""#### Calculate the weight of a key.
|
625 |
+
|
626 |
+
#### Args:
|
627 |
+
- `patches` (list): The list of patches.
|
628 |
+
- `weight` (torch.Tensor): The weight tensor.
|
629 |
+
- `key` (str): The key.
|
630 |
+
|
631 |
+
#### Returns:
|
632 |
+
- `torch.Tensor`: The calculated weight.
|
633 |
+
"""
|
634 |
+
for p in patches:
|
635 |
+
alpha = p[0]
|
636 |
+
v = p[1]
|
637 |
+
p[2]
|
638 |
+
v[0]
|
639 |
+
v = v[1]
|
640 |
+
mat1 = Device.cast_to_device(v[0], weight.device, torch.float32)
|
641 |
+
mat2 = Device.cast_to_device(v[1], weight.device, torch.float32)
|
642 |
+
v[4]
|
643 |
+
if v[2] is not None:
|
644 |
+
alpha *= v[2] / mat2.shape[0]
|
645 |
+
weight += (
|
646 |
+
(alpha * torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)))
|
647 |
+
.reshape(weight.shape)
|
648 |
+
.type(weight.dtype)
|
649 |
+
)
|
650 |
+
return weight
|
651 |
+
|
652 |
+
def unpatch_model(
|
653 |
+
self, device_to: torch.device = None, unpatch_weights: bool = True
|
654 |
+
) -> None:
|
655 |
+
"""#### Unpatch the model.
|
656 |
+
|
657 |
+
#### Args:
|
658 |
+
- `device_to` (torch.device, optional): The device to unpatch to. Defaults to None.
|
659 |
+
- `unpatch_weights` (bool, optional): Whether to unpatch weights. Defaults to True.
|
660 |
+
"""
|
661 |
+
if unpatch_weights:
|
662 |
+
keys = list(self.backup.keys())
|
663 |
+
for k in keys:
|
664 |
+
util.set_attr_param(self.model, k, self.backup[k])
|
665 |
+
self.backup.clear()
|
666 |
+
if device_to is not None:
|
667 |
+
self.model.to(device_to)
|
668 |
+
self.current_device = device_to
|
669 |
+
|
670 |
+
keys = list(self.object_patches_backup.keys())
|
671 |
+
self.object_patches_backup.clear()
|
672 |
+
|
673 |
+
def partially_load(self, device_to: torch.device, extra_memory: int = 0) -> int:
|
674 |
+
"""#### Partially load the model.
|
675 |
+
|
676 |
+
#### Args:
|
677 |
+
- `device_to` (torch.device): The device to load to.
|
678 |
+
- `extra_memory` (int, optional): The extra memory. Defaults to 0.
|
679 |
+
|
680 |
+
#### Returns:
|
681 |
+
- `int`: The memory loaded.
|
682 |
+
"""
|
683 |
+
self.unpatch_model(unpatch_weights=False)
|
684 |
+
self.patch_model(patch_weights=False)
|
685 |
+
full_load = False
|
686 |
+
if self.model.model_lowvram is False:
|
687 |
+
return 0
|
688 |
+
if self.model.model_loaded_weight_memory + extra_memory > self.model_size():
|
689 |
+
full_load = True
|
690 |
+
current_used = self.model.model_loaded_weight_memory
|
691 |
+
self.load(
|
692 |
+
device_to,
|
693 |
+
lowvram_model_memory=current_used + extra_memory,
|
694 |
+
full_load=full_load,
|
695 |
+
)
|
696 |
+
return self.model.model_loaded_weight_memory - current_used
|
697 |
+
|
698 |
+
def add_object_patch(self, name, obj):
|
699 |
+
self.object_patches[name] = obj
|
700 |
+
|
701 |
+
def unet_prefix_from_state_dict(state_dict: dict) -> str:
|
702 |
+
"""#### Get the UNet prefix from the state dictionary.
|
703 |
+
|
704 |
+
#### Args:
|
705 |
+
- `state_dict` (dict): The state dictionary.
|
706 |
+
|
707 |
+
#### Returns:
|
708 |
+
- `str`: The UNet prefix.
|
709 |
+
"""
|
710 |
+
candidates = [
|
711 |
+
"model.diffusion_model.", # ldm/sgm models
|
712 |
+
"model.model.", # audio models
|
713 |
+
]
|
714 |
+
counts = {k: 0 for k in candidates}
|
715 |
+
for k in state_dict:
|
716 |
+
for c in candidates:
|
717 |
+
if k.startswith(c):
|
718 |
+
counts[c] += 1
|
719 |
+
break
|
720 |
+
|
721 |
+
top = max(counts, key=counts.get)
|
722 |
+
if counts[top] > 5:
|
723 |
+
return top
|
724 |
+
else:
|
725 |
+
return "model." # aura flow and others
|
726 |
+
|
727 |
+
def load_diffusion_model_state_dict(
|
728 |
+
sd, model_options={}
|
729 |
+
) -> ModelPatcher:
|
730 |
+
"""#### Load the diffusion model state dictionary.
|
731 |
+
|
732 |
+
#### Args:
|
733 |
+
- `sd`: The state dictionary.
|
734 |
+
- `model_options` (dict, optional): The model options. Defaults to {}.
|
735 |
+
|
736 |
+
#### Returns:
|
737 |
+
- `ModelPatcher`: The model patcher.
|
738 |
+
"""
|
739 |
+
# load unet in diffusers or regular format
|
740 |
+
dtype = model_options.get("dtype", None)
|
741 |
+
|
742 |
+
# Allow loading unets from checkpoint files
|
743 |
+
diffusion_model_prefix = unet_prefix_from_state_dict(sd)
|
744 |
+
temp_sd = util.state_dict_prefix_replace(
|
745 |
+
sd, {diffusion_model_prefix: ""}, filter_keys=True
|
746 |
+
)
|
747 |
+
if len(temp_sd) > 0:
|
748 |
+
sd = temp_sd
|
749 |
+
|
750 |
+
parameters = util.calculate_parameters(sd)
|
751 |
+
load_device = Device.get_torch_device()
|
752 |
+
model_config = unet.model_config_from_unet(sd, "")
|
753 |
+
|
754 |
+
if model_config is not None:
|
755 |
+
new_sd = sd
|
756 |
+
|
757 |
+
offload_device = Device.unet_offload_device()
|
758 |
+
if dtype is None:
|
759 |
+
unet_dtype2 = Device.unet_dtype(
|
760 |
+
model_params=parameters,
|
761 |
+
supported_dtypes=model_config.supported_inference_dtypes,
|
762 |
+
)
|
763 |
+
else:
|
764 |
+
unet_dtype2 = dtype
|
765 |
+
|
766 |
+
manual_cast_dtype = Device.unet_manual_cast(
|
767 |
+
unet_dtype2, load_device, model_config.supported_inference_dtypes
|
768 |
+
)
|
769 |
+
model_config.set_inference_dtype(unet_dtype2, manual_cast_dtype)
|
770 |
+
model_config.custom_operations = model_options.get(
|
771 |
+
"custom_operations", model_config.custom_operations
|
772 |
+
)
|
773 |
+
model = model_config.get_model(new_sd, "")
|
774 |
+
model = model.to(offload_device)
|
775 |
+
model.load_model_weights(new_sd, "")
|
776 |
+
left_over = sd.keys()
|
777 |
+
if len(left_over) > 0:
|
778 |
+
logging.info("left over keys in unet: {}".format(left_over))
|
779 |
+
return ModelPatcher(model, load_device=load_device, offload_device=offload_device)
|
modules/NeuralNetwork/transformer.py
ADDED
@@ -0,0 +1,443 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from einops import rearrange
|
2 |
+
import torch
|
3 |
+
from modules.Utilities import util
|
4 |
+
import torch.nn as nn
|
5 |
+
from modules.Attention import Attention
|
6 |
+
from modules.Device import Device
|
7 |
+
from modules.cond import Activation
|
8 |
+
from modules.cond import cast
|
9 |
+
from modules.sample import sampling_util
|
10 |
+
|
11 |
+
if Device.xformers_enabled():
|
12 |
+
pass
|
13 |
+
|
14 |
+
ops = cast.disable_weight_init
|
15 |
+
|
16 |
+
_ATTN_PRECISION = "fp32"
|
17 |
+
|
18 |
+
|
19 |
+
class FeedForward(nn.Module):
|
20 |
+
"""#### FeedForward neural network module.
|
21 |
+
|
22 |
+
#### Args:
|
23 |
+
- `dim` (int): The input dimension.
|
24 |
+
- `dim_out` (int, optional): The output dimension. Defaults to None.
|
25 |
+
- `mult` (int, optional): The multiplier for the inner dimension. Defaults to 4.
|
26 |
+
- `glu` (bool, optional): Whether to use Gated Linear Units. Defaults to False.
|
27 |
+
- `dropout` (float, optional): The dropout rate. Defaults to 0.0.
|
28 |
+
- `dtype` (torch.dtype, optional): The data type. Defaults to None.
|
29 |
+
- `device` (torch.device, optional): The device. Defaults to None.
|
30 |
+
- `operations` (object, optional): The operations module. Defaults to `ops`.
|
31 |
+
"""
|
32 |
+
|
33 |
+
def __init__(
|
34 |
+
self,
|
35 |
+
dim: int,
|
36 |
+
dim_out: int = None,
|
37 |
+
mult: int = 4,
|
38 |
+
glu: bool = False,
|
39 |
+
dropout: float = 0.0,
|
40 |
+
dtype: torch.dtype = None,
|
41 |
+
device: torch.device = None,
|
42 |
+
operations: object = ops,
|
43 |
+
):
|
44 |
+
super().__init__()
|
45 |
+
inner_dim = int(dim * mult)
|
46 |
+
dim_out = util.default(dim_out, dim)
|
47 |
+
project_in = (
|
48 |
+
nn.Sequential(
|
49 |
+
operations.Linear(dim, inner_dim, dtype=dtype, device=device), nn.GELU()
|
50 |
+
)
|
51 |
+
if not glu
|
52 |
+
else Activation.GEGLU(dim, inner_dim)
|
53 |
+
)
|
54 |
+
|
55 |
+
self.net = nn.Sequential(
|
56 |
+
project_in,
|
57 |
+
nn.Dropout(dropout),
|
58 |
+
operations.Linear(inner_dim, dim_out, dtype=dtype, device=device),
|
59 |
+
)
|
60 |
+
|
61 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
62 |
+
"""#### Forward pass of the FeedForward network.
|
63 |
+
|
64 |
+
#### Args:
|
65 |
+
- `x` (torch.Tensor): The input tensor.
|
66 |
+
|
67 |
+
#### Returns:
|
68 |
+
- `torch.Tensor`: The output tensor.
|
69 |
+
"""
|
70 |
+
return self.net(x)
|
71 |
+
|
72 |
+
|
73 |
+
class BasicTransformerBlock(nn.Module):
|
74 |
+
"""#### Basic Transformer block.
|
75 |
+
|
76 |
+
#### Args:
|
77 |
+
- `dim` (int): The input dimension.
|
78 |
+
- `n_heads` (int): The number of attention heads.
|
79 |
+
- `d_head` (int): The dimension of each attention head.
|
80 |
+
- `dropout` (float, optional): The dropout rate. Defaults to 0.0.
|
81 |
+
- `context_dim` (int, optional): The context dimension. Defaults to None.
|
82 |
+
- `gated_ff` (bool, optional): Whether to use Gated FeedForward. Defaults to True.
|
83 |
+
- `checkpoint` (bool, optional): Whether to use checkpointing. Defaults to True.
|
84 |
+
- `ff_in` (bool, optional): Whether to use FeedForward input. Defaults to False.
|
85 |
+
- `inner_dim` (int, optional): The inner dimension. Defaults to None.
|
86 |
+
- `disable_self_attn` (bool, optional): Whether to disable self-attention. Defaults to False.
|
87 |
+
- `disable_temporal_crossattention` (bool, optional): Whether to disable temporal cross-attention. Defaults to False.
|
88 |
+
- `switch_temporal_ca_to_sa` (bool, optional): Whether to switch temporal cross-attention to self-attention. Defaults to False.
|
89 |
+
- `dtype` (torch.dtype, optional): The data type. Defaults to None.
|
90 |
+
- `device` (torch.device, optional): The device. Defaults to None.
|
91 |
+
- `operations` (object, optional): The operations module. Defaults to `ops`.
|
92 |
+
"""
|
93 |
+
|
94 |
+
def __init__(
|
95 |
+
self,
|
96 |
+
dim: int,
|
97 |
+
n_heads: int,
|
98 |
+
d_head: int,
|
99 |
+
dropout: float = 0.0,
|
100 |
+
context_dim: int = None,
|
101 |
+
gated_ff: bool = True,
|
102 |
+
checkpoint: bool = True,
|
103 |
+
ff_in: bool = False,
|
104 |
+
inner_dim: int = None,
|
105 |
+
disable_self_attn: bool = False,
|
106 |
+
disable_temporal_crossattention: bool = False,
|
107 |
+
switch_temporal_ca_to_sa: bool = False,
|
108 |
+
dtype: torch.dtype = None,
|
109 |
+
device: torch.device = None,
|
110 |
+
operations: object = ops,
|
111 |
+
):
|
112 |
+
super().__init__()
|
113 |
+
|
114 |
+
self.ff_in = ff_in or inner_dim is not None
|
115 |
+
if inner_dim is None:
|
116 |
+
inner_dim = dim
|
117 |
+
|
118 |
+
self.is_res = inner_dim == dim
|
119 |
+
self.disable_self_attn = disable_self_attn
|
120 |
+
self.attn1 = Attention.CrossAttention(
|
121 |
+
query_dim=inner_dim,
|
122 |
+
heads=n_heads,
|
123 |
+
dim_head=d_head,
|
124 |
+
dropout=dropout,
|
125 |
+
context_dim=context_dim if self.disable_self_attn else None,
|
126 |
+
dtype=dtype,
|
127 |
+
device=device,
|
128 |
+
operations=operations,
|
129 |
+
) # is a self-attention if not self.disable_self_attn
|
130 |
+
self.ff = FeedForward(
|
131 |
+
inner_dim,
|
132 |
+
dim_out=dim,
|
133 |
+
dropout=dropout,
|
134 |
+
glu=gated_ff,
|
135 |
+
dtype=dtype,
|
136 |
+
device=device,
|
137 |
+
operations=operations,
|
138 |
+
)
|
139 |
+
|
140 |
+
context_dim_attn2 = None
|
141 |
+
if not switch_temporal_ca_to_sa:
|
142 |
+
context_dim_attn2 = context_dim
|
143 |
+
|
144 |
+
self.attn2 = Attention.CrossAttention(
|
145 |
+
query_dim=inner_dim,
|
146 |
+
context_dim=context_dim_attn2,
|
147 |
+
heads=n_heads,
|
148 |
+
dim_head=d_head,
|
149 |
+
dropout=dropout,
|
150 |
+
dtype=dtype,
|
151 |
+
device=device,
|
152 |
+
operations=operations,
|
153 |
+
) # is self-attn if context is none
|
154 |
+
self.norm2 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)
|
155 |
+
|
156 |
+
self.norm1 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)
|
157 |
+
self.norm3 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)
|
158 |
+
self.checkpoint = checkpoint
|
159 |
+
self.n_heads = n_heads
|
160 |
+
self.d_head = d_head
|
161 |
+
self.switch_temporal_ca_to_sa = switch_temporal_ca_to_sa
|
162 |
+
|
163 |
+
def forward(
|
164 |
+
self,
|
165 |
+
x: torch.Tensor,
|
166 |
+
context: torch.Tensor = None,
|
167 |
+
transformer_options: dict = {},
|
168 |
+
) -> torch.Tensor:
|
169 |
+
"""#### Forward pass of the Basic Transformer block.
|
170 |
+
|
171 |
+
#### Args:
|
172 |
+
- `x` (torch.Tensor): The input tensor.
|
173 |
+
- `context` (torch.Tensor, optional): The context tensor. Defaults to None.
|
174 |
+
- `transformer_options` (dict, optional): Additional transformer options. Defaults to {}.
|
175 |
+
|
176 |
+
#### Returns:
|
177 |
+
- `torch.Tensor`: The output tensor.
|
178 |
+
"""
|
179 |
+
return sampling_util.checkpoint(
|
180 |
+
self._forward,
|
181 |
+
(x, context, transformer_options),
|
182 |
+
self.parameters(),
|
183 |
+
self.checkpoint,
|
184 |
+
)
|
185 |
+
|
186 |
+
def _forward(
|
187 |
+
self,
|
188 |
+
x: torch.Tensor,
|
189 |
+
context: torch.Tensor = None,
|
190 |
+
transformer_options: dict = {},
|
191 |
+
) -> torch.Tensor:
|
192 |
+
"""#### Internal forward pass of the Basic Transformer block.
|
193 |
+
|
194 |
+
#### Args:
|
195 |
+
- `x` (torch.Tensor): The input tensor.
|
196 |
+
- `context` (torch.Tensor, optional): The context tensor. Defaults to None.
|
197 |
+
- `transformer_options` (dict, optional): Additional transformer options. Defaults to {}.
|
198 |
+
|
199 |
+
#### Returns:
|
200 |
+
- `torch.Tensor`: The output tensor.
|
201 |
+
"""
|
202 |
+
extra_options = {}
|
203 |
+
block = transformer_options.get("block", None)
|
204 |
+
block_index = transformer_options.get("block_index", 0)
|
205 |
+
transformer_patches_replace = {}
|
206 |
+
|
207 |
+
for k in transformer_options:
|
208 |
+
extra_options[k] = transformer_options[k]
|
209 |
+
|
210 |
+
extra_options["n_heads"] = self.n_heads
|
211 |
+
extra_options["dim_head"] = self.d_head
|
212 |
+
|
213 |
+
n = self.norm1(x)
|
214 |
+
context_attn1 = None
|
215 |
+
value_attn1 = None
|
216 |
+
|
217 |
+
transformer_block = (block[0], block[1], block_index)
|
218 |
+
attn1_replace_patch = transformer_patches_replace.get("attn1", {})
|
219 |
+
block_attn1 = transformer_block
|
220 |
+
if block_attn1 not in attn1_replace_patch:
|
221 |
+
block_attn1 = block
|
222 |
+
|
223 |
+
n = self.attn1(n, context=context_attn1, value=value_attn1)
|
224 |
+
|
225 |
+
x += n
|
226 |
+
|
227 |
+
if self.attn2 is not None:
|
228 |
+
n = self.norm2(x)
|
229 |
+
context_attn2 = context
|
230 |
+
value_attn2 = None
|
231 |
+
|
232 |
+
attn2_replace_patch = transformer_patches_replace.get("attn2", {})
|
233 |
+
block_attn2 = transformer_block
|
234 |
+
if block_attn2 not in attn2_replace_patch:
|
235 |
+
block_attn2 = block
|
236 |
+
n = self.attn2(n, context=context_attn2, value=value_attn2)
|
237 |
+
|
238 |
+
x += n
|
239 |
+
if self.is_res:
|
240 |
+
x_skip = x
|
241 |
+
x = self.ff(self.norm3(x))
|
242 |
+
if self.is_res:
|
243 |
+
x += x_skip
|
244 |
+
|
245 |
+
return x
|
246 |
+
|
247 |
+
|
248 |
+
class SpatialTransformer(nn.Module):
|
249 |
+
"""#### Spatial Transformer module.
|
250 |
+
|
251 |
+
#### Args:
|
252 |
+
- `in_channels` (int): The number of input channels.
|
253 |
+
- `n_heads` (int): The number of attention heads.
|
254 |
+
- `d_head` (int): The dimension of each attention head.
|
255 |
+
- `depth` (int, optional): The depth of the transformer. Defaults to 1.
|
256 |
+
- `dropout` (float, optional): The dropout rate. Defaults to 0.0.
|
257 |
+
- `context_dim` (int, optional): The context dimension. Defaults to None.
|
258 |
+
- `disable_self_attn` (bool, optional): Whether to disable self-attention. Defaults to False.
|
259 |
+
- `use_linear` (bool, optional): Whether to use linear projections. Defaults to False.
|
260 |
+
- `use_checkpoint` (bool, optional): Whether to use checkpointing. Defaults to True.
|
261 |
+
- `dtype` (torch.dtype, optional): The data type. Defaults to None.
|
262 |
+
- `device` (torch.device, optional): The device. Defaults to None.
|
263 |
+
- `operations` (object, optional): The operations module. Defaults to `ops`.
|
264 |
+
"""
|
265 |
+
|
266 |
+
def __init__(
|
267 |
+
self,
|
268 |
+
in_channels: int,
|
269 |
+
n_heads: int,
|
270 |
+
d_head: int,
|
271 |
+
depth: int = 1,
|
272 |
+
dropout: float = 0.0,
|
273 |
+
context_dim: int = None,
|
274 |
+
disable_self_attn: bool = False,
|
275 |
+
use_linear: bool = False,
|
276 |
+
use_checkpoint: bool = True,
|
277 |
+
dtype: torch.dtype = None,
|
278 |
+
device: torch.device = None,
|
279 |
+
operations: object = ops,
|
280 |
+
):
|
281 |
+
super().__init__()
|
282 |
+
if util.exists(context_dim) and not isinstance(context_dim, list):
|
283 |
+
context_dim = [context_dim] * depth
|
284 |
+
self.in_channels = in_channels
|
285 |
+
inner_dim = n_heads * d_head
|
286 |
+
self.norm = operations.GroupNorm(
|
287 |
+
num_groups=32,
|
288 |
+
num_channels=in_channels,
|
289 |
+
eps=1e-6,
|
290 |
+
affine=True,
|
291 |
+
dtype=dtype,
|
292 |
+
device=device,
|
293 |
+
)
|
294 |
+
if not use_linear:
|
295 |
+
self.proj_in = operations.Conv2d(
|
296 |
+
in_channels,
|
297 |
+
inner_dim,
|
298 |
+
kernel_size=1,
|
299 |
+
stride=1,
|
300 |
+
padding=0,
|
301 |
+
dtype=dtype,
|
302 |
+
device=device,
|
303 |
+
)
|
304 |
+
else:
|
305 |
+
self.proj_in = operations.Linear(
|
306 |
+
in_channels, inner_dim, dtype=dtype, device=device
|
307 |
+
)
|
308 |
+
|
309 |
+
self.transformer_blocks = nn.ModuleList(
|
310 |
+
[
|
311 |
+
BasicTransformerBlock(
|
312 |
+
inner_dim,
|
313 |
+
n_heads,
|
314 |
+
d_head,
|
315 |
+
dropout=dropout,
|
316 |
+
context_dim=context_dim[d],
|
317 |
+
disable_self_attn=disable_self_attn,
|
318 |
+
checkpoint=use_checkpoint,
|
319 |
+
dtype=dtype,
|
320 |
+
device=device,
|
321 |
+
operations=operations,
|
322 |
+
)
|
323 |
+
for d in range(depth)
|
324 |
+
]
|
325 |
+
)
|
326 |
+
if not use_linear:
|
327 |
+
self.proj_out = operations.Conv2d(
|
328 |
+
inner_dim,
|
329 |
+
in_channels,
|
330 |
+
kernel_size=1,
|
331 |
+
stride=1,
|
332 |
+
padding=0,
|
333 |
+
dtype=dtype,
|
334 |
+
device=device,
|
335 |
+
)
|
336 |
+
else:
|
337 |
+
self.proj_out = operations.Linear(
|
338 |
+
in_channels, inner_dim, dtype=dtype, device=device
|
339 |
+
)
|
340 |
+
self.use_linear = use_linear
|
341 |
+
|
342 |
+
def forward(
|
343 |
+
self,
|
344 |
+
x: torch.Tensor,
|
345 |
+
context: torch.Tensor = None,
|
346 |
+
transformer_options: dict = {},
|
347 |
+
) -> torch.Tensor:
|
348 |
+
"""#### Forward pass of the Spatial Transformer.
|
349 |
+
|
350 |
+
#### Args:
|
351 |
+
- `x` (torch.Tensor): The input tensor.
|
352 |
+
- `context` (torch.Tensor, optional): The context tensor. Defaults to None.
|
353 |
+
- `transformer_options` (dict, optional): Additional transformer options. Defaults to {}.
|
354 |
+
|
355 |
+
#### Returns:
|
356 |
+
- `torch.Tensor`: The output tensor.
|
357 |
+
"""
|
358 |
+
# note: if no context is given, cross-attention defaults to self-attention
|
359 |
+
if not isinstance(context, list):
|
360 |
+
context = [context] * len(self.transformer_blocks)
|
361 |
+
b, c, h, w = x.shape
|
362 |
+
x_in = x
|
363 |
+
x = self.norm(x)
|
364 |
+
if not self.use_linear:
|
365 |
+
x = self.proj_in(x)
|
366 |
+
x = rearrange(x, "b c h w -> b (h w) c").contiguous()
|
367 |
+
if self.use_linear:
|
368 |
+
x = self.proj_in(x)
|
369 |
+
for i, block in enumerate(self.transformer_blocks):
|
370 |
+
transformer_options["block_index"] = i
|
371 |
+
x = block(x, context=context[i], transformer_options=transformer_options)
|
372 |
+
if self.use_linear:
|
373 |
+
x = self.proj_out(x)
|
374 |
+
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
|
375 |
+
if not self.use_linear:
|
376 |
+
x = self.proj_out(x)
|
377 |
+
return x + x_in
|
378 |
+
|
379 |
+
|
380 |
+
def count_blocks(state_dict_keys: list, prefix_string: str) -> int:
|
381 |
+
"""#### Count the number of blocks in a state dictionary.
|
382 |
+
|
383 |
+
#### Args:
|
384 |
+
- `state_dict_keys` (list): The list of state dictionary keys.
|
385 |
+
- `prefix_string` (str): The prefix string to match.
|
386 |
+
|
387 |
+
#### Returns:
|
388 |
+
- `int`: The number of blocks.
|
389 |
+
"""
|
390 |
+
count = 0
|
391 |
+
while True:
|
392 |
+
c = False
|
393 |
+
for k in state_dict_keys:
|
394 |
+
if k.startswith(prefix_string.format(count)):
|
395 |
+
c = True
|
396 |
+
break
|
397 |
+
if c is False:
|
398 |
+
break
|
399 |
+
count += 1
|
400 |
+
return count
|
401 |
+
|
402 |
+
|
403 |
+
def calculate_transformer_depth(
|
404 |
+
prefix: str, state_dict_keys: list, state_dict: dict
|
405 |
+
) -> tuple:
|
406 |
+
"""#### Calculate the depth of a transformer.
|
407 |
+
|
408 |
+
#### Args:
|
409 |
+
- `prefix` (str): The prefix string.
|
410 |
+
- `state_dict_keys` (list): The list of state dictionary keys.
|
411 |
+
- `state_dict` (dict): The state dictionary.
|
412 |
+
|
413 |
+
#### Returns:
|
414 |
+
- `tuple`: The transformer depth, context dimension, use of linear in transformer, and time stack.
|
415 |
+
"""
|
416 |
+
context_dim = None
|
417 |
+
use_linear_in_transformer = False
|
418 |
+
|
419 |
+
transformer_prefix = prefix + "1.transformer_blocks."
|
420 |
+
transformer_keys = sorted(
|
421 |
+
list(filter(lambda a: a.startswith(transformer_prefix), state_dict_keys))
|
422 |
+
)
|
423 |
+
if len(transformer_keys) > 0:
|
424 |
+
last_transformer_depth = count_blocks(
|
425 |
+
state_dict_keys, transformer_prefix + "{}"
|
426 |
+
)
|
427 |
+
context_dim = state_dict[
|
428 |
+
"{}0.attn2.to_k.weight".format(transformer_prefix)
|
429 |
+
].shape[1]
|
430 |
+
use_linear_in_transformer = (
|
431 |
+
len(state_dict["{}1.proj_in.weight".format(prefix)].shape) == 2
|
432 |
+
)
|
433 |
+
time_stack = (
|
434 |
+
"{}1.time_stack.0.attn1.to_q.weight".format(prefix) in state_dict
|
435 |
+
or "{}1.time_mix_blocks.0.attn1.to_q.weight".format(prefix) in state_dict
|
436 |
+
)
|
437 |
+
return (
|
438 |
+
last_transformer_depth,
|
439 |
+
context_dim,
|
440 |
+
use_linear_in_transformer,
|
441 |
+
time_stack,
|
442 |
+
)
|
443 |
+
return None
|
modules/NeuralNetwork/unet.py
ADDED
@@ -0,0 +1,1132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import math
|
3 |
+
from typing import Any, Dict, List, Optional
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch as th
|
6 |
+
import torch
|
7 |
+
|
8 |
+
from modules.Utilities import util
|
9 |
+
from modules.AutoEncoders import ResBlock
|
10 |
+
from modules.NeuralNetwork import transformer
|
11 |
+
from modules.cond import cast
|
12 |
+
from modules.sample import sampling, sampling_util
|
13 |
+
|
14 |
+
UNET_MAP_ATTENTIONS = {
|
15 |
+
"proj_in.weight",
|
16 |
+
"proj_in.bias",
|
17 |
+
"proj_out.weight",
|
18 |
+
"proj_out.bias",
|
19 |
+
"norm.weight",
|
20 |
+
"norm.bias",
|
21 |
+
}
|
22 |
+
|
23 |
+
TRANSFORMER_BLOCKS = {
|
24 |
+
"norm1.weight",
|
25 |
+
"norm1.bias",
|
26 |
+
"norm2.weight",
|
27 |
+
"norm2.bias",
|
28 |
+
"norm3.weight",
|
29 |
+
"norm3.bias",
|
30 |
+
"attn1.to_q.weight",
|
31 |
+
"attn1.to_k.weight",
|
32 |
+
"attn1.to_v.weight",
|
33 |
+
"attn1.to_out.0.weight",
|
34 |
+
"attn1.to_out.0.bias",
|
35 |
+
"attn2.to_q.weight",
|
36 |
+
"attn2.to_k.weight",
|
37 |
+
"attn2.to_v.weight",
|
38 |
+
"attn2.to_out.0.weight",
|
39 |
+
"attn2.to_out.0.bias",
|
40 |
+
"ff.net.0.proj.weight",
|
41 |
+
"ff.net.0.proj.bias",
|
42 |
+
"ff.net.2.weight",
|
43 |
+
"ff.net.2.bias",
|
44 |
+
}
|
45 |
+
|
46 |
+
UNET_MAP_RESNET = {
|
47 |
+
"in_layers.2.weight": "conv1.weight",
|
48 |
+
"in_layers.2.bias": "conv1.bias",
|
49 |
+
"emb_layers.1.weight": "time_emb_proj.weight",
|
50 |
+
"emb_layers.1.bias": "time_emb_proj.bias",
|
51 |
+
"out_layers.3.weight": "conv2.weight",
|
52 |
+
"out_layers.3.bias": "conv2.bias",
|
53 |
+
"skip_connection.weight": "conv_shortcut.weight",
|
54 |
+
"skip_connection.bias": "conv_shortcut.bias",
|
55 |
+
"in_layers.0.weight": "norm1.weight",
|
56 |
+
"in_layers.0.bias": "norm1.bias",
|
57 |
+
"out_layers.0.weight": "norm2.weight",
|
58 |
+
"out_layers.0.bias": "norm2.bias",
|
59 |
+
}
|
60 |
+
|
61 |
+
UNET_MAP_BASIC = {
|
62 |
+
("label_emb.0.0.weight", "class_embedding.linear_1.weight"),
|
63 |
+
("label_emb.0.0.bias", "class_embedding.linear_1.bias"),
|
64 |
+
("label_emb.0.2.weight", "class_embedding.linear_2.weight"),
|
65 |
+
("label_emb.0.2.bias", "class_embedding.linear_2.bias"),
|
66 |
+
("label_emb.0.0.weight", "add_embedding.linear_1.weight"),
|
67 |
+
("label_emb.0.0.bias", "add_embedding.linear_1.bias"),
|
68 |
+
("label_emb.0.2.weight", "add_embedding.linear_2.weight"),
|
69 |
+
("label_emb.0.2.bias", "add_embedding.linear_2.bias"),
|
70 |
+
("input_blocks.0.0.weight", "conv_in.weight"),
|
71 |
+
("input_blocks.0.0.bias", "conv_in.bias"),
|
72 |
+
("out.0.weight", "conv_norm_out.weight"),
|
73 |
+
("out.0.bias", "conv_norm_out.bias"),
|
74 |
+
("out.2.weight", "conv_out.weight"),
|
75 |
+
("out.2.bias", "conv_out.bias"),
|
76 |
+
("time_embed.0.weight", "time_embedding.linear_1.weight"),
|
77 |
+
("time_embed.0.bias", "time_embedding.linear_1.bias"),
|
78 |
+
("time_embed.2.weight", "time_embedding.linear_2.weight"),
|
79 |
+
("time_embed.2.bias", "time_embedding.linear_2.bias"),
|
80 |
+
}
|
81 |
+
|
82 |
+
# taken from https://github.com/TencentARC/T2I-Adapter
|
83 |
+
|
84 |
+
|
85 |
+
def unet_to_diffusers(unet_config: dict) -> dict:
|
86 |
+
"""#### Convert a UNet configuration to a diffusers configuration.
|
87 |
+
|
88 |
+
#### Args:
|
89 |
+
- `unet_config` (dict): The UNet configuration.
|
90 |
+
|
91 |
+
#### Returns:
|
92 |
+
- `dict`: The diffusers configuration.
|
93 |
+
"""
|
94 |
+
if "num_res_blocks" not in unet_config:
|
95 |
+
return {}
|
96 |
+
num_res_blocks = unet_config["num_res_blocks"]
|
97 |
+
channel_mult = unet_config["channel_mult"]
|
98 |
+
transformer_depth = unet_config["transformer_depth"][:]
|
99 |
+
transformer_depth_output = unet_config["transformer_depth_output"][:]
|
100 |
+
num_blocks = len(channel_mult)
|
101 |
+
|
102 |
+
transformers_mid = unet_config.get("transformer_depth_middle", None)
|
103 |
+
|
104 |
+
diffusers_unet_map = {}
|
105 |
+
for x in range(num_blocks):
|
106 |
+
n = 1 + (num_res_blocks[x] + 1) * x
|
107 |
+
for i in range(num_res_blocks[x]):
|
108 |
+
for b in UNET_MAP_RESNET:
|
109 |
+
diffusers_unet_map[
|
110 |
+
"down_blocks.{}.resnets.{}.{}".format(x, i, UNET_MAP_RESNET[b])
|
111 |
+
] = "input_blocks.{}.0.{}".format(n, b)
|
112 |
+
num_transformers = transformer_depth.pop(0)
|
113 |
+
if num_transformers > 0:
|
114 |
+
for b in UNET_MAP_ATTENTIONS:
|
115 |
+
diffusers_unet_map[
|
116 |
+
"down_blocks.{}.attentions.{}.{}".format(x, i, b)
|
117 |
+
] = "input_blocks.{}.1.{}".format(n, b)
|
118 |
+
for t in range(num_transformers):
|
119 |
+
for b in TRANSFORMER_BLOCKS:
|
120 |
+
diffusers_unet_map[
|
121 |
+
"down_blocks.{}.attentions.{}.transformer_blocks.{}.{}".format(
|
122 |
+
x, i, t, b
|
123 |
+
)
|
124 |
+
] = "input_blocks.{}.1.transformer_blocks.{}.{}".format(n, t, b)
|
125 |
+
n += 1
|
126 |
+
for k in ["weight", "bias"]:
|
127 |
+
diffusers_unet_map["down_blocks.{}.downsamplers.0.conv.{}".format(x, k)] = (
|
128 |
+
"input_blocks.{}.0.op.{}".format(n, k)
|
129 |
+
)
|
130 |
+
|
131 |
+
i = 0
|
132 |
+
for b in UNET_MAP_ATTENTIONS:
|
133 |
+
diffusers_unet_map["mid_block.attentions.{}.{}".format(i, b)] = (
|
134 |
+
"middle_block.1.{}".format(b)
|
135 |
+
)
|
136 |
+
for t in range(transformers_mid):
|
137 |
+
for b in TRANSFORMER_BLOCKS:
|
138 |
+
diffusers_unet_map[
|
139 |
+
"mid_block.attentions.{}.transformer_blocks.{}.{}".format(i, t, b)
|
140 |
+
] = "middle_block.1.transformer_blocks.{}.{}".format(t, b)
|
141 |
+
|
142 |
+
for i, n in enumerate([0, 2]):
|
143 |
+
for b in UNET_MAP_RESNET:
|
144 |
+
diffusers_unet_map[
|
145 |
+
"mid_block.resnets.{}.{}".format(i, UNET_MAP_RESNET[b])
|
146 |
+
] = "middle_block.{}.{}".format(n, b)
|
147 |
+
|
148 |
+
num_res_blocks = list(reversed(num_res_blocks))
|
149 |
+
for x in range(num_blocks):
|
150 |
+
n = (num_res_blocks[x] + 1) * x
|
151 |
+
length = num_res_blocks[x] + 1
|
152 |
+
for i in range(length):
|
153 |
+
c = 0
|
154 |
+
for b in UNET_MAP_RESNET:
|
155 |
+
diffusers_unet_map[
|
156 |
+
"up_blocks.{}.resnets.{}.{}".format(x, i, UNET_MAP_RESNET[b])
|
157 |
+
] = "output_blocks.{}.0.{}".format(n, b)
|
158 |
+
c += 1
|
159 |
+
num_transformers = transformer_depth_output.pop()
|
160 |
+
if num_transformers > 0:
|
161 |
+
c += 1
|
162 |
+
for b in UNET_MAP_ATTENTIONS:
|
163 |
+
diffusers_unet_map[
|
164 |
+
"up_blocks.{}.attentions.{}.{}".format(x, i, b)
|
165 |
+
] = "output_blocks.{}.1.{}".format(n, b)
|
166 |
+
for t in range(num_transformers):
|
167 |
+
for b in TRANSFORMER_BLOCKS:
|
168 |
+
diffusers_unet_map[
|
169 |
+
"up_blocks.{}.attentions.{}.transformer_blocks.{}.{}".format(
|
170 |
+
x, i, t, b
|
171 |
+
)
|
172 |
+
] = "output_blocks.{}.1.transformer_blocks.{}.{}".format(
|
173 |
+
n, t, b
|
174 |
+
)
|
175 |
+
if i == length - 1:
|
176 |
+
for k in ["weight", "bias"]:
|
177 |
+
diffusers_unet_map[
|
178 |
+
"up_blocks.{}.upsamplers.0.conv.{}".format(x, k)
|
179 |
+
] = "output_blocks.{}.{}.conv.{}".format(n, c, k)
|
180 |
+
n += 1
|
181 |
+
|
182 |
+
for k in UNET_MAP_BASIC:
|
183 |
+
diffusers_unet_map[k[1]] = k[0]
|
184 |
+
|
185 |
+
return diffusers_unet_map
|
186 |
+
|
187 |
+
|
188 |
+
def apply_control1(h: th.Tensor, control: any, name: str) -> th.Tensor:
|
189 |
+
"""#### Apply control to a tensor.
|
190 |
+
|
191 |
+
#### Args:
|
192 |
+
- `h` (torch.Tensor): The input tensor.
|
193 |
+
- `control` (any): The control to apply.
|
194 |
+
- `name` (str): The name of the control.
|
195 |
+
|
196 |
+
#### Returns:
|
197 |
+
- `torch.Tensor`: The controlled tensor.
|
198 |
+
"""
|
199 |
+
return h
|
200 |
+
|
201 |
+
|
202 |
+
oai_ops = cast.disable_weight_init
|
203 |
+
|
204 |
+
|
205 |
+
class UNetModel1(nn.Module):
|
206 |
+
"""#### UNet Model class."""
|
207 |
+
|
208 |
+
def __init__(
|
209 |
+
self,
|
210 |
+
image_size: int,
|
211 |
+
in_channels: int,
|
212 |
+
model_channels: int,
|
213 |
+
out_channels: int,
|
214 |
+
num_res_blocks: list,
|
215 |
+
dropout: float = 0,
|
216 |
+
channel_mult: tuple = (1, 2, 4, 8),
|
217 |
+
conv_resample: bool = True,
|
218 |
+
dims: int = 2,
|
219 |
+
num_classes: int = None,
|
220 |
+
use_checkpoint: bool = False,
|
221 |
+
dtype: th.dtype = th.float32,
|
222 |
+
num_heads: int = -1,
|
223 |
+
num_head_channels: int = -1,
|
224 |
+
num_heads_upsample: int = -1,
|
225 |
+
use_scale_shift_norm: bool = False,
|
226 |
+
resblock_updown: bool = False,
|
227 |
+
use_new_attention_order: bool = False,
|
228 |
+
use_spatial_transformer: bool = False, # custom transformer support
|
229 |
+
transformer_depth: int = 1, # custom transformer support
|
230 |
+
context_dim: int = None, # custom transformer support
|
231 |
+
n_embed: int = None, # custom support for prediction of discrete ids into codebook of first stage vq model
|
232 |
+
legacy: bool = True,
|
233 |
+
disable_self_attentions: list = None,
|
234 |
+
num_attention_blocks: list = None,
|
235 |
+
disable_middle_self_attn: bool = False,
|
236 |
+
use_linear_in_transformer: bool = False,
|
237 |
+
adm_in_channels: int = None,
|
238 |
+
transformer_depth_middle: int = None,
|
239 |
+
transformer_depth_output: list = None,
|
240 |
+
use_temporal_resblock: bool = False,
|
241 |
+
use_temporal_attention: bool = False,
|
242 |
+
time_context_dim: int = None,
|
243 |
+
extra_ff_mix_layer: bool = False,
|
244 |
+
use_spatial_context: bool = False,
|
245 |
+
merge_strategy: any = None,
|
246 |
+
merge_factor: float = 0.0,
|
247 |
+
video_kernel_size: int = None,
|
248 |
+
disable_temporal_crossattention: bool = False,
|
249 |
+
max_ddpm_temb_period: int = 10000,
|
250 |
+
device: th.device = None,
|
251 |
+
operations: any = oai_ops,
|
252 |
+
):
|
253 |
+
"""#### Initialize the UNetModel1 class.
|
254 |
+
|
255 |
+
#### Args:
|
256 |
+
- `image_size` (int): The size of the input image.
|
257 |
+
- `in_channels` (int): The number of input channels.
|
258 |
+
- `model_channels` (int): The number of model channels.
|
259 |
+
- `out_channels` (int): The number of output channels.
|
260 |
+
- `num_res_blocks` (list): The number of residual blocks.
|
261 |
+
- `dropout` (float, optional): The dropout rate. Defaults to 0.
|
262 |
+
- `channel_mult` (tuple, optional): The channel multiplier. Defaults to (1, 2, 4, 8).
|
263 |
+
- `conv_resample` (bool, optional): Whether to use convolutional resampling. Defaults to True.
|
264 |
+
- `dims` (int, optional): The number of dimensions. Defaults to 2.
|
265 |
+
- `num_classes` (int, optional): The number of classes. Defaults to None.
|
266 |
+
- `use_checkpoint` (bool, optional): Whether to use checkpointing. Defaults to False.
|
267 |
+
- `dtype` (torch.dtype, optional): The data type. Defaults to torch.float32.
|
268 |
+
- `num_heads` (int, optional): The number of heads. Defaults to -1.
|
269 |
+
- `num_head_channels` (int, optional): The number of head channels. Defaults to -1.
|
270 |
+
- `num_heads_upsample` (int, optional): The number of heads for upsampling. Defaults to -1.
|
271 |
+
- `use_scale_shift_norm` (bool, optional): Whether to use scale-shift normalization. Defaults to False.
|
272 |
+
- `resblock_updown` (bool, optional): Whether to use residual blocks for up/down sampling. Defaults to False.
|
273 |
+
- `use_new_attention_order` (bool, optional): Whether to use a new attention order. Defaults to False.
|
274 |
+
- `use_spatial_transformer` (bool, optional): Whether to use a spatial transformer. Defaults to False.
|
275 |
+
- `transformer_depth` (int, optional): The depth of the transformer. Defaults to 1.
|
276 |
+
- `context_dim` (int, optional): The context dimension. Defaults to None.
|
277 |
+
- `n_embed` (int, optional): The number of embeddings. Defaults to None.
|
278 |
+
- `legacy` (bool, optional): Whether to use legacy mode. Defaults to True.
|
279 |
+
- `disable_self_attentions` (list, optional): The list of self-attentions to disable. Defaults to None.
|
280 |
+
- `num_attention_blocks` (list, optional): The number of attention blocks. Defaults to None.
|
281 |
+
- `disable_middle_self_attn` (bool, optional): Whether to disable middle self-attention. Defaults to False.
|
282 |
+
- `use_linear_in_transformer` (bool, optional): Whether to use linear in transformer. Defaults to False.
|
283 |
+
- `adm_in_channels` (int, optional): The number of ADM input channels. Defaults to None.
|
284 |
+
- `transformer_depth_middle` (int, optional): The depth of the middle transformer. Defaults to None.
|
285 |
+
- `transformer_depth_output` (list, optional): The depth of the output transformer. Defaults to None.
|
286 |
+
- `use_temporal_resblock` (bool, optional): Whether to use temporal residual blocks. Defaults to False.
|
287 |
+
- `use_temporal_attention` (bool, optional): Whether to use temporal attention. Defaults to False.
|
288 |
+
- `time_context_dim` (int, optional): The time context dimension. Defaults to None.
|
289 |
+
- `extra_ff_mix_layer` (bool, optional): Whether to use an extra feed-forward mix layer. Defaults to False.
|
290 |
+
- `use_spatial_context` (bool, optional): Whether to use spatial context. Defaults to False.
|
291 |
+
- `merge_strategy` (any, optional): The merge strategy. Defaults to None.
|
292 |
+
- `merge_factor` (float, optional): The merge factor. Defaults to 0.0.
|
293 |
+
- `video_kernel_size` (int, optional): The video kernel size. Defaults to None.
|
294 |
+
- `disable_temporal_crossattention` (bool, optional): Whether to disable temporal cross-attention. Defaults to False.
|
295 |
+
- `max_ddpm_temb_period` (int, optional): The maximum DDPM temporal embedding period. Defaults to 10000.
|
296 |
+
- `device` (torch.device, optional): The device to use. Defaults to None.
|
297 |
+
- `operations` (any, optional): The operations to use. Defaults to oai_ops.
|
298 |
+
"""
|
299 |
+
super().__init__()
|
300 |
+
|
301 |
+
if context_dim is not None:
|
302 |
+
self.context_dim = context_dim
|
303 |
+
|
304 |
+
if num_heads_upsample == -1:
|
305 |
+
num_heads_upsample = num_heads
|
306 |
+
if num_head_channels == -1:
|
307 |
+
assert num_heads != -1, "Either num_heads or num_head_channels has to be set"
|
308 |
+
|
309 |
+
self.in_channels = in_channels
|
310 |
+
self.model_channels = model_channels
|
311 |
+
self.out_channels = out_channels
|
312 |
+
self.num_res_blocks = num_res_blocks
|
313 |
+
|
314 |
+
transformer_depth = transformer_depth[:]
|
315 |
+
transformer_depth_output = transformer_depth_output[:]
|
316 |
+
|
317 |
+
self.dropout = dropout
|
318 |
+
self.channel_mult = channel_mult
|
319 |
+
self.conv_resample = conv_resample
|
320 |
+
self.num_classes = num_classes
|
321 |
+
self.use_checkpoint = use_checkpoint
|
322 |
+
self.dtype = dtype
|
323 |
+
self.num_heads = num_heads
|
324 |
+
self.num_head_channels = num_head_channels
|
325 |
+
self.num_heads_upsample = num_heads_upsample
|
326 |
+
self.use_temporal_resblocks = use_temporal_resblock
|
327 |
+
self.predict_codebook_ids = n_embed is not None
|
328 |
+
|
329 |
+
self.default_num_video_frames = None
|
330 |
+
|
331 |
+
time_embed_dim = model_channels * 4
|
332 |
+
self.time_embed = nn.Sequential(
|
333 |
+
operations.Linear(
|
334 |
+
model_channels, time_embed_dim, dtype=self.dtype, device=device
|
335 |
+
),
|
336 |
+
nn.SiLU(),
|
337 |
+
operations.Linear(
|
338 |
+
time_embed_dim, time_embed_dim, dtype=self.dtype, device=device
|
339 |
+
),
|
340 |
+
)
|
341 |
+
|
342 |
+
self.input_blocks = nn.ModuleList(
|
343 |
+
[
|
344 |
+
sampling.TimestepEmbedSequential1(
|
345 |
+
operations.conv_nd(
|
346 |
+
dims,
|
347 |
+
in_channels,
|
348 |
+
model_channels,
|
349 |
+
3,
|
350 |
+
padding=1,
|
351 |
+
dtype=self.dtype,
|
352 |
+
device=device,
|
353 |
+
)
|
354 |
+
)
|
355 |
+
]
|
356 |
+
)
|
357 |
+
self._feature_size = model_channels
|
358 |
+
input_block_chans = [model_channels]
|
359 |
+
ch = model_channels
|
360 |
+
ds = 1
|
361 |
+
|
362 |
+
def get_attention_layer(
|
363 |
+
ch: int,
|
364 |
+
num_heads: int,
|
365 |
+
dim_head: int,
|
366 |
+
depth: int = 1,
|
367 |
+
context_dim: int = None,
|
368 |
+
use_checkpoint: bool = False,
|
369 |
+
disable_self_attn: bool = False,
|
370 |
+
) -> transformer.SpatialTransformer:
|
371 |
+
"""#### Get an attention layer.
|
372 |
+
|
373 |
+
#### Args:
|
374 |
+
- `ch` (int): The number of channels.
|
375 |
+
- `num_heads` (int): The number of heads.
|
376 |
+
- `dim_head` (int): The dimension of each head.
|
377 |
+
- `depth` (int, optional): The depth of the transformer. Defaults to 1.
|
378 |
+
- `context_dim` (int, optional): The context dimension. Defaults to None.
|
379 |
+
- `use_checkpoint` (bool, optional): Whether to use checkpointing. Defaults to False.
|
380 |
+
- `disable_self_attn` (bool, optional): Whether to disable self-attention. Defaults to False.
|
381 |
+
|
382 |
+
#### Returns:
|
383 |
+
- `transformer.SpatialTransformer`: The attention layer.
|
384 |
+
"""
|
385 |
+
return transformer.SpatialTransformer(
|
386 |
+
ch,
|
387 |
+
num_heads,
|
388 |
+
dim_head,
|
389 |
+
depth=depth,
|
390 |
+
context_dim=context_dim,
|
391 |
+
disable_self_attn=disable_self_attn,
|
392 |
+
use_linear=use_linear_in_transformer,
|
393 |
+
use_checkpoint=use_checkpoint,
|
394 |
+
dtype=self.dtype,
|
395 |
+
device=device,
|
396 |
+
operations=operations,
|
397 |
+
)
|
398 |
+
|
399 |
+
def get_resblock(
|
400 |
+
merge_factor: float,
|
401 |
+
merge_strategy: any,
|
402 |
+
video_kernel_size: int,
|
403 |
+
ch: int,
|
404 |
+
time_embed_dim: int,
|
405 |
+
dropout: float,
|
406 |
+
out_channels: int,
|
407 |
+
dims: int,
|
408 |
+
use_checkpoint: bool,
|
409 |
+
use_scale_shift_norm: bool,
|
410 |
+
down: bool = False,
|
411 |
+
up: bool = False,
|
412 |
+
dtype: th.dtype = None,
|
413 |
+
device: th.device = None,
|
414 |
+
operations: any = oai_ops,
|
415 |
+
) -> ResBlock.ResBlock1:
|
416 |
+
"""#### Get a residual block.
|
417 |
+
|
418 |
+
#### Args:
|
419 |
+
- `merge_factor` (float): The merge factor.
|
420 |
+
- `merge_strategy` (any): The merge strategy.
|
421 |
+
- `video_kernel_size` (int): The video kernel size.
|
422 |
+
- `ch` (int): The number of channels.
|
423 |
+
- `time_embed_dim` (int): The time embedding dimension.
|
424 |
+
- `dropout` (float): The dropout rate.
|
425 |
+
- `out_channels` (int): The number of output channels.
|
426 |
+
- `dims` (int): The number of dimensions.
|
427 |
+
- `use_checkpoint` (bool): Whether to use checkpointing.
|
428 |
+
- `use_scale_shift_norm` (bool): Whether to use scale-shift normalization.
|
429 |
+
- `down` (bool, optional): Whether to use downsampling. Defaults to False.
|
430 |
+
- `up` (bool, optional): Whether to use upsampling. Defaults to False.
|
431 |
+
- `dtype` (torch.dtype, optional): The data type. Defaults to None.
|
432 |
+
- `device` (torch.device, optional): The device. Defaults to None.
|
433 |
+
- `operations` (any, optional): The operations to use. Defaults to oai_ops.
|
434 |
+
|
435 |
+
#### Returns:
|
436 |
+
- `ResBlock.ResBlock1`: The residual block.
|
437 |
+
"""
|
438 |
+
return ResBlock.ResBlock1(
|
439 |
+
channels=ch,
|
440 |
+
emb_channels=time_embed_dim,
|
441 |
+
dropout=dropout,
|
442 |
+
out_channels=out_channels,
|
443 |
+
use_checkpoint=use_checkpoint,
|
444 |
+
dims=dims,
|
445 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
446 |
+
down=down,
|
447 |
+
up=up,
|
448 |
+
dtype=dtype,
|
449 |
+
device=device,
|
450 |
+
operations=operations,
|
451 |
+
)
|
452 |
+
|
453 |
+
self.double_blocks = nn.ModuleList()
|
454 |
+
for level, mult in enumerate(channel_mult):
|
455 |
+
for nr in range(self.num_res_blocks[level]):
|
456 |
+
layers = [
|
457 |
+
get_resblock(
|
458 |
+
merge_factor=merge_factor,
|
459 |
+
merge_strategy=merge_strategy,
|
460 |
+
video_kernel_size=video_kernel_size,
|
461 |
+
ch=ch,
|
462 |
+
time_embed_dim=time_embed_dim,
|
463 |
+
dropout=dropout,
|
464 |
+
out_channels=mult * model_channels,
|
465 |
+
dims=dims,
|
466 |
+
use_checkpoint=use_checkpoint,
|
467 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
468 |
+
dtype=self.dtype,
|
469 |
+
device=device,
|
470 |
+
operations=operations,
|
471 |
+
)
|
472 |
+
]
|
473 |
+
ch = mult * model_channels
|
474 |
+
num_transformers = transformer_depth.pop(0)
|
475 |
+
if num_transformers > 0:
|
476 |
+
dim_head = ch // num_heads
|
477 |
+
disabled_sa = False
|
478 |
+
|
479 |
+
if (
|
480 |
+
not util.exists(num_attention_blocks)
|
481 |
+
or nr < num_attention_blocks[level]
|
482 |
+
):
|
483 |
+
layers.append(
|
484 |
+
get_attention_layer(
|
485 |
+
ch,
|
486 |
+
num_heads,
|
487 |
+
dim_head,
|
488 |
+
depth=num_transformers,
|
489 |
+
context_dim=context_dim,
|
490 |
+
disable_self_attn=disabled_sa,
|
491 |
+
use_checkpoint=use_checkpoint,
|
492 |
+
)
|
493 |
+
)
|
494 |
+
self.input_blocks.append(sampling.TimestepEmbedSequential1(*layers))
|
495 |
+
self._feature_size += ch
|
496 |
+
input_block_chans.append(ch)
|
497 |
+
if level != len(channel_mult) - 1:
|
498 |
+
out_ch = ch
|
499 |
+
self.input_blocks.append(
|
500 |
+
sampling.TimestepEmbedSequential1(
|
501 |
+
get_resblock(
|
502 |
+
merge_factor=merge_factor,
|
503 |
+
merge_strategy=merge_strategy,
|
504 |
+
video_kernel_size=video_kernel_size,
|
505 |
+
ch=ch,
|
506 |
+
time_embed_dim=time_embed_dim,
|
507 |
+
dropout=dropout,
|
508 |
+
out_channels=out_ch,
|
509 |
+
dims=dims,
|
510 |
+
use_checkpoint=use_checkpoint,
|
511 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
512 |
+
down=True,
|
513 |
+
dtype=self.dtype,
|
514 |
+
device=device,
|
515 |
+
operations=operations,
|
516 |
+
)
|
517 |
+
if resblock_updown
|
518 |
+
else ResBlock.Downsample1(
|
519 |
+
ch,
|
520 |
+
conv_resample,
|
521 |
+
dims=dims,
|
522 |
+
out_channels=out_ch,
|
523 |
+
dtype=self.dtype,
|
524 |
+
device=device,
|
525 |
+
operations=operations,
|
526 |
+
)
|
527 |
+
)
|
528 |
+
)
|
529 |
+
ch = out_ch
|
530 |
+
input_block_chans.append(ch)
|
531 |
+
ds *= 2
|
532 |
+
self._feature_size += ch
|
533 |
+
|
534 |
+
dim_head = ch // num_heads
|
535 |
+
mid_block = [
|
536 |
+
get_resblock(
|
537 |
+
merge_factor=merge_factor,
|
538 |
+
merge_strategy=merge_strategy,
|
539 |
+
video_kernel_size=video_kernel_size,
|
540 |
+
ch=ch,
|
541 |
+
time_embed_dim=time_embed_dim,
|
542 |
+
dropout=dropout,
|
543 |
+
out_channels=None,
|
544 |
+
dims=dims,
|
545 |
+
use_checkpoint=use_checkpoint,
|
546 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
547 |
+
dtype=self.dtype,
|
548 |
+
device=device,
|
549 |
+
operations=operations,
|
550 |
+
)
|
551 |
+
]
|
552 |
+
|
553 |
+
self.middle_block = None
|
554 |
+
if transformer_depth_middle >= -1:
|
555 |
+
if transformer_depth_middle >= 0:
|
556 |
+
mid_block += [
|
557 |
+
get_attention_layer( # always uses a self-attn
|
558 |
+
ch,
|
559 |
+
num_heads,
|
560 |
+
dim_head,
|
561 |
+
depth=transformer_depth_middle,
|
562 |
+
context_dim=context_dim,
|
563 |
+
disable_self_attn=disable_middle_self_attn,
|
564 |
+
use_checkpoint=use_checkpoint,
|
565 |
+
),
|
566 |
+
get_resblock(
|
567 |
+
merge_factor=merge_factor,
|
568 |
+
merge_strategy=merge_strategy,
|
569 |
+
video_kernel_size=video_kernel_size,
|
570 |
+
ch=ch,
|
571 |
+
time_embed_dim=time_embed_dim,
|
572 |
+
dropout=dropout,
|
573 |
+
out_channels=None,
|
574 |
+
dims=dims,
|
575 |
+
use_checkpoint=use_checkpoint,
|
576 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
577 |
+
dtype=self.dtype,
|
578 |
+
device=device,
|
579 |
+
operations=operations,
|
580 |
+
),
|
581 |
+
]
|
582 |
+
self.middle_block = sampling.TimestepEmbedSequential1(*mid_block)
|
583 |
+
self._feature_size += ch
|
584 |
+
|
585 |
+
self.output_blocks = nn.ModuleList([])
|
586 |
+
for level, mult in list(enumerate(channel_mult))[::-1]:
|
587 |
+
for i in range(self.num_res_blocks[level] + 1):
|
588 |
+
ich = input_block_chans.pop()
|
589 |
+
layers = [
|
590 |
+
get_resblock(
|
591 |
+
merge_factor=merge_factor,
|
592 |
+
merge_strategy=merge_strategy,
|
593 |
+
video_kernel_size=video_kernel_size,
|
594 |
+
ch=ch + ich,
|
595 |
+
time_embed_dim=time_embed_dim,
|
596 |
+
dropout=dropout,
|
597 |
+
out_channels=model_channels * mult,
|
598 |
+
dims=dims,
|
599 |
+
use_checkpoint=use_checkpoint,
|
600 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
601 |
+
dtype=self.dtype,
|
602 |
+
device=device,
|
603 |
+
operations=operations,
|
604 |
+
)
|
605 |
+
]
|
606 |
+
ch = model_channels * mult
|
607 |
+
num_transformers = transformer_depth_output.pop()
|
608 |
+
if num_transformers > 0:
|
609 |
+
dim_head = ch // num_heads
|
610 |
+
disabled_sa = False
|
611 |
+
|
612 |
+
if (
|
613 |
+
not util.exists(num_attention_blocks)
|
614 |
+
or i < num_attention_blocks[level]
|
615 |
+
):
|
616 |
+
layers.append(
|
617 |
+
get_attention_layer(
|
618 |
+
ch,
|
619 |
+
num_heads,
|
620 |
+
dim_head,
|
621 |
+
depth=num_transformers,
|
622 |
+
context_dim=context_dim,
|
623 |
+
disable_self_attn=disabled_sa,
|
624 |
+
use_checkpoint=use_checkpoint,
|
625 |
+
)
|
626 |
+
)
|
627 |
+
if level and i == self.num_res_blocks[level]:
|
628 |
+
out_ch = ch
|
629 |
+
layers.append(
|
630 |
+
get_resblock(
|
631 |
+
merge_factor=merge_factor,
|
632 |
+
merge_strategy=merge_strategy,
|
633 |
+
video_kernel_size=video_kernel_size,
|
634 |
+
ch=ch,
|
635 |
+
time_embed_dim=time_embed_dim,
|
636 |
+
dropout=dropout,
|
637 |
+
out_channels=out_ch,
|
638 |
+
dims=dims,
|
639 |
+
use_checkpoint=use_checkpoint,
|
640 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
641 |
+
up=True,
|
642 |
+
dtype=self.dtype,
|
643 |
+
device=device,
|
644 |
+
operations=operations,
|
645 |
+
)
|
646 |
+
if resblock_updown
|
647 |
+
else ResBlock.Upsample1(
|
648 |
+
ch,
|
649 |
+
conv_resample,
|
650 |
+
dims=dims,
|
651 |
+
out_channels=out_ch,
|
652 |
+
dtype=self.dtype,
|
653 |
+
device=device,
|
654 |
+
operations=operations,
|
655 |
+
)
|
656 |
+
)
|
657 |
+
ds //= 2
|
658 |
+
self.output_blocks.append(sampling.TimestepEmbedSequential1(*layers))
|
659 |
+
self._feature_size += ch
|
660 |
+
|
661 |
+
self.out = nn.Sequential(
|
662 |
+
operations.GroupNorm(32, ch, dtype=self.dtype, device=device),
|
663 |
+
nn.SiLU(),
|
664 |
+
util.zero_module(
|
665 |
+
operations.conv_nd(
|
666 |
+
dims,
|
667 |
+
model_channels,
|
668 |
+
out_channels,
|
669 |
+
3,
|
670 |
+
padding=1,
|
671 |
+
dtype=self.dtype,
|
672 |
+
device=device,
|
673 |
+
)
|
674 |
+
),
|
675 |
+
)
|
676 |
+
|
677 |
+
def forward(
|
678 |
+
self,
|
679 |
+
x: torch.Tensor,
|
680 |
+
timesteps: Optional[torch.Tensor] = None,
|
681 |
+
context: Optional[torch.Tensor] = None,
|
682 |
+
y: Optional[torch.Tensor] = None,
|
683 |
+
control: Optional[torch.Tensor] = None,
|
684 |
+
transformer_options: Dict[str, Any] = {},
|
685 |
+
**kwargs: Any,
|
686 |
+
) -> torch.Tensor:
|
687 |
+
"""#### Forward pass of the UNet model.
|
688 |
+
|
689 |
+
#### Args:
|
690 |
+
- `x` (torch.Tensor): The input tensor.
|
691 |
+
- `timesteps` (Optional[torch.Tensor], optional): The timesteps tensor. Defaults to None.
|
692 |
+
- `context` (Optional[torch.Tensor], optional): The context tensor. Defaults to None.
|
693 |
+
- `y` (Optional[torch.Tensor], optional): The class labels tensor. Defaults to None.
|
694 |
+
- `control` (Optional[torch.Tensor], optional): The control tensor. Defaults to None.
|
695 |
+
- `transformer_options` (Dict[str, Any], optional): Options for the transformer. Defaults to {}.
|
696 |
+
- `**kwargs` (Any): Additional keyword arguments.
|
697 |
+
|
698 |
+
#### Returns:
|
699 |
+
- `torch.Tensor`: The output tensor.
|
700 |
+
"""
|
701 |
+
transformer_options["original_shape"] = list(x.shape)
|
702 |
+
transformer_options["transformer_index"] = 0
|
703 |
+
transformer_patches = transformer_options.get("patches", {})
|
704 |
+
|
705 |
+
num_video_frames = kwargs.get("num_video_frames", self.default_num_video_frames)
|
706 |
+
image_only_indicator = kwargs.get("image_only_indicator", None)
|
707 |
+
time_context = kwargs.get("time_context", None)
|
708 |
+
|
709 |
+
assert (y is not None) == (
|
710 |
+
self.num_classes is not None
|
711 |
+
), "must specify y if and only if the model is class-conditional"
|
712 |
+
hs = []
|
713 |
+
t_emb = sampling_util.timestep_embedding(
|
714 |
+
timesteps, self.model_channels
|
715 |
+
).to(x.dtype)
|
716 |
+
emb = self.time_embed(t_emb)
|
717 |
+
h = x
|
718 |
+
for id, module in enumerate(self.input_blocks):
|
719 |
+
transformer_options["block"] = ("input", id)
|
720 |
+
h = ResBlock.forward_timestep_embed1(
|
721 |
+
module,
|
722 |
+
h,
|
723 |
+
emb,
|
724 |
+
context,
|
725 |
+
transformer_options,
|
726 |
+
time_context=time_context,
|
727 |
+
num_video_frames=num_video_frames,
|
728 |
+
image_only_indicator=image_only_indicator,
|
729 |
+
)
|
730 |
+
h = apply_control1(h, control, "input")
|
731 |
+
hs.append(h)
|
732 |
+
|
733 |
+
transformer_options["block"] = ("middle", 0)
|
734 |
+
if self.middle_block is not None:
|
735 |
+
h = ResBlock.forward_timestep_embed1(
|
736 |
+
self.middle_block,
|
737 |
+
h,
|
738 |
+
emb,
|
739 |
+
context,
|
740 |
+
transformer_options,
|
741 |
+
time_context=time_context,
|
742 |
+
num_video_frames=num_video_frames,
|
743 |
+
image_only_indicator=image_only_indicator,
|
744 |
+
)
|
745 |
+
h = apply_control1(h, control, "middle")
|
746 |
+
|
747 |
+
for id, module in enumerate(self.output_blocks):
|
748 |
+
transformer_options["block"] = ("output", id)
|
749 |
+
hsp = hs.pop()
|
750 |
+
hsp = apply_control1(hsp, control, "output")
|
751 |
+
|
752 |
+
h = torch.cat([h, hsp], dim=1)
|
753 |
+
del hsp
|
754 |
+
if len(hs) > 0:
|
755 |
+
output_shape = hs[-1].shape
|
756 |
+
else:
|
757 |
+
output_shape = None
|
758 |
+
h = ResBlock.forward_timestep_embed1(
|
759 |
+
module,
|
760 |
+
h,
|
761 |
+
emb,
|
762 |
+
context,
|
763 |
+
transformer_options,
|
764 |
+
output_shape,
|
765 |
+
time_context=time_context,
|
766 |
+
num_video_frames=num_video_frames,
|
767 |
+
image_only_indicator=image_only_indicator,
|
768 |
+
)
|
769 |
+
h = h.type(x.dtype)
|
770 |
+
return self.out(h)
|
771 |
+
|
772 |
+
|
773 |
+
def detect_unet_config(state_dict: Dict[str, torch.Tensor], key_prefix: str) -> Dict[str, Any]:
|
774 |
+
"""#### Detect the UNet configuration from a state dictionary.
|
775 |
+
|
776 |
+
#### Args:
|
777 |
+
- `state_dict` (Dict[str, torch.Tensor]): The state dictionary.
|
778 |
+
- `key_prefix` (str): The key prefix.
|
779 |
+
|
780 |
+
#### Returns:
|
781 |
+
- `Dict[str, Any]`: The detected UNet configuration.
|
782 |
+
"""
|
783 |
+
state_dict_keys = list(state_dict.keys())
|
784 |
+
|
785 |
+
if (
|
786 |
+
"{}joint_blocks.0.context_block.attn.qkv.weight".format(key_prefix)
|
787 |
+
in state_dict_keys
|
788 |
+
): # mmdit model
|
789 |
+
unet_config = {}
|
790 |
+
unet_config["in_channels"] = state_dict[
|
791 |
+
"{}x_embedder.proj.weight".format(key_prefix)
|
792 |
+
].shape[1]
|
793 |
+
patch_size = state_dict["{}x_embedder.proj.weight".format(key_prefix)].shape[2]
|
794 |
+
unet_config["patch_size"] = patch_size
|
795 |
+
final_layer = "{}final_layer.linear.weight".format(key_prefix)
|
796 |
+
if final_layer in state_dict:
|
797 |
+
unet_config["out_channels"] = state_dict[final_layer].shape[0] // (
|
798 |
+
patch_size * patch_size
|
799 |
+
)
|
800 |
+
|
801 |
+
unet_config["depth"] = (
|
802 |
+
state_dict["{}x_embedder.proj.weight".format(key_prefix)].shape[0] // 64
|
803 |
+
)
|
804 |
+
unet_config["input_size"] = None
|
805 |
+
y_key = "{}y_embedder.mlp.0.weight".format(key_prefix)
|
806 |
+
if y_key in state_dict_keys:
|
807 |
+
unet_config["adm_in_channels"] = state_dict[y_key].shape[1]
|
808 |
+
|
809 |
+
context_key = "{}context_embedder.weight".format(key_prefix)
|
810 |
+
if context_key in state_dict_keys:
|
811 |
+
in_features = state_dict[context_key].shape[1]
|
812 |
+
out_features = state_dict[context_key].shape[0]
|
813 |
+
unet_config["context_embedder_config"] = {
|
814 |
+
"target": "torch.nn.Linear",
|
815 |
+
"params": {"in_features": in_features, "out_features": out_features},
|
816 |
+
}
|
817 |
+
num_patches_key = "{}pos_embed".format(key_prefix)
|
818 |
+
if num_patches_key in state_dict_keys:
|
819 |
+
num_patches = state_dict[num_patches_key].shape[1]
|
820 |
+
unet_config["num_patches"] = num_patches
|
821 |
+
unet_config["pos_embed_max_size"] = round(math.sqrt(num_patches))
|
822 |
+
|
823 |
+
rms_qk = "{}joint_blocks.0.context_block.attn.ln_q.weight".format(key_prefix)
|
824 |
+
if rms_qk in state_dict_keys:
|
825 |
+
unet_config["qk_norm"] = "rms"
|
826 |
+
|
827 |
+
unet_config["pos_embed_scaling_factor"] = None # unused for inference
|
828 |
+
context_processor = "{}context_processor.layers.0.attn.qkv.weight".format(
|
829 |
+
key_prefix
|
830 |
+
)
|
831 |
+
if context_processor in state_dict_keys:
|
832 |
+
unet_config["context_processor_layers"] = transformer.count_blocks(
|
833 |
+
state_dict_keys,
|
834 |
+
"{}context_processor.layers.".format(key_prefix) + "{}.",
|
835 |
+
)
|
836 |
+
return unet_config
|
837 |
+
|
838 |
+
if "{}clf.1.weight".format(key_prefix) in state_dict_keys: # stable cascade
|
839 |
+
unet_config = {}
|
840 |
+
text_mapper_name = "{}clip_txt_mapper.weight".format(key_prefix)
|
841 |
+
if text_mapper_name in state_dict_keys:
|
842 |
+
unet_config["stable_cascade_stage"] = "c"
|
843 |
+
w = state_dict[text_mapper_name]
|
844 |
+
if w.shape[0] == 1536: # stage c lite
|
845 |
+
unet_config["c_cond"] = 1536
|
846 |
+
unet_config["c_hidden"] = [1536, 1536]
|
847 |
+
unet_config["nhead"] = [24, 24]
|
848 |
+
unet_config["blocks"] = [[4, 12], [12, 4]]
|
849 |
+
elif w.shape[0] == 2048: # stage c full
|
850 |
+
unet_config["c_cond"] = 2048
|
851 |
+
elif "{}clip_mapper.weight".format(key_prefix) in state_dict_keys:
|
852 |
+
unet_config["stable_cascade_stage"] = "b"
|
853 |
+
w = state_dict["{}down_blocks.1.0.channelwise.0.weight".format(key_prefix)]
|
854 |
+
if w.shape[-1] == 640:
|
855 |
+
unet_config["c_hidden"] = [320, 640, 1280, 1280]
|
856 |
+
unet_config["nhead"] = [-1, -1, 20, 20]
|
857 |
+
unet_config["blocks"] = [[2, 6, 28, 6], [6, 28, 6, 2]]
|
858 |
+
unet_config["block_repeat"] = [[1, 1, 1, 1], [3, 3, 2, 2]]
|
859 |
+
elif w.shape[-1] == 576: # stage b lite
|
860 |
+
unet_config["c_hidden"] = [320, 576, 1152, 1152]
|
861 |
+
unet_config["nhead"] = [-1, 9, 18, 18]
|
862 |
+
unet_config["blocks"] = [[2, 4, 14, 4], [4, 14, 4, 2]]
|
863 |
+
unet_config["block_repeat"] = [[1, 1, 1, 1], [2, 2, 2, 2]]
|
864 |
+
return unet_config
|
865 |
+
|
866 |
+
if (
|
867 |
+
"{}transformer.rotary_pos_emb.inv_freq".format(key_prefix) in state_dict_keys
|
868 |
+
): # stable audio dit
|
869 |
+
unet_config = {}
|
870 |
+
unet_config["audio_model"] = "dit1.0"
|
871 |
+
return unet_config
|
872 |
+
|
873 |
+
if (
|
874 |
+
"{}double_layers.0.attn.w1q.weight".format(key_prefix) in state_dict_keys
|
875 |
+
): # aura flow dit
|
876 |
+
unet_config = {}
|
877 |
+
unet_config["max_seq"] = state_dict[
|
878 |
+
"{}positional_encoding".format(key_prefix)
|
879 |
+
].shape[1]
|
880 |
+
unet_config["cond_seq_dim"] = state_dict[
|
881 |
+
"{}cond_seq_linear.weight".format(key_prefix)
|
882 |
+
].shape[1]
|
883 |
+
double_layers = transformer.count_blocks(
|
884 |
+
state_dict_keys, "{}double_layers.".format(key_prefix) + "{}."
|
885 |
+
)
|
886 |
+
single_layers = transformer.count_blocks(
|
887 |
+
state_dict_keys, "{}single_layers.".format(key_prefix) + "{}."
|
888 |
+
)
|
889 |
+
unet_config["n_double_layers"] = double_layers
|
890 |
+
unet_config["n_layers"] = double_layers + single_layers
|
891 |
+
return unet_config
|
892 |
+
|
893 |
+
if "{}mlp_t5.0.weight".format(key_prefix) in state_dict_keys: # Hunyuan DiT
|
894 |
+
unet_config = {}
|
895 |
+
unet_config["image_model"] = "hydit"
|
896 |
+
unet_config["depth"] = transformer.count_blocks(
|
897 |
+
state_dict_keys, "{}blocks.".format(key_prefix) + "{}."
|
898 |
+
)
|
899 |
+
unet_config["hidden_size"] = state_dict[
|
900 |
+
"{}x_embedder.proj.weight".format(key_prefix)
|
901 |
+
].shape[0]
|
902 |
+
if unet_config["hidden_size"] == 1408 and unet_config["depth"] == 40: # DiT-g/2
|
903 |
+
unet_config["mlp_ratio"] = 4.3637
|
904 |
+
if state_dict["{}extra_embedder.0.weight".format(key_prefix)].shape[1] == 3968:
|
905 |
+
unet_config["size_cond"] = True
|
906 |
+
unet_config["use_style_cond"] = True
|
907 |
+
unet_config["image_model"] = "hydit1"
|
908 |
+
return unet_config
|
909 |
+
|
910 |
+
if (
|
911 |
+
"{}double_blocks.0.img_attn.norm.key_norm.scale".format(key_prefix)
|
912 |
+
in state_dict_keys
|
913 |
+
): # Flux
|
914 |
+
dit_config = {}
|
915 |
+
dit_config["image_model"] = "flux"
|
916 |
+
dit_config["in_channels"] = 16
|
917 |
+
dit_config["vec_in_dim"] = 768
|
918 |
+
dit_config["context_in_dim"] = 4096
|
919 |
+
dit_config["hidden_size"] = 3072
|
920 |
+
dit_config["mlp_ratio"] = 4.0
|
921 |
+
dit_config["num_heads"] = 24
|
922 |
+
dit_config["depth"] = transformer.count_blocks(
|
923 |
+
state_dict_keys, "{}double_blocks.".format(key_prefix) + "{}."
|
924 |
+
)
|
925 |
+
dit_config["depth_single_blocks"] = transformer.count_blocks(
|
926 |
+
state_dict_keys, "{}single_blocks.".format(key_prefix) + "{}."
|
927 |
+
)
|
928 |
+
dit_config["axes_dim"] = [16, 56, 56]
|
929 |
+
dit_config["theta"] = 10000
|
930 |
+
dit_config["qkv_bias"] = True
|
931 |
+
dit_config["guidance_embed"] = (
|
932 |
+
"{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
|
933 |
+
)
|
934 |
+
return dit_config
|
935 |
+
|
936 |
+
if "{}input_blocks.0.0.weight".format(key_prefix) not in state_dict_keys:
|
937 |
+
return None
|
938 |
+
|
939 |
+
unet_config = {
|
940 |
+
"use_checkpoint": False,
|
941 |
+
"image_size": 32,
|
942 |
+
"use_spatial_transformer": True,
|
943 |
+
"legacy": False,
|
944 |
+
}
|
945 |
+
|
946 |
+
y_input = "{}label_emb.0.0.weight".format(key_prefix)
|
947 |
+
if y_input in state_dict_keys:
|
948 |
+
unet_config["num_classes"] = "sequential"
|
949 |
+
unet_config["adm_in_channels"] = state_dict[y_input].shape[1]
|
950 |
+
else:
|
951 |
+
unet_config["adm_in_channels"] = None
|
952 |
+
|
953 |
+
model_channels = state_dict["{}input_blocks.0.0.weight".format(key_prefix)].shape[0]
|
954 |
+
in_channels = state_dict["{}input_blocks.0.0.weight".format(key_prefix)].shape[1]
|
955 |
+
|
956 |
+
out_key = "{}out.2.weight".format(key_prefix)
|
957 |
+
if out_key in state_dict:
|
958 |
+
out_channels = state_dict[out_key].shape[0]
|
959 |
+
else:
|
960 |
+
out_channels = 4
|
961 |
+
|
962 |
+
num_res_blocks = []
|
963 |
+
channel_mult = []
|
964 |
+
transformer_depth = []
|
965 |
+
transformer_depth_output = []
|
966 |
+
context_dim = None
|
967 |
+
use_linear_in_transformer = False
|
968 |
+
|
969 |
+
video_model = False
|
970 |
+
video_model_cross = False
|
971 |
+
|
972 |
+
current_res = 1
|
973 |
+
count = 0
|
974 |
+
|
975 |
+
last_res_blocks = 0
|
976 |
+
last_channel_mult = 0
|
977 |
+
|
978 |
+
input_block_count = transformer.count_blocks(
|
979 |
+
state_dict_keys, "{}input_blocks".format(key_prefix) + ".{}."
|
980 |
+
)
|
981 |
+
for count in range(input_block_count):
|
982 |
+
prefix = "{}input_blocks.{}.".format(key_prefix, count)
|
983 |
+
prefix_output = "{}output_blocks.{}.".format(
|
984 |
+
key_prefix, input_block_count - count - 1
|
985 |
+
)
|
986 |
+
|
987 |
+
block_keys = sorted(
|
988 |
+
list(filter(lambda a: a.startswith(prefix), state_dict_keys))
|
989 |
+
)
|
990 |
+
if len(block_keys) == 0:
|
991 |
+
break
|
992 |
+
|
993 |
+
block_keys_output = sorted(
|
994 |
+
list(filter(lambda a: a.startswith(prefix_output), state_dict_keys))
|
995 |
+
)
|
996 |
+
|
997 |
+
if "{}0.op.weight".format(prefix) in block_keys: # new layer
|
998 |
+
num_res_blocks.append(last_res_blocks)
|
999 |
+
channel_mult.append(last_channel_mult)
|
1000 |
+
|
1001 |
+
current_res *= 2
|
1002 |
+
last_res_blocks = 0
|
1003 |
+
last_channel_mult = 0
|
1004 |
+
out = transformer.calculate_transformer_depth(
|
1005 |
+
prefix_output, state_dict_keys, state_dict
|
1006 |
+
)
|
1007 |
+
if out is not None:
|
1008 |
+
transformer_depth_output.append(out[0])
|
1009 |
+
else:
|
1010 |
+
transformer_depth_output.append(0)
|
1011 |
+
else:
|
1012 |
+
res_block_prefix = "{}0.in_layers.0.weight".format(prefix)
|
1013 |
+
if res_block_prefix in block_keys:
|
1014 |
+
last_res_blocks += 1
|
1015 |
+
last_channel_mult = (
|
1016 |
+
state_dict["{}0.out_layers.3.weight".format(prefix)].shape[0]
|
1017 |
+
// model_channels
|
1018 |
+
)
|
1019 |
+
|
1020 |
+
out = transformer.calculate_transformer_depth(prefix, state_dict_keys, state_dict)
|
1021 |
+
if out is not None:
|
1022 |
+
transformer_depth.append(out[0])
|
1023 |
+
if context_dim is None:
|
1024 |
+
context_dim = out[1]
|
1025 |
+
use_linear_in_transformer = out[2]
|
1026 |
+
out[3]
|
1027 |
+
else:
|
1028 |
+
transformer_depth.append(0)
|
1029 |
+
|
1030 |
+
res_block_prefix = "{}0.in_layers.0.weight".format(prefix_output)
|
1031 |
+
if res_block_prefix in block_keys_output:
|
1032 |
+
out = transformer.calculate_transformer_depth(
|
1033 |
+
prefix_output, state_dict_keys, state_dict
|
1034 |
+
)
|
1035 |
+
if out is not None:
|
1036 |
+
transformer_depth_output.append(out[0])
|
1037 |
+
else:
|
1038 |
+
transformer_depth_output.append(0)
|
1039 |
+
|
1040 |
+
num_res_blocks.append(last_res_blocks)
|
1041 |
+
channel_mult.append(last_channel_mult)
|
1042 |
+
if "{}middle_block.1.proj_in.weight".format(key_prefix) in state_dict_keys:
|
1043 |
+
transformer_depth_middle = transformer.count_blocks(
|
1044 |
+
state_dict_keys,
|
1045 |
+
"{}middle_block.1.transformer_blocks.".format(key_prefix) + "{}",
|
1046 |
+
)
|
1047 |
+
elif "{}middle_block.0.in_layers.0.weight".format(key_prefix) in state_dict_keys:
|
1048 |
+
transformer_depth_middle = -1
|
1049 |
+
else:
|
1050 |
+
transformer_depth_middle = -2
|
1051 |
+
|
1052 |
+
unet_config["in_channels"] = in_channels
|
1053 |
+
unet_config["out_channels"] = out_channels
|
1054 |
+
unet_config["model_channels"] = model_channels
|
1055 |
+
unet_config["num_res_blocks"] = num_res_blocks
|
1056 |
+
unet_config["transformer_depth"] = transformer_depth
|
1057 |
+
unet_config["transformer_depth_output"] = transformer_depth_output
|
1058 |
+
unet_config["channel_mult"] = channel_mult
|
1059 |
+
unet_config["transformer_depth_middle"] = transformer_depth_middle
|
1060 |
+
unet_config["use_linear_in_transformer"] = use_linear_in_transformer
|
1061 |
+
unet_config["context_dim"] = context_dim
|
1062 |
+
|
1063 |
+
if video_model:
|
1064 |
+
unet_config["extra_ff_mix_layer"] = True
|
1065 |
+
unet_config["use_spatial_context"] = True
|
1066 |
+
unet_config["merge_strategy"] = "learned_with_images"
|
1067 |
+
unet_config["merge_factor"] = 0.0
|
1068 |
+
unet_config["video_kernel_size"] = [3, 1, 1]
|
1069 |
+
unet_config["use_temporal_resblock"] = True
|
1070 |
+
unet_config["use_temporal_attention"] = True
|
1071 |
+
unet_config["disable_temporal_crossattention"] = not video_model_cross
|
1072 |
+
else:
|
1073 |
+
unet_config["use_temporal_resblock"] = False
|
1074 |
+
unet_config["use_temporal_attention"] = False
|
1075 |
+
|
1076 |
+
return unet_config
|
1077 |
+
|
1078 |
+
|
1079 |
+
def model_config_from_unet_config(unet_config: Dict[str, Any], state_dict: Optional[Dict[str, torch.Tensor]] = None) -> Any:
|
1080 |
+
"""#### Get the model configuration from a UNet configuration.
|
1081 |
+
|
1082 |
+
#### Args:
|
1083 |
+
- `unet_config` (Dict[str, Any]): The UNet configuration.
|
1084 |
+
- `state_dict` (Optional[Dict[str, torch.Tensor]], optional): The state dictionary. Defaults to None.
|
1085 |
+
|
1086 |
+
#### Returns:
|
1087 |
+
- `Any`: The model configuration.
|
1088 |
+
"""
|
1089 |
+
from modules.SD15 import SD15
|
1090 |
+
|
1091 |
+
for model_config in SD15.models:
|
1092 |
+
if model_config.matches(unet_config, state_dict):
|
1093 |
+
return model_config(unet_config)
|
1094 |
+
|
1095 |
+
logging.error("no match {}".format(unet_config))
|
1096 |
+
return None
|
1097 |
+
|
1098 |
+
|
1099 |
+
def model_config_from_unet(state_dict: Dict[str, torch.Tensor], unet_key_prefix: str, use_base_if_no_match: bool = False) -> Any:
|
1100 |
+
"""#### Get the model configuration from a UNet state dictionary.
|
1101 |
+
|
1102 |
+
#### Args:
|
1103 |
+
- `state_dict` (Dict[str, torch.Tensor]): The state dictionary.
|
1104 |
+
- `unet_key_prefix` (str): The UNet key prefix.
|
1105 |
+
- `use_base_if_no_match` (bool, optional): Whether to use the base configuration if no match is found. Defaults to False.
|
1106 |
+
|
1107 |
+
#### Returns:
|
1108 |
+
- `Any`: The model configuration.
|
1109 |
+
"""
|
1110 |
+
unet_config = detect_unet_config(state_dict, unet_key_prefix)
|
1111 |
+
if unet_config is None:
|
1112 |
+
return None
|
1113 |
+
model_config = model_config_from_unet_config(unet_config, state_dict)
|
1114 |
+
return model_config
|
1115 |
+
|
1116 |
+
|
1117 |
+
def unet_dtype1(
|
1118 |
+
device: Optional[torch.device] = None,
|
1119 |
+
model_params: int = 0,
|
1120 |
+
supported_dtypes: List[torch.dtype] = [torch.float16, torch.bfloat16, torch.float32],
|
1121 |
+
) -> torch.dtype:
|
1122 |
+
"""#### Get the dtype for the UNet model.
|
1123 |
+
|
1124 |
+
#### Args:
|
1125 |
+
- `device` (Optional[torch.device], optional): The device. Defaults to None.
|
1126 |
+
- `model_params` (int, optional): The model parameters. Defaults to 0.
|
1127 |
+
- `supported_dtypes` (List[torch.dtype], optional): The supported dtypes. Defaults to [torch.float16, torch.bfloat16, torch.float32].
|
1128 |
+
|
1129 |
+
#### Returns:
|
1130 |
+
- `torch.dtype`: The dtype for the UNet model.
|
1131 |
+
"""
|
1132 |
+
return torch.float16
|
modules/Quantize/Quantizer.py
ADDED
@@ -0,0 +1,1012 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import logging
|
3 |
+
import gguf
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from modules.Device import Device
|
7 |
+
from modules.Model import ModelPatcher
|
8 |
+
from modules.Utilities import util
|
9 |
+
from modules.clip import Clip
|
10 |
+
from modules.cond import cast
|
11 |
+
|
12 |
+
# Constants for torch-compatible quantization types
|
13 |
+
TORCH_COMPATIBLE_QTYPES = {
|
14 |
+
None,
|
15 |
+
gguf.GGMLQuantizationType.F32,
|
16 |
+
gguf.GGMLQuantizationType.F16,
|
17 |
+
}
|
18 |
+
|
19 |
+
|
20 |
+
def is_torch_compatible(tensor: torch.Tensor) -> bool:
|
21 |
+
"""#### Check if a tensor is compatible with PyTorch operations.
|
22 |
+
|
23 |
+
#### Args:
|
24 |
+
- `tensor` (torch.Tensor): The tensor to check.
|
25 |
+
|
26 |
+
#### Returns:
|
27 |
+
- `bool`: Whether the tensor is torch-compatible.
|
28 |
+
"""
|
29 |
+
return (
|
30 |
+
tensor is None
|
31 |
+
or getattr(tensor, "tensor_type", None) in TORCH_COMPATIBLE_QTYPES
|
32 |
+
)
|
33 |
+
|
34 |
+
|
35 |
+
def is_quantized(tensor: torch.Tensor) -> bool:
|
36 |
+
"""#### Check if a tensor is quantized.
|
37 |
+
|
38 |
+
#### Args:
|
39 |
+
- `tensor` (torch.Tensor): The tensor to check.
|
40 |
+
|
41 |
+
#### Returns:
|
42 |
+
- `bool`: Whether the tensor is quantized.
|
43 |
+
"""
|
44 |
+
return not is_torch_compatible(tensor)
|
45 |
+
|
46 |
+
|
47 |
+
def dequantize(
|
48 |
+
data: torch.Tensor,
|
49 |
+
qtype: gguf.GGMLQuantizationType,
|
50 |
+
oshape: tuple,
|
51 |
+
dtype: torch.dtype = None,
|
52 |
+
) -> torch.Tensor:
|
53 |
+
"""#### Dequantize tensor back to usable shape/dtype.
|
54 |
+
|
55 |
+
#### Args:
|
56 |
+
- `data` (torch.Tensor): The quantized data.
|
57 |
+
- `qtype` (gguf.GGMLQuantizationType): The quantization type.
|
58 |
+
- `oshape` (tuple): The output shape.
|
59 |
+
- `dtype` (torch.dtype, optional): The output dtype. Defaults to None.
|
60 |
+
|
61 |
+
#### Returns:
|
62 |
+
- `torch.Tensor`: The dequantized tensor.
|
63 |
+
"""
|
64 |
+
# Get block size and type size for quantization format
|
65 |
+
block_size, type_size = gguf.GGML_QUANT_SIZES[qtype]
|
66 |
+
dequantize_blocks = dequantize_functions[qtype]
|
67 |
+
|
68 |
+
# Reshape data into blocks
|
69 |
+
rows = data.reshape((-1, data.shape[-1])).view(torch.uint8)
|
70 |
+
n_blocks = rows.numel() // type_size
|
71 |
+
blocks = rows.reshape((n_blocks, type_size))
|
72 |
+
|
73 |
+
# Dequantize blocks and reshape to target shape
|
74 |
+
blocks = dequantize_blocks(blocks, block_size, type_size, dtype)
|
75 |
+
return blocks.reshape(oshape)
|
76 |
+
|
77 |
+
|
78 |
+
def split_block_dims(blocks: torch.Tensor, *args) -> list:
|
79 |
+
"""#### Split blocks into dimensions.
|
80 |
+
|
81 |
+
#### Args:
|
82 |
+
- `blocks` (torch.Tensor): The blocks to split.
|
83 |
+
- `*args`: The dimensions to split into.
|
84 |
+
|
85 |
+
#### Returns:
|
86 |
+
- `list`: The split blocks.
|
87 |
+
"""
|
88 |
+
n_max = blocks.shape[1]
|
89 |
+
dims = list(args) + [n_max - sum(args)]
|
90 |
+
return torch.split(blocks, dims, dim=1)
|
91 |
+
|
92 |
+
|
93 |
+
# Legacy Quantization Functions
|
94 |
+
def dequantize_blocks_Q8_0(
|
95 |
+
blocks: torch.Tensor, block_size: int, type_size: int, dtype: torch.dtype = None
|
96 |
+
) -> torch.Tensor:
|
97 |
+
"""#### Dequantize Q8_0 quantized blocks.
|
98 |
+
|
99 |
+
#### Args:
|
100 |
+
- `blocks` (torch.Tensor): The quantized blocks.
|
101 |
+
- `block_size` (int): The block size.
|
102 |
+
- `type_size` (int): The type size.
|
103 |
+
- `dtype` (torch.dtype, optional): The output dtype. Defaults to None.
|
104 |
+
|
105 |
+
#### Returns:
|
106 |
+
- `torch.Tensor`: The dequantized blocks.
|
107 |
+
"""
|
108 |
+
# Split blocks into scale and quantized values
|
109 |
+
d, x = split_block_dims(blocks, 2)
|
110 |
+
d = d.view(torch.float16).to(dtype)
|
111 |
+
x = x.view(torch.int8)
|
112 |
+
return d * x
|
113 |
+
|
114 |
+
|
115 |
+
# K Quants #
|
116 |
+
QK_K = 256
|
117 |
+
K_SCALE_SIZE = 12
|
118 |
+
|
119 |
+
# Mapping of quantization types to dequantization functions
|
120 |
+
dequantize_functions = {
|
121 |
+
gguf.GGMLQuantizationType.Q8_0: dequantize_blocks_Q8_0,
|
122 |
+
}
|
123 |
+
|
124 |
+
|
125 |
+
def dequantize_tensor(
|
126 |
+
tensor: torch.Tensor, dtype: torch.dtype = None, dequant_dtype: torch.dtype = None
|
127 |
+
) -> torch.Tensor:
|
128 |
+
"""#### Dequantize a potentially quantized tensor.
|
129 |
+
|
130 |
+
#### Args:
|
131 |
+
- `tensor` (torch.Tensor): The tensor to dequantize.
|
132 |
+
- `dtype` (torch.dtype, optional): Target dtype. Defaults to None.
|
133 |
+
- `dequant_dtype` (torch.dtype, optional): Intermediate dequantization dtype. Defaults to None.
|
134 |
+
|
135 |
+
#### Returns:
|
136 |
+
- `torch.Tensor`: The dequantized tensor.
|
137 |
+
"""
|
138 |
+
qtype = getattr(tensor, "tensor_type", None)
|
139 |
+
oshape = getattr(tensor, "tensor_shape", tensor.shape)
|
140 |
+
|
141 |
+
if qtype in TORCH_COMPATIBLE_QTYPES:
|
142 |
+
return tensor.to(dtype)
|
143 |
+
elif qtype in dequantize_functions:
|
144 |
+
dequant_dtype = dtype if dequant_dtype == "target" else dequant_dtype
|
145 |
+
return dequantize(tensor.data, qtype, oshape, dtype=dequant_dtype).to(dtype)
|
146 |
+
|
147 |
+
|
148 |
+
class GGMLLayer(torch.nn.Module):
|
149 |
+
"""#### Base class for GGML quantized layers.
|
150 |
+
|
151 |
+
Handles dynamic dequantization of weights during forward pass.
|
152 |
+
"""
|
153 |
+
|
154 |
+
comfy_cast_weights: bool = True
|
155 |
+
dequant_dtype: torch.dtype = None
|
156 |
+
patch_dtype: torch.dtype = None
|
157 |
+
torch_compatible_tensor_types: set = {
|
158 |
+
None,
|
159 |
+
gguf.GGMLQuantizationType.F32,
|
160 |
+
gguf.GGMLQuantizationType.F16,
|
161 |
+
}
|
162 |
+
|
163 |
+
def is_ggml_quantized(
|
164 |
+
self, *, weight: torch.Tensor = None, bias: torch.Tensor = None
|
165 |
+
) -> bool:
|
166 |
+
"""#### Check if layer weights are GGML quantized.
|
167 |
+
|
168 |
+
#### Args:
|
169 |
+
- `weight` (torch.Tensor, optional): Weight tensor to check. Defaults to self.weight.
|
170 |
+
- `bias` (torch.Tensor, optional): Bias tensor to check. Defaults to self.bias.
|
171 |
+
|
172 |
+
#### Returns:
|
173 |
+
- `bool`: Whether weights are quantized.
|
174 |
+
"""
|
175 |
+
if weight is None:
|
176 |
+
weight = self.weight
|
177 |
+
if bias is None:
|
178 |
+
bias = self.bias
|
179 |
+
return is_quantized(weight) or is_quantized(bias)
|
180 |
+
|
181 |
+
def _load_from_state_dict(
|
182 |
+
self, state_dict: dict, prefix: str, *args, **kwargs
|
183 |
+
) -> None:
|
184 |
+
"""#### Load quantized weights from state dict.
|
185 |
+
|
186 |
+
#### Args:
|
187 |
+
- `state_dict` (dict): State dictionary.
|
188 |
+
- `prefix` (str): Key prefix.
|
189 |
+
- `*args`: Additional arguments.
|
190 |
+
- `**kwargs`: Additional keyword arguments.
|
191 |
+
"""
|
192 |
+
weight = state_dict.get(f"{prefix}weight")
|
193 |
+
bias = state_dict.get(f"{prefix}bias")
|
194 |
+
# Use modified loader for quantized or linear layers
|
195 |
+
if self.is_ggml_quantized(weight=weight, bias=bias) or isinstance(
|
196 |
+
self, torch.nn.Linear
|
197 |
+
):
|
198 |
+
return self.ggml_load_from_state_dict(state_dict, prefix, *args, **kwargs)
|
199 |
+
return super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
|
200 |
+
|
201 |
+
def ggml_load_from_state_dict(
|
202 |
+
self,
|
203 |
+
state_dict: dict,
|
204 |
+
prefix: str,
|
205 |
+
local_metadata: dict,
|
206 |
+
strict: bool,
|
207 |
+
missing_keys: list,
|
208 |
+
unexpected_keys: list,
|
209 |
+
error_msgs: list,
|
210 |
+
) -> None:
|
211 |
+
"""#### Load GGML quantized weights from state dict.
|
212 |
+
|
213 |
+
#### Args:
|
214 |
+
- `state_dict` (dict): State dictionary.
|
215 |
+
- `prefix` (str): Key prefix.
|
216 |
+
- `local_metadata` (dict): Local metadata.
|
217 |
+
- `strict` (bool): Strict loading mode.
|
218 |
+
- `missing_keys` (list): Keys missing from state dict.
|
219 |
+
- `unexpected_keys` (list): Unexpected keys found.
|
220 |
+
- `error_msgs` (list): Error messages.
|
221 |
+
"""
|
222 |
+
prefix_len = len(prefix)
|
223 |
+
for k, v in state_dict.items():
|
224 |
+
if k[prefix_len:] == "weight":
|
225 |
+
self.weight = torch.nn.Parameter(v, requires_grad=False)
|
226 |
+
elif k[prefix_len:] == "bias" and v is not None:
|
227 |
+
self.bias = torch.nn.Parameter(v, requires_grad=False)
|
228 |
+
else:
|
229 |
+
missing_keys.append(k)
|
230 |
+
|
231 |
+
def _save_to_state_dict(self, *args, **kwargs) -> None:
|
232 |
+
"""#### Save layer state to state dict.
|
233 |
+
|
234 |
+
#### Args:
|
235 |
+
- `*args`: Additional arguments.
|
236 |
+
- `**kwargs`: Additional keyword arguments.
|
237 |
+
"""
|
238 |
+
if self.is_ggml_quantized():
|
239 |
+
return self.ggml_save_to_state_dict(*args, **kwargs)
|
240 |
+
return super()._save_to_state_dict(*args, **kwargs)
|
241 |
+
|
242 |
+
def ggml_save_to_state_dict(
|
243 |
+
self, destination: dict, prefix: str, keep_vars: bool
|
244 |
+
) -> None:
|
245 |
+
"""#### Save GGML layer state to state dict.
|
246 |
+
|
247 |
+
#### Args:
|
248 |
+
- `destination` (dict): Destination dictionary.
|
249 |
+
- `prefix` (str): Key prefix.
|
250 |
+
- `keep_vars` (bool): Whether to keep variables.
|
251 |
+
"""
|
252 |
+
# Create fake tensors for VRAM estimation
|
253 |
+
weight = torch.zeros_like(self.weight, device=torch.device("meta"))
|
254 |
+
destination[prefix + "weight"] = weight
|
255 |
+
if self.bias is not None:
|
256 |
+
bias = torch.zeros_like(self.bias, device=torch.device("meta"))
|
257 |
+
destination[prefix + "bias"] = bias
|
258 |
+
return
|
259 |
+
|
260 |
+
def get_weight(self, tensor: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
|
261 |
+
"""#### Get dequantized weight tensor.
|
262 |
+
|
263 |
+
#### Args:
|
264 |
+
- `tensor` (torch.Tensor): Input tensor.
|
265 |
+
- `dtype` (torch.dtype): Target dtype.
|
266 |
+
|
267 |
+
#### Returns:
|
268 |
+
- `torch.Tensor`: Dequantized tensor.
|
269 |
+
"""
|
270 |
+
if tensor is None:
|
271 |
+
return
|
272 |
+
|
273 |
+
# Consolidate and load patches to GPU asynchronously
|
274 |
+
patch_list = []
|
275 |
+
device = tensor.device
|
276 |
+
for function, patches, key in getattr(tensor, "patches", []):
|
277 |
+
patch_list += move_patch_to_device(patches, device)
|
278 |
+
|
279 |
+
# Dequantize tensor while patches load
|
280 |
+
weight = dequantize_tensor(tensor, dtype, self.dequant_dtype)
|
281 |
+
|
282 |
+
# Apply patches
|
283 |
+
if patch_list:
|
284 |
+
if self.patch_dtype is None:
|
285 |
+
weight = function(patch_list, weight, key)
|
286 |
+
else:
|
287 |
+
# For testing, may degrade image quality
|
288 |
+
patch_dtype = (
|
289 |
+
dtype if self.patch_dtype == "target" else self.patch_dtype
|
290 |
+
)
|
291 |
+
weight = function(patch_list, weight, key, patch_dtype)
|
292 |
+
return weight
|
293 |
+
|
294 |
+
def cast_bias_weight(
|
295 |
+
self,
|
296 |
+
input: torch.Tensor = None,
|
297 |
+
dtype: torch.dtype = None,
|
298 |
+
device: torch.device = None,
|
299 |
+
bias_dtype: torch.dtype = None,
|
300 |
+
) -> tuple:
|
301 |
+
"""#### Cast layer weights and bias to target dtype/device.
|
302 |
+
|
303 |
+
#### Args:
|
304 |
+
- `input` (torch.Tensor, optional): Input tensor for type/device inference.
|
305 |
+
- `dtype` (torch.dtype, optional): Target dtype.
|
306 |
+
- `device` (torch.device, optional): Target device.
|
307 |
+
- `bias_dtype` (torch.dtype, optional): Target bias dtype.
|
308 |
+
|
309 |
+
#### Returns:
|
310 |
+
- `tuple`: (cast_weight, cast_bias)
|
311 |
+
"""
|
312 |
+
if input is not None:
|
313 |
+
if dtype is None:
|
314 |
+
dtype = getattr(input, "dtype", torch.float32)
|
315 |
+
if bias_dtype is None:
|
316 |
+
bias_dtype = dtype
|
317 |
+
if device is None:
|
318 |
+
device = input.device
|
319 |
+
|
320 |
+
bias = None
|
321 |
+
non_blocking = Device.device_supports_non_blocking(device)
|
322 |
+
if self.bias is not None:
|
323 |
+
bias = self.get_weight(self.bias.to(device), dtype)
|
324 |
+
bias = cast.cast_to(
|
325 |
+
bias, bias_dtype, device, non_blocking=non_blocking, copy=False
|
326 |
+
)
|
327 |
+
|
328 |
+
weight = self.get_weight(self.weight.to(device), dtype)
|
329 |
+
weight = cast.cast_to(
|
330 |
+
weight, dtype, device, non_blocking=non_blocking, copy=False
|
331 |
+
)
|
332 |
+
return weight, bias
|
333 |
+
|
334 |
+
def forward_comfy_cast_weights(
|
335 |
+
self, input: torch.Tensor, *args, **kwargs
|
336 |
+
) -> torch.Tensor:
|
337 |
+
"""#### Forward pass with weight casting.
|
338 |
+
|
339 |
+
#### Args:
|
340 |
+
- `input` (torch.Tensor): Input tensor.
|
341 |
+
- `*args`: Additional arguments.
|
342 |
+
- `**kwargs`: Additional keyword arguments.
|
343 |
+
|
344 |
+
#### Returns:
|
345 |
+
- `torch.Tensor`: Output tensor.
|
346 |
+
"""
|
347 |
+
if self.is_ggml_quantized():
|
348 |
+
return self.forward_ggml_cast_weights(input, *args, **kwargs)
|
349 |
+
return super().forward_comfy_cast_weights(input, *args, **kwargs)
|
350 |
+
|
351 |
+
|
352 |
+
class GGMLOps(cast.manual_cast):
|
353 |
+
"""
|
354 |
+
Dequantize weights on the fly before doing the compute
|
355 |
+
"""
|
356 |
+
|
357 |
+
class Linear(GGMLLayer, cast.manual_cast.Linear):
|
358 |
+
def __init__(
|
359 |
+
self, in_features, out_features, bias=True, device=None, dtype=None
|
360 |
+
):
|
361 |
+
"""
|
362 |
+
Initialize the Linear layer.
|
363 |
+
|
364 |
+
Args:
|
365 |
+
in_features (int): Number of input features.
|
366 |
+
out_features (int): Number of output features.
|
367 |
+
bias (bool, optional): If set to False, the layer will not learn an additive bias. Defaults to True.
|
368 |
+
device (torch.device, optional): The device to store the layer's parameters. Defaults to None.
|
369 |
+
dtype (torch.dtype, optional): The data type of the layer's parameters. Defaults to None.
|
370 |
+
"""
|
371 |
+
torch.nn.Module.__init__(self)
|
372 |
+
# TODO: better workaround for reserved memory spike on windows
|
373 |
+
# Issue is with `torch.empty` still reserving the full memory for the layer
|
374 |
+
# Windows doesn't over-commit memory so without this 24GB+ of pagefile is used
|
375 |
+
self.in_features = in_features
|
376 |
+
self.out_features = out_features
|
377 |
+
self.weight = None
|
378 |
+
self.bias = None
|
379 |
+
|
380 |
+
def forward_ggml_cast_weights(self, input: torch.Tensor) -> torch.Tensor:
|
381 |
+
"""
|
382 |
+
Forward pass with GGML cast weights.
|
383 |
+
|
384 |
+
Args:
|
385 |
+
input (torch.Tensor): The input tensor.
|
386 |
+
|
387 |
+
Returns:
|
388 |
+
torch.Tensor: The output tensor.
|
389 |
+
"""
|
390 |
+
weight, bias = self.cast_bias_weight(input)
|
391 |
+
return torch.nn.functional.linear(input, weight, bias)
|
392 |
+
|
393 |
+
class Embedding(GGMLLayer, cast.manual_cast.Embedding):
|
394 |
+
def forward_ggml_cast_weights(
|
395 |
+
self, input: torch.Tensor, out_dtype: torch.dtype = None
|
396 |
+
) -> torch.Tensor:
|
397 |
+
"""
|
398 |
+
Forward pass with GGML cast weights for embedding.
|
399 |
+
|
400 |
+
Args:
|
401 |
+
input (torch.Tensor): The input tensor.
|
402 |
+
out_dtype (torch.dtype, optional): The output data type. Defaults to None.
|
403 |
+
|
404 |
+
Returns:
|
405 |
+
torch.Tensor: The output tensor.
|
406 |
+
"""
|
407 |
+
output_dtype = out_dtype
|
408 |
+
if (
|
409 |
+
self.weight.dtype == torch.float16
|
410 |
+
or self.weight.dtype == torch.bfloat16
|
411 |
+
):
|
412 |
+
out_dtype = None
|
413 |
+
weight, _bias = self.cast_bias_weight(
|
414 |
+
self, device=input.device, dtype=out_dtype
|
415 |
+
)
|
416 |
+
return torch.nn.functional.embedding(
|
417 |
+
input,
|
418 |
+
weight,
|
419 |
+
self.padding_idx,
|
420 |
+
self.max_norm,
|
421 |
+
self.norm_type,
|
422 |
+
self.scale_grad_by_freq,
|
423 |
+
self.sparse,
|
424 |
+
).to(dtype=output_dtype)
|
425 |
+
|
426 |
+
|
427 |
+
def gguf_sd_loader_get_orig_shape(
|
428 |
+
reader: gguf.GGUFReader, tensor_name: str
|
429 |
+
) -> torch.Size:
|
430 |
+
"""#### Get the original shape of a tensor from a GGUF reader.
|
431 |
+
|
432 |
+
#### Args:
|
433 |
+
- `reader` (gguf.GGUFReader): The GGUF reader.
|
434 |
+
- `tensor_name` (str): The name of the tensor.
|
435 |
+
|
436 |
+
#### Returns:
|
437 |
+
- `torch.Size`: The original shape of the tensor.
|
438 |
+
"""
|
439 |
+
field_key = f"comfy.gguf.orig_shape.{tensor_name}"
|
440 |
+
field = reader.get_field(field_key)
|
441 |
+
if field is None:
|
442 |
+
return None
|
443 |
+
# Has original shape metadata, so we try to decode it.
|
444 |
+
if (
|
445 |
+
len(field.types) != 2
|
446 |
+
or field.types[0] != gguf.GGUFValueType.ARRAY
|
447 |
+
or field.types[1] != gguf.GGUFValueType.INT32
|
448 |
+
):
|
449 |
+
raise TypeError(
|
450 |
+
f"Bad original shape metadata for {field_key}: Expected ARRAY of INT32, got {field.types}"
|
451 |
+
)
|
452 |
+
return torch.Size(tuple(int(field.parts[part_idx][0]) for part_idx in field.data))
|
453 |
+
|
454 |
+
|
455 |
+
class GGMLTensor(torch.Tensor):
|
456 |
+
"""
|
457 |
+
Main tensor-like class for storing quantized weights
|
458 |
+
"""
|
459 |
+
|
460 |
+
def __init__(self, *args, tensor_type, tensor_shape, patches=[], **kwargs):
|
461 |
+
"""
|
462 |
+
Initialize the GGMLTensor.
|
463 |
+
|
464 |
+
Args:
|
465 |
+
*args: Variable length argument list.
|
466 |
+
tensor_type: The type of the tensor.
|
467 |
+
tensor_shape: The shape of the tensor.
|
468 |
+
patches (list, optional): List of patches. Defaults to [].
|
469 |
+
**kwargs: Arbitrary keyword arguments.
|
470 |
+
"""
|
471 |
+
super().__init__()
|
472 |
+
self.tensor_type = tensor_type
|
473 |
+
self.tensor_shape = tensor_shape
|
474 |
+
self.patches = patches
|
475 |
+
|
476 |
+
def __new__(cls, *args, tensor_type, tensor_shape, patches=[], **kwargs):
|
477 |
+
"""
|
478 |
+
Create a new instance of GGMLTensor.
|
479 |
+
|
480 |
+
Args:
|
481 |
+
*args: Variable length argument list.
|
482 |
+
tensor_type: The type of the tensor.
|
483 |
+
tensor_shape: The shape of the tensor.
|
484 |
+
patches (list, optional): List of patches. Defaults to [].
|
485 |
+
**kwargs: Arbitrary keyword arguments.
|
486 |
+
|
487 |
+
Returns:
|
488 |
+
GGMLTensor: A new instance of GGMLTensor.
|
489 |
+
"""
|
490 |
+
return super().__new__(cls, *args, **kwargs)
|
491 |
+
|
492 |
+
def to(self, *args, **kwargs):
|
493 |
+
"""
|
494 |
+
Convert the tensor to a specified device and/or dtype.
|
495 |
+
|
496 |
+
Args:
|
497 |
+
*args: Variable length argument list.
|
498 |
+
**kwargs: Arbitrary keyword arguments.
|
499 |
+
|
500 |
+
Returns:
|
501 |
+
GGMLTensor: The converted tensor.
|
502 |
+
"""
|
503 |
+
new = super().to(*args, **kwargs)
|
504 |
+
new.tensor_type = getattr(self, "tensor_type", None)
|
505 |
+
new.tensor_shape = getattr(self, "tensor_shape", new.data.shape)
|
506 |
+
new.patches = getattr(self, "patches", []).copy()
|
507 |
+
return new
|
508 |
+
|
509 |
+
def clone(self, *args, **kwargs):
|
510 |
+
"""
|
511 |
+
Clone the tensor.
|
512 |
+
|
513 |
+
Args:
|
514 |
+
*args: Variable length argument list.
|
515 |
+
**kwargs: Arbitrary keyword arguments.
|
516 |
+
|
517 |
+
Returns:
|
518 |
+
GGMLTensor: The cloned tensor.
|
519 |
+
"""
|
520 |
+
return self
|
521 |
+
|
522 |
+
def detach(self, *args, **kwargs):
|
523 |
+
"""
|
524 |
+
Detach the tensor from the computation graph.
|
525 |
+
|
526 |
+
Args:
|
527 |
+
*args: Variable length argument list.
|
528 |
+
**kwargs: Arbitrary keyword arguments.
|
529 |
+
|
530 |
+
Returns:
|
531 |
+
GGMLTensor: The detached tensor.
|
532 |
+
"""
|
533 |
+
return self
|
534 |
+
|
535 |
+
def copy_(self, *args, **kwargs):
|
536 |
+
"""
|
537 |
+
Copy the values from another tensor into this tensor.
|
538 |
+
|
539 |
+
Args:
|
540 |
+
*args: Variable length argument list.
|
541 |
+
**kwargs: Arbitrary keyword arguments.
|
542 |
+
|
543 |
+
Returns:
|
544 |
+
GGMLTensor: The tensor with copied values.
|
545 |
+
"""
|
546 |
+
try:
|
547 |
+
return super().copy_(*args, **kwargs)
|
548 |
+
except Exception as e:
|
549 |
+
print(f"ignoring 'copy_' on tensor: {e}")
|
550 |
+
|
551 |
+
def __deepcopy__(self, *args, **kwargs):
|
552 |
+
"""
|
553 |
+
Create a deep copy of the tensor.
|
554 |
+
|
555 |
+
Args:
|
556 |
+
*args: Variable length argument list.
|
557 |
+
**kwargs: Arbitrary keyword arguments.
|
558 |
+
|
559 |
+
Returns:
|
560 |
+
GGMLTensor: The deep copied tensor.
|
561 |
+
"""
|
562 |
+
new = super().__deepcopy__(*args, **kwargs)
|
563 |
+
new.tensor_type = getattr(self, "tensor_type", None)
|
564 |
+
new.tensor_shape = getattr(self, "tensor_shape", new.data.shape)
|
565 |
+
new.patches = getattr(self, "patches", []).copy()
|
566 |
+
return new
|
567 |
+
|
568 |
+
@property
|
569 |
+
def shape(self):
|
570 |
+
"""
|
571 |
+
Get the shape of the tensor.
|
572 |
+
|
573 |
+
Returns:
|
574 |
+
torch.Size: The shape of the tensor.
|
575 |
+
"""
|
576 |
+
if not hasattr(self, "tensor_shape"):
|
577 |
+
self.tensor_shape = self.size()
|
578 |
+
return self.tensor_shape
|
579 |
+
|
580 |
+
|
581 |
+
def gguf_sd_loader(path: str, handle_prefix: str = "model.diffusion_model."):
|
582 |
+
"""#### Load a GGUF file into a state dict.
|
583 |
+
|
584 |
+
#### Args:
|
585 |
+
- `path` (str): The path to the GGUF file.
|
586 |
+
- `handle_prefix` (str, optional): The prefix to handle. Defaults to "model.diffusion_model.".
|
587 |
+
|
588 |
+
#### Returns:
|
589 |
+
- `dict`: The loaded state dict.
|
590 |
+
"""
|
591 |
+
reader = gguf.GGUFReader(path)
|
592 |
+
|
593 |
+
# filter and strip prefix
|
594 |
+
has_prefix = False
|
595 |
+
if handle_prefix is not None:
|
596 |
+
prefix_len = len(handle_prefix)
|
597 |
+
tensor_names = set(tensor.name for tensor in reader.tensors)
|
598 |
+
has_prefix = any(s.startswith(handle_prefix) for s in tensor_names)
|
599 |
+
|
600 |
+
tensors = []
|
601 |
+
for tensor in reader.tensors:
|
602 |
+
sd_key = tensor_name = tensor.name
|
603 |
+
if has_prefix:
|
604 |
+
if not tensor_name.startswith(handle_prefix):
|
605 |
+
continue
|
606 |
+
sd_key = tensor_name[prefix_len:]
|
607 |
+
tensors.append((sd_key, tensor))
|
608 |
+
|
609 |
+
# detect and verify architecture
|
610 |
+
compat = None
|
611 |
+
arch_str = None
|
612 |
+
arch_field = reader.get_field("general.architecture")
|
613 |
+
if arch_field is not None:
|
614 |
+
if (
|
615 |
+
len(arch_field.types) != 1
|
616 |
+
or arch_field.types[0] != gguf.GGUFValueType.STRING
|
617 |
+
):
|
618 |
+
raise TypeError(
|
619 |
+
f"Bad type for GGUF general.architecture key: expected string, got {arch_field.types!r}"
|
620 |
+
)
|
621 |
+
arch_str = str(arch_field.parts[arch_field.data[-1]], encoding="utf-8")
|
622 |
+
if arch_str not in {"flux", "sd1", "sdxl", "t5", "t5encoder"}:
|
623 |
+
raise ValueError(
|
624 |
+
f"Unexpected architecture type in GGUF file, expected one of flux, sd1, sdxl, t5encoder but got {arch_str!r}"
|
625 |
+
)
|
626 |
+
|
627 |
+
# main loading loop
|
628 |
+
state_dict = {}
|
629 |
+
qtype_dict = {}
|
630 |
+
for sd_key, tensor in tensors:
|
631 |
+
tensor_name = tensor.name
|
632 |
+
tensor_type_str = str(tensor.tensor_type)
|
633 |
+
torch_tensor = torch.from_numpy(tensor.data) # mmap
|
634 |
+
|
635 |
+
shape = gguf_sd_loader_get_orig_shape(reader, tensor_name)
|
636 |
+
if shape is None:
|
637 |
+
shape = torch.Size(tuple(int(v) for v in reversed(tensor.shape)))
|
638 |
+
# Workaround for stable-diffusion.cpp SDXL detection.
|
639 |
+
if compat == "sd.cpp" and arch_str == "sdxl":
|
640 |
+
if any(
|
641 |
+
[
|
642 |
+
tensor_name.endswith(x)
|
643 |
+
for x in (".proj_in.weight", ".proj_out.weight")
|
644 |
+
]
|
645 |
+
):
|
646 |
+
while len(shape) > 2 and shape[-1] == 1:
|
647 |
+
shape = shape[:-1]
|
648 |
+
|
649 |
+
# add to state dict
|
650 |
+
if tensor.tensor_type in {
|
651 |
+
gguf.GGMLQuantizationType.F32,
|
652 |
+
gguf.GGMLQuantizationType.F16,
|
653 |
+
}:
|
654 |
+
torch_tensor = torch_tensor.view(*shape)
|
655 |
+
state_dict[sd_key] = GGMLTensor(
|
656 |
+
torch_tensor, tensor_type=tensor.tensor_type, tensor_shape=shape
|
657 |
+
)
|
658 |
+
qtype_dict[tensor_type_str] = qtype_dict.get(tensor_type_str, 0) + 1
|
659 |
+
|
660 |
+
# sanity check debug print
|
661 |
+
print("\nggml_sd_loader:")
|
662 |
+
for k, v in qtype_dict.items():
|
663 |
+
print(f" {k:30}{v:3}")
|
664 |
+
|
665 |
+
return state_dict
|
666 |
+
|
667 |
+
|
668 |
+
class GGUFModelPatcher(ModelPatcher.ModelPatcher):
|
669 |
+
patch_on_device = False
|
670 |
+
|
671 |
+
def unpatch_model(self, device_to=None, unpatch_weights=True):
|
672 |
+
"""
|
673 |
+
Unpatch the model.
|
674 |
+
|
675 |
+
Args:
|
676 |
+
device_to (torch.device, optional): The device to move the model to. Defaults to None.
|
677 |
+
unpatch_weights (bool, optional): Whether to unpatch the weights. Defaults to True.
|
678 |
+
|
679 |
+
Returns:
|
680 |
+
GGUFModelPatcher: The unpatched model.
|
681 |
+
"""
|
682 |
+
if unpatch_weights:
|
683 |
+
for p in self.model.parameters():
|
684 |
+
if is_torch_compatible(p):
|
685 |
+
continue
|
686 |
+
patches = getattr(p, "patches", [])
|
687 |
+
if len(patches) > 0:
|
688 |
+
p.patches = []
|
689 |
+
self.object_patches = {}
|
690 |
+
# TODO: Find another way to not unload after patches
|
691 |
+
return super().unpatch_model(
|
692 |
+
device_to=device_to, unpatch_weights=unpatch_weights
|
693 |
+
)
|
694 |
+
|
695 |
+
mmap_released = False
|
696 |
+
|
697 |
+
def load(self, *args, force_patch_weights=False, **kwargs):
|
698 |
+
"""
|
699 |
+
Load the model.
|
700 |
+
|
701 |
+
Args:
|
702 |
+
*args: Variable length argument list.
|
703 |
+
force_patch_weights (bool, optional): Whether to force patch weights. Defaults to False.
|
704 |
+
**kwargs: Arbitrary keyword arguments.
|
705 |
+
"""
|
706 |
+
super().load(*args, force_patch_weights=True, **kwargs)
|
707 |
+
|
708 |
+
# make sure nothing stays linked to mmap after first load
|
709 |
+
if not self.mmap_released:
|
710 |
+
linked = []
|
711 |
+
if kwargs.get("lowvram_model_memory", 0) > 0:
|
712 |
+
for n, m in self.model.named_modules():
|
713 |
+
if hasattr(m, "weight"):
|
714 |
+
device = getattr(m.weight, "device", None)
|
715 |
+
if device == self.offload_device:
|
716 |
+
linked.append((n, m))
|
717 |
+
continue
|
718 |
+
if hasattr(m, "bias"):
|
719 |
+
device = getattr(m.bias, "device", None)
|
720 |
+
if device == self.offload_device:
|
721 |
+
linked.append((n, m))
|
722 |
+
continue
|
723 |
+
if linked:
|
724 |
+
print(f"Attempting to release mmap ({len(linked)})")
|
725 |
+
for n, m in linked:
|
726 |
+
# TODO: possible to OOM, find better way to detach
|
727 |
+
m.to(self.load_device).to(self.offload_device)
|
728 |
+
self.mmap_released = True
|
729 |
+
|
730 |
+
def add_object_patch(self, name, obj):
|
731 |
+
self.object_patches[name] = obj
|
732 |
+
|
733 |
+
def clone(self, *args, **kwargs):
|
734 |
+
"""
|
735 |
+
Clone the model patcher.
|
736 |
+
|
737 |
+
Args:
|
738 |
+
*args: Variable length argument list.
|
739 |
+
**kwargs: Arbitrary keyword arguments.
|
740 |
+
|
741 |
+
Returns:
|
742 |
+
GGUFModelPatcher: The cloned model patcher.
|
743 |
+
"""
|
744 |
+
n = GGUFModelPatcher(
|
745 |
+
self.model,
|
746 |
+
self.load_device,
|
747 |
+
self.offload_device,
|
748 |
+
self.size,
|
749 |
+
weight_inplace_update=self.weight_inplace_update,
|
750 |
+
)
|
751 |
+
n.patches = {}
|
752 |
+
for k in self.patches:
|
753 |
+
n.patches[k] = self.patches[k][:]
|
754 |
+
n.patches_uuid = self.patches_uuid
|
755 |
+
|
756 |
+
n.object_patches = self.object_patches.copy()
|
757 |
+
n.model_options = copy.deepcopy(self.model_options)
|
758 |
+
n.backup = self.backup
|
759 |
+
n.object_patches_backup = self.object_patches_backup
|
760 |
+
n.patch_on_device = getattr(self, "patch_on_device", False)
|
761 |
+
return n
|
762 |
+
|
763 |
+
|
764 |
+
class UnetLoaderGGUF:
|
765 |
+
def load_unet(
|
766 |
+
self,
|
767 |
+
unet_name: str,
|
768 |
+
dequant_dtype: str = None,
|
769 |
+
patch_dtype: str = None,
|
770 |
+
patch_on_device: bool = None,
|
771 |
+
) -> tuple:
|
772 |
+
"""
|
773 |
+
Load the UNet model.
|
774 |
+
|
775 |
+
Args:
|
776 |
+
unet_name (str): The name of the UNet model.
|
777 |
+
dequant_dtype (str, optional): The dequantization data type. Defaults to None.
|
778 |
+
patch_dtype (str, optional): The patch data type. Defaults to None.
|
779 |
+
patch_on_device (bool, optional): Whether to patch on device. Defaults to None.
|
780 |
+
|
781 |
+
Returns:
|
782 |
+
tuple: The loaded model.
|
783 |
+
"""
|
784 |
+
ops = GGMLOps()
|
785 |
+
|
786 |
+
if dequant_dtype in ("default", None):
|
787 |
+
ops.Linear.dequant_dtype = None
|
788 |
+
elif dequant_dtype in ["target"]:
|
789 |
+
ops.Linear.dequant_dtype = dequant_dtype
|
790 |
+
else:
|
791 |
+
ops.Linear.dequant_dtype = getattr(torch, dequant_dtype)
|
792 |
+
|
793 |
+
if patch_dtype in ("default", None):
|
794 |
+
ops.Linear.patch_dtype = None
|
795 |
+
elif patch_dtype in ["target"]:
|
796 |
+
ops.Linear.patch_dtype = patch_dtype
|
797 |
+
else:
|
798 |
+
ops.Linear.patch_dtype = getattr(torch, patch_dtype)
|
799 |
+
|
800 |
+
unet_path = "./_internal/unet/" + unet_name
|
801 |
+
sd = gguf_sd_loader(unet_path)
|
802 |
+
model = ModelPatcher.load_diffusion_model_state_dict(
|
803 |
+
sd, model_options={"custom_operations": ops}
|
804 |
+
)
|
805 |
+
if model is None:
|
806 |
+
logging.error("ERROR UNSUPPORTED UNET {}".format(unet_path))
|
807 |
+
raise RuntimeError(
|
808 |
+
"ERROR: Could not detect model type of: {}".format(unet_path)
|
809 |
+
)
|
810 |
+
model = GGUFModelPatcher.clone(model)
|
811 |
+
model.patch_on_device = patch_on_device
|
812 |
+
return (model,)
|
813 |
+
|
814 |
+
|
815 |
+
clip_sd_map = {
|
816 |
+
"enc.": "encoder.",
|
817 |
+
".blk.": ".block.",
|
818 |
+
"token_embd": "shared",
|
819 |
+
"output_norm": "final_layer_norm",
|
820 |
+
"attn_q": "layer.0.SelfAttention.q",
|
821 |
+
"attn_k": "layer.0.SelfAttention.k",
|
822 |
+
"attn_v": "layer.0.SelfAttention.v",
|
823 |
+
"attn_o": "layer.0.SelfAttention.o",
|
824 |
+
"attn_norm": "layer.0.layer_norm",
|
825 |
+
"attn_rel_b": "layer.0.SelfAttention.relative_attention_bias",
|
826 |
+
"ffn_up": "layer.1.DenseReluDense.wi_1",
|
827 |
+
"ffn_down": "layer.1.DenseReluDense.wo",
|
828 |
+
"ffn_gate": "layer.1.DenseReluDense.wi_0",
|
829 |
+
"ffn_norm": "layer.1.layer_norm",
|
830 |
+
}
|
831 |
+
|
832 |
+
clip_name_dict = {
|
833 |
+
"stable_diffusion": Clip.CLIPType.STABLE_DIFFUSION,
|
834 |
+
"sdxl": Clip.CLIPType.STABLE_DIFFUSION,
|
835 |
+
"sd3": Clip.CLIPType.SD3,
|
836 |
+
"flux": Clip.CLIPType.FLUX,
|
837 |
+
}
|
838 |
+
|
839 |
+
|
840 |
+
def gguf_clip_loader(path: str) -> dict:
|
841 |
+
"""#### Load a CLIP model from a GGUF file.
|
842 |
+
|
843 |
+
#### Args:
|
844 |
+
- `path` (str): The path to the GGUF file.
|
845 |
+
|
846 |
+
#### Returns:
|
847 |
+
- `dict`: The loaded CLIP model.
|
848 |
+
"""
|
849 |
+
raw_sd = gguf_sd_loader(path)
|
850 |
+
assert "enc.blk.23.ffn_up.weight" in raw_sd, "Invalid Text Encoder!"
|
851 |
+
sd = {}
|
852 |
+
for k, v in raw_sd.items():
|
853 |
+
for s, d in clip_sd_map.items():
|
854 |
+
k = k.replace(s, d)
|
855 |
+
sd[k] = v
|
856 |
+
return sd
|
857 |
+
|
858 |
+
|
859 |
+
class CLIPLoaderGGUF:
|
860 |
+
def load_data(self, ckpt_paths: list) -> list:
|
861 |
+
"""
|
862 |
+
Load data from checkpoint paths.
|
863 |
+
|
864 |
+
Args:
|
865 |
+
ckpt_paths (list): List of checkpoint paths.
|
866 |
+
|
867 |
+
Returns:
|
868 |
+
list: List of loaded data.
|
869 |
+
"""
|
870 |
+
clip_data = []
|
871 |
+
for p in ckpt_paths:
|
872 |
+
if p.endswith(".gguf"):
|
873 |
+
clip_data.append(gguf_clip_loader(p))
|
874 |
+
else:
|
875 |
+
sd = util.load_torch_file(p, safe_load=True)
|
876 |
+
clip_data.append(
|
877 |
+
{
|
878 |
+
k: GGMLTensor(
|
879 |
+
v,
|
880 |
+
tensor_type=gguf.GGMLQuantizationType.F16,
|
881 |
+
tensor_shape=v.shape,
|
882 |
+
)
|
883 |
+
for k, v in sd.items()
|
884 |
+
}
|
885 |
+
)
|
886 |
+
return clip_data
|
887 |
+
|
888 |
+
def load_patcher(self, clip_paths: list, clip_type: str, clip_data: list) -> Clip:
|
889 |
+
"""
|
890 |
+
Load the model patcher.
|
891 |
+
|
892 |
+
Args:
|
893 |
+
clip_paths (list): List of clip paths.
|
894 |
+
clip_type (str): The type of the clip.
|
895 |
+
clip_data (list): List of clip data.
|
896 |
+
|
897 |
+
Returns:
|
898 |
+
Clip: The loaded clip.
|
899 |
+
"""
|
900 |
+
clip = Clip.load_text_encoder_state_dicts(
|
901 |
+
clip_type=clip_type,
|
902 |
+
state_dicts=clip_data,
|
903 |
+
model_options={
|
904 |
+
"custom_operations": GGMLOps,
|
905 |
+
"initial_device": Device.text_encoder_offload_device(),
|
906 |
+
},
|
907 |
+
embedding_directory="models/embeddings",
|
908 |
+
)
|
909 |
+
clip.patcher = GGUFModelPatcher.clone(clip.patcher)
|
910 |
+
|
911 |
+
# for some reason this is just missing in some SAI checkpoints
|
912 |
+
if getattr(clip.cond_stage_model, "clip_l", None) is not None:
|
913 |
+
if (
|
914 |
+
getattr(
|
915 |
+
clip.cond_stage_model.clip_l.transformer.text_projection.weight,
|
916 |
+
"tensor_shape",
|
917 |
+
None,
|
918 |
+
)
|
919 |
+
is None
|
920 |
+
):
|
921 |
+
clip.cond_stage_model.clip_l.transformer.text_projection = (
|
922 |
+
cast.manual_cast.Linear(768, 768)
|
923 |
+
)
|
924 |
+
if getattr(clip.cond_stage_model, "clip_g", None) is not None:
|
925 |
+
if (
|
926 |
+
getattr(
|
927 |
+
clip.cond_stage_model.clip_g.transformer.text_projection.weight,
|
928 |
+
"tensor_shape",
|
929 |
+
None,
|
930 |
+
)
|
931 |
+
is None
|
932 |
+
):
|
933 |
+
clip.cond_stage_model.clip_g.transformer.text_projection = (
|
934 |
+
cast.manual_cast.Linear(1280, 1280)
|
935 |
+
)
|
936 |
+
|
937 |
+
return clip
|
938 |
+
|
939 |
+
|
940 |
+
class DualCLIPLoaderGGUF(CLIPLoaderGGUF):
|
941 |
+
def load_clip(self, clip_name1: str, clip_name2: str, type: str) -> tuple:
|
942 |
+
"""
|
943 |
+
Load dual clips.
|
944 |
+
|
945 |
+
Args:
|
946 |
+
clip_name1 (str): The name of the first clip.
|
947 |
+
clip_name2 (str): The name of the second clip.
|
948 |
+
type (str): The type of the clip.
|
949 |
+
|
950 |
+
Returns:
|
951 |
+
tuple: The loaded clips.
|
952 |
+
"""
|
953 |
+
clip_path1 = "./_internal/clip/" + clip_name1
|
954 |
+
clip_path2 = "./_internal/clip/" + clip_name2
|
955 |
+
clip_paths = (clip_path1, clip_path2)
|
956 |
+
clip_type = clip_name_dict.get(type, Clip.CLIPType.STABLE_DIFFUSION)
|
957 |
+
return (self.load_patcher(clip_paths, clip_type, self.load_data(clip_paths)),)
|
958 |
+
|
959 |
+
|
960 |
+
class CLIPTextEncodeFlux:
|
961 |
+
def encode(
|
962 |
+
self,
|
963 |
+
clip: Clip,
|
964 |
+
clip_l: str,
|
965 |
+
t5xxl: str,
|
966 |
+
guidance: str,
|
967 |
+
flux_enabled: bool = False,
|
968 |
+
) -> tuple:
|
969 |
+
"""
|
970 |
+
Encode text using CLIP and T5XXL.
|
971 |
+
|
972 |
+
Args:
|
973 |
+
clip (Clip): The clip object.
|
974 |
+
clip_l (str): The clip text.
|
975 |
+
t5xxl (str): The T5XXL text.
|
976 |
+
guidance (str): The guidance text.
|
977 |
+
flux_enabled (bool, optional): Whether flux is enabled. Defaults to False.
|
978 |
+
|
979 |
+
Returns:
|
980 |
+
tuple: The encoded text.
|
981 |
+
"""
|
982 |
+
tokens = clip.tokenize(clip_l)
|
983 |
+
tokens["t5xxl"] = clip.tokenize(t5xxl)["t5xxl"]
|
984 |
+
|
985 |
+
output = clip.encode_from_tokens(
|
986 |
+
tokens, return_pooled=True, return_dict=True, flux_enabled=flux_enabled
|
987 |
+
)
|
988 |
+
cond = output.pop("cond")
|
989 |
+
output["guidance"] = guidance
|
990 |
+
return ([[cond, output]],)
|
991 |
+
|
992 |
+
|
993 |
+
class ConditioningZeroOut:
|
994 |
+
def zero_out(self, conditioning: list) -> list:
|
995 |
+
"""
|
996 |
+
Zero out the conditioning.
|
997 |
+
|
998 |
+
Args:
|
999 |
+
conditioning (list): The conditioning list.
|
1000 |
+
|
1001 |
+
Returns:
|
1002 |
+
list: The zeroed out conditioning.
|
1003 |
+
"""
|
1004 |
+
c = []
|
1005 |
+
for t in conditioning:
|
1006 |
+
d = t[1].copy()
|
1007 |
+
pooled_output = d.get("pooled_output", None)
|
1008 |
+
if pooled_output is not None:
|
1009 |
+
d["pooled_output"] = torch.zeros_like(pooled_output)
|
1010 |
+
n = [torch.zeros_like(t[0]), d]
|
1011 |
+
c.append(n)
|
1012 |
+
return (c,)
|
modules/SD15/SD15.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from modules.BlackForest import Flux
|
3 |
+
from modules.Utilities import util
|
4 |
+
from modules.Model import ModelBase
|
5 |
+
from modules.SD15 import SDClip, SDToken
|
6 |
+
from modules.Utilities import Latent
|
7 |
+
from modules.clip import Clip
|
8 |
+
|
9 |
+
|
10 |
+
class sm_SD15(ModelBase.BASE):
|
11 |
+
"""#### Class representing the SD15 model.
|
12 |
+
|
13 |
+
#### Args:
|
14 |
+
- `ModelBase.BASE` (ModelBase.BASE): The base model class.
|
15 |
+
"""
|
16 |
+
|
17 |
+
unet_config: dict = {
|
18 |
+
"context_dim": 768,
|
19 |
+
"model_channels": 320,
|
20 |
+
"use_linear_in_transformer": False,
|
21 |
+
"adm_in_channels": None,
|
22 |
+
"use_temporal_attention": False,
|
23 |
+
}
|
24 |
+
|
25 |
+
unet_extra_config: dict = {
|
26 |
+
"num_heads": 8,
|
27 |
+
"num_head_channels": -1,
|
28 |
+
}
|
29 |
+
|
30 |
+
latent_format: Latent.SD15 = Latent.SD15
|
31 |
+
|
32 |
+
def process_clip_state_dict(self, state_dict: dict) -> dict:
|
33 |
+
"""#### Process the state dictionary for the CLIP model.
|
34 |
+
|
35 |
+
#### Args:
|
36 |
+
- `state_dict` (dict): The state dictionary.
|
37 |
+
|
38 |
+
#### Returns:
|
39 |
+
- `dict`: The processed state dictionary.
|
40 |
+
"""
|
41 |
+
k = list(state_dict.keys())
|
42 |
+
for x in k:
|
43 |
+
if x.startswith("cond_stage_model.transformer.") and not x.startswith(
|
44 |
+
"cond_stage_model.transformer.text_model."
|
45 |
+
):
|
46 |
+
y = x.replace(
|
47 |
+
"cond_stage_model.transformer.",
|
48 |
+
"cond_stage_model.transformer.text_model.",
|
49 |
+
)
|
50 |
+
state_dict[y] = state_dict.pop(x)
|
51 |
+
|
52 |
+
if (
|
53 |
+
"cond_stage_model.transformer.text_model.embeddings.position_ids"
|
54 |
+
in state_dict
|
55 |
+
):
|
56 |
+
ids = state_dict[
|
57 |
+
"cond_stage_model.transformer.text_model.embeddings.position_ids"
|
58 |
+
]
|
59 |
+
if ids.dtype == torch.float32:
|
60 |
+
state_dict[
|
61 |
+
"cond_stage_model.transformer.text_model.embeddings.position_ids"
|
62 |
+
] = ids.round()
|
63 |
+
|
64 |
+
replace_prefix = {}
|
65 |
+
replace_prefix["cond_stage_model."] = "clip_l."
|
66 |
+
state_dict = util.state_dict_prefix_replace(
|
67 |
+
state_dict, replace_prefix, filter_keys=True
|
68 |
+
)
|
69 |
+
return state_dict
|
70 |
+
|
71 |
+
def clip_target(self) -> Clip.ClipTarget:
|
72 |
+
"""#### Get the target CLIP model.
|
73 |
+
|
74 |
+
#### Returns:
|
75 |
+
- `Clip.ClipTarget`: The target CLIP model.
|
76 |
+
"""
|
77 |
+
return Clip.ClipTarget(SDToken.SD1Tokenizer, SDClip.SD1ClipModel)
|
78 |
+
|
79 |
+
models = [
|
80 |
+
sm_SD15, Flux.Flux
|
81 |
+
]
|
modules/SD15/SDClip.py
ADDED
@@ -0,0 +1,403 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import logging
|
3 |
+
import numbers
|
4 |
+
import torch
|
5 |
+
from modules.Device import Device
|
6 |
+
from modules.cond import cast
|
7 |
+
from modules.clip.CLIPTextModel import CLIPTextModel
|
8 |
+
|
9 |
+
|
10 |
+
|
11 |
+
def gen_empty_tokens(special_tokens: dict, length: int) -> list:
|
12 |
+
"""#### Generate a list of empty tokens.
|
13 |
+
|
14 |
+
#### Args:
|
15 |
+
- `special_tokens` (dict): The special tokens.
|
16 |
+
- `length` (int): The length of the token list.
|
17 |
+
|
18 |
+
#### Returns:
|
19 |
+
- `list`: The list of empty tokens.
|
20 |
+
"""
|
21 |
+
start_token = special_tokens.get("start", None)
|
22 |
+
end_token = special_tokens.get("end", None)
|
23 |
+
pad_token = special_tokens.get("pad")
|
24 |
+
output = []
|
25 |
+
if start_token is not None:
|
26 |
+
output.append(start_token)
|
27 |
+
if end_token is not None:
|
28 |
+
output.append(end_token)
|
29 |
+
output += [pad_token] * (length - len(output))
|
30 |
+
return output
|
31 |
+
|
32 |
+
|
33 |
+
class ClipTokenWeightEncoder:
|
34 |
+
"""#### Class representing a CLIP token weight encoder."""
|
35 |
+
|
36 |
+
def encode_token_weights(self, token_weight_pairs: list) -> tuple:
|
37 |
+
"""#### Encode token weights.
|
38 |
+
|
39 |
+
#### Args:
|
40 |
+
- `token_weight_pairs` (list): The token weight pairs.
|
41 |
+
|
42 |
+
#### Returns:
|
43 |
+
- `tuple`: The encoded tokens and the pooled output.
|
44 |
+
"""
|
45 |
+
to_encode = list()
|
46 |
+
max_token_len = 0
|
47 |
+
has_weights = False
|
48 |
+
for x in token_weight_pairs:
|
49 |
+
tokens = list(map(lambda a: a[0], x))
|
50 |
+
max_token_len = max(len(tokens), max_token_len)
|
51 |
+
has_weights = has_weights or not all(map(lambda a: a[1] == 1.0, x))
|
52 |
+
to_encode.append(tokens)
|
53 |
+
|
54 |
+
sections = len(to_encode)
|
55 |
+
if has_weights or sections == 0:
|
56 |
+
to_encode.append(gen_empty_tokens(self.special_tokens, max_token_len))
|
57 |
+
|
58 |
+
o = self.encode(to_encode)
|
59 |
+
out, pooled = o[:2]
|
60 |
+
|
61 |
+
if pooled is not None:
|
62 |
+
first_pooled = pooled[0:1].to(Device.intermediate_device())
|
63 |
+
else:
|
64 |
+
first_pooled = pooled
|
65 |
+
|
66 |
+
output = []
|
67 |
+
for k in range(0, sections):
|
68 |
+
z = out[k : k + 1]
|
69 |
+
if has_weights:
|
70 |
+
z_empty = out[-1]
|
71 |
+
for i in range(len(z)):
|
72 |
+
for j in range(len(z[i])):
|
73 |
+
weight = token_weight_pairs[k][j][1]
|
74 |
+
if weight != 1.0:
|
75 |
+
z[i][j] = (z[i][j] - z_empty[j]) * weight + z_empty[j]
|
76 |
+
output.append(z)
|
77 |
+
|
78 |
+
if len(output) == 0:
|
79 |
+
r = (out[-1:].to(Device.intermediate_device()), first_pooled)
|
80 |
+
else:
|
81 |
+
r = (torch.cat(output, dim=-2).to(Device.intermediate_device()), first_pooled)
|
82 |
+
|
83 |
+
if len(o) > 2:
|
84 |
+
extra = {}
|
85 |
+
for k in o[2]:
|
86 |
+
v = o[2][k]
|
87 |
+
if k == "attention_mask":
|
88 |
+
v = (
|
89 |
+
v[:sections]
|
90 |
+
.flatten()
|
91 |
+
.unsqueeze(dim=0)
|
92 |
+
.to(Device.intermediate_device())
|
93 |
+
)
|
94 |
+
extra[k] = v
|
95 |
+
|
96 |
+
r = r + (extra,)
|
97 |
+
return r
|
98 |
+
|
99 |
+
class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
100 |
+
"""#### Uses the CLIP transformer encoder for text (from huggingface)."""
|
101 |
+
|
102 |
+
LAYERS = ["last", "pooled", "hidden"]
|
103 |
+
|
104 |
+
def __init__(
|
105 |
+
self,
|
106 |
+
version: str = "openai/clip-vit-large-patch14",
|
107 |
+
device: str = "cpu",
|
108 |
+
max_length: int = 77,
|
109 |
+
freeze: bool = True,
|
110 |
+
layer: str = "last",
|
111 |
+
layer_idx: int = None,
|
112 |
+
textmodel_json_config: str = None,
|
113 |
+
dtype: torch.dtype = None,
|
114 |
+
model_class: type = CLIPTextModel,
|
115 |
+
special_tokens: dict = {"start": 49406, "end": 49407, "pad": 49407},
|
116 |
+
layer_norm_hidden_state: bool = True,
|
117 |
+
enable_attention_masks: bool = False,
|
118 |
+
zero_out_masked:bool = False,
|
119 |
+
return_projected_pooled: bool = True,
|
120 |
+
return_attention_masks: bool = False,
|
121 |
+
model_options={},
|
122 |
+
):
|
123 |
+
"""#### Initialize the SDClipModel.
|
124 |
+
|
125 |
+
#### Args:
|
126 |
+
- `version` (str, optional): The version of the model. Defaults to "openai/clip-vit-large-patch14".
|
127 |
+
- `device` (str, optional): The device to use. Defaults to "cpu".
|
128 |
+
- `max_length` (int, optional): The maximum length of the input. Defaults to 77.
|
129 |
+
- `freeze` (bool, optional): Whether to freeze the model parameters. Defaults to True.
|
130 |
+
- `layer` (str, optional): The layer to use. Defaults to "last".
|
131 |
+
- `layer_idx` (int, optional): The index of the layer. Defaults to None.
|
132 |
+
- `textmodel_json_config` (str, optional): The path to the JSON config file. Defaults to None.
|
133 |
+
- `dtype` (torch.dtype, optional): The data type. Defaults to None.
|
134 |
+
- `model_class` (type, optional): The model class. Defaults to CLIPTextModel.
|
135 |
+
- `special_tokens` (dict, optional): The special tokens. Defaults to {"start": 49406, "end": 49407, "pad": 49407}.
|
136 |
+
- `layer_norm_hidden_state` (bool, optional): Whether to normalize the hidden state. Defaults to True.
|
137 |
+
- `enable_attention_masks` (bool, optional): Whether to enable attention masks. Defaults to False.
|
138 |
+
- `zero_out_masked` (bool, optional): Whether to zero out masked tokens. Defaults to False.
|
139 |
+
- `return_projected_pooled` (bool, optional): Whether to return the projected pooled output. Defaults to True.
|
140 |
+
- `return_attention_masks` (bool, optional): Whether to return the attention masks. Defaults to False.
|
141 |
+
- `model_options` (dict, optional): Additional model options. Defaults to {}.
|
142 |
+
"""
|
143 |
+
super().__init__()
|
144 |
+
assert layer in self.LAYERS
|
145 |
+
|
146 |
+
if textmodel_json_config is None:
|
147 |
+
textmodel_json_config = "./_internal/clip/sd1_clip_config.json"
|
148 |
+
|
149 |
+
with open(textmodel_json_config) as f:
|
150 |
+
config = json.load(f)
|
151 |
+
|
152 |
+
operations = model_options.get("custom_operations", None)
|
153 |
+
if operations is None:
|
154 |
+
operations = cast.manual_cast
|
155 |
+
|
156 |
+
self.operations = operations
|
157 |
+
self.transformer = model_class(config, dtype, device, self.operations)
|
158 |
+
self.num_layers = self.transformer.num_layers
|
159 |
+
|
160 |
+
self.max_length = max_length
|
161 |
+
if freeze:
|
162 |
+
self.freeze()
|
163 |
+
self.layer = layer
|
164 |
+
self.layer_idx = None
|
165 |
+
self.special_tokens = special_tokens
|
166 |
+
|
167 |
+
self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055))
|
168 |
+
self.enable_attention_masks = enable_attention_masks
|
169 |
+
self.zero_out_masked = zero_out_masked
|
170 |
+
|
171 |
+
self.layer_norm_hidden_state = layer_norm_hidden_state
|
172 |
+
self.return_projected_pooled = return_projected_pooled
|
173 |
+
self.return_attention_masks = return_attention_masks
|
174 |
+
|
175 |
+
if layer == "hidden":
|
176 |
+
assert layer_idx is not None
|
177 |
+
assert abs(layer_idx) < self.num_layers
|
178 |
+
self.set_clip_options({"layer": layer_idx})
|
179 |
+
self.options_default = (
|
180 |
+
self.layer,
|
181 |
+
self.layer_idx,
|
182 |
+
self.return_projected_pooled,
|
183 |
+
)
|
184 |
+
|
185 |
+
def freeze(self) -> None:
|
186 |
+
"""#### Freeze the model parameters."""
|
187 |
+
self.transformer = self.transformer.eval()
|
188 |
+
for param in self.parameters():
|
189 |
+
param.requires_grad = False
|
190 |
+
|
191 |
+
def set_clip_options(self, options: dict) -> None:
|
192 |
+
"""#### Set the CLIP options.
|
193 |
+
|
194 |
+
#### Args:
|
195 |
+
- `options` (dict): The options to set.
|
196 |
+
"""
|
197 |
+
layer_idx = options.get("layer", self.layer_idx)
|
198 |
+
self.return_projected_pooled = options.get(
|
199 |
+
"projected_pooled", self.return_projected_pooled
|
200 |
+
)
|
201 |
+
if layer_idx is None or abs(layer_idx) > self.num_layers:
|
202 |
+
self.layer = "last"
|
203 |
+
else:
|
204 |
+
self.layer = "hidden"
|
205 |
+
self.layer_idx = layer_idx
|
206 |
+
|
207 |
+
def reset_clip_options(self) -> None:
|
208 |
+
"""#### Reset the CLIP options to default."""
|
209 |
+
self.layer = self.options_default[0]
|
210 |
+
self.layer_idx = self.options_default[1]
|
211 |
+
self.return_projected_pooled = self.options_default[2]
|
212 |
+
|
213 |
+
def set_up_textual_embeddings(self, tokens: list, current_embeds: torch.nn.Embedding) -> list:
|
214 |
+
"""#### Set up the textual embeddings.
|
215 |
+
|
216 |
+
#### Args:
|
217 |
+
- `tokens` (list): The input tokens.
|
218 |
+
- `current_embeds` (torch.nn.Embedding): The current embeddings.
|
219 |
+
|
220 |
+
#### Returns:
|
221 |
+
- `list`: The processed tokens.
|
222 |
+
"""
|
223 |
+
out_tokens = []
|
224 |
+
next_new_token = token_dict_size = current_embeds.weight.shape[0]
|
225 |
+
embedding_weights = []
|
226 |
+
|
227 |
+
for x in tokens:
|
228 |
+
tokens_temp = []
|
229 |
+
for y in x:
|
230 |
+
if isinstance(y, numbers.Integral):
|
231 |
+
tokens_temp += [int(y)]
|
232 |
+
else:
|
233 |
+
if y.shape[0] == current_embeds.weight.shape[1]:
|
234 |
+
embedding_weights += [y]
|
235 |
+
tokens_temp += [next_new_token]
|
236 |
+
next_new_token += 1
|
237 |
+
else:
|
238 |
+
logging.warning(
|
239 |
+
"WARNING: shape mismatch when trying to apply embedding, embedding will be ignored {} != {}".format(
|
240 |
+
y.shape[0], current_embeds.weight.shape[1]
|
241 |
+
)
|
242 |
+
)
|
243 |
+
while len(tokens_temp) < len(x):
|
244 |
+
tokens_temp += [self.special_tokens["pad"]]
|
245 |
+
out_tokens += [tokens_temp]
|
246 |
+
|
247 |
+
n = token_dict_size
|
248 |
+
if len(embedding_weights) > 0:
|
249 |
+
new_embedding = self.operations.Embedding(
|
250 |
+
next_new_token + 1,
|
251 |
+
current_embeds.weight.shape[1],
|
252 |
+
device=current_embeds.weight.device,
|
253 |
+
dtype=current_embeds.weight.dtype,
|
254 |
+
)
|
255 |
+
new_embedding.weight[:token_dict_size] = current_embeds.weight
|
256 |
+
for x in embedding_weights:
|
257 |
+
new_embedding.weight[n] = x
|
258 |
+
n += 1
|
259 |
+
self.transformer.set_input_embeddings(new_embedding)
|
260 |
+
|
261 |
+
processed_tokens = []
|
262 |
+
for x in out_tokens:
|
263 |
+
processed_tokens += [
|
264 |
+
list(map(lambda a: n if a == -1 else a, x))
|
265 |
+
] # The EOS token should always be the largest one
|
266 |
+
|
267 |
+
return processed_tokens
|
268 |
+
|
269 |
+
def forward(self, tokens: list) -> tuple:
|
270 |
+
"""#### Forward pass of the model.
|
271 |
+
|
272 |
+
#### Args:
|
273 |
+
- `tokens` (list): The input tokens.
|
274 |
+
|
275 |
+
#### Returns:
|
276 |
+
- `tuple`: The output and the pooled output.
|
277 |
+
"""
|
278 |
+
backup_embeds = self.transformer.get_input_embeddings()
|
279 |
+
device = backup_embeds.weight.device
|
280 |
+
tokens = self.set_up_textual_embeddings(tokens, backup_embeds)
|
281 |
+
tokens = torch.LongTensor(tokens).to(device)
|
282 |
+
|
283 |
+
attention_mask = None
|
284 |
+
if (
|
285 |
+
self.enable_attention_masks
|
286 |
+
or self.zero_out_masked
|
287 |
+
or self.return_attention_masks
|
288 |
+
):
|
289 |
+
attention_mask = torch.zeros_like(tokens)
|
290 |
+
end_token = self.special_tokens.get("end", -1)
|
291 |
+
for x in range(attention_mask.shape[0]):
|
292 |
+
for y in range(attention_mask.shape[1]):
|
293 |
+
attention_mask[x, y] = 1
|
294 |
+
if tokens[x, y] == end_token:
|
295 |
+
break
|
296 |
+
|
297 |
+
attention_mask_model = None
|
298 |
+
if self.enable_attention_masks:
|
299 |
+
attention_mask_model = attention_mask
|
300 |
+
|
301 |
+
outputs = self.transformer(
|
302 |
+
tokens,
|
303 |
+
attention_mask_model,
|
304 |
+
intermediate_output=self.layer_idx,
|
305 |
+
final_layer_norm_intermediate=self.layer_norm_hidden_state,
|
306 |
+
dtype=torch.float32,
|
307 |
+
)
|
308 |
+
self.transformer.set_input_embeddings(backup_embeds)
|
309 |
+
|
310 |
+
if self.layer == "last":
|
311 |
+
z = outputs[0].float()
|
312 |
+
else:
|
313 |
+
z = outputs[1].float()
|
314 |
+
|
315 |
+
if self.zero_out_masked:
|
316 |
+
z *= attention_mask.unsqueeze(-1).float()
|
317 |
+
|
318 |
+
pooled_output = None
|
319 |
+
if len(outputs) >= 3:
|
320 |
+
if (
|
321 |
+
not self.return_projected_pooled
|
322 |
+
and len(outputs) >= 4
|
323 |
+
and outputs[3] is not None
|
324 |
+
):
|
325 |
+
pooled_output = outputs[3].float()
|
326 |
+
elif outputs[2] is not None:
|
327 |
+
pooled_output = outputs[2].float()
|
328 |
+
|
329 |
+
extra = {}
|
330 |
+
if self.return_attention_masks:
|
331 |
+
extra["attention_mask"] = attention_mask
|
332 |
+
|
333 |
+
if len(extra) > 0:
|
334 |
+
return z, pooled_output, extra
|
335 |
+
|
336 |
+
return z, pooled_output
|
337 |
+
|
338 |
+
def encode(self, tokens: list) -> tuple:
|
339 |
+
"""#### Encode the input tokens.
|
340 |
+
|
341 |
+
#### Args:
|
342 |
+
- `tokens` (list): The input tokens.
|
343 |
+
|
344 |
+
#### Returns:
|
345 |
+
- `tuple`: The encoded tokens and the pooled output.
|
346 |
+
"""
|
347 |
+
return self(tokens)
|
348 |
+
|
349 |
+
def load_sd(self, sd: dict) -> None:
|
350 |
+
"""#### Load the state dictionary.
|
351 |
+
|
352 |
+
#### Args:
|
353 |
+
- `sd` (dict): The state dictionary.
|
354 |
+
"""
|
355 |
+
return self.transformer.load_state_dict(sd, strict=False)
|
356 |
+
|
357 |
+
|
358 |
+
class SD1ClipModel(torch.nn.Module):
|
359 |
+
"""#### Class representing the SD1ClipModel."""
|
360 |
+
|
361 |
+
def __init__(
|
362 |
+
self, device: str = "cpu", dtype: torch.dtype = None, clip_name: str = "l", clip_model: type = SDClipModel, **kwargs
|
363 |
+
):
|
364 |
+
"""#### Initialize the SD1ClipModel.
|
365 |
+
|
366 |
+
#### Args:
|
367 |
+
- `device` (str, optional): The device to use. Defaults to "cpu".
|
368 |
+
- `dtype` (torch.dtype, optional): The data type. Defaults to None.
|
369 |
+
- `clip_name` (str, optional): The name of the CLIP model. Defaults to "l".
|
370 |
+
- `clip_model` (type, optional): The CLIP model class. Defaults to SDClipModel.
|
371 |
+
- `**kwargs`: Additional keyword arguments.
|
372 |
+
"""
|
373 |
+
super().__init__()
|
374 |
+
self.clip_name = clip_name
|
375 |
+
self.clip = "clip_{}".format(self.clip_name)
|
376 |
+
self.lowvram_patch_counter = 0
|
377 |
+
self.model_loaded_weight_memory = 0
|
378 |
+
setattr(self, self.clip, clip_model(device=device, dtype=dtype, **kwargs))
|
379 |
+
|
380 |
+
def set_clip_options(self, options: dict) -> None:
|
381 |
+
"""#### Set the CLIP options.
|
382 |
+
|
383 |
+
#### Args:
|
384 |
+
- `options` (dict): The options to set.
|
385 |
+
"""
|
386 |
+
getattr(self, self.clip).set_clip_options(options)
|
387 |
+
|
388 |
+
def reset_clip_options(self) -> None:
|
389 |
+
"""#### Reset the CLIP options to default."""
|
390 |
+
getattr(self, self.clip).reset_clip_options()
|
391 |
+
|
392 |
+
def encode_token_weights(self, token_weight_pairs: dict) -> tuple:
|
393 |
+
"""#### Encode token weights.
|
394 |
+
|
395 |
+
#### Args:
|
396 |
+
- `token_weight_pairs` (dict): The token weight pairs.
|
397 |
+
|
398 |
+
#### Returns:
|
399 |
+
- `tuple`: The encoded tokens and the pooled output.
|
400 |
+
"""
|
401 |
+
token_weight_pairs = token_weight_pairs[self.clip_name]
|
402 |
+
out, pooled = getattr(self, self.clip).encode_token_weights(token_weight_pairs)
|
403 |
+
return out, pooled
|
modules/SD15/SDToken.py
ADDED
@@ -0,0 +1,450 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
import traceback
|
4 |
+
import torch
|
5 |
+
from transformers import CLIPTokenizerFast
|
6 |
+
|
7 |
+
def model_options_long_clip(sd, tokenizer_data, model_options):
|
8 |
+
w = sd.get("clip_l.text_model.embeddings.position_embedding.weight", None)
|
9 |
+
if w is None:
|
10 |
+
w = sd.get("text_model.embeddings.position_embedding.weight", None)
|
11 |
+
return tokenizer_data, model_options
|
12 |
+
|
13 |
+
def parse_parentheses(string: str) -> list:
|
14 |
+
"""#### Parse a string with nested parentheses.
|
15 |
+
|
16 |
+
#### Args:
|
17 |
+
- `string` (str): The input string.
|
18 |
+
|
19 |
+
#### Returns:
|
20 |
+
- `list`: The parsed list of strings.
|
21 |
+
"""
|
22 |
+
result = []
|
23 |
+
current_item = ""
|
24 |
+
nesting_level = 0
|
25 |
+
for char in string:
|
26 |
+
if char == "(":
|
27 |
+
if nesting_level == 0:
|
28 |
+
if current_item:
|
29 |
+
result.append(current_item)
|
30 |
+
current_item = "("
|
31 |
+
else:
|
32 |
+
current_item = "("
|
33 |
+
else:
|
34 |
+
current_item += char
|
35 |
+
nesting_level += 1
|
36 |
+
elif char == ")":
|
37 |
+
nesting_level -= 1
|
38 |
+
if nesting_level == 0:
|
39 |
+
result.append(current_item + ")")
|
40 |
+
current_item = ""
|
41 |
+
else:
|
42 |
+
current_item += char
|
43 |
+
else:
|
44 |
+
current_item += char
|
45 |
+
if current_item:
|
46 |
+
result.append(current_item)
|
47 |
+
return result
|
48 |
+
|
49 |
+
|
50 |
+
def token_weights(string: str, current_weight: float) -> list:
|
51 |
+
"""#### Parse a string into tokens with weights.
|
52 |
+
|
53 |
+
#### Args:
|
54 |
+
- `string` (str): The input string.
|
55 |
+
- `current_weight` (float): The current weight.
|
56 |
+
|
57 |
+
#### Returns:
|
58 |
+
- `list`: The list of token-weight pairs.
|
59 |
+
"""
|
60 |
+
a = parse_parentheses(string)
|
61 |
+
out = []
|
62 |
+
for x in a:
|
63 |
+
weight = current_weight
|
64 |
+
if len(x) >= 2 and x[-1] == ")" and x[0] == "(":
|
65 |
+
x = x[1:-1]
|
66 |
+
xx = x.rfind(":")
|
67 |
+
weight *= 1.1
|
68 |
+
if xx > 0:
|
69 |
+
try:
|
70 |
+
weight = float(x[xx + 1 :])
|
71 |
+
x = x[:xx]
|
72 |
+
except:
|
73 |
+
pass
|
74 |
+
out += token_weights(x, weight)
|
75 |
+
else:
|
76 |
+
out += [(x, current_weight)]
|
77 |
+
return out
|
78 |
+
|
79 |
+
|
80 |
+
def escape_important(text: str) -> str:
|
81 |
+
"""#### Escape important characters in a string.
|
82 |
+
|
83 |
+
#### Args:
|
84 |
+
- `text` (str): The input text.
|
85 |
+
|
86 |
+
#### Returns:
|
87 |
+
- `str`: The escaped text.
|
88 |
+
"""
|
89 |
+
text = text.replace("\\)", "\0\1")
|
90 |
+
text = text.replace("\\(", "\0\2")
|
91 |
+
return text
|
92 |
+
|
93 |
+
|
94 |
+
def unescape_important(text: str) -> str:
|
95 |
+
"""#### Unescape important characters in a string.
|
96 |
+
|
97 |
+
#### Args:
|
98 |
+
- `text` (str): The input text.
|
99 |
+
|
100 |
+
#### Returns:
|
101 |
+
- `str`: The unescaped text.
|
102 |
+
"""
|
103 |
+
text = text.replace("\0\1", ")")
|
104 |
+
text = text.replace("\0\2", "(")
|
105 |
+
return text
|
106 |
+
|
107 |
+
|
108 |
+
def expand_directory_list(directories: list) -> list:
|
109 |
+
"""#### Expand a list of directories to include all subdirectories.
|
110 |
+
|
111 |
+
#### Args:
|
112 |
+
- `directories` (list): The list of directories.
|
113 |
+
|
114 |
+
#### Returns:
|
115 |
+
- `list`: The expanded list of directories.
|
116 |
+
"""
|
117 |
+
dirs = set()
|
118 |
+
for x in directories:
|
119 |
+
dirs.add(x)
|
120 |
+
for root, subdir, file in os.walk(x, followlinks=True):
|
121 |
+
dirs.add(root)
|
122 |
+
return list(dirs)
|
123 |
+
|
124 |
+
|
125 |
+
def load_embed(embedding_name: str, embedding_directory: list, embedding_size: int, embed_key: str = None) -> torch.Tensor:
|
126 |
+
"""#### Load an embedding from a directory.
|
127 |
+
|
128 |
+
#### Args:
|
129 |
+
- `embedding_name` (str): The name of the embedding.
|
130 |
+
- `embedding_directory` (list): The list of directories to search.
|
131 |
+
- `embedding_size` (int): The size of the embedding.
|
132 |
+
- `embed_key` (str, optional): The key for the embedding. Defaults to None.
|
133 |
+
|
134 |
+
#### Returns:
|
135 |
+
- `torch.Tensor`: The loaded embedding.
|
136 |
+
"""
|
137 |
+
if isinstance(embedding_directory, str):
|
138 |
+
embedding_directory = [embedding_directory]
|
139 |
+
|
140 |
+
embedding_directory = expand_directory_list(embedding_directory)
|
141 |
+
|
142 |
+
valid_file = None
|
143 |
+
for embed_dir in embedding_directory:
|
144 |
+
embed_path = os.path.abspath(os.path.join(embed_dir, embedding_name))
|
145 |
+
embed_dir = os.path.abspath(embed_dir)
|
146 |
+
try:
|
147 |
+
if os.path.commonpath((embed_dir, embed_path)) != embed_dir:
|
148 |
+
continue
|
149 |
+
except:
|
150 |
+
continue
|
151 |
+
if not os.path.isfile(embed_path):
|
152 |
+
extensions = [".safetensors", ".pt", ".bin"]
|
153 |
+
for x in extensions:
|
154 |
+
t = embed_path + x
|
155 |
+
if os.path.isfile(t):
|
156 |
+
valid_file = t
|
157 |
+
break
|
158 |
+
else:
|
159 |
+
valid_file = embed_path
|
160 |
+
if valid_file is not None:
|
161 |
+
break
|
162 |
+
|
163 |
+
if valid_file is None:
|
164 |
+
return None
|
165 |
+
|
166 |
+
embed_path = valid_file
|
167 |
+
|
168 |
+
embed_out = None
|
169 |
+
|
170 |
+
try:
|
171 |
+
if embed_path.lower().endswith(".safetensors"):
|
172 |
+
import safetensors.torch
|
173 |
+
|
174 |
+
embed = safetensors.torch.load_file(embed_path, device="cpu")
|
175 |
+
else:
|
176 |
+
if "weights_only" in torch.load.__code__.co_varnames:
|
177 |
+
embed = torch.load(embed_path, weights_only=True, map_location="cpu")
|
178 |
+
else:
|
179 |
+
embed = torch.load(embed_path, map_location="cpu")
|
180 |
+
except Exception:
|
181 |
+
logging.warning(
|
182 |
+
"{}\n\nerror loading embedding, skipping loading: {}".format(
|
183 |
+
traceback.format_exc(), embedding_name
|
184 |
+
)
|
185 |
+
)
|
186 |
+
return None
|
187 |
+
|
188 |
+
if embed_out is None:
|
189 |
+
if "string_to_param" in embed:
|
190 |
+
values = embed["string_to_param"].values()
|
191 |
+
embed_out = next(iter(values))
|
192 |
+
elif isinstance(embed, list):
|
193 |
+
out_list = []
|
194 |
+
for x in range(len(embed)):
|
195 |
+
for k in embed[x]:
|
196 |
+
t = embed[x][k]
|
197 |
+
if t.shape[-1] != embedding_size:
|
198 |
+
continue
|
199 |
+
out_list.append(t.reshape(-1, t.shape[-1]))
|
200 |
+
embed_out = torch.cat(out_list, dim=0)
|
201 |
+
elif embed_key is not None and embed_key in embed:
|
202 |
+
embed_out = embed[embed_key]
|
203 |
+
else:
|
204 |
+
values = embed.values()
|
205 |
+
embed_out = next(iter(values))
|
206 |
+
return embed_out
|
207 |
+
|
208 |
+
|
209 |
+
class SDTokenizer:
|
210 |
+
"""#### Class representing a Stable Diffusion tokenizer."""
|
211 |
+
|
212 |
+
def __init__(
|
213 |
+
self,
|
214 |
+
tokenizer_path: str = None,
|
215 |
+
max_length: int = 77,
|
216 |
+
pad_with_end: bool = True,
|
217 |
+
embedding_directory: str = None,
|
218 |
+
embedding_size: int = 768,
|
219 |
+
embedding_key: str = "clip_l",
|
220 |
+
tokenizer_class: type = CLIPTokenizerFast,
|
221 |
+
has_start_token: bool = True,
|
222 |
+
pad_to_max_length: bool = True,
|
223 |
+
min_length: int = None,
|
224 |
+
):
|
225 |
+
"""#### Initialize the SDTokenizer.
|
226 |
+
|
227 |
+
#### Args:
|
228 |
+
- `tokenizer_path` (str, optional): The path to the tokenizer. Defaults to None.
|
229 |
+
- `max_length` (int, optional): The maximum length of the input. Defaults to 77.
|
230 |
+
- `pad_with_end` (bool, optional): Whether to pad with the end token. Defaults to True.
|
231 |
+
- `embedding_directory` (str, optional): The directory for embeddings. Defaults to None.
|
232 |
+
- `embedding_size` (int, optional): The size of the embeddings. Defaults to 768.
|
233 |
+
- `embedding_key` (str, optional): The key for the embeddings. Defaults to "clip_l".
|
234 |
+
- `tokenizer_class` (type, optional): The tokenizer class. Defaults to CLIPTokenizer.
|
235 |
+
- `has_start_token` (bool, optional): Whether the tokenizer has a start token. Defaults to True.
|
236 |
+
- `pad_to_max_length` (bool, optional): Whether to pad to the maximum length. Defaults to True.
|
237 |
+
- `min_length` (int, optional): The minimum length of the input. Defaults to None.
|
238 |
+
"""
|
239 |
+
if tokenizer_path is None:
|
240 |
+
tokenizer_path = "_internal/sd1_tokenizer/"
|
241 |
+
self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path)
|
242 |
+
self.max_length = max_length
|
243 |
+
self.min_length = min_length
|
244 |
+
|
245 |
+
empty = self.tokenizer("")["input_ids"]
|
246 |
+
if has_start_token:
|
247 |
+
self.tokens_start = 1
|
248 |
+
self.start_token = empty[0]
|
249 |
+
self.end_token = empty[1]
|
250 |
+
else:
|
251 |
+
self.tokens_start = 0
|
252 |
+
self.start_token = None
|
253 |
+
self.end_token = empty[0]
|
254 |
+
self.pad_with_end = pad_with_end
|
255 |
+
self.pad_to_max_length = pad_to_max_length
|
256 |
+
|
257 |
+
vocab = self.tokenizer.get_vocab()
|
258 |
+
self.inv_vocab = {v: k for k, v in vocab.items()}
|
259 |
+
self.embedding_directory = embedding_directory
|
260 |
+
self.max_word_length = 8
|
261 |
+
self.embedding_identifier = "embedding:"
|
262 |
+
self.embedding_size = embedding_size
|
263 |
+
self.embedding_key = embedding_key
|
264 |
+
|
265 |
+
def _try_get_embedding(self, embedding_name: str) -> tuple:
|
266 |
+
"""#### Try to get an embedding.
|
267 |
+
|
268 |
+
#### Args:
|
269 |
+
- `embedding_name` (str): The name of the embedding.
|
270 |
+
|
271 |
+
#### Returns:
|
272 |
+
- `tuple`: The embedding and any leftover text.
|
273 |
+
"""
|
274 |
+
embed = load_embed(
|
275 |
+
embedding_name,
|
276 |
+
self.embedding_directory,
|
277 |
+
self.embedding_size,
|
278 |
+
self.embedding_key,
|
279 |
+
)
|
280 |
+
if embed is None:
|
281 |
+
stripped = embedding_name.strip(",")
|
282 |
+
if len(stripped) < len(embedding_name):
|
283 |
+
embed = load_embed(
|
284 |
+
stripped,
|
285 |
+
self.embedding_directory,
|
286 |
+
self.embedding_size,
|
287 |
+
self.embedding_key,
|
288 |
+
)
|
289 |
+
return (embed, embedding_name[len(stripped) :])
|
290 |
+
return (embed, "")
|
291 |
+
|
292 |
+
def tokenize_with_weights(self, text: str, return_word_ids: bool = False) -> list:
|
293 |
+
"""#### Tokenize text with weights.
|
294 |
+
|
295 |
+
#### Args:
|
296 |
+
- `text` (str): The input text.
|
297 |
+
- `return_word_ids` (bool, optional): Whether to return word IDs. Defaults to False.
|
298 |
+
|
299 |
+
#### Returns:
|
300 |
+
- `list`: The tokenized text with weights.
|
301 |
+
"""
|
302 |
+
if self.pad_with_end:
|
303 |
+
pad_token = self.end_token
|
304 |
+
else:
|
305 |
+
pad_token = 0
|
306 |
+
|
307 |
+
text = escape_important(text)
|
308 |
+
parsed_weights = token_weights(text, 1.0)
|
309 |
+
|
310 |
+
# tokenize words
|
311 |
+
tokens = []
|
312 |
+
for weighted_segment, weight in parsed_weights:
|
313 |
+
to_tokenize = (
|
314 |
+
unescape_important(weighted_segment).replace("\n", " ").split(" ")
|
315 |
+
)
|
316 |
+
to_tokenize = [x for x in to_tokenize if x != ""]
|
317 |
+
for word in to_tokenize:
|
318 |
+
# if we find an embedding, deal with the embedding
|
319 |
+
if (
|
320 |
+
word.startswith(self.embedding_identifier)
|
321 |
+
and self.embedding_directory is not None
|
322 |
+
):
|
323 |
+
embedding_name = word[len(self.embedding_identifier) :].strip("\n")
|
324 |
+
embed, leftover = self._try_get_embedding(embedding_name)
|
325 |
+
if embed is None:
|
326 |
+
logging.warning(
|
327 |
+
f"warning, embedding:{embedding_name} does not exist, ignoring"
|
328 |
+
)
|
329 |
+
else:
|
330 |
+
if len(embed.shape) == 1:
|
331 |
+
tokens.append([(embed, weight)])
|
332 |
+
else:
|
333 |
+
tokens.append(
|
334 |
+
[(embed[x], weight) for x in range(embed.shape[0])]
|
335 |
+
)
|
336 |
+
print("loading ", embedding_name)
|
337 |
+
# if we accidentally have leftover text, continue parsing using leftover, else move on to next word
|
338 |
+
if leftover != "":
|
339 |
+
word = leftover
|
340 |
+
else:
|
341 |
+
continue
|
342 |
+
# parse word
|
343 |
+
tokens.append(
|
344 |
+
[
|
345 |
+
(t, weight)
|
346 |
+
for t in self.tokenizer(word)["input_ids"][
|
347 |
+
self.tokens_start : -1
|
348 |
+
]
|
349 |
+
]
|
350 |
+
)
|
351 |
+
|
352 |
+
# reshape token array to CLIP input size
|
353 |
+
batched_tokens = []
|
354 |
+
batch = []
|
355 |
+
if self.start_token is not None:
|
356 |
+
batch.append((self.start_token, 1.0, 0))
|
357 |
+
batched_tokens.append(batch)
|
358 |
+
for i, t_group in enumerate(tokens):
|
359 |
+
# determine if we're going to try and keep the tokens in a single batch
|
360 |
+
is_large = len(t_group) >= self.max_word_length
|
361 |
+
|
362 |
+
while len(t_group) > 0:
|
363 |
+
if len(t_group) + len(batch) > self.max_length - 1:
|
364 |
+
remaining_length = self.max_length - len(batch) - 1
|
365 |
+
# break word in two and add end token
|
366 |
+
if is_large:
|
367 |
+
batch.extend(
|
368 |
+
[(t, w, i + 1) for t, w in t_group[:remaining_length]]
|
369 |
+
)
|
370 |
+
batch.append((self.end_token, 1.0, 0))
|
371 |
+
t_group = t_group[remaining_length:]
|
372 |
+
# add end token and pad
|
373 |
+
else:
|
374 |
+
batch.append((self.end_token, 1.0, 0))
|
375 |
+
if self.pad_to_max_length:
|
376 |
+
batch.extend([(pad_token, 1.0, 0)] * (remaining_length))
|
377 |
+
# start new batch
|
378 |
+
batch = []
|
379 |
+
if self.start_token is not None:
|
380 |
+
batch.append((self.start_token, 1.0, 0))
|
381 |
+
batched_tokens.append(batch)
|
382 |
+
else:
|
383 |
+
batch.extend([(t, w, i + 1) for t, w in t_group])
|
384 |
+
t_group = []
|
385 |
+
|
386 |
+
# fill last batch
|
387 |
+
batch.append((self.end_token, 1.0, 0))
|
388 |
+
if self.pad_to_max_length:
|
389 |
+
batch.extend([(pad_token, 1.0, 0)] * (self.max_length - len(batch)))
|
390 |
+
if self.min_length is not None and len(batch) < self.min_length:
|
391 |
+
batch.extend([(pad_token, 1.0, 0)] * (self.min_length - len(batch)))
|
392 |
+
|
393 |
+
if not return_word_ids:
|
394 |
+
batched_tokens = [[(t, w) for t, w, _ in x] for x in batched_tokens]
|
395 |
+
|
396 |
+
return batched_tokens
|
397 |
+
|
398 |
+
def untokenize(self, token_weight_pair: list) -> list:
|
399 |
+
"""#### Untokenize a list of token-weight pairs.
|
400 |
+
|
401 |
+
#### Args:
|
402 |
+
- `token_weight_pair` (list): The list of token-weight pairs.
|
403 |
+
|
404 |
+
#### Returns:
|
405 |
+
- `list`: The untokenized list.
|
406 |
+
"""
|
407 |
+
return list(map(lambda a: (a, self.inv_vocab[a[0]]), token_weight_pair))
|
408 |
+
|
409 |
+
|
410 |
+
class SD1Tokenizer:
|
411 |
+
"""#### Class representing the SD1Tokenizer."""
|
412 |
+
|
413 |
+
def __init__(self, embedding_directory: str = None, clip_name: str = "l", tokenizer: type = SDTokenizer):
|
414 |
+
"""#### Initialize the SD1Tokenizer.
|
415 |
+
|
416 |
+
#### Args:
|
417 |
+
- `embedding_directory` (str, optional): The directory for embeddings. Defaults to None.
|
418 |
+
- `clip_name` (str, optional): The name of the CLIP model. Defaults to "l".
|
419 |
+
- `tokenizer` (type, optional): The tokenizer class. Defaults to SDTokenizer.
|
420 |
+
"""
|
421 |
+
self.clip_name = clip_name
|
422 |
+
self.clip = "clip_{}".format(self.clip_name)
|
423 |
+
setattr(self, self.clip, tokenizer(embedding_directory=embedding_directory))
|
424 |
+
|
425 |
+
def tokenize_with_weights(self, text: str, return_word_ids: bool = False) -> dict:
|
426 |
+
"""#### Tokenize text with weights.
|
427 |
+
|
428 |
+
#### Args:
|
429 |
+
- `text` (str): The input text.
|
430 |
+
- `return_word_ids` (bool, optional): Whether to return word IDs. Defaults to False.
|
431 |
+
|
432 |
+
#### Returns:
|
433 |
+
- `dict`: The tokenized text with weights.
|
434 |
+
"""
|
435 |
+
out = {}
|
436 |
+
out[self.clip_name] = getattr(self, self.clip).tokenize_with_weights(
|
437 |
+
text, return_word_ids
|
438 |
+
)
|
439 |
+
return out
|
440 |
+
|
441 |
+
def untokenize(self, token_weight_pair: list) -> list:
|
442 |
+
"""#### Untokenize a list of token-weight pairs.
|
443 |
+
|
444 |
+
#### Args:
|
445 |
+
- `token_weight_pair` (list): The list of token-weight pairs.
|
446 |
+
|
447 |
+
#### Returns:
|
448 |
+
- `list`: The untokenized list.
|
449 |
+
"""
|
450 |
+
return getattr(self, self.clip).untokenize(token_weight_pair)
|
modules/StableFast/StableFast.py
ADDED
@@ -0,0 +1,274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import contextlib
|
2 |
+
import functools
|
3 |
+
import logging
|
4 |
+
from dataclasses import dataclass
|
5 |
+
|
6 |
+
import torch
|
7 |
+
|
8 |
+
try:
|
9 |
+
from sfast.compilers.diffusion_pipeline_compiler import CompilationConfig
|
10 |
+
from sfast.compilers.diffusion_pipeline_compiler import (
|
11 |
+
_enable_xformers,
|
12 |
+
_modify_model,
|
13 |
+
)
|
14 |
+
from sfast.cuda.graphs import make_dynamic_graphed_callable
|
15 |
+
from sfast.jit import utils as jit_utils
|
16 |
+
from sfast.jit.trace_helper import trace_with_kwargs
|
17 |
+
except:
|
18 |
+
pass
|
19 |
+
|
20 |
+
|
21 |
+
def hash_arg(arg):
|
22 |
+
# micro optimization: bool obj is an instance of int
|
23 |
+
if isinstance(arg, (str, int, float, bytes)):
|
24 |
+
return arg
|
25 |
+
if isinstance(arg, (tuple, list)):
|
26 |
+
return tuple(map(hash_arg, arg))
|
27 |
+
if isinstance(arg, dict):
|
28 |
+
return tuple(
|
29 |
+
sorted(
|
30 |
+
((hash_arg(k), hash_arg(v)) for k, v in arg.items()), key=lambda x: x[0]
|
31 |
+
)
|
32 |
+
)
|
33 |
+
return type(arg)
|
34 |
+
|
35 |
+
|
36 |
+
class ModuleFactory:
|
37 |
+
def get_converted_kwargs(self):
|
38 |
+
return self.converted_kwargs
|
39 |
+
|
40 |
+
|
41 |
+
import torch as th
|
42 |
+
import torch.nn as nn
|
43 |
+
import copy
|
44 |
+
|
45 |
+
|
46 |
+
class BaseModelApplyModelModule(torch.nn.Module):
|
47 |
+
def __init__(self, func, module):
|
48 |
+
super().__init__()
|
49 |
+
self.func = func
|
50 |
+
self.module = module
|
51 |
+
|
52 |
+
def forward(
|
53 |
+
self,
|
54 |
+
input_x,
|
55 |
+
timestep,
|
56 |
+
c_concat=None,
|
57 |
+
c_crossattn=None,
|
58 |
+
y=None,
|
59 |
+
control=None,
|
60 |
+
transformer_options={},
|
61 |
+
):
|
62 |
+
kwargs = {"y": y}
|
63 |
+
|
64 |
+
new_transformer_options = {}
|
65 |
+
|
66 |
+
return self.func(
|
67 |
+
input_x,
|
68 |
+
timestep,
|
69 |
+
c_concat=c_concat,
|
70 |
+
c_crossattn=c_crossattn,
|
71 |
+
control=control,
|
72 |
+
transformer_options=new_transformer_options,
|
73 |
+
**kwargs,
|
74 |
+
)
|
75 |
+
|
76 |
+
|
77 |
+
class BaseModelApplyModelModuleFactory(ModuleFactory):
|
78 |
+
kwargs_name = (
|
79 |
+
"input_x",
|
80 |
+
"timestep",
|
81 |
+
"c_concat",
|
82 |
+
"c_crossattn",
|
83 |
+
"y",
|
84 |
+
"control",
|
85 |
+
)
|
86 |
+
|
87 |
+
def __init__(self, callable, kwargs) -> None:
|
88 |
+
self.callable = callable
|
89 |
+
self.unet_config = callable.__self__.model_config.unet_config
|
90 |
+
self.kwargs = kwargs
|
91 |
+
self.patch_module = {}
|
92 |
+
self.patch_module_parameter = {}
|
93 |
+
self.converted_kwargs = self.gen_converted_kwargs()
|
94 |
+
|
95 |
+
def gen_converted_kwargs(self):
|
96 |
+
converted_kwargs = {}
|
97 |
+
for arg_name, arg in self.kwargs.items():
|
98 |
+
if arg_name in self.kwargs_name:
|
99 |
+
converted_kwargs[arg_name] = arg
|
100 |
+
|
101 |
+
transformer_options = self.kwargs.get("transformer_options", {})
|
102 |
+
patches = transformer_options.get("patches", {})
|
103 |
+
|
104 |
+
patch_module = {}
|
105 |
+
patch_module_parameter = {}
|
106 |
+
|
107 |
+
new_transformer_options = {}
|
108 |
+
new_transformer_options["patches"] = patch_module_parameter
|
109 |
+
|
110 |
+
self.patch_module = patch_module
|
111 |
+
self.patch_module_parameter = patch_module_parameter
|
112 |
+
return converted_kwargs
|
113 |
+
|
114 |
+
def gen_cache_key(self):
|
115 |
+
key_kwargs = {}
|
116 |
+
for k, v in self.converted_kwargs.items():
|
117 |
+
key_kwargs[k] = v
|
118 |
+
|
119 |
+
patch_module_cache_key = {}
|
120 |
+
return (
|
121 |
+
self.callable.__class__.__qualname__,
|
122 |
+
hash_arg(self.unet_config),
|
123 |
+
hash_arg(key_kwargs),
|
124 |
+
hash_arg(patch_module_cache_key),
|
125 |
+
)
|
126 |
+
|
127 |
+
@contextlib.contextmanager
|
128 |
+
def converted_module_context(self):
|
129 |
+
module = BaseModelApplyModelModule(self.callable, self.callable.__self__)
|
130 |
+
yield (module, self.converted_kwargs)
|
131 |
+
|
132 |
+
|
133 |
+
logger = logging.getLogger()
|
134 |
+
|
135 |
+
|
136 |
+
@dataclass
|
137 |
+
class TracedModuleCacheItem:
|
138 |
+
module: object
|
139 |
+
patch_id: int
|
140 |
+
device: str
|
141 |
+
|
142 |
+
|
143 |
+
class LazyTraceModule:
|
144 |
+
traced_modules = {}
|
145 |
+
|
146 |
+
def __init__(self, config=None, patch_id=None, **kwargs_) -> None:
|
147 |
+
self.config = config
|
148 |
+
self.patch_id = patch_id
|
149 |
+
self.kwargs_ = kwargs_
|
150 |
+
self.modify_model = functools.partial(
|
151 |
+
_modify_model,
|
152 |
+
enable_cnn_optimization=config.enable_cnn_optimization,
|
153 |
+
prefer_lowp_gemm=config.prefer_lowp_gemm,
|
154 |
+
enable_triton=config.enable_triton,
|
155 |
+
enable_triton_reshape=config.enable_triton,
|
156 |
+
memory_format=config.memory_format,
|
157 |
+
)
|
158 |
+
self.cuda_graph_modules = {}
|
159 |
+
|
160 |
+
def ts_compiler(
|
161 |
+
self,
|
162 |
+
m,
|
163 |
+
):
|
164 |
+
with torch.jit.optimized_execution(True):
|
165 |
+
if self.config.enable_jit_freeze:
|
166 |
+
# raw freeze causes Tensor reference leak
|
167 |
+
# because the constant Tensors in the GraphFunction of
|
168 |
+
# the compilation unit are never freed.
|
169 |
+
m.eval()
|
170 |
+
m = jit_utils.better_freeze(m)
|
171 |
+
self.modify_model(m)
|
172 |
+
|
173 |
+
if self.config.enable_cuda_graph:
|
174 |
+
m = make_dynamic_graphed_callable(m)
|
175 |
+
return m
|
176 |
+
|
177 |
+
def __call__(self, model_function, /, **kwargs):
|
178 |
+
module_factory = BaseModelApplyModelModuleFactory(model_function, kwargs)
|
179 |
+
kwargs = module_factory.get_converted_kwargs()
|
180 |
+
key = module_factory.gen_cache_key()
|
181 |
+
|
182 |
+
traced_module = self.cuda_graph_modules.get(key)
|
183 |
+
if traced_module is None:
|
184 |
+
with module_factory.converted_module_context() as (m_model, m_kwargs):
|
185 |
+
logger.info(
|
186 |
+
f'Tracing {getattr(m_model, "__name__", m_model.__class__.__name__)}'
|
187 |
+
)
|
188 |
+
traced_m, call_helper = trace_with_kwargs(
|
189 |
+
m_model, None, m_kwargs, **self.kwargs_
|
190 |
+
)
|
191 |
+
|
192 |
+
traced_m = self.ts_compiler(traced_m)
|
193 |
+
traced_module = call_helper(traced_m)
|
194 |
+
self.cuda_graph_modules[key] = traced_module
|
195 |
+
|
196 |
+
return traced_module(**kwargs)
|
197 |
+
|
198 |
+
|
199 |
+
def build_lazy_trace_module(config, device, patch_id):
|
200 |
+
config.enable_cuda_graph = config.enable_cuda_graph and device.type == "cuda"
|
201 |
+
|
202 |
+
if config.enable_xformers:
|
203 |
+
_enable_xformers(None)
|
204 |
+
|
205 |
+
return LazyTraceModule(
|
206 |
+
config=config,
|
207 |
+
patch_id=patch_id,
|
208 |
+
check_trace=True,
|
209 |
+
strict=True,
|
210 |
+
)
|
211 |
+
|
212 |
+
|
213 |
+
def gen_stable_fast_config():
|
214 |
+
config = CompilationConfig.Default()
|
215 |
+
try:
|
216 |
+
import xformers
|
217 |
+
|
218 |
+
config.enable_xformers = True
|
219 |
+
except ImportError:
|
220 |
+
print("xformers not installed, skip")
|
221 |
+
|
222 |
+
# CUDA Graph is suggested for small batch sizes.
|
223 |
+
# After capturing, the model only accepts one fixed image size.
|
224 |
+
# If you want the model to be dynamic, don't enable it.
|
225 |
+
config.enable_cuda_graph = False
|
226 |
+
# config.enable_jit_freeze = False
|
227 |
+
return config
|
228 |
+
|
229 |
+
|
230 |
+
class StableFastPatch:
|
231 |
+
def __init__(self, model, config):
|
232 |
+
self.model = model
|
233 |
+
self.config = config
|
234 |
+
self.stable_fast_model = None
|
235 |
+
|
236 |
+
def __call__(self, model_function, params):
|
237 |
+
input_x = params.get("input")
|
238 |
+
timestep_ = params.get("timestep")
|
239 |
+
c = params.get("c")
|
240 |
+
|
241 |
+
if self.stable_fast_model is None:
|
242 |
+
self.stable_fast_model = build_lazy_trace_module(
|
243 |
+
self.config,
|
244 |
+
input_x.device,
|
245 |
+
id(self),
|
246 |
+
)
|
247 |
+
|
248 |
+
return self.stable_fast_model(
|
249 |
+
model_function, input_x=input_x, timestep=timestep_, **c
|
250 |
+
)
|
251 |
+
|
252 |
+
def to(self, device):
|
253 |
+
if type(device) == torch.device:
|
254 |
+
if self.config.enable_cuda_graph or self.config.enable_jit_freeze:
|
255 |
+
if device.type == "cpu":
|
256 |
+
del self.stable_fast_model
|
257 |
+
self.stable_fast_model = None
|
258 |
+
print(
|
259 |
+
"\33[93mWarning: Your graphics card doesn't have enough video memory to keep the model. If you experience a noticeable delay every time you start sampling, please consider disable enable_cuda_graph.\33[0m"
|
260 |
+
)
|
261 |
+
return self
|
262 |
+
|
263 |
+
|
264 |
+
class ApplyStableFastUnet:
|
265 |
+
def apply_stable_fast(self, model, enable_cuda_graph):
|
266 |
+
config = gen_stable_fast_config()
|
267 |
+
|
268 |
+
if config.memory_format is not None:
|
269 |
+
model.model.to(memory_format=config.memory_format)
|
270 |
+
|
271 |
+
patch = StableFastPatch(model, config)
|
272 |
+
model_stable_fast = model.clone()
|
273 |
+
model_stable_fast.set_model_unet_function_wrapper(patch)
|
274 |
+
return (model_stable_fast,)
|
modules/UltimateSDUpscale/RDRB.py
ADDED
@@ -0,0 +1,471 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import OrderedDict
|
2 |
+
import functools
|
3 |
+
import math
|
4 |
+
import re
|
5 |
+
from typing import Union, Dict
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
from modules.UltimateSDUpscale import USDU_util
|
9 |
+
|
10 |
+
|
11 |
+
class RRDB(nn.Module):
|
12 |
+
"""#### Residual in Residual Dense Block (RRDB) class.
|
13 |
+
|
14 |
+
#### Args:
|
15 |
+
- `nf` (int): Number of filters.
|
16 |
+
- `kernel_size` (int, optional): Kernel size. Defaults to 3.
|
17 |
+
- `gc` (int, optional): Growth channel. Defaults to 32.
|
18 |
+
- `stride` (int, optional): Stride. Defaults to 1.
|
19 |
+
- `bias` (bool, optional): Whether to use bias. Defaults to True.
|
20 |
+
- `pad_type` (str, optional): Padding type. Defaults to "zero".
|
21 |
+
- `norm_type` (str, optional): Normalization type. Defaults to None.
|
22 |
+
- `act_type` (str, optional): Activation type. Defaults to "leakyrelu".
|
23 |
+
- `mode` (USDU_util.ConvMode, optional): Convolution mode. Defaults to "CNA".
|
24 |
+
- `_convtype` (str, optional): Convolution type. Defaults to "Conv2D".
|
25 |
+
- `_spectral_norm` (bool, optional): Whether to use spectral normalization. Defaults to False.
|
26 |
+
- `plus` (bool, optional): Whether to use the plus variant. Defaults to False.
|
27 |
+
- `c2x2` (bool, optional): Whether to use 2x2 convolution. Defaults to False.
|
28 |
+
"""
|
29 |
+
|
30 |
+
def __init__(
|
31 |
+
self,
|
32 |
+
nf: int,
|
33 |
+
kernel_size: int = 3,
|
34 |
+
gc: int = 32,
|
35 |
+
stride: int = 1,
|
36 |
+
bias: bool = True,
|
37 |
+
pad_type: str = "zero",
|
38 |
+
norm_type: str = None,
|
39 |
+
act_type: str = "leakyrelu",
|
40 |
+
mode: USDU_util.ConvMode = "CNA",
|
41 |
+
_convtype: str = "Conv2D",
|
42 |
+
_spectral_norm: bool = False,
|
43 |
+
plus: bool = False,
|
44 |
+
c2x2: bool = False,
|
45 |
+
) -> None:
|
46 |
+
super(RRDB, self).__init__()
|
47 |
+
self.RDB1 = ResidualDenseBlock_5C(
|
48 |
+
nf,
|
49 |
+
kernel_size,
|
50 |
+
gc,
|
51 |
+
stride,
|
52 |
+
bias,
|
53 |
+
pad_type,
|
54 |
+
norm_type,
|
55 |
+
act_type,
|
56 |
+
mode,
|
57 |
+
plus=plus,
|
58 |
+
c2x2=c2x2,
|
59 |
+
)
|
60 |
+
self.RDB2 = ResidualDenseBlock_5C(
|
61 |
+
nf,
|
62 |
+
kernel_size,
|
63 |
+
gc,
|
64 |
+
stride,
|
65 |
+
bias,
|
66 |
+
pad_type,
|
67 |
+
norm_type,
|
68 |
+
act_type,
|
69 |
+
mode,
|
70 |
+
plus=plus,
|
71 |
+
c2x2=c2x2,
|
72 |
+
)
|
73 |
+
self.RDB3 = ResidualDenseBlock_5C(
|
74 |
+
nf,
|
75 |
+
kernel_size,
|
76 |
+
gc,
|
77 |
+
stride,
|
78 |
+
bias,
|
79 |
+
pad_type,
|
80 |
+
norm_type,
|
81 |
+
act_type,
|
82 |
+
mode,
|
83 |
+
plus=plus,
|
84 |
+
c2x2=c2x2,
|
85 |
+
)
|
86 |
+
|
87 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
88 |
+
"""#### Forward pass of the RRDB.
|
89 |
+
|
90 |
+
#### Args:
|
91 |
+
- `x` (torch.Tensor): Input tensor.
|
92 |
+
|
93 |
+
#### Returns:
|
94 |
+
- `torch.Tensor`: Output tensor.
|
95 |
+
"""
|
96 |
+
out = self.RDB1(x)
|
97 |
+
out = self.RDB2(out)
|
98 |
+
out = self.RDB3(out)
|
99 |
+
return out * 0.2 + x
|
100 |
+
|
101 |
+
|
102 |
+
class ResidualDenseBlock_5C(nn.Module):
|
103 |
+
"""#### Residual Dense Block with 5 Convolutions (ResidualDenseBlock_5C) class.
|
104 |
+
|
105 |
+
#### Args:
|
106 |
+
- `nf` (int, optional): Number of filters. Defaults to 64.
|
107 |
+
- `kernel_size` (int, optional): Kernel size. Defaults to 3.
|
108 |
+
- `gc` (int, optional): Growth channel. Defaults to 32.
|
109 |
+
- `stride` (int, optional): Stride. Defaults to 1.
|
110 |
+
- `bias` (bool, optional): Whether to use bias. Defaults to True.
|
111 |
+
- `pad_type` (str, optional): Padding type. Defaults to "zero".
|
112 |
+
- `norm_type` (str, optional): Normalization type. Defaults to None.
|
113 |
+
- `act_type` (str, optional): Activation type. Defaults to "leakyrelu".
|
114 |
+
- `mode` (USDU_util.ConvMode, optional): Convolution mode. Defaults to "CNA".
|
115 |
+
- `plus` (bool, optional): Whether to use the plus variant. Defaults to False.
|
116 |
+
- `c2x2` (bool, optional): Whether to use 2x2 convolution. Defaults to False.
|
117 |
+
"""
|
118 |
+
|
119 |
+
def __init__(
|
120 |
+
self,
|
121 |
+
nf: int = 64,
|
122 |
+
kernel_size: int = 3,
|
123 |
+
gc: int = 32,
|
124 |
+
stride: int = 1,
|
125 |
+
bias: bool = True,
|
126 |
+
pad_type: str = "zero",
|
127 |
+
norm_type: str = None,
|
128 |
+
act_type: str = "leakyrelu",
|
129 |
+
mode: USDU_util.ConvMode = "CNA",
|
130 |
+
plus: bool = False,
|
131 |
+
c2x2: bool = False,
|
132 |
+
) -> None:
|
133 |
+
super(ResidualDenseBlock_5C, self).__init__()
|
134 |
+
|
135 |
+
self.conv1x1 = None
|
136 |
+
|
137 |
+
self.conv1 = USDU_util.conv_block(
|
138 |
+
nf,
|
139 |
+
gc,
|
140 |
+
kernel_size,
|
141 |
+
stride,
|
142 |
+
bias=bias,
|
143 |
+
pad_type=pad_type,
|
144 |
+
norm_type=norm_type,
|
145 |
+
act_type=act_type,
|
146 |
+
mode=mode,
|
147 |
+
c2x2=c2x2,
|
148 |
+
)
|
149 |
+
self.conv2 = USDU_util.conv_block(
|
150 |
+
nf + gc,
|
151 |
+
gc,
|
152 |
+
kernel_size,
|
153 |
+
stride,
|
154 |
+
bias=bias,
|
155 |
+
pad_type=pad_type,
|
156 |
+
norm_type=norm_type,
|
157 |
+
act_type=act_type,
|
158 |
+
mode=mode,
|
159 |
+
c2x2=c2x2,
|
160 |
+
)
|
161 |
+
self.conv3 = USDU_util.conv_block(
|
162 |
+
nf + 2 * gc,
|
163 |
+
gc,
|
164 |
+
kernel_size,
|
165 |
+
stride,
|
166 |
+
bias=bias,
|
167 |
+
pad_type=pad_type,
|
168 |
+
norm_type=norm_type,
|
169 |
+
act_type=act_type,
|
170 |
+
mode=mode,
|
171 |
+
c2x2=c2x2,
|
172 |
+
)
|
173 |
+
self.conv4 = USDU_util.conv_block(
|
174 |
+
nf + 3 * gc,
|
175 |
+
gc,
|
176 |
+
kernel_size,
|
177 |
+
stride,
|
178 |
+
bias=bias,
|
179 |
+
pad_type=pad_type,
|
180 |
+
norm_type=norm_type,
|
181 |
+
act_type=act_type,
|
182 |
+
mode=mode,
|
183 |
+
c2x2=c2x2,
|
184 |
+
)
|
185 |
+
last_act = None
|
186 |
+
self.conv5 = USDU_util.conv_block(
|
187 |
+
nf + 4 * gc,
|
188 |
+
nf,
|
189 |
+
3,
|
190 |
+
stride,
|
191 |
+
bias=bias,
|
192 |
+
pad_type=pad_type,
|
193 |
+
norm_type=norm_type,
|
194 |
+
act_type=last_act,
|
195 |
+
mode=mode,
|
196 |
+
c2x2=c2x2,
|
197 |
+
)
|
198 |
+
|
199 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
200 |
+
"""#### Forward pass of the ResidualDenseBlock_5C.
|
201 |
+
|
202 |
+
#### Args:
|
203 |
+
- `x` (torch.Tensor): Input tensor.
|
204 |
+
|
205 |
+
#### Returns:
|
206 |
+
- `torch.Tensor`: Output tensor.
|
207 |
+
"""
|
208 |
+
x1 = self.conv1(x)
|
209 |
+
x2 = self.conv2(torch.cat((x, x1), 1))
|
210 |
+
x3 = self.conv3(torch.cat((x, x1, x2), 1))
|
211 |
+
x4 = self.conv4(torch.cat((x, x1, x2, x3), 1))
|
212 |
+
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
|
213 |
+
return x5 * 0.2 + x
|
214 |
+
|
215 |
+
|
216 |
+
class RRDBNet(nn.Module):
|
217 |
+
"""#### Residual in Residual Dense Block Network (RRDBNet) class.
|
218 |
+
|
219 |
+
#### Args:
|
220 |
+
- `state_dict` (dict): State dictionary.
|
221 |
+
- `norm` (str, optional): Normalization type. Defaults to None.
|
222 |
+
- `act` (str, optional): Activation type. Defaults to "leakyrelu".
|
223 |
+
- `upsampler` (str, optional): Upsampler type. Defaults to "upconv".
|
224 |
+
- `mode` (USDU_util.ConvMode, optional): Convolution mode. Defaults to "CNA".
|
225 |
+
"""
|
226 |
+
|
227 |
+
def __init__(
|
228 |
+
self,
|
229 |
+
state_dict: Dict[str, torch.Tensor],
|
230 |
+
norm: str = None,
|
231 |
+
act: str = "leakyrelu",
|
232 |
+
upsampler: str = "upconv",
|
233 |
+
mode: USDU_util.ConvMode = "CNA",
|
234 |
+
) -> None:
|
235 |
+
super(RRDBNet, self).__init__()
|
236 |
+
self.model_arch = "ESRGAN"
|
237 |
+
self.sub_type = "SR"
|
238 |
+
|
239 |
+
self.state = state_dict
|
240 |
+
self.norm = norm
|
241 |
+
self.act = act
|
242 |
+
self.upsampler = upsampler
|
243 |
+
self.mode = mode
|
244 |
+
|
245 |
+
self.state_map = {
|
246 |
+
# currently supports old, new, and newer RRDBNet arch _internal
|
247 |
+
# ESRGAN, BSRGAN/RealSR, Real-ESRGAN
|
248 |
+
"model.0.weight": ("conv_first.weight",),
|
249 |
+
"model.0.bias": ("conv_first.bias",),
|
250 |
+
"model.1.sub./NB/.weight": ("trunk_conv.weight", "conv_body.weight"),
|
251 |
+
"model.1.sub./NB/.bias": ("trunk_conv.bias", "conv_body.bias"),
|
252 |
+
r"model.1.sub.\1.RDB\2.conv\3.0.\4": (
|
253 |
+
r"RRDB_trunk\.(\d+)\.RDB(\d)\.conv(\d+)\.(weight|bias)",
|
254 |
+
r"body\.(\d+)\.rdb(\d)\.conv(\d+)\.(weight|bias)",
|
255 |
+
),
|
256 |
+
}
|
257 |
+
self.num_blocks = self.get_num_blocks()
|
258 |
+
self.plus = any("conv1x1" in k for k in self.state.keys())
|
259 |
+
|
260 |
+
self.state = self.new_to_old_arch(self.state)
|
261 |
+
|
262 |
+
self.key_arr = list(self.state.keys())
|
263 |
+
|
264 |
+
self.in_nc: int = self.state[self.key_arr[0]].shape[1]
|
265 |
+
self.out_nc: int = self.state[self.key_arr[-1]].shape[0]
|
266 |
+
|
267 |
+
self.scale: int = self.get_scale()
|
268 |
+
self.num_filters: int = self.state[self.key_arr[0]].shape[0]
|
269 |
+
|
270 |
+
c2x2 = False
|
271 |
+
|
272 |
+
self.supports_fp16 = True
|
273 |
+
self.supports_bfp16 = True
|
274 |
+
self.min_size_restriction = None
|
275 |
+
|
276 |
+
self.shuffle_factor = None
|
277 |
+
|
278 |
+
upsample_block = {
|
279 |
+
"upconv": USDU_util.upconv_block,
|
280 |
+
}.get(self.upsampler)
|
281 |
+
upsample_blocks = [
|
282 |
+
upsample_block(
|
283 |
+
in_nc=self.num_filters,
|
284 |
+
out_nc=self.num_filters,
|
285 |
+
act_type=self.act,
|
286 |
+
c2x2=c2x2,
|
287 |
+
)
|
288 |
+
for _ in range(int(math.log(self.scale, 2)))
|
289 |
+
]
|
290 |
+
|
291 |
+
self.model = USDU_util.sequential(
|
292 |
+
# fea conv
|
293 |
+
USDU_util.conv_block(
|
294 |
+
in_nc=self.in_nc,
|
295 |
+
out_nc=self.num_filters,
|
296 |
+
kernel_size=3,
|
297 |
+
norm_type=None,
|
298 |
+
act_type=None,
|
299 |
+
c2x2=c2x2,
|
300 |
+
),
|
301 |
+
USDU_util.ShortcutBlock(
|
302 |
+
USDU_util.sequential(
|
303 |
+
# rrdb blocks
|
304 |
+
*[
|
305 |
+
RRDB(
|
306 |
+
nf=self.num_filters,
|
307 |
+
kernel_size=3,
|
308 |
+
gc=32,
|
309 |
+
stride=1,
|
310 |
+
bias=True,
|
311 |
+
pad_type="zero",
|
312 |
+
norm_type=self.norm,
|
313 |
+
act_type=self.act,
|
314 |
+
mode="CNA",
|
315 |
+
plus=self.plus,
|
316 |
+
c2x2=c2x2,
|
317 |
+
)
|
318 |
+
for _ in range(self.num_blocks)
|
319 |
+
],
|
320 |
+
# lr conv
|
321 |
+
USDU_util.conv_block(
|
322 |
+
in_nc=self.num_filters,
|
323 |
+
out_nc=self.num_filters,
|
324 |
+
kernel_size=3,
|
325 |
+
norm_type=self.norm,
|
326 |
+
act_type=None,
|
327 |
+
mode=self.mode,
|
328 |
+
c2x2=c2x2,
|
329 |
+
),
|
330 |
+
)
|
331 |
+
),
|
332 |
+
*upsample_blocks,
|
333 |
+
# hr_conv0
|
334 |
+
USDU_util.conv_block(
|
335 |
+
in_nc=self.num_filters,
|
336 |
+
out_nc=self.num_filters,
|
337 |
+
kernel_size=3,
|
338 |
+
norm_type=None,
|
339 |
+
act_type=self.act,
|
340 |
+
c2x2=c2x2,
|
341 |
+
),
|
342 |
+
# hr_conv1
|
343 |
+
USDU_util.conv_block(
|
344 |
+
in_nc=self.num_filters,
|
345 |
+
out_nc=self.out_nc,
|
346 |
+
kernel_size=3,
|
347 |
+
norm_type=None,
|
348 |
+
act_type=None,
|
349 |
+
c2x2=c2x2,
|
350 |
+
),
|
351 |
+
)
|
352 |
+
|
353 |
+
self.load_state_dict(self.state, strict=False)
|
354 |
+
|
355 |
+
def new_to_old_arch(self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
356 |
+
"""#### Convert new architecture state dictionary to old architecture.
|
357 |
+
|
358 |
+
#### Args:
|
359 |
+
- `state` (dict): State dictionary.
|
360 |
+
|
361 |
+
#### Returns:
|
362 |
+
- `dict`: Converted state dictionary.
|
363 |
+
"""
|
364 |
+
# add nb to state keys
|
365 |
+
for kind in ("weight", "bias"):
|
366 |
+
self.state_map[f"model.1.sub.{self.num_blocks}.{kind}"] = self.state_map[
|
367 |
+
f"model.1.sub./NB/.{kind}"
|
368 |
+
]
|
369 |
+
del self.state_map[f"model.1.sub./NB/.{kind}"]
|
370 |
+
|
371 |
+
old_state = OrderedDict()
|
372 |
+
for old_key, new_keys in self.state_map.items():
|
373 |
+
for new_key in new_keys:
|
374 |
+
if r"\1" in old_key:
|
375 |
+
for k, v in state.items():
|
376 |
+
sub = re.sub(new_key, old_key, k)
|
377 |
+
if sub != k:
|
378 |
+
old_state[sub] = v
|
379 |
+
else:
|
380 |
+
if new_key in state:
|
381 |
+
old_state[old_key] = state[new_key]
|
382 |
+
|
383 |
+
# upconv layers
|
384 |
+
max_upconv = 0
|
385 |
+
for key in state.keys():
|
386 |
+
match = re.match(r"(upconv|conv_up)(\d)\.(weight|bias)", key)
|
387 |
+
if match is not None:
|
388 |
+
_, key_num, key_type = match.groups()
|
389 |
+
old_state[f"model.{int(key_num) * 3}.{key_type}"] = state[key]
|
390 |
+
max_upconv = max(max_upconv, int(key_num) * 3)
|
391 |
+
|
392 |
+
# final layers
|
393 |
+
for key in state.keys():
|
394 |
+
if key in ("HRconv.weight", "conv_hr.weight"):
|
395 |
+
old_state[f"model.{max_upconv + 2}.weight"] = state[key]
|
396 |
+
elif key in ("HRconv.bias", "conv_hr.bias"):
|
397 |
+
old_state[f"model.{max_upconv + 2}.bias"] = state[key]
|
398 |
+
elif key in ("conv_last.weight",):
|
399 |
+
old_state[f"model.{max_upconv + 4}.weight"] = state[key]
|
400 |
+
elif key in ("conv_last.bias",):
|
401 |
+
old_state[f"model.{max_upconv + 4}.bias"] = state[key]
|
402 |
+
|
403 |
+
# Sort by first numeric value of each layer
|
404 |
+
def compare(item1: str, item2: str) -> int:
|
405 |
+
parts1 = item1.split(".")
|
406 |
+
parts2 = item2.split(".")
|
407 |
+
int1 = int(parts1[1])
|
408 |
+
int2 = int(parts2[1])
|
409 |
+
return int1 - int2
|
410 |
+
|
411 |
+
sorted_keys = sorted(old_state.keys(), key=functools.cmp_to_key(compare))
|
412 |
+
|
413 |
+
# Rebuild the output dict in the right order
|
414 |
+
out_dict = OrderedDict((k, old_state[k]) for k in sorted_keys)
|
415 |
+
|
416 |
+
return out_dict
|
417 |
+
|
418 |
+
def get_scale(self, min_part: int = 6) -> int:
|
419 |
+
"""#### Get the scale factor.
|
420 |
+
|
421 |
+
#### Args:
|
422 |
+
- `min_part` (int, optional): Minimum part. Defaults to 6.
|
423 |
+
|
424 |
+
#### Returns:
|
425 |
+
- `int`: Scale factor.
|
426 |
+
"""
|
427 |
+
n = 0
|
428 |
+
for part in list(self.state):
|
429 |
+
parts = part.split(".")[1:]
|
430 |
+
if len(parts) == 2:
|
431 |
+
part_num = int(parts[0])
|
432 |
+
if part_num > min_part and parts[1] == "weight":
|
433 |
+
n += 1
|
434 |
+
return 2**n
|
435 |
+
|
436 |
+
def get_num_blocks(self) -> int:
|
437 |
+
"""#### Get the number of blocks.
|
438 |
+
|
439 |
+
#### Returns:
|
440 |
+
- `int`: Number of blocks.
|
441 |
+
"""
|
442 |
+
nbs = []
|
443 |
+
state_keys = self.state_map[r"model.1.sub.\1.RDB\2.conv\3.0.\4"] + (
|
444 |
+
r"model\.\d+\.sub\.(\d+)\.RDB(\d+)\.conv(\d+)\.0\.(weight|bias)",
|
445 |
+
)
|
446 |
+
for state_key in state_keys:
|
447 |
+
for k in self.state:
|
448 |
+
m = re.search(state_key, k)
|
449 |
+
if m:
|
450 |
+
nbs.append(int(m.group(1)))
|
451 |
+
if nbs:
|
452 |
+
break
|
453 |
+
return max(*nbs) + 1
|
454 |
+
|
455 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
456 |
+
"""#### Forward pass of the RRDBNet.
|
457 |
+
|
458 |
+
#### Args:
|
459 |
+
- `x` (torch.Tensor): Input tensor.
|
460 |
+
|
461 |
+
#### Returns:
|
462 |
+
- `torch.Tensor`: Output tensor.
|
463 |
+
"""
|
464 |
+
return self.model(x)
|
465 |
+
|
466 |
+
|
467 |
+
PyTorchSRModels = (RRDBNet,)
|
468 |
+
PyTorchSRModel = Union[RRDBNet,]
|
469 |
+
|
470 |
+
PyTorchModels = (*PyTorchSRModels,)
|
471 |
+
PyTorchModel = Union[PyTorchSRModel]
|
modules/UltimateSDUpscale/USDU_upscaler.py
ADDED
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging as logger
|
2 |
+
import torch
|
3 |
+
from PIL import Image
|
4 |
+
|
5 |
+
from modules.Device import Device
|
6 |
+
from modules.UltimateSDUpscale import RDRB
|
7 |
+
from modules.UltimateSDUpscale import image_util
|
8 |
+
from modules.Utilities import util
|
9 |
+
|
10 |
+
|
11 |
+
def load_state_dict(state_dict: dict) -> RDRB.PyTorchModel:
|
12 |
+
"""#### Load a state dictionary into a PyTorch model.
|
13 |
+
|
14 |
+
#### Args:
|
15 |
+
- `state_dict` (dict): The state dictionary.
|
16 |
+
|
17 |
+
#### Returns:
|
18 |
+
- `RDRB.PyTorchModel`: The loaded PyTorch model.
|
19 |
+
"""
|
20 |
+
logger.debug("Loading state dict into pytorch model arch")
|
21 |
+
state_dict_keys = list(state_dict.keys())
|
22 |
+
if "params_ema" in state_dict_keys:
|
23 |
+
state_dict = state_dict["params_ema"]
|
24 |
+
model = RDRB.RRDBNet(state_dict)
|
25 |
+
return model
|
26 |
+
|
27 |
+
|
28 |
+
class UpscaleModelLoader:
|
29 |
+
"""#### Class for loading upscale models."""
|
30 |
+
|
31 |
+
def load_model(self, model_name: str) -> tuple:
|
32 |
+
"""#### Load an upscale model.
|
33 |
+
|
34 |
+
#### Args:
|
35 |
+
- `model_name` (str): The name of the model.
|
36 |
+
|
37 |
+
#### Returns:
|
38 |
+
- `tuple`: The loaded model.
|
39 |
+
"""
|
40 |
+
model_path = f"_internal/ESRGAN/{model_name}"
|
41 |
+
sd = util.load_torch_file(model_path, safe_load=True)
|
42 |
+
if "module.layers.0.residual_group.blocks.0.norm1.weight" in sd:
|
43 |
+
sd = util.state_dict_prefix_replace(sd, {"module.": ""})
|
44 |
+
out = load_state_dict(sd).eval()
|
45 |
+
return (out,)
|
46 |
+
|
47 |
+
|
48 |
+
class ImageUpscaleWithModel:
|
49 |
+
"""#### Class for upscaling images with a model."""
|
50 |
+
|
51 |
+
def upscale(self, upscale_model: torch.nn.Module, image: torch.Tensor) -> tuple:
|
52 |
+
"""#### Upscale an image using a model.
|
53 |
+
|
54 |
+
#### Args:
|
55 |
+
- `upscale_model` (torch.nn.Module): The upscale model.
|
56 |
+
- `image` (torch.Tensor): The input image tensor.
|
57 |
+
|
58 |
+
#### Returns:
|
59 |
+
- `tuple`: The upscaled image tensor.
|
60 |
+
"""
|
61 |
+
device = torch.device(torch.cuda.current_device())
|
62 |
+
upscale_model.to(device)
|
63 |
+
in_img = image.movedim(-1, -3).to(device)
|
64 |
+
Device.get_free_memory(device)
|
65 |
+
|
66 |
+
tile = 512
|
67 |
+
overlap = 32
|
68 |
+
|
69 |
+
oom = True
|
70 |
+
while oom:
|
71 |
+
steps = in_img.shape[0] * image_util.get_tiled_scale_steps(
|
72 |
+
in_img.shape[3],
|
73 |
+
in_img.shape[2],
|
74 |
+
tile_x=tile,
|
75 |
+
tile_y=tile,
|
76 |
+
overlap=overlap,
|
77 |
+
)
|
78 |
+
pbar = util.ProgressBar(steps)
|
79 |
+
s = image_util.tiled_scale(
|
80 |
+
in_img,
|
81 |
+
lambda a: upscale_model(a),
|
82 |
+
tile_x=tile,
|
83 |
+
tile_y=tile,
|
84 |
+
overlap=overlap,
|
85 |
+
upscale_amount=upscale_model.scale,
|
86 |
+
pbar=pbar,
|
87 |
+
)
|
88 |
+
oom = False
|
89 |
+
|
90 |
+
upscale_model.cpu()
|
91 |
+
s = torch.clamp(s.movedim(-3, -1), min=0, max=1.0)
|
92 |
+
return (s,)
|
93 |
+
|
94 |
+
|
95 |
+
def torch_gc() -> None:
|
96 |
+
"""#### Perform garbage collection for PyTorch."""
|
97 |
+
pass
|
98 |
+
|
99 |
+
|
100 |
+
class Script:
|
101 |
+
"""#### Class representing a script."""
|
102 |
+
pass
|
103 |
+
|
104 |
+
|
105 |
+
class Options:
|
106 |
+
"""#### Class representing options."""
|
107 |
+
|
108 |
+
img2img_background_color: str = "#ffffff" # Set to white for now
|
109 |
+
|
110 |
+
|
111 |
+
class State:
|
112 |
+
"""#### Class representing the state."""
|
113 |
+
|
114 |
+
interrupted: bool = False
|
115 |
+
|
116 |
+
def begin(self) -> None:
|
117 |
+
"""#### Begin the state."""
|
118 |
+
pass
|
119 |
+
|
120 |
+
def end(self) -> None:
|
121 |
+
"""#### End the state."""
|
122 |
+
pass
|
123 |
+
|
124 |
+
|
125 |
+
opts = Options()
|
126 |
+
state = State()
|
127 |
+
|
128 |
+
# Will only ever hold 1 upscaler
|
129 |
+
sd_upscalers = [None]
|
130 |
+
actual_upscaler = None
|
131 |
+
|
132 |
+
# Batch of images to upscale
|
133 |
+
batch = None
|
134 |
+
|
135 |
+
|
136 |
+
if not hasattr(Image, "Resampling"): # For older versions of Pillow
|
137 |
+
Image.Resampling = Image
|
138 |
+
|
139 |
+
|
140 |
+
class Upscaler:
|
141 |
+
"""#### Class for upscaling images."""
|
142 |
+
|
143 |
+
def _upscale(self, img: Image.Image, scale: float) -> Image.Image:
|
144 |
+
"""#### Upscale an image.
|
145 |
+
|
146 |
+
#### Args:
|
147 |
+
- `img` (Image.Image): The input image.
|
148 |
+
- `scale` (float): The scale factor.
|
149 |
+
|
150 |
+
#### Returns:
|
151 |
+
- `Image.Image`: The upscaled image.
|
152 |
+
"""
|
153 |
+
global actual_upscaler
|
154 |
+
tensor = image_util.pil_to_tensor(img)
|
155 |
+
image_upscale_node = ImageUpscaleWithModel()
|
156 |
+
(upscaled,) = image_upscale_node.upscale(actual_upscaler, tensor)
|
157 |
+
return image_util.tensor_to_pil(upscaled)
|
158 |
+
|
159 |
+
def upscale(self, img: Image.Image, scale: float, selected_model: str = None) -> Image.Image:
|
160 |
+
"""#### Upscale an image with a selected model.
|
161 |
+
|
162 |
+
#### Args:
|
163 |
+
- `img` (Image.Image): The input image.
|
164 |
+
- `scale` (float): The scale factor.
|
165 |
+
- `selected_model` (str, optional): The selected model. Defaults to None.
|
166 |
+
|
167 |
+
#### Returns:
|
168 |
+
- `Image.Image`: The upscaled image.
|
169 |
+
"""
|
170 |
+
global batch
|
171 |
+
batch = [self._upscale(img, scale) for img in batch]
|
172 |
+
return batch[0]
|
173 |
+
|
174 |
+
|
175 |
+
class UpscalerData:
|
176 |
+
"""#### Class for storing upscaler data."""
|
177 |
+
|
178 |
+
name: str = ""
|
179 |
+
data_path: str = ""
|
180 |
+
|
181 |
+
def __init__(self):
|
182 |
+
self.scaler = Upscaler()
|
modules/UltimateSDUpscale/USDU_util.py
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Literal
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
ConvMode = Literal["CNA", "NAC", "CNAC"]
|
6 |
+
|
7 |
+
def act(act_type: str, inplace: bool = True, neg_slope: float = 0.2, n_prelu: int = 1) -> nn.Module:
|
8 |
+
"""#### Get the activation layer.
|
9 |
+
|
10 |
+
#### Args:
|
11 |
+
- `act_type` (str): The type of activation.
|
12 |
+
- `inplace` (bool, optional): Whether to perform the operation in-place. Defaults to True.
|
13 |
+
- `neg_slope` (float, optional): The negative slope for LeakyReLU. Defaults to 0.2.
|
14 |
+
- `n_prelu` (int, optional): The number of PReLU parameters. Defaults to 1.
|
15 |
+
|
16 |
+
#### Returns:
|
17 |
+
- `nn.Module`: The activation layer.
|
18 |
+
"""
|
19 |
+
act_type = act_type.lower()
|
20 |
+
layer = nn.LeakyReLU(neg_slope, inplace)
|
21 |
+
return layer
|
22 |
+
|
23 |
+
def get_valid_padding(kernel_size: int, dilation: int) -> int:
|
24 |
+
"""#### Get the valid padding for a convolutional layer.
|
25 |
+
|
26 |
+
#### Args:
|
27 |
+
- `kernel_size` (int): The size of the kernel.
|
28 |
+
- `dilation` (int): The dilation rate.
|
29 |
+
|
30 |
+
#### Returns:
|
31 |
+
- `int`: The valid padding.
|
32 |
+
"""
|
33 |
+
kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1)
|
34 |
+
padding = (kernel_size - 1) // 2
|
35 |
+
return padding
|
36 |
+
|
37 |
+
def sequential(*args: nn.Module) -> nn.Sequential:
|
38 |
+
"""#### Create a sequential container.
|
39 |
+
|
40 |
+
#### Args:
|
41 |
+
- `*args` (nn.Module): The modules to include in the sequential container.
|
42 |
+
|
43 |
+
#### Returns:
|
44 |
+
- `nn.Sequential`: The sequential container.
|
45 |
+
"""
|
46 |
+
modules = []
|
47 |
+
for module in args:
|
48 |
+
if isinstance(module, nn.Sequential):
|
49 |
+
for submodule in module.children():
|
50 |
+
modules.append(submodule)
|
51 |
+
elif isinstance(module, nn.Module):
|
52 |
+
modules.append(module)
|
53 |
+
return nn.Sequential(*modules)
|
54 |
+
|
55 |
+
def conv_block(
|
56 |
+
in_nc: int,
|
57 |
+
out_nc: int,
|
58 |
+
kernel_size: int,
|
59 |
+
stride: int = 1,
|
60 |
+
dilation: int = 1,
|
61 |
+
groups: int = 1,
|
62 |
+
bias: bool = True,
|
63 |
+
pad_type: str = "zero",
|
64 |
+
norm_type: str | None = None,
|
65 |
+
act_type: str | None = "relu",
|
66 |
+
mode: ConvMode = "CNA",
|
67 |
+
c2x2: bool = False,
|
68 |
+
) -> nn.Sequential:
|
69 |
+
"""#### Create a convolutional block.
|
70 |
+
|
71 |
+
#### Args:
|
72 |
+
- `in_nc` (int): The number of input channels.
|
73 |
+
- `out_nc` (int): The number of output channels.
|
74 |
+
- `kernel_size` (int): The size of the kernel.
|
75 |
+
- `stride` (int, optional): The stride of the convolution. Defaults to 1.
|
76 |
+
- `dilation` (int, optional): The dilation rate. Defaults to 1.
|
77 |
+
- `groups` (int, optional): The number of groups. Defaults to 1.
|
78 |
+
- `bias` (bool, optional): Whether to include a bias term. Defaults to True.
|
79 |
+
- `pad_type` (str, optional): The type of padding. Defaults to "zero".
|
80 |
+
- `norm_type` (str | None, optional): The type of normalization. Defaults to None.
|
81 |
+
- `act_type` (str | None, optional): The type of activation. Defaults to "relu".
|
82 |
+
- `mode` (ConvMode, optional): The mode of the convolution. Defaults to "CNA".
|
83 |
+
- `c2x2` (bool, optional): Whether to use 2x2 convolutions. Defaults to False.
|
84 |
+
|
85 |
+
#### Returns:
|
86 |
+
- `nn.Sequential`: The convolutional block.
|
87 |
+
"""
|
88 |
+
assert mode in ("CNA", "NAC", "CNAC"), "Wrong conv mode [{:s}]".format(mode)
|
89 |
+
padding = get_valid_padding(kernel_size, dilation)
|
90 |
+
padding = padding if pad_type == "zero" else 0
|
91 |
+
|
92 |
+
c = nn.Conv2d(
|
93 |
+
in_nc,
|
94 |
+
out_nc,
|
95 |
+
kernel_size=kernel_size,
|
96 |
+
stride=stride,
|
97 |
+
padding=padding,
|
98 |
+
dilation=dilation,
|
99 |
+
bias=bias,
|
100 |
+
groups=groups,
|
101 |
+
)
|
102 |
+
a = act(act_type) if act_type else None
|
103 |
+
if mode in ("CNA", "CNAC"):
|
104 |
+
return sequential(None, c, None, a)
|
105 |
+
|
106 |
+
def upconv_block(
|
107 |
+
in_nc: int,
|
108 |
+
out_nc: int,
|
109 |
+
upscale_factor: int = 2,
|
110 |
+
kernel_size: int = 3,
|
111 |
+
stride: int = 1,
|
112 |
+
bias: bool = True,
|
113 |
+
pad_type: str = "zero",
|
114 |
+
norm_type: str | None = None,
|
115 |
+
act_type: str = "relu",
|
116 |
+
mode: str = "nearest",
|
117 |
+
c2x2: bool = False,
|
118 |
+
) -> nn.Sequential:
|
119 |
+
"""#### Create an upsampling convolutional block.
|
120 |
+
|
121 |
+
#### Args:
|
122 |
+
- `in_nc` (int): The number of input channels.
|
123 |
+
- `out_nc` (int): The number of output channels.
|
124 |
+
- `upscale_factor` (int, optional): The upscale factor. Defaults to 2.
|
125 |
+
- `kernel_size` (int, optional): The size of the kernel. Defaults to 3.
|
126 |
+
- `stride` (int, optional): The stride of the convolution. Defaults to 1.
|
127 |
+
- `bias` (bool, optional): Whether to include a bias term. Defaults to True.
|
128 |
+
- `pad_type` (str, optional): The type of padding. Defaults to "zero".
|
129 |
+
- `norm_type` (str | None, optional): The type of normalization. Defaults to None.
|
130 |
+
- `act_type` (str, optional): The type of activation. Defaults to "relu".
|
131 |
+
- `mode` (str, optional): The mode of upsampling. Defaults to "nearest".
|
132 |
+
- `c2x2` (bool, optional): Whether to use 2x2 convolutions. Defaults to False.
|
133 |
+
|
134 |
+
#### Returns:
|
135 |
+
- `nn.Sequential`: The upsampling convolutional block.
|
136 |
+
"""
|
137 |
+
upsample = nn.Upsample(scale_factor=upscale_factor, mode=mode)
|
138 |
+
conv = conv_block(
|
139 |
+
in_nc,
|
140 |
+
out_nc,
|
141 |
+
kernel_size,
|
142 |
+
stride,
|
143 |
+
bias=bias,
|
144 |
+
pad_type=pad_type,
|
145 |
+
norm_type=norm_type,
|
146 |
+
act_type=act_type,
|
147 |
+
c2x2=c2x2,
|
148 |
+
)
|
149 |
+
return sequential(upsample, conv)
|
150 |
+
|
151 |
+
class ShortcutBlock(nn.Module):
|
152 |
+
"""#### Elementwise sum the output of a submodule to its input."""
|
153 |
+
|
154 |
+
def __init__(self, submodule: nn.Module):
|
155 |
+
"""#### Initialize the ShortcutBlock.
|
156 |
+
|
157 |
+
#### Args:
|
158 |
+
- `submodule` (nn.Module): The submodule to apply.
|
159 |
+
"""
|
160 |
+
super(ShortcutBlock, self).__init__()
|
161 |
+
self.sub = submodule
|
162 |
+
|
163 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
164 |
+
"""#### Forward pass.
|
165 |
+
|
166 |
+
#### Args:
|
167 |
+
- `x` (torch.Tensor): The input tensor.
|
168 |
+
|
169 |
+
#### Returns:
|
170 |
+
- `torch.Tensor`: The output tensor.
|
171 |
+
"""
|
172 |
+
output = x + self.sub(x)
|
173 |
+
return output
|
modules/UltimateSDUpscale/UltimateSDUpscale.py
ADDED
@@ -0,0 +1,1019 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from modules.AutoEncoders import VariationalAE
|
2 |
+
from modules.sample import sampling
|
3 |
+
from modules.UltimateSDUpscale import USDU_upscaler, image_util
|
4 |
+
import torch
|
5 |
+
from PIL import ImageFilter, ImageDraw, Image
|
6 |
+
from enum import Enum
|
7 |
+
import math
|
8 |
+
|
9 |
+
# taken from https://github.com/ssitu/ComfyUI_UltimateSDUpscale
|
10 |
+
|
11 |
+
state = USDU_upscaler.state
|
12 |
+
|
13 |
+
class UnsupportedModel(Exception):
|
14 |
+
"""#### Exception raised for unsupported models."""
|
15 |
+
pass
|
16 |
+
|
17 |
+
|
18 |
+
class StableDiffusionProcessing:
|
19 |
+
"""#### Class representing the processing of Stable Diffusion images."""
|
20 |
+
|
21 |
+
def __init__(
|
22 |
+
self,
|
23 |
+
init_img: Image.Image,
|
24 |
+
model: torch.nn.Module,
|
25 |
+
positive: str,
|
26 |
+
negative: str,
|
27 |
+
vae: VariationalAE.VAE,
|
28 |
+
seed: int,
|
29 |
+
steps: int,
|
30 |
+
cfg: float,
|
31 |
+
sampler_name: str,
|
32 |
+
scheduler: str,
|
33 |
+
denoise: float,
|
34 |
+
upscale_by: float,
|
35 |
+
uniform_tile_mode: bool,
|
36 |
+
):
|
37 |
+
"""
|
38 |
+
#### Initialize the StableDiffusionProcessing class.
|
39 |
+
|
40 |
+
#### Args:
|
41 |
+
- `init_img` (Image.Image): The initial image.
|
42 |
+
- `model` (torch.nn.Module): The model.
|
43 |
+
- `positive` (str): The positive prompt.
|
44 |
+
- `negative` (str): The negative prompt.
|
45 |
+
- `vae` (VariationalAE.VAE): The variational autoencoder.
|
46 |
+
- `seed` (int): The seed.
|
47 |
+
- `steps` (int): The number of steps.
|
48 |
+
- `cfg` (float): The CFG scale.
|
49 |
+
- `sampler_name` (str): The sampler name.
|
50 |
+
- `scheduler` (str): The scheduler.
|
51 |
+
- `denoise` (float): The denoise strength.
|
52 |
+
- `upscale_by` (float): The upscale factor.
|
53 |
+
- `uniform_tile_mode` (bool): Whether to use uniform tile mode.
|
54 |
+
"""
|
55 |
+
# Variables used by the USDU script
|
56 |
+
self.init_images = [init_img]
|
57 |
+
self.image_mask = None
|
58 |
+
self.mask_blur = 0
|
59 |
+
self.inpaint_full_res_padding = 0
|
60 |
+
self.width = init_img.width
|
61 |
+
self.height = init_img.height
|
62 |
+
|
63 |
+
self.model = model
|
64 |
+
self.positive = positive
|
65 |
+
self.negative = negative
|
66 |
+
self.vae = vae
|
67 |
+
self.seed = seed
|
68 |
+
self.steps = steps
|
69 |
+
self.cfg = cfg
|
70 |
+
self.sampler_name = sampler_name
|
71 |
+
self.scheduler = scheduler
|
72 |
+
self.denoise = denoise
|
73 |
+
|
74 |
+
# Variables used only by this script
|
75 |
+
self.init_size = init_img.width, init_img.height
|
76 |
+
self.upscale_by = upscale_by
|
77 |
+
self.uniform_tile_mode = uniform_tile_mode
|
78 |
+
|
79 |
+
# Other required A1111 variables for the USDU script that is currently unused in this script
|
80 |
+
self.extra_generation_params = {}
|
81 |
+
|
82 |
+
|
83 |
+
class Processed:
|
84 |
+
"""#### Class representing the processed images."""
|
85 |
+
|
86 |
+
def __init__(
|
87 |
+
self, p: StableDiffusionProcessing, images: list, seed: int, info: str
|
88 |
+
):
|
89 |
+
"""
|
90 |
+
#### Initialize the Processed class.
|
91 |
+
|
92 |
+
#### Args:
|
93 |
+
- `p` (StableDiffusionProcessing): The processing object.
|
94 |
+
- `images` (list): The list of images.
|
95 |
+
- `seed` (int): The seed.
|
96 |
+
- `info` (str): The information string.
|
97 |
+
"""
|
98 |
+
self.images = images
|
99 |
+
self.seed = seed
|
100 |
+
self.info = info
|
101 |
+
|
102 |
+
def infotext(self, p: StableDiffusionProcessing, index: int) -> str:
|
103 |
+
"""
|
104 |
+
#### Get the information text.
|
105 |
+
|
106 |
+
#### Args:
|
107 |
+
- `p` (StableDiffusionProcessing): The processing object.
|
108 |
+
- `index` (int): The index.
|
109 |
+
|
110 |
+
#### Returns:
|
111 |
+
- `str`: The information text.
|
112 |
+
"""
|
113 |
+
return None
|
114 |
+
|
115 |
+
|
116 |
+
def fix_seed(p: StableDiffusionProcessing) -> None:
|
117 |
+
"""
|
118 |
+
#### Fix the seed for reproducibility.
|
119 |
+
|
120 |
+
#### Args:
|
121 |
+
- `p` (StableDiffusionProcessing): The processing object.
|
122 |
+
"""
|
123 |
+
pass
|
124 |
+
|
125 |
+
|
126 |
+
def process_images(p: StableDiffusionProcessing, pipeline: bool = False) -> Processed:
|
127 |
+
"""
|
128 |
+
#### Process the images.
|
129 |
+
|
130 |
+
#### Args:
|
131 |
+
- `p` (StableDiffusionProcessing): The processing object.
|
132 |
+
|
133 |
+
#### Returns:
|
134 |
+
- `Processed`: The processed images.
|
135 |
+
"""
|
136 |
+
# Where the main image generation happens in A1111
|
137 |
+
|
138 |
+
# Setup
|
139 |
+
image_mask = p.image_mask.convert("L")
|
140 |
+
init_image = p.init_images[0]
|
141 |
+
|
142 |
+
# Locate the white region of the mask outlining the tile and add padding
|
143 |
+
crop_region = image_util.get_crop_region(image_mask, p.inpaint_full_res_padding)
|
144 |
+
|
145 |
+
x1, y1, x2, y2 = crop_region
|
146 |
+
crop_width = x2 - x1
|
147 |
+
crop_height = y2 - y1
|
148 |
+
crop_ratio = crop_width / crop_height
|
149 |
+
p_ratio = p.width / p.height
|
150 |
+
if crop_ratio > p_ratio:
|
151 |
+
target_width = crop_width
|
152 |
+
target_height = round(crop_width / p_ratio)
|
153 |
+
else:
|
154 |
+
target_width = round(crop_height * p_ratio)
|
155 |
+
target_height = crop_height
|
156 |
+
crop_region, _ = image_util.expand_crop(
|
157 |
+
crop_region,
|
158 |
+
image_mask.width,
|
159 |
+
image_mask.height,
|
160 |
+
target_width,
|
161 |
+
target_height,
|
162 |
+
)
|
163 |
+
tile_size = p.width, p.height
|
164 |
+
|
165 |
+
# Blur the mask
|
166 |
+
if p.mask_blur > 0:
|
167 |
+
image_mask = image_mask.filter(ImageFilter.GaussianBlur(p.mask_blur))
|
168 |
+
|
169 |
+
# Crop the images to get the tiles that will be used for generation
|
170 |
+
tiles = [img.crop(crop_region) for img in USDU_upscaler.batch]
|
171 |
+
|
172 |
+
# Assume the same size for all images in the batch
|
173 |
+
initial_tile_size = tiles[0].size
|
174 |
+
|
175 |
+
# Resize if necessary
|
176 |
+
for i, tile in enumerate(tiles):
|
177 |
+
if tile.size != tile_size:
|
178 |
+
tiles[i] = tile.resize(tile_size, Image.Resampling.LANCZOS)
|
179 |
+
|
180 |
+
# Crop conditioning
|
181 |
+
positive_cropped = image_util.crop_cond(
|
182 |
+
p.positive, crop_region, p.init_size, init_image.size, tile_size
|
183 |
+
)
|
184 |
+
negative_cropped = image_util.crop_cond(
|
185 |
+
p.negative, crop_region, p.init_size, init_image.size, tile_size
|
186 |
+
)
|
187 |
+
|
188 |
+
# Encode the image
|
189 |
+
vae_encoder = VariationalAE.VAEEncode()
|
190 |
+
batched_tiles = torch.cat([image_util.pil_to_tensor(tile) for tile in tiles], dim=0)
|
191 |
+
(latent,) = vae_encoder.encode(p.vae, batched_tiles)
|
192 |
+
|
193 |
+
# Generate samples
|
194 |
+
(samples,) = sampling.common_ksampler(
|
195 |
+
p.model,
|
196 |
+
p.seed,
|
197 |
+
p.steps,
|
198 |
+
p.cfg,
|
199 |
+
p.sampler_name,
|
200 |
+
p.scheduler,
|
201 |
+
positive_cropped,
|
202 |
+
negative_cropped,
|
203 |
+
latent,
|
204 |
+
denoise=p.denoise,
|
205 |
+
pipeline=pipeline
|
206 |
+
)
|
207 |
+
|
208 |
+
# Decode the sample
|
209 |
+
vae_decoder = VariationalAE.VAEDecode()
|
210 |
+
(decoded,) = vae_decoder.decode(p.vae, samples)
|
211 |
+
|
212 |
+
# Convert the sample to a PIL image
|
213 |
+
tiles_sampled = [image_util.tensor_to_pil(decoded, i) for i in range(len(decoded))]
|
214 |
+
|
215 |
+
for i, tile_sampled in enumerate(tiles_sampled):
|
216 |
+
init_image = USDU_upscaler.batch[i]
|
217 |
+
|
218 |
+
# Resize back to the original size
|
219 |
+
if tile_sampled.size != initial_tile_size:
|
220 |
+
tile_sampled = tile_sampled.resize(
|
221 |
+
initial_tile_size, Image.Resampling.LANCZOS
|
222 |
+
)
|
223 |
+
|
224 |
+
# Put the tile into position
|
225 |
+
image_tile_only = Image.new("RGBA", init_image.size)
|
226 |
+
image_tile_only.paste(tile_sampled, crop_region[:2])
|
227 |
+
|
228 |
+
# Add the mask as an alpha channel
|
229 |
+
# Must make a copy due to the possibility of an edge becoming black
|
230 |
+
temp = image_tile_only.copy()
|
231 |
+
image_mask = image_mask.resize(temp.size)
|
232 |
+
temp.putalpha(image_mask)
|
233 |
+
temp.putalpha(image_mask)
|
234 |
+
image_tile_only.paste(temp, image_tile_only)
|
235 |
+
|
236 |
+
# Add back the tile to the initial image according to the mask in the alpha channel
|
237 |
+
result = init_image.convert("RGBA")
|
238 |
+
result.alpha_composite(image_tile_only)
|
239 |
+
|
240 |
+
# Convert back to RGB
|
241 |
+
result = result.convert("RGB")
|
242 |
+
USDU_upscaler.batch[i] = result
|
243 |
+
|
244 |
+
processed = Processed(p, [USDU_upscaler.batch[0]], p.seed, None)
|
245 |
+
return processed
|
246 |
+
|
247 |
+
|
248 |
+
class USDUMode(Enum):
|
249 |
+
"""#### Enum representing the modes for Ultimate SD Upscale."""
|
250 |
+
LINEAR = 0
|
251 |
+
CHESS = 1
|
252 |
+
NONE = 2
|
253 |
+
|
254 |
+
|
255 |
+
class USDUSFMode(Enum):
|
256 |
+
"""#### Enum representing the seam fix modes for Ultimate SD Upscale."""
|
257 |
+
NONE = 0
|
258 |
+
BAND_PASS = 1
|
259 |
+
HALF_TILE = 2
|
260 |
+
HALF_TILE_PLUS_INTERSECTIONS = 3
|
261 |
+
|
262 |
+
|
263 |
+
class USDUpscaler:
|
264 |
+
"""#### Class representing the Ultimate SD Upscaler."""
|
265 |
+
|
266 |
+
def __init__(
|
267 |
+
self,
|
268 |
+
p: StableDiffusionProcessing,
|
269 |
+
image: Image.Image,
|
270 |
+
upscaler_index: int,
|
271 |
+
save_redraw: bool,
|
272 |
+
save_seams_fix: bool,
|
273 |
+
tile_width: int,
|
274 |
+
tile_height: int,
|
275 |
+
) -> None:
|
276 |
+
"""
|
277 |
+
#### Initialize the USDUpscaler class.
|
278 |
+
|
279 |
+
#### Args:
|
280 |
+
- `p` (StableDiffusionProcessing): The processing object.
|
281 |
+
- `image` (Image.Image): The image.
|
282 |
+
- `upscaler_index` (int): The upscaler index.
|
283 |
+
- `save_redraw` (bool): Whether to save the redraw.
|
284 |
+
- `save_seams_fix` (bool): Whether to save the seams fix.
|
285 |
+
- `tile_width` (int): The tile width.
|
286 |
+
- `tile_height` (int): The tile height.
|
287 |
+
"""
|
288 |
+
self.p: StableDiffusionProcessing = p
|
289 |
+
self.image: Image = image
|
290 |
+
self.scale_factor = math.ceil(
|
291 |
+
max(p.width, p.height) / max(image.width, image.height)
|
292 |
+
)
|
293 |
+
self.upscaler = USDU_upscaler.sd_upscalers[upscaler_index]
|
294 |
+
self.redraw = USDURedraw()
|
295 |
+
self.redraw.save = save_redraw
|
296 |
+
self.redraw.tile_width = tile_width if tile_width > 0 else tile_height
|
297 |
+
self.redraw.tile_height = tile_height if tile_height > 0 else tile_width
|
298 |
+
self.seams_fix = USDUSeamsFix()
|
299 |
+
self.seams_fix.save = save_seams_fix
|
300 |
+
self.seams_fix.tile_width = tile_width if tile_width > 0 else tile_height
|
301 |
+
self.seams_fix.tile_height = tile_height if tile_height > 0 else tile_width
|
302 |
+
self.initial_info = None
|
303 |
+
self.rows = math.ceil(self.p.height / self.redraw.tile_height)
|
304 |
+
self.cols = math.ceil(self.p.width / self.redraw.tile_width)
|
305 |
+
|
306 |
+
def get_factor(self, num: int) -> int:
|
307 |
+
"""
|
308 |
+
#### Get the factor for a given number.
|
309 |
+
|
310 |
+
#### Args:
|
311 |
+
- `num` (int): The number.
|
312 |
+
|
313 |
+
#### Returns:
|
314 |
+
- `int`: The factor.
|
315 |
+
"""
|
316 |
+
if num == 1:
|
317 |
+
return 2
|
318 |
+
if num % 4 == 0:
|
319 |
+
return 4
|
320 |
+
if num % 3 == 0:
|
321 |
+
return 3
|
322 |
+
if num % 2 == 0:
|
323 |
+
return 2
|
324 |
+
return 0
|
325 |
+
|
326 |
+
def get_factors(self) -> None:
|
327 |
+
"""
|
328 |
+
#### Get the list of scale factors.
|
329 |
+
"""
|
330 |
+
scales = []
|
331 |
+
current_scale = 1
|
332 |
+
current_scale_factor = self.get_factor(self.scale_factor)
|
333 |
+
while current_scale < self.scale_factor:
|
334 |
+
current_scale_factor = self.get_factor(self.scale_factor // current_scale)
|
335 |
+
scales.append(current_scale_factor)
|
336 |
+
current_scale = current_scale * current_scale_factor
|
337 |
+
self.scales = enumerate(scales)
|
338 |
+
|
339 |
+
def upscale(self) -> None:
|
340 |
+
"""
|
341 |
+
#### Upscale the image.
|
342 |
+
"""
|
343 |
+
# Log info
|
344 |
+
print(f"Canva size: {self.p.width}x{self.p.height}")
|
345 |
+
print(f"Image size: {self.image.width}x{self.image.height}")
|
346 |
+
print(f"Scale factor: {self.scale_factor}")
|
347 |
+
# Get list with scale factors
|
348 |
+
self.get_factors()
|
349 |
+
# Upscaling image over all factors
|
350 |
+
for index, value in self.scales:
|
351 |
+
print(f"Upscaling iteration {index + 1} with scale factor {value}")
|
352 |
+
self.image = self.upscaler.scaler.upscale(
|
353 |
+
self.image, value, self.upscaler.data_path
|
354 |
+
)
|
355 |
+
# Resize image to set values
|
356 |
+
self.image = self.image.resize(
|
357 |
+
(self.p.width, self.p.height), resample=Image.LANCZOS
|
358 |
+
)
|
359 |
+
|
360 |
+
def setup_redraw(self, redraw_mode: int, padding: int, mask_blur: int) -> None:
|
361 |
+
"""
|
362 |
+
#### Set up the redraw.
|
363 |
+
|
364 |
+
#### Args:
|
365 |
+
- `redraw_mode` (int): The redraw mode.
|
366 |
+
- `padding` (int): The padding.
|
367 |
+
- `mask_blur` (int): The mask blur.
|
368 |
+
"""
|
369 |
+
self.redraw.mode = USDUMode(redraw_mode)
|
370 |
+
self.redraw.enabled = self.redraw.mode != USDUMode.NONE
|
371 |
+
self.redraw.padding = padding
|
372 |
+
self.p.mask_blur = mask_blur
|
373 |
+
|
374 |
+
def setup_seams_fix(
|
375 |
+
self, padding: int, denoise: float, mask_blur: int, width: int, mode: int
|
376 |
+
) -> None:
|
377 |
+
"""
|
378 |
+
#### Set up the seams fix.
|
379 |
+
|
380 |
+
#### Args:
|
381 |
+
- `padding` (int): The padding.
|
382 |
+
- `denoise` (float): The denoise strength.
|
383 |
+
- `mask_blur` (int): The mask blur.
|
384 |
+
- `width` (int): The width.
|
385 |
+
- `mode` (int): The mode.
|
386 |
+
"""
|
387 |
+
self.seams_fix.padding = padding
|
388 |
+
self.seams_fix.denoise = denoise
|
389 |
+
self.seams_fix.mask_blur = mask_blur
|
390 |
+
self.seams_fix.width = width
|
391 |
+
self.seams_fix.mode = USDUSFMode(mode)
|
392 |
+
self.seams_fix.enabled = self.seams_fix.mode != USDUSFMode.NONE
|
393 |
+
|
394 |
+
def calc_jobs_count(self) -> None:
|
395 |
+
"""
|
396 |
+
#### Calculate the number of jobs.
|
397 |
+
"""
|
398 |
+
redraw_job_count = (self.rows * self.cols) if self.redraw.enabled else 0
|
399 |
+
seams_job_count = self.rows * (self.cols - 1) + (self.rows - 1) * self.cols
|
400 |
+
global state
|
401 |
+
state.job_count = redraw_job_count + seams_job_count
|
402 |
+
|
403 |
+
def print_info(self) -> None:
|
404 |
+
"""
|
405 |
+
#### Print the information.
|
406 |
+
"""
|
407 |
+
print(f"Tile size: {self.redraw.tile_width}x{self.redraw.tile_height}")
|
408 |
+
print(f"Tiles amount: {self.rows * self.cols}")
|
409 |
+
print(f"Grid: {self.rows}x{self.cols}")
|
410 |
+
print(f"Redraw enabled: {self.redraw.enabled}")
|
411 |
+
print(f"Seams fix mode: {self.seams_fix.mode.name}")
|
412 |
+
|
413 |
+
def add_extra_info(self) -> None:
|
414 |
+
"""
|
415 |
+
#### Add extra information.
|
416 |
+
"""
|
417 |
+
self.p.extra_generation_params["Ultimate SD upscale upscaler"] = (
|
418 |
+
self.upscaler.name
|
419 |
+
)
|
420 |
+
self.p.extra_generation_params["Ultimate SD upscale tile_width"] = (
|
421 |
+
self.redraw.tile_width
|
422 |
+
)
|
423 |
+
self.p.extra_generation_params["Ultimate SD upscale tile_height"] = (
|
424 |
+
self.redraw.tile_height
|
425 |
+
)
|
426 |
+
self.p.extra_generation_params["Ultimate SD upscale mask_blur"] = (
|
427 |
+
self.p.mask_blur
|
428 |
+
)
|
429 |
+
self.p.extra_generation_params["Ultimate SD upscale padding"] = (
|
430 |
+
self.redraw.padding
|
431 |
+
)
|
432 |
+
|
433 |
+
def process(self, pipeline) -> None:
|
434 |
+
"""
|
435 |
+
#### Process the image.
|
436 |
+
"""
|
437 |
+
USDU_upscaler.state.begin()
|
438 |
+
self.calc_jobs_count()
|
439 |
+
self.result_images = []
|
440 |
+
if self.redraw.enabled:
|
441 |
+
self.image = self.redraw.start(self.p, self.image, self.rows, self.cols, pipeline)
|
442 |
+
self.initial_info = self.redraw.initial_info
|
443 |
+
self.result_images.append(self.image)
|
444 |
+
|
445 |
+
if self.seams_fix.enabled:
|
446 |
+
self.image = self.seams_fix.start(self.p, self.image, self.rows, self.cols, pipeline)
|
447 |
+
self.initial_info = self.seams_fix.initial_info
|
448 |
+
self.result_images.append(self.image)
|
449 |
+
USDU_upscaler.state.end()
|
450 |
+
|
451 |
+
|
452 |
+
class USDURedraw:
|
453 |
+
"""#### Class representing the redraw functionality for Ultimate SD Upscale."""
|
454 |
+
|
455 |
+
def init_draw(self, p: StableDiffusionProcessing, width: int, height: int) -> tuple:
|
456 |
+
"""
|
457 |
+
#### Initialize the draw.
|
458 |
+
|
459 |
+
#### Args:
|
460 |
+
- `p` (StableDiffusionProcessing): The processing object.
|
461 |
+
- `width` (int): The width.
|
462 |
+
- `height` (int): The height.
|
463 |
+
|
464 |
+
#### Returns:
|
465 |
+
- `tuple`: The mask and draw objects.
|
466 |
+
"""
|
467 |
+
p.inpaint_full_res = True
|
468 |
+
p.inpaint_full_res_padding = self.padding
|
469 |
+
p.width = math.ceil((self.tile_width + self.padding) / 64) * 64
|
470 |
+
p.height = math.ceil((self.tile_height + self.padding) / 64) * 64
|
471 |
+
mask = Image.new("L", (width, height), "black")
|
472 |
+
draw = ImageDraw.Draw(mask)
|
473 |
+
return mask, draw
|
474 |
+
|
475 |
+
def calc_rectangle(self, xi: int, yi: int) -> tuple:
|
476 |
+
"""
|
477 |
+
#### Calculate the rectangle coordinates.
|
478 |
+
|
479 |
+
#### Args:
|
480 |
+
- `xi` (int): The x index.
|
481 |
+
- `yi` (int): The y index.
|
482 |
+
|
483 |
+
#### Returns:
|
484 |
+
- `tuple`: The rectangle coordinates.
|
485 |
+
"""
|
486 |
+
x1 = xi * self.tile_width
|
487 |
+
y1 = yi * self.tile_height
|
488 |
+
x2 = xi * self.tile_width + self.tile_width
|
489 |
+
y2 = yi * self.tile_height + self.tile_height
|
490 |
+
|
491 |
+
return x1, y1, x2, y2
|
492 |
+
|
493 |
+
def linear_process(
|
494 |
+
self, p: StableDiffusionProcessing, image: Image.Image, rows: int, cols: int, pipeline: bool = False
|
495 |
+
) -> Image.Image:
|
496 |
+
"""
|
497 |
+
#### Perform linear processing.
|
498 |
+
|
499 |
+
#### Args:
|
500 |
+
- `p` (StableDiffusionProcessing): The processing object.
|
501 |
+
- `image` (Image.Image): The image.
|
502 |
+
- `rows` (int): The number of rows.
|
503 |
+
- `cols` (int): The number of columns.
|
504 |
+
|
505 |
+
#### Returns:
|
506 |
+
- `Image.Image`: The processed image.
|
507 |
+
"""
|
508 |
+
global state
|
509 |
+
mask, draw = self.init_draw(p, image.width, image.height)
|
510 |
+
for yi in range(rows):
|
511 |
+
for xi in range(cols):
|
512 |
+
if state.interrupted:
|
513 |
+
break
|
514 |
+
draw.rectangle(self.calc_rectangle(xi, yi), fill="white")
|
515 |
+
p.init_images = [image]
|
516 |
+
p.image_mask = mask
|
517 |
+
processed = process_images(p, pipeline)
|
518 |
+
draw.rectangle(self.calc_rectangle(xi, yi), fill="black")
|
519 |
+
if len(processed.images) > 0:
|
520 |
+
image = processed.images[0]
|
521 |
+
|
522 |
+
p.width = image.width
|
523 |
+
p.height = image.height
|
524 |
+
self.initial_info = processed.infotext(p, 0)
|
525 |
+
|
526 |
+
return image
|
527 |
+
|
528 |
+
def start(self, p: StableDiffusionProcessing, image: Image.Image, rows: int, cols: int, pipeline: bool = False) -> Image.Image:
|
529 |
+
"""#### Start the redraw.
|
530 |
+
|
531 |
+
#### Args:
|
532 |
+
- `p` (StableDiffusionProcessing): The processing object.
|
533 |
+
- `image` (Image.Image): The image.
|
534 |
+
- `rows` (int): The number of rows.
|
535 |
+
- `cols` (int): The number of columns.
|
536 |
+
|
537 |
+
#### Returns:
|
538 |
+
- `Image.Image`: The processed image.
|
539 |
+
"""
|
540 |
+
self.initial_info = None
|
541 |
+
return self.linear_process(p, image, rows, cols, pipeline=pipeline)
|
542 |
+
|
543 |
+
|
544 |
+
class USDUSeamsFix:
|
545 |
+
"""#### Class representing the seams fix functionality for Ultimate SD Upscale."""
|
546 |
+
|
547 |
+
def init_draw(self, p: StableDiffusionProcessing) -> None:
|
548 |
+
"""#### Initialize the draw.
|
549 |
+
|
550 |
+
#### Args:
|
551 |
+
- `p` (StableDiffusionProcessing): The processing object.
|
552 |
+
"""
|
553 |
+
self.initial_info = None
|
554 |
+
p.width = math.ceil((self.tile_width + self.padding) / 64) * 64
|
555 |
+
p.height = math.ceil((self.tile_height + self.padding) / 64) * 64
|
556 |
+
|
557 |
+
def half_tile_process(
|
558 |
+
self, p: StableDiffusionProcessing, image: Image.Image, rows: int, cols: int, pipeline: bool = False
|
559 |
+
) -> Image.Image:
|
560 |
+
"""#### Perform half-tile processing.
|
561 |
+
|
562 |
+
#### Args:
|
563 |
+
- `p` (StableDiffusionProcessing): The processing object.
|
564 |
+
- `image` (Image.Image): The image.
|
565 |
+
- `rows` (int): The number of rows.
|
566 |
+
- `cols` (int): The number of columns.
|
567 |
+
|
568 |
+
#### Returns:
|
569 |
+
- `Image.Image`: The processed image.
|
570 |
+
"""
|
571 |
+
global state
|
572 |
+
self.init_draw(p)
|
573 |
+
processed = None
|
574 |
+
|
575 |
+
gradient = Image.linear_gradient("L")
|
576 |
+
row_gradient = Image.new("L", (self.tile_width, self.tile_height), "black")
|
577 |
+
row_gradient.paste(
|
578 |
+
gradient.resize(
|
579 |
+
(self.tile_width, self.tile_height // 2), resample=Image.BICUBIC
|
580 |
+
),
|
581 |
+
(0, 0),
|
582 |
+
)
|
583 |
+
row_gradient.paste(
|
584 |
+
gradient.rotate(180).resize(
|
585 |
+
(self.tile_width, self.tile_height // 2), resample=Image.BICUBIC
|
586 |
+
),
|
587 |
+
(0, self.tile_height // 2),
|
588 |
+
)
|
589 |
+
col_gradient = Image.new("L", (self.tile_width, self.tile_height), "black")
|
590 |
+
col_gradient.paste(
|
591 |
+
gradient.rotate(90).resize(
|
592 |
+
(self.tile_width // 2, self.tile_height), resample=Image.BICUBIC
|
593 |
+
),
|
594 |
+
(0, 0),
|
595 |
+
)
|
596 |
+
col_gradient.paste(
|
597 |
+
gradient.rotate(270).resize(
|
598 |
+
(self.tile_width // 2, self.tile_height), resample=Image.BICUBIC
|
599 |
+
),
|
600 |
+
(self.tile_width // 2, 0),
|
601 |
+
)
|
602 |
+
|
603 |
+
p.denoising_strength = self.denoise
|
604 |
+
p.mask_blur = self.mask_blur
|
605 |
+
|
606 |
+
for yi in range(rows - 1):
|
607 |
+
for xi in range(cols):
|
608 |
+
p.width = self.tile_width
|
609 |
+
p.height = self.tile_height
|
610 |
+
p.inpaint_full_res = True
|
611 |
+
p.inpaint_full_res_padding = self.padding
|
612 |
+
mask = Image.new("L", (image.width, image.height), "black")
|
613 |
+
mask.paste(
|
614 |
+
row_gradient,
|
615 |
+
(
|
616 |
+
xi * self.tile_width,
|
617 |
+
yi * self.tile_height + self.tile_height // 2,
|
618 |
+
),
|
619 |
+
)
|
620 |
+
|
621 |
+
p.init_images = [image]
|
622 |
+
p.image_mask = mask
|
623 |
+
processed = process_images(p, pipeline)
|
624 |
+
if len(processed.images) > 0:
|
625 |
+
image = processed.images[0]
|
626 |
+
|
627 |
+
for yi in range(rows):
|
628 |
+
for xi in range(cols - 1):
|
629 |
+
p.width = self.tile_width
|
630 |
+
p.height = self.tile_height
|
631 |
+
p.inpaint_full_res = True
|
632 |
+
p.inpaint_full_res_padding = self.padding
|
633 |
+
mask = Image.new("L", (image.width, image.height), "black")
|
634 |
+
mask.paste(
|
635 |
+
col_gradient,
|
636 |
+
(
|
637 |
+
xi * self.tile_width + self.tile_width // 2,
|
638 |
+
yi * self.tile_height,
|
639 |
+
),
|
640 |
+
)
|
641 |
+
|
642 |
+
p.init_images = [image]
|
643 |
+
p.image_mask = mask
|
644 |
+
processed = process_images(p, pipeline)
|
645 |
+
if len(processed.images) > 0:
|
646 |
+
image = processed.images[0]
|
647 |
+
|
648 |
+
p.width = image.width
|
649 |
+
p.height = image.height
|
650 |
+
if processed is not None:
|
651 |
+
self.initial_info = processed.infotext(p, 0)
|
652 |
+
|
653 |
+
return image
|
654 |
+
|
655 |
+
def start(
|
656 |
+
self, p: StableDiffusionProcessing, image: Image.Image, rows: int, cols: int, pipeline: bool = False
|
657 |
+
) -> Image.Image:
|
658 |
+
"""#### Start the seams fix process.
|
659 |
+
|
660 |
+
#### Args:
|
661 |
+
- `p` (StableDiffusionProcessing): The processing object.
|
662 |
+
- `image` (Image.Image): The image.
|
663 |
+
- `rows` (int): The number of rows.
|
664 |
+
- `cols` (int): The number of columns.
|
665 |
+
|
666 |
+
#### Returns:
|
667 |
+
- `Image.Image`: The processed image.
|
668 |
+
"""
|
669 |
+
return self.half_tile_process(p, image, rows, cols, pipeline=pipeline)
|
670 |
+
|
671 |
+
|
672 |
+
class Script(USDU_upscaler.Script):
|
673 |
+
"""#### Class representing the script for Ultimate SD Upscale."""
|
674 |
+
|
675 |
+
def run(
|
676 |
+
self,
|
677 |
+
p: StableDiffusionProcessing,
|
678 |
+
_: None,
|
679 |
+
tile_width: int,
|
680 |
+
tile_height: int,
|
681 |
+
mask_blur: int,
|
682 |
+
padding: int,
|
683 |
+
seams_fix_width: int,
|
684 |
+
seams_fix_denoise: float,
|
685 |
+
seams_fix_padding: int,
|
686 |
+
upscaler_index: int,
|
687 |
+
save_upscaled_image: bool,
|
688 |
+
redraw_mode: int,
|
689 |
+
save_seams_fix_image: bool,
|
690 |
+
seams_fix_mask_blur: int,
|
691 |
+
seams_fix_type: int,
|
692 |
+
target_size_type: int,
|
693 |
+
custom_width: int,
|
694 |
+
custom_height: int,
|
695 |
+
custom_scale: float,
|
696 |
+
pipeline: bool = False,
|
697 |
+
) -> Processed:
|
698 |
+
"""#### Run the script.
|
699 |
+
|
700 |
+
#### Args:
|
701 |
+
- `p` (StableDiffusionProcessing): The processing object.
|
702 |
+
- `_` (None): Unused parameter.
|
703 |
+
- `tile_width` (int): The tile width.
|
704 |
+
- `tile_height` (int): The tile height.
|
705 |
+
- `mask_blur` (int): The mask blur.
|
706 |
+
- `padding` (int): The padding.
|
707 |
+
- `seams_fix_width` (int): The seams fix width.
|
708 |
+
- `seams_fix_denoise` (float): The seams fix denoise strength.
|
709 |
+
- `seams_fix_padding` (int): The seams fix padding.
|
710 |
+
- `upscaler_index` (int): The upscaler index.
|
711 |
+
- `save_upscaled_image` (bool): Whether to save the upscaled image.
|
712 |
+
- `redraw_mode` (int): The redraw mode.
|
713 |
+
- `save_seams_fix_image` (bool): Whether to save the seams fix image.
|
714 |
+
- `seams_fix_mask_blur` (int): The seams fix mask blur.
|
715 |
+
- `seams_fix_type` (int): The seams fix type.
|
716 |
+
- `target_size_type` (int): The target size type.
|
717 |
+
- `custom_width` (int): The custom width.
|
718 |
+
- `custom_height` (int): The custom height.
|
719 |
+
- `custom_scale` (float): The custom scale.
|
720 |
+
|
721 |
+
#### Returns:
|
722 |
+
- `Processed`: The processed images.
|
723 |
+
"""
|
724 |
+
# Init
|
725 |
+
fix_seed(p)
|
726 |
+
USDU_upscaler.torch_gc()
|
727 |
+
|
728 |
+
p.do_not_save_grid = True
|
729 |
+
p.do_not_save_samples = True
|
730 |
+
p.inpaint_full_res = False
|
731 |
+
|
732 |
+
p.inpainting_fill = 1
|
733 |
+
p.n_iter = 1
|
734 |
+
p.batch_size = 1
|
735 |
+
|
736 |
+
seed = p.seed
|
737 |
+
|
738 |
+
# Init image
|
739 |
+
init_img = p.init_images[0]
|
740 |
+
init_img = image_util.flatten(
|
741 |
+
init_img, USDU_upscaler.opts.img2img_background_color
|
742 |
+
)
|
743 |
+
|
744 |
+
p.width = math.ceil((init_img.width * custom_scale) / 64) * 64
|
745 |
+
p.height = math.ceil((init_img.height * custom_scale) / 64) * 64
|
746 |
+
|
747 |
+
# Upscaling
|
748 |
+
upscaler = USDUpscaler(
|
749 |
+
p,
|
750 |
+
init_img,
|
751 |
+
upscaler_index,
|
752 |
+
save_upscaled_image,
|
753 |
+
save_seams_fix_image,
|
754 |
+
tile_width,
|
755 |
+
tile_height,
|
756 |
+
)
|
757 |
+
upscaler.upscale()
|
758 |
+
|
759 |
+
# Drawing
|
760 |
+
upscaler.setup_redraw(redraw_mode, padding, mask_blur)
|
761 |
+
upscaler.setup_seams_fix(
|
762 |
+
seams_fix_padding,
|
763 |
+
seams_fix_denoise,
|
764 |
+
seams_fix_mask_blur,
|
765 |
+
seams_fix_width,
|
766 |
+
seams_fix_type,
|
767 |
+
)
|
768 |
+
upscaler.print_info()
|
769 |
+
upscaler.add_extra_info()
|
770 |
+
upscaler.process(pipeline=pipeline)
|
771 |
+
result_images = upscaler.result_images
|
772 |
+
|
773 |
+
return Processed(
|
774 |
+
p,
|
775 |
+
result_images,
|
776 |
+
seed,
|
777 |
+
upscaler.initial_info if upscaler.initial_info is not None else "",
|
778 |
+
)
|
779 |
+
|
780 |
+
|
781 |
+
# Upscaler
|
782 |
+
old_init = USDUpscaler.__init__
|
783 |
+
|
784 |
+
|
785 |
+
def new_init(
|
786 |
+
self: USDUpscaler,
|
787 |
+
p: StableDiffusionProcessing,
|
788 |
+
image: Image.Image,
|
789 |
+
upscaler_index: int,
|
790 |
+
save_redraw: bool,
|
791 |
+
save_seams_fix: bool,
|
792 |
+
tile_width: int,
|
793 |
+
tile_height: int,
|
794 |
+
) -> None:
|
795 |
+
"""#### Initialize the USDUpscaler class with new settings.
|
796 |
+
|
797 |
+
#### Args:
|
798 |
+
- `self` (USDUpscaler): The USDUpscaler instance.
|
799 |
+
- `p` (StableDiffusionProcessing): The processing object.
|
800 |
+
- `image` (Image.Image): The image.
|
801 |
+
- `upscaler_index` (int): The upscaler index.
|
802 |
+
- `save_redraw` (bool): Whether to save the redraw.
|
803 |
+
- `save_seams_fix` (bool): Whether to save the seams fix.
|
804 |
+
- `tile_width` (int): The tile width.
|
805 |
+
- `tile_height` (int): The tile height.
|
806 |
+
"""
|
807 |
+
p.width = math.ceil((image.width * p.upscale_by) / 8) * 8
|
808 |
+
p.height = math.ceil((image.height * p.upscale_by) / 8) * 8
|
809 |
+
old_init(
|
810 |
+
self,
|
811 |
+
p,
|
812 |
+
image,
|
813 |
+
upscaler_index,
|
814 |
+
save_redraw,
|
815 |
+
save_seams_fix,
|
816 |
+
tile_width,
|
817 |
+
tile_height,
|
818 |
+
)
|
819 |
+
|
820 |
+
|
821 |
+
USDUpscaler.__init__ = new_init
|
822 |
+
|
823 |
+
# Redraw
|
824 |
+
old_setup_redraw = USDURedraw.init_draw
|
825 |
+
|
826 |
+
|
827 |
+
def new_setup_redraw(
|
828 |
+
self: USDURedraw, p: StableDiffusionProcessing, width: int, height: int
|
829 |
+
) -> tuple:
|
830 |
+
"""#### Set up the redraw with new settings.
|
831 |
+
|
832 |
+
#### Args:
|
833 |
+
- `self` (USDURedraw): The USDURedraw instance.
|
834 |
+
- `p` (StableDiffusionProcessing): The processing object.
|
835 |
+
- `width` (int): The width.
|
836 |
+
- `height` (int): The height.
|
837 |
+
|
838 |
+
#### Returns:
|
839 |
+
- `tuple`: The mask and draw objects.
|
840 |
+
"""
|
841 |
+
mask, draw = old_setup_redraw(self, p, width, height)
|
842 |
+
p.width = math.ceil((self.tile_width + self.padding) / 8) * 8
|
843 |
+
p.height = math.ceil((self.tile_height + self.padding) / 8) * 8
|
844 |
+
return mask, draw
|
845 |
+
|
846 |
+
|
847 |
+
USDURedraw.init_draw = new_setup_redraw
|
848 |
+
|
849 |
+
# Seams fix
|
850 |
+
old_setup_seams_fix = USDUSeamsFix.init_draw
|
851 |
+
|
852 |
+
|
853 |
+
def new_setup_seams_fix(self: USDUSeamsFix, p: StableDiffusionProcessing) -> None:
|
854 |
+
"""#### Set up the seams fix with new settings.
|
855 |
+
|
856 |
+
#### Args:
|
857 |
+
- `self` (USDUSeamsFix): The USDUSeamsFix instance.
|
858 |
+
- `p` (StableDiffusionProcessing): The processing object.
|
859 |
+
"""
|
860 |
+
old_setup_seams_fix(self, p)
|
861 |
+
p.width = math.ceil((self.tile_width + self.padding) / 8) * 8
|
862 |
+
p.height = math.ceil((self.tile_height + self.padding) / 8) * 8
|
863 |
+
|
864 |
+
|
865 |
+
USDUSeamsFix.init_draw = new_setup_seams_fix
|
866 |
+
|
867 |
+
# Make the script upscale on a batch of images instead of one image
|
868 |
+
old_upscale = USDUpscaler.upscale
|
869 |
+
|
870 |
+
|
871 |
+
def new_upscale(self: USDUpscaler) -> None:
|
872 |
+
"""#### Upscale a batch of images.
|
873 |
+
|
874 |
+
#### Args:
|
875 |
+
- `self` (USDUpscaler): The USDUpscaler instance.
|
876 |
+
"""
|
877 |
+
old_upscale(self)
|
878 |
+
USDU_upscaler.batch = [self.image] + [
|
879 |
+
img.resize((self.p.width, self.p.height), resample=Image.LANCZOS)
|
880 |
+
for img in USDU_upscaler.batch[1:]
|
881 |
+
]
|
882 |
+
|
883 |
+
|
884 |
+
USDUpscaler.upscale = new_upscale
|
885 |
+
MAX_RESOLUTION = 8192
|
886 |
+
# The modes available for Ultimate SD Upscale
|
887 |
+
MODES = {
|
888 |
+
"Linear": USDUMode.LINEAR,
|
889 |
+
"Chess": USDUMode.CHESS,
|
890 |
+
"None": USDUMode.NONE,
|
891 |
+
}
|
892 |
+
# The seam fix modes
|
893 |
+
SEAM_FIX_MODES = {
|
894 |
+
"None": USDUSFMode.NONE,
|
895 |
+
"Band Pass": USDUSFMode.BAND_PASS,
|
896 |
+
"Half Tile": USDUSFMode.HALF_TILE,
|
897 |
+
"Half Tile + Intersections": USDUSFMode.HALF_TILE_PLUS_INTERSECTIONS,
|
898 |
+
}
|
899 |
+
|
900 |
+
|
901 |
+
class UltimateSDUpscale:
|
902 |
+
"""#### Class representing the Ultimate SD Upscale functionality."""
|
903 |
+
|
904 |
+
def upscale(
|
905 |
+
self,
|
906 |
+
image: torch.Tensor,
|
907 |
+
model: torch.nn.Module,
|
908 |
+
positive: str,
|
909 |
+
negative: str,
|
910 |
+
vae: VariationalAE.VAE,
|
911 |
+
upscale_by: float,
|
912 |
+
seed: int,
|
913 |
+
steps: int,
|
914 |
+
cfg: float,
|
915 |
+
sampler_name: str,
|
916 |
+
scheduler: str,
|
917 |
+
denoise: float,
|
918 |
+
upscale_model: any,
|
919 |
+
mode_type: str,
|
920 |
+
tile_width: int,
|
921 |
+
tile_height: int,
|
922 |
+
mask_blur: int,
|
923 |
+
tile_padding: int,
|
924 |
+
seam_fix_mode: str,
|
925 |
+
seam_fix_denoise: float,
|
926 |
+
seam_fix_mask_blur: int,
|
927 |
+
seam_fix_width: int,
|
928 |
+
seam_fix_padding: int,
|
929 |
+
force_uniform_tiles: bool,
|
930 |
+
pipeline: bool = False,
|
931 |
+
) -> tuple:
|
932 |
+
"""#### Upscale the image.
|
933 |
+
|
934 |
+
#### Args:
|
935 |
+
- `image` (torch.Tensor): The image tensor.
|
936 |
+
- `model` (torch.nn.Module): The model.
|
937 |
+
- `positive` (str): The positive prompt.
|
938 |
+
- `negative` (str): The negative prompt.
|
939 |
+
- `vae` (VariationalAE.VAE): The variational autoencoder.
|
940 |
+
- `upscale_by` (float): The upscale factor.
|
941 |
+
- `seed` (int): The seed.
|
942 |
+
- `steps` (int): The number of steps.
|
943 |
+
- `cfg` (float): The CFG scale.
|
944 |
+
- `sampler_name` (str): The sampler name.
|
945 |
+
- `scheduler` (str): The scheduler.
|
946 |
+
- `denoise` (float): The denoise strength.
|
947 |
+
- `upscale_model` (any): The upscale model.
|
948 |
+
- `mode_type` (str): The mode type.
|
949 |
+
- `tile_width` (int): The tile width.
|
950 |
+
- `tile_height` (int): The tile height.
|
951 |
+
- `mask_blur` (int): The mask blur.
|
952 |
+
- `tile_padding` (int): The tile padding.
|
953 |
+
- `seam_fix_mode` (str): The seam fix mode.
|
954 |
+
- `seam_fix_denoise` (float): The seam fix denoise strength.
|
955 |
+
- `seam_fix_mask_blur` (int): The seam fix mask blur.
|
956 |
+
- `seam_fix_width` (int): The seam fix width.
|
957 |
+
- `seam_fix_padding` (int): The seam fix padding.
|
958 |
+
- `force_uniform_tiles` (bool): Whether to force uniform tiles.
|
959 |
+
|
960 |
+
#### Returns:
|
961 |
+
- `tuple`: The resulting tensor.
|
962 |
+
"""
|
963 |
+
# Set up A1111 patches
|
964 |
+
|
965 |
+
# Upscaler
|
966 |
+
# An object that the script works with
|
967 |
+
USDU_upscaler.sd_upscalers[0] = USDU_upscaler.UpscalerData()
|
968 |
+
# Where the actual upscaler is stored, will be used when the script upscales using the Upscaler in UpscalerData
|
969 |
+
USDU_upscaler.actual_upscaler = upscale_model
|
970 |
+
|
971 |
+
# Set the batch of images
|
972 |
+
USDU_upscaler.batch = [image_util.tensor_to_pil(image, i) for i in range(len(image))]
|
973 |
+
|
974 |
+
# Processing
|
975 |
+
sdprocessing = StableDiffusionProcessing(
|
976 |
+
image_util.tensor_to_pil(image),
|
977 |
+
model,
|
978 |
+
positive,
|
979 |
+
negative,
|
980 |
+
vae,
|
981 |
+
seed,
|
982 |
+
steps,
|
983 |
+
cfg,
|
984 |
+
sampler_name,
|
985 |
+
scheduler,
|
986 |
+
denoise,
|
987 |
+
upscale_by,
|
988 |
+
force_uniform_tiles,
|
989 |
+
)
|
990 |
+
|
991 |
+
# Running the script
|
992 |
+
script = Script()
|
993 |
+
script.run(
|
994 |
+
p=sdprocessing,
|
995 |
+
_=None,
|
996 |
+
tile_width=tile_width,
|
997 |
+
tile_height=tile_height,
|
998 |
+
mask_blur=mask_blur,
|
999 |
+
padding=tile_padding,
|
1000 |
+
seams_fix_width=seam_fix_width,
|
1001 |
+
seams_fix_denoise=seam_fix_denoise,
|
1002 |
+
seams_fix_padding=seam_fix_padding,
|
1003 |
+
upscaler_index=0,
|
1004 |
+
save_upscaled_image=False,
|
1005 |
+
redraw_mode=MODES[mode_type],
|
1006 |
+
save_seams_fix_image=False,
|
1007 |
+
seams_fix_mask_blur=seam_fix_mask_blur,
|
1008 |
+
seams_fix_type=SEAM_FIX_MODES[seam_fix_mode],
|
1009 |
+
target_size_type=2,
|
1010 |
+
custom_width=None,
|
1011 |
+
custom_height=None,
|
1012 |
+
custom_scale=upscale_by,
|
1013 |
+
pipeline=pipeline,
|
1014 |
+
)
|
1015 |
+
|
1016 |
+
# Return the resulting images
|
1017 |
+
images = [image_util.pil_to_tensor(img) for img in USDU_upscaler.batch]
|
1018 |
+
tensor = torch.cat(images, dim=0)
|
1019 |
+
return (tensor,)
|
modules/UltimateSDUpscale/image_util.py
ADDED
@@ -0,0 +1,265 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from PIL import Image
|
5 |
+
|
6 |
+
|
7 |
+
def get_tiled_scale_steps(width: int, height: int, tile_x: int, tile_y: int, overlap: int) -> int:
|
8 |
+
"""#### Calculate the number of steps required for tiled scaling.
|
9 |
+
|
10 |
+
#### Args:
|
11 |
+
- `width` (int): The width of the image.
|
12 |
+
- `height` (int): The height of the image.
|
13 |
+
- `tile_x` (int): The width of each tile.
|
14 |
+
- `tile_y` (int): The height of each tile.
|
15 |
+
- `overlap` (int): The overlap between tiles.
|
16 |
+
|
17 |
+
#### Returns:
|
18 |
+
- `int`: The number of steps required for tiled scaling.
|
19 |
+
"""
|
20 |
+
return math.ceil((height / (tile_y - overlap))) * math.ceil(
|
21 |
+
(width / (tile_x - overlap))
|
22 |
+
)
|
23 |
+
|
24 |
+
|
25 |
+
@torch.inference_mode()
|
26 |
+
def tiled_scale(
|
27 |
+
samples: torch.Tensor,
|
28 |
+
function: callable,
|
29 |
+
tile_x: int = 64,
|
30 |
+
tile_y: int = 64,
|
31 |
+
overlap: int = 8,
|
32 |
+
upscale_amount: float = 4,
|
33 |
+
out_channels: int = 3,
|
34 |
+
pbar: any = None,
|
35 |
+
) -> torch.Tensor:
|
36 |
+
"""#### Perform tiled scaling on a batch of samples.
|
37 |
+
|
38 |
+
#### Args:
|
39 |
+
- `samples` (torch.Tensor): The input samples.
|
40 |
+
- `function` (callable): The function to apply to each tile.
|
41 |
+
- `tile_x` (int, optional): The width of each tile. Defaults to 64.
|
42 |
+
- `tile_y` (int, optional): The height of each tile. Defaults to 64.
|
43 |
+
- `overlap` (int, optional): The overlap between tiles. Defaults to 8.
|
44 |
+
- `upscale_amount` (float, optional): The upscale amount. Defaults to 4.
|
45 |
+
- `out_channels` (int, optional): The number of output channels. Defaults to 3.
|
46 |
+
- `pbar` (any, optional): The progress bar. Defaults to None.
|
47 |
+
|
48 |
+
#### Returns:
|
49 |
+
- `torch.Tensor`: The scaled output tensor.
|
50 |
+
"""
|
51 |
+
output = torch.empty(
|
52 |
+
(
|
53 |
+
samples.shape[0],
|
54 |
+
out_channels,
|
55 |
+
round(samples.shape[2] * upscale_amount),
|
56 |
+
round(samples.shape[3] * upscale_amount),
|
57 |
+
),
|
58 |
+
device="cpu",
|
59 |
+
)
|
60 |
+
for b in range(samples.shape[0]):
|
61 |
+
s = samples[b : b + 1]
|
62 |
+
out = torch.zeros(
|
63 |
+
(
|
64 |
+
s.shape[0],
|
65 |
+
out_channels,
|
66 |
+
round(s.shape[2] * upscale_amount),
|
67 |
+
round(s.shape[3] * upscale_amount),
|
68 |
+
),
|
69 |
+
device="cpu",
|
70 |
+
)
|
71 |
+
out_div = torch.zeros(
|
72 |
+
(
|
73 |
+
s.shape[0],
|
74 |
+
out_channels,
|
75 |
+
round(s.shape[2] * upscale_amount),
|
76 |
+
round(s.shape[3] * upscale_amount),
|
77 |
+
),
|
78 |
+
device="cpu",
|
79 |
+
)
|
80 |
+
for y in range(0, s.shape[2], tile_y - overlap):
|
81 |
+
for x in range(0, s.shape[3], tile_x - overlap):
|
82 |
+
s_in = s[:, :, y : y + tile_y, x : x + tile_x]
|
83 |
+
|
84 |
+
ps = function(s_in).cpu()
|
85 |
+
mask = torch.ones_like(ps)
|
86 |
+
feather = round(overlap * upscale_amount)
|
87 |
+
for t in range(feather):
|
88 |
+
mask[:, :, t : 1 + t, :] *= (1.0 / feather) * (t + 1)
|
89 |
+
mask[:, :, mask.shape[2] - 1 - t : mask.shape[2] - t, :] *= (
|
90 |
+
1.0 / feather
|
91 |
+
) * (t + 1)
|
92 |
+
mask[:, :, :, t : 1 + t] *= (1.0 / feather) * (t + 1)
|
93 |
+
mask[:, :, :, mask.shape[3] - 1 - t : mask.shape[3] - t] *= (
|
94 |
+
1.0 / feather
|
95 |
+
) * (t + 1)
|
96 |
+
out[
|
97 |
+
:,
|
98 |
+
:,
|
99 |
+
round(y * upscale_amount) : round((y + tile_y) * upscale_amount),
|
100 |
+
round(x * upscale_amount) : round((x + tile_x) * upscale_amount),
|
101 |
+
] += ps * mask
|
102 |
+
out_div[
|
103 |
+
:,
|
104 |
+
:,
|
105 |
+
round(y * upscale_amount) : round((y + tile_y) * upscale_amount),
|
106 |
+
round(x * upscale_amount) : round((x + tile_x) * upscale_amount),
|
107 |
+
] += mask
|
108 |
+
|
109 |
+
output[b : b + 1] = out / out_div
|
110 |
+
return output
|
111 |
+
|
112 |
+
|
113 |
+
def flatten(img: Image.Image, bgcolor: str) -> Image.Image:
|
114 |
+
"""#### Replace transparency with a background color.
|
115 |
+
|
116 |
+
#### Args:
|
117 |
+
- `img` (Image.Image): The input image.
|
118 |
+
- `bgcolor` (str): The background color.
|
119 |
+
|
120 |
+
#### Returns:
|
121 |
+
- `Image.Image`: The image with transparency replaced by the background color.
|
122 |
+
"""
|
123 |
+
if img.mode in ("RGB"):
|
124 |
+
return img
|
125 |
+
return Image.alpha_composite(Image.new("RGBA", img.size, bgcolor), img).convert(
|
126 |
+
"RGB"
|
127 |
+
)
|
128 |
+
|
129 |
+
|
130 |
+
BLUR_KERNEL_SIZE = 15
|
131 |
+
|
132 |
+
|
133 |
+
def tensor_to_pil(img_tensor: torch.Tensor, batch_index: int = 0) -> Image.Image:
|
134 |
+
"""#### Convert a tensor to a PIL image.
|
135 |
+
|
136 |
+
#### Args:
|
137 |
+
- `img_tensor` (torch.Tensor): The input tensor.
|
138 |
+
- `batch_index` (int, optional): The batch index. Defaults to 0.
|
139 |
+
|
140 |
+
#### Returns:
|
141 |
+
- `Image.Image`: The converted PIL image.
|
142 |
+
"""
|
143 |
+
img_tensor = img_tensor[batch_index].unsqueeze(0)
|
144 |
+
i = 255.0 * img_tensor.cpu().numpy()
|
145 |
+
img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8).squeeze())
|
146 |
+
return img
|
147 |
+
|
148 |
+
|
149 |
+
def pil_to_tensor(image: Image.Image) -> torch.Tensor:
|
150 |
+
"""#### Convert a PIL image to a tensor.
|
151 |
+
|
152 |
+
#### Args:
|
153 |
+
- `image` (Image.Image): The input PIL image.
|
154 |
+
|
155 |
+
#### Returns:
|
156 |
+
- `torch.Tensor`: The converted tensor.
|
157 |
+
"""
|
158 |
+
image = np.array(image).astype(np.float32) / 255.0
|
159 |
+
image = torch.from_numpy(image).unsqueeze(0)
|
160 |
+
return image
|
161 |
+
|
162 |
+
|
163 |
+
def get_crop_region(mask: Image.Image, pad: int = 0) -> tuple:
|
164 |
+
"""#### Get the coordinates of the white rectangular mask region.
|
165 |
+
|
166 |
+
#### Args:
|
167 |
+
- `mask` (Image.Image): The input mask image in 'L' mode.
|
168 |
+
- `pad` (int, optional): The padding to apply. Defaults to 0.
|
169 |
+
|
170 |
+
#### Returns:
|
171 |
+
- `tuple`: The coordinates of the crop region.
|
172 |
+
"""
|
173 |
+
coordinates = mask.getbbox()
|
174 |
+
if coordinates is not None:
|
175 |
+
x1, y1, x2, y2 = coordinates
|
176 |
+
else:
|
177 |
+
x1, y1, x2, y2 = mask.width, mask.height, 0, 0
|
178 |
+
# Apply padding
|
179 |
+
x1 = max(x1 - pad, 0)
|
180 |
+
y1 = max(y1 - pad, 0)
|
181 |
+
x2 = min(x2 + pad, mask.width)
|
182 |
+
y2 = min(y2 + pad, mask.height)
|
183 |
+
return fix_crop_region((x1, y1, x2, y2), (mask.width, mask.height))
|
184 |
+
|
185 |
+
|
186 |
+
def fix_crop_region(region: tuple, image_size: tuple) -> tuple:
|
187 |
+
"""#### Remove the extra pixel added by the get_crop_region function.
|
188 |
+
|
189 |
+
#### Args:
|
190 |
+
- `region` (tuple): The crop region coordinates.
|
191 |
+
- `image_size` (tuple): The size of the image.
|
192 |
+
|
193 |
+
#### Returns:
|
194 |
+
- `tuple`: The fixed crop region coordinates.
|
195 |
+
"""
|
196 |
+
image_width, image_height = image_size
|
197 |
+
x1, y1, x2, y2 = region
|
198 |
+
if x2 < image_width:
|
199 |
+
x2 -= 1
|
200 |
+
if y2 < image_height:
|
201 |
+
y2 -= 1
|
202 |
+
return x1, y1, x2, y2
|
203 |
+
|
204 |
+
|
205 |
+
def expand_crop(region: tuple, width: int, height: int, target_width: int, target_height: int) -> tuple:
|
206 |
+
"""#### Expand a crop region to a specified target size.
|
207 |
+
|
208 |
+
#### Args:
|
209 |
+
- `region` (tuple): The crop region coordinates.
|
210 |
+
- `width` (int): The width of the image.
|
211 |
+
- `height` (int): The height of the image.
|
212 |
+
- `target_width` (int): The desired width of the crop region.
|
213 |
+
- `target_height` (int): The desired height of the crop region.
|
214 |
+
|
215 |
+
#### Returns:
|
216 |
+
- `tuple`: The expanded crop region coordinates and the target size.
|
217 |
+
"""
|
218 |
+
x1, y1, x2, y2 = region
|
219 |
+
actual_width = x2 - x1
|
220 |
+
actual_height = y2 - y1
|
221 |
+
|
222 |
+
# Try to expand region to the right of half the difference
|
223 |
+
width_diff = target_width - actual_width
|
224 |
+
x2 = min(x2 + width_diff // 2, width)
|
225 |
+
# Expand region to the left of the difference including the pixels that could not be expanded to the right
|
226 |
+
width_diff = target_width - (x2 - x1)
|
227 |
+
x1 = max(x1 - width_diff, 0)
|
228 |
+
# Try the right again
|
229 |
+
width_diff = target_width - (x2 - x1)
|
230 |
+
x2 = min(x2 + width_diff, width)
|
231 |
+
|
232 |
+
# Try to expand region to the bottom of half the difference
|
233 |
+
height_diff = target_height - actual_height
|
234 |
+
y2 = min(y2 + height_diff // 2, height)
|
235 |
+
# Expand region to the top of the difference including the pixels that could not be expanded to the bottom
|
236 |
+
height_diff = target_height - (y2 - y1)
|
237 |
+
y1 = max(y1 - height_diff, 0)
|
238 |
+
# Try the bottom again
|
239 |
+
height_diff = target_height - (y2 - y1)
|
240 |
+
y2 = min(y2 + height_diff, height)
|
241 |
+
|
242 |
+
return (x1, y1, x2, y2), (target_width, target_height)
|
243 |
+
|
244 |
+
|
245 |
+
def crop_cond(cond: list, region: tuple, init_size: tuple, canvas_size: tuple, tile_size: tuple, w_pad: int = 0, h_pad: int = 0) -> list:
|
246 |
+
"""#### Crop conditioning data to match a specific region.
|
247 |
+
|
248 |
+
#### Args:
|
249 |
+
- `cond` (list): The conditioning data.
|
250 |
+
- `region` (tuple): The crop region coordinates.
|
251 |
+
- `init_size` (tuple): The initial size of the image.
|
252 |
+
- `canvas_size` (tuple): The size of the canvas.
|
253 |
+
- `tile_size` (tuple): The size of the tile.
|
254 |
+
- `w_pad` (int, optional): The width padding. Defaults to 0.
|
255 |
+
- `h_pad` (int, optional): The height padding. Defaults to 0.
|
256 |
+
|
257 |
+
#### Returns:
|
258 |
+
- `list`: The cropped conditioning data.
|
259 |
+
"""
|
260 |
+
cropped = []
|
261 |
+
for emb, x in cond:
|
262 |
+
cond_dict = x.copy()
|
263 |
+
n = [emb, cond_dict]
|
264 |
+
cropped.append(n)
|
265 |
+
return cropped
|