RedbeardNZ MaxwellMeyer commited on
Commit
30b3307
·
verified ·
0 Parent(s):

Duplicate from PramaLLC/BEN2

Browse files

Co-authored-by: Maxwell Meyer <[email protected]>

.gitattributes ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ BEN2_demo_pictures/grid_example1.png filter=lfs diff=lfs merge=lfs -text
37
+ BEN2_demo_pictures/grid_example2.png filter=lfs diff=lfs merge=lfs -text
38
+ BEN2_demo_pictures/grid_example3.png filter=lfs diff=lfs merge=lfs -text
39
+ BEN2_demo_pictures/grid_example6.png filter=lfs diff=lfs merge=lfs -text
40
+ BEN2_demo_pictures/grid_example7.png filter=lfs diff=lfs merge=lfs -text
41
+ BEN2_demo_pictures/model_comparison.png filter=lfs diff=lfs merge=lfs -text
BEN2.py ADDED
@@ -0,0 +1,1401 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import math
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from einops import rearrange
7
+ import torch.utils.checkpoint as checkpoint
8
+ import numpy as np
9
+ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
10
+ from PIL import Image, ImageOps
11
+ from torchvision import transforms
12
+ import numpy as np
13
+ import random
14
+ import cv2
15
+ import os
16
+ import subprocess
17
+ import time
18
+ import tempfile
19
+
20
+
21
+
22
+
23
+ def set_random_seed(seed):
24
+ random.seed(seed)
25
+ np.random.seed(seed)
26
+ torch.manual_seed(seed)
27
+ torch.cuda.manual_seed(seed)
28
+ torch.cuda.manual_seed_all(seed)
29
+ torch.backends.cudnn.deterministic = True
30
+ torch.backends.cudnn.benchmark = False
31
+ set_random_seed(9)
32
+
33
+
34
+ torch.set_float32_matmul_precision('highest')
35
+
36
+
37
+
38
+ class Mlp(nn.Module):
39
+ """ Multilayer perceptron."""
40
+
41
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
42
+ super().__init__()
43
+ out_features = out_features or in_features
44
+ hidden_features = hidden_features or in_features
45
+ self.fc1 = nn.Linear(in_features, hidden_features)
46
+ self.act = act_layer()
47
+ self.fc2 = nn.Linear(hidden_features, out_features)
48
+ self.drop = nn.Dropout(drop)
49
+
50
+ def forward(self, x):
51
+ x = self.fc1(x)
52
+ x = self.act(x)
53
+ x = self.drop(x)
54
+ x = self.fc2(x)
55
+ x = self.drop(x)
56
+ return x
57
+
58
+
59
+ def window_partition(x, window_size):
60
+ """
61
+ Args:
62
+ x: (B, H, W, C)
63
+ window_size (int): window size
64
+ Returns:
65
+ windows: (num_windows*B, window_size, window_size, C)
66
+ """
67
+ B, H, W, C = x.shape
68
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
69
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
70
+ return windows
71
+
72
+
73
+ def window_reverse(windows, window_size, H, W):
74
+ """
75
+ Args:
76
+ windows: (num_windows*B, window_size, window_size, C)
77
+ window_size (int): Window size
78
+ H (int): Height of image
79
+ W (int): Width of image
80
+ Returns:
81
+ x: (B, H, W, C)
82
+ """
83
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
84
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
85
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
86
+ return x
87
+
88
+
89
+ class WindowAttention(nn.Module):
90
+ """ Window based multi-head self attention (W-MSA) module with relative position bias.
91
+ It supports both of shifted and non-shifted window.
92
+ Args:
93
+ dim (int): Number of input channels.
94
+ window_size (tuple[int]): The height and width of the window.
95
+ num_heads (int): Number of attention heads.
96
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
97
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
98
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
99
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
100
+ """
101
+
102
+ def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
103
+
104
+ super().__init__()
105
+ self.dim = dim
106
+ self.window_size = window_size # Wh, Ww
107
+ self.num_heads = num_heads
108
+ head_dim = dim // num_heads
109
+ self.scale = qk_scale or head_dim ** -0.5
110
+
111
+ # define a parameter table of relative position bias
112
+ self.relative_position_bias_table = nn.Parameter(
113
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
114
+
115
+ # get pair-wise relative position index for each token inside the window
116
+ coords_h = torch.arange(self.window_size[0])
117
+ coords_w = torch.arange(self.window_size[1])
118
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
119
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
120
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
121
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
122
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
123
+ relative_coords[:, :, 1] += self.window_size[1] - 1
124
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
125
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
126
+ self.register_buffer("relative_position_index", relative_position_index)
127
+
128
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
129
+ self.attn_drop = nn.Dropout(attn_drop)
130
+ self.proj = nn.Linear(dim, dim)
131
+ self.proj_drop = nn.Dropout(proj_drop)
132
+
133
+ trunc_normal_(self.relative_position_bias_table, std=.02)
134
+ self.softmax = nn.Softmax(dim=-1)
135
+
136
+ def forward(self, x, mask=None):
137
+ """ Forward function.
138
+ Args:
139
+ x: input features with shape of (num_windows*B, N, C)
140
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
141
+ """
142
+ B_, N, C = x.shape
143
+ qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
144
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
145
+
146
+ q = q * self.scale
147
+ attn = (q @ k.transpose(-2, -1))
148
+
149
+ relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
150
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
151
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
152
+ attn = attn + relative_position_bias.unsqueeze(0)
153
+
154
+ if mask is not None:
155
+ nW = mask.shape[0]
156
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
157
+ attn = attn.view(-1, self.num_heads, N, N)
158
+ attn = self.softmax(attn)
159
+ else:
160
+ attn = self.softmax(attn)
161
+
162
+ attn = self.attn_drop(attn)
163
+
164
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
165
+ x = self.proj(x)
166
+ x = self.proj_drop(x)
167
+ return x
168
+
169
+
170
+ class SwinTransformerBlock(nn.Module):
171
+ """ Swin Transformer Block.
172
+ Args:
173
+ dim (int): Number of input channels.
174
+ num_heads (int): Number of attention heads.
175
+ window_size (int): Window size.
176
+ shift_size (int): Shift size for SW-MSA.
177
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
178
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
179
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
180
+ drop (float, optional): Dropout rate. Default: 0.0
181
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
182
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
183
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
184
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
185
+ """
186
+
187
+ def __init__(self, dim, num_heads, window_size=7, shift_size=0,
188
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
189
+ act_layer=nn.GELU, norm_layer=nn.LayerNorm):
190
+ super().__init__()
191
+ self.dim = dim
192
+ self.num_heads = num_heads
193
+ self.window_size = window_size
194
+ self.shift_size = shift_size
195
+ self.mlp_ratio = mlp_ratio
196
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
197
+
198
+ self.norm1 = norm_layer(dim)
199
+ self.attn = WindowAttention(
200
+ dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
201
+ qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
202
+
203
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
204
+ self.norm2 = norm_layer(dim)
205
+ mlp_hidden_dim = int(dim * mlp_ratio)
206
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
207
+
208
+ self.H = None
209
+ self.W = None
210
+
211
+ def forward(self, x, mask_matrix):
212
+ """ Forward function.
213
+ Args:
214
+ x: Input feature, tensor size (B, H*W, C).
215
+ H, W: Spatial resolution of the input feature.
216
+ mask_matrix: Attention mask for cyclic shift.
217
+ """
218
+ B, L, C = x.shape
219
+ H, W = self.H, self.W
220
+ assert L == H * W, "input feature has wrong size"
221
+
222
+ shortcut = x
223
+ x = self.norm1(x)
224
+ x = x.view(B, H, W, C)
225
+
226
+ # pad feature maps to multiples of window size
227
+ pad_l = pad_t = 0
228
+ pad_r = (self.window_size - W % self.window_size) % self.window_size
229
+ pad_b = (self.window_size - H % self.window_size) % self.window_size
230
+ x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
231
+ _, Hp, Wp, _ = x.shape
232
+
233
+ # cyclic shift
234
+ if self.shift_size > 0:
235
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
236
+ attn_mask = mask_matrix
237
+ else:
238
+ shifted_x = x
239
+ attn_mask = None
240
+
241
+ # partition windows
242
+ x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
243
+ x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
244
+
245
+ # W-MSA/SW-MSA
246
+ attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
247
+
248
+ # merge windows
249
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
250
+ shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
251
+
252
+ # reverse cyclic shift
253
+ if self.shift_size > 0:
254
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
255
+ else:
256
+ x = shifted_x
257
+
258
+ if pad_r > 0 or pad_b > 0:
259
+ x = x[:, :H, :W, :].contiguous()
260
+
261
+ x = x.view(B, H * W, C)
262
+
263
+ # FFN
264
+ x = shortcut + self.drop_path(x)
265
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
266
+
267
+ return x
268
+
269
+
270
+ class PatchMerging(nn.Module):
271
+ """ Patch Merging Layer
272
+ Args:
273
+ dim (int): Number of input channels.
274
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
275
+ """
276
+ def __init__(self, dim, norm_layer=nn.LayerNorm):
277
+ super().__init__()
278
+ self.dim = dim
279
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
280
+ self.norm = norm_layer(4 * dim)
281
+
282
+ def forward(self, x, H, W):
283
+ """ Forward function.
284
+ Args:
285
+ x: Input feature, tensor size (B, H*W, C).
286
+ H, W: Spatial resolution of the input feature.
287
+ """
288
+ B, L, C = x.shape
289
+ assert L == H * W, "input feature has wrong size"
290
+
291
+ x = x.view(B, H, W, C)
292
+
293
+ # padding
294
+ pad_input = (H % 2 == 1) or (W % 2 == 1)
295
+ if pad_input:
296
+ x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
297
+
298
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
299
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
300
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
301
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
302
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
303
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
304
+
305
+ x = self.norm(x)
306
+ x = self.reduction(x)
307
+
308
+ return x
309
+
310
+
311
+ class BasicLayer(nn.Module):
312
+ """ A basic Swin Transformer layer for one stage.
313
+ Args:
314
+ dim (int): Number of feature channels
315
+ depth (int): Depths of this stage.
316
+ num_heads (int): Number of attention head.
317
+ window_size (int): Local window size. Default: 7.
318
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
319
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
320
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
321
+ drop (float, optional): Dropout rate. Default: 0.0
322
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
323
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
324
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
325
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
326
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
327
+ """
328
+
329
+ def __init__(self,
330
+ dim,
331
+ depth,
332
+ num_heads,
333
+ window_size=7,
334
+ mlp_ratio=4.,
335
+ qkv_bias=True,
336
+ qk_scale=None,
337
+ drop=0.,
338
+ attn_drop=0.,
339
+ drop_path=0.,
340
+ norm_layer=nn.LayerNorm,
341
+ downsample=None,
342
+ use_checkpoint=False):
343
+ super().__init__()
344
+ self.window_size = window_size
345
+ self.shift_size = window_size // 2
346
+ self.depth = depth
347
+ self.use_checkpoint = use_checkpoint
348
+
349
+ # build blocks
350
+ self.blocks = nn.ModuleList([
351
+ SwinTransformerBlock(
352
+ dim=dim,
353
+ num_heads=num_heads,
354
+ window_size=window_size,
355
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
356
+ mlp_ratio=mlp_ratio,
357
+ qkv_bias=qkv_bias,
358
+ qk_scale=qk_scale,
359
+ drop=drop,
360
+ attn_drop=attn_drop,
361
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
362
+ norm_layer=norm_layer)
363
+ for i in range(depth)])
364
+
365
+ # patch merging layer
366
+ if downsample is not None:
367
+ self.downsample = downsample(dim=dim, norm_layer=norm_layer)
368
+ else:
369
+ self.downsample = None
370
+
371
+ def forward(self, x, H, W):
372
+ """ Forward function.
373
+ Args:
374
+ x: Input feature, tensor size (B, H*W, C).
375
+ H, W: Spatial resolution of the input feature.
376
+ """
377
+
378
+ # calculate attention mask for SW-MSA
379
+ Hp = int(np.ceil(H / self.window_size)) * self.window_size
380
+ Wp = int(np.ceil(W / self.window_size)) * self.window_size
381
+ img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
382
+ h_slices = (slice(0, -self.window_size),
383
+ slice(-self.window_size, -self.shift_size),
384
+ slice(-self.shift_size, None))
385
+ w_slices = (slice(0, -self.window_size),
386
+ slice(-self.window_size, -self.shift_size),
387
+ slice(-self.shift_size, None))
388
+ cnt = 0
389
+ for h in h_slices:
390
+ for w in w_slices:
391
+ img_mask[:, h, w, :] = cnt
392
+ cnt += 1
393
+
394
+ mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
395
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
396
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
397
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
398
+
399
+ for blk in self.blocks:
400
+ blk.H, blk.W = H, W
401
+ if self.use_checkpoint:
402
+ x = checkpoint.checkpoint(blk, x, attn_mask)
403
+ else:
404
+ x = blk(x, attn_mask)
405
+ if self.downsample is not None:
406
+ x_down = self.downsample(x, H, W)
407
+ Wh, Ww = (H + 1) // 2, (W + 1) // 2
408
+ return x, H, W, x_down, Wh, Ww
409
+ else:
410
+ return x, H, W, x, H, W
411
+
412
+
413
+ class PatchEmbed(nn.Module):
414
+ """ Image to Patch Embedding
415
+ Args:
416
+ patch_size (int): Patch token size. Default: 4.
417
+ in_chans (int): Number of input image channels. Default: 3.
418
+ embed_dim (int): Number of linear projection output channels. Default: 96.
419
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
420
+ """
421
+
422
+ def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
423
+ super().__init__()
424
+ patch_size = to_2tuple(patch_size)
425
+ self.patch_size = patch_size
426
+
427
+ self.in_chans = in_chans
428
+ self.embed_dim = embed_dim
429
+
430
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
431
+ if norm_layer is not None:
432
+ self.norm = norm_layer(embed_dim)
433
+ else:
434
+ self.norm = None
435
+
436
+ def forward(self, x):
437
+ """Forward function."""
438
+ # padding
439
+ _, _, H, W = x.size()
440
+ if W % self.patch_size[1] != 0:
441
+ x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
442
+ if H % self.patch_size[0] != 0:
443
+ x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
444
+
445
+ x = self.proj(x) # B C Wh Ww
446
+ if self.norm is not None:
447
+ Wh, Ww = x.size(2), x.size(3)
448
+ x = x.flatten(2).transpose(1, 2)
449
+ x = self.norm(x)
450
+ x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
451
+
452
+ return x
453
+
454
+
455
+ class SwinTransformer(nn.Module):
456
+ """ Swin Transformer backbone.
457
+ A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
458
+ https://arxiv.org/pdf/2103.14030
459
+ Args:
460
+ pretrain_img_size (int): Input image size for training the pretrained model,
461
+ used in absolute postion embedding. Default 224.
462
+ patch_size (int | tuple(int)): Patch size. Default: 4.
463
+ in_chans (int): Number of input image channels. Default: 3.
464
+ embed_dim (int): Number of linear projection output channels. Default: 96.
465
+ depths (tuple[int]): Depths of each Swin Transformer stage.
466
+ num_heads (tuple[int]): Number of attention head of each stage.
467
+ window_size (int): Window size. Default: 7.
468
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
469
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
470
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
471
+ drop_rate (float): Dropout rate.
472
+ attn_drop_rate (float): Attention dropout rate. Default: 0.
473
+ drop_path_rate (float): Stochastic depth rate. Default: 0.2.
474
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
475
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.
476
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True.
477
+ out_indices (Sequence[int]): Output from which stages.
478
+ frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
479
+ -1 means not freezing any parameters.
480
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
481
+ """
482
+
483
+ def __init__(self,
484
+ pretrain_img_size=224,
485
+ patch_size=4,
486
+ in_chans=3,
487
+ embed_dim=96,
488
+ depths=[2, 2, 6, 2],
489
+ num_heads=[3, 6, 12, 24],
490
+ window_size=7,
491
+ mlp_ratio=4.,
492
+ qkv_bias=True,
493
+ qk_scale=None,
494
+ drop_rate=0.,
495
+ attn_drop_rate=0.,
496
+ drop_path_rate=0.2,
497
+ norm_layer=nn.LayerNorm,
498
+ ape=False,
499
+ patch_norm=True,
500
+ out_indices=(0, 1, 2, 3),
501
+ frozen_stages=-1,
502
+ use_checkpoint=False):
503
+ super().__init__()
504
+
505
+ self.pretrain_img_size = pretrain_img_size
506
+ self.num_layers = len(depths)
507
+ self.embed_dim = embed_dim
508
+ self.ape = ape
509
+ self.patch_norm = patch_norm
510
+ self.out_indices = out_indices
511
+ self.frozen_stages = frozen_stages
512
+
513
+ # split image into non-overlapping patches
514
+ self.patch_embed = PatchEmbed(
515
+ patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
516
+ norm_layer=norm_layer if self.patch_norm else None)
517
+
518
+ # absolute position embedding
519
+ if self.ape:
520
+ pretrain_img_size = to_2tuple(pretrain_img_size)
521
+ patch_size = to_2tuple(patch_size)
522
+ patches_resolution = [pretrain_img_size[0] // patch_size[0], pretrain_img_size[1] // patch_size[1]]
523
+
524
+ self.absolute_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1]))
525
+ trunc_normal_(self.absolute_pos_embed, std=.02)
526
+
527
+ self.pos_drop = nn.Dropout(p=drop_rate)
528
+
529
+ # stochastic depth
530
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
531
+
532
+ # build layers
533
+ self.layers = nn.ModuleList()
534
+ for i_layer in range(self.num_layers):
535
+ layer = BasicLayer(
536
+ dim=int(embed_dim * 2 ** i_layer),
537
+ depth=depths[i_layer],
538
+ num_heads=num_heads[i_layer],
539
+ window_size=window_size,
540
+ mlp_ratio=mlp_ratio,
541
+ qkv_bias=qkv_bias,
542
+ qk_scale=qk_scale,
543
+ drop=drop_rate,
544
+ attn_drop=attn_drop_rate,
545
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
546
+ norm_layer=norm_layer,
547
+ downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
548
+ use_checkpoint=use_checkpoint)
549
+ self.layers.append(layer)
550
+
551
+ num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
552
+ self.num_features = num_features
553
+
554
+ # add a norm layer for each output
555
+ for i_layer in out_indices:
556
+ layer = norm_layer(num_features[i_layer])
557
+ layer_name = f'norm{i_layer}'
558
+ self.add_module(layer_name, layer)
559
+
560
+ self._freeze_stages()
561
+
562
+ def _freeze_stages(self):
563
+ if self.frozen_stages >= 0:
564
+ self.patch_embed.eval()
565
+ for param in self.patch_embed.parameters():
566
+ param.requires_grad = False
567
+
568
+ if self.frozen_stages >= 1 and self.ape:
569
+ self.absolute_pos_embed.requires_grad = False
570
+
571
+ if self.frozen_stages >= 2:
572
+ self.pos_drop.eval()
573
+ for i in range(0, self.frozen_stages - 1):
574
+ m = self.layers[i]
575
+ m.eval()
576
+ for param in m.parameters():
577
+ param.requires_grad = False
578
+
579
+
580
+ def forward(self, x):
581
+
582
+ x = self.patch_embed(x)
583
+
584
+ Wh, Ww = x.size(2), x.size(3)
585
+ if self.ape:
586
+ # interpolate the position embedding to the corresponding size
587
+ absolute_pos_embed = F.interpolate(self.absolute_pos_embed, size=(Wh, Ww), mode='bicubic')
588
+ x = (x + absolute_pos_embed) # B Wh*Ww C
589
+
590
+ outs = [x.contiguous()]
591
+ x = x.flatten(2).transpose(1, 2)
592
+ x = self.pos_drop(x)
593
+
594
+
595
+ for i in range(self.num_layers):
596
+ layer = self.layers[i]
597
+ x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
598
+
599
+
600
+ if i in self.out_indices:
601
+ norm_layer = getattr(self, f'norm{i}')
602
+ x_out = norm_layer(x_out)
603
+
604
+ out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
605
+ outs.append(out)
606
+
607
+
608
+
609
+ return tuple(outs)
610
+
611
+
612
+
613
+
614
+
615
+
616
+
617
+
618
+ def get_activation_fn(activation):
619
+ """Return an activation function given a string"""
620
+ if activation == "gelu":
621
+ return F.gelu
622
+
623
+ raise RuntimeError(F"activation should be gelu, not {activation}.")
624
+
625
+
626
+ def make_cbr(in_dim, out_dim):
627
+ return nn.Sequential(nn.Conv2d(in_dim, out_dim, kernel_size=3, padding=1), nn.InstanceNorm2d(out_dim), nn.GELU())
628
+
629
+
630
+ def make_cbg(in_dim, out_dim):
631
+ return nn.Sequential(nn.Conv2d(in_dim, out_dim, kernel_size=3, padding=1), nn.InstanceNorm2d(out_dim), nn.GELU())
632
+
633
+
634
+ def rescale_to(x, scale_factor: float = 2, interpolation='nearest'):
635
+ return F.interpolate(x, scale_factor=scale_factor, mode=interpolation)
636
+
637
+
638
+ def resize_as(x, y, interpolation='bilinear'):
639
+ return F.interpolate(x, size=y.shape[-2:], mode=interpolation)
640
+
641
+
642
+ def image2patches(x):
643
+ """b c (hg h) (wg w) -> (hg wg b) c h w"""
644
+ x = rearrange(x, 'b c (hg h) (wg w) -> (hg wg b) c h w', hg=2, wg=2 )
645
+ return x
646
+
647
+
648
+ def patches2image(x):
649
+ """(hg wg b) c h w -> b c (hg h) (wg w)"""
650
+ x = rearrange(x, '(hg wg b) c h w -> b c (hg h) (wg w)', hg=2, wg=2)
651
+ return x
652
+
653
+
654
+
655
+ class PositionEmbeddingSine:
656
+ def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
657
+ super().__init__()
658
+ self.num_pos_feats = num_pos_feats
659
+ self.temperature = temperature
660
+ self.normalize = normalize
661
+ if scale is not None and normalize is False:
662
+ raise ValueError("normalize should be True if scale is passed")
663
+ if scale is None:
664
+ scale = 2 * math.pi
665
+ self.scale = scale
666
+ self.dim_t = torch.arange(0, self.num_pos_feats, dtype=torch.float32)
667
+
668
+ def __call__(self, b, h, w):
669
+ device = self.dim_t.device
670
+ mask = torch.zeros([b, h, w], dtype=torch.bool, device=device)
671
+ assert mask is not None
672
+ not_mask = ~mask
673
+ y_embed = not_mask.cumsum(dim=1, dtype=torch.float32)
674
+ x_embed = not_mask.cumsum(dim=2, dtype=torch.float32)
675
+ if self.normalize:
676
+ eps = 1e-6
677
+ y_embed = (y_embed - 0.5) / (y_embed[:, -1:, :] + eps) * self.scale
678
+ x_embed = (x_embed - 0.5) / (x_embed[:, :, -1:] + eps) * self.scale
679
+
680
+ dim_t = self.temperature ** (2 * (self.dim_t.to(device) // 2) / self.num_pos_feats)
681
+ pos_x = x_embed[:, :, :, None] / dim_t
682
+ pos_y = y_embed[:, :, :, None] / dim_t
683
+
684
+ pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
685
+ pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
686
+
687
+ return torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
688
+
689
+
690
+
691
+ class PositionEmbeddingSine:
692
+ def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
693
+ super().__init__()
694
+ self.num_pos_feats = num_pos_feats
695
+ self.temperature = temperature
696
+ self.normalize = normalize
697
+ if scale is not None and normalize is False:
698
+ raise ValueError("normalize should be True if scale is passed")
699
+ if scale is None:
700
+ scale = 2 * math.pi
701
+ self.scale = scale
702
+ self.dim_t = torch.arange(0, self.num_pos_feats, dtype=torch.float32)
703
+
704
+ def __call__(self, b, h, w):
705
+ device = self.dim_t.device
706
+ mask = torch.zeros([b, h, w], dtype=torch.bool, device=device)
707
+ assert mask is not None
708
+ not_mask = ~mask
709
+ y_embed = not_mask.cumsum(dim=1, dtype=torch.float32)
710
+ x_embed = not_mask.cumsum(dim=2, dtype=torch.float32)
711
+ if self.normalize:
712
+ eps = 1e-6
713
+ y_embed = (y_embed - 0.5) / (y_embed[:, -1:, :] + eps) * self.scale
714
+ x_embed = (x_embed - 0.5) / (x_embed[:, :, -1:] + eps) * self.scale
715
+
716
+ dim_t = self.temperature ** (2 * (self.dim_t.to(device) // 2) / self.num_pos_feats)
717
+ pos_x = x_embed[:, :, :, None] / dim_t
718
+ pos_y = y_embed[:, :, :, None] / dim_t
719
+
720
+ pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
721
+ pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
722
+
723
+ return torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
724
+
725
+
726
+ class MCLM(nn.Module):
727
+ def __init__(self, d_model, num_heads, pool_ratios=[1, 4, 8]):
728
+ super(MCLM, self).__init__()
729
+ self.attention = nn.ModuleList([
730
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
731
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
732
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
733
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
734
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1)
735
+ ])
736
+
737
+ self.linear1 = nn.Linear(d_model, d_model * 2)
738
+ self.linear2 = nn.Linear(d_model * 2, d_model)
739
+ self.linear3 = nn.Linear(d_model, d_model * 2)
740
+ self.linear4 = nn.Linear(d_model * 2, d_model)
741
+ self.norm1 = nn.LayerNorm(d_model)
742
+ self.norm2 = nn.LayerNorm(d_model)
743
+ self.dropout = nn.Dropout(0.1)
744
+ self.dropout1 = nn.Dropout(0.1)
745
+ self.dropout2 = nn.Dropout(0.1)
746
+ self.activation = get_activation_fn('gelu')
747
+ self.pool_ratios = pool_ratios
748
+ self.p_poses = []
749
+ self.g_pos = None
750
+ self.positional_encoding = PositionEmbeddingSine(num_pos_feats=d_model // 2, normalize=True)
751
+
752
+ def forward(self, l, g):
753
+ """
754
+ l: 4,c,h,w
755
+ g: 1,c,h,w
756
+ """
757
+ self.p_poses = []
758
+ self.g_pos = None
759
+ b, c, h, w = l.size()
760
+ # 4,c,h,w -> 1,c,2h,2w
761
+ concated_locs = rearrange(l, '(hg wg b) c h w -> b c (hg h) (wg w)', hg=2, wg=2)
762
+
763
+ pools = []
764
+ for pool_ratio in self.pool_ratios:
765
+ # b,c,h,w
766
+ tgt_hw = (round(h / pool_ratio), round(w / pool_ratio))
767
+ pool = F.adaptive_avg_pool2d(concated_locs, tgt_hw)
768
+ pools.append(rearrange(pool, 'b c h w -> (h w) b c'))
769
+ if self.g_pos is None:
770
+ pos_emb = self.positional_encoding(pool.shape[0], pool.shape[2], pool.shape[3])
771
+ pos_emb = rearrange(pos_emb, 'b c h w -> (h w) b c')
772
+ self.p_poses.append(pos_emb)
773
+ pools = torch.cat(pools, 0)
774
+ if self.g_pos is None:
775
+ self.p_poses = torch.cat(self.p_poses, dim=0)
776
+ pos_emb = self.positional_encoding(g.shape[0], g.shape[2], g.shape[3])
777
+ self.g_pos = rearrange(pos_emb, 'b c h w -> (h w) b c')
778
+
779
+ device = pools.device
780
+ self.p_poses = self.p_poses.to(device)
781
+ self.g_pos = self.g_pos.to(device)
782
+
783
+
784
+ # attention between glb (q) & multisensory concated-locs (k,v)
785
+ g_hw_b_c = rearrange(g, 'b c h w -> (h w) b c')
786
+
787
+
788
+ g_hw_b_c = g_hw_b_c + self.dropout1(self.attention[0](g_hw_b_c + self.g_pos, pools + self.p_poses, pools)[0])
789
+ g_hw_b_c = self.norm1(g_hw_b_c)
790
+ g_hw_b_c = g_hw_b_c + self.dropout2(self.linear2(self.dropout(self.activation(self.linear1(g_hw_b_c)).clone())))
791
+ g_hw_b_c = self.norm2(g_hw_b_c)
792
+
793
+ # attention between origin locs (q) & freashed glb (k,v)
794
+ l_hw_b_c = rearrange(l, "b c h w -> (h w) b c")
795
+ _g_hw_b_c = rearrange(g_hw_b_c, '(h w) b c -> h w b c', h=h, w=w)
796
+ _g_hw_b_c = rearrange(_g_hw_b_c, "(ng h) (nw w) b c -> (h w) (ng nw b) c", ng=2, nw=2)
797
+ outputs_re = []
798
+ for i, (_l, _g) in enumerate(zip(l_hw_b_c.chunk(4, dim=1), _g_hw_b_c.chunk(4, dim=1))):
799
+ outputs_re.append(self.attention[i + 1](_l, _g, _g)[0]) # (h w) 1 c
800
+ outputs_re = torch.cat(outputs_re, 1) # (h w) 4 c
801
+
802
+ l_hw_b_c = l_hw_b_c + self.dropout1(outputs_re)
803
+ l_hw_b_c = self.norm1(l_hw_b_c)
804
+ l_hw_b_c = l_hw_b_c + self.dropout2(self.linear4(self.dropout(self.activation(self.linear3(l_hw_b_c)).clone())))
805
+ l_hw_b_c = self.norm2(l_hw_b_c)
806
+
807
+ l = torch.cat((l_hw_b_c, g_hw_b_c), 1) # hw,b(5),c
808
+ return rearrange(l, "(h w) b c -> b c h w", h=h, w=w) ## (5,c,h*w)
809
+
810
+
811
+
812
+
813
+
814
+
815
+
816
+
817
+
818
+ class MCRM(nn.Module):
819
+ def __init__(self, d_model, num_heads, pool_ratios=[4, 8, 16], h=None):
820
+ super(MCRM, self).__init__()
821
+ self.attention = nn.ModuleList([
822
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
823
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
824
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
825
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1)
826
+ ])
827
+ self.linear3 = nn.Linear(d_model, d_model * 2)
828
+ self.linear4 = nn.Linear(d_model * 2, d_model)
829
+ self.norm1 = nn.LayerNorm(d_model)
830
+ self.norm2 = nn.LayerNorm(d_model)
831
+ self.dropout = nn.Dropout(0.1)
832
+ self.dropout1 = nn.Dropout(0.1)
833
+ self.dropout2 = nn.Dropout(0.1)
834
+ self.sigmoid = nn.Sigmoid()
835
+ self.activation = get_activation_fn('gelu')
836
+ self.sal_conv = nn.Conv2d(d_model, 1, 1)
837
+ self.pool_ratios = pool_ratios
838
+
839
+ def forward(self, x):
840
+ device = x.device
841
+ b, c, h, w = x.size()
842
+ loc, glb = x.split([4, 1], dim=0) # 4,c,h,w; 1,c,h,w
843
+
844
+ patched_glb = rearrange(glb, 'b c (hg h) (wg w) -> (hg wg b) c h w', hg=2, wg=2)
845
+
846
+ token_attention_map = self.sigmoid(self.sal_conv(glb))
847
+ token_attention_map = F.interpolate(token_attention_map, size=patches2image(loc).shape[-2:], mode='nearest')
848
+ loc = loc * rearrange(token_attention_map, 'b c (hg h) (wg w) -> (hg wg b) c h w', hg=2, wg=2)
849
+
850
+ pools = []
851
+ for pool_ratio in self.pool_ratios:
852
+ tgt_hw = (round(h / pool_ratio), round(w / pool_ratio))
853
+ pool = F.adaptive_avg_pool2d(patched_glb, tgt_hw)
854
+ pools.append(rearrange(pool, 'nl c h w -> nl c (h w)')) # nl(4),c,hw
855
+
856
+ pools = rearrange(torch.cat(pools, 2), "nl c nphw -> nl nphw 1 c")
857
+ loc_ = rearrange(loc, 'nl c h w -> nl (h w) 1 c')
858
+
859
+ outputs = []
860
+ for i, q in enumerate(loc_.unbind(dim=0)): # traverse all local patches
861
+ v = pools[i]
862
+ k = v
863
+ outputs.append(self.attention[i](q, k, v)[0])
864
+
865
+ outputs = torch.cat(outputs, 1)
866
+ src = loc.view(4, c, -1).permute(2, 0, 1) + self.dropout1(outputs)
867
+ src = self.norm1(src)
868
+ src = src + self.dropout2(self.linear4(self.dropout(self.activation(self.linear3(src)).clone())))
869
+ src = self.norm2(src)
870
+ src = src.permute(1, 2, 0).reshape(4, c, h, w) # freshed loc
871
+ glb = glb + F.interpolate(patches2image(src), size=glb.shape[-2:], mode='nearest') # freshed glb
872
+
873
+ return torch.cat((src, glb), 0), token_attention_map
874
+
875
+
876
+
877
+ class BEN_Base(nn.Module):
878
+ def __init__(self):
879
+ super().__init__()
880
+
881
+ self.backbone = SwinTransformer(embed_dim=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32], window_size=12)
882
+ emb_dim = 128
883
+ self.sideout5 = nn.Sequential(nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
884
+ self.sideout4 = nn.Sequential(nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
885
+ self.sideout3 = nn.Sequential(nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
886
+ self.sideout2 = nn.Sequential(nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
887
+ self.sideout1 = nn.Sequential(nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
888
+
889
+ self.output5 = make_cbr(1024, emb_dim)
890
+ self.output4 = make_cbr(512, emb_dim)
891
+ self.output3 = make_cbr(256, emb_dim)
892
+ self.output2 = make_cbr(128, emb_dim)
893
+ self.output1 = make_cbr(128, emb_dim)
894
+
895
+ self.multifieldcrossatt = MCLM(emb_dim, 1, [1, 4, 8])
896
+ self.conv1 = make_cbr(emb_dim, emb_dim)
897
+ self.conv2 = make_cbr(emb_dim, emb_dim)
898
+ self.conv3 = make_cbr(emb_dim, emb_dim)
899
+ self.conv4 = make_cbr(emb_dim, emb_dim)
900
+ self.dec_blk1 = MCRM(emb_dim, 1, [2, 4, 8])
901
+ self.dec_blk2 = MCRM(emb_dim, 1, [2, 4, 8])
902
+ self.dec_blk3 = MCRM(emb_dim, 1, [2, 4, 8])
903
+ self.dec_blk4 = MCRM(emb_dim, 1, [2, 4, 8])
904
+
905
+ self.insmask_head = nn.Sequential(
906
+ nn.Conv2d(emb_dim, 384, kernel_size=3, padding=1),
907
+ nn.InstanceNorm2d(384),
908
+ nn.GELU(),
909
+ nn.Conv2d(384, 384, kernel_size=3, padding=1),
910
+ nn.InstanceNorm2d(384),
911
+ nn.GELU(),
912
+ nn.Conv2d(384, emb_dim, kernel_size=3, padding=1)
913
+ )
914
+
915
+ self.shallow = nn.Sequential(nn.Conv2d(3, emb_dim, kernel_size=3, padding=1))
916
+ self.upsample1 = make_cbg(emb_dim, emb_dim)
917
+ self.upsample2 = make_cbg(emb_dim, emb_dim)
918
+ self.output = nn.Sequential(nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
919
+
920
+ for m in self.modules():
921
+ if isinstance(m, nn.GELU) or isinstance(m, nn.Dropout):
922
+ m.inplace = True
923
+
924
+
925
+
926
+ @torch.inference_mode()
927
+ @torch.autocast(device_type="cuda",dtype=torch.float16)
928
+ def forward(self, x):
929
+ real_batch = x.size(0)
930
+
931
+ shallow_batch = self.shallow(x)
932
+ glb_batch = rescale_to(x, scale_factor=0.5, interpolation='bilinear')
933
+
934
+
935
+
936
+ final_input = None
937
+ for i in range(real_batch):
938
+ start = i * 4
939
+ end = (i + 1) * 4
940
+ loc_batch = image2patches(x[i,:,:,:].unsqueeze(dim=0))
941
+ input_ = torch.cat((loc_batch, glb_batch[i,:,:,:].unsqueeze(dim=0)), dim=0)
942
+
943
+
944
+ if final_input == None:
945
+ final_input= input_
946
+ else: final_input = torch.cat((final_input, input_), dim=0)
947
+
948
+ features = self.backbone(final_input)
949
+ outputs = []
950
+
951
+ for i in range(real_batch):
952
+
953
+ start = i * 5
954
+ end = (i + 1) * 5
955
+
956
+ f4 = features[4][start:end, :, :, :] # shape: [5, C, H, W]
957
+ f3 = features[3][start:end, :, :, :]
958
+ f2 = features[2][start:end, :, :, :]
959
+ f1 = features[1][start:end, :, :, :]
960
+ f0 = features[0][start:end, :, :, :]
961
+ e5 = self.output5(f4)
962
+ e4 = self.output4(f3)
963
+ e3 = self.output3(f2)
964
+ e2 = self.output2(f1)
965
+ e1 = self.output1(f0)
966
+ loc_e5, glb_e5 = e5.split([4, 1], dim=0)
967
+ e5 = self.multifieldcrossatt(loc_e5, glb_e5) # (4,128,16,16)
968
+
969
+
970
+ e4, tokenattmap4 = self.dec_blk4(e4 + resize_as(e5, e4))
971
+ e4 = self.conv4(e4)
972
+ e3, tokenattmap3 = self.dec_blk3(e3 + resize_as(e4, e3))
973
+ e3 = self.conv3(e3)
974
+ e2, tokenattmap2 = self.dec_blk2(e2 + resize_as(e3, e2))
975
+ e2 = self.conv2(e2)
976
+ e1, tokenattmap1 = self.dec_blk1(e1 + resize_as(e2, e1))
977
+ e1 = self.conv1(e1)
978
+
979
+ loc_e1, glb_e1 = e1.split([4, 1], dim=0)
980
+
981
+ output1_cat = patches2image(loc_e1) # (1,128,256,256)
982
+
983
+ # add glb feat in
984
+ output1_cat = output1_cat + resize_as(glb_e1, output1_cat)
985
+ # merge
986
+ final_output = self.insmask_head(output1_cat) # (1,128,256,256)
987
+ # shallow feature merge
988
+ shallow = shallow_batch[i,:,:,:].unsqueeze(dim=0)
989
+ final_output = final_output + resize_as(shallow, final_output)
990
+ final_output = self.upsample1(rescale_to(final_output))
991
+ final_output = rescale_to(final_output + resize_as(shallow, final_output))
992
+ final_output = self.upsample2(final_output)
993
+ final_output = self.output(final_output)
994
+ mask = final_output.sigmoid()
995
+ outputs.append(mask)
996
+
997
+ return torch.cat(outputs, dim=0)
998
+
999
+
1000
+
1001
+
1002
+ def loadcheckpoints(self,model_path):
1003
+ model_dict = torch.load(model_path, map_location="cpu", weights_only=True)
1004
+ self.load_state_dict(model_dict['model_state_dict'], strict=True)
1005
+ del model_path
1006
+
1007
+ def inference(self,image,refine_foreground=False):
1008
+
1009
+ set_random_seed(9)
1010
+ # image = ImageOps.exif_transpose(image)
1011
+ if isinstance(image, Image.Image):
1012
+ image, h, w,original_image = rgb_loader_refiner(image)
1013
+ if torch.cuda.is_available():
1014
+
1015
+ img_tensor = img_transform(image).unsqueeze(0).to(next(self.parameters()).device)
1016
+ else:
1017
+ img_tensor = img_transform32(image).unsqueeze(0).to(next(self.parameters()).device)
1018
+
1019
+
1020
+ with torch.no_grad():
1021
+ res = self.forward(img_tensor)
1022
+
1023
+ # Show Results
1024
+ if refine_foreground == True:
1025
+
1026
+ pred_pil = transforms.ToPILImage()(res.squeeze())
1027
+ image_masked = refine_foreground_process(original_image, pred_pil)
1028
+
1029
+ image_masked.putalpha(pred_pil.resize(original_image.size))
1030
+ return image_masked
1031
+
1032
+ else:
1033
+ alpha = postprocess_image(res, im_size=[w,h])
1034
+ pred_pil = transforms.ToPILImage()(alpha)
1035
+ mask = pred_pil.resize(original_image.size)
1036
+ original_image.putalpha(mask)
1037
+ # mask = Image.fromarray(alpha)
1038
+
1039
+ return original_image
1040
+
1041
+
1042
+ else:
1043
+ foregrounds = []
1044
+ for batch in image:
1045
+ image, h, w,original_image = rgb_loader_refiner(batch)
1046
+ if torch.cuda.is_available():
1047
+
1048
+ img_tensor = img_transform(image).unsqueeze(0).to(next(self.parameters()).device)
1049
+ else:
1050
+ img_tensor = img_transform32(image).unsqueeze(0).to(next(self.parameters()).device)
1051
+
1052
+ with torch.no_grad():
1053
+ res = self.forward(img_tensor)
1054
+
1055
+ if refine_foreground == True:
1056
+
1057
+ pred_pil = transforms.ToPILImage()(res.squeeze())
1058
+ image_masked = refine_foreground_process(original_image, pred_pil)
1059
+
1060
+ image_masked.putalpha(pred_pil.resize(original_image.size))
1061
+
1062
+ foregrounds.append(image_masked)
1063
+ else:
1064
+ alpha = postprocess_image(res, im_size=[w,h])
1065
+ pred_pil = transforms.ToPILImage()(alpha)
1066
+ mask = pred_pil.resize(original_image.size)
1067
+ original_image.putalpha(mask)
1068
+ # mask = Image.fromarray(alpha)
1069
+ foregrounds.append(original_image)
1070
+
1071
+ return foregrounds
1072
+
1073
+
1074
+
1075
+
1076
+ def segment_video(self, video_path, output_path="./", fps=0, refine_foreground=False, batch=1, print_frames_processed=True, webm = False, rgb_value= (0, 255, 0)):
1077
+
1078
+ """
1079
+ Segments the given video to extract the foreground (with alpha) from each frame
1080
+ and saves the result as either a WebM video (with alpha channel) or MP4 (with a
1081
+ color background).
1082
+
1083
+ Args:
1084
+ video_path (str):
1085
+ Path to the input video file.
1086
+
1087
+ output_path (str, optional):
1088
+ Directory (or full path) where the output video and/or files will be saved.
1089
+ Defaults to "./".
1090
+
1091
+ fps (int, optional):
1092
+ The frames per second (FPS) to use for the output video. If 0 (default), the
1093
+ original FPS of the input video is used. Otherwise, overrides it.
1094
+
1095
+ refine_foreground (bool, optional):
1096
+ Whether to run an additional “refine foreground” process on each frame.
1097
+ Defaults to False.
1098
+
1099
+ batch (int, optional):
1100
+ Number of frames to process at once (inference batch size). Large batch sizes
1101
+ may require more GPU memory. Defaults to 1.
1102
+
1103
+ print_frames_processed (bool, optional):
1104
+ If True (default), prints progress (how many frames have been processed) to
1105
+ the console.
1106
+
1107
+ webm (bool, optional):
1108
+ If True (default), exports a WebM video with alpha channel (VP9 / yuva420p).
1109
+ If False, exports an MP4 video composited over a solid color background.
1110
+
1111
+ rgb_value (tuple, optional):
1112
+ The RGB background color (e.g., green screen) used to composite frames when
1113
+ saving to MP4. Defaults to (0, 255, 0).
1114
+
1115
+ Returns:
1116
+ None. Writes the output video(s) to disk in the specified format.
1117
+ """
1118
+
1119
+
1120
+ cap = cv2.VideoCapture(video_path)
1121
+ if not cap.isOpened():
1122
+ raise IOError(f"Cannot open video: {video_path}")
1123
+
1124
+ original_fps = cap.get(cv2.CAP_PROP_FPS)
1125
+ original_fps = 30 if original_fps == 0 else original_fps
1126
+ fps = original_fps if fps == 0 else fps
1127
+
1128
+ ret, first_frame = cap.read()
1129
+ if not ret:
1130
+ raise ValueError("No frames found in the video.")
1131
+ height, width = first_frame.shape[:2]
1132
+ cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
1133
+
1134
+ foregrounds = []
1135
+ frame_idx = 0
1136
+ processed_count = 0
1137
+ batch_frames = []
1138
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
1139
+
1140
+ while True:
1141
+ ret, frame = cap.read()
1142
+ if not ret:
1143
+ if batch_frames:
1144
+ batch_results = self.inference(batch_frames, refine_foreground)
1145
+ if isinstance(batch_results, Image.Image):
1146
+ foregrounds.append(batch_results)
1147
+ else:
1148
+ foregrounds.extend(batch_results)
1149
+ if print_frames_processed:
1150
+ print(f"Processed frames {frame_idx-len(batch_frames)+1} to {frame_idx} of {total_frames}")
1151
+ break
1152
+
1153
+ # Process every frame instead of using intervals
1154
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
1155
+ pil_frame = Image.fromarray(frame_rgb)
1156
+ batch_frames.append(pil_frame)
1157
+
1158
+ if len(batch_frames) == batch:
1159
+ batch_results = self.inference(batch_frames, refine_foreground)
1160
+ if isinstance(batch_results, Image.Image):
1161
+ foregrounds.append(batch_results)
1162
+ else:
1163
+ foregrounds.extend(batch_results)
1164
+ if print_frames_processed:
1165
+ print(f"Processed frames {frame_idx-batch+1} to {frame_idx} of {total_frames}")
1166
+ batch_frames = []
1167
+ processed_count += batch
1168
+
1169
+ frame_idx += 1
1170
+
1171
+
1172
+ if webm:
1173
+ alpha_webm_path = os.path.join(output_path, "foreground.webm")
1174
+ pil_images_to_webm_alpha(foregrounds, alpha_webm_path, fps=original_fps)
1175
+
1176
+ else:
1177
+ cap.release()
1178
+ fg_output = os.path.join(output_path, 'foreground.mp4')
1179
+
1180
+ pil_images_to_mp4(foregrounds, fg_output, fps=original_fps,rgb_value=rgb_value)
1181
+ cv2.destroyAllWindows()
1182
+
1183
+ try:
1184
+ fg_audio_output = os.path.join(output_path, 'foreground_output_with_audio.mp4')
1185
+ add_audio_to_video(fg_output, video_path, fg_audio_output)
1186
+ except Exception as e:
1187
+ print("No audio found in the original video")
1188
+ print(e)
1189
+
1190
+
1191
+
1192
+
1193
+
1194
+ def rgb_loader_refiner( original_image):
1195
+ h, w = original_image.size
1196
+
1197
+ image = original_image
1198
+ # Convert to RGB if necessary
1199
+ if image.mode != 'RGB':
1200
+ image = image.convert('RGB')
1201
+
1202
+ # Resize the image
1203
+ image = image.resize((1024, 1024), resample=Image.LANCZOS)
1204
+
1205
+ return image.convert('RGB'), h, w,original_image
1206
+
1207
+ # Define the image transformation
1208
+ img_transform = transforms.Compose([
1209
+ transforms.ToTensor(),
1210
+ transforms.ConvertImageDtype(torch.float16),
1211
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
1212
+ ])
1213
+
1214
+ img_transform32 = transforms.Compose([
1215
+ transforms.ToTensor(),
1216
+ transforms.ConvertImageDtype(torch.float32),
1217
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
1218
+ ])
1219
+
1220
+
1221
+
1222
+
1223
+
1224
+ def pil_images_to_mp4(images, output_path, fps=24, rgb_value=(0, 255, 0)):
1225
+ """
1226
+ Converts an array of PIL images to an MP4 video.
1227
+
1228
+ Args:
1229
+ images: List of PIL images
1230
+ output_path: Path to save the MP4 file
1231
+ fps: Frames per second (default: 24)
1232
+ rgb_value: Background RGB color tuple (default: green (0, 255, 0))
1233
+ """
1234
+ if not images:
1235
+ raise ValueError("No images provided to convert to MP4.")
1236
+
1237
+ width, height = images[0].size
1238
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
1239
+ video_writer = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
1240
+
1241
+ for image in images:
1242
+ # If image has alpha channel, composite onto the specified background color
1243
+ if image.mode == 'RGBA':
1244
+ # Create background image with specified RGB color
1245
+ background = Image.new('RGB', image.size, rgb_value)
1246
+ background = background.convert('RGBA')
1247
+ # Composite the image onto the background
1248
+ image = Image.alpha_composite(background, image)
1249
+ image = image.convert('RGB')
1250
+ else:
1251
+ # Ensure RGB format for non-alpha images
1252
+ image = image.convert('RGB')
1253
+
1254
+ # Convert to OpenCV format and write
1255
+ open_cv_image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
1256
+ video_writer.write(open_cv_image)
1257
+
1258
+ video_writer.release()
1259
+
1260
+ def pil_images_to_webm_alpha(images, output_path, fps=30):
1261
+ """
1262
+ Converts a list of PIL RGBA images to a VP9 .webm video with alpha channel.
1263
+
1264
+ NOTE: Not all players will display alpha in WebM.
1265
+ Browsers like Chrome/Firefox typically do support VP9 alpha.
1266
+ """
1267
+ if not images:
1268
+ raise ValueError("No images provided for WebM with alpha.")
1269
+
1270
+ # Ensure output directory exists
1271
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
1272
+
1273
+ with tempfile.TemporaryDirectory() as tmpdir:
1274
+ # Save frames as PNG (with alpha)
1275
+ for idx, img in enumerate(images):
1276
+ if img.mode != "RGBA":
1277
+ img = img.convert("RGBA")
1278
+ out_path = os.path.join(tmpdir, f"{idx:06d}.png")
1279
+ img.save(out_path, "PNG")
1280
+
1281
+ # Construct ffmpeg command
1282
+ # -c:v libvpx-vp9 => VP9 encoder
1283
+ # -pix_fmt yuva420p => alpha-enabled pixel format
1284
+ # -auto-alt-ref 0 => helps preserve alpha frames (libvpx quirk)
1285
+ ffmpeg_cmd = [
1286
+ "ffmpeg", "-y",
1287
+ "-framerate", str(fps),
1288
+ "-i", os.path.join(tmpdir, "%06d.png"),
1289
+ "-c:v", "libvpx-vp9",
1290
+ "-pix_fmt", "yuva420p",
1291
+ "-auto-alt-ref", "0",
1292
+ output_path
1293
+ ]
1294
+
1295
+ subprocess.run(ffmpeg_cmd, check=True)
1296
+
1297
+ print(f"WebM with alpha saved to {output_path}")
1298
+
1299
+ def add_audio_to_video(video_without_audio_path, original_video_path, output_path):
1300
+ """
1301
+ Check if the original video has an audio stream. If yes, add it. If not, skip.
1302
+ """
1303
+ # 1) Probe original video for audio streams
1304
+ probe_command = [
1305
+ 'ffprobe', '-v', 'error',
1306
+ '-select_streams', 'a:0',
1307
+ '-show_entries', 'stream=index',
1308
+ '-of', 'csv=p=0',
1309
+ original_video_path
1310
+ ]
1311
+ result = subprocess.run(probe_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
1312
+
1313
+ # result.stdout is empty if no audio stream found
1314
+ if not result.stdout.strip():
1315
+ print("No audio track found in original video, skipping audio addition.")
1316
+ return
1317
+
1318
+ print("Audio track detected; proceeding to mux audio.")
1319
+ # 2) If audio found, run ffmpeg to add it
1320
+ command = [
1321
+ 'ffmpeg', '-y',
1322
+ '-i', video_without_audio_path,
1323
+ '-i', original_video_path,
1324
+ '-c', 'copy',
1325
+ '-map', '0:v:0',
1326
+ '-map', '1:a:0', # we know there's an audio track now
1327
+ output_path
1328
+ ]
1329
+ subprocess.run(command, check=True)
1330
+ print(f"Audio added successfully => {output_path}")
1331
+
1332
+
1333
+
1334
+
1335
+
1336
+ ### Thanks to the source: https://huggingface.co/ZhengPeng7/BiRefNet/blob/main/handler.py
1337
+ def refine_foreground_process(image, mask, r=90):
1338
+ if mask.size != image.size:
1339
+ mask = mask.resize(image.size)
1340
+ image = np.array(image) / 255.0
1341
+ mask = np.array(mask) / 255.0
1342
+ estimated_foreground = FB_blur_fusion_foreground_estimator_2(image, mask, r=r)
1343
+ image_masked = Image.fromarray((estimated_foreground * 255.0).astype(np.uint8))
1344
+ return image_masked
1345
+
1346
+
1347
+ def FB_blur_fusion_foreground_estimator_2(image, alpha, r=90):
1348
+ # Thanks to the source: https://github.com/Photoroom/fast-foreground-estimation
1349
+ alpha = alpha[:, :, None]
1350
+ F, blur_B = FB_blur_fusion_foreground_estimator(image, image, image, alpha, r)
1351
+ return FB_blur_fusion_foreground_estimator(image, F, blur_B, alpha, r=6)[0]
1352
+
1353
+
1354
+ def FB_blur_fusion_foreground_estimator(image, F, B, alpha, r=90):
1355
+ if isinstance(image, Image.Image):
1356
+ image = np.array(image) / 255.0
1357
+ blurred_alpha = cv2.blur(alpha, (r, r))[:, :, None]
1358
+
1359
+ blurred_FA = cv2.blur(F * alpha, (r, r))
1360
+ blurred_F = blurred_FA / (blurred_alpha + 1e-5)
1361
+
1362
+ blurred_B1A = cv2.blur(B * (1 - alpha), (r, r))
1363
+ blurred_B = blurred_B1A / ((1 - blurred_alpha) + 1e-5)
1364
+ F = blurred_F + alpha * \
1365
+ (image - alpha * blurred_F - (1 - alpha) * blurred_B)
1366
+ F = np.clip(F, 0, 1)
1367
+ return F, blurred_B
1368
+
1369
+
1370
+
1371
+ def postprocess_image(result: torch.Tensor, im_size: list) -> np.ndarray:
1372
+ result = torch.squeeze(F.interpolate(result, size=im_size, mode='bilinear'), 0)
1373
+ ma = torch.max(result)
1374
+ mi = torch.min(result)
1375
+ result = (result - mi) / (ma - mi)
1376
+ im_array = (result * 255).permute(1, 2, 0).cpu().data.numpy().astype(np.uint8)
1377
+ im_array = np.squeeze(im_array)
1378
+ return im_array
1379
+
1380
+
1381
+
1382
+
1383
+ def rgb_loader_refiner( original_image):
1384
+ h, w = original_image.size
1385
+ # # Apply EXIF orientation
1386
+
1387
+ image = ImageOps.exif_transpose(original_image)
1388
+
1389
+ if original_image.mode != 'RGB':
1390
+ original_image = original_image.convert('RGB')
1391
+
1392
+ image = original_image
1393
+ # Convert to RGB if necessary
1394
+
1395
+ # Resize the image
1396
+ image = image.resize((1024, 1024), resample=Image.LANCZOS)
1397
+
1398
+ return image, h, w,original_image
1399
+
1400
+
1401
+
BEN2_Base.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:22cea62108ff53b7ccc20f7a008bf30494228d84b1687f29ecbe76936a998101
3
+ size 222932053
BEN2_Base.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:926144a876bda06f125555b4f5a239ece89dc6eb838a863700ca9bf192161a1c
3
+ size 1134584206
BEN2_demo_pictures/grid_example1.png ADDED

Git LFS Details

  • SHA256: 49df5808df57c1db87f1bdf94ff0687ba436f2c377378799dd3ce49be85e0973
  • Pointer size: 132 Bytes
  • Size of remote file: 1.51 MB
BEN2_demo_pictures/grid_example2.png ADDED

Git LFS Details

  • SHA256: 0899c57bdb592ccf3b04ed0316d3f8d4b23f337ff00f6807b18b46b74b8e91bf
  • Pointer size: 132 Bytes
  • Size of remote file: 1.76 MB
BEN2_demo_pictures/grid_example3.png ADDED

Git LFS Details

  • SHA256: f0e2cb53afd4ad04daa223525f688cad835826890eb4ababb1e0bf0e629800e5
  • Pointer size: 132 Bytes
  • Size of remote file: 8.59 MB
BEN2_demo_pictures/grid_example6.png ADDED

Git LFS Details

  • SHA256: 327eca743beef0cd452e40015b0695d67806cd6b59bd3a7759cfd2be260c5cae
  • Pointer size: 132 Bytes
  • Size of remote file: 2.37 MB
BEN2_demo_pictures/grid_example7.png ADDED

Git LFS Details

  • SHA256: 0f758d617b3266d2fb540bd5088c58a3a69d1cb5e54dcdf0a4d5ea6b74c6e7e2
  • Pointer size: 132 Bytes
  • Size of remote file: 5.27 MB
BEN2_demo_pictures/model_comparison.png ADDED

Git LFS Details

  • SHA256: b7b666c9f0b2c40fa471c21a3eeb6ff34f045082b3253d05392c7c58caff8621
  • Pointer size: 131 Bytes
  • Size of remote file: 328 kB
README.md ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ pipeline_tag: image-segmentation
4
+ library_name: ben2
5
+ tags:
6
+ - BEN2
7
+ - background-remove
8
+ - mask-generation
9
+ - Dichotomous image segmentation
10
+ - background remove
11
+ - foreground
12
+ - background
13
+ - remove background
14
+ - pytorch
15
+ - model_hub_mixin
16
+ - pytorch_model_hub_mixin
17
+ - background removal
18
+ - background-removal
19
+ ---
20
+
21
+ # BEN2: Background Erase Network
22
+
23
+ [![arXiv](https://img.shields.io/badge/arXiv-2501.06230-b31b1b.svg)](https://arxiv.org/abs/2501.06230)
24
+ [![GitHub](https://img.shields.io/badge/GitHub-BEN2-black.svg)](https://github.com/PramaLLC/BEN2/)
25
+ [![Website](https://img.shields.io/badge/Website-backgrounderase.net-104233)](https://backgrounderase.net)
26
+
27
+ ## Overview
28
+ BEN2 (Background Erase Network) introduces a novel approach to foreground segmentation through its innovative Confidence Guided Matting (CGM) pipeline. The architecture employs a refiner network that targets and processes pixels where the base model exhibits lower confidence levels, resulting in more precise and reliable matting results. This model is built on BEN:
29
+ [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/ben-using-confidence-guided-matting-for/dichotomous-image-segmentation-on-dis-vd)](https://paperswithcode.com/sota/dichotomous-image-segmentation-on-dis-vd?p=ben-using-confidence-guided-matting-for)
30
+
31
+
32
+
33
+
34
+ ## BEN2 access
35
+ BEN2 was trained on the DIS5k and our 22K proprietary segmentation dataset. Our enhanced model delivers superior performance in hair matting, 4K processing, object segmentation, and edge refinement. Our Base model is open source. To try the full model through our free web demo or integrate BEN2 into your project with our API:
36
+ - 🌐 [backgrounderase.net](https://backgrounderase.net)
37
+
38
+
39
+ ## Contact us
40
+ - For access to our commercial model email us at [email protected]
41
+ - Our website: https://prama.llc/
42
+ - Follow us on X: https://x.com/PramaResearch/
43
+
44
+
45
+ ## Installation
46
+
47
+ ```
48
+ pip install -e "git+https://github.com/PramaLLC/BEN2.git#egg=ben2"
49
+ ```
50
+
51
+ ## Quick start code
52
+
53
+ ```python
54
+ from ben2 import BEN_Base
55
+ from PIL import Image
56
+ import torch
57
+
58
+
59
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
60
+
61
+ file = "./image.png" # input image
62
+
63
+ model = BEN_Base.from_pretrained("PramaLLC/BEN2")
64
+ model.to(device).eval()
65
+
66
+ image = Image.open(file)
67
+ foreground = model.inference(image, refine_foreground=False,) #Refine foreground is an extract postprocessing step that increases inference time but can improve matting edges. The default value is False.
68
+
69
+ foreground.save("./foreground.png")
70
+
71
+ ```
72
+
73
+
74
+ ## Batch image processing
75
+
76
+ ```python
77
+ from ben2 import BEN_Base
78
+ from PIL import Image
79
+ import torch
80
+
81
+
82
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
83
+
84
+
85
+
86
+ model = BEN_Base.from_pretrained("PramaLLC/BEN2")
87
+ model.to(device).eval()
88
+
89
+
90
+ file1 = "./image1.png" # input image1
91
+ file2 = "./image2.png" # input image2
92
+ image1 = Image.open(file1)
93
+ image2 = Image.open(file2)
94
+
95
+
96
+
97
+ foregrounds = model.inference([image1, image2]) # We recommend that the batch size not exceed 3 for consumer GPUs as there are minimal inference gains due to our custom batch processing for the MVANet decoding steps.
98
+ foregrounds[0].save("./foreground1.png")
99
+ foregrounds[1].save("./foreground2.png")
100
+
101
+ ```
102
+
103
+
104
+
105
+ # BEN2 video segmentation
106
+ [![BEN2 Demo](https://img.youtube.com/vi/skEXiIHQcys/0.jpg)](https://www.youtube.com/watch?v=skEXiIHQcys)
107
+
108
+ ## Video Segmentation
109
+
110
+ ```bash
111
+ sudo apt update
112
+ sudo apt install ffmpeg
113
+ ```
114
+
115
+ ```python
116
+ from ben2 import BEN_Base
117
+ from PIL import Image
118
+ import torch
119
+
120
+
121
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
122
+
123
+ video_path = "/path_to_your_video.mp4"# input video
124
+
125
+ model = BEN_Base.from_pretrained("PramaLLC/BEN2")
126
+ model.to(device).eval()
127
+
128
+
129
+ model.segment_video(
130
+ video_path= video_path,
131
+ output_path="./", # Outputs will be saved as foreground.webm or foreground.mp4. The default value is "./"
132
+ fps=0, # If this is set to 0 CV2 will detect the fps in the original video. The default value is 0.
133
+ refine_foreground=False, #refine foreground is an extract postprocessing step that increases inference time but can improve matting edges. The default value is False.
134
+ batch=1, # We recommended that batch size not exceed 3 for consumer GPUs as there are minimal inference gains. The default value is 1.
135
+ print_frames_processed=True, #Informs you what frame is being processed. The default value is True.
136
+ webm = False, # This will output an alpha layer video but this defaults to mp4 when webm is false. The default value is False.
137
+ rgb_value= (0, 255, 0) # If you do not use webm this will be the RGB value of the resulting background only when webm is False. The default value is a green background (0,255,0).
138
+ )
139
+
140
+
141
+ ```
142
+
143
+
144
+
145
+ **# BEN2 evaluation**
146
+ ![Model Comparison](BEN2_demo_pictures/model_comparison.png)
147
+
148
+ RMBG 2.0 did not preserve the DIS 5k validation dataset
149
+
150
+ ![Example 1](BEN2_demo_pictures/grid_example1.png)
151
+ ![Example 2](BEN2_demo_pictures/grid_example2.png)
152
+ ![Example 3](BEN2_demo_pictures/grid_example3.png)
153
+ ![Example 6](BEN2_demo_pictures/grid_example6.png)
154
+ ![Example 7](BEN2_demo_pictures/grid_example7.png)
config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "PramaLLC/BEN2",
3
+ "architectures": ["PramaBEN_Base"],
4
+ "version": "1.0",
5
+ "torch_dtype": "float32"
6
+ }
inference.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import BEN2
2
+ from PIL import Image
3
+ import torch
4
+
5
+
6
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
7
+
8
+ file = "./image.png" # input image
9
+
10
+ model = BEN2.BEN_Base().to(device).eval() #init pipeline
11
+
12
+ model.loadcheckpoints("./BEN2_Base.pth")
13
+ image = Image.open(file)
14
+ foreground = model.inference(image)
15
+
16
+ foreground.save("./foreground.png")
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ea8b7907176a09667c86343dc7d00de6a6d871076cb90bb5f753618fd6fb3ebb
3
+ size 380577976
onnx_run.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import onnxruntime
2
+ import numpy as np
3
+ from PIL import Image
4
+ import torchvision.transforms as transforms
5
+ import torch
6
+ import torch.nn.functional as F
7
+
8
+
9
+ session = onnxruntime.InferenceSession("./BEN2_Base.onnx", providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
10
+
11
+ def postprocess_image(result_np: np.ndarray, im_size: list) -> np.ndarray:
12
+
13
+ result = torch.from_numpy(result_np)
14
+
15
+
16
+ if len(result.shape) == 3:
17
+ result = result.unsqueeze(0)
18
+
19
+
20
+ result = torch.squeeze(F.interpolate(result, size=im_size, mode='bilinear'), 0)
21
+
22
+
23
+ ma = torch.max(result)
24
+ mi = torch.min(result)
25
+ result = (result - mi) / (ma - mi)
26
+
27
+ im_array = (result * 255).permute(1, 2, 0).cpu().data.numpy().astype(np.uint8)
28
+ im_array = np.squeeze(im_array)
29
+ return im_array
30
+
31
+ def preprocess_image(image):
32
+ original_size = image.size
33
+ transform = transforms.Compose([
34
+ transforms.Resize((1024, 1024)),
35
+ transforms.ToTensor(),
36
+ ])
37
+ img_tensor = transform(image)
38
+
39
+ img_tensor = img_tensor.unsqueeze(0)
40
+ return img_tensor.numpy(), image, original_size
41
+
42
+ def run_inference(image):
43
+
44
+ input_data, original_image, (w, h) = preprocess_image(image)
45
+
46
+ input_name = session.get_inputs()[0].name
47
+
48
+ outputs = session.run(None, {input_name: input_data})
49
+
50
+
51
+ alpha = postprocess_image(outputs[0], im_size=[w, h])
52
+
53
+
54
+ mask = Image.fromarray(alpha)
55
+ mask = mask.resize((w, h))
56
+
57
+
58
+ original_image.putalpha(mask)
59
+ return original_image
60
+
61
+ # Example usage
62
+ image_path = "image.png"
63
+ output_path = "output.png"
64
+
65
+
66
+ image = Image.open(image_path)
67
+
68
+ result_image = run_inference(image)
69
+ result_image.save(output_path)
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ numpy>=1.21.0
2
+ torch>=1.9.0
3
+ einops>=0.6.0
4
+ Pillow>=9.0.0
5
+ timm>=1.0.10
6
+ torchvision>=0.10.0
7
+ onnxruntime
8
+
9
+