mohammed-aljafry commited on
Commit
ebf512b
·
verified ·
1 Parent(s): 28770b3

Upload model_definition.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. model_definition.py +1394 -0
model_definition.py ADDED
@@ -0,0 +1,1394 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # model_definition.py
2
+ # ============================================================================
3
+ # الاستيرادات الأساسية
4
+ # ============================================================================
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ import torch.optim as optim
9
+ from torch.optim import AdamW
10
+ from torch.optim.lr_scheduler import OneCycleLR
11
+ from torch.utils.data import Dataset, DataLoader
12
+ from torchvision import transforms
13
+ from functools import partial
14
+ from typing import Optional, List
15
+ from torch import Tensor
16
+ import os
17
+ import json
18
+ import numpy as np
19
+ import cv2
20
+ from PIL import Image
21
+ import torch
22
+ import torch.nn as nn
23
+ import torch.nn.functional as F
24
+ from torchvision import transforms
25
+ from functools import partial
26
+ from collections import deque, OrderedDict
27
+ import math
28
+ from torch.nn import MultiheadAttention
29
+ from torch.nn import TransformerEncoder, TransformerEncoderLayer
30
+ from torch.nn import TransformerDecoder, TransformerDecoderLayer
31
+ from timm.models.resnet import resnet50d, resnet26d, resnet18d
32
+ try:
33
+ from timm.layers import trunc_normal_
34
+ except ImportError:
35
+ from timm.models.layers import trunc_normal_
36
+ from huggingface_hub import hf_hub_download, HfApi
37
+ from huggingface_hub.utils import HfFolder
38
+
39
+ # مكتبات إضافية
40
+ import os
41
+ import json
42
+ import logging
43
+ import math
44
+ import copy
45
+ from pathlib import Path
46
+ from collections import OrderedDict
47
+
48
+ # مكتبات معالجة البيانات
49
+ import numpy as np
50
+ import cv2
51
+
52
+ # مكتبات اختيارية (يمكن تعطيلها إذا لم تكن متوفرة)
53
+
54
+ try:
55
+ from tqdm import tqdm
56
+ except ImportError:
57
+ # إذا لم تكن tqdm متوفرة، استخدم دالة بديلة
58
+ def tqdm(iterable, *args, **kwargs):
59
+ return iterable
60
+
61
+ # ============================================================================
62
+ # دوال مساعدة
63
+ # ============================================================================
64
+ def to_2tuple(x):
65
+ """تحويل قيمة إلى tuple من عنصرين"""
66
+ if isinstance(x, (list, tuple)):
67
+ return tuple(x)
68
+ return (x, x)
69
+ # ============================================================================
70
+ # ============================================================================
71
+
72
+ class HybridEmbed(nn.Module):
73
+ def __init__(
74
+ self,
75
+ backbone,
76
+ img_size=224,
77
+ patch_size=1,
78
+ feature_size=None,
79
+ in_chans=3,
80
+ embed_dim=768,
81
+ ):
82
+ super().__init__()
83
+ assert isinstance(backbone, nn.Module)
84
+ img_size = to_2tuple(img_size)
85
+ patch_size = to_2tuple(patch_size)
86
+ self.img_size = img_size
87
+ self.patch_size = patch_size
88
+ self.backbone = backbone
89
+ if feature_size is None:
90
+ with torch.no_grad():
91
+ training = backbone.training
92
+ if training:
93
+ backbone.eval()
94
+ o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))
95
+ if isinstance(o, (list, tuple)):
96
+ o = o[-1] # last feature if backbone outputs list/tuple of features
97
+ feature_size = o.shape[-2:]
98
+ feature_dim = o.shape[1]
99
+ backbone.train(training)
100
+ else:
101
+ feature_size = to_2tuple(feature_size)
102
+ if hasattr(self.backbone, "feature_info"):
103
+ feature_dim = self.backbone.feature_info.channels()[-1]
104
+ else:
105
+ feature_dim = self.backbone.num_features
106
+
107
+ self.proj = nn.Conv2d(feature_dim, embed_dim, kernel_size=1, stride=1)
108
+
109
+ def forward(self, x):
110
+ x = self.backbone(x)
111
+ if isinstance(x, (list, tuple)):
112
+ x = x[-1] # last feature if backbone outputs list/tuple of features
113
+ x = self.proj(x)
114
+ global_x = torch.mean(x, [2, 3], keepdim=False)[:, :, None]
115
+ return x, global_x
116
+
117
+
118
+ class PositionEmbeddingSine(nn.Module):
119
+ """
120
+ This is a more standard version of the position embedding, very similar to the one
121
+ used by the Attention is all you need paper, generalized to work on images.
122
+ """
123
+
124
+ def __init__(
125
+ self, num_pos_feats=64, temperature=10000, normalize=False, scale=None
126
+ ):
127
+ super().__init__()
128
+ self.num_pos_feats = num_pos_feats
129
+ self.temperature = temperature
130
+ self.normalize = normalize
131
+ if scale is not None and normalize is False:
132
+ raise ValueError("normalize should be True if scale is passed")
133
+ if scale is None:
134
+ scale = 2 * math.pi
135
+ self.scale = scale
136
+
137
+ def forward(self, tensor):
138
+ x = tensor
139
+ bs, _, h, w = x.shape
140
+ not_mask = torch.ones((bs, h, w), device=x.device)
141
+ y_embed = not_mask.cumsum(1, dtype=torch.float32)
142
+ x_embed = not_mask.cumsum(2, dtype=torch.float32)
143
+ if self.normalize:
144
+ eps = 1e-6
145
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
146
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
147
+
148
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
149
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
150
+
151
+ pos_x = x_embed[:, :, :, None] / dim_t
152
+ pos_y = y_embed[:, :, :, None] / dim_t
153
+ pos_x = torch.stack(
154
+ (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
155
+ ).flatten(3)
156
+ pos_y = torch.stack(
157
+ (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
158
+ ).flatten(3)
159
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
160
+ return pos
161
+
162
+
163
+ class TransformerEncoder(nn.Module):
164
+ def __init__(self, encoder_layer, num_layers, norm=None):
165
+ super().__init__()
166
+ self.layers = _get_clones(encoder_layer, num_layers)
167
+ self.num_layers = num_layers
168
+ self.norm = norm
169
+
170
+ def forward(
171
+ self,
172
+ src,
173
+ mask: Optional[Tensor] = None,
174
+ src_key_padding_mask: Optional[Tensor] = None,
175
+ pos: Optional[Tensor] = None,
176
+ ):
177
+ output = src
178
+
179
+ for layer in self.layers:
180
+ output = layer(
181
+ output,
182
+ src_mask=mask,
183
+ src_key_padding_mask=src_key_padding_mask,
184
+ pos=pos,
185
+ )
186
+
187
+ if self.norm is not None:
188
+ output = self.norm(output)
189
+
190
+ return output
191
+
192
+
193
+ class SpatialSoftmax(nn.Module):
194
+ def __init__(self, height, width, channel, temperature=None, data_format="NCHW"):
195
+ super().__init__()
196
+
197
+ self.data_format = data_format
198
+ self.height = height
199
+ self.width = width
200
+ self.channel = channel
201
+
202
+ if temperature:
203
+ self.temperature = Parameter(torch.ones(1) * temperature)
204
+ else:
205
+ self.temperature = 1.0
206
+
207
+ pos_x, pos_y = np.meshgrid(
208
+ np.linspace(-1.0, 1.0, self.height), np.linspace(-1.0, 1.0, self.width)
209
+ )
210
+ pos_x = torch.from_numpy(pos_x.reshape(self.height * self.width)).float()
211
+ pos_y = torch.from_numpy(pos_y.reshape(self.height * self.width)).float()
212
+ self.register_buffer("pos_x", pos_x)
213
+ self.register_buffer("pos_y", pos_y)
214
+
215
+ def forward(self, feature):
216
+ # Output:
217
+ # (N, C*2) x_0 y_0 ...
218
+
219
+ if self.data_format == "NHWC":
220
+ feature = (
221
+ feature.transpose(1, 3)
222
+ .tranpose(2, 3)
223
+ .view(-1, self.height * self.width)
224
+ )
225
+ else:
226
+ feature = feature.view(-1, self.height * self.width)
227
+
228
+ weight = F.softmax(feature / self.temperature, dim=-1)
229
+ expected_x = torch.sum(
230
+ torch.autograd.Variable(self.pos_x) * weight, dim=1, keepdim=True
231
+ )
232
+ expected_y = torch.sum(
233
+ torch.autograd.Variable(self.pos_y) * weight, dim=1, keepdim=True
234
+ )
235
+ expected_xy = torch.cat([expected_x, expected_y], 1)
236
+ feature_keypoints = expected_xy.view(-1, self.channel, 2)
237
+ feature_keypoints[:, :, 1] = (feature_keypoints[:, :, 1] - 1) * 12
238
+ feature_keypoints[:, :, 0] = feature_keypoints[:, :, 0] * 12
239
+ return feature_keypoints
240
+
241
+
242
+ class MultiPath_Generator(nn.Module):
243
+ def __init__(self, in_channel, embed_dim, out_channel):
244
+ super().__init__()
245
+ self.spatial_softmax = SpatialSoftmax(100, 100, out_channel)
246
+ self.tconv0 = nn.Sequential(
247
+ nn.ConvTranspose2d(in_channel, 256, 4, 2, 1, bias=False),
248
+ nn.BatchNorm2d(256),
249
+ nn.ReLU(True),
250
+ )
251
+ self.tconv1 = nn.Sequential(
252
+ nn.ConvTranspose2d(256, 256, 4, 2, 1, bias=False),
253
+ nn.BatchNorm2d(256),
254
+ nn.ReLU(True),
255
+ )
256
+ self.tconv2 = nn.Sequential(
257
+ nn.ConvTranspose2d(256, 192, 4, 2, 1, bias=False),
258
+ nn.BatchNorm2d(192),
259
+ nn.ReLU(True),
260
+ )
261
+ self.tconv3 = nn.Sequential(
262
+ nn.ConvTranspose2d(192, 64, 4, 2, 1, bias=False),
263
+ nn.BatchNorm2d(64),
264
+ nn.ReLU(True),
265
+ )
266
+ self.tconv4_list = torch.nn.ModuleList(
267
+ [
268
+ nn.Sequential(
269
+ nn.ConvTranspose2d(64, out_channel, 8, 2, 3, bias=False),
270
+ nn.Tanh(),
271
+ )
272
+ for _ in range(6)
273
+ ]
274
+ )
275
+
276
+ self.upsample = nn.Upsample(size=(50, 50), mode="bilinear")
277
+
278
+ def forward(self, x, measurements):
279
+ mask = measurements[:, :6]
280
+ mask = mask.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).repeat(1, 1, 1, 100, 100)
281
+ velocity = measurements[:, 6:7].unsqueeze(-1).unsqueeze(-1)
282
+ velocity = velocity.repeat(1, 32, 2, 2)
283
+
284
+ n, d, c = x.shape
285
+ x = x.transpose(1, 2)
286
+ x = x.view(n, -1, 2, 2)
287
+ x = torch.cat([x, velocity], dim=1)
288
+ x = self.tconv0(x)
289
+ x = self.tconv1(x)
290
+ x = self.tconv2(x)
291
+ x = self.tconv3(x)
292
+ x = self.upsample(x)
293
+ xs = []
294
+ for i in range(6):
295
+ xt = self.tconv4_list[i](x)
296
+ xs.append(xt)
297
+ xs = torch.stack(xs, dim=1)
298
+ x = torch.sum(xs * mask, dim=1)
299
+ x = self.spatial_softmax(x)
300
+ return x
301
+
302
+
303
+ class LinearWaypointsPredictor(nn.Module):
304
+ def __init__(self, input_dim, cumsum=True):
305
+ super().__init__()
306
+ self.cumsum = cumsum
307
+ self.rank_embed = nn.Parameter(torch.zeros(1, 10, input_dim))
308
+ self.head_fc1_list = nn.ModuleList([nn.Linear(input_dim, 64) for _ in range(6)])
309
+ self.head_relu = nn.ReLU(inplace=True)
310
+ self.head_fc2_list = nn.ModuleList([nn.Linear(64, 2) for _ in range(6)])
311
+
312
+ def forward(self, x, measurements):
313
+ # input shape: n 10 embed_dim
314
+ bs, n, dim = x.shape
315
+ x = x + self.rank_embed
316
+ x = x.reshape(-1, dim)
317
+
318
+ mask = measurements[:, :6]
319
+ mask = torch.unsqueeze(mask, -1).repeat(n, 1, 2)
320
+
321
+ rs = []
322
+ for i in range(6):
323
+ res = self.head_fc1_list[i](x)
324
+ res = self.head_relu(res)
325
+ res = self.head_fc2_list[i](res)
326
+ rs.append(res)
327
+ rs = torch.stack(rs, 1)
328
+ x = torch.sum(rs * mask, dim=1)
329
+
330
+ x = x.view(bs, n, 2)
331
+ if self.cumsum:
332
+ x = torch.cumsum(x, 1)
333
+ return x
334
+
335
+
336
+ class GRUWaypointsPredictor(nn.Module):
337
+ def __init__(self, input_dim, waypoints=10):
338
+ super().__init__()
339
+ # self.gru = torch.nn.GRUCell(input_size=input_dim, hidden_size=64)
340
+ self.gru = torch.nn.GRU(input_size=input_dim, hidden_size=64, batch_first=True)
341
+ self.encoder = nn.Linear(2, 64)
342
+ self.decoder = nn.Linear(64, 2)
343
+ self.waypoints = waypoints
344
+
345
+ def forward(self, x, target_point):
346
+ bs = x.shape[0]
347
+ z = self.encoder(target_point).unsqueeze(0)
348
+ output, _ = self.gru(x, z)
349
+ output = output.reshape(bs * self.waypoints, -1)
350
+ output = self.decoder(output).reshape(bs, self.waypoints, 2)
351
+ output = torch.cumsum(output, 1)
352
+ return output
353
+
354
+ class GRUWaypointsPredictorWithCommand(nn.Module):
355
+ def __init__(self, input_dim, waypoints=10):
356
+ super().__init__()
357
+ # self.gru = torch.nn.GRUCell(input_size=input_dim, hidden_size=64)
358
+ self.grus = nn.ModuleList([torch.nn.GRU(input_size=input_dim, hidden_size=64, batch_first=True) for _ in range(6)])
359
+ self.encoder = nn.Linear(2, 64)
360
+ self.decoders = nn.ModuleList([nn.Linear(64, 2) for _ in range(6)])
361
+ self.waypoints = waypoints
362
+
363
+ def forward(self, x, target_point, measurements):
364
+ bs, n, dim = x.shape
365
+ mask = measurements[:, :6, None, None]
366
+ mask = mask.repeat(1, 1, self.waypoints, 2)
367
+
368
+ z = self.encoder(target_point).unsqueeze(0)
369
+ outputs = []
370
+ for i in range(6):
371
+ output, _ = self.grus[i](x, z)
372
+ output = output.reshape(bs * self.waypoints, -1)
373
+ output = self.decoders[i](output).reshape(bs, self.waypoints, 2)
374
+ output = torch.cumsum(output, 1)
375
+ outputs.append(output)
376
+ outputs = torch.stack(outputs, 1)
377
+ output = torch.sum(outputs * mask, dim=1)
378
+ return output
379
+
380
+
381
+ class TransformerDecoder(nn.Module):
382
+ def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
383
+ super().__init__()
384
+ self.layers = _get_clones(decoder_layer, num_layers)
385
+ self.num_layers = num_layers
386
+ self.norm = norm
387
+ self.return_intermediate = return_intermediate
388
+
389
+ def forward(
390
+ self,
391
+ tgt,
392
+ memory,
393
+ tgt_mask: Optional[Tensor] = None,
394
+ memory_mask: Optional[Tensor] = None,
395
+ tgt_key_padding_mask: Optional[Tensor] = None,
396
+ memory_key_padding_mask: Optional[Tensor] = None,
397
+ pos: Optional[Tensor] = None,
398
+ query_pos: Optional[Tensor] = None,
399
+ ):
400
+ output = tgt
401
+
402
+ intermediate = []
403
+
404
+ for layer in self.layers:
405
+ output = layer(
406
+ output,
407
+ memory,
408
+ tgt_mask=tgt_mask,
409
+ memory_mask=memory_mask,
410
+ tgt_key_padding_mask=tgt_key_padding_mask,
411
+ memory_key_padding_mask=memory_key_padding_mask,
412
+ pos=pos,
413
+ query_pos=query_pos,
414
+ )
415
+ if self.return_intermediate:
416
+ intermediate.append(self.norm(output))
417
+
418
+ if self.norm is not None:
419
+ output = self.norm(output)
420
+ if self.return_intermediate:
421
+ intermediate.pop()
422
+ intermediate.append(output)
423
+
424
+ if self.return_intermediate:
425
+ return torch.stack(intermediate)
426
+
427
+ return output.unsqueeze(0)
428
+
429
+
430
+ class TransformerEncoderLayer(nn.Module):
431
+ def __init__(
432
+ self,
433
+ d_model,
434
+ nhead,
435
+ dim_feedforward=2048,
436
+ dropout=0.1,
437
+ activation=nn.ReLU(),
438
+ normalize_before=False,
439
+ ):
440
+ super().__init__()
441
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
442
+ # Implementation of Feedforward model
443
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
444
+ self.dropout = nn.Dropout(dropout)
445
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
446
+
447
+ self.norm1 = nn.LayerNorm(d_model)
448
+ self.norm2 = nn.LayerNorm(d_model)
449
+ self.dropout1 = nn.Dropout(dropout)
450
+ self.dropout2 = nn.Dropout(dropout)
451
+
452
+ self.activation = activation()
453
+ self.normalize_before = normalize_before
454
+
455
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
456
+ return tensor if pos is None else tensor + pos
457
+
458
+ def forward_post(
459
+ self,
460
+ src,
461
+ src_mask: Optional[Tensor] = None,
462
+ src_key_padding_mask: Optional[Tensor] = None,
463
+ pos: Optional[Tensor] = None,
464
+ ):
465
+ q = k = self.with_pos_embed(src, pos)
466
+ src2 = self.self_attn(
467
+ q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask
468
+ )[0]
469
+ src = src + self.dropout1(src2)
470
+ src = self.norm1(src)
471
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
472
+ src = src + self.dropout2(src2)
473
+ src = self.norm2(src)
474
+ return src
475
+
476
+ def forward_pre(
477
+ self,
478
+ src,
479
+ src_mask: Optional[Tensor] = None,
480
+ src_key_padding_mask: Optional[Tensor] = None,
481
+ pos: Optional[Tensor] = None,
482
+ ):
483
+ src2 = self.norm1(src)
484
+ q = k = self.with_pos_embed(src2, pos)
485
+ src2 = self.self_attn(
486
+ q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask
487
+ )[0]
488
+ src = src + self.dropout1(src2)
489
+ src2 = self.norm2(src)
490
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
491
+ src = src + self.dropout2(src2)
492
+ return src
493
+
494
+ def forward(
495
+ self,
496
+ src,
497
+ src_mask: Optional[Tensor] = None,
498
+ src_key_padding_mask: Optional[Tensor] = None,
499
+ pos: Optional[Tensor] = None,
500
+ ):
501
+ if self.normalize_before:
502
+ return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
503
+ return self.forward_post(src, src_mask, src_key_padding_mask, pos)
504
+
505
+
506
+ class TransformerDecoderLayer(nn.Module):
507
+ def __init__(
508
+ self,
509
+ d_model,
510
+ nhead,
511
+ dim_feedforward=2048,
512
+ dropout=0.1,
513
+ activation=nn.ReLU(),
514
+ normalize_before=False,
515
+ ):
516
+ super().__init__()
517
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
518
+ self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
519
+ # Implementation of Feedforward model
520
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
521
+ self.dropout = nn.Dropout(dropout)
522
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
523
+
524
+ self.norm1 = nn.LayerNorm(d_model)
525
+ self.norm2 = nn.LayerNorm(d_model)
526
+ self.norm3 = nn.LayerNorm(d_model)
527
+ self.dropout1 = nn.Dropout(dropout)
528
+ self.dropout2 = nn.Dropout(dropout)
529
+ self.dropout3 = nn.Dropout(dropout)
530
+
531
+ self.activation = activation()
532
+ self.normalize_before = normalize_before
533
+
534
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
535
+ return tensor if pos is None else tensor + pos
536
+
537
+ def forward_post(
538
+ self,
539
+ tgt,
540
+ memory,
541
+ tgt_mask: Optional[Tensor] = None,
542
+ memory_mask: Optional[Tensor] = None,
543
+ tgt_key_padding_mask: Optional[Tensor] = None,
544
+ memory_key_padding_mask: Optional[Tensor] = None,
545
+ pos: Optional[Tensor] = None,
546
+ query_pos: Optional[Tensor] = None,
547
+ ):
548
+ q = k = self.with_pos_embed(tgt, query_pos)
549
+ tgt2 = self.self_attn(
550
+ q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
551
+ )[0]
552
+ tgt = tgt + self.dropout1(tgt2)
553
+ tgt = self.norm1(tgt)
554
+ tgt2 = self.multihead_attn(
555
+ query=self.with_pos_embed(tgt, query_pos),
556
+ key=self.with_pos_embed(memory, pos),
557
+ value=memory,
558
+ attn_mask=memory_mask,
559
+ key_padding_mask=memory_key_padding_mask,
560
+ )[0]
561
+ tgt = tgt + self.dropout2(tgt2)
562
+ tgt = self.norm2(tgt)
563
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
564
+ tgt = tgt + self.dropout3(tgt2)
565
+ tgt = self.norm3(tgt)
566
+ return tgt
567
+
568
+ def forward_pre(
569
+ self,
570
+ tgt,
571
+ memory,
572
+ tgt_mask: Optional[Tensor] = None,
573
+ memory_mask: Optional[Tensor] = None,
574
+ tgt_key_padding_mask: Optional[Tensor] = None,
575
+ memory_key_padding_mask: Optional[Tensor] = None,
576
+ pos: Optional[Tensor] = None,
577
+ query_pos: Optional[Tensor] = None,
578
+ ):
579
+ tgt2 = self.norm1(tgt)
580
+ q = k = self.with_pos_embed(tgt2, query_pos)
581
+ tgt2 = self.self_attn(
582
+ q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
583
+ )[0]
584
+ tgt = tgt + self.dropout1(tgt2)
585
+ tgt2 = self.norm2(tgt)
586
+ tgt2 = self.multihead_attn(
587
+ query=self.with_pos_embed(tgt2, query_pos),
588
+ key=self.with_pos_embed(memory, pos),
589
+ value=memory,
590
+ attn_mask=memory_mask,
591
+ key_padding_mask=memory_key_padding_mask,
592
+ )[0]
593
+ tgt = tgt + self.dropout2(tgt2)
594
+ tgt2 = self.norm3(tgt)
595
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
596
+ tgt = tgt + self.dropout3(tgt2)
597
+ return tgt
598
+
599
+ def forward(
600
+ self,
601
+ tgt,
602
+ memory,
603
+ tgt_mask: Optional[Tensor] = None,
604
+ memory_mask: Optional[Tensor] = None,
605
+ tgt_key_padding_mask: Optional[Tensor] = None,
606
+ memory_key_padding_mask: Optional[Tensor] = None,
607
+ pos: Optional[Tensor] = None,
608
+ query_pos: Optional[Tensor] = None,
609
+ ):
610
+ if self.normalize_before:
611
+ return self.forward_pre(
612
+ tgt,
613
+ memory,
614
+ tgt_mask,
615
+ memory_mask,
616
+ tgt_key_padding_mask,
617
+ memory_key_padding_mask,
618
+ pos,
619
+ query_pos,
620
+ )
621
+ return self.forward_post(
622
+ tgt,
623
+ memory,
624
+ tgt_mask,
625
+ memory_mask,
626
+ tgt_key_padding_mask,
627
+ memory_key_padding_mask,
628
+ pos,
629
+ query_pos,
630
+ )
631
+
632
+
633
+ def _get_clones(module, N):
634
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
635
+
636
+
637
+ def _get_activation_fn(activation):
638
+ """Return an activation function given a string"""
639
+ if activation == "relu":
640
+ return F.relu
641
+ if activation == "gelu":
642
+ return F.gelu
643
+ if activation == "glu":
644
+ return F.glu
645
+ raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
646
+
647
+
648
+ def build_attn_mask(mask_type):
649
+ mask = torch.ones((151, 151), dtype=torch.bool).cuda()
650
+ if mask_type == "seperate_all":
651
+ mask[:50, :50] = False
652
+ mask[50:67, 50:67] = False
653
+ mask[67:84, 67:84] = False
654
+ mask[84:101, 84:101] = False
655
+ mask[101:151, 101:151] = False
656
+ elif mask_type == "seperate_view":
657
+ mask[:50, :50] = False
658
+ mask[50:67, 50:67] = False
659
+ mask[67:84, 67:84] = False
660
+ mask[84:101, 84:101] = False
661
+ mask[101:151, :] = False
662
+ mask[:, 101:151] = False
663
+ return mask
664
+ # class InterfuserModel(nn.Module):
665
+
666
+ class InterfuserModel(nn.Module):
667
+ def __init__(
668
+ self,
669
+ img_size=224,
670
+ multi_view_img_size=112,
671
+ patch_size=8,
672
+ in_chans=3,
673
+ embed_dim=768,
674
+ enc_depth=6,
675
+ dec_depth=6,
676
+ dim_feedforward=2048,
677
+ normalize_before=False,
678
+ rgb_backbone_name="r50",
679
+ lidar_backbone_name="r50",
680
+ num_heads=8,
681
+ norm_layer=None,
682
+ dropout=0.1,
683
+ end2end=False,
684
+ direct_concat=False,
685
+ separate_view_attention=False,
686
+ separate_all_attention=False,
687
+ act_layer=None,
688
+ weight_init="",
689
+ freeze_num=-1,
690
+ with_lidar=False,
691
+ with_right_left_sensors=False,
692
+ with_center_sensor=False,
693
+ traffic_pred_head_type="det",
694
+ waypoints_pred_head="heatmap",
695
+ reverse_pos=True,
696
+ use_different_backbone=False,
697
+ use_view_embed=False,
698
+ use_mmad_pretrain=None,
699
+ ):
700
+ super().__init__()
701
+ self.traffic_pred_head_type = traffic_pred_head_type
702
+ self.num_features = (
703
+ self.embed_dim
704
+ ) = embed_dim # num_features for consistency with other models
705
+ norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
706
+ act_layer = act_layer or nn.GELU
707
+
708
+ self.reverse_pos = reverse_pos
709
+ self.waypoints_pred_head = waypoints_pred_head
710
+ self.with_lidar = with_lidar
711
+ self.with_right_left_sensors = with_right_left_sensors
712
+ self.with_center_sensor = with_center_sensor
713
+
714
+ self.direct_concat = direct_concat
715
+ self.separate_view_attention = separate_view_attention
716
+ self.separate_all_attention = separate_all_attention
717
+ self.end2end = end2end
718
+ self.use_view_embed = use_view_embed
719
+
720
+ if self.direct_concat:
721
+ in_chans = in_chans * 4
722
+ self.with_center_sensor = False
723
+ self.with_right_left_sensors = False
724
+
725
+ if self.separate_view_attention:
726
+ self.attn_mask = build_attn_mask("seperate_view")
727
+ elif self.separate_all_attention:
728
+ self.attn_mask = build_attn_mask("seperate_all")
729
+ else:
730
+ self.attn_mask = None
731
+
732
+ if use_different_backbone:
733
+ if rgb_backbone_name == "r50":
734
+ self.rgb_backbone = resnet50d(
735
+ pretrained=True,
736
+ in_chans=in_chans,
737
+ features_only=True,
738
+ out_indices=[4],
739
+ )
740
+ elif rgb_backbone_name == "r26":
741
+ self.rgb_backbone = resnet26d(
742
+ pretrained=True,
743
+ in_chans=in_chans,
744
+ features_only=True,
745
+ out_indices=[4],
746
+ )
747
+ elif rgb_backbone_name == "r18":
748
+ self.rgb_backbone = resnet18d(
749
+ pretrained=True,
750
+ in_chans=in_chans,
751
+ features_only=True,
752
+ out_indices=[4],
753
+ )
754
+ if lidar_backbone_name == "r50":
755
+ self.lidar_backbone = resnet50d(
756
+ pretrained=False,
757
+ in_chans=in_chans,
758
+ features_only=True,
759
+ out_indices=[4],
760
+ )
761
+ elif lidar_backbone_name == "r26":
762
+ self.lidar_backbone = resnet26d(
763
+ pretrained=False,
764
+ in_chans=in_chans,
765
+ features_only=True,
766
+ out_indices=[4],
767
+ )
768
+ elif lidar_backbone_name == "r18":
769
+ self.lidar_backbone = resnet18d(
770
+ pretrained=False, in_chans=3, features_only=True, out_indices=[4]
771
+ )
772
+ rgb_embed_layer = partial(HybridEmbed, backbone=self.rgb_backbone)
773
+ lidar_embed_layer = partial(HybridEmbed, backbone=self.lidar_backbone)
774
+
775
+ if use_mmad_pretrain:
776
+ params = torch.load(use_mmad_pretrain)["state_dict"]
777
+ updated_params = OrderedDict()
778
+ for key in params:
779
+ if "backbone" in key:
780
+ updated_params[key.replace("backbone.", "")] = params[key]
781
+ self.rgb_backbone.load_state_dict(updated_params)
782
+
783
+ self.rgb_patch_embed = rgb_embed_layer(
784
+ img_size=img_size,
785
+ patch_size=patch_size,
786
+ in_chans=in_chans,
787
+ embed_dim=embed_dim,
788
+ )
789
+ self.lidar_patch_embed = lidar_embed_layer(
790
+ img_size=img_size,
791
+ patch_size=patch_size,
792
+ in_chans=3,
793
+ embed_dim=embed_dim,
794
+ )
795
+ else:
796
+ if rgb_backbone_name == "r50":
797
+ self.rgb_backbone = resnet50d(
798
+ pretrained=True, in_chans=3, features_only=True, out_indices=[4]
799
+ )
800
+ elif rgb_backbone_name == "r101":
801
+ self.rgb_backbone = resnet101d(
802
+ pretrained=True, in_chans=3, features_only=True, out_indices=[4]
803
+ )
804
+ elif rgb_backbone_name == "r26":
805
+ self.rgb_backbone = resnet26d(
806
+ pretrained=True, in_chans=3, features_only=True, out_indices=[4]
807
+ )
808
+ elif rgb_backbone_name == "r18":
809
+ self.rgb_backbone = resnet18d(
810
+ pretrained=True, in_chans=3, features_only=True, out_indices=[4]
811
+ )
812
+ embed_layer = partial(HybridEmbed, backbone=self.rgb_backbone)
813
+
814
+ self.rgb_patch_embed = embed_layer(
815
+ img_size=img_size,
816
+ patch_size=patch_size,
817
+ in_chans=in_chans,
818
+ embed_dim=embed_dim,
819
+ )
820
+ self.lidar_patch_embed = embed_layer(
821
+ img_size=img_size,
822
+ patch_size=patch_size,
823
+ in_chans=in_chans,
824
+ embed_dim=embed_dim,
825
+ )
826
+
827
+ self.global_embed = nn.Parameter(torch.zeros(1, embed_dim, 5))
828
+ self.view_embed = nn.Parameter(torch.zeros(1, embed_dim, 5, 1))
829
+
830
+ if self.end2end:
831
+ self.query_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, 4))
832
+ self.query_embed = nn.Parameter(torch.zeros(4, 1, embed_dim))
833
+ elif self.waypoints_pred_head == "heatmap":
834
+ self.query_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, 5))
835
+ self.query_embed = nn.Parameter(torch.zeros(400 + 5, 1, embed_dim))
836
+ else:
837
+ self.query_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, 11))
838
+ self.query_embed = nn.Parameter(torch.zeros(400 + 11, 1, embed_dim))
839
+
840
+ if self.end2end:
841
+ self.waypoints_generator = GRUWaypointsPredictor(embed_dim, 4)
842
+ elif self.waypoints_pred_head == "heatmap":
843
+ self.waypoints_generator = MultiPath_Generator(
844
+ embed_dim + 32, embed_dim, 10
845
+ )
846
+ elif self.waypoints_pred_head == "gru":
847
+ self.waypoints_generator = GRUWaypointsPredictor(embed_dim)
848
+ elif self.waypoints_pred_head == "gru-command":
849
+ self.waypoints_generator = GRUWaypointsPredictorWithCommand(embed_dim)
850
+ elif self.waypoints_pred_head == "linear":
851
+ self.waypoints_generator = LinearWaypointsPredictor(embed_dim)
852
+ elif self.waypoints_pred_head == "linear-sum":
853
+ self.waypoints_generator = LinearWaypointsPredictor(embed_dim, cumsum=True)
854
+
855
+ self.junction_pred_head = nn.Linear(embed_dim, 2)
856
+ self.traffic_light_pred_head = nn.Linear(embed_dim, 2)
857
+ self.stop_sign_head = nn.Linear(embed_dim, 2)
858
+
859
+ if self.traffic_pred_head_type == "det":
860
+ self.traffic_pred_head = nn.Sequential(
861
+ *[
862
+ nn.Linear(embed_dim + 32, 64),
863
+ nn.ReLU(),
864
+ nn.Linear(64, 7),
865
+ # nn.Sigmoid(),
866
+ ]
867
+ )
868
+ elif self.traffic_pred_head_type == "seg":
869
+ self.traffic_pred_head = nn.Sequential(
870
+ *[nn.Linear(embed_dim, 64), nn.ReLU(), nn.Linear(64, 1), nn.Sigmoid()]
871
+ )
872
+
873
+ self.position_encoding = PositionEmbeddingSine(embed_dim // 2, normalize=True)
874
+
875
+ encoder_layer = TransformerEncoderLayer(
876
+ embed_dim, num_heads, dim_feedforward, dropout, act_layer, normalize_before
877
+ )
878
+ self.encoder = TransformerEncoder(encoder_layer, enc_depth, None)
879
+
880
+ decoder_layer = TransformerDecoderLayer(
881
+ embed_dim, num_heads, dim_feedforward, dropout, act_layer, normalize_before
882
+ )
883
+ decoder_norm = nn.LayerNorm(embed_dim)
884
+ self.decoder = TransformerDecoder(
885
+ decoder_layer, dec_depth, decoder_norm, return_intermediate=False
886
+ )
887
+ self.reset_parameters()
888
+
889
+ def reset_parameters(self):
890
+ nn.init.uniform_(self.global_embed)
891
+ nn.init.uniform_(self.view_embed)
892
+ nn.init.uniform_(self.query_embed)
893
+ nn.init.uniform_(self.query_pos_embed)
894
+
895
+ def forward_features(
896
+ self,
897
+ front_image,
898
+ left_image,
899
+ right_image,
900
+ front_center_image,
901
+ lidar,
902
+ measurements,
903
+ ):
904
+ features = []
905
+
906
+ # Front view processing
907
+ front_image_token, front_image_token_global = self.rgb_patch_embed(front_image)
908
+ if self.use_view_embed:
909
+ front_image_token = (
910
+ front_image_token
911
+ + self.view_embed[:, :, 0:1, :]
912
+ + self.position_encoding(front_image_token)
913
+ )
914
+ else:
915
+ front_image_token = front_image_token + self.position_encoding(
916
+ front_image_token
917
+ )
918
+ front_image_token = front_image_token.flatten(2).permute(2, 0, 1)
919
+ front_image_token_global = (
920
+ front_image_token_global
921
+ + self.view_embed[:, :, 0, :]
922
+ + self.global_embed[:, :, 0:1]
923
+ )
924
+ front_image_token_global = front_image_token_global.permute(2, 0, 1)
925
+ features.extend([front_image_token, front_image_token_global])
926
+
927
+ if self.with_right_left_sensors:
928
+ # Left view processing
929
+ left_image_token, left_image_token_global = self.rgb_patch_embed(left_image)
930
+ if self.use_view_embed:
931
+ left_image_token = (
932
+ left_image_token
933
+ + self.view_embed[:, :, 1:2, :]
934
+ + self.position_encoding(left_image_token)
935
+ )
936
+ else:
937
+ left_image_token = left_image_token + self.position_encoding(
938
+ left_image_token
939
+ )
940
+ left_image_token = left_image_token.flatten(2).permute(2, 0, 1)
941
+ left_image_token_global = (
942
+ left_image_token_global
943
+ + self.view_embed[:, :, 1, :]
944
+ + self.global_embed[:, :, 1:2]
945
+ )
946
+ left_image_token_global = left_image_token_global.permute(2, 0, 1)
947
+
948
+ # Right view processing
949
+ right_image_token, right_image_token_global = self.rgb_patch_embed(
950
+ right_image
951
+ )
952
+ if self.use_view_embed:
953
+ right_image_token = (
954
+ right_image_token
955
+ + self.view_embed[:, :, 2:3, :]
956
+ + self.position_encoding(right_image_token)
957
+ )
958
+ else:
959
+ right_image_token = right_image_token + self.position_encoding(
960
+ right_image_token
961
+ )
962
+ right_image_token = right_image_token.flatten(2).permute(2, 0, 1)
963
+ right_image_token_global = (
964
+ right_image_token_global
965
+ + self.view_embed[:, :, 2, :]
966
+ + self.global_embed[:, :, 2:3]
967
+ )
968
+ right_image_token_global = right_image_token_global.permute(2, 0, 1)
969
+
970
+ features.extend(
971
+ [
972
+ left_image_token,
973
+ left_image_token_global,
974
+ right_image_token,
975
+ right_image_token_global,
976
+ ]
977
+ )
978
+
979
+ if self.with_center_sensor:
980
+ # Front center view processing
981
+ (
982
+ front_center_image_token,
983
+ front_center_image_token_global,
984
+ ) = self.rgb_patch_embed(front_center_image)
985
+ if self.use_view_embed:
986
+ front_center_image_token = (
987
+ front_center_image_token
988
+ + self.view_embed[:, :, 3:4, :]
989
+ + self.position_encoding(front_center_image_token)
990
+ )
991
+ else:
992
+ front_center_image_token = (
993
+ front_center_image_token
994
+ + self.position_encoding(front_center_image_token)
995
+ )
996
+
997
+ front_center_image_token = front_center_image_token.flatten(2).permute(
998
+ 2, 0, 1
999
+ )
1000
+ front_center_image_token_global = (
1001
+ front_center_image_token_global
1002
+ + self.view_embed[:, :, 3, :]
1003
+ + self.global_embed[:, :, 3:4]
1004
+ )
1005
+ front_center_image_token_global = front_center_image_token_global.permute(
1006
+ 2, 0, 1
1007
+ )
1008
+ features.extend([front_center_image_token, front_center_image_token_global])
1009
+
1010
+ if self.with_lidar:
1011
+ lidar_token, lidar_token_global = self.lidar_patch_embed(lidar)
1012
+ if self.use_view_embed:
1013
+ lidar_token = (
1014
+ lidar_token
1015
+ + self.view_embed[:, :, 4:5, :]
1016
+ + self.position_encoding(lidar_token)
1017
+ )
1018
+ else:
1019
+ lidar_token = lidar_token + self.position_encoding(lidar_token)
1020
+ lidar_token = lidar_token.flatten(2).permute(2, 0, 1)
1021
+ lidar_token_global = (
1022
+ lidar_token_global
1023
+ + self.view_embed[:, :, 4, :]
1024
+ + self.global_embed[:, :, 4:5]
1025
+ )
1026
+ lidar_token_global = lidar_token_global.permute(2, 0, 1)
1027
+ features.extend([lidar_token, lidar_token_global])
1028
+
1029
+ features = torch.cat(features, 0)
1030
+ return features
1031
+
1032
+ def forward(self, x):
1033
+ front_image = x["rgb"]
1034
+ left_image = x["rgb_left"]
1035
+ right_image = x["rgb_right"]
1036
+ front_center_image = x["rgb_center"]
1037
+ measurements = x["measurements"]
1038
+ target_point = x["target_point"]
1039
+ lidar = x["lidar"]
1040
+
1041
+ if self.direct_concat:
1042
+ img_size = front_image.shape[-1]
1043
+ left_image = torch.nn.functional.interpolate(
1044
+ left_image, size=(img_size, img_size)
1045
+ )
1046
+ right_image = torch.nn.functional.interpolate(
1047
+ right_image, size=(img_size, img_size)
1048
+ )
1049
+ front_center_image = torch.nn.functional.interpolate(
1050
+ front_center_image, size=(img_size, img_size)
1051
+ )
1052
+ front_image = torch.cat(
1053
+ [front_image, left_image, right_image, front_center_image], dim=1
1054
+ )
1055
+ features = self.forward_features(
1056
+ front_image,
1057
+ left_image,
1058
+ right_image,
1059
+ front_center_image,
1060
+ lidar,
1061
+ measurements,
1062
+ )
1063
+
1064
+ bs = front_image.shape[0]
1065
+
1066
+ if self.end2end:
1067
+ tgt = self.query_pos_embed.repeat(bs, 1, 1)
1068
+ else:
1069
+ tgt = self.position_encoding(
1070
+ torch.ones((bs, 1, 20, 20), device=x["rgb"].device)
1071
+ )
1072
+ tgt = tgt.flatten(2)
1073
+ tgt = torch.cat([tgt, self.query_pos_embed.repeat(bs, 1, 1)], 2)
1074
+ tgt = tgt.permute(2, 0, 1)
1075
+
1076
+ memory = self.encoder(features, mask=self.attn_mask)
1077
+ hs = self.decoder(self.query_embed.repeat(1, bs, 1), memory, query_pos=tgt)[0]
1078
+
1079
+ hs = hs.permute(1, 0, 2) # Batchsize , N, C
1080
+ if self.end2end:
1081
+ waypoints = self.waypoints_generator(hs, target_point)
1082
+ return waypoints
1083
+
1084
+ if self.waypoints_pred_head != "heatmap":
1085
+ traffic_feature = hs[:, :400]
1086
+ is_junction_feature = hs[:, 400]
1087
+ traffic_light_state_feature = hs[:, 400]
1088
+ stop_sign_feature = hs[:, 400]
1089
+ waypoints_feature = hs[:, 401:411]
1090
+ else:
1091
+ traffic_feature = hs[:, :400]
1092
+ is_junction_feature = hs[:, 400]
1093
+ traffic_light_state_feature = hs[:, 400]
1094
+ stop_sign_feature = hs[:, 400]
1095
+ waypoints_feature = hs[:, 401:405]
1096
+
1097
+ if self.waypoints_pred_head == "heatmap":
1098
+ waypoints = self.waypoints_generator(waypoints_feature, measurements)
1099
+ elif self.waypoints_pred_head == "gru":
1100
+ waypoints = self.waypoints_generator(waypoints_feature, target_point)
1101
+ elif self.waypoints_pred_head == "gru-command":
1102
+ waypoints = self.waypoints_generator(waypoints_feature, target_point, measurements)
1103
+ elif self.waypoints_pred_head == "linear":
1104
+ waypoints = self.waypoints_generator(waypoints_feature, measurements)
1105
+ elif self.waypoints_pred_head == "linear-sum":
1106
+ waypoints = self.waypoints_generator(waypoints_feature, measurements)
1107
+
1108
+ is_junction = self.junction_pred_head(is_junction_feature)
1109
+ traffic_light_state = self.traffic_light_pred_head(traffic_light_state_feature)
1110
+ stop_sign = self.stop_sign_head(stop_sign_feature)
1111
+
1112
+ velocity = measurements[:, 6:7].unsqueeze(-1)
1113
+ velocity = velocity.repeat(1, 400, 32)
1114
+ traffic_feature_with_vel = torch.cat([traffic_feature, velocity], dim=2)
1115
+ traffic = self.traffic_pred_head(traffic_feature_with_vel)
1116
+ return traffic, waypoints, is_junction, traffic_light_state, stop_sign, traffic_feature
1117
+ def load_pretrained(self, model_path, strict=False):
1118
+ """
1119
+ تحميل الأوزان المدربة مسبقاً - نسخة محسنة
1120
+
1121
+ Args:
1122
+ model_path (str): مسار ملف الأوزان
1123
+ strict (bool): إذا كان True، يتطلب تطابق تام للمفاتيح
1124
+ """
1125
+ if not model_path or not Path(model_path).exists():
1126
+ logging.warning(f"ملف الأوزان غير موجود: {model_path}")
1127
+ logging.info("سيتم استخدام أوزان عشوائية")
1128
+ return False
1129
+
1130
+ try:
1131
+ logging.info(f"محاولة تحميل الأوزان من: {model_path}")
1132
+
1133
+ # تحميل الملف مع معالجة أنواع مختلفة من ملفات الحفظ
1134
+ checkpoint = torch.load(model_path, map_location='cpu', weights_only=False)
1135
+
1136
+ # استخراج state_dict من أنواع مختلفة من ملفات الحفظ
1137
+ if isinstance(checkpoint, dict):
1138
+ if 'model_state_dict' in checkpoint:
1139
+ state_dict = checkpoint['model_state_dict']
1140
+ logging.info("تم العثور على 'model_state_dict' في الملف")
1141
+ elif 'state_dict' in checkpoint:
1142
+ state_dict = checkpoint['state_dict']
1143
+ logging.info("تم العثور على 'state_dict' في الملف")
1144
+ elif 'model' in checkpoint:
1145
+ state_dict = checkpoint['model']
1146
+ logging.info("تم العثور على 'model' في الملف")
1147
+ else:
1148
+ state_dict = checkpoint
1149
+ logging.info("استخدام الملف كـ state_dict مباشرة")
1150
+ else:
1151
+ state_dict = checkpoint
1152
+ logging.info("استخدام الملف كـ state_dict مباشرة")
1153
+
1154
+ # تنظيف أسماء المفاتيح (إزالة 'module.' إذا كانت موجودة)
1155
+ clean_state_dict = OrderedDict()
1156
+ for k, v in state_dict.items():
1157
+ # إزالة 'module.' من بداية اسم المفتاح إذا كان موجوداً
1158
+ clean_key = k[7:] if k.startswith('module.') else k
1159
+ clean_state_dict[clean_key] = v
1160
+
1161
+ # تحميل الأوزان
1162
+ missing_keys, unexpected_keys = self.load_state_dict(clean_state_dict, strict=strict)
1163
+
1164
+ # تقرير حالة التحميل
1165
+ if missing_keys:
1166
+ logging.warning(f"مفاتيح مفقودة ({len(missing_keys)}): {missing_keys[:5]}..." if len(missing_keys) > 5 else f"مفاتيح مفقودة: {missing_keys}")
1167
+
1168
+ if unexpected_keys:
1169
+ logging.warning(f"مفاتيح غير متوقعة ({len(unexpected_keys)}): {unexpected_keys[:5]}..." if len(unexpected_keys) > 5 else f"مفاتيح غير متوقعة: {unexpected_keys}")
1170
+
1171
+ if not missing_keys and not unexpected_keys:
1172
+ logging.info("✅ تم تحميل جميع الأوزان بنجاح تام")
1173
+ elif not strict:
1174
+ logging.info("✅ تم تحميل الأوزان بنجاح (مع تجاهل عدم التطابق)")
1175
+
1176
+ return True
1177
+
1178
+ except Exception as e:
1179
+ logging.error(f"❌ خطأ في تحميل الأوزان: {str(e)}")
1180
+ logging.info("سيتم استخدام أوزان عشوائية")
1181
+ return False
1182
+
1183
+
1184
+ # ============================================================================
1185
+ # دوال مساعدة لتحميل النموذج
1186
+ # ============================================================================
1187
+ # ==============================================================================
1188
+ # ملف: config_and_loader.py
1189
+ # هذا هو المصدر الوحيد للحقيقة لجميع الإعدادات وعملية تحميل النموذج.
1190
+ # ==============================================================================
1191
+
1192
+
1193
+
1194
+ # def get_master_config(model_path="model/best_model.pth"):
1195
+ # """
1196
+ # [النسخة الكاملة والنهائية]
1197
+ # ينشئ ويدمج كل الإعدادات المطلوبة للتطبيق (النموذج، المتتبع، المتحكم).
1198
+ # """
1199
+ # model_params = {
1200
+ # "img_size": 224, "embed_dim": 256, "enc_depth": 6, "dec_depth": 6,
1201
+ # "rgb_backbone_name": 'r50', "lidar_backbone_name": 'r18',
1202
+ # "waypoints_pred_head": 'gru', "use_different_backbone": True,
1203
+ # "with_lidar": False, "with_right_left_sensors": False,
1204
+ # "with_center_sensor": False, "multi_view_img_size": 112,
1205
+ # "patch_size": 8, "in_chans": 3, "dim_feedforward": 2048,
1206
+ # "normalize_before": False, "num_heads": 8, "dropout": 0.1,
1207
+ # "end2end": False, "direct_concat": False, "separate_view_attention": False,
1208
+ # "separate_all_attention": False, "freeze_num": -1,
1209
+ # "traffic_pred_head_type": "det", "reverse_pos": True,
1210
+ # "use_view_embed": False, "use_mmad_pretrain": None,
1211
+ # }
1212
+
1213
+ # grid_conf = {
1214
+ # 'h': 20, 'w': 20, 'x_res': 1.0, 'y_res': 1.0,
1215
+ # 'y_min': 0.0, 'y_max': 20.0, 'x_min': -10.0, 'x_max': 10.0,
1216
+ # }
1217
+
1218
+ # controller_params = {
1219
+ # 'turn_KP': 0.75, 'turn_KI': 0.05, 'turn_KD': 0.25, 'turn_n': 20,
1220
+ # 'speed_KP': 0.55, 'speed_KI': 0.05, 'speed_KD': 0.15, 'speed_n': 20,
1221
+ # 'max_speed': 8.0, 'max_throttle': 0.75, 'min_speed': 0.1,
1222
+ # 'brake_sensitivity': 0.3, 'light_threshold': 0.5, 'stop_threshold': 0.6,
1223
+ # 'stop_sign_duration': 20, 'max_stop_time': 250,
1224
+ # 'forced_move_duration': 20, 'forced_throttle': 0.5,
1225
+ # 'max_red_light_time': 150, 'red_light_block_duration': 80,
1226
+ # 'accel_rate': 0.1, 'decel_rate': 0.2, 'critical_distance': 4.0,
1227
+ # 'follow_distance': 10.0, 'speed_match_factor': 0.9,
1228
+ # 'tracker_match_thresh': 2.5, 'tracker_prune_age': 5,
1229
+ # 'follow_grace_period': 20
1230
+ # }
1231
+
1232
+ # master_config = {
1233
+ # 'model_params': model_params,
1234
+ # 'grid_conf': grid_conf,
1235
+ # 'controller_params': controller_params,
1236
+ # 'paths': {'pretrained_weights': model_path},
1237
+ # 'simulation': {'frequency': 10.0}
1238
+ # }
1239
+
1240
+ # return master_config
1241
+
1242
+
1243
+ # def load_and_prepare_model(device: torch.device) -> InterfuserModel:
1244
+ # """
1245
+ # [النسخة النهائية الصحيحة - تستقبل مدخلاً واحدًا فقط]
1246
+ # تستخدم دالة الإعدادات الرئيسية لإنشاء وتحميل النموذج.
1247
+ # """
1248
+ # try:
1249
+ # logging.info("Attempting to load model using master config...")
1250
+ # # 1. الحصول على كل الإعدادات من المصدر الوحيد للحقيقة
1251
+ # config = get_master_config()
1252
+
1253
+ # # 2. إنشاء النموذج باستخدام إعدادات النموذج فقط
1254
+ # model = InterfuserModel(**config['model_params']).to(device)
1255
+ # logging.info(f"Model instantiated on device: {device}")
1256
+
1257
+ # # 3. تحميل الأوزان باستخدام الدالة الداخلية للنموذج
1258
+ # checkpoint_path = config['paths']['pretrained_weights']
1259
+ # model.load_pretrained(checkpoint_path, strict=False)
1260
+
1261
+ # # 4. وضع النموذج في وضع التقييم
1262
+ # model.eval()
1263
+ # logging.info("✅ Model prepared and set to evaluation mode.")
1264
+
1265
+ # return model
1266
+
1267
+ # except Exception as e:
1268
+ # logging.error(f"❌ CRITICAL ERROR in load_and_prepare_model: {e}", exc_info=True)
1269
+ # raise
1270
+
1271
+
1272
+
1273
+ # ==============================================================================
1274
+ # الدالة الأولى: get_master_config
1275
+ # ==============================================================================
1276
+
1277
+ def get_master_config():
1278
+ """
1279
+ [النسخة الاحترافية]
1280
+ يعيد قاموسًا شاملاً يحتوي على جميع إعدادات التطبيق الثابتة.
1281
+ هذه الدالة هي المصدر الوحيد للحقيقة للإعدادات.
1282
+ """
1283
+
1284
+ # --- القسم 1: معلومات مستودع النموذج على Hugging Face Hub ---
1285
+ huggingface_repo = {
1286
+ 'repo_id': "BaseerAI/Interfuser-Baseer-v1", # استبدله باسم مستودع النموذج الخاص بك
1287
+ 'filename': "best_model.pth" # اسم ملف الأوزان داخل المستودع
1288
+ }
1289
+
1290
+ # --- القسم 2: إعدادات بنية نموذج Interfuser ---
1291
+ model_params = {
1292
+ "img_size": 224, "embed_dim": 256, "enc_depth": 6, "dec_depth": 6,
1293
+ "rgb_backbone_name": 'r50', "lidar_backbone_name": 'r18',
1294
+ "waypoints_pred_head": 'gru', "use_different_backbone": True,
1295
+ "with_lidar": False, "with_right_left_sensors": False,
1296
+ "with_center_sensor": False, "multi_view_img_size": 112,
1297
+ "patch_size": 8, "in_chans": 3, "dim_feedforward": 2048,
1298
+ "normalize_before": False, "num_heads": 8, "dropout": 0.1,
1299
+ "end2end": False, "direct_concat": False, "separate_view_attention": False,
1300
+ "separate_all_attention": False, "freeze_num": -1,
1301
+ "traffic_pred_head_type": "det", "reverse_pos": True,
1302
+ "use_view_embed": False, "use_mmad_pretrain": None,
1303
+ }
1304
+
1305
+ # --- القسم 3: إعدادات الشبكة ومنظور عين الطائر (BEV) ---
1306
+ grid_conf = {
1307
+ 'h': 20, 'w': 20, 'x_res': 1.0, 'y_res': 1.0,
1308
+ 'y_min': 0.0, 'y_max': 20.0, 'x_min': -10.0, 'x_max': 10.0,
1309
+ }
1310
+
1311
+ # --- القسم 4: إعدادات وحدة التحكم (Controller) والمتتبع (Tracker) ---
1312
+ controller_params = {
1313
+ 'turn_KP': 0.75, 'turn_KI': 0.05, 'turn_KD': 0.25, 'turn_n': 20,
1314
+ 'speed_KP': 0.55, 'speed_KI': 0.05, 'speed_KD': 0.15, 'speed_n': 20,
1315
+ 'max_speed': 8.0, 'max_throttle': 0.75, 'min_speed': 0.1,
1316
+ 'brake_sensitivity': 0.3, 'light_threshold': 0.5, 'stop_threshold': 0.6,
1317
+ 'stop_sign_duration': 20, 'max_stop_time': 250,
1318
+ 'forced_move_duration': 20, 'forced_throttle': 0.5,
1319
+ 'max_red_light_time': 150, 'red_light_block_duration': 80,
1320
+ 'accel_rate': 0.1, 'decel_rate': 0.2, 'critical_distance': 4.0,
1321
+ 'follow_distance': 10.0, 'speed_match_factor': 0.9,
1322
+ 'tracker_match_thresh': 2.5, 'tracker_prune_age': 5,
1323
+ 'follow_grace_period': 20
1324
+ }
1325
+
1326
+ # --- القسم 5: تجميع كل شيء في قاموس رئيسي واحد ---
1327
+ master_config = {
1328
+ 'huggingface_repo': huggingface_repo,
1329
+ 'model_params': model_params,
1330
+ 'grid_conf': grid_conf,
1331
+ 'controller_params': controller_params,
1332
+ 'simulation': {
1333
+ 'frequency': 10.0
1334
+ }
1335
+ }
1336
+
1337
+ return master_config
1338
+
1339
+
1340
+ # ==============================================================================
1341
+ # الدالة الثانية: load_and_prepare_model
1342
+ # ==============================================================================
1343
+
1344
+ def load_and_prepare_model(device: torch.device) -> InterfuserModel:
1345
+ """
1346
+ [النسخة الاحترافية]
1347
+ تستخدم الإعدادات الرئيسية من `get_master_config` لإنشاء وتحميل النموذج.
1348
+ تقوم بتحويل معرّف النموذج من Hugging Face Hub إلى مسار ملف حقيقي.
1349
+
1350
+ Args:
1351
+ device (torch.device): الجهاز المستهدف (CPU/GPU)
1352
+
1353
+ Returns:
1354
+ Interfuser: النموذج المحمل وجاهز للاستدلال.
1355
+ """
1356
+ try:
1357
+ logging.info("Initializing model loading process...")
1358
+
1359
+ # 1. الحصول على جميع الإعدادات من المصدر الوحيد للحقيقة
1360
+ config = get_master_config()
1361
+
1362
+ # 2. تحميل ملف الأوزان من Hugging Face Hub
1363
+ repo_info = config['huggingface_repo']
1364
+ logging.info(f"Downloading model weights from repo: '{repo_info['repo_id']}'")
1365
+
1366
+ # استخدام token إذا كان المستودع خاصًا
1367
+ # token = HfFolder.get_token() # أو يمكن تمريره مباشرة
1368
+ actual_model_path = hf_hub_download(
1369
+ repo_id=repo_info['repo_id'],
1370
+ filename=repo_info['filename'],
1371
+ # token=token, # قم بإلغاء التعليق إذا كان المستودع خاصًا
1372
+ )
1373
+ logging.info(f"Model weights are available at local path: {actual_model_path}")
1374
+
1375
+ # 3. إنشاء نسخة من النموذج باستخدام الإعدادات الصحيحة
1376
+ logging.info("Instantiating model with specified parameters...")
1377
+ model = InterfuserModel(**config['model_params']).to(device)
1378
+
1379
+ # 4. تحميل الأوزان التي تم تنزيلها إلى النموذج
1380
+ # نستخدم الدالة المساعدة الموجودة داخل كلاس النموذج نفسه
1381
+ success = model.load_pretrained(actual_model_path, strict=False)
1382
+ if not success:
1383
+ logging.warning("⚠️ Model weights were not loaded successfully. The model will use random weights.")
1384
+
1385
+ # 5. وضع النموذج في وضع التقييم (خطوة حاسمة)
1386
+ model.eval()
1387
+ logging.info("✅ Model prepared and set to evaluation mode. Ready for inference.")
1388
+
1389
+ return model
1390
+
1391
+ except Exception as e:
1392
+ # تسجيل الخطأ بالتفصيل ثم إطلاقه مرة أخرى ليتم التعامل معه في مستوى أعلى
1393
+ logging.error(f"❌ CRITICAL ERROR during model initialization: {e}", exc_info=True)
1394
+ raise