John6666 r3gm commited on
Commit
0c57d8c
·
verified ·
0 Parent(s):

Super-squash branch 'main' using huggingface_hub

Browse files

Co-authored-by: r3gm <[email protected]>

.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: 🧩 DiffuseCraft Mod
3
+ emoji: 🧩🖼️📦
4
+ colorFrom: red
5
+ colorTo: pink
6
+ sdk: gradio
7
+ sdk_version: 4.37.2
8
+ app_file: app.py
9
+ pinned: false
10
+ header: mini
11
+ license: mit
12
+ short_description: Stunning images using stable diffusion.
13
+ ---
14
+
15
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,1896 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import os
3
+ from stablepy import Model_Diffusers
4
+ from stablepy.diffusers_vanilla.model import scheduler_names
5
+ from stablepy.diffusers_vanilla.style_prompt_config import STYLE_NAMES
6
+ import torch
7
+ import re
8
+ import shutil
9
+ import random
10
+ from stablepy import (
11
+ CONTROLNET_MODEL_IDS,
12
+ VALID_TASKS,
13
+ T2I_PREPROCESSOR_NAME,
14
+ FLASH_LORA,
15
+ SCHEDULER_CONFIG_MAP,
16
+ scheduler_names,
17
+ IP_ADAPTER_MODELS,
18
+ IP_ADAPTERS_SD,
19
+ IP_ADAPTERS_SDXL,
20
+ REPO_IMAGE_ENCODER,
21
+ ALL_PROMPT_WEIGHT_OPTIONS,
22
+ SD15_TASKS,
23
+ SDXL_TASKS,
24
+ )
25
+ import urllib.parse
26
+
27
+ preprocessor_controlnet = {
28
+ "openpose": [
29
+ "Openpose",
30
+ "None",
31
+ ],
32
+ "scribble": [
33
+ "HED",
34
+ "Pidinet",
35
+ "None",
36
+ ],
37
+ "softedge": [
38
+ "Pidinet",
39
+ "HED",
40
+ "HED safe",
41
+ "Pidinet safe",
42
+ "None",
43
+ ],
44
+ "segmentation": [
45
+ "UPerNet",
46
+ "None",
47
+ ],
48
+ "depth": [
49
+ "DPT",
50
+ "Midas",
51
+ "None",
52
+ ],
53
+ "normalbae": [
54
+ "NormalBae",
55
+ "None",
56
+ ],
57
+ "lineart": [
58
+ "Lineart",
59
+ "Lineart coarse",
60
+ "Lineart (anime)",
61
+ "None",
62
+ "None (anime)",
63
+ ],
64
+ "shuffle": [
65
+ "ContentShuffle",
66
+ "None",
67
+ ],
68
+ "canny": [
69
+ "Canny"
70
+ ],
71
+ "mlsd": [
72
+ "MLSD"
73
+ ],
74
+ "ip2p": [
75
+ "ip2p"
76
+ ]
77
+ }
78
+
79
+ task_stablepy = {
80
+ 'txt2img': 'txt2img',
81
+ 'img2img': 'img2img',
82
+ 'inpaint': 'inpaint',
83
+ # 'canny T2I Adapter': 'sdxl_canny_t2i', # NO HAVE STEP CALLBACK PARAMETERS SO NOT WORKS WITH DIFFUSERS 0.29.0
84
+ # 'sketch T2I Adapter': 'sdxl_sketch_t2i',
85
+ # 'lineart T2I Adapter': 'sdxl_lineart_t2i',
86
+ # 'depth-midas T2I Adapter': 'sdxl_depth-midas_t2i',
87
+ # 'openpose T2I Adapter': 'sdxl_openpose_t2i',
88
+ 'openpose ControlNet': 'openpose',
89
+ 'canny ControlNet': 'canny',
90
+ 'mlsd ControlNet': 'mlsd',
91
+ 'scribble ControlNet': 'scribble',
92
+ 'softedge ControlNet': 'softedge',
93
+ 'segmentation ControlNet': 'segmentation',
94
+ 'depth ControlNet': 'depth',
95
+ 'normalbae ControlNet': 'normalbae',
96
+ 'lineart ControlNet': 'lineart',
97
+ # 'lineart_anime ControlNet': 'lineart_anime',
98
+ 'shuffle ControlNet': 'shuffle',
99
+ 'ip2p ControlNet': 'ip2p',
100
+ 'optical pattern ControlNet': 'pattern',
101
+ 'tile realistic': 'sdxl_tile_realistic',
102
+ }
103
+
104
+ task_model_list = list(task_stablepy.keys())
105
+
106
+
107
+ def download_things(directory, url, hf_token="", civitai_api_key=""):
108
+ url = url.strip()
109
+
110
+ if "drive.google.com" in url:
111
+ original_dir = os.getcwd()
112
+ os.chdir(directory)
113
+ os.system(f"gdown --fuzzy {url}")
114
+ os.chdir(original_dir)
115
+ elif "huggingface.co" in url:
116
+ url = url.replace("?download=true", "")
117
+ # url = urllib.parse.quote(url, safe=':/') # fix encoding
118
+ if "/blob/" in url:
119
+ url = url.replace("/blob/", "/resolve/")
120
+ user_header = f'"Authorization: Bearer {hf_token}"'
121
+ if hf_token:
122
+ os.system(f"aria2c --console-log-level=error --summary-interval=10 --header={user_header} -c -x 16 -k 1M -s 16 {url} -d {directory} -o {url.split('/')[-1]}")
123
+ else:
124
+ os.system (f"aria2c --optimize-concurrent-downloads --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 {url} -d {directory} -o {url.split('/')[-1]}")
125
+ elif "civitai.com" in url:
126
+ if "?" in url:
127
+ url = url.split("?")[0]
128
+ if civitai_api_key:
129
+ url = url + f"?token={civitai_api_key}"
130
+ os.system(f"aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d {directory} {url}")
131
+ else:
132
+ print("\033[91mYou need an API key to download Civitai models.\033[0m")
133
+ else:
134
+ os.system(f"aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d {directory} {url}")
135
+
136
+
137
+ def get_model_list(directory_path):
138
+ model_list = []
139
+ valid_extensions = {'.ckpt' , '.pt', '.pth', '.safetensors', '.bin'}
140
+
141
+ for filename in os.listdir(directory_path):
142
+ if os.path.splitext(filename)[1] in valid_extensions:
143
+ name_without_extension = os.path.splitext(filename)[0]
144
+ file_path = os.path.join(directory_path, filename)
145
+ # model_list.append((name_without_extension, file_path))
146
+ model_list.append(file_path)
147
+ print('\033[34mFILE: ' + file_path + '\033[0m')
148
+ return model_list
149
+
150
+
151
+ def process_string(input_string):
152
+ parts = input_string.split('/')
153
+
154
+ if len(parts) == 2:
155
+ first_element = parts[1]
156
+ complete_string = input_string
157
+ result = (first_element, complete_string)
158
+ return result
159
+ else:
160
+ return None
161
+
162
+
163
+ directory_models = 'models'
164
+ os.makedirs(directory_models, exist_ok=True)
165
+ directory_loras = 'loras'
166
+ os.makedirs(directory_loras, exist_ok=True)
167
+ directory_vaes = 'vaes'
168
+ os.makedirs(directory_vaes, exist_ok=True)
169
+
170
+
171
+ ## BEGIN MOD
172
+ from modutils import (
173
+ download_private_repo,
174
+ get_private_model_lists,
175
+ get_model_id_list,
176
+ list_uniq,
177
+ get_tupled_embed_list,
178
+ update_lora_dict,
179
+ )
180
+
181
+ # - **Download SD 1.5 Models**
182
+ #download_model = "https://huggingface.co/frankjoshua/toonyou_beta6/resolve/main/toonyou_beta6.safetensors"
183
+ download_model = ""
184
+ # - **Download VAEs**
185
+ download_vae_list = [
186
+ 'https://huggingface.co/madebyollin/sdxl-vae-fp16-fix/resolve/main/sdxl.vae.safetensors?download=true',
187
+ 'https://huggingface.co/nubby/blessed-sdxl-vae-fp16-fix/resolve/main/sdxl_vae-fp16fix-c-1.1-b-0.5.safetensors?download=true',
188
+ 'https://huggingface.co/nubby/blessed-sdxl-vae-fp16-fix/resolve/main/sdxl_vae-fp16fix-blessed.safetensors?download=true',
189
+ ]
190
+ download_vae = ", ".join(download_vae_list)
191
+ # - **Download LoRAs**
192
+ download_lora_list = []
193
+ download_lora = ", ".join(download_lora_list)
194
+
195
+ HF_LORA_PRIVATE_REPOS = ['John6666/loratest1', 'John6666/loratest3', 'John6666/loratest4', 'John6666/loratest']
196
+ HF_LORA_ESSENTIAL_PRIVATE_REPO = 'John6666/loratest1'
197
+ download_private_repo(HF_LORA_ESSENTIAL_PRIVATE_REPO, directory_loras, True)
198
+ download_private_repo('John6666/vaetest', directory_vaes, False)
199
+
200
+ load_diffusers_format_model = [
201
+ 'stabilityai/stable-diffusion-xl-base-1.0',
202
+ 'cagliostrolab/animagine-xl-3.1',
203
+ 'misri/epicrealismXL_v7FinalDestination',
204
+ 'misri/juggernautXL_juggernautX',
205
+ 'misri/zavychromaxl_v80',
206
+ 'SG161222/RealVisXL_V4.0',
207
+ 'misri/newrealityxlAllInOne_Newreality40',
208
+ 'eienmojiki/Anything-XL',
209
+ 'eienmojiki/Starry-XL-v5.2',
210
+ 'gsdf/CounterfeitXL',
211
+ 'kitty7779/ponyDiffusionV6XL',
212
+ 'John6666/ebara-mfcg-pony-mix-v12-sdxl',
213
+ 'John6666/t-ponynai3-v51-sdxl',
214
+ 'yodayo-ai/kivotos-xl-2.0',
215
+ 'yodayo-ai/holodayo-xl-2.1',
216
+ 'digiplay/majicMIX_sombre_v2',
217
+ 'digiplay/majicMIX_realistic_v6',
218
+ 'digiplay/majicMIX_realistic_v7',
219
+ 'digiplay/DreamShaper_8',
220
+ 'digiplay/BeautifulArt_v1',
221
+ 'digiplay/DarkSushi2.5D_v1',
222
+ 'digiplay/darkphoenix3D_v1.1',
223
+ 'digiplay/BeenYouLiteL11_diffusers',
224
+ 'rubbrband/revAnimated_v2Rebirth',
225
+ 'youknownothing/cyberrealistic_v50',
226
+ 'votepurchase/counterfeitV30_v30',
227
+ 'Meina/MeinaMix_V11',
228
+ 'Meina/MeinaUnreal_V5',
229
+ 'Meina/MeinaPastel_V7',
230
+ 'rubbrband/realcartoon3d_v16',
231
+ 'rubbrband/realcartoonRealistic_v14',
232
+ 'KBlueLeaf/Kohaku-XL-Epsilon-rev2',
233
+ 'Raelina/Rae-Diffusion-XL-V2',
234
+ ]
235
+
236
+ load_diffusers_format_model = list_uniq(get_model_id_list() + load_diffusers_format_model)
237
+ ## END MOD
238
+
239
+ CIVITAI_API_KEY = os.environ.get("CIVITAI_API_KEY")
240
+ hf_token = os.environ.get("HF_TOKEN")
241
+
242
+ # Download stuffs
243
+ for url in [url.strip() for url in download_model.split(',')]:
244
+ if not os.path.exists(f"./models/{url.split('/')[-1]}"):
245
+ download_things(directory_models, url, hf_token, CIVITAI_API_KEY)
246
+ for url in [url.strip() for url in download_vae.split(',')]:
247
+ if not os.path.exists(f"./vaes/{url.split('/')[-1]}"):
248
+ download_things(directory_vaes, url, hf_token, CIVITAI_API_KEY)
249
+ for url in [url.strip() for url in download_lora.split(',')]:
250
+ if not os.path.exists(f"./loras/{url.split('/')[-1]}"):
251
+ download_things(directory_loras, url, hf_token, CIVITAI_API_KEY)
252
+
253
+ # Download Embeddings
254
+ directory_embeds = 'embedings'
255
+ os.makedirs(directory_embeds, exist_ok=True)
256
+ download_embeds = [
257
+ 'https://huggingface.co/datasets/Nerfgun3/bad_prompt/blob/main/bad_prompt_version2.pt',
258
+ 'https://huggingface.co/embed/negative/resolve/main/EasyNegativeV2.safetensors',
259
+ 'https://huggingface.co/embed/negative/resolve/main/bad-hands-5.pt',
260
+ ]
261
+
262
+ for url_embed in download_embeds:
263
+ if not os.path.exists(f"./embedings/{url_embed.split('/')[-1]}"):
264
+ download_things(directory_embeds, url_embed, hf_token, CIVITAI_API_KEY)
265
+
266
+ # Build list models
267
+ embed_list = get_model_list(directory_embeds)
268
+ model_list = get_model_list(directory_models)
269
+ model_list = load_diffusers_format_model + model_list
270
+ lora_model_list = list_uniq(get_private_model_lists(HF_LORA_PRIVATE_REPOS, directory_loras) + get_model_list(directory_loras))
271
+ lora_model_list.insert(0, "None")
272
+ vae_model_list = get_model_list(directory_vaes)
273
+ vae_model_list.insert(0, "None")
274
+
275
+ ## BEGIN MOD
276
+ directory_embeds_sdxl = 'embedings_xl'
277
+ os.makedirs(directory_embeds_sdxl, exist_ok=True)
278
+ download_private_repo('John6666/embeddingstest', directory_embeds_sdxl, False)
279
+ directory_embeds_postitive_sdxl = 'embedings_xl/positive'
280
+ os.makedirs(directory_embeds_postitive_sdxl, exist_ok=True)
281
+ download_private_repo('John6666/embeddingspositivetest', directory_embeds_postitive_sdxl, False)
282
+ embed_sdxl_list = get_model_list(directory_embeds_sdxl) + get_model_list(directory_embeds_postitive_sdxl)
283
+
284
+ def get_embed_list(pipeline_name):
285
+ return get_tupled_embed_list(embed_sdxl_list if pipeline_name == "StableDiffusionXLPipeline" else embed_list)
286
+
287
+ def get_my_lora(link_url):
288
+ for url in [url.strip() for url in link_url.split(',')]:
289
+ if not os.path.exists(f"./loras/{url.split('/')[-1]}"):
290
+ download_things(directory_loras, url, hf_token, CIVITAI_API_KEY)
291
+ update_lora_dict(f"./loras/{url.split('/')[-1]}")
292
+ new_lora_model_list = list_uniq(get_private_model_lists(HF_LORA_PRIVATE_REPOS, directory_loras) + get_model_list(directory_loras))
293
+ new_lora_model_list.insert(0, "None")
294
+
295
+ return gr.update(
296
+ choices=get_lora_tupled_list(new_lora_model_list)
297
+ ), gr.update(
298
+ choices=get_lora_tupled_list(new_lora_model_list)
299
+ ), gr.update(
300
+ choices=get_lora_tupled_list(new_lora_model_list)
301
+ ), gr.update(
302
+ choices=get_lora_tupled_list(new_lora_model_list)
303
+ ), gr.update(
304
+ choices=get_lora_tupled_list(new_lora_model_list)
305
+ ),
306
+ ## END MOD
307
+
308
+ print('\033[33m🏁 Download and listing of valid models completed.\033[0m')
309
+
310
+ upscaler_dict_gui = {
311
+ None : None,
312
+ "Lanczos" : "Lanczos",
313
+ "Nearest" : "Nearest",
314
+ "RealESRGAN_x4plus" : "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
315
+ "RealESRNet_x4plus" : "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth",
316
+ "RealESRGAN_x4plus_anime_6B": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
317
+ "RealESRGAN_x2plus": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
318
+ "realesr-animevideov3": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth",
319
+ "realesr-general-x4v3": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth",
320
+ "realesr-general-wdn-x4v3" : "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth",
321
+ "4x-UltraSharp" : "https://huggingface.co/Shandypur/ESRGAN-4x-UltraSharp/resolve/main/4x-UltraSharp.pth",
322
+ "4x_foolhardy_Remacri" : "https://huggingface.co/FacehugmanIII/4x_foolhardy_Remacri/resolve/main/4x_foolhardy_Remacri.pth",
323
+ "Remacri4xExtraSmoother" : "https://huggingface.co/hollowstrawberry/upscalers-backup/resolve/main/ESRGAN/Remacri%204x%20ExtraSmoother.pth",
324
+ "AnimeSharp4x" : "https://huggingface.co/hollowstrawberry/upscalers-backup/resolve/main/ESRGAN/AnimeSharp%204x.pth",
325
+ "lollypop" : "https://huggingface.co/hollowstrawberry/upscalers-backup/resolve/main/ESRGAN/lollypop.pth",
326
+ "RealisticRescaler4x" : "https://huggingface.co/hollowstrawberry/upscalers-backup/resolve/main/ESRGAN/RealisticRescaler%204x.pth",
327
+ "NickelbackFS4x" : "https://huggingface.co/hollowstrawberry/upscalers-backup/resolve/main/ESRGAN/NickelbackFS%204x.pth"
328
+ }
329
+
330
+
331
+ def extract_parameters(input_string):
332
+ parameters = {}
333
+ input_string = input_string.replace("\n", "")
334
+
335
+ if not "Negative prompt:" in input_string:
336
+ print("Negative prompt not detected")
337
+ parameters["prompt"] = input_string
338
+ return parameters
339
+
340
+ parm = input_string.split("Negative prompt:")
341
+ parameters["prompt"] = parm[0]
342
+ if not "Steps:" in parm[1]:
343
+ print("Steps not detected")
344
+ parameters["neg_prompt"] = parm[1]
345
+ return parameters
346
+ parm = parm[1].split("Steps:")
347
+ parameters["neg_prompt"] = parm[0]
348
+ input_string = "Steps:" + parm[1]
349
+
350
+ # Extracting Steps
351
+ steps_match = re.search(r'Steps: (\d+)', input_string)
352
+ if steps_match:
353
+ parameters['Steps'] = int(steps_match.group(1))
354
+
355
+ # Extracting Size
356
+ size_match = re.search(r'Size: (\d+x\d+)', input_string)
357
+ if size_match:
358
+ parameters['Size'] = size_match.group(1)
359
+ width, height = map(int, parameters['Size'].split('x'))
360
+ parameters['width'] = width
361
+ parameters['height'] = height
362
+
363
+ # Extracting other parameters
364
+ other_parameters = re.findall(r'(\w+): (.*?)(?=, \w+|$)', input_string)
365
+ for param in other_parameters:
366
+ parameters[param[0]] = param[1].strip('"')
367
+
368
+ return parameters
369
+
370
+
371
+ #######################
372
+ # GUI
373
+ #######################
374
+ import spaces
375
+ import gradio as gr
376
+ from PIL import Image
377
+ import IPython.display
378
+ import time, json
379
+ from IPython.utils import capture
380
+ import logging
381
+ logging.getLogger("diffusers").setLevel(logging.ERROR)
382
+ import diffusers
383
+ diffusers.utils.logging.set_verbosity(40)
384
+ import warnings
385
+ warnings.filterwarnings(action="ignore", category=FutureWarning, module="diffusers")
386
+ warnings.filterwarnings(action="ignore", category=UserWarning, module="diffusers")
387
+ warnings.filterwarnings(action="ignore", category=FutureWarning, module="transformers")
388
+ ## BEGIN MOD
389
+ from stablepy import logger
390
+ logger.setLevel(logging.CRITICAL)
391
+
392
+ from transformers.utils.hub import move_cache
393
+ move_cache()
394
+
395
+ from v2 import (
396
+ V2UI,
397
+ parse_upsampling_output,
398
+ V2_ALL_MODELS,
399
+ )
400
+ from utils import (
401
+ gradio_copy_text,
402
+ COPY_ACTION_JS,
403
+ V2_ASPECT_RATIO_OPTIONS,
404
+ V2_RATING_OPTIONS,
405
+ V2_LENGTH_OPTIONS,
406
+ V2_IDENTITY_OPTIONS
407
+ )
408
+ from tagger import (
409
+ predict_tags_wd,
410
+ convert_danbooru_to_e621_prompt,
411
+ remove_specific_prompt,
412
+ insert_recom_prompt,
413
+ compose_prompt_to_copy,
414
+ translate_prompt,
415
+ )
416
+ from modutils import (
417
+ get_t2i_model_info,
418
+ get_tupled_model_list,
419
+ save_gallery_images,
420
+ upload_file_lora,
421
+ move_file_lora,
422
+ set_lora_trigger,
423
+ set_lora_prompt,
424
+ get_lora_tupled_list,
425
+ apply_lora_prompt,
426
+ set_textual_inversion_prompt,
427
+ get_model_pipeline,
428
+ set_optimization,
429
+ set_sampler_settings,
430
+ process_style_prompt,
431
+ optimization_list,
432
+ preset_styles,
433
+ preset_quality,
434
+ preset_sampler_setting,
435
+ set_quick_presets,
436
+ )
437
+ def description_ui():
438
+ gr.Markdown(
439
+ """
440
+ ## Danbooru Tags Transformer V2 Demo with WD Tagger
441
+ (Image =>) Prompt => Upsampled longer prompt
442
+ - Mod of p1atdev's [Danbooru Tags Transformer V2 Demo](https://huggingface.co/spaces/p1atdev/danbooru-tags-transformer-v2) and [WD Tagger with 🤗 transformers](https://huggingface.co/spaces/p1atdev/wd-tagger-transformers).
443
+ - Models: p1atdev's [wd-swinv2-tagger-v3-hf](https://huggingface.co/p1atdev/wd-swinv2-tagger-v3-hf), [dart-v2-moe-sft](https://huggingface.co/p1atdev/dart-v2-moe-sft)
444
+ """
445
+ )
446
+ ## END MOD
447
+
448
+
449
+ def info_html(json_data, title, subtitle):
450
+ return f"""
451
+ <div style='padding: 0; border-radius: 10px;'>
452
+ <p style='margin: 0; font-weight: bold;'>{title}</p>
453
+ <details>
454
+ <summary>Details</summary>
455
+ <p style='margin: 0; font-weight: bold;'>{subtitle}</p>
456
+ </details>
457
+ </div>
458
+ """
459
+
460
+ class GuiSD:
461
+ def __init__(self, stream=True):
462
+ self.model = None
463
+
464
+ print("Loading model...")
465
+ self.model = Model_Diffusers(
466
+ base_model_id="cagliostrolab/animagine-xl-3.1",
467
+ task_name="txt2img",
468
+ vae_model=None,
469
+ type_model_precision=torch.float16,
470
+ retain_task_model_in_cache=False,
471
+ )
472
+
473
+ def load_new_model(self, model_name, vae_model, task, progress=gr.Progress(track_tqdm=True)):
474
+
475
+ yield f"Loading model: {model_name}"
476
+
477
+ vae_model = vae_model if vae_model != "None" else None
478
+
479
+ if model_name in model_list:
480
+ model_is_xl = "xl" in model_name.lower()
481
+ sdxl_in_vae = vae_model and "sdxl" in vae_model.lower()
482
+ model_type = "SDXL" if model_is_xl else "SD 1.5"
483
+ incompatible_vae = (model_is_xl and vae_model and not sdxl_in_vae) or (not model_is_xl and sdxl_in_vae)
484
+
485
+ if incompatible_vae:
486
+ vae_model = None
487
+
488
+
489
+ self.model.load_pipe(
490
+ model_name,
491
+ task_name=task_stablepy[task],
492
+ vae_model=vae_model if vae_model != "None" else None,
493
+ type_model_precision=torch.float16,
494
+ retain_task_model_in_cache=False,
495
+ )
496
+ yield f"Model loaded: {model_name}"
497
+
498
+ @spaces.GPU
499
+ def generate_pipeline(
500
+ self,
501
+ prompt,
502
+ neg_prompt,
503
+ num_images,
504
+ steps,
505
+ cfg,
506
+ clip_skip,
507
+ seed,
508
+ lora1,
509
+ lora_scale1,
510
+ lora2,
511
+ lora_scale2,
512
+ lora3,
513
+ lora_scale3,
514
+ lora4,
515
+ lora_scale4,
516
+ lora5,
517
+ lora_scale5,
518
+ sampler,
519
+ img_height,
520
+ img_width,
521
+ model_name,
522
+ vae_model,
523
+ task,
524
+ image_control,
525
+ preprocessor_name,
526
+ preprocess_resolution,
527
+ image_resolution,
528
+ style_prompt, # list []
529
+ style_json_file,
530
+ image_mask,
531
+ strength,
532
+ low_threshold,
533
+ high_threshold,
534
+ value_threshold,
535
+ distance_threshold,
536
+ controlnet_output_scaling_in_unet,
537
+ controlnet_start_threshold,
538
+ controlnet_stop_threshold,
539
+ textual_inversion,
540
+ syntax_weights,
541
+ upscaler_model_path,
542
+ upscaler_increases_size,
543
+ esrgan_tile,
544
+ esrgan_tile_overlap,
545
+ hires_steps,
546
+ hires_denoising_strength,
547
+ hires_sampler,
548
+ hires_prompt,
549
+ hires_negative_prompt,
550
+ hires_before_adetailer,
551
+ hires_after_adetailer,
552
+ loop_generation,
553
+ leave_progress_bar,
554
+ disable_progress_bar,
555
+ image_previews,
556
+ display_images,
557
+ save_generated_images,
558
+ image_storage_location,
559
+ retain_compel_previous_load,
560
+ retain_detailfix_model_previous_load,
561
+ retain_hires_model_previous_load,
562
+ t2i_adapter_preprocessor,
563
+ t2i_adapter_conditioning_scale,
564
+ t2i_adapter_conditioning_factor,
565
+ xformers_memory_efficient_attention,
566
+ freeu,
567
+ generator_in_cpu,
568
+ adetailer_inpaint_only,
569
+ adetailer_verbose,
570
+ adetailer_sampler,
571
+ adetailer_active_a,
572
+ prompt_ad_a,
573
+ negative_prompt_ad_a,
574
+ strength_ad_a,
575
+ face_detector_ad_a,
576
+ person_detector_ad_a,
577
+ hand_detector_ad_a,
578
+ mask_dilation_a,
579
+ mask_blur_a,
580
+ mask_padding_a,
581
+ adetailer_active_b,
582
+ prompt_ad_b,
583
+ negative_prompt_ad_b,
584
+ strength_ad_b,
585
+ face_detector_ad_b,
586
+ person_detector_ad_b,
587
+ hand_detector_ad_b,
588
+ mask_dilation_b,
589
+ mask_blur_b,
590
+ mask_padding_b,
591
+ retain_task_cache_gui,
592
+ image_ip1,
593
+ mask_ip1,
594
+ model_ip1,
595
+ mode_ip1,
596
+ scale_ip1,
597
+ image_ip2,
598
+ mask_ip2,
599
+ model_ip2,
600
+ mode_ip2,
601
+ scale_ip2,
602
+ ):
603
+
604
+ vae_model = vae_model if vae_model != "None" else None
605
+ loras_list = [lora1, lora2, lora3, lora4, lora5]
606
+ vae_msg = f"VAE: {vae_model}" if vae_model else ""
607
+ msg_lora = []
608
+
609
+
610
+ if model_name in model_list:
611
+ model_is_xl = "xl" in model_name.lower()
612
+ sdxl_in_vae = vae_model and "sdxl" in vae_model.lower()
613
+ model_type = "SDXL" if model_is_xl else "SD 1.5"
614
+ incompatible_vae = (model_is_xl and vae_model and not sdxl_in_vae) or (not model_is_xl and sdxl_in_vae)
615
+
616
+ if incompatible_vae:
617
+ msg_inc_vae = (
618
+ f"The selected VAE is for a { 'SD 1.5' if model_is_xl else 'SDXL' } model, but you"
619
+ f" are using a { model_type } model. The default VAE "
620
+ "will be used."
621
+ )
622
+ gr.Info(msg_inc_vae)
623
+ vae_msg = msg_inc_vae
624
+ vae_model = None
625
+
626
+ for la in loras_list:
627
+ if la is not None and la != "None" and la in lora_model_list:
628
+ print(la)
629
+ lora_type = ("animetarot" in la.lower() or "Hyper-SD15-8steps".lower() in la.lower())
630
+ if (model_is_xl and lora_type) or (not model_is_xl and not lora_type):
631
+ msg_inc_lora = f"The LoRA {la} is for { 'SD 1.5' if model_is_xl else 'SDXL' }, but you are using { model_type }."
632
+ gr.Info(msg_inc_lora)
633
+ msg_lora.append(msg_inc_lora)
634
+
635
+ task = task_stablepy[task]
636
+
637
+ params_ip_img = []
638
+ params_ip_msk = []
639
+ params_ip_model = []
640
+ params_ip_mode = []
641
+ params_ip_scale = []
642
+
643
+ all_adapters = [
644
+ (image_ip1, mask_ip1, model_ip1, mode_ip1, scale_ip1),
645
+ (image_ip2, mask_ip2, model_ip2, mode_ip2, scale_ip2),
646
+ ]
647
+
648
+ for imgip, mskip, modelip, modeip, scaleip in all_adapters:
649
+ if imgip:
650
+ params_ip_img.append(imgip)
651
+ if mskip:
652
+ params_ip_msk.append(mskip)
653
+ params_ip_model.append(modelip)
654
+ params_ip_mode.append(modeip)
655
+ params_ip_scale.append(scaleip)
656
+
657
+ # First load
658
+ model_precision = torch.float16
659
+ if not self.model:
660
+ from modelstream import Model_Diffusers2
661
+
662
+ print("Loading model...")
663
+ self.model = Model_Diffusers2(
664
+ base_model_id=model_name,
665
+ task_name=task,
666
+ vae_model=vae_model if vae_model != "None" else None,
667
+ type_model_precision=model_precision,
668
+ retain_task_model_in_cache=retain_task_cache_gui,
669
+ )
670
+
671
+ if task != "txt2img" and not image_control:
672
+ raise ValueError("No control image found: To use this function, you have to upload an image in 'Image ControlNet/Inpaint/Img2img'")
673
+
674
+ if task == "inpaint" and not image_mask:
675
+ raise ValueError("No mask image found: Specify one in 'Image Mask'")
676
+
677
+ if upscaler_model_path in [None, "Lanczos", "Nearest"]:
678
+ upscaler_model = upscaler_model_path
679
+ else:
680
+ directory_upscalers = 'upscalers'
681
+ os.makedirs(directory_upscalers, exist_ok=True)
682
+
683
+ url_upscaler = upscaler_dict_gui[upscaler_model_path]
684
+
685
+ if not os.path.exists(f"./upscalers/{url_upscaler.split('/')[-1]}"):
686
+ download_things(directory_upscalers, url_upscaler, hf_token)
687
+
688
+ upscaler_model = f"./upscalers/{url_upscaler.split('/')[-1]}"
689
+
690
+ logging.getLogger("ultralytics").setLevel(logging.INFO if adetailer_verbose else logging.ERROR)
691
+
692
+ print("Config model:", model_name, vae_model, loras_list)
693
+
694
+ self.model.load_pipe(
695
+ model_name,
696
+ task_name=task,
697
+ vae_model=vae_model if vae_model != "None" else None,
698
+ type_model_precision=model_precision,
699
+ retain_task_model_in_cache=retain_task_cache_gui,
700
+ )
701
+
702
+ ## BEGIN MOD
703
+ # if textual_inversion and self.model.class_name == "StableDiffusionXLPipeline":
704
+ # print("No Textual inversion for SDXL")
705
+ ## END MOD
706
+
707
+ adetailer_params_A = {
708
+ "face_detector_ad" : face_detector_ad_a,
709
+ "person_detector_ad" : person_detector_ad_a,
710
+ "hand_detector_ad" : hand_detector_ad_a,
711
+ "prompt": prompt_ad_a,
712
+ "negative_prompt" : negative_prompt_ad_a,
713
+ "strength" : strength_ad_a,
714
+ # "image_list_task" : None,
715
+ "mask_dilation" : mask_dilation_a,
716
+ "mask_blur" : mask_blur_a,
717
+ "mask_padding" : mask_padding_a,
718
+ "inpaint_only" : adetailer_inpaint_only,
719
+ "sampler" : adetailer_sampler,
720
+ }
721
+
722
+ adetailer_params_B = {
723
+ "face_detector_ad" : face_detector_ad_b,
724
+ "person_detector_ad" : person_detector_ad_b,
725
+ "hand_detector_ad" : hand_detector_ad_b,
726
+ "prompt": prompt_ad_b,
727
+ "negative_prompt" : negative_prompt_ad_b,
728
+ "strength" : strength_ad_b,
729
+ # "image_list_task" : None,
730
+ "mask_dilation" : mask_dilation_b,
731
+ "mask_blur" : mask_blur_b,
732
+ "mask_padding" : mask_padding_b,
733
+ }
734
+ pipe_params = {
735
+ "prompt": prompt,
736
+ "negative_prompt": neg_prompt,
737
+ "img_height": img_height,
738
+ "img_width": img_width,
739
+ "num_images": num_images,
740
+ "num_steps": steps,
741
+ "guidance_scale": cfg,
742
+ "clip_skip": clip_skip,
743
+ "seed": seed,
744
+ "image": image_control,
745
+ "preprocessor_name": preprocessor_name,
746
+ "preprocess_resolution": preprocess_resolution,
747
+ "image_resolution": image_resolution,
748
+ "style_prompt": style_prompt if style_prompt else "",
749
+ "style_json_file": "",
750
+ "image_mask": image_mask, # only for Inpaint
751
+ "strength": strength, # only for Inpaint or ...
752
+ "low_threshold": low_threshold,
753
+ "high_threshold": high_threshold,
754
+ "value_threshold": value_threshold,
755
+ "distance_threshold": distance_threshold,
756
+ "lora_A": lora1 if lora1 != "None" else None,
757
+ "lora_scale_A": lora_scale1,
758
+ "lora_B": lora2 if lora2 != "None" else None,
759
+ "lora_scale_B": lora_scale2,
760
+ "lora_C": lora3 if lora3 != "None" else None,
761
+ "lora_scale_C": lora_scale3,
762
+ "lora_D": lora4 if lora4 != "None" else None,
763
+ "lora_scale_D": lora_scale4,
764
+ "lora_E": lora5 if lora5 != "None" else None,
765
+ "lora_scale_E": lora_scale5,
766
+ ## BEGIN MOD
767
+ "textual_inversion": get_embed_list(self.model.class_name) if textual_inversion else [],
768
+ ## END MOD
769
+ "syntax_weights": syntax_weights, # "Classic"
770
+ "sampler": sampler,
771
+ "xformers_memory_efficient_attention": xformers_memory_efficient_attention,
772
+ "gui_active": True,
773
+ "loop_generation": loop_generation,
774
+ "controlnet_conditioning_scale": float(controlnet_output_scaling_in_unet),
775
+ "control_guidance_start": float(controlnet_start_threshold),
776
+ "control_guidance_end": float(controlnet_stop_threshold),
777
+ "generator_in_cpu": generator_in_cpu,
778
+ "FreeU": freeu,
779
+ "adetailer_A": adetailer_active_a,
780
+ "adetailer_A_params": adetailer_params_A,
781
+ "adetailer_B": adetailer_active_b,
782
+ "adetailer_B_params": adetailer_params_B,
783
+ "leave_progress_bar": leave_progress_bar,
784
+ "disable_progress_bar": disable_progress_bar,
785
+ "image_previews": image_previews,
786
+ "display_images": display_images,
787
+ "save_generated_images": save_generated_images,
788
+ "image_storage_location": image_storage_location,
789
+ "retain_compel_previous_load": retain_compel_previous_load,
790
+ "retain_detailfix_model_previous_load": retain_detailfix_model_previous_load,
791
+ "retain_hires_model_previous_load": retain_hires_model_previous_load,
792
+ "t2i_adapter_preprocessor": t2i_adapter_preprocessor,
793
+ "t2i_adapter_conditioning_scale": float(t2i_adapter_conditioning_scale),
794
+ "t2i_adapter_conditioning_factor": float(t2i_adapter_conditioning_factor),
795
+ "upscaler_model_path": upscaler_model,
796
+ "upscaler_increases_size": upscaler_increases_size,
797
+ "esrgan_tile": esrgan_tile,
798
+ "esrgan_tile_overlap": esrgan_tile_overlap,
799
+ "hires_steps": hires_steps,
800
+ "hires_denoising_strength": hires_denoising_strength,
801
+ "hires_prompt": hires_prompt,
802
+ "hires_negative_prompt": hires_negative_prompt,
803
+ "hires_sampler": hires_sampler,
804
+ "hires_before_adetailer": hires_before_adetailer,
805
+ "hires_after_adetailer": hires_after_adetailer,
806
+ "ip_adapter_image": params_ip_img,
807
+ "ip_adapter_mask": params_ip_msk,
808
+ "ip_adapter_model": params_ip_model,
809
+ "ip_adapter_mode": params_ip_mode,
810
+ "ip_adapter_scale": params_ip_scale,
811
+ }
812
+
813
+ # print(pipe_params)
814
+
815
+ random_number = random.randint(1, 100)
816
+ if random_number < 25 and num_images < 3:
817
+ if not upscaler_model and steps < 45 and task in ["txt2img", "img2img"] and not adetailer_active_a and not adetailer_active_b:
818
+ num_images *=2
819
+ pipe_params["num_images"] = num_images
820
+ gr.Info("Num images x 2 🎉")
821
+
822
+ # Maybe fix lora issue: 'Cannot copy out of meta tensor; no data!''
823
+ self.model.pipe.to("cuda:0" if torch.cuda.is_available() else "cpu")
824
+
825
+ info_state = f"PROCESSING "
826
+ for img, seed, data in self.model(**pipe_params):
827
+ info_state += ">"
828
+ if data:
829
+ info_state = f"COMPLETED. Seeds: {str(seed)}"
830
+ if vae_msg:
831
+ info_state = info_state + "<br>" + vae_msg
832
+ if msg_lora:
833
+ info_state = info_state + "<br>" + "<br>".join(msg_lora)
834
+ yield img, info_state
835
+
836
+
837
+ sd_gen = GuiSD()
838
+
839
+ ## BEGIN MOD
840
+ CSS ="""
841
+ .contain { display: flex; flex-direction: column; }
842
+ #component-0 { height: 100%; }
843
+ #gallery { flex-grow: 1; }
844
+ """
845
+ ## END MOD
846
+
847
+ sdxl_task = [k for k, v in task_stablepy.items() if v in SDXL_TASKS ]
848
+ sd_task = [k for k, v in task_stablepy.items() if v in SD15_TASKS ]
849
+ def update_task_options(model_name, task_name):
850
+ if model_name in model_list:
851
+ if "xl" in model_name.lower():
852
+ new_choices = sdxl_task
853
+ else:
854
+ new_choices = sd_task
855
+
856
+ if task_name not in new_choices:
857
+ task_name = "txt2img"
858
+
859
+ return gr.update(value=task_name, choices=new_choices)
860
+ else:
861
+ return gr.update(value=task_name, choices=task_model_list)
862
+
863
+ ## BEGIN MOD
864
+ with gr.Blocks(theme="NoCrypt/miku", elem_id="main", css=CSS) as app:
865
+ gr.Markdown("# 🧩 DiffuseCraft Mod")
866
+ gr.Markdown(
867
+ f"""
868
+ This space is a modification of [r3gm's DiffuseCraft](https://huggingface.co/spaces/r3gm/DiffuseCraft).
869
+ """
870
+ )
871
+ with gr.Row():
872
+ with gr.Tab("Generation"):
873
+ v2b = V2UI()
874
+ with gr.Column(scale=2):
875
+ with gr.Accordion("Model and Task", open=False):
876
+ task_gui = gr.Dropdown(label="Task", choices=sdxl_task, value=task_model_list[0])
877
+ model_name_gui = gr.Dropdown(label="Model", info="You can enter a huggingface model repo_id to want to use.", choices=get_tupled_model_list(model_list), value="votepurchase/animagine-xl-3.1", allow_custom_value=True)
878
+ model_info_gui = gr.Markdown()
879
+ with gr.Row():
880
+ quick_model_type_gui = gr.Radio(label="Model Type", choices=["None", "Animagine", "Pony"], value="None", interactive=True)
881
+ quick_genre_gui = gr.Radio(label="Genre", choices=["Anime", "Photo"], value="Anime", interactive=True)
882
+ quick_speed_gui = gr.Radio(label="Speed", choices=["Fast", "Standard", "Heavy"], value="Standard", interactive=True)
883
+ quick_aspect_gui = gr.Radio(label="Aspect Ratio", choices=["1:1", "3:4"], value="3:4", interactive=True)
884
+ quality_selector_gui = gr.Dropdown(label="Quality Tags Presets", interactive=True, choices=list(preset_quality.keys()), value="None")
885
+ style_selector_gui = gr.Dropdown(label="Style Preset", interactive=True, choices=list(preset_styles.keys()), value="None")
886
+ sampler_selector_gui = gr.Dropdown(label="Sampler Quick Settings", interactive=True, choices=list(preset_sampler_setting.keys()), value="None")
887
+ optimization_gui = gr.Dropdown(label="Optimization for SDXL", choices=list(optimization_list.keys()), value="None", interactive=True)
888
+ with gr.Accordion("Generate prompt from Image", open=False):
889
+ input_image_gui = gr.Image(label="Input image", type="pil", sources=["upload", "clipboard"], height=256)
890
+ with gr.Accordion(label="Advanced options", open=False):
891
+ general_threshold_gui = gr.Slider(label="Threshold", minimum=0.0, maximum=1.0, value=0.3, step=0.01, interactive=True)
892
+ character_threshold_gui = gr.Slider(label="Character threshold", minimum=0.0, maximum=1.0, value=0.8, step=0.01, interactive=True)
893
+ tag_type_gui = gr.Radio(label="Convert tags to", info="danbooru for Animagine, e621 for Pony.", choices=["danbooru", "e621"], value="danbooru")
894
+ recom_prompt_gui = gr.Radio(label="Insert reccomended prompt", choices=["None", "Animagine", "Pony"], value="None", interactive=True)
895
+ keep_tags_gui = gr.Radio(label="Remove tags leaving only the following", choices=["body", "dress", "all"], value="all")
896
+ image_algorithms = gr.CheckboxGroup(["Use WD Tagger", "Use Florence-2-SD3-Long-Captioner"], label="Algorithms", value=["Use WD Tagger"], visible=False)
897
+ generate_from_image_btn_gui = gr.Button(value="GENERATE TAGS FROM IMAGE", size="lg", variant="primary")
898
+ with gr.Group():
899
+ prompt_gui = gr.Textbox(lines=6, placeholder="1girl, solo, ...", label="Prompt", show_copy_button=True)
900
+ with gr.Accordion("Negative prompt, etc.", open=False):
901
+ neg_prompt_gui = gr.Textbox(lines=3, placeholder="lowres, (bad), ...", label="Negative prompt", show_copy_button=True)
902
+ translate_prompt_button = gr.Button(value="Translate prompt to English", size="sm", variant="secondary")
903
+ insert_prompt_gui = gr.Radio(label="Insert reccomended positive / negative prompt", choices=["None", "Animagine", "Pony"], value="None", interactive=True, scale=2)
904
+ prompt_type_gui = gr.Radio(label="Convert tags to", choices=["danbooru", "e621"], value="e621", visible=False)
905
+ prompt_type_button = gr.Button(value="Convert prompt to Pony e621 style", size="sm", variant="secondary")
906
+ with gr.Row():
907
+ character_dbt = gr.Textbox(lines=1, placeholder="kafuu chino, ...", label="Character names")
908
+ series_dbt = gr.Textbox(lines=1, placeholder="Is the order a rabbit?, ...", label="Series names")
909
+ generate_db_random_button = gr.Button(value="Generate random prompt from character", size="sm", variant="secondary")
910
+ model_name_dbt = gr.Dropdown(label="Model", choices=list(V2_ALL_MODELS.keys()), value=list(V2_ALL_MODELS.keys())[0], visible=False)
911
+ aspect_ratio_dbt = gr.Radio(label="Aspect ratio", choices=list(V2_ASPECT_RATIO_OPTIONS), value="square", visible=False)
912
+ length_dbt = gr.Radio(label="Length", choices=list(V2_LENGTH_OPTIONS), value="very_long", visible=False)
913
+ identity_dbt = gr.Radio(label="Keep identity", choices=list(V2_IDENTITY_OPTIONS), value="lax", visible=False)
914
+ ban_tags_dbt = gr.Textbox(label="Ban tags", placeholder="alternate costumen, ...", value="futanari, censored, furry, furrification", visible=False)
915
+ elapsed_time_dbt = gr.Markdown(label="Elapsed time", value="", visible=False)
916
+ copy_button_dbt = gr.Button(value="Copy to clipboard", visible=False)
917
+ rating_dbt = gr.Radio(label="Rating", choices=list(V2_RATING_OPTIONS), value="explicit")
918
+ with gr.Row():
919
+ set_params_gui = gr.Button(value="↙️")
920
+ clear_prompt_gui = gr.Button(value="🗑️")
921
+ set_random_seed = gr.Button(value="🎲")
922
+
923
+ generate_button = gr.Button(value="GENERATE IMAGE", size="lg", variant="primary")
924
+
925
+ model_name_gui.change(
926
+ update_task_options,
927
+ [model_name_gui, task_gui],
928
+ [task_gui],
929
+ )
930
+
931
+ load_model_gui = gr.HTML()
932
+
933
+ result_images = gr.Gallery(
934
+ label="Generated images",
935
+ show_label=False,
936
+ elem_id="gallery",
937
+ columns=[2],
938
+ rows=[2],
939
+ object_fit="contain",
940
+ # height="auto",
941
+ interactive=False,
942
+ preview=False,
943
+ show_share_button=False,
944
+ show_download_button=True,
945
+ selected_index=50,
946
+ format="png",
947
+ )
948
+
949
+ result_images_files = gr.Files(interactive=False, visible=False)
950
+
951
+ actual_task_info = gr.HTML()
952
+
953
+ with gr.Accordion("Generation settings", open=False, visible=True):
954
+ steps_gui = gr.Slider(minimum=1, maximum=100, step=1, value=28, label="Steps")
955
+ cfg_gui = gr.Slider(minimum=0, maximum=30, step=0.5, value=7.0, label="CFG")
956
+ sampler_gui = gr.Dropdown(label="Sampler", choices=scheduler_names, value="Euler a")
957
+ img_width_gui = gr.Slider(minimum=64, maximum=4096, step=8, value=1024, label="Img Width")
958
+ img_height_gui = gr.Slider(minimum=64, maximum=4096, step=8, value=1024, label="Img Height")
959
+ with gr.Row():
960
+ clip_skip_gui = gr.Checkbox(value=False, label="Layer 2 Clip Skip")
961
+ free_u_gui = gr.Checkbox(value=False, label="FreeU")
962
+ seed_gui = gr.Number(minimum=-1, maximum=9999999999, value=-1, label="Seed")
963
+
964
+ with gr.Row(equal_height=False):
965
+
966
+ def run_set_params_gui(base_prompt):
967
+ valid_receptors = { # default values
968
+ "prompt": gr.update(value=base_prompt),
969
+ "neg_prompt": gr.update(value=""),
970
+ "Steps": gr.update(value=30),
971
+ "width": gr.update(value=1024),
972
+ "height": gr.update(value=1024),
973
+ "Seed": gr.update(value=-1),
974
+ "Sampler": gr.update(value="Euler a"),
975
+ "scale": gr.update(value=7.5), # cfg
976
+ "skip": gr.update(value=True),
977
+ }
978
+ valid_keys = list(valid_receptors.keys())
979
+
980
+ parameters = extract_parameters(base_prompt)
981
+ for key, val in parameters.items():
982
+ # print(val)
983
+ if key in valid_keys:
984
+ if key == "Sampler":
985
+ if val not in scheduler_names:
986
+ continue
987
+ elif key == "skip":
988
+ if int(val) >= 2:
989
+ val = True
990
+ if key == "prompt":
991
+ if ">" in val and "<" in val:
992
+ val = re.sub(r'<[^>]+>', '', val)
993
+ print("Removed LoRA written in the prompt")
994
+ if key in ["prompt", "neg_prompt"]:
995
+ val = val.strip()
996
+ if key in ["Steps", "width", "height", "Seed"]:
997
+ val = int(val)
998
+ if key == "scale":
999
+ val = float(val)
1000
+ if key == "Seed":
1001
+ continue
1002
+ valid_receptors[key] = gr.update(value=val)
1003
+ # print(val, type(val))
1004
+ # print(valid_receptors)
1005
+ return [value for value in valid_receptors.values()]
1006
+
1007
+ set_params_gui.click(
1008
+ run_set_params_gui, [prompt_gui],[
1009
+ prompt_gui,
1010
+ neg_prompt_gui,
1011
+ steps_gui,
1012
+ img_width_gui,
1013
+ img_height_gui,
1014
+ seed_gui,
1015
+ sampler_gui,
1016
+ cfg_gui,
1017
+ clip_skip_gui,
1018
+ ],
1019
+ )
1020
+
1021
+ def run_clear_prompt_gui():
1022
+ return gr.update(value=""), gr.update(value="")
1023
+ clear_prompt_gui.click(
1024
+ run_clear_prompt_gui, [], [prompt_gui, neg_prompt_gui]
1025
+ )
1026
+
1027
+ def run_set_random_seed():
1028
+ return -1
1029
+ set_random_seed.click(
1030
+ run_set_random_seed, [], seed_gui
1031
+ )
1032
+
1033
+ num_images_gui = gr.Slider(minimum=1, maximum=4, step=1, value=1, label="Images")
1034
+ prompt_s_options = [
1035
+ ("Compel format: (word)weight", "Compel"),
1036
+ ("Classic format: (word:weight)", "Classic"),
1037
+ ("Classic-original format: (word:weight)", "Classic-original"),
1038
+ ("Classic-no_norm format: (word:weight)", "Classic-no_norm"),
1039
+ ("Classic-ignore", "Classic-ignore"),
1040
+ ("None", "None"),
1041
+ ]
1042
+ prompt_syntax_gui = gr.Dropdown(label="Prompt Syntax", choices=prompt_s_options, value=prompt_s_options[1][1])
1043
+ vae_model_gui = gr.Dropdown(label="VAE Model", choices=vae_model_list)
1044
+
1045
+ with gr.Accordion("Hires fix", open=False, visible=True):
1046
+
1047
+ upscaler_keys = list(upscaler_dict_gui.keys())
1048
+
1049
+ upscaler_model_path_gui = gr.Dropdown(label="Upscaler", choices=upscaler_keys, value=None)
1050
+ upscaler_increases_size_gui = gr.Slider(minimum=1.1, maximum=6., step=0.1, value=1.0, label="Upscale by")
1051
+ esrgan_tile_gui = gr.Slider(minimum=0, value=100, maximum=500, step=1, label="ESRGAN Tile")
1052
+ esrgan_tile_overlap_gui = gr.Slider(minimum=1, maximum=200, step=1, value=10, label="ESRGAN Tile Overlap")
1053
+ hires_steps_gui = gr.Slider(minimum=0, value=30, maximum=100, step=1, label="Hires Steps")
1054
+ hires_denoising_strength_gui = gr.Slider(minimum=0.1, maximum=1.0, step=0.01, value=0.55, label="Hires Denoising Strength")
1055
+ hires_sampler_gui = gr.Dropdown(label="Hires Sampler", choices=["Use same sampler"] + scheduler_names[:-1], value="Use same sampler")
1056
+ hires_prompt_gui = gr.Textbox(label="Hires Prompt", placeholder="Main prompt will be use", lines=3)
1057
+ hires_negative_prompt_gui = gr.Textbox(label="Hires Negative Prompt", placeholder="Main negative prompt will be use", lines=3)
1058
+
1059
+ with gr.Accordion("LoRA", open=False, visible=True):
1060
+ lora1_gui = gr.Dropdown(label="Lora1", choices=get_lora_tupled_list(lora_model_list), allow_custom_value=True)
1061
+ lora_scale_1_gui = gr.Slider(minimum=-2, maximum=2, step=0.01, value=1.00, label="Lora Scale 1")
1062
+ with gr.Row():
1063
+ with gr.Group():
1064
+ lora1_trigger_gui = gr.Textbox(label="Lora1 prompts", info="Example of prompt:", value="", show_copy_button=True, interactive=False, visible=False)
1065
+ lora1_copy_button = gr.Button(value="Copy example to prompt", visible=False)
1066
+ lora1_desc_gui = gr.Markdown(value="", visible=False)
1067
+ lora2_gui = gr.Dropdown(label="Lora2", choices=get_lora_tupled_list(lora_model_list), allow_custom_value=True)
1068
+ lora_scale_2_gui = gr.Slider(minimum=-2, maximum=2, step=0.01, value=1.00, label="Lora Scale 2")
1069
+ with gr.Row():
1070
+ with gr.Group():
1071
+ lora2_trigger_gui = gr.Textbox(label="Lora2 prompts", info="Example of prompt:", value="", show_copy_button=True, interactive=False, visible=False)
1072
+ lora2_copy_button = gr.Button(value="Copy example to prompt", visible=False)
1073
+ lora2_desc_gui = gr.Markdown(value="", visible=False)
1074
+ lora3_gui = gr.Dropdown(label="Lora3", choices=get_lora_tupled_list(lora_model_list), allow_custom_value=True)
1075
+ lora_scale_3_gui = gr.Slider(minimum=-2, maximum=2, step=0.01, value=1.00, label="Lora Scale 3")
1076
+ with gr.Row():
1077
+ with gr.Group():
1078
+ lora3_trigger_gui = gr.Textbox(label="Lora3 prompts", info="Example of prompt:", value="", show_copy_button=True, interactive=False, visible=False)
1079
+ lora3_copy_button = gr.Button(value="Copy example to prompt", visible=False)
1080
+ lora3_desc_gui = gr.Markdown(value="", visible=False)
1081
+ lora4_gui = gr.Dropdown(label="Lora4", choices=get_lora_tupled_list(lora_model_list), allow_custom_value=True)
1082
+ lora_scale_4_gui = gr.Slider(minimum=-2, maximum=2, step=0.01, value=1.00, label="Lora Scale 4")
1083
+ with gr.Row():
1084
+ with gr.Group():
1085
+ lora4_trigger_gui = gr.Textbox(label="Lora4 prompts", info="Example of prompt:", value="", show_copy_button=True, interactive=False, visible=False)
1086
+ lora4_copy_button = gr.Button(value="Copy example to prompt", visible=False)
1087
+ lora4_desc_gui = gr.Markdown(value="", visible=False)
1088
+ lora5_gui = gr.Dropdown(label="Lora5", choices=get_lora_tupled_list(lora_model_list), allow_custom_value=True)
1089
+ lora_scale_5_gui = gr.Slider(minimum=-2, maximum=2, step=0.01, value=1.00, label="Lora Scale 5")
1090
+ with gr.Row():
1091
+ with gr.Group():
1092
+ lora5_trigger_gui = gr.Textbox(label="Lora5 prompts", info="Example of prompt", value="", show_copy_button=True, interactive=False, visible=False)
1093
+ lora5_copy_button = gr.Button(value="Copy example to prompt", visible=False)
1094
+ lora5_desc_gui = gr.Markdown(value="", visible=False)
1095
+
1096
+ with gr.Accordion("From URL", open=True, visible=True):
1097
+ text_lora = gr.Textbox(label="URL", placeholder="http://...my_lora_url.safetensors", lines=1)
1098
+ button_lora = gr.Button("Get and update lists of LoRAs")
1099
+ button_lora.click(
1100
+ get_my_lora,
1101
+ [text_lora],
1102
+ [lora1_gui, lora2_gui, lora3_gui, lora4_gui, lora5_gui]
1103
+ )
1104
+
1105
+ with gr.Accordion("From Local", open=True, visible=True):
1106
+ file_output_lora = gr.File(label="Uploaded LoRA", file_types=['.ckpt', '.pt', '.pth', '.safetensors', '.bin'], file_count="multiple", interactive=False, visible=False)
1107
+ upload_button_lora = gr.UploadButton(label="Upload LoRA from your disk (very slow)", file_types=['.ckpt' , '.pt', '.pth', '.safetensors', '.bin'], file_count="multiple")
1108
+ upload_button_lora.upload(upload_file_lora, upload_button_lora, file_output_lora).success(
1109
+ move_file_lora,
1110
+ [file_output_lora],
1111
+ [lora1_gui, lora2_gui, lora3_gui, lora4_gui, lora5_gui]
1112
+ )
1113
+ ## END MOD
1114
+ with gr.Accordion("IP-Adapter", open=False, visible=True):##############
1115
+
1116
+ IP_MODELS = sorted(list(set(IP_ADAPTERS_SD + IP_ADAPTERS_SDXL)))
1117
+ MODE_IP_OPTIONS = ["original", "style", "layout", "style+layout"]
1118
+
1119
+ with gr.Accordion("IP-Adapter 1", open=False, visible=True):
1120
+ image_ip1 = gr.Image(label="IP Image", type="filepath")
1121
+ mask_ip1 = gr.Image(label="IP Mask", type="filepath")
1122
+ model_ip1 = gr.Dropdown(value="plus_face", label="Model", choices=IP_MODELS)
1123
+ mode_ip1 = gr.Dropdown(value="original", label="Mode", choices=MODE_IP_OPTIONS)
1124
+ scale_ip1 = gr.Slider(minimum=0., maximum=2., step=0.01, value=0.7, label="Scale")
1125
+ with gr.Accordion("IP-Adapter 2", open=False, visible=True):
1126
+ image_ip2 = gr.Image(label="IP Image", type="filepath")
1127
+ mask_ip2 = gr.Image(label="IP Mask (optional)", type="filepath")
1128
+ model_ip2 = gr.Dropdown(value="base", label="Model", choices=IP_MODELS)
1129
+ mode_ip2 = gr.Dropdown(value="style", label="Mode", choices=MODE_IP_OPTIONS)
1130
+ scale_ip2 = gr.Slider(minimum=0., maximum=2., step=0.01, value=0.7, label="Scale")
1131
+
1132
+ with gr.Accordion("ControlNet / Img2img / Inpaint", open=False, visible=True):
1133
+ image_control = gr.Image(label="Image ControlNet/Inpaint/Img2img", type="filepath")
1134
+ image_mask_gui = gr.Image(label="Image Mask", type="filepath")
1135
+ strength_gui = gr.Slider(
1136
+ minimum=0.01, maximum=1.0, step=0.01, value=0.55, label="Strength",
1137
+ info="This option adjusts the level of changes for img2img and inpainting."
1138
+ )
1139
+ image_resolution_gui = gr.Slider(minimum=64, maximum=2048, step=64, value=1024, label="Image Resolution")
1140
+ preprocessor_name_gui = gr.Dropdown(label="Preprocessor Name", choices=preprocessor_controlnet["canny"])
1141
+
1142
+ def change_preprocessor_choices(task):
1143
+ task = task_stablepy[task]
1144
+ if task in preprocessor_controlnet.keys():
1145
+ choices_task = preprocessor_controlnet[task]
1146
+ else:
1147
+ choices_task = preprocessor_controlnet["canny"]
1148
+ return gr.update(choices=choices_task, value=choices_task[0])
1149
+
1150
+ task_gui.change(
1151
+ change_preprocessor_choices,
1152
+ [task_gui],
1153
+ [preprocessor_name_gui],
1154
+ )
1155
+ preprocess_resolution_gui = gr.Slider(minimum=64, maximum=2048, step=64, value=512, label="Preprocess Resolution")
1156
+ low_threshold_gui = gr.Slider(minimum=1, maximum=255, step=1, value=100, label="Canny low threshold")
1157
+ high_threshold_gui = gr.Slider(minimum=1, maximum=255, step=1, value=200, label="Canny high threshold")
1158
+ value_threshold_gui = gr.Slider(minimum=1, maximum=2.0, step=0.01, value=0.1, label="Hough value threshold (MLSD)")
1159
+ distance_threshold_gui = gr.Slider(minimum=1, maximum=20.0, step=0.01, value=0.1, label="Hough distance threshold (MLSD)")
1160
+ control_net_output_scaling_gui = gr.Slider(minimum=0, maximum=5.0, step=0.1, value=1, label="ControlNet Output Scaling in UNet")
1161
+ control_net_start_threshold_gui = gr.Slider(minimum=0, maximum=1, step=0.01, value=0, label="ControlNet Start Threshold (%)")
1162
+ control_net_stop_threshold_gui = gr.Slider(minimum=0, maximum=1, step=0.01, value=1, label="ControlNet Stop Threshold (%)")
1163
+
1164
+ with gr.Accordion("T2I adapter", open=False, visible=True):
1165
+ t2i_adapter_preprocessor_gui = gr.Checkbox(value=True, label="T2i Adapter Preprocessor")
1166
+ adapter_conditioning_scale_gui = gr.Slider(minimum=0, maximum=5., step=0.1, value=1, label="Adapter Conditioning Scale")
1167
+ adapter_conditioning_factor_gui = gr.Slider(minimum=0, maximum=1., step=0.01, value=0.55, label="Adapter Conditioning Factor (%)")
1168
+
1169
+ with gr.Accordion("Styles", open=False, visible=True):
1170
+
1171
+ try:
1172
+ style_names_found = sd_gen.model.STYLE_NAMES
1173
+ except:
1174
+ style_names_found = STYLE_NAMES
1175
+
1176
+ style_prompt_gui = gr.Dropdown(
1177
+ style_names_found,
1178
+ multiselect=True,
1179
+ value=None,
1180
+ label="Style Prompt",
1181
+ interactive=True,
1182
+ )
1183
+ style_json_gui = gr.File(label="Style JSON File")
1184
+ style_button = gr.Button("Load styles")
1185
+
1186
+ def load_json_style_file(json):
1187
+ if not sd_gen.model:
1188
+ gr.Info("First load the model")
1189
+ return gr.update(value=None, choices=STYLE_NAMES)
1190
+
1191
+ sd_gen.model.load_style_file(json)
1192
+ gr.Info(f"{len(sd_gen.model.STYLE_NAMES)} styles loaded")
1193
+ return gr.update(value=None, choices=sd_gen.model.STYLE_NAMES)
1194
+
1195
+ style_button.click(load_json_style_file, [style_json_gui], [style_prompt_gui])
1196
+
1197
+ ## BEGIN MOD
1198
+ with gr.Accordion("Textual inversion", open=False, visible=True):
1199
+ active_textual_inversion_gui = gr.Checkbox(value=False, label="Active Textual Inversion in prompt")
1200
+ use_textual_inversion_gui = gr.CheckboxGroup(choices=get_embed_list(get_model_pipeline(model_name_gui.value)) if active_textual_inversion_gui.value else [], value=None, label="Use Textual Invertion in prompt")
1201
+ def update_textual_inversion_gui(active_textual_inversion_gui, model_name_gui):
1202
+ return gr.update(choices=get_embed_list(get_model_pipeline(model_name_gui)) if active_textual_inversion_gui else [])
1203
+ active_textual_inversion_gui.change(update_textual_inversion_gui, [active_textual_inversion_gui, model_name_gui], [use_textual_inversion_gui])
1204
+ model_name_gui.change(update_textual_inversion_gui, [active_textual_inversion_gui, model_name_gui], [use_textual_inversion_gui])
1205
+ ## END MOD
1206
+
1207
+ with gr.Accordion("Detailfix", open=False, visible=True):
1208
+
1209
+ # Adetailer Inpaint Only
1210
+ adetailer_inpaint_only_gui = gr.Checkbox(label="Inpaint only", value=True)
1211
+
1212
+ # Adetailer Verbose
1213
+ adetailer_verbose_gui = gr.Checkbox(label="Verbose", value=False)
1214
+
1215
+ # Adetailer Sampler
1216
+ adetailer_sampler_options = ["Use same sampler"] + scheduler_names[:-1]
1217
+ adetailer_sampler_gui = gr.Dropdown(label="Adetailer sampler:", choices=adetailer_sampler_options, value="Use same sampler")
1218
+
1219
+ with gr.Accordion("Detailfix A", open=False, visible=True):
1220
+ # Adetailer A
1221
+ adetailer_active_a_gui = gr.Checkbox(label="Enable Adetailer A", value=False)
1222
+ prompt_ad_a_gui = gr.Textbox(label="Main prompt", placeholder="Main prompt will be use", lines=3)
1223
+ negative_prompt_ad_a_gui = gr.Textbox(label="Negative prompt", placeholder="Main negative prompt will be use", lines=3)
1224
+ strength_ad_a_gui = gr.Number(label="Strength:", value=0.35, step=0.01, minimum=0.01, maximum=1.0)
1225
+ face_detector_ad_a_gui = gr.Checkbox(label="Face detector", value=True)
1226
+ person_detector_ad_a_gui = gr.Checkbox(label="Person detector", value=True)
1227
+ hand_detector_ad_a_gui = gr.Checkbox(label="Hand detector", value=False)
1228
+ mask_dilation_a_gui = gr.Number(label="Mask dilation:", value=4, minimum=1)
1229
+ mask_blur_a_gui = gr.Number(label="Mask blur:", value=4, minimum=1)
1230
+ mask_padding_a_gui = gr.Number(label="Mask padding:", value=32, minimum=1)
1231
+
1232
+ with gr.Accordion("Detailfix B", open=False, visible=True):
1233
+ # Adetailer B
1234
+ adetailer_active_b_gui = gr.Checkbox(label="Enable Adetailer B", value=False)
1235
+ prompt_ad_b_gui = gr.Textbox(label="Main prompt", placeholder="Main prompt will be use", lines=3)
1236
+ negative_prompt_ad_b_gui = gr.Textbox(label="Negative prompt", placeholder="Main negative prompt will be use", lines=3)
1237
+ strength_ad_b_gui = gr.Number(label="Strength:", value=0.35, step=0.01, minimum=0.01, maximum=1.0)
1238
+ face_detector_ad_b_gui = gr.Checkbox(label="Face detector", value=True)
1239
+ person_detector_ad_b_gui = gr.Checkbox(label="Person detector", value=True)
1240
+ hand_detector_ad_b_gui = gr.Checkbox(label="Hand detector", value=False)
1241
+ mask_dilation_b_gui = gr.Number(label="Mask dilation:", value=4, minimum=1)
1242
+ mask_blur_b_gui = gr.Number(label="Mask blur:", value=4, minimum=1)
1243
+ mask_padding_b_gui = gr.Number(label="Mask padding:", value=32, minimum=1)
1244
+
1245
+ with gr.Accordion("Other settings", open=False, visible=True):
1246
+ image_previews_gui = gr.Checkbox(value=True, label="Image Previews")
1247
+ hires_before_adetailer_gui = gr.Checkbox(value=False, label="Hires Before Adetailer")
1248
+ hires_after_adetailer_gui = gr.Checkbox(value=True, label="Hires After Adetailer")
1249
+ generator_in_cpu_gui = gr.Checkbox(value=False, label="Generator in CPU")
1250
+
1251
+ with gr.Accordion("More settings", open=False, visible=False):
1252
+ loop_generation_gui = gr.Slider(minimum=1, value=1, label="Loop Generation")
1253
+ retain_task_cache_gui = gr.Checkbox(value=False, label="Retain task model in cache")
1254
+ leave_progress_bar_gui = gr.Checkbox(value=True, label="Leave Progress Bar")
1255
+ disable_progress_bar_gui = gr.Checkbox(value=False, label="Disable Progress Bar")
1256
+ display_images_gui = gr.Checkbox(value=True, label="Display Images")
1257
+ save_generated_images_gui = gr.Checkbox(value=False, label="Save Generated Images")
1258
+ image_storage_location_gui = gr.Textbox(value="./images", label="Image Storage Location")
1259
+ retain_compel_previous_load_gui = gr.Checkbox(value=False, label="Retain Compel Previous Load")
1260
+ retain_detailfix_model_previous_load_gui = gr.Checkbox(value=False, label="Retain Detailfix Model Previous Load")
1261
+ retain_hires_model_previous_load_gui = gr.Checkbox(value=False, label="Retain Hires Model Previous Load")
1262
+ xformers_memory_efficient_attention_gui = gr.Checkbox(value=False, label="Xformers Memory Efficient Attention")
1263
+
1264
+ ## BEGIN MOD
1265
+ with gr.Accordion("Examples and help", open=True, visible=True):
1266
+ gr.Examples(
1267
+ examples=[
1268
+ [
1269
+ "1girl, souryuu asuka langley, neon genesis evangelion, plugsuit, pilot suit, red bodysuit, sitting, crossing legs, black eye patch, cat hat, throne, symmetrical, looking down, from bottom, looking at viewer, outdoors, masterpiece, best quality, very aesthetic, absurdres",
1270
+ "nsfw, lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]",
1271
+ 1,
1272
+ 30,
1273
+ 7.5,
1274
+ True,
1275
+ -1,
1276
+ None,
1277
+ 1.0,
1278
+ None,
1279
+ 1.0,
1280
+ None,
1281
+ 1.0,
1282
+ None,
1283
+ 1.0,
1284
+ None,
1285
+ 1.0,
1286
+ "Euler a",
1287
+ 1152,
1288
+ 896,
1289
+ "votepurchase/animagine-xl-3.1",
1290
+ None, # vae
1291
+ "txt2img",
1292
+ None, # img conttol
1293
+ "Canny", # preprocessor
1294
+ 512, # preproc resolution
1295
+ 1024, # img resolution
1296
+ None, # Style prompt
1297
+ None, # Style json
1298
+ None, # img Mask
1299
+ 0.35, # strength
1300
+ 100, # low th canny
1301
+ 200, # high th canny
1302
+ 0.1, # value mstd
1303
+ 0.1, # distance mstd
1304
+ 1.0, # cn scale
1305
+ 0., # cn start
1306
+ 1., # cn end
1307
+ False, # ti
1308
+ "Classic",
1309
+ None,
1310
+ ],
1311
+ [
1312
+ "solo, princess Zelda OOT, score_9, score_8_up, score_8, medium breasts, cute, eyelashes, cute small face, long hair, crown braid, hairclip, pointy ears, soft curvy body, looking at viewer, smile, blush, white dress, medium body, (((holding the Master Sword))), standing, deep forest in the background",
1313
+ "score_6, score_5, score_4, busty, ugly face, mutated hands, low res, blurry face, black and white,",
1314
+ 1,
1315
+ 30,
1316
+ 5.,
1317
+ True,
1318
+ -1,
1319
+ None,
1320
+ 1.0,
1321
+ None,
1322
+ 1.0,
1323
+ None,
1324
+ 1.0,
1325
+ None,
1326
+ 1.0,
1327
+ None,
1328
+ 1.0,
1329
+ "Euler a",
1330
+ 1024,
1331
+ 1024,
1332
+ "votepurchase/ponyDiffusionV6XL",
1333
+ None, # vae
1334
+ "txt2img",
1335
+ None, # img conttol
1336
+ "Canny", # preprocessor
1337
+ 512, # preproc resolution
1338
+ 1024, # img resolution
1339
+ None, # Style prompt
1340
+ None, # Style json
1341
+ None, # img Mask
1342
+ 0.35, # strength
1343
+ 100, # low th canny
1344
+ 200, # high th canny
1345
+ 0.1, # value mstd
1346
+ 0.1, # distance mstd
1347
+ 1.0, # cn scale
1348
+ 0., # cn start
1349
+ 1., # cn end
1350
+ False, # ti
1351
+ "Classic",
1352
+ None,
1353
+ ],
1354
+ [
1355
+ "1girl, oomuro sakurako, yuru yuri, official art, school uniform, anime artwork, anime style, key visual, vibrant, studio anime, highly detailed, masterpiece, best quality, very aesthetic, absurdres",
1356
+ "photo, deformed, black and white, realism, disfigured, low contrast, lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]",
1357
+ 1,
1358
+ 40,
1359
+ 7.0,
1360
+ True,
1361
+ -1,
1362
+ None,
1363
+ 1.0,
1364
+ None,
1365
+ 1.0,
1366
+ None,
1367
+ 1.0,
1368
+ None,
1369
+ 1.0,
1370
+ None,
1371
+ 1.0,
1372
+ "Euler a",
1373
+ 1024,
1374
+ 1024,
1375
+ "Raelina/Rae-Diffusion-XL-V2",
1376
+ "vaes/sdxl.vae.safetensors", # vae
1377
+ "txt2img",
1378
+ None, # img conttol
1379
+ "Canny", # preprocessor
1380
+ 512, # preproc resolution
1381
+ 1024, # img resolution
1382
+ None, # Style prompt
1383
+ None, # Style json
1384
+ None, # img Mask
1385
+ 0.35, # strength
1386
+ 100, # low th canny
1387
+ 200, # high th canny
1388
+ 0.1, # value mstd
1389
+ 0.1, # distance mstd
1390
+ 1.0, # cn scale
1391
+ 0., # cn start
1392
+ 1., # cn end
1393
+ False, # ti
1394
+ "Classic",
1395
+ None,
1396
+ ],
1397
+ [
1398
+ "yoshida yuuko, machikado mazoku, 1girl, solo, demon horns,horns, school uniform, long hair, open mouth, skirt, demon girl, ahoge, shiny, shiny hair, anime artwork",
1399
+ "nsfw, lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]",
1400
+ 1,
1401
+ 50,
1402
+ 7.,
1403
+ True,
1404
+ -1,
1405
+ None,
1406
+ 1.0,
1407
+ None,
1408
+ 1.0,
1409
+ None,
1410
+ 1.0,
1411
+ None,
1412
+ 1.0,
1413
+ None,
1414
+ 1.0,
1415
+ "Euler a",
1416
+ 1024,
1417
+ 1024,
1418
+ "votepurchase/animagine-xl-3.1",
1419
+ None, # vae
1420
+ "img2img",
1421
+ "color_image.png", # img conttol
1422
+ "Canny", # preprocessor
1423
+ 512, # preproc resolution
1424
+ 1024, # img resolution
1425
+ None, # Style prompt
1426
+ None, # Style json
1427
+ None, # img Mask
1428
+ 0.35, # strength
1429
+ 100, # low th canny
1430
+ 200, # high th canny
1431
+ 0.1, # value mstd
1432
+ 0.1, # distance mstd
1433
+ 1.0, # cn scale
1434
+ 0., # cn start
1435
+ 1., # cn end
1436
+ False, # ti
1437
+ "Classic",
1438
+ None,
1439
+ ],
1440
+ ],
1441
+ fn=sd_gen.generate_pipeline,
1442
+ inputs=[
1443
+ prompt_gui,
1444
+ neg_prompt_gui,
1445
+ num_images_gui,
1446
+ steps_gui,
1447
+ cfg_gui,
1448
+ clip_skip_gui,
1449
+ seed_gui,
1450
+ lora1_gui,
1451
+ lora_scale_1_gui,
1452
+ lora2_gui,
1453
+ lora_scale_2_gui,
1454
+ lora3_gui,
1455
+ lora_scale_3_gui,
1456
+ lora4_gui,
1457
+ lora_scale_4_gui,
1458
+ lora5_gui,
1459
+ lora_scale_5_gui,
1460
+ sampler_gui,
1461
+ img_height_gui,
1462
+ img_width_gui,
1463
+ model_name_gui,
1464
+ vae_model_gui,
1465
+ task_gui,
1466
+ image_control,
1467
+ preprocessor_name_gui,
1468
+ preprocess_resolution_gui,
1469
+ image_resolution_gui,
1470
+ style_prompt_gui,
1471
+ style_json_gui,
1472
+ image_mask_gui,
1473
+ strength_gui,
1474
+ low_threshold_gui,
1475
+ high_threshold_gui,
1476
+ value_threshold_gui,
1477
+ distance_threshold_gui,
1478
+ control_net_output_scaling_gui,
1479
+ control_net_start_threshold_gui,
1480
+ control_net_stop_threshold_gui,
1481
+ active_textual_inversion_gui,
1482
+ prompt_syntax_gui,
1483
+ upscaler_model_path_gui,
1484
+ ],
1485
+ outputs=[result_images],
1486
+ cache_examples=False,
1487
+ elem_id="examples",
1488
+ )
1489
+ ## END MOD
1490
+
1491
+ with gr.Tab("Inpaint mask maker", render=True):
1492
+
1493
+ def create_mask_now(img, invert):
1494
+ import numpy as np
1495
+ import time
1496
+
1497
+ time.sleep(0.5)
1498
+
1499
+ transparent_image = img["layers"][0]
1500
+
1501
+ # Extract the alpha channel
1502
+ alpha_channel = np.array(transparent_image)[:, :, 3]
1503
+
1504
+ # Create a binary mask by thresholding the alpha channel
1505
+ binary_mask = alpha_channel > 1
1506
+
1507
+ if invert:
1508
+ print("Invert")
1509
+ # Invert the binary mask so that the drawn shape is white and the rest is black
1510
+ binary_mask = np.invert(binary_mask)
1511
+
1512
+ # Convert the binary mask to a 3-channel RGB mask
1513
+ rgb_mask = np.stack((binary_mask,) * 3, axis=-1)
1514
+
1515
+ # Convert the mask to uint8
1516
+ rgb_mask = rgb_mask.astype(np.uint8) * 255
1517
+
1518
+ return img["background"], rgb_mask
1519
+
1520
+ with gr.Row():
1521
+ with gr.Column(scale=2):
1522
+ # image_base = gr.ImageEditor(label="Base image", show_label=True, brush=gr.Brush(colors=["#000000"]))
1523
+ image_base = gr.ImageEditor(
1524
+ sources=["upload", "clipboard"],
1525
+ # crop_size="1:1",
1526
+ # enable crop (or disable it)
1527
+ # transforms=["crop"],
1528
+ brush=gr.Brush(
1529
+ default_size="16", # or leave it as 'auto'
1530
+ color_mode="fixed", # 'fixed' hides the user swatches and colorpicker, 'defaults' shows it
1531
+ # default_color="black", # html names are supported
1532
+ colors=[
1533
+ "rgba(0, 0, 0, 1)", # rgb(a)
1534
+ "rgba(0, 0, 0, 0.1)",
1535
+ "rgba(255, 255, 255, 0.1)",
1536
+ # "hsl(360, 120, 120)" # in fact any valid colorstring
1537
+ ]
1538
+ ),
1539
+ eraser=gr.Eraser(default_size="16")
1540
+ )
1541
+ invert_mask = gr.Checkbox(value=False, label="Invert mask")
1542
+ btn = gr.Button("Create mask")
1543
+ with gr.Column(scale=1):
1544
+ img_source = gr.Image(interactive=False)
1545
+ img_result = gr.Image(label="Mask image", show_label=True, interactive=False)
1546
+ btn_send = gr.Button("Send to the first tab")
1547
+
1548
+ btn.click(create_mask_now, [image_base, invert_mask], [img_source, img_result])
1549
+
1550
+ def send_img(img_source, img_result):
1551
+ return img_source, img_result
1552
+ btn_send.click(send_img, [img_source, img_result], [image_control, image_mask_gui])
1553
+
1554
+ ## BEGIN MOD
1555
+ model_name_gui.change(get_t2i_model_info, [model_name_gui], [model_info_gui])
1556
+
1557
+ quick_model_type_gui.change(
1558
+ set_quick_presets,
1559
+ [quick_genre_gui, quick_model_type_gui, quick_speed_gui, quick_aspect_gui],
1560
+ [quality_selector_gui, style_selector_gui, sampler_selector_gui, optimization_gui],
1561
+ )
1562
+ quick_genre_gui.change(
1563
+ set_quick_presets,
1564
+ [quick_genre_gui, quick_model_type_gui, quick_speed_gui, quick_aspect_gui],
1565
+ [quality_selector_gui, style_selector_gui, sampler_selector_gui, optimization_gui],
1566
+ )
1567
+ quick_speed_gui.change(
1568
+ set_quick_presets,
1569
+ [quick_genre_gui, quick_model_type_gui, quick_speed_gui, quick_aspect_gui],
1570
+ [quality_selector_gui, style_selector_gui, sampler_selector_gui, optimization_gui],
1571
+ )
1572
+ quick_aspect_gui.change(
1573
+ set_quick_presets,
1574
+ [quick_genre_gui, quick_model_type_gui, quick_speed_gui, quick_aspect_gui],
1575
+ [quality_selector_gui, style_selector_gui, sampler_selector_gui, optimization_gui],
1576
+ )
1577
+
1578
+ quality_selector_gui.change(
1579
+ process_style_prompt,
1580
+ inputs=[prompt_gui, neg_prompt_gui, style_selector_gui, quality_selector_gui, insert_prompt_gui],
1581
+ outputs=[prompt_gui, neg_prompt_gui],
1582
+ )
1583
+ style_selector_gui.change(
1584
+ process_style_prompt,
1585
+ inputs=[prompt_gui, neg_prompt_gui, style_selector_gui, quality_selector_gui, insert_prompt_gui],
1586
+ outputs=[prompt_gui, neg_prompt_gui],
1587
+ )
1588
+ sampler_selector_gui.change(set_sampler_settings, [sampler_selector_gui], [sampler_gui, steps_gui, cfg_gui, clip_skip_gui, img_width_gui, img_height_gui, optimization_gui])
1589
+ optimization_gui.change(set_optimization, [optimization_gui, steps_gui, cfg_gui, sampler_gui, clip_skip_gui, lora5_gui, lora_scale_5_gui], [steps_gui, cfg_gui, sampler_gui, clip_skip_gui, lora5_gui, lora_scale_5_gui])
1590
+
1591
+ lora1_gui.change(set_lora_prompt, [prompt_gui, prompt_syntax_gui, lora1_gui, lora_scale_1_gui, lora2_gui, lora_scale_2_gui, lora3_gui, lora_scale_3_gui, lora4_gui, lora_scale_4_gui, lora5_gui, lora_scale_5_gui], [prompt_gui])\
1592
+ .then(set_lora_trigger, [lora1_gui], [lora1_trigger_gui, lora1_copy_button, lora1_desc_gui, lora1_gui])
1593
+ lora2_gui.change(set_lora_prompt, [prompt_gui, prompt_syntax_gui, lora1_gui, lora_scale_1_gui, lora2_gui, lora_scale_2_gui, lora3_gui, lora_scale_3_gui, lora4_gui, lora_scale_4_gui, lora5_gui, lora_scale_5_gui], [prompt_gui])\
1594
+ .then(set_lora_trigger, [lora2_gui], [lora2_trigger_gui, lora2_copy_button, lora2_desc_gui, lora2_gui])
1595
+ lora3_gui.change(set_lora_prompt, [prompt_gui, prompt_syntax_gui, lora1_gui, lora_scale_1_gui, lora2_gui, lora_scale_2_gui, lora3_gui, lora_scale_3_gui, lora4_gui, lora_scale_4_gui, lora5_gui, lora_scale_5_gui], [prompt_gui])\
1596
+ .then(set_lora_trigger, [lora3_gui], [lora3_trigger_gui, lora3_copy_button, lora3_desc_gui, lora3_gui])
1597
+ lora4_gui.change(set_lora_prompt, [prompt_gui, prompt_syntax_gui, lora1_gui, lora_scale_1_gui, lora2_gui, lora_scale_2_gui, lora3_gui, lora_scale_3_gui, lora4_gui, lora_scale_4_gui, lora5_gui, lora_scale_5_gui], [prompt_gui])\
1598
+ .then(set_lora_trigger, [lora4_gui], [lora4_trigger_gui, lora4_copy_button, lora4_desc_gui, lora4_gui])
1599
+ lora5_gui.change(set_lora_prompt, [prompt_gui, prompt_syntax_gui, lora1_gui, lora_scale_1_gui, lora2_gui, lora_scale_2_gui, lora3_gui, lora_scale_3_gui, lora4_gui, lora_scale_4_gui, lora5_gui, lora_scale_5_gui], [prompt_gui])\
1600
+ .then(set_lora_trigger, [lora5_gui], [lora5_trigger_gui, lora5_copy_button, lora5_desc_gui, lora5_gui])
1601
+ lora_scale_1_gui.change(set_lora_prompt, [prompt_gui, prompt_syntax_gui, lora1_gui, lora_scale_1_gui, lora2_gui, lora_scale_2_gui, lora3_gui, lora_scale_3_gui, lora4_gui, lora_scale_4_gui, lora5_gui, lora_scale_5_gui], [prompt_gui])
1602
+ lora_scale_2_gui.change(set_lora_prompt, [prompt_gui, prompt_syntax_gui, lora1_gui, lora_scale_1_gui, lora2_gui, lora_scale_2_gui, lora3_gui, lora_scale_3_gui, lora4_gui, lora_scale_4_gui, lora5_gui, lora_scale_5_gui], [prompt_gui])
1603
+ lora_scale_3_gui.change(set_lora_prompt, [prompt_gui, prompt_syntax_gui, lora1_gui, lora_scale_1_gui, lora2_gui, lora_scale_2_gui, lora3_gui, lora_scale_3_gui, lora4_gui, lora_scale_4_gui, lora5_gui, lora_scale_5_gui], [prompt_gui])
1604
+ lora_scale_4_gui.change(set_lora_prompt, [prompt_gui, prompt_syntax_gui, lora1_gui, lora_scale_1_gui, lora2_gui, lora_scale_2_gui, lora3_gui, lora_scale_3_gui, lora4_gui, lora_scale_4_gui, lora5_gui, lora_scale_5_gui], [prompt_gui])
1605
+ lora_scale_5_gui.change(set_lora_prompt, [prompt_gui, prompt_syntax_gui, lora1_gui, lora_scale_1_gui, lora2_gui, lora_scale_2_gui, lora3_gui, lora_scale_3_gui, lora4_gui, lora_scale_4_gui, lora5_gui, lora_scale_5_gui], [prompt_gui])
1606
+ prompt_syntax_gui.change(set_lora_prompt, [prompt_gui, prompt_syntax_gui, lora1_gui, lora_scale_1_gui, lora2_gui, lora_scale_2_gui, lora3_gui, lora_scale_3_gui, lora4_gui, lora_scale_4_gui, lora5_gui, lora_scale_5_gui], [prompt_gui])
1607
+ lora1_copy_button.click(apply_lora_prompt, [prompt_gui, lora1_trigger_gui], [prompt_gui])
1608
+ lora2_copy_button.click(apply_lora_prompt, [prompt_gui, lora2_trigger_gui], [prompt_gui])
1609
+ lora3_copy_button.click(apply_lora_prompt, [prompt_gui, lora3_trigger_gui], [prompt_gui])
1610
+ lora4_copy_button.click(apply_lora_prompt, [prompt_gui, lora4_trigger_gui], [prompt_gui])
1611
+ lora5_copy_button.click(apply_lora_prompt, [prompt_gui, lora5_trigger_gui], [prompt_gui])
1612
+
1613
+ use_textual_inversion_gui.change(set_textual_inversion_prompt, [use_textual_inversion_gui, prompt_gui, neg_prompt_gui, prompt_syntax_gui], [prompt_gui, neg_prompt_gui])
1614
+
1615
+ generate_from_image_btn_gui.click(
1616
+ predict_tags_wd,
1617
+ inputs=[input_image_gui, prompt_gui, image_algorithms, general_threshold_gui, character_threshold_gui],
1618
+ outputs=[
1619
+ series_dbt,
1620
+ character_dbt,
1621
+ prompt_gui,
1622
+ copy_button_dbt,
1623
+ ],
1624
+ ).then(
1625
+ compose_prompt_to_copy, inputs=[character_dbt, series_dbt, prompt_gui], outputs=[prompt_gui]
1626
+ ).then(
1627
+ remove_specific_prompt, inputs=[prompt_gui, keep_tags_gui], outputs=[prompt_gui],
1628
+ ).then(
1629
+ convert_danbooru_to_e621_prompt, inputs=[prompt_gui, tag_type_gui], outputs=[prompt_gui],
1630
+ ).then(
1631
+ insert_recom_prompt, inputs=[prompt_gui, neg_prompt_gui, recom_prompt_gui], outputs=[prompt_gui, neg_prompt_gui],
1632
+ )
1633
+
1634
+ v2b.input_components = [
1635
+ model_name_dbt,
1636
+ series_dbt,
1637
+ character_dbt,
1638
+ prompt_gui,
1639
+ rating_dbt,
1640
+ aspect_ratio_dbt,
1641
+ length_dbt,
1642
+ identity_dbt,
1643
+ ban_tags_dbt,
1644
+ ]
1645
+
1646
+ insert_prompt_gui.change(
1647
+ process_style_prompt,
1648
+ inputs=[prompt_gui, neg_prompt_gui, style_selector_gui, quality_selector_gui, insert_prompt_gui],
1649
+ outputs=[prompt_gui, neg_prompt_gui],
1650
+ )
1651
+
1652
+ prompt_type_button.click(
1653
+ convert_danbooru_to_e621_prompt,
1654
+ inputs=[prompt_gui, prompt_type_gui],
1655
+ outputs=[prompt_gui],
1656
+ )
1657
+
1658
+ generate_db_random_button.click(
1659
+ parse_upsampling_output(v2b.on_generate),
1660
+ inputs=[
1661
+ *v2b.input_components,
1662
+ ],
1663
+ outputs=[prompt_gui, elapsed_time_dbt, copy_button_dbt, copy_button_dbt],
1664
+ )
1665
+
1666
+ translate_prompt_button.click(translate_prompt, inputs=[prompt_gui], outputs=[prompt_gui])
1667
+ translate_prompt_button.click(translate_prompt, inputs=[character_dbt], outputs=[character_dbt])
1668
+ translate_prompt_button.click(translate_prompt, inputs=[series_dbt], outputs=[series_dbt])
1669
+
1670
+ generate_button.click(
1671
+ fn=sd_gen.load_new_model,
1672
+ inputs=[
1673
+ model_name_gui,
1674
+ vae_model_gui,
1675
+ task_gui
1676
+ ],
1677
+ outputs=[load_model_gui],
1678
+ queue=True,
1679
+ show_progress="minimal",
1680
+ ).success(
1681
+ fn=sd_gen.generate_pipeline,
1682
+ inputs=[
1683
+ prompt_gui,
1684
+ neg_prompt_gui,
1685
+ num_images_gui,
1686
+ steps_gui,
1687
+ cfg_gui,
1688
+ clip_skip_gui,
1689
+ seed_gui,
1690
+ lora1_gui,
1691
+ lora_scale_1_gui,
1692
+ lora2_gui,
1693
+ lora_scale_2_gui,
1694
+ lora3_gui,
1695
+ lora_scale_3_gui,
1696
+ lora4_gui,
1697
+ lora_scale_4_gui,
1698
+ lora5_gui,
1699
+ lora_scale_5_gui,
1700
+ sampler_gui,
1701
+ img_height_gui,
1702
+ img_width_gui,
1703
+ model_name_gui,
1704
+ vae_model_gui,
1705
+ task_gui,
1706
+ image_control,
1707
+ preprocessor_name_gui,
1708
+ preprocess_resolution_gui,
1709
+ image_resolution_gui,
1710
+ style_prompt_gui,
1711
+ style_json_gui,
1712
+ image_mask_gui,
1713
+ strength_gui,
1714
+ low_threshold_gui,
1715
+ high_threshold_gui,
1716
+ value_threshold_gui,
1717
+ distance_threshold_gui,
1718
+ control_net_output_scaling_gui,
1719
+ control_net_start_threshold_gui,
1720
+ control_net_stop_threshold_gui,
1721
+ active_textual_inversion_gui,
1722
+ prompt_syntax_gui,
1723
+ upscaler_model_path_gui,
1724
+ upscaler_increases_size_gui,
1725
+ esrgan_tile_gui,
1726
+ esrgan_tile_overlap_gui,
1727
+ hires_steps_gui,
1728
+ hires_denoising_strength_gui,
1729
+ hires_sampler_gui,
1730
+ hires_prompt_gui,
1731
+ hires_negative_prompt_gui,
1732
+ hires_before_adetailer_gui,
1733
+ hires_after_adetailer_gui,
1734
+ loop_generation_gui,
1735
+ leave_progress_bar_gui,
1736
+ disable_progress_bar_gui,
1737
+ image_previews_gui,
1738
+ display_images_gui,
1739
+ save_generated_images_gui,
1740
+ image_storage_location_gui,
1741
+ retain_compel_previous_load_gui,
1742
+ retain_detailfix_model_previous_load_gui,
1743
+ retain_hires_model_previous_load_gui,
1744
+ t2i_adapter_preprocessor_gui,
1745
+ adapter_conditioning_scale_gui,
1746
+ adapter_conditioning_factor_gui,
1747
+ xformers_memory_efficient_attention_gui,
1748
+ free_u_gui,
1749
+ generator_in_cpu_gui,
1750
+ adetailer_inpaint_only_gui,
1751
+ adetailer_verbose_gui,
1752
+ adetailer_sampler_gui,
1753
+ adetailer_active_a_gui,
1754
+ prompt_ad_a_gui,
1755
+ negative_prompt_ad_a_gui,
1756
+ strength_ad_a_gui,
1757
+ face_detector_ad_a_gui,
1758
+ person_detector_ad_a_gui,
1759
+ hand_detector_ad_a_gui,
1760
+ mask_dilation_a_gui,
1761
+ mask_blur_a_gui,
1762
+ mask_padding_a_gui,
1763
+ adetailer_active_b_gui,
1764
+ prompt_ad_b_gui,
1765
+ negative_prompt_ad_b_gui,
1766
+ strength_ad_b_gui,
1767
+ face_detector_ad_b_gui,
1768
+ person_detector_ad_b_gui,
1769
+ hand_detector_ad_b_gui,
1770
+ mask_dilation_b_gui,
1771
+ mask_blur_b_gui,
1772
+ mask_padding_b_gui,
1773
+ retain_task_cache_gui,
1774
+ image_ip1,
1775
+ mask_ip1,
1776
+ model_ip1,
1777
+ mode_ip1,
1778
+ scale_ip1,
1779
+ image_ip2,
1780
+ mask_ip2,
1781
+ model_ip2,
1782
+ mode_ip2,
1783
+ scale_ip2,
1784
+ ],
1785
+ outputs=[result_images, actual_task_info],
1786
+ queue=True,
1787
+ show_progress="minimal",
1788
+ ).success(save_gallery_images, [result_images], [result_images, result_images_files, result_images_files])
1789
+
1790
+ with gr.Tab("Danbooru Tags Transformer with WD Tagger", render=True):
1791
+ v2 = V2UI()
1792
+ with gr.Column(scale=2):
1793
+ with gr.Group():
1794
+ input_image = gr.Image(label="Input image", type="pil", sources=["upload", "clipboard"], height=256)
1795
+ with gr.Accordion(label="Advanced options", open=False):
1796
+ general_threshold = gr.Slider(label="Threshold", minimum=0.0, maximum=1.0, value=0.3, step=0.01, interactive=True)
1797
+ character_threshold = gr.Slider(label="Character threshold", minimum=0.0, maximum=1.0, value=0.8, step=0.01, interactive=True)
1798
+ input_tag_type = gr.Radio(label="Convert tags to", info="danbooru for Animagine, e621 for Pony.", choices=["danbooru", "e621"], value="danbooru")
1799
+ recom_prompt = gr.Radio(label="Insert reccomended prompt", choices=["None", "Animagine", "Pony"], value="None", interactive=True)
1800
+ image_algorithms = gr.CheckboxGroup(["Use WD Tagger", "Use Florence-2-SD3-Long-Captioner"], label="Algorithms", value=["Use WD Tagger"], visible=False)
1801
+ keep_tags = gr.Radio(label="Remove tags leaving only the following", choices=["body", "dress", "all"], value="all")
1802
+ generate_from_image_btn = gr.Button(value="GENERATE TAGS FROM IMAGE", size="lg", variant="primary")
1803
+
1804
+ with gr.Group():
1805
+ input_character = gr.Textbox(label="Character tags", placeholder="hatsune miku")
1806
+ input_copyright = gr.Textbox(label="Copyright tags", placeholder="vocaloid")
1807
+ input_general = gr.TextArea(label="General tags", lines=4, placeholder="1girl, ...", value="")
1808
+ input_tags_to_copy = gr.Textbox(value="", visible=False)
1809
+ copy_input_btn = gr.Button(value="Copy to clipboard", size="sm", interactive=False)
1810
+ translate_input_prompt_button = gr.Button(value="Translate prompt to English", size="sm", variant="secondary")
1811
+ tag_type = gr.Radio(label="Output tag conversion", info="danbooru for Animagine, e621 for Pony.", choices=["danbooru", "e621"], value="e621", visible=False)
1812
+ input_rating = gr.Radio(label="Rating", choices=list(V2_RATING_OPTIONS), value="explicit")
1813
+ with gr.Accordion(label="Advanced options", open=False):
1814
+ input_aspect_ratio = gr.Radio(label="Aspect ratio", info="The aspect ratio of the image.", choices=list(V2_ASPECT_RATIO_OPTIONS), value="square")
1815
+ input_length = gr.Radio(label="Length", info="The total length of the tags.", choices=list(V2_LENGTH_OPTIONS), value="very_long")
1816
+ input_identity = gr.Radio(label="Keep identity", info="How strictly to keep the identity of the character or subject. If you specify the detail of subject in the prompt, you should choose `strict`. Otherwise, choose `none` or `lax`. `none` is very creative but sometimes ignores the input prompt.", choices=list(V2_IDENTITY_OPTIONS), value="lax")
1817
+ input_ban_tags = gr.Textbox(label="Ban tags", info="Tags to ban from the output.", placeholder="alternate costumen, ...", value="censored")
1818
+ model_name = gr.Dropdown(label="Model", choices=list(V2_ALL_MODELS.keys()), value=list(V2_ALL_MODELS.keys())[0])
1819
+ dummy_np = gr.Textbox(label="Negative prompt", value="", visible=False)
1820
+ recom_animagine = gr.Textbox(label="Animagine reccomended prompt", value="Animagine", visible=False)
1821
+ recom_pony = gr.Textbox(label="Pony reccomended prompt", value="Pony", visible=False)
1822
+
1823
+ generate_btn = gr.Button(value="GENERATE TAGS", size="lg", variant="primary")
1824
+
1825
+ with gr.Group():
1826
+ output_text = gr.TextArea(label="Output tags", interactive=False, show_copy_button=True)
1827
+ copy_btn = gr.Button(value="Copy to clipboard", size="sm", interactive=False)
1828
+ elapsed_time_md = gr.Markdown(label="Elapsed time", value="", visible=False)
1829
+
1830
+ with gr.Group():
1831
+ output_text_pony = gr.TextArea(label="Output tags (Pony e621 style)", interactive=False, show_copy_button=True)
1832
+ copy_btn_pony = gr.Button(value="Copy to clipboard", size="sm", interactive=False)
1833
+
1834
+ description_ui()
1835
+
1836
+ v2.input_components = [
1837
+ model_name,
1838
+ input_copyright,
1839
+ input_character,
1840
+ input_general,
1841
+ input_rating,
1842
+ input_aspect_ratio,
1843
+ input_length,
1844
+ input_identity,
1845
+ input_ban_tags,
1846
+ ]
1847
+
1848
+ translate_input_prompt_button.click(translate_prompt, inputs=[input_general], outputs=[input_general])
1849
+ translate_input_prompt_button.click(translate_prompt, inputs=[input_character], outputs=[input_character])
1850
+ translate_input_prompt_button.click(translate_prompt, inputs=[input_copyright], outputs=[input_copyright])
1851
+
1852
+ generate_from_image_btn.click(
1853
+ predict_tags_wd,
1854
+ inputs=[input_image, input_general, image_algorithms, general_threshold, character_threshold],
1855
+ outputs=[
1856
+ input_copyright,
1857
+ input_character,
1858
+ input_general,
1859
+ copy_input_btn,
1860
+ ],
1861
+ ).then(
1862
+ remove_specific_prompt, inputs=[input_general, keep_tags], outputs=[input_general],
1863
+ ).then(
1864
+ convert_danbooru_to_e621_prompt, inputs=[input_general, input_tag_type], outputs=[input_general],
1865
+ ).then(
1866
+ insert_recom_prompt, inputs=[input_general, dummy_np, recom_prompt], outputs=[input_general, dummy_np],
1867
+ )
1868
+ copy_input_btn.click(compose_prompt_to_copy, inputs=[input_character, input_copyright, input_general], outputs=[input_tags_to_copy]).then(
1869
+ gradio_copy_text, inputs=[input_tags_to_copy], js=COPY_ACTION_JS,
1870
+ )
1871
+
1872
+ generate_btn.click(
1873
+ parse_upsampling_output(v2.on_generate),
1874
+ inputs=[
1875
+ *v2.input_components,
1876
+ ],
1877
+ outputs=[output_text, elapsed_time_md, copy_btn, copy_btn_pony],
1878
+ ).then(
1879
+ convert_danbooru_to_e621_prompt, inputs=[output_text, tag_type], outputs=[output_text_pony],
1880
+ ).then(
1881
+ insert_recom_prompt, inputs=[output_text, dummy_np, recom_animagine], outputs=[output_text, dummy_np],
1882
+ ).then(
1883
+ insert_recom_prompt, inputs=[output_text_pony, dummy_np, recom_pony], outputs=[output_text_pony, dummy_np],
1884
+ )
1885
+ copy_btn.click(gradio_copy_text, inputs=[output_text], js=COPY_ACTION_JS)
1886
+ copy_btn_pony.click(gradio_copy_text, inputs=[output_text_pony], js=COPY_ACTION_JS)
1887
+
1888
+ gr.DuplicateButton(value="Duplicate Space for private use (This demo does not work on CPU. Requires Space with GPU available.)", elem_id="duplicate-button", visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1")
1889
+
1890
+ app.queue()
1891
+
1892
+ app.launch(
1893
+ show_error=False,
1894
+ debug=False,
1895
+ )
1896
+ ## END MOD
character_series_dict.csv ADDED
The diff for this file is too large to render. See raw diff
 
color_image.png ADDED
danbooru_e621.csv ADDED
The diff for this file is too large to render. See raw diff
 
image.webp ADDED
lora_dict.json ADDED
The diff for this file is too large to render. See raw diff
 
modutils.py ADDED
@@ -0,0 +1,751 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import gradio as gr
4
+ from PIL import Image
5
+
6
+ HF_LORA_PRIVATE_REPOS = ['John6666/loratest1', 'John6666/loratest3', 'John6666/loratest4', 'John6666/loratest']
7
+ directory_loras = 'loras'
8
+
9
+
10
+ def get_model_list(directory_path):
11
+ model_list = []
12
+ valid_extensions = {'.ckpt', '.pt', '.pth', '.safetensors', '.bin'}
13
+
14
+ for filename in os.listdir(directory_path):
15
+ if os.path.splitext(filename)[1] in valid_extensions:
16
+ name_without_extension = os.path.splitext(filename)[0]
17
+ file_path = os.path.join(directory_path, filename)
18
+ # model_list.append((name_without_extension, file_path))
19
+ model_list.append(file_path)
20
+ # print('\033[34mFILE: ' + file_path + '\033[0m')
21
+ return model_list
22
+
23
+
24
+ def list_uniq(l):
25
+ return sorted(set(l), key=l.index)
26
+
27
+
28
+ def list_sub(a, b):
29
+ return [e for e in a if e not in b]
30
+
31
+
32
+ def normalize_prompt_list(tags):
33
+ prompts = []
34
+ for tag in tags:
35
+ tag = str(tag).strip()
36
+ if tag:
37
+ prompts.append(tag)
38
+ return prompts
39
+
40
+
41
+ def escape_lora_basename(basename: str):
42
+ return basename.replace(".", "_").replace(" ", "_").replace(",", "")
43
+
44
+
45
+ def download_private_repo(repo_id, dir_path, is_replace):
46
+ from huggingface_hub import snapshot_download
47
+ hf_read_token = os.environ.get('HF_READ_TOKEN')
48
+ if not hf_read_token: return
49
+ snapshot_download(repo_id=repo_id, local_dir=dir_path, allow_patterns=['*.ckpt', '*.pt', '*.pth', '*.safetensors', '*.bin'], use_auth_token=hf_read_token)
50
+ if is_replace:
51
+ from pathlib import Path
52
+ for file in Path(dir_path).glob("*"):
53
+ if file.exists() and "." in file.stem or " " in file.stem and file.suffix in ['.ckpt', '.pt', '.pth', '.safetensors', '.bin']:
54
+ newpath = Path(f'{file.parent.name}/{escape_lora_basename(file.stem)}{file.suffix}')
55
+ file.rename(newpath)
56
+
57
+
58
+ private_model_path_repo_dict = {}
59
+
60
+
61
+ def get_private_model_list(repo_id, dir_path):
62
+ global private_model_path_repo_dict
63
+ from huggingface_hub import HfApi
64
+ api = HfApi()
65
+ hf_read_token = os.environ.get('HF_READ_TOKEN')
66
+ if not hf_read_token: return []
67
+ files = api.list_repo_files(repo_id, token=hf_read_token)
68
+ model_list = []
69
+ for file in files:
70
+ from pathlib import Path
71
+ path = Path(f"{dir_path}/{file}")
72
+ if path.suffix in ['.ckpt', '.pt', '.pth', '.safetensors', '.bin']:
73
+ model_list.append(str(path))
74
+ for model in model_list:
75
+ private_model_path_repo_dict[model] = repo_id
76
+ return model_list
77
+
78
+
79
+ def get_private_model_lists(repo_id_list, dir_path):
80
+ models = []
81
+ for repo in repo_id_list:
82
+ models.extend(get_private_model_list(repo, dir_path))
83
+ models = list_uniq(models)
84
+ return models
85
+
86
+
87
+ def download_private_file(repo_id, path, is_replace):
88
+ from huggingface_hub import hf_hub_download
89
+ from pathlib import Path
90
+ file = Path(path)
91
+ newpath = Path(f'{file.parent.name}/{escape_lora_basename(file.stem)}{file.suffix}') if is_replace else file
92
+ hf_read_token = os.environ.get('HF_READ_TOKEN')
93
+ if not hf_read_token or newpath.exists(): return
94
+ filename = file.name
95
+ dirname = file.parent.name
96
+ hf_hub_download(repo_id=repo_id, filename=filename, local_dir=dirname, use_auth_token=hf_read_token)
97
+ if is_replace:
98
+ file.rename(newpath)
99
+
100
+
101
+ def download_private_file_from_somewhere(path, is_replace):
102
+ if not path in private_model_path_repo_dict.keys(): return
103
+ repo_id = private_model_path_repo_dict.get(path, None)
104
+ download_private_file(repo_id, path, is_replace)
105
+
106
+
107
+ def get_model_id_list():
108
+ from huggingface_hub import HfApi
109
+ api = HfApi()
110
+ model_ids = []
111
+
112
+ models_vp = api.list_models(author="votepurchase", cardData=True, sort="likes")
113
+ models_john = api.list_models(author="John6666", cardData=True, sort="last_modified")
114
+
115
+ for model in models_vp:
116
+ model_ids.append(model.id) if not model.private else ""
117
+
118
+ anime_models = []
119
+ real_models = []
120
+ for model in models_john:
121
+ if not model.private:
122
+ anime_models.append(model.id) if 'anime' in model.tags else real_models.append(model.id)
123
+
124
+ model_ids.extend(anime_models)
125
+ model_ids.extend(real_models)
126
+
127
+ return model_ids
128
+
129
+
130
+ def get_t2i_model_info(repo_id: str):
131
+ from huggingface_hub import HfApi
132
+ api = HfApi()
133
+ if " " in repo_id or not api.repo_exists(repo_id): return ""
134
+ model = api.model_info(repo_id=repo_id)
135
+ if model.private or model.gated: return ""
136
+
137
+ tags = model.tags
138
+ info = []
139
+ url = f"https://huggingface.co/{repo_id}/"
140
+ if not 'diffusers' in tags: return ""
141
+ if 'diffusers:StableDiffusionXLPipeline' in tags:
142
+ info.append("SDXL")
143
+ elif 'diffusers:StableDiffusionPipeline' in tags:
144
+ info.append("SD1.5")
145
+ if model.card_data and model.card_data.tags:
146
+ info.extend(list_sub(model.card_data.tags, ['text-to-image', 'stable-diffusion', 'stable-diffusion-api', 'safetensors', 'stable-diffusion-xl']))
147
+ info.append(f"DLs: {model.downloads}")
148
+ info.append(f"likes: {model.likes}")
149
+ info.append(model.last_modified.strftime("lastmod: %Y-%m-%d"))
150
+
151
+ md = f"Model Info: {', '.join(info)}, [Model Repo]({url})"
152
+ return gr.update(value=md)
153
+
154
+
155
+ def get_tupled_model_list(model_list):
156
+ if not model_list: return []
157
+ tupled_list = []
158
+ for repo_id in model_list:
159
+ from huggingface_hub import HfApi
160
+ api = HfApi()
161
+ if not api.repo_exists(repo_id): continue
162
+ model = api.model_info(repo_id=repo_id)
163
+ if model.private or model.gated: continue
164
+
165
+ tags = model.tags
166
+ info = []
167
+ if not 'diffusers' in tags: continue
168
+ if 'diffusers:StableDiffusionXLPipeline' in tags:
169
+ info.append("SDXL")
170
+ elif 'diffusers:StableDiffusionPipeline' in tags:
171
+ info.append("SD1.5")
172
+ if model.card_data and model.card_data.tags:
173
+ info.extend(list_sub(model.card_data.tags, ['text-to-image', 'stable-diffusion', 'stable-diffusion-api', 'safetensors', 'stable-diffusion-xl']))
174
+
175
+ if "pony" in info:
176
+ info.remove("pony")
177
+ name = f"{repo_id} (Pony🐴, {', '.join(info)})"
178
+ else:
179
+ name = f"{repo_id} ({', '.join(info)})"
180
+ tupled_list.append((name, repo_id))
181
+ return tupled_list
182
+
183
+
184
+ def save_gallery_images(images):
185
+ from datetime import datetime, timezone, timedelta
186
+ japan_tz = timezone(timedelta(hours=9))
187
+ dt_now = datetime.utcnow().replace(tzinfo=timezone.utc).astimezone(japan_tz)
188
+ basename = dt_now.strftime('%Y%m%d_%H%M%S_')
189
+ i = 1
190
+ if not images: return images
191
+ output_images = []
192
+ output_paths = []
193
+ for image in images:
194
+ from pathlib import Path
195
+ filename = basename + str(i) + ".png"
196
+ oldpath = Path(image[0])
197
+ newpath = oldpath.rename(Path(filename))
198
+ output_paths.append(str(newpath))
199
+ output_images.append((str(newpath), str(filename)))
200
+ i += 1
201
+ return gr.update(value=output_images), gr.update(value=output_paths), gr.update(visible=True),
202
+
203
+
204
+ optimization_list = {
205
+ "None": [28, 7., 'Euler a', False, None, 1.],
206
+ "Default": [28, 7., 'Euler a', False, None, 1.],
207
+ "SPO": [28, 7., 'Euler a', True, 'loras/spo_sdxl_10ep_4k-data_lora_diffusers.safetensors', 1.],
208
+ "DPO": [28, 7., 'Euler a', True, 'loras/sdxl-DPO-LoRA.safetensors', 1.],
209
+ "DPO Turbo": [8, 2.5, 'LCM', True, 'loras/sd_xl_dpo_turbo_lora_v1-128dim.safetensors', 1.],
210
+ "SDXL Turbo": [8, 2.5, 'LCM', True, 'loras/sd_xl_turbo_lora_v1.safetensors', 1.],
211
+ "Hyper-SDXL 12step": [12, 5., 'Euler a trailing', True, 'loras/Hyper-SDXL-12steps-CFG-lora.safetensors', 0.9],
212
+ "Hyper-SDXL 8step": [8, 5., 'Euler a trailing', True, 'loras/Hyper-SDXL-8steps-CFG-lora.safetensors', 0.9],
213
+ "Hyper-SDXL 4step": [4, 0, 'Euler a trailing', True, 'loras/Hyper-SDXL-4steps-lora.safetensors', 0.9],
214
+ "Hyper-SDXL 2step": [2, 0, 'Euler a trailing', True, 'loras/Hyper-SDXL-2steps-lora.safetensors', 0.9],
215
+ "Hyper-SDXL 1step": [1, 0, 'Euler a trailing', True, 'loras/Hyper-SDXL-1steps-lora.safetensors', 0.9],
216
+ "PCM 16step": [16, 4., 'Euler a trailing', True, 'loras/pcm_sdxl_normalcfg_16step_converted.safetensors', 1.],
217
+ "PCM 8step": [8, 4., 'Euler a trailing', True, 'loras/pcm_sdxl_normalcfg_8step_converted.safetensors', 1.],
218
+ "PCM 4step": [4, 2., 'Euler a trailing', True, 'loras/pcm_sdxl_smallcfg_4step_converted.safetensors', 1.],
219
+ "PCM 2step": [2, 1., 'Euler a trailing', True, 'loras/pcm_sdxl_smallcfg_2step_converted.safetensors', 1.],
220
+ }
221
+
222
+
223
+ def set_optimization(opt, steps_gui, cfg_gui, sampler_gui, clip_skip_gui, lora1_gui, lora_scale_1_gui):
224
+ if not opt in list(optimization_list.keys()): opt = "None"
225
+ def_steps_gui = 28
226
+ def_cfg_gui = 7.
227
+ steps = optimization_list.get(opt, "None")[0]
228
+ cfg = optimization_list.get(opt, "None")[1]
229
+ sampler = optimization_list.get(opt, "None")[2]
230
+ clip_skip = optimization_list.get(opt, "None")[3]
231
+ lora1 = optimization_list.get(opt, "None")[4]
232
+ lora_scale_1 = optimization_list.get(opt, "None")[5]
233
+ if opt == "None":
234
+ steps = max(steps_gui, def_steps_gui)
235
+ cfg = max(cfg_gui, def_cfg_gui)
236
+ clip_skip = clip_skip_gui
237
+ elif opt == "SPO" or opt == "DPO":
238
+ steps = max(steps_gui, def_steps_gui)
239
+ cfg = max(cfg_gui, def_cfg_gui)
240
+
241
+ return gr.update(value=steps), gr.update(value=cfg), gr.update(value=sampler),\
242
+ gr.update(value=clip_skip), gr.update(value=lora1), gr.update(value=lora_scale_1),
243
+
244
+
245
+ def set_lora_prompt(prompt_gui, prompt_syntax_gui, lora1_gui, lora_scale_1_gui, lora2_gui, lora_scale_2_gui, lora3_gui, lora_scale_3_gui, lora4_gui, lora_scale_4_gui, lora5_gui, lora_scale_5_gui):
246
+ import os
247
+ if not "Classic" in str(prompt_syntax_gui): return prompt_gui
248
+ loras = []
249
+ if lora1_gui:
250
+ basename = os.path.splitext(os.path.basename(lora1_gui))[0]
251
+ loras.append(f"<lora:{basename}:{lora_scale_1_gui:.2f}>")
252
+ if lora2_gui:
253
+ basename = os.path.splitext(os.path.basename(lora2_gui))[0]
254
+ loras.append(f"<lora:{basename}:{lora_scale_2_gui:.2f}>")
255
+ if lora3_gui:
256
+ basename = os.path.splitext(os.path.basename(lora3_gui))[0]
257
+ loras.append(f"<lora:{basename}:{lora_scale_3_gui:.2f}>")
258
+ if lora4_gui:
259
+ basename = os.path.splitext(os.path.basename(lora4_gui))[0]
260
+ loras.append(f"<lora:{basename}:{lora_scale_4_gui:.2f}>")
261
+ if lora5_gui:
262
+ basename = os.path.splitext(os.path.basename(lora5_gui))[0]
263
+ loras.append(f"<lora:{basename}:{lora_scale_5_gui:.2f}>")
264
+
265
+ tags = prompt_gui.split(",") if prompt_gui else []
266
+ prompts = []
267
+ for tag in tags:
268
+ tag = str(tag).strip()
269
+ if tag and not "<lora" in tag:
270
+ prompts.append(tag)
271
+
272
+ empty = [""]
273
+ prompt = ", ".join(prompts + loras + empty)
274
+
275
+ return gr.update(value=prompt)
276
+
277
+
278
+ temp_dict = {}
279
+ lora_trigger_dict = {}
280
+ with open('lora_dict.json', encoding='utf-8') as f:
281
+ temp_dict = json.load(f)
282
+ for k, v in temp_dict.items():
283
+ lora_trigger_dict[escape_lora_basename(k)] = v
284
+
285
+
286
+ civitai_not_exists_list = []
287
+
288
+
289
+ def get_civitai_info(path):
290
+ global civitai_not_exists_list
291
+ import requests
292
+ if path in set(civitai_not_exists_list): return ["", "", "", "", ""]
293
+ from pathlib import Path
294
+ if not Path(path).exists(): return None
295
+ from fake_useragent import UserAgent
296
+ ua = UserAgent()
297
+ user_agent = ua.random
298
+ headers = {'User-Agent': user_agent, 'content-type': 'application/json'}
299
+ base_url = 'https://civitai.com/api/v1/model-versions/by-hash/'
300
+ params = {}
301
+ import hashlib
302
+ with open(path, 'rb') as file:
303
+ file_data = file.read()
304
+ hash_sha256 = hashlib.sha256(file_data).hexdigest()
305
+ url = base_url + hash_sha256
306
+ r = requests.get(url, params=params, headers=headers, timeout=(3.0, 7.5))
307
+ if not r.ok: return None
308
+ json = r.json()
309
+ if not 'baseModel' in json:
310
+ civitai_not_exists_list.append(path)
311
+ return ["", "", "", "", ""]
312
+ items = []
313
+ items.append(" / ".join(json['trainedWords']))
314
+ items.append(json['baseModel'])
315
+ items.append(json['model']['name'])
316
+ items.append(f"https://civitai.com/models/{json['modelId']}")
317
+ items.append(json['images'][0]['url'])
318
+ return items
319
+
320
+
321
+ def update_lora_dict(path):
322
+ global lora_trigger_dict
323
+ from pathlib import Path
324
+ key = escape_lora_basename(Path(path).stem)
325
+ if key in lora_trigger_dict.keys(): return
326
+ items = get_civitai_info(path)
327
+ if items == None: return
328
+ lora_trigger_dict[key] = items
329
+
330
+
331
+ def get_lora_tupled_list(lora_model_list):
332
+ global lora_trigger_dict
333
+ from pathlib import Path
334
+ if not lora_model_list: return []
335
+ tupled_list = []
336
+ local_models = set(get_model_list(directory_loras))
337
+ for model in lora_model_list:
338
+ basename = Path(model).stem
339
+ key = escape_lora_basename(basename)
340
+ items = None
341
+ if key in lora_trigger_dict.keys():
342
+ items = lora_trigger_dict.get(key, None)
343
+ elif model in local_models:
344
+ items = get_civitai_info(model)
345
+ if items != None:
346
+ lora_trigger_dict[key] = items
347
+ name = basename
348
+ value = model
349
+ if items and items[2] != "":
350
+ if items[1] == "Pony":
351
+ name = f"{basename} (for {items[1]}🐴, {items[2]})"
352
+ else:
353
+ name = f"{basename} (for {items[1]}, {items[2]})"
354
+ tupled_list.append((name, value))
355
+ return tupled_list
356
+
357
+
358
+ def set_lora_trigger(lora_gui: str):
359
+ from pathlib import Path
360
+ if not lora_gui or lora_gui == "None": return gr.update(value="", visible=False), gr.update(visible=False), gr.update(value="", visible=False), gr.update(value="", visible=True)
361
+ path = Path(lora_gui)
362
+ new_path = Path(f'{path.parent.name}/{escape_lora_basename(path.stem)}{path.suffix}')
363
+ if not new_path.stem in lora_trigger_dict.keys() and not str(path) in set(get_private_model_lists(HF_LORA_PRIVATE_REPOS, directory_loras) + get_model_list(directory_loras)):
364
+ return gr.update(value="", visible=False), gr.update(visible=False), gr.update(value="", visible=False), gr.update(value="", visible=True)
365
+ if not new_path.exists():
366
+ download_private_file_from_somewhere(str(path), True)
367
+ basename = new_path.stem
368
+ tag = ""
369
+ label = f'Trigger: {basename} / Prompt:'
370
+ value = "None"
371
+ md = "None"
372
+ flag = False
373
+ items = lora_trigger_dict.get(basename, None)
374
+ if items == None:
375
+ items = get_civitai_info(str(new_path))
376
+ if items != None:
377
+ lora_trigger_dict[basename] = items
378
+ flag = True
379
+ if items and items[2] != "":
380
+ tag = items[0]
381
+ label = f'Trigger: {basename} / Prompt:'
382
+ if items[1] == "Pony":
383
+ label = f'Trigger: {basename} / Prompt (for Pony🐴):'
384
+ if items[4]:
385
+ md = f'<img src="{items[4]}" alt="thumbnail" width="150" height="240"><br>[LoRA Model URL]({items[3]})'
386
+ elif items[3]:
387
+ md = f'[LoRA Model URL]({items[3]})'
388
+ if tag and flag:
389
+ new_lora_model_list = list_uniq(get_private_model_lists(HF_LORA_PRIVATE_REPOS, directory_loras) + get_model_list(directory_loras))
390
+ return gr.update(value=tag, label=label, visible=True), gr.update(visible=True),\
391
+ gr.update(value=md, visible=True), gr.update(value=str(new_path), choices=get_lora_tupled_list(new_lora_model_list))
392
+ elif tag:
393
+ return gr.update(value=tag, label=label, visible=True), gr.update(visible=True),\
394
+ gr.update(value=md, visible=True), gr.update(value=str(new_path))
395
+ else:
396
+ return gr.update(value=value, label=label, visible=True), gr.update(visible=True),\
397
+ gr.update(value=md, visible=True), gr.update(visible=True)
398
+
399
+
400
+ def apply_lora_prompt(prompt_gui: str, lora_trigger_gui: str):
401
+ if lora_trigger_gui == "None": return gr.update(value=prompt_gui)
402
+ tags = prompt_gui.split(",") if prompt_gui else []
403
+ prompts = normalize_prompt_list(tags)
404
+
405
+ lora_tag = lora_trigger_gui.replace("/",",")
406
+ lora_tags = lora_tag.split(",") if str(lora_trigger_gui) != "None" else []
407
+ lora_prompts = normalize_prompt_list(lora_tags)
408
+
409
+ empty = [""]
410
+ prompt = ", ".join(list_uniq(prompts + lora_prompts) + empty)
411
+ return gr.update(value=prompt)
412
+
413
+
414
+ def upload_file_lora(files):
415
+ file_paths = [file.name for file in files]
416
+ return gr.update(value=file_paths, visible=True)
417
+
418
+ def move_file_lora(filepaths):
419
+ import shutil
420
+ from pathlib import Path
421
+ for file in filepaths:
422
+ path = Path(shutil.move(Path(file).absolute(), Path(f"./{directory_loras}").absolute()))
423
+ newpath = Path(f'{path.parent.name}/{escape_lora_basename(path.stem)}{path.suffix}')
424
+ path.rename(newpath)
425
+ update_lora_dict(str(newpath))
426
+
427
+ new_lora_model_list = list_uniq(get_private_model_lists(HF_LORA_PRIVATE_REPOS, directory_loras) + get_model_list(directory_loras))
428
+ new_lora_model_list.insert(0, "None")
429
+
430
+ return gr.update(
431
+ choices=get_lora_tupled_list(new_lora_model_list)
432
+ ), gr.update(
433
+ choices=get_lora_tupled_list(new_lora_model_list)
434
+ ), gr.update(
435
+ choices=get_lora_tupled_list(new_lora_model_list)
436
+ ), gr.update(
437
+ choices=get_lora_tupled_list(new_lora_model_list)
438
+ ), gr.update(
439
+ choices=get_lora_tupled_list(new_lora_model_list)
440
+ ),
441
+
442
+
443
+ quality_prompt_list = [
444
+ {
445
+ "name": "None",
446
+ "prompt": "",
447
+ "negative_prompt": "lowres",
448
+ },
449
+ {
450
+ "name": "Animagine Common",
451
+ "prompt": "anime artwork, anime style, key visual, vibrant, studio anime, highly detailed, masterpiece, best quality, very aesthetic, absurdres",
452
+ "negative_prompt": "lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]",
453
+ },
454
+ {
455
+ "name": "Pony Anime Common",
456
+ "prompt": "source_anime, score_9, score_8_up, score_7_up, masterpiece, best quality, very aesthetic, absurdres",
457
+ "negative_prompt": "source_pony, source_furry, source_cartoon, score_6, score_5, score_4, busty, ugly face, mutated hands, low res, blurry face, black and white, the simpsons, overwatch, apex legends",
458
+ },
459
+ {
460
+ "name": "Pony Common",
461
+ "prompt": "source_anime, score_9, score_8_up, score_7_up",
462
+ "negative_prompt": "source_pony, source_furry, source_cartoon, score_6, score_5, score_4, busty, ugly face, mutated hands, low res, blurry face, black and white, the simpsons, overwatch, apex legends",
463
+ },
464
+ {
465
+ "name": "Animagine Standard v3.0",
466
+ "prompt": "masterpiece, best quality",
467
+ "negative_prompt": "lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry, artist name",
468
+ },
469
+ {
470
+ "name": "Animagine Standard v3.1",
471
+ "prompt": "masterpiece, best quality, very aesthetic, absurdres",
472
+ "negative_prompt": "lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]",
473
+ },
474
+ {
475
+ "name": "Animagine Light v3.1",
476
+ "prompt": "(masterpiece), best quality, very aesthetic, perfect face",
477
+ "negative_prompt": "(low quality, worst quality:1.2), very displeasing, 3d, watermark, signature, ugly, poorly drawn",
478
+ },
479
+ {
480
+ "name": "Animagine Heavy v3.1",
481
+ "prompt": "(masterpiece), (best quality), (ultra-detailed), very aesthetic, illustration, disheveled hair, perfect composition, moist skin, intricate details",
482
+ "negative_prompt": "longbody, lowres, bad anatomy, bad hands, missing fingers, pubic hair, extra digit, fewer digits, cropped, worst quality, low quality, very displeasing",
483
+ },
484
+ ]
485
+
486
+
487
+ style_list = [
488
+ {
489
+ "name": "None",
490
+ "prompt": "",
491
+ "negative_prompt": "",
492
+ },
493
+ {
494
+ "name": "Cinematic",
495
+ "prompt": "cinematic still, emotional, harmonious, vignette, highly detailed, high budget, bokeh, cinemascope, moody, epic, gorgeous, film grain, grainy",
496
+ "negative_prompt": "cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, ugly, disfigured",
497
+ },
498
+ {
499
+ "name": "Photographic",
500
+ "prompt": "cinematic photo, 35mm photograph, film, bokeh, professional, 4k, highly detailed",
501
+ "negative_prompt": "drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly",
502
+ },
503
+ {
504
+ "name": "Anime",
505
+ "prompt": "anime artwork, anime style, key visual, vibrant, studio anime, highly detailed",
506
+ "negative_prompt": "photo, deformed, black and white, realism, disfigured, low contrast",
507
+ },
508
+ {
509
+ "name": "Manga",
510
+ "prompt": "manga style, vibrant, high-energy, detailed, iconic, Japanese comic style",
511
+ "negative_prompt": "ugly, deformed, noisy, blurry, low contrast, realism, photorealistic, Western comic style",
512
+ },
513
+ {
514
+ "name": "Digital Art",
515
+ "prompt": "concept art, digital artwork, illustrative, painterly, matte painting, highly detailed",
516
+ "negative_prompt": "photo, photorealistic, realism, ugly",
517
+ },
518
+ {
519
+ "name": "Pixel art",
520
+ "prompt": "pixel-art, low-res, blocky, pixel art style, 8-bit graphics",
521
+ "negative_prompt": "sloppy, messy, blurry, noisy, highly detailed, ultra textured, photo, realistic",
522
+ },
523
+ {
524
+ "name": "Fantasy art",
525
+ "prompt": "ethereal fantasy concept art, magnificent, celestial, ethereal, painterly, epic, majestic, magical, fantasy art, cover art, dreamy",
526
+ "negative_prompt": "photographic, realistic, realism, 35mm film, dslr, cropped, frame, text, deformed, glitch, noise, noisy, off-center, deformed, cross-eyed, closed eyes, bad anatomy, ugly, disfigured, sloppy, duplicate, mutated, black and white",
527
+ },
528
+ {
529
+ "name": "Neonpunk",
530
+ "prompt": "neonpunk style, cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, ultra detailed, intricate, professional",
531
+ "negative_prompt": "painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured",
532
+ },
533
+ {
534
+ "name": "3D Model",
535
+ "prompt": "professional 3d model, octane render, highly detailed, volumetric, dramatic lighting",
536
+ "negative_prompt": "ugly, deformed, noisy, low poly, blurry, painting",
537
+ },
538
+ ]
539
+
540
+
541
+ # [sampler_gui, steps_gui, cfg_gui, clip_skip_gui, img_width_gui, img_height_gui, optimization_gui]
542
+ preset_sampler_setting = {
543
+ "None": ["Euler a", 28, 7., True, 1024, 1024, "None"],
544
+ "Anime 3:4 Fast": ["LCM", 8, 2.5, True, 896, 1152, "DPO Turbo"],
545
+ "Anime 3:4 Standard": ["Euler a", 28, 7., True, 896, 1152, "None"],
546
+ "Anime 3:4 Heavy": ["Euler a", 40, 7., True, 896, 1152, "None"],
547
+ "Anime 1:1 Fast": ["LCM", 8, 2.5, True, 1024, 1024, "DPO Turbo"],
548
+ "Anime 1:1 Standard": ["Euler a", 28, 7., True, 1024, 1024, "None"],
549
+ "Anime 1:1 Heavy": ["Euler a", 40, 7., True, 1024, 1024, "None"],
550
+ "Photo 3:4 Fast": ["LCM", 8, 2.5, False, 896, 1152, "DPO Turbo"],
551
+ "Photo 3:4 Standard": ["DPM++ 2M Karras", 28, 7., False, 896, 1152, "None"],
552
+ "Photo 3:4 Heavy": ["DPM++ 2M Karras", 40, 7., False, 896, 1152, "None"],
553
+ "Photo 1:1 Fast": ["LCM", 8, 2.5, False, 1024, 1024, "DPO Turbo"],
554
+ "Photo 1:1 Standard": ["DPM++ 2M Karras", 28, 7., False, 1024, 1024, "None"],
555
+ "Photo 1:1 Heavy": ["DPM++ 2M Karras", 40, 7., False, 1024, 1024, "None"],
556
+ }
557
+
558
+
559
+ def set_sampler_settings(sampler_setting):
560
+ if not sampler_setting in list(preset_sampler_setting.keys()) or sampler_setting == "None":
561
+ return gr.update(value="Euler a"), gr.update(value=28), gr.update(value=7.), gr.update(value=True),\
562
+ gr.update(value=1024), gr.update(value=1024), gr.update(value="None")
563
+ v = preset_sampler_setting.get(sampler_setting, ["Euler a", 28, 7., True, 1024, 1024])
564
+ # sampler, steps, cfg, clip_skip, width, height, optimization
565
+ return gr.update(value=v[0]), gr.update(value=v[1]), gr.update(value=v[2]), gr.update(value=v[3]),\
566
+ gr.update(value=v[4]), gr.update(value=v[5]), gr.update(value=v[6])
567
+
568
+
569
+ preset_styles = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in style_list}
570
+ preset_quality = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in quality_prompt_list}
571
+
572
+
573
+ def process_style_prompt(prompt: str, neg_prompt: str, styles_key: str = "None", quality_key: str = "None", type: str = "None"):
574
+ def to_list(s):
575
+ return [x.strip() for x in s.split(",") if not s == ""]
576
+
577
+ def list_sub(a, b):
578
+ return [e for e in a if e not in b]
579
+
580
+ def list_uniq(l):
581
+ return sorted(set(l), key=l.index)
582
+
583
+ animagine_ps = to_list("anime artwork, anime style, key visual, vibrant, studio anime, highly detailed, masterpiece, best quality, very aesthetic, absurdres")
584
+ animagine_nps = to_list("lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]")
585
+ pony_ps = to_list("source_anime, score_9, score_8_up, score_7_up, masterpiece, best quality, very aesthetic, absurdres")
586
+ pony_nps = to_list("source_pony, source_furry, source_cartoon, score_6, score_5, score_4, busty, ugly face, mutated hands, low res, blurry face, black and white, the simpsons, overwatch, apex legends")
587
+ prompts = to_list(prompt)
588
+ neg_prompts = to_list(neg_prompt)
589
+
590
+ all_styles_ps = []
591
+ all_styles_nps = []
592
+ for d in style_list:
593
+ all_styles_ps.extend(to_list(str(d.get("prompt", ""))))
594
+ all_styles_nps.extend(to_list(str(d.get("negative_prompt", ""))))
595
+
596
+ all_quality_ps = []
597
+ all_quality_nps = []
598
+ for d in quality_prompt_list:
599
+ all_quality_ps.extend(to_list(str(d.get("prompt", ""))))
600
+ all_quality_nps.extend(to_list(str(d.get("negative_prompt", ""))))
601
+
602
+ quality_ps = to_list(preset_quality[quality_key][0])
603
+ quality_nps = to_list(preset_quality[quality_key][1])
604
+ styles_ps = to_list(preset_styles[styles_key][0])
605
+ styles_nps = to_list(preset_styles[styles_key][1])
606
+
607
+ prompts = list_sub(prompts, animagine_ps + pony_ps + all_styles_ps + all_quality_ps)
608
+ neg_prompts = list_sub(neg_prompts, animagine_nps + pony_nps + all_styles_nps + all_quality_nps)
609
+
610
+ last_empty_p = [""] if not prompts and type != "None" and styles_key != "None" and quality_key != "None" else []
611
+ last_empty_np = [""] if not neg_prompts and type != "None" and styles_key != "None" and quality_key != "None" else []
612
+
613
+ if type == "Animagine":
614
+ prompts = prompts + animagine_ps
615
+ neg_prompts = neg_prompts + animagine_nps
616
+ elif type == "Pony":
617
+ prompts = prompts + pony_ps
618
+ neg_prompts = neg_prompts + pony_nps
619
+
620
+ prompts = prompts + styles_ps + quality_ps
621
+ neg_prompts = neg_prompts + styles_nps + quality_nps
622
+
623
+ prompt = ", ".join(list_uniq(prompts) + last_empty_p)
624
+ neg_prompt = ", ".join(list_uniq(neg_prompts) + last_empty_np)
625
+
626
+ return prompt, neg_prompt
627
+
628
+
629
+ def set_quick_presets(genre:str = "None", type:str = "None", speed:str = "None", aspect:str = "None"):
630
+ quality = "None"
631
+ style = "None"
632
+ sampler = "None"
633
+ opt = "None"
634
+
635
+ if genre == "Anime":
636
+ style = "Anime"
637
+ if aspect == "1:1":
638
+ if speed == "Heavy":
639
+ sampler = "Anime 1:1 Heavy"
640
+ elif speed == "Fast":
641
+ sampler = "Anime 1:1 Fast"
642
+ else:
643
+ sampler = "Anime 1:1 Standard"
644
+ elif aspect == "3:4":
645
+ if speed == "Heavy":
646
+ sampler = "Anime 3:4 Heavy"
647
+ elif speed == "Fast":
648
+ sampler = "Anime 3:4 Fast"
649
+ else:
650
+ sampler = "Anime 3:4 Standard"
651
+ if type == "Pony":
652
+ quality = "Pony Anime Common"
653
+ else:
654
+ quality = "Animagine Common"
655
+ elif genre == "Photo":
656
+ style = "Photographic"
657
+ if aspect == "1:1":
658
+ if speed == "Heavy":
659
+ sampler = "Photo 1:1 Heavy"
660
+ elif speed == "Fast":
661
+ sampler = "Photo 1:1 Fast"
662
+ else:
663
+ sampler = "Photo 1:1 Standard"
664
+ elif aspect == "3:4":
665
+ if speed == "Heavy":
666
+ sampler = "Photo 3:4 Heavy"
667
+ elif speed == "Fast":
668
+ sampler = "Photo 3:4 Fast"
669
+ else:
670
+ sampler = "Photo 3:4 Standard"
671
+ if type == "Pony":
672
+ quality = "Pony Common"
673
+ else:
674
+ quality = "None"
675
+
676
+ if speed == "Fast":
677
+ opt = "DPO Turbo"
678
+ if genre == "Anime" and type != "Pony": quality = "Animagine Light v3.1"
679
+
680
+ return gr.update(value=quality), gr.update(value=style), gr.update(value=sampler), gr.update(value=opt)
681
+
682
+
683
+ textual_inversion_dict = {}
684
+ with open('textual_inversion_dict.json', encoding='utf-8') as f:
685
+ textual_inversion_dict = json.load(f)
686
+
687
+
688
+ textual_inversion_file_token_list = []
689
+
690
+
691
+ def get_tupled_embed_list(embed_list):
692
+ from pathlib import Path
693
+ global textual_inversion_file_list
694
+ tupled_list = []
695
+ for file in embed_list:
696
+ token = textual_inversion_dict.get(Path(file).name, [Path(file).stem.replace(",",""), False])[0]
697
+ tupled_list.append((token, file))
698
+ textual_inversion_file_token_list.append(token)
699
+ return tupled_list
700
+
701
+
702
+ def set_textual_inversion_prompt(textual_inversion_gui, prompt_gui, neg_prompt_gui, prompt_syntax_gui):
703
+ ti_tags = list(textual_inversion_dict.values()) + textual_inversion_file_token_list
704
+ tags = prompt_gui.split(",") if prompt_gui else []
705
+ prompts = []
706
+ for tag in tags:
707
+ tag = str(tag).strip()
708
+ if tag and not tag in ti_tags:
709
+ prompts.append(tag)
710
+
711
+ ntags = neg_prompt_gui.split(",") if neg_prompt_gui else []
712
+ neg_prompts = []
713
+ for tag in ntags:
714
+ tag = str(tag).strip()
715
+ if tag and not tag in ti_tags:
716
+ neg_prompts.append(tag)
717
+
718
+ ti_prompts = []
719
+ ti_neg_prompts = []
720
+ for ti in textual_inversion_gui:
721
+ from pathlib import Path
722
+ tokens = textual_inversion_dict.get(Path(ti).name, [Path(ti).stem.replace(",",""), False])
723
+ is_positive = tokens[1] == True or "positive" in Path(ti).parent.name
724
+ if is_positive: # positive prompt
725
+ ti_prompts.append(tokens[0])
726
+ else: # negative prompt (default)
727
+ ti_neg_prompts.append(tokens[0])
728
+
729
+ empty = [""]
730
+ prompt = ", ".join(prompts + ti_prompts + empty)
731
+ neg_prompt = ", ".join(neg_prompts + ti_neg_prompts + empty)
732
+
733
+ return gr.update(value=prompt), gr.update(value=neg_prompt),
734
+
735
+
736
+ def get_model_pipeline(repo_id: str):
737
+ from huggingface_hub import HfApi
738
+ api = HfApi()
739
+ default = "StableDiffusionPipeline"
740
+ if " " in repo_id or not api.repo_exists(repo_id): return default
741
+ model = api.model_info(repo_id=repo_id)
742
+ if model.private or model.gated: return default
743
+
744
+ tags = model.tags
745
+ if not 'diffusers' in tags: return default
746
+ if 'diffusers:StableDiffusionXLPipeline' in tags:
747
+ return "StableDiffusionXLPipeline"
748
+ elif 'diffusers:StableDiffusionPipeline' in tags:
749
+ return "StableDiffusionPipeline"
750
+ else:
751
+ return default
output.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+
4
+ @dataclass
5
+ class UpsamplingOutput:
6
+ upsampled_tags: str
7
+
8
+ copyright_tags: str
9
+ character_tags: str
10
+ general_tags: str
11
+ rating_tag: str
12
+ aspect_ratio_tag: str
13
+ length_tag: str
14
+ identity_tag: str
15
+
16
+ elapsed_time: float = 0.0
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ git-lfs aria2 -y ffmpeg
pre-requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ pip>=23.0.0
requirements.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ git+https://github.com/R3gm/stablepy.git@test_stream
2
+ torch==2.2.0
3
+ gdown
4
+ opencv-python
5
+ pytorch-lightning
6
+ torchvision
7
+ accelerate
8
+ transformers
9
+ optimum[onnxruntime]
10
+ spaces
11
+ dartrs
12
+ huggingface_hub
13
+ httpx==0.13.3
14
+ httpcore
15
+ googletrans==4.0.0rc1
16
+ timm
17
+ fake-useragent
spiral_no_transparent.png ADDED
stablepy_model.py ADDED
The diff for this file is too large to render. See raw diff
 
tag_group.csv ADDED
The diff for this file is too large to render. See raw diff
 
tagger.py ADDED
@@ -0,0 +1,450 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import torch
3
+ import gradio as gr
4
+ import spaces # ZERO GPU
5
+
6
+ from transformers import (
7
+ AutoImageProcessor,
8
+ AutoModelForImageClassification,
9
+ )
10
+
11
+ WD_MODEL_NAMES = ["p1atdev/wd-swinv2-tagger-v3-hf"]
12
+ WD_MODEL_NAME = WD_MODEL_NAMES[0]
13
+
14
+ wd_model = AutoModelForImageClassification.from_pretrained(WD_MODEL_NAME, trust_remote_code=True)
15
+ wd_model.to("cuda" if torch.cuda.is_available() else "cpu")
16
+ wd_processor = AutoImageProcessor.from_pretrained(WD_MODEL_NAME, trust_remote_code=True)
17
+
18
+
19
+ def _people_tag(noun: str, minimum: int = 1, maximum: int = 5):
20
+ return (
21
+ [f"1{noun}"]
22
+ + [f"{num}{noun}s" for num in range(minimum + 1, maximum + 1)]
23
+ + [f"{maximum+1}+{noun}s"]
24
+ )
25
+
26
+
27
+ PEOPLE_TAGS = (
28
+ _people_tag("girl") + _people_tag("boy") + _people_tag("other") + ["no humans"]
29
+ )
30
+
31
+
32
+ RATING_MAP = {
33
+ "general": "safe",
34
+ "sensitive": "sensitive",
35
+ "questionable": "nsfw",
36
+ "explicit": "explicit, nsfw",
37
+ }
38
+ DANBOORU_TO_E621_RATING_MAP = {
39
+ "safe": "rating_safe",
40
+ "sensitive": "rating_safe",
41
+ "nsfw": "rating_explicit",
42
+ "explicit, nsfw": "rating_explicit",
43
+ "explicit": "rating_explicit",
44
+ "rating:safe": "rating_safe",
45
+ "rating:general": "rating_safe",
46
+ "rating:sensitive": "rating_safe",
47
+ "rating:questionable, nsfw": "rating_explicit",
48
+ "rating:explicit, nsfw": "rating_explicit",
49
+ }
50
+
51
+
52
+ def load_dict_from_csv(filename):
53
+ with open(filename, 'r', encoding="utf-8") as f:
54
+ lines = f.readlines()
55
+ dict = {}
56
+ for line in lines:
57
+ parts = line.strip().split(',')
58
+ dict[parts[0]] = parts[1]
59
+ return dict
60
+
61
+
62
+ anime_series_dict = load_dict_from_csv('character_series_dict.csv')
63
+
64
+
65
+ def character_list_to_series_list(character_list):
66
+ output_series_tag = []
67
+ series_tag = ""
68
+ series_dict = anime_series_dict
69
+ for tag in character_list:
70
+ series_tag = series_dict.get(tag, "")
71
+ if tag.endswith(")"):
72
+ tags = tag.split("(")
73
+ character_tag = "(".join(tags[:-1])
74
+ if character_tag.endswith(" "):
75
+ character_tag = character_tag[:-1]
76
+ series_tag = tags[-1].replace(")", "")
77
+
78
+ if series_tag:
79
+ output_series_tag.append(series_tag)
80
+
81
+ return output_series_tag
82
+
83
+
84
+ def danbooru_to_e621(dtag, e621_dict):
85
+ def d_to_e(match, e621_dict):
86
+ dtag = match.group(0)
87
+ etag = e621_dict.get(dtag.strip().replace("_", " "), "")
88
+ if etag:
89
+ return etag
90
+ else:
91
+ return dtag
92
+
93
+ import re
94
+ tag = re.sub(r'[\w ]+', lambda wrapper: d_to_e(wrapper, e621_dict), dtag, 2)
95
+
96
+ return tag
97
+
98
+
99
+ danbooru_to_e621_dict = load_dict_from_csv('danbooru_e621.csv')
100
+
101
+
102
+ def convert_danbooru_to_e621_prompt(input_prompt: str = "", prompt_type: str = "danbooru"):
103
+ if prompt_type == "danbooru": return input_prompt
104
+ tags = input_prompt.split(",") if input_prompt else []
105
+ people_tags: list[str] = []
106
+ other_tags: list[str] = []
107
+ rating_tags: list[str] = []
108
+
109
+ e621_dict = danbooru_to_e621_dict
110
+ for tag in tags:
111
+ tag = tag.strip().replace("_", " ")
112
+ tag = danbooru_to_e621(tag, e621_dict)
113
+ if tag in PEOPLE_TAGS:
114
+ people_tags.append(tag)
115
+ elif tag in DANBOORU_TO_E621_RATING_MAP.keys():
116
+ rating_tags.append(DANBOORU_TO_E621_RATING_MAP.get(tag.replace(" ",""), ""))
117
+ else:
118
+ other_tags.append(tag)
119
+
120
+ rating_tags = sorted(set(rating_tags), key=rating_tags.index)
121
+ rating_tags = [rating_tags[0]] if rating_tags else []
122
+ rating_tags = ["explicit, nsfw"] if rating_tags and rating_tags[0] == "explicit" else rating_tags
123
+
124
+ output_prompt = ", ".join(people_tags + other_tags + rating_tags)
125
+
126
+ return output_prompt
127
+
128
+
129
+ def translate_prompt(prompt: str = ""):
130
+ def translate_to_english(prompt):
131
+ import httpcore
132
+ setattr(httpcore, 'SyncHTTPTransport', 'AsyncHTTPProxy')
133
+ from googletrans import Translator
134
+ translator = Translator()
135
+ try:
136
+ translated_prompt = translator.translate(prompt, src='auto', dest='en').text
137
+ return translated_prompt
138
+ except Exception as e:
139
+ return prompt
140
+
141
+ def is_japanese(s):
142
+ import unicodedata
143
+ for ch in s:
144
+ name = unicodedata.name(ch, "")
145
+ if "CJK UNIFIED" in name or "HIRAGANA" in name or "KATAKANA" in name:
146
+ return True
147
+ return False
148
+
149
+ def to_list(s):
150
+ return [x.strip() for x in s.split(",")]
151
+
152
+ prompts = to_list(prompt)
153
+ outputs = []
154
+ for p in prompts:
155
+ p = translate_to_english(p) if is_japanese(p) else p
156
+ outputs.append(p)
157
+
158
+ return ", ".join(outputs)
159
+
160
+
161
+ def translate_prompt_to_ja(prompt: str = ""):
162
+ def translate_to_japanese(prompt):
163
+ import httpcore
164
+ setattr(httpcore, 'SyncHTTPTransport', 'AsyncHTTPProxy')
165
+ from googletrans import Translator
166
+ translator = Translator()
167
+ try:
168
+ translated_prompt = translator.translate(prompt, src='en', dest='ja').text
169
+ return translated_prompt
170
+ except Exception as e:
171
+ return prompt
172
+
173
+ def is_japanese(s):
174
+ import unicodedata
175
+ for ch in s:
176
+ name = unicodedata.name(ch, "")
177
+ if "CJK UNIFIED" in name or "HIRAGANA" in name or "KATAKANA" in name:
178
+ return True
179
+ return False
180
+
181
+ def to_list(s):
182
+ return [x.strip() for x in s.split(",")]
183
+
184
+ prompts = to_list(prompt)
185
+ outputs = []
186
+ for p in prompts:
187
+ p = translate_to_japanese(p) if not is_japanese(p) else p
188
+ outputs.append(p)
189
+
190
+ return ", ".join(outputs)
191
+
192
+
193
+ def tags_to_ja(itag, dict):
194
+ def t_to_j(match, dict):
195
+ tag = match.group(0)
196
+ ja = dict.get(tag.strip().replace("_", " "), "")
197
+ if ja:
198
+ return ja
199
+ else:
200
+ return tag
201
+
202
+ import re
203
+ tag = re.sub(r'[\w ]+', lambda wrapper: t_to_j(wrapper, dict), itag, 2)
204
+
205
+ return tag
206
+
207
+
208
+ def convert_tags_to_ja(input_prompt: str = ""):
209
+ tags = input_prompt.split(",") if input_prompt else []
210
+ out_tags = []
211
+
212
+ tags_to_ja_dict = load_dict_from_csv('all_tags_ja_ext.csv')
213
+ dict = tags_to_ja_dict
214
+ for tag in tags:
215
+ tag = tag.strip().replace("_", " ")
216
+ tag = tags_to_ja(tag, dict)
217
+ out_tags.append(tag)
218
+
219
+ return ", ".join(out_tags)
220
+
221
+
222
+ def insert_recom_prompt(prompt: str = "", neg_prompt: str = "", type: str = "None"):
223
+ def to_list(s):
224
+ return [x.strip() for x in s.split(",") if not s == ""]
225
+
226
+ def list_sub(a, b):
227
+ return [e for e in a if e not in b]
228
+
229
+ def list_uniq(l):
230
+ return sorted(set(l), key=l.index)
231
+
232
+ animagine_ps = to_list("anime artwork, anime style, key visual, vibrant, studio anime, highly detailed, masterpiece, best quality, very aesthetic, absurdres")
233
+ animagine_nps = to_list("lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]")
234
+ pony_ps = to_list("source_anime, score_9, score_8_up, score_7_up, masterpiece, best quality, very aesthetic, absurdres")
235
+ pony_nps = to_list("source_pony, source_furry, source_cartoon, score_6, score_5, score_4, busty, ugly face, mutated hands, low res, blurry face, black and white, the simpsons, overwatch, apex legends")
236
+ prompts = to_list(prompt)
237
+ neg_prompts = to_list(neg_prompt)
238
+
239
+ prompts = list_sub(prompts, animagine_ps + pony_ps)
240
+ neg_prompts = list_sub(neg_prompts, animagine_nps + pony_nps)
241
+
242
+ last_empty_p = [""] if not prompts and type != "None" else []
243
+ last_empty_np = [""] if not neg_prompts and type != "None" else []
244
+
245
+ if type == "Animagine":
246
+ prompts = prompts + animagine_ps
247
+ neg_prompts = neg_prompts + animagine_nps
248
+ elif type == "Pony":
249
+ prompts = prompts + pony_ps
250
+ neg_prompts = neg_prompts + pony_nps
251
+
252
+ prompt = ", ".join(list_uniq(prompts) + last_empty_p)
253
+ neg_prompt = ", ".join(list_uniq(neg_prompts) + last_empty_np)
254
+
255
+ return prompt, neg_prompt
256
+
257
+
258
+ tag_group_dict = load_dict_from_csv('tag_group.csv')
259
+
260
+
261
+ def remove_specific_prompt(input_prompt: str = "", keep_tags: str = "all"):
262
+ def is_dressed(tag):
263
+ import re
264
+ p = re.compile(r'dress|cloth|uniform|costume|vest|sweater|coat|shirt|jacket|blazer|apron|leotard|hood|sleeve|skirt|shorts|pant|loafer|ribbon|necktie|bow|collar|glove|sock|shoe|boots|wear|emblem')
265
+ return p.search(tag)
266
+
267
+ def is_background(tag):
268
+ import re
269
+ p = re.compile(r'background|outline|light|sky|build|day|screen|tree|city')
270
+ return p.search(tag)
271
+
272
+ un_tags = ['solo']
273
+ group_list = ['groups', 'body_parts', 'attire', 'posture', 'objects', 'creatures', 'locations', 'disambiguation_pages', 'commonly_misused_tags', 'phrases', 'verbs_and_gerunds', 'subjective', 'nudity', 'sex_objects', 'sex', 'sex_acts', 'image_composition', 'artistic_license', 'text', 'year_tags', 'metatags']
274
+ keep_group_dict = {
275
+ "body": ['groups', 'body_parts'],
276
+ "dress": ['groups', 'body_parts', 'attire'],
277
+ "all": group_list,
278
+ }
279
+
280
+ def is_necessary(tag, keep_tags, group_dict):
281
+ if keep_tags == "all":
282
+ return True
283
+ elif tag in un_tags or group_dict.get(tag, "") in explicit_group:
284
+ return False
285
+ elif keep_tags == "body" and is_dressed(tag):
286
+ return False
287
+ elif is_background(tag):
288
+ return False
289
+ else:
290
+ return True
291
+
292
+ if keep_tags == "all": return input_prompt
293
+ keep_group = keep_group_dict.get(keep_tags, keep_group_dict["body"])
294
+ explicit_group = list(set(group_list) ^ set(keep_group))
295
+
296
+ tags = input_prompt.split(",") if input_prompt else []
297
+ people_tags: list[str] = []
298
+ other_tags: list[str] = []
299
+
300
+ group_dict = tag_group_dict
301
+ for tag in tags:
302
+ tag = tag.strip().replace("_", " ")
303
+ if tag in PEOPLE_TAGS:
304
+ people_tags.append(tag)
305
+ elif is_necessary(tag, keep_tags, group_dict):
306
+ other_tags.append(tag)
307
+
308
+ output_prompt = ", ".join(people_tags + other_tags)
309
+
310
+ return output_prompt
311
+
312
+
313
+ def sort_taglist(tags: list[str]):
314
+ if not tags: return []
315
+ character_tags: list[str] = []
316
+ series_tags: list[str] = []
317
+ people_tags: list[str] = []
318
+ group_list = ['groups', 'body_parts', 'attire', 'posture', 'objects', 'creatures', 'locations', 'disambiguation_pages', 'commonly_misused_tags', 'phrases', 'verbs_and_gerunds', 'subjective', 'nudity', 'sex_objects', 'sex', 'sex_acts', 'image_composition', 'artistic_license', 'text', 'year_tags', 'metatags']
319
+ group_tags = {}
320
+ other_tags: list[str] = []
321
+ rating_tags: list[str] = []
322
+
323
+ group_dict = tag_group_dict
324
+ group_set = set(group_dict.keys())
325
+ character_set = set(anime_series_dict.keys())
326
+ series_set = set(anime_series_dict.values())
327
+ rating_set = set(DANBOORU_TO_E621_RATING_MAP.keys()) | set(DANBOORU_TO_E621_RATING_MAP.values())
328
+
329
+ for tag in tags:
330
+ tag = tag.strip().replace("_", " ")
331
+ if tag in PEOPLE_TAGS:
332
+ people_tags.append(tag)
333
+ elif tag in rating_set:
334
+ rating_tags.append(tag)
335
+ elif tag in group_set:
336
+ elem = group_dict[tag]
337
+ group_tags[elem] = group_tags[elem] + [tag] if elem in group_tags else [tag]
338
+ elif tag in character_set:
339
+ character_tags.append(tag)
340
+ elif tag in series_set:
341
+ series_tags.append(tag)
342
+ else:
343
+ other_tags.append(tag)
344
+
345
+ output_group_tags: list[str] = []
346
+ for k in group_list:
347
+ output_group_tags.extend(group_tags.get(k, []))
348
+
349
+ rating_tags = [rating_tags[0]] if rating_tags else []
350
+ rating_tags = ["explicit, nsfw"] if rating_tags and rating_tags[0] == "explicit" else rating_tags
351
+
352
+ output_tags = character_tags + series_tags + people_tags + output_group_tags + other_tags + rating_tags
353
+
354
+ return output_tags
355
+
356
+
357
+ def sort_tags(tags: str):
358
+ if not tags: return ""
359
+ taglist: list[str] = []
360
+ for tag in tags.split(","):
361
+ taglist.append(tag.strip())
362
+ taglist = list(filter(lambda x: x != "", taglist))
363
+ return ", ".join(sort_taglist(taglist))
364
+
365
+
366
+ def postprocess_results(results: dict[str, float], general_threshold: float, character_threshold: float):
367
+ results = {
368
+ k: v for k, v in sorted(results.items(), key=lambda item: item[1], reverse=True)
369
+ }
370
+
371
+ rating = {}
372
+ character = {}
373
+ general = {}
374
+
375
+ for k, v in results.items():
376
+ if k.startswith("rating:"):
377
+ rating[k.replace("rating:", "")] = v
378
+ continue
379
+ elif k.startswith("character:"):
380
+ character[k.replace("character:", "")] = v
381
+ continue
382
+
383
+ general[k] = v
384
+
385
+ character = {k: v for k, v in character.items() if v >= character_threshold}
386
+ general = {k: v for k, v in general.items() if v >= general_threshold}
387
+
388
+ return rating, character, general
389
+
390
+
391
+ def gen_prompt(rating: list[str], character: list[str], general: list[str]):
392
+ people_tags: list[str] = []
393
+ other_tags: list[str] = []
394
+ rating_tag = RATING_MAP[rating[0]]
395
+
396
+ for tag in general:
397
+ if tag in PEOPLE_TAGS:
398
+ people_tags.append(tag)
399
+ else:
400
+ other_tags.append(tag)
401
+
402
+ all_tags = people_tags + other_tags
403
+
404
+ return ", ".join(all_tags)
405
+
406
+
407
+ @spaces.GPU()
408
+ def predict_tags(image: Image.Image, general_threshold: float = 0.3, character_threshold: float = 0.8):
409
+ inputs = wd_processor.preprocess(image, return_tensors="pt")
410
+
411
+ outputs = wd_model(**inputs.to(wd_model.device, wd_model.dtype))
412
+ logits = torch.sigmoid(outputs.logits[0]) # take the first logits
413
+
414
+ # get probabilities
415
+ results = {
416
+ wd_model.config.id2label[i]: float(logit.float()) for i, logit in enumerate(logits)
417
+ }
418
+
419
+ # rating, character, general
420
+ rating, character, general = postprocess_results(
421
+ results, general_threshold, character_threshold
422
+ )
423
+
424
+ prompt = gen_prompt(
425
+ list(rating.keys()), list(character.keys()), list(general.keys())
426
+ )
427
+
428
+ output_series_tag = ""
429
+ output_series_list = character_list_to_series_list(character.keys())
430
+ if output_series_list:
431
+ output_series_tag = output_series_list[0]
432
+ else:
433
+ output_series_tag = ""
434
+
435
+ return output_series_tag, ", ".join(character.keys()), prompt, gr.update(interactive=True),
436
+
437
+
438
+ def predict_tags_wd(image: Image.Image, input_tags: str, algo: list[str], general_threshold: float = 0.3, character_threshold: float = 0.8):
439
+ if not "Use WD Tagger" in algo and len(algo) != 0:
440
+ return "", "", input_tags, gr.update(interactive=True),
441
+ return predict_tags(image, general_threshold, character_threshold)
442
+
443
+
444
+ def compose_prompt_to_copy(character: str, series: str, general: str):
445
+ characters = character.split(",") if character else []
446
+ serieses = series.split(",") if series else []
447
+ generals = general.split(",") if general else []
448
+ tags = characters + serieses + generals
449
+ cprompt = ",".join(tags) if tags else ""
450
+ return cprompt
textual_inversion_dict.json ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bad_prompt_version2.pt": [
3
+ "bad_prompt",
4
+ false
5
+ ],
6
+ "EasyNegativeV2.safetensors": [
7
+ "EasyNegative",
8
+ false
9
+ ],
10
+ "bad-hands-5.pt": [
11
+ "bad_hand",
12
+ false
13
+ ],
14
+ "negativeXL_A.safetensors": [
15
+ "negativeXL_A",
16
+ false
17
+ ],
18
+ "negativeXL_B.safetensors": [
19
+ "negativeXL_B",
20
+ false
21
+ ],
22
+ "negativeXL_C.safetensors": [
23
+ "negativeXL_C",
24
+ false
25
+ ],
26
+ "negativeXL_D.safetensors": [
27
+ "negativeXL_D",
28
+ false
29
+ ],
30
+ "unaestheticXL2v10.safetensors": [
31
+ "2v10",
32
+ false
33
+ ],
34
+ "unaestheticXL_AYv1.safetensors": [
35
+ "_AYv1",
36
+ false
37
+ ],
38
+ "unaestheticXL_Alb2.safetensors": [
39
+ "_Alb2",
40
+ false
41
+ ],
42
+ "unaestheticXL_Jug6.safetensors": [
43
+ "_Jug6",
44
+ false
45
+ ],
46
+ "unaestheticXL_bp5.safetensors": [
47
+ "_bp5",
48
+ false
49
+ ],
50
+ "unaestheticXL_hk1.safetensors": [
51
+ "_hk1",
52
+ false
53
+ ],
54
+ "unaestheticXLv1.safetensors": [
55
+ "v1.0",
56
+ false
57
+ ],
58
+ "unaestheticXLv13.safetensors": [
59
+ "v1.3",
60
+ false
61
+ ],
62
+ "unaestheticXLv31.safetensors": [
63
+ "v3.1",
64
+ false
65
+ ],
66
+ "unaestheticXL_Sky3.1.safetensors": [
67
+ "_Sky3.1",
68
+ false
69
+ ],
70
+ "SimplePositiveXLv2.safetensors": [
71
+ "SIMPLEPOSITIVEXLV2",
72
+ true
73
+ ]
74
+ }
utils.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from dartrs.v2 import AspectRatioTag, LengthTag, RatingTag, IdentityTag
3
+
4
+
5
+ V2_ASPECT_RATIO_OPTIONS: list[AspectRatioTag] = [
6
+ "ultra_wide",
7
+ "wide",
8
+ "square",
9
+ "tall",
10
+ "ultra_tall",
11
+ ]
12
+ V2_RATING_OPTIONS: list[RatingTag] = [
13
+ "sfw",
14
+ "general",
15
+ "sensitive",
16
+ "nsfw",
17
+ "questionable",
18
+ "explicit",
19
+ ]
20
+ V2_LENGTH_OPTIONS: list[LengthTag] = [
21
+ "very_short",
22
+ "short",
23
+ "medium",
24
+ "long",
25
+ "very_long",
26
+ ]
27
+ V2_IDENTITY_OPTIONS: list[IdentityTag] = [
28
+ "none",
29
+ "lax",
30
+ "strict",
31
+ ]
32
+
33
+
34
+ # ref: https://qiita.com/tregu148/items/fccccbbc47d966dd2fc2
35
+ def gradio_copy_text(_text: None):
36
+ gr.Info("Copied!")
37
+
38
+
39
+ COPY_ACTION_JS = """\
40
+ (inputs, _outputs) => {
41
+ // inputs is the string value of the input_text
42
+ if (inputs.trim() !== "") {
43
+ navigator.clipboard.writeText(inputs);
44
+ }
45
+ }"""
v2.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import os
3
+ import torch
4
+ from typing import Callable
5
+
6
+ from dartrs.v2 import (
7
+ V2Model,
8
+ MixtralModel,
9
+ MistralModel,
10
+ compose_prompt,
11
+ LengthTag,
12
+ AspectRatioTag,
13
+ RatingTag,
14
+ IdentityTag,
15
+ )
16
+ from dartrs.dartrs import DartTokenizer
17
+ from dartrs.utils import get_generation_config
18
+
19
+
20
+ import gradio as gr
21
+ from gradio.components import Component
22
+
23
+ try:
24
+ import spaces
25
+ except ImportError:
26
+
27
+ class spaces:
28
+ def GPU(*args, **kwargs):
29
+ return lambda x: x
30
+
31
+
32
+ from output import UpsamplingOutput
33
+
34
+
35
+ HF_TOKEN = os.getenv("HF_TOKEN", None)
36
+
37
+ V2_ALL_MODELS = {
38
+ "dart-v2-moe-sft": {
39
+ "repo": "p1atdev/dart-v2-moe-sft",
40
+ "type": "sft",
41
+ "class": MixtralModel,
42
+ },
43
+ "dart-v2-sft": {
44
+ "repo": "p1atdev/dart-v2-sft",
45
+ "type": "sft",
46
+ "class": MistralModel,
47
+ },
48
+ }
49
+
50
+
51
+ def prepare_models(model_config: dict):
52
+ model_name = model_config["repo"]
53
+ tokenizer = DartTokenizer.from_pretrained(model_name, auth_token=HF_TOKEN)
54
+ model = model_config["class"].from_pretrained(model_name, auth_token=HF_TOKEN)
55
+
56
+ return {
57
+ "tokenizer": tokenizer,
58
+ "model": model,
59
+ }
60
+
61
+
62
+ def normalize_tags(tokenizer: DartTokenizer, tags: str):
63
+ """Just remove unk tokens."""
64
+ return ", ".join([tag for tag in tokenizer.tokenize(tags) if tag != "<|unk|>"])
65
+
66
+
67
+ @torch.no_grad()
68
+ def generate_tags(
69
+ model: V2Model,
70
+ tokenizer: DartTokenizer,
71
+ prompt: str,
72
+ ban_token_ids: list[int],
73
+ ):
74
+ output = model.generate(
75
+ get_generation_config(
76
+ prompt,
77
+ tokenizer=tokenizer,
78
+ temperature=1,
79
+ top_p=0.9,
80
+ top_k=100,
81
+ max_new_tokens=256,
82
+ ban_token_ids=ban_token_ids,
83
+ ),
84
+ )
85
+
86
+ return output
87
+
88
+
89
+ def _people_tag(noun: str, minimum: int = 1, maximum: int = 5):
90
+ return (
91
+ [f"1{noun}"]
92
+ + [f"{num}{noun}s" for num in range(minimum + 1, maximum + 1)]
93
+ + [f"{maximum+1}+{noun}s"]
94
+ )
95
+
96
+
97
+ PEOPLE_TAGS = (
98
+ _people_tag("girl") + _people_tag("boy") + _people_tag("other") + ["no humans"]
99
+ )
100
+
101
+
102
+ def gen_prompt_text(output: UpsamplingOutput):
103
+ # separate people tags (e.g. 1girl)
104
+ people_tags = []
105
+ other_general_tags = []
106
+
107
+ for tag in output.general_tags.split(","):
108
+ tag = tag.strip()
109
+ if tag in PEOPLE_TAGS:
110
+ people_tags.append(tag)
111
+ else:
112
+ other_general_tags.append(tag)
113
+
114
+ return ", ".join(
115
+ [
116
+ part.strip()
117
+ for part in [
118
+ *people_tags,
119
+ output.character_tags,
120
+ output.copyright_tags,
121
+ *other_general_tags,
122
+ output.upsampled_tags,
123
+ output.rating_tag,
124
+ ]
125
+ if part.strip() != ""
126
+ ]
127
+ )
128
+
129
+
130
+ def elapsed_time_format(elapsed_time: float) -> str:
131
+ return f"Elapsed: {elapsed_time:.2f} seconds"
132
+
133
+
134
+ def parse_upsampling_output(
135
+ upsampler: Callable[..., UpsamplingOutput],
136
+ ):
137
+ def _parse_upsampling_output(*args) -> tuple[str, str, dict]:
138
+ output = upsampler(*args)
139
+
140
+ return (
141
+ gen_prompt_text(output),
142
+ elapsed_time_format(output.elapsed_time),
143
+ gr.update(interactive=True),
144
+ gr.update(interactive=True),
145
+ )
146
+
147
+ return _parse_upsampling_output
148
+
149
+
150
+ class V2UI:
151
+ model_name: str | None = None
152
+ model: V2Model
153
+ tokenizer: DartTokenizer
154
+
155
+ input_components: list[Component] = []
156
+ generate_btn: gr.Button
157
+
158
+ def on_generate(
159
+ self,
160
+ model_name: str,
161
+ copyright_tags: str,
162
+ character_tags: str,
163
+ general_tags: str,
164
+ rating_tag: RatingTag,
165
+ aspect_ratio_tag: AspectRatioTag,
166
+ length_tag: LengthTag,
167
+ identity_tag: IdentityTag,
168
+ ban_tags: str,
169
+ *args,
170
+ ) -> UpsamplingOutput:
171
+ if self.model_name is None or self.model_name != model_name:
172
+ models = prepare_models(V2_ALL_MODELS[model_name])
173
+ self.model = models["model"]
174
+ self.tokenizer = models["tokenizer"]
175
+ self.model_name = model_name
176
+
177
+ # normalize tags
178
+ # copyright_tags = normalize_tags(self.tokenizer, copyright_tags)
179
+ # character_tags = normalize_tags(self.tokenizer, character_tags)
180
+ # general_tags = normalize_tags(self.tokenizer, general_tags)
181
+
182
+ ban_token_ids = self.tokenizer.encode(ban_tags.strip())
183
+
184
+ prompt = compose_prompt(
185
+ prompt=general_tags,
186
+ copyright=copyright_tags,
187
+ character=character_tags,
188
+ rating=rating_tag,
189
+ aspect_ratio=aspect_ratio_tag,
190
+ length=length_tag,
191
+ identity=identity_tag,
192
+ )
193
+
194
+ start = time.time()
195
+ upsampled_tags = generate_tags(
196
+ self.model,
197
+ self.tokenizer,
198
+ prompt,
199
+ ban_token_ids,
200
+ )
201
+ elapsed_time = time.time() - start
202
+
203
+ return UpsamplingOutput(
204
+ upsampled_tags=upsampled_tags,
205
+ copyright_tags=copyright_tags,
206
+ character_tags=character_tags,
207
+ general_tags=general_tags,
208
+ rating_tag=rating_tag,
209
+ aspect_ratio_tag=aspect_ratio_tag,
210
+ length_tag=length_tag,
211
+ identity_tag=identity_tag,
212
+ elapsed_time=elapsed_time,
213
+ )
214
+