''' python seg_script.py Genshin_Impact_Images Genshin_Impact_Images_Seg ''' import os import cv2 import argparse from PIL import Image import numpy as np from tqdm import tqdm from pathlib import Path from animeinsseg import AnimeInsSeg, AnimeInstances from animeinsseg.anime_instances import get_color # 设置模型路径 ckpt = r'models/AnimeInstanceSegmentation/rtmdetl_e60.ckpt' mask_thres = 0.3 instance_thres = 0.3 refine_kwargs = {'refine_method': 'refinenet_isnet'} # 如果不使用 refinenet,设置为 None # refine_kwargs = None # 初始化模型 net = AnimeInsSeg(ckpt, mask_thr=mask_thres, refine_kwargs=refine_kwargs) def process_image(image_path, output_dir): # 读取图像 img = cv2.imread(image_path) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # 推理 instances: AnimeInstances = net.infer( img, output_type='numpy', pred_score_thr=instance_thres ) # 初始化输出图像 drawed = img.copy() im_h, im_w = img.shape[:2] # 如果没有检测到对象,直接返回原图 if instances.bboxes is None: return # 保存绘制后的图像(只保存一次) base_name = Path(image_path).stem output_path = Path(output_dir) output_path.mkdir(parents=True, exist_ok=True) Image.fromarray(drawed).save(output_path / f"{base_name}_drawed.png") # 处理每个实例 for ii, (xywh, mask) in enumerate(zip(instances.bboxes, instances.masks)): color = get_color(ii) mask_alpha = 0.5 linewidth = max(round(sum(img.shape) / 2 * 0.003), 2) # 绘制边界框 p1, p2 = (int(xywh[0]), int(xywh[1])), (int(xywh[2] + xywh[0]), int(xywh[3] + xywh[1])) cv2.rectangle(drawed, p1, p2, color, thickness=linewidth, lineType=cv2.LINE_AA) # 绘制掩码 p = mask.astype(np.float32) blend_mask = np.full((im_h, im_w, 3), color, dtype=np.float32) alpha_msk = (mask_alpha * p)[..., None] alpha_ori = 1 - alpha_msk drawed = drawed * alpha_ori + alpha_msk * blend_mask drawed = drawed.astype(np.uint8) # 裁剪图像 x1, y1, x2, y2 = int(xywh[0]), int(xywh[1]), int(xywh[0] + xywh[2]), int(xywh[1] + xywh[3]) cropped_img = img[y1:y2, x1:x2] cropped_mask = mask[y1:y2, x1:x2] # 创建透明通道的边缘图 alpha_channel = (cropped_mask * 255).astype(np.uint8) rgba_image = np.dstack((cropped_img, alpha_channel)) # 保存裁剪后的图像和分割后的图像(文件名包含实例下标) Image.fromarray(cropped_img).save(output_path / f"{base_name}_cropped_{ii}.png") Image.fromarray(rgba_image, 'RGBA').save(output_path / f"{base_name}_segmented_{ii}.png") def main(): parser = argparse.ArgumentParser(description="Anime Instance Segmentation") parser.add_argument("input_path", type=str, help="Path to the input image or folder") parser.add_argument("output_dir", type=str, help="Path to the output directory") args = parser.parse_args() input_path = Path(args.input_path) output_dir = Path(args.output_dir) if input_path.is_file(): process_image(input_path, output_dir) elif input_path.is_dir(): image_paths = list(input_path.rglob("*.png")) + list(input_path.rglob("*.jpg")) for image_path in tqdm(image_paths, desc="Processing images"): process_image(image_path, output_dir) else: print("Invalid input path. Please provide a valid image or folder path.") if __name__ == "__main__": main()