svjack commited on
Commit
47f51f0
·
verified ·
1 Parent(s): abb9a80

Upload mask_app.py

Browse files
Files changed (1) hide show
  1. mask_app.py +125 -0
mask_app.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ from gradio_imageslider import ImageSlider
4
+ from loadimg import load_img
5
+ #import spaces
6
+ from transformers import AutoModelForImageSegmentation
7
+ import torch
8
+ from torchvision import transforms
9
+ import numpy as np
10
+ from PIL import Image
11
+
12
+ # 检查 CUDA 是否可用
13
+ if torch.cuda.is_available():
14
+ device = "cuda"
15
+ else:
16
+ device = "cpu"
17
+
18
+ torch.set_float32_matmul_precision(["high", "highest"][0])
19
+
20
+ birefnet = AutoModelForImageSegmentation.from_pretrained(
21
+ "briaai/RMBG-2.0", trust_remote_code=True
22
+ )
23
+ birefnet.to(device)
24
+ transform_image = transforms.Compose(
25
+ [
26
+ transforms.Resize((1024, 1024)),
27
+ transforms.ToTensor(),
28
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
29
+ ]
30
+ )
31
+
32
+ output_folder = 'output_images'
33
+ if not os.path.exists(output_folder):
34
+ os.makedirs(output_folder)
35
+
36
+ # 定义颜色列表,每个颜色对应一个 mask
37
+ colors = [
38
+ '#000000', # 背景色
39
+ '#2692F3', # 蓝色
40
+ '#F89E12', # 橙色
41
+ '#16C232', # 绿色
42
+ '#F92F6C', # 粉色
43
+ '#AC6AEB', # 紫色
44
+ ]
45
+
46
+ # 将颜色转换为 RGB 值
47
+ palette = np.array([
48
+ tuple(int(s[i + 1:i + 3], 16) for i in (0, 2, 4))
49
+ for s in colors[1:] # 跳过背景色
50
+ ]) # (N, 3)
51
+
52
+ def fn(image, mask_color):
53
+ im = load_img(image, output_type="pil")
54
+ im = im.convert("RGB")
55
+ origin = im.copy()
56
+ image, mask = process(im, mask_color)
57
+ image_path = os.path.join(output_folder, "no_bg_image.png")
58
+ mask_path = os.path.join(output_folder, "mask_image.png")
59
+ image.save(image_path)
60
+ mask.save(mask_path)
61
+ return (image, origin), image_path, mask
62
+
63
+ #@spaces.GPU
64
+ def process(image, mask_color):
65
+ image_size = image.size
66
+ input_images = transform_image(image).unsqueeze(0).to(device)
67
+ # Prediction
68
+ with torch.no_grad():
69
+ preds = birefnet(input_images)[-1].sigmoid().cpu()
70
+ pred = preds[0].squeeze()
71
+ pred_pil = transforms.ToPILImage()(pred)
72
+ mask = pred_pil.resize(image_size)
73
+
74
+ # 创建一个新的透明背景图像
75
+ transparent_image = Image.new("RGBA", image_size, (0, 0, 0, 0))
76
+ transparent_image.paste(image, (0, 0), mask)
77
+
78
+ # 创建一个带有颜色的 mask 图像
79
+ mask_color_rgb = tuple(int(mask_color[i + 1:i + 3], 16) for i in (0, 2, 4))
80
+ colored_mask = Image.new("RGBA", image_size, mask_color_rgb + (255,))
81
+ colored_mask.putalpha(mask)
82
+
83
+ return transparent_image, colored_mask
84
+
85
+ # 示例数据
86
+ example_image = "giraffe.jpg" # 确保该文件存在于当前目录
87
+ example_url = "http://farm9.staticflickr.com/8488/8228323072_76eeddfea3_z.jpg"
88
+
89
+ # 定义 Gradio 组件
90
+ with gr.Blocks() as demo:
91
+ gr.Markdown("# 🖼️ RMBG-2.0 for Background Removal")
92
+ with gr.Row():
93
+ # 左侧列:输入
94
+ with gr.Column():
95
+ gr.Markdown("## Input")
96
+ image_input = gr.Image(label="Upload an image")
97
+ text_input = gr.Textbox(label="Paste an image URL")
98
+ color_input = gr.Dropdown(label="Mask Color", choices=colors[1:], value=colors[1])
99
+ run_button = gr.Button("Run")
100
+
101
+ # 右侧列:输出
102
+ with gr.Column():
103
+ gr.Markdown("## Output")
104
+ slider_output = ImageSlider(label="RMBG-2.0", type="pil")
105
+ file_output = gr.File(label="Output PNG File")
106
+ mask_output = gr.Image(label="Mask Image")
107
+
108
+ # 示例数据
109
+ gr.Examples(
110
+ examples=[[example_image, colors[1]], [example_url, colors[1]]],
111
+ inputs=[image_input, color_input],
112
+ outputs=[slider_output, file_output, mask_output], # 添加 outputs 参数
113
+ fn=fn,
114
+ cache_examples=True
115
+ )
116
+
117
+ # 绑定事件
118
+ run_button.click(
119
+ fn=fn,
120
+ inputs=[image_input, color_input],
121
+ outputs=[slider_output, file_output, mask_output]
122
+ )
123
+
124
+ if __name__ == "__main__":
125
+ demo.launch(share=True, show_error=True)