Spaces:
Runtime error
Runtime error
refactor
Browse files- scripts/generate_prompt.py +48 -119
- scripts/process_utils.py +3 -5
scripts/generate_prompt.py
CHANGED
|
@@ -10,145 +10,74 @@ from tensorflow.keras.layers import TFSMLayer
|
|
| 10 |
from huggingface_hub import hf_hub_download
|
| 11 |
from pathlib import Path
|
| 12 |
|
| 13 |
-
#
|
| 14 |
IMAGE_SIZE = 448
|
| 15 |
|
| 16 |
-
#
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
CSV_FILE =
|
| 22 |
|
| 23 |
def preprocess_image(image):
|
| 24 |
-
|
| 25 |
-
|
| 26 |
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
pad_y = size - image.shape[0]
|
| 31 |
-
pad_l = pad_x // 2
|
| 32 |
-
pad_t = pad_y // 2
|
| 33 |
-
image = np.pad(image, ((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)), mode="constant", constant_values=255)
|
| 34 |
|
| 35 |
interp = cv2.INTER_AREA if size > IMAGE_SIZE else cv2.INTER_LANCZOS4
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
image = image.astype(np.float32)
|
| 39 |
-
return image
|
| 40 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
|
| 42 |
def load_wd14_tagger_model():
|
|
|
|
| 43 |
model_dir = "wd14_tagger_model"
|
| 44 |
-
repo_id = DEFAULT_WD14_TAGGER_REPO
|
| 45 |
-
|
| 46 |
if not os.path.exists(model_dir):
|
| 47 |
-
|
| 48 |
-
for file in FILES:
|
| 49 |
-
hf_hub_download(repo_id, file, cache_dir=model_dir, force_download=True, force_filename=file)
|
| 50 |
-
for file in SUB_DIR_FILES:
|
| 51 |
-
hf_hub_download(
|
| 52 |
-
repo_id,
|
| 53 |
-
file,
|
| 54 |
-
subfolder=SUB_DIR,
|
| 55 |
-
cache_dir=model_dir + "/" + SUB_DIR,
|
| 56 |
-
force_download=True,
|
| 57 |
-
force_filename=file,
|
| 58 |
-
)
|
| 59 |
else:
|
| 60 |
-
print("
|
| 61 |
-
|
| 62 |
-
# モデルを読み込む
|
| 63 |
-
model = TFSMLayer(model_dir, call_endpoint='serving_default')
|
| 64 |
-
return model
|
| 65 |
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
with open(
|
| 69 |
reader = csv.reader(f)
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
assert header[
|
| 74 |
-
|
| 75 |
-
general_tags = [row[1] for row in rows[1:] if row[2] == "0"]
|
| 76 |
-
character_tags = [row[1] for row in rows[1:] if row[2] == "4"]
|
| 77 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
tag_freq = {}
|
| 79 |
-
undesired_tags =
|
| 80 |
-
'swimsuit',
|
| 81 |
-
'leotard',
|
| 82 |
-
'saitama_(one-punch_man)',
|
| 83 |
-
'1boy',
|
| 84 |
-
]
|
| 85 |
-
|
| 86 |
-
probs = model(images, training=False)
|
| 87 |
-
probs = probs['predictions_sigmoid'].numpy()
|
| 88 |
|
|
|
|
| 89 |
tag_text_list = []
|
|
|
|
| 90 |
for prob in probs:
|
| 91 |
-
|
| 92 |
-
general_tag_text = ""
|
| 93 |
-
character_tag_text = ""
|
| 94 |
-
thresh = 0.35
|
| 95 |
for i, p in enumerate(prob[4:]):
|
| 96 |
-
if i < len(general_tags)
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
combined_tags.append(tag_name)
|
| 102 |
-
elif i >= len(general_tags) and p >= thresh:
|
| 103 |
-
tag_name = character_tags[i - len(general_tags)]
|
| 104 |
-
if tag_name not in undesired_tags:
|
| 105 |
-
tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
|
| 106 |
-
character_tag_text += ", " + tag_name
|
| 107 |
-
combined_tags.append(tag_name)
|
| 108 |
|
| 109 |
-
|
| 110 |
-
general_tag_text = general_tag_text[2:]
|
| 111 |
-
if len(character_tag_text) > 0:
|
| 112 |
-
character_tag_text = character_tag_text[2:]
|
| 113 |
-
|
| 114 |
-
tag_text = ", ".join(combined_tags)
|
| 115 |
-
tag_text_list.append(tag_text)
|
| 116 |
return tag_text_list
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
def generate_prompt_json(target_folder, prompt_file, model_dir, model):
|
| 120 |
-
image_files = [f for f in os.listdir(target_folder) if os.path.isfile(os.path.join(target_folder, f))]
|
| 121 |
-
image_count = len(image_files)
|
| 122 |
-
|
| 123 |
-
prompt_list = []
|
| 124 |
-
|
| 125 |
-
for i, filename in enumerate(image_files, 1):
|
| 126 |
-
source_path = "source/" + filename
|
| 127 |
-
target_path = os.path.join(target_folder, filename) # Use absolute path
|
| 128 |
-
target_path2 = "target/" + filename
|
| 129 |
-
|
| 130 |
-
prompt = generate_tags(target_path, model_dir, model)
|
| 131 |
-
|
| 132 |
-
for j in range(4):
|
| 133 |
-
prompt_data = {
|
| 134 |
-
"source": f"{source_path.split('.')[0]}_{j}.jpg",
|
| 135 |
-
"target": f"{target_path2.split('.')[0]}_{j}.jpg",
|
| 136 |
-
"prompt": prompt
|
| 137 |
-
}
|
| 138 |
-
|
| 139 |
-
prompt_list.append(prompt_data)
|
| 140 |
-
|
| 141 |
-
print(f"Processed Images: {i}/{image_count}", end="\r", flush=True)
|
| 142 |
-
|
| 143 |
-
with open(prompt_file, "w") as file:
|
| 144 |
-
for prompt_data in prompt_list:
|
| 145 |
-
json.dump(prompt_data, file)
|
| 146 |
-
file.write("\n")
|
| 147 |
-
|
| 148 |
-
print(f"Processing completed. Total Images: {image_count}")
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
if __name__ == '__main__':
|
| 152 |
-
model_dir = "wd14_tagger_model"
|
| 153 |
-
model = load_wd14_tagger_model()
|
| 154 |
-
prompt = generate_tags(target_path, model_dir, model)
|
|
|
|
| 10 |
from huggingface_hub import hf_hub_download
|
| 11 |
from pathlib import Path
|
| 12 |
|
| 13 |
+
# 画像サイズの設定
|
| 14 |
IMAGE_SIZE = 448
|
| 15 |
|
| 16 |
+
# デフォルトのタグ付けリポジトリとファイル構成
|
| 17 |
+
DEFAULT_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2"
|
| 18 |
+
MODEL_FILES = ["keras_metadata.pb", "saved_model.pb", "selected_tags.csv"]
|
| 19 |
+
VAR_DIR = "variables"
|
| 20 |
+
VAR_FILES = ["variables.data-00000-of-00001", "variables.index"]
|
| 21 |
+
CSV_FILE = MODEL_FILES[-1]
|
| 22 |
|
| 23 |
def preprocess_image(image):
|
| 24 |
+
"""画像を前処理して正方形に変換"""
|
| 25 |
+
img = np.array(image)[:, :, ::-1] # RGB->BGR
|
| 26 |
|
| 27 |
+
size = max(img.shape[:2])
|
| 28 |
+
pad_x, pad_y = size - img.shape[1], size - img.shape[0]
|
| 29 |
+
img = np.pad(img, ((pad_y // 2, pad_y - pad_y // 2), (pad_x // 2, pad_x - pad_x // 2), (0, 0)), mode="constant", constant_values=255)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
interp = cv2.INTER_AREA if size > IMAGE_SIZE else cv2.INTER_LANCZOS4
|
| 32 |
+
img = cv2.resize(img, (IMAGE_SIZE, IMAGE_SIZE), interpolation=interp)
|
| 33 |
+
return img.astype(np.float32)
|
|
|
|
|
|
|
| 34 |
|
| 35 |
+
def download_model_files(repo_id, model_dir, sub_dir, files, sub_files):
|
| 36 |
+
"""モデルファイルをHugging Face Hubからダウンロード"""
|
| 37 |
+
for file in files:
|
| 38 |
+
hf_hub_download(repo_id, file, cache_dir=model_dir, force_download=True, force_filename=file)
|
| 39 |
+
for file in sub_files:
|
| 40 |
+
hf_hub_download(repo_id, file, subfolder=sub_dir, cache_dir=os.path.join(model_dir, sub_dir), force_download=True, force_filename=file)
|
| 41 |
|
| 42 |
def load_wd14_tagger_model():
|
| 43 |
+
"""WD14タグ付けモデルをロード"""
|
| 44 |
model_dir = "wd14_tagger_model"
|
|
|
|
|
|
|
| 45 |
if not os.path.exists(model_dir):
|
| 46 |
+
download_model_files(DEFAULT_REPO, model_dir, VAR_DIR, MODEL_FILES, VAR_FILES)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
else:
|
| 48 |
+
print("Using existing model")
|
| 49 |
+
return TFSMLayer(model_dir, call_endpoint='serving_default')
|
|
|
|
|
|
|
|
|
|
| 50 |
|
| 51 |
+
def read_tags_from_csv(csv_path):
|
| 52 |
+
"""CSVファイルからタグを読み取る"""
|
| 53 |
+
with open(csv_path, "r", encoding="utf-8") as f:
|
| 54 |
reader = csv.reader(f)
|
| 55 |
+
tags = [row for row in reader]
|
| 56 |
+
header = tags[0]
|
| 57 |
+
rows = tags[1:]
|
| 58 |
+
assert header[:3] == ["tag_id", "name", "category"], f"Unexpected CSV format: {header}"
|
| 59 |
+
return rows
|
|
|
|
|
|
|
| 60 |
|
| 61 |
+
def generate_tags(images, model_dir, model):
|
| 62 |
+
"""画像にタグを生成"""
|
| 63 |
+
rows = read_tags_from_csv(os.path.join(model_dir, CSV_FILE))
|
| 64 |
+
general_tags = [row[1] for row in rows if row[2] == "0"]
|
| 65 |
+
character_tags = [row[1] for row in rows if row[2] == "4"]
|
| 66 |
+
|
| 67 |
tag_freq = {}
|
| 68 |
+
undesired_tags = {'one-piece_swimsuit', 'swimsuit', 'leotard', 'saitama_(one-punch_man)', '1boy'}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
|
| 70 |
+
probs = model(images, training=False)['predictions_sigmoid'].numpy()
|
| 71 |
tag_text_list = []
|
| 72 |
+
|
| 73 |
for prob in probs:
|
| 74 |
+
tags_combined = []
|
|
|
|
|
|
|
|
|
|
| 75 |
for i, p in enumerate(prob[4:]):
|
| 76 |
+
tag_list = general_tags if i < len(general_tags) else character_tags
|
| 77 |
+
tag = tag_list[i - len(general_tags)] if i >= len(general_tags) else tag_list[i]
|
| 78 |
+
if p >= 0.35 and tag not in undesired_tags:
|
| 79 |
+
tag_freq[tag] = tag_freq.get(tag, 0) + 1
|
| 80 |
+
tags_combined.append(tag)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
|
| 82 |
+
tag_text_list.append(", ".join(tags_combined))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
return tag_text_list
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/process_utils.py
CHANGED
|
@@ -40,9 +40,9 @@ def initialize(_use_local=False, use_gpu=False, use_dotenv=False):
|
|
| 40 |
device = "cuda" if use_gpu and torch.cuda.is_available() else "cpu"
|
| 41 |
torch_dtype = torch.float16 if device == "cuda" else torch.float32
|
| 42 |
use_local = _use_local
|
| 43 |
-
|
| 44 |
-
print(f"
|
| 45 |
-
|
| 46 |
init_model(use_local)
|
| 47 |
model = load_wd14_tagger_model()
|
| 48 |
sotai_gen_pipe = initialize_sotai_model()
|
|
@@ -59,7 +59,6 @@ def initialize_sotai_model():
|
|
| 59 |
controlnet_path1 = get_file_path(os.environ["controlnet_name1"], subfolder=os.environ["controlnet_dir2"])
|
| 60 |
# controlnet_path1 = get_file_path(os.environ["controlnet_name2"], subfolder=os.environ["controlnet_dir1"])
|
| 61 |
controlnet_path2 = get_file_path(os.environ["controlnet_name2"], subfolder=os.environ["controlnet_dir1"])
|
| 62 |
-
print(use_local, controlnet_path1)
|
| 63 |
|
| 64 |
# Load the Stable Diffusion model
|
| 65 |
sd_pipe = StableDiffusionPipeline.from_single_file(
|
|
@@ -294,7 +293,6 @@ def process_image(input_image, mode: str, weight1: float = 0.4, weight2: float =
|
|
| 294 |
image_np = np.array(ensure_rgb(input_image))
|
| 295 |
prompt = get_wd_tags([image_np])[0]
|
| 296 |
prompt = f"{prompt}"
|
| 297 |
-
print(prompt)
|
| 298 |
|
| 299 |
refined_image = generate_refined_image(prompt, input_image, output_width, output_height, weight1, weight2)
|
| 300 |
refined_image = refined_image.convert('RGB')
|
|
|
|
| 40 |
device = "cuda" if use_gpu and torch.cuda.is_available() else "cpu"
|
| 41 |
torch_dtype = torch.float16 if device == "cuda" else torch.float32
|
| 42 |
use_local = _use_local
|
| 43 |
+
|
| 44 |
+
print(f"\nDevice: {device}, Local model: {_use_local}\n")
|
| 45 |
+
|
| 46 |
init_model(use_local)
|
| 47 |
model = load_wd14_tagger_model()
|
| 48 |
sotai_gen_pipe = initialize_sotai_model()
|
|
|
|
| 59 |
controlnet_path1 = get_file_path(os.environ["controlnet_name1"], subfolder=os.environ["controlnet_dir2"])
|
| 60 |
# controlnet_path1 = get_file_path(os.environ["controlnet_name2"], subfolder=os.environ["controlnet_dir1"])
|
| 61 |
controlnet_path2 = get_file_path(os.environ["controlnet_name2"], subfolder=os.environ["controlnet_dir1"])
|
|
|
|
| 62 |
|
| 63 |
# Load the Stable Diffusion model
|
| 64 |
sd_pipe = StableDiffusionPipeline.from_single_file(
|
|
|
|
| 293 |
image_np = np.array(ensure_rgb(input_image))
|
| 294 |
prompt = get_wd_tags([image_np])[0]
|
| 295 |
prompt = f"{prompt}"
|
|
|
|
| 296 |
|
| 297 |
refined_image = generate_refined_image(prompt, input_image, output_width, output_height, weight1, weight2)
|
| 298 |
refined_image = refined_image.convert('RGB')
|