svjack commited on
Commit
123094d
·
verified ·
1 Parent(s): edf2c31

Upload 2 files

Browse files
Files changed (2) hide show
  1. green_process.py +106 -0
  2. transparent_process.py +90 -0
green_process.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ python green_process.py Star_Rail_Tribbie_MMD_Videos_sp30 Star_Rail_Tribbie_MMD_Videos_30s_Green --fast_mode --max_workers=10
3
+ '''
4
+
5
+ import os
6
+ import sys
7
+ import argparse
8
+ import torch
9
+ from torchvision import transforms
10
+ from moviepy import VideoFileClip, vfx, concatenate_videoclips, ImageSequenceClip
11
+ from PIL import Image
12
+ import numpy as np
13
+ from concurrent.futures import ThreadPoolExecutor
14
+ from transformers import AutoModelForImageSegmentation
15
+
16
+ # Set up device
17
+ device = "cuda" if torch.cuda.is_available() else "cpu"
18
+
19
+ # Load both BiRefNet models
20
+ birefnet = AutoModelForImageSegmentation.from_pretrained("ZhengPeng7/BiRefNet", trust_remote_code=True)
21
+ birefnet.to(device)
22
+ birefnet_lite = AutoModelForImageSegmentation.from_pretrained("ZhengPeng7/BiRefNet_lite", trust_remote_code=True)
23
+ birefnet_lite.to(device)
24
+
25
+ # Image transformation pipeline
26
+ transform_image = transforms.Compose([
27
+ transforms.Resize((768, 768)),
28
+ transforms.ToTensor(),
29
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
30
+ ])
31
+
32
+ def process_frame(frame, fast_mode):
33
+ try:
34
+ pil_image = Image.fromarray(frame)
35
+ processed_image = process(pil_image, "#00FF00", fast_mode)
36
+ return np.array(processed_image)
37
+ except Exception as e:
38
+ print(f"Error processing frame: {e}")
39
+ return frame
40
+
41
+ def process(image, bg, fast_mode=False):
42
+ image_size = image.size
43
+ input_images = transform_image(image).unsqueeze(0).to(device)
44
+ model = birefnet_lite if fast_mode else birefnet
45
+
46
+ with torch.no_grad():
47
+ preds = model(input_images)[-1].sigmoid().cpu()
48
+ pred = preds[0].squeeze()
49
+ pred_pil = transforms.ToPILImage()(pred)
50
+ mask = pred_pil.resize(image_size)
51
+
52
+ if isinstance(bg, str) and bg.startswith("#"):
53
+ color_rgb = tuple(int(bg[i:i+2], 16) for i in (1, 3, 5))
54
+ background = Image.new("RGBA", image_size, color_rgb + (255,))
55
+ elif isinstance(bg, Image.Image):
56
+ background = bg.convert("RGBA").resize(image_size)
57
+ else:
58
+ background = Image.open(bg).convert("RGBA").resize(image_size)
59
+
60
+ image = Image.composite(image, background, mask)
61
+ return image
62
+
63
+ def process_video(video_path, output_path, fast_mode=True, max_workers=10):
64
+ try:
65
+ video = VideoFileClip(video_path)
66
+ fps = video.fps
67
+ audio = video.audio
68
+ frames = list(video.iter_frames(fps=fps))
69
+
70
+ processed_frames = []
71
+
72
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
73
+ futures = [executor.submit(process_frame, frames[i], fast_mode) for i in range(len(frames))]
74
+ for future in futures:
75
+ result = future.result()
76
+ processed_frames.append(result)
77
+
78
+ processed_video = ImageSequenceClip(processed_frames, fps=fps)
79
+ processed_video = processed_video.with_audio(audio)
80
+ processed_video.write_videofile(output_path, codec="libx264")
81
+
82
+ except Exception as e:
83
+ print(f"Error processing video {video_path}: {e}")
84
+
85
+ def main(input_folder, output_folder, fast_mode=True, max_workers=10):
86
+ if not os.path.exists(output_folder):
87
+ os.makedirs(output_folder)
88
+
89
+ for video_file in os.listdir(input_folder):
90
+ if video_file.endswith((".mp4", ".avi", ".mov")):
91
+ video_path = os.path.join(input_folder, video_file)
92
+ output_path = os.path.join(output_folder, video_file)
93
+ print(f"Processing {video_path}...")
94
+ process_video(video_path, output_path, fast_mode, max_workers)
95
+ print(f"Finished processing {video_path}")
96
+
97
+ if __name__ == "__main__":
98
+ parser = argparse.ArgumentParser(description="Process videos to replace background with green.")
99
+ parser.add_argument("input_folder", type=str, help="Path to the folder containing input videos.")
100
+ parser.add_argument("output_folder", type=str, help="Path to the folder where processed videos will be saved.")
101
+ parser.add_argument("--fast_mode", action="store_true", help="Use BiRefNet_lite for faster processing.")
102
+ parser.add_argument("--max_workers", type=int, default=10, help="Number of workers for parallel processing.")
103
+
104
+ args = parser.parse_args()
105
+
106
+ main(args.input_folder, args.output_folder, args.fast_mode, args.max_workers)
transparent_process.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ python transparent_process.py Star_Rail_Tribbie_MMD_Videos_30s_Green Star_Rail_Tribbie_MMD_Videos_30s_Transparent --max_workers=10
3
+ '''
4
+
5
+ import os
6
+ import sys
7
+ import argparse
8
+ from moviepy.editor import VideoFileClip, ImageSequenceClip
9
+ from PIL import Image
10
+ import numpy as np
11
+ from concurrent.futures import ThreadPoolExecutor
12
+
13
+ def green_to_transparent(frame):
14
+ """
15
+ 将绿色背景转化为透明背景
16
+ """
17
+ try:
18
+ pil_image = Image.fromarray(frame)
19
+ # 将图像转换为RGBA模式
20
+ pil_image = pil_image.convert("RGBA")
21
+ # 获取图像的像素数据
22
+ data = pil_image.getdata()
23
+ # 创建一个新的像素列表
24
+ new_data = []
25
+ for item in data:
26
+ # 判断是否为绿色(这里使用简单的RGB范围判断)
27
+ if item[0] < 100 and item[1] > 200 and item[2] < 100:
28
+ # 将绿色像素设置为透明
29
+ new_data.append((item[0], item[1], item[2], 0))
30
+ else:
31
+ # 保留其他像素
32
+ new_data.append(item)
33
+ # 更新图像的像素数据
34
+ pil_image.putdata(new_data)
35
+ return np.array(pil_image)
36
+ except Exception as e:
37
+ print(f"Error processing frame: {e}")
38
+ return frame
39
+
40
+ def process_video(video_path, output_path, max_workers=10):
41
+ """
42
+ 处理视频,将绿色背景转化为透明背景(多线程版本)
43
+ """
44
+ try:
45
+ video = VideoFileClip(video_path)
46
+ fps = video.fps
47
+ audio = video.audio
48
+ frames = list(video.iter_frames(fps=fps))
49
+
50
+ processed_frames = []
51
+
52
+ # 使用线程池处理帧
53
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
54
+ futures = [executor.submit(green_to_transparent, frame) for frame in frames]
55
+ for future in futures:
56
+ processed_frame = future.result()
57
+ processed_frames.append(processed_frame)
58
+
59
+ # 将处理后的帧保存为视频
60
+ processed_video = ImageSequenceClip(processed_frames, fps=fps)
61
+ processed_video = processed_video.with_audio(audio)
62
+ processed_video.write_videofile(output_path, codec="png") # 使用PNG编码支持透明度
63
+ print(f"Video saved to {output_path}")
64
+ except Exception as e:
65
+ print(f"Error processing video {video_path}: {e}")
66
+
67
+ def main(input_folder, output_folder, max_workers=10):
68
+ """
69
+ 遍历输入文件夹中的所有视频,处理并保存到输出文件夹(多线程版本)
70
+ """
71
+ if not os.path.exists(output_folder):
72
+ os.makedirs(output_folder)
73
+
74
+ for video_file in os.listdir(input_folder):
75
+ if video_file.endswith((".mp4", ".avi", ".mov")):
76
+ video_path = os.path.join(input_folder, video_file)
77
+ output_path = os.path.join(output_folder, video_file)
78
+ print(f"Processing {video_path}...")
79
+ process_video(video_path, output_path, max_workers)
80
+ print(f"Finished processing {video_path}")
81
+
82
+ if __name__ == "__main__":
83
+ parser = argparse.ArgumentParser(description="Convert green background to transparent in videos.")
84
+ parser.add_argument("input_folder", type=str, help="Path to the folder containing input videos.")
85
+ parser.add_argument("output_folder", type=str, help="Path to the folder where processed videos will be saved.")
86
+ parser.add_argument("--max_workers", type=int, default=10, help="Number of workers for parallel processing.")
87
+
88
+ args = parser.parse_args()
89
+
90
+ main(args.input_folder, args.output_folder, args.max_workers)