BaseerAI commited on
Commit
e5c9178
·
verified ·
1 Parent(s): 3e2aeec

Update model_definition.py

Browse files
Files changed (1) hide show
  1. model_definition.py +81 -122
model_definition.py CHANGED
@@ -115,48 +115,93 @@ class HybridEmbed(nn.Module):
115
  return x, global_x
116
 
117
 
118
- class PositionEmbeddingSine(nn.Module):
119
- """
120
- This is a more standard version of the position embedding, very similar to the one
121
- used by the Attention is all you need paper, generalized to work on images.
122
  """
 
 
123
 
124
- def __init__(
125
- self, num_pos_feats=64, temperature=10000, normalize=False, scale=None
126
- ):
 
 
 
 
 
 
 
 
 
 
 
127
  super().__init__()
 
 
 
 
128
  self.num_pos_feats = num_pos_feats
129
  self.temperature = temperature
130
  self.normalize = normalize
131
- if scale is not None and normalize is False:
 
132
  raise ValueError("normalize should be True if scale is passed")
133
  if scale is None:
134
  scale = 2 * math.pi
135
  self.scale = scale
136
 
137
- def forward(self, tensor):
138
- x = tensor
139
- bs, _, h, w = x.shape
140
- not_mask = torch.ones((bs, h, w), device=x.device)
141
- y_embed = not_mask.cumsum(1, dtype=torch.float32)
142
- x_embed = not_mask.cumsum(2, dtype=torch.float32)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
  if self.normalize:
144
- eps = 1e-6
145
- y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
146
- x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
147
-
148
- dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
149
- dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
150
-
151
- pos_x = x_embed[:, :, :, None] / dim_t
152
- pos_y = y_embed[:, :, :, None] / dim_t
153
- pos_x = torch.stack(
154
- (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
155
- ).flatten(3)
156
- pos_y = torch.stack(
157
- (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
158
- ).flatten(3)
159
- pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
 
 
160
  return pos
161
 
162
 
@@ -663,7 +708,7 @@ def build_attn_mask(mask_type):
663
  return mask
664
  # class InterfuserModel(nn.Module):
665
 
666
- class InterfuserModel(nn.Module):
667
  def __init__(
668
  self,
669
  img_size=224,
@@ -870,7 +915,7 @@ class InterfuserModel(nn.Module):
870
  *[nn.Linear(embed_dim, 64), nn.ReLU(), nn.Linear(64, 1), nn.Sigmoid()]
871
  )
872
 
873
- self.position_encoding = PositionEmbeddingSine(embed_dim // 2, normalize=True)
874
 
875
  encoder_layer = TransformerEncoderLayer(
876
  embed_dim, num_heads, dim_feedforward, dropout, act_layer, normalize_before
@@ -1114,6 +1159,8 @@ class InterfuserModel(nn.Module):
1114
  traffic_feature_with_vel = torch.cat([traffic_feature, velocity], dim=2)
1115
  traffic = self.traffic_pred_head(traffic_feature_with_vel)
1116
  return traffic, waypoints, is_junction, traffic_light_state, stop_sign, traffic_feature
 
 
1117
  def load_pretrained(self, model_path, strict=False):
1118
  """
1119
  تحميل الأوزان المدربة مسبقاً - نسخة محسنة
@@ -1181,94 +1228,6 @@ class InterfuserModel(nn.Module):
1181
  return False
1182
 
1183
 
1184
- # ============================================================================
1185
- # دوال مساعدة لتحميل النموذج
1186
- # ============================================================================
1187
- # ==============================================================================
1188
- # ملف: config_and_loader.py
1189
- # هذا هو المصدر الوحيد للحقيقة لجميع الإعدادات وعملية تحميل النموذج.
1190
- # ==============================================================================
1191
-
1192
-
1193
-
1194
- # def get_master_config(model_path="model/best_model.pth"):
1195
- # """
1196
- # [النسخة الكاملة والنهائية]
1197
- # ينشئ ويدمج كل الإعدادات المطلوبة للتطبيق (النموذج، المتتبع، المتحكم).
1198
- # """
1199
- # model_params = {
1200
- # "img_size": 224, "embed_dim": 256, "enc_depth": 6, "dec_depth": 6,
1201
- # "rgb_backbone_name": 'r50', "lidar_backbone_name": 'r18',
1202
- # "waypoints_pred_head": 'gru', "use_different_backbone": True,
1203
- # "with_lidar": False, "with_right_left_sensors": False,
1204
- # "with_center_sensor": False, "multi_view_img_size": 112,
1205
- # "patch_size": 8, "in_chans": 3, "dim_feedforward": 2048,
1206
- # "normalize_before": False, "num_heads": 8, "dropout": 0.1,
1207
- # "end2end": False, "direct_concat": False, "separate_view_attention": False,
1208
- # "separate_all_attention": False, "freeze_num": -1,
1209
- # "traffic_pred_head_type": "det", "reverse_pos": True,
1210
- # "use_view_embed": False, "use_mmad_pretrain": None,
1211
- # }
1212
-
1213
- # grid_conf = {
1214
- # 'h': 20, 'w': 20, 'x_res': 1.0, 'y_res': 1.0,
1215
- # 'y_min': 0.0, 'y_max': 20.0, 'x_min': -10.0, 'x_max': 10.0,
1216
- # }
1217
-
1218
- # controller_params = {
1219
- # 'turn_KP': 0.75, 'turn_KI': 0.05, 'turn_KD': 0.25, 'turn_n': 20,
1220
- # 'speed_KP': 0.55, 'speed_KI': 0.05, 'speed_KD': 0.15, 'speed_n': 20,
1221
- # 'max_speed': 8.0, 'max_throttle': 0.75, 'min_speed': 0.1,
1222
- # 'brake_sensitivity': 0.3, 'light_threshold': 0.5, 'stop_threshold': 0.6,
1223
- # 'stop_sign_duration': 20, 'max_stop_time': 250,
1224
- # 'forced_move_duration': 20, 'forced_throttle': 0.5,
1225
- # 'max_red_light_time': 150, 'red_light_block_duration': 80,
1226
- # 'accel_rate': 0.1, 'decel_rate': 0.2, 'critical_distance': 4.0,
1227
- # 'follow_distance': 10.0, 'speed_match_factor': 0.9,
1228
- # 'tracker_match_thresh': 2.5, 'tracker_prune_age': 5,
1229
- # 'follow_grace_period': 20
1230
- # }
1231
-
1232
- # master_config = {
1233
- # 'model_params': model_params,
1234
- # 'grid_conf': grid_conf,
1235
- # 'controller_params': controller_params,
1236
- # 'paths': {'pretrained_weights': model_path},
1237
- # 'simulation': {'frequency': 10.0}
1238
- # }
1239
-
1240
- # return master_config
1241
-
1242
-
1243
- # def load_and_prepare_model(device: torch.device) -> InterfuserModel:
1244
- # """
1245
- # [النسخة النهائية الصحيحة - تستقبل مدخلاً واحدًا فقط]
1246
- # تستخدم دالة الإعدادات الرئيسية لإنشاء وتحميل النموذج.
1247
- # """
1248
- # try:
1249
- # logging.info("Attempting to load model using master config...")
1250
- # # 1. الحصول على كل الإعدادات من المصدر الوحيد للحقيقة
1251
- # config = get_master_config()
1252
-
1253
- # # 2. إنشاء النموذج باستخدام إعدادات النموذج فقط
1254
- # model = InterfuserModel(**config['model_params']).to(device)
1255
- # logging.info(f"Model instantiated on device: {device}")
1256
-
1257
- # # 3. تحميل الأوزان باستخدام الدالة الداخلية للنموذج
1258
- # checkpoint_path = config['paths']['pretrained_weights']
1259
- # model.load_pretrained(checkpoint_path, strict=False)
1260
-
1261
- # # 4. وضع النموذج في وضع التقييم
1262
- # model.eval()
1263
- # logging.info("✅ Model prepared and set to evaluation mode.")
1264
-
1265
- # return model
1266
-
1267
- # except Exception as e:
1268
- # logging.error(f"❌ CRITICAL ERROR in load_and_prepare_model: {e}", exc_info=True)
1269
- # raise
1270
-
1271
-
1272
 
1273
  # ==============================================================================
1274
  # الدالة الأولى: get_master_config
@@ -1341,7 +1300,7 @@ def get_master_config():
1341
  # الدالة الثانية: load_and_prepare_model
1342
  # ==============================================================================
1343
 
1344
- def load_and_prepare_model(device: torch.device) -> InterfuserModel:
1345
  """
1346
  [النسخة الاحترافية]
1347
  تستخدم الإعدادات الرئيسية من `get_master_config` لإنشاء وتحميل النموذج.
@@ -1374,7 +1333,7 @@ def load_and_prepare_model(device: torch.device) -> InterfuserModel:
1374
 
1375
  # 3. إنشاء نسخة من النموذج باستخدام الإعدادات الصحيحة
1376
  logging.info("Instantiating model with specified parameters...")
1377
- model = InterfuserModel(**config['model_params']).to(device)
1378
 
1379
  # 4. تحميل الأوزان التي تم تنزيلها إلى النموذج
1380
  # نستخدم الدالة المساعدة الموجودة داخل كلاس النموذج نفسه
 
115
  return x, global_x
116
 
117
 
118
+ class HyperDimensionalPositionalEncoding(nn.Module):
 
 
 
119
  """
120
+ [GCPE v1.1 - Professional & Corrected Implementation]
121
+ A novel positional encoding scheme based on geometric centrality.
122
 
123
+ This class is designed as a drop-in replacement for the standard
124
+ PositionEmbeddingSine, accepting similar arguments and producing an
125
+ output of the same shape. This version corrects a type error in the
126
+ distance calculation.
127
+ """
128
+ def __init__(self, num_pos_feats=256, temperature=10000, normalize=True, scale=None):
129
+ """
130
+ Args:
131
+ num_pos_feats (int): The desired number of output channels for the positional encoding.
132
+ This must be an even number.
133
+ temperature (int): A constant used to scale the frequencies.
134
+ normalize (bool): If True, normalizes the coordinates to the range [0, scale].
135
+ scale (float, optional): The scaling factor for normalization. Defaults to 2*pi.
136
+ """
137
  super().__init__()
138
+
139
+ if num_pos_feats % 2 != 0:
140
+ raise ValueError(f"num_pos_feats must be an even number, but got {num_pos_feats}")
141
+
142
  self.num_pos_feats = num_pos_feats
143
  self.temperature = temperature
144
  self.normalize = normalize
145
+
146
+ if scale is not None and not normalize:
147
  raise ValueError("normalize should be True if scale is passed")
148
  if scale is None:
149
  scale = 2 * math.pi
150
  self.scale = scale
151
 
152
+ def forward(self, tensor: torch.Tensor) -> torch.Tensor:
153
+ """
154
+ Args:
155
+ tensor (torch.Tensor): A 4D tensor of shape (B, C, H, W). The content is not
156
+ used, only its shape and device.
157
+
158
+ Returns:
159
+ torch.Tensor: A 4D tensor of positional encodings with shape (B, num_pos_feats, H, W).
160
+ """
161
+ batch_size, _, h, w = tensor.shape
162
+ device = tensor.device
163
+
164
+ # 1. Create coordinate grids
165
+ y_embed = torch.arange(h, dtype=torch.float32, device=device).view(h, 1)
166
+ x_embed = torch.arange(w, dtype=torch.float32, device=device).view(1, w)
167
+
168
+ # 2. Calculate normalized distance from the center
169
+ # Use floating point division for center calculation
170
+ center_y, center_x = (h - 1) / 2.0, (w - 1) / 2.0
171
+
172
+ # Calculate the Euclidean distance for each pixel from the center
173
+ dist_map = torch.sqrt(
174
+ (y_embed - center_y)**2 + (x_embed - center_x)**2
175
+ )
176
+
177
+ # ✅ CORRECTION: The max distance is a scalar, no need for torch.sqrt on a float.
178
+ # We can calculate it with math.sqrt or just compute the squared value.
179
+ # To keep everything in tensors for consistency, we can do this:
180
+ max_dist_sq = torch.tensor(center_y**2 + center_x**2, device=device)
181
+ max_dist = torch.sqrt(max_dist_sq)
182
+
183
+ # Normalize the distance map to the range [0, 1]
184
+ normalized_dist_map = dist_map / (max_dist + 1e-6)
185
+
186
  if self.normalize:
187
+ normalized_dist_map = normalized_dist_map * self.scale
188
+
189
+ pos_dist = normalized_dist_map.unsqueeze(0).repeat(batch_size, 1, 1)
190
+
191
+ # 3. Create the frequency-based embedding
192
+ # This part remains the same as it operates on tensors correctly.
193
+ dim_t = torch.arange(self.num_pos_feats // 2, dtype=torch.float32, device=device)
194
+ dim_t = self.temperature ** (2 * dim_t / (self.num_pos_feats // 2))
195
+
196
+ pos = pos_dist.unsqueeze(-1) / dim_t
197
+
198
+ pos_sin = pos.sin()
199
+ pos_cos = pos.cos()
200
+
201
+ # 4. Concatenate and reshape to match the desired output format
202
+ pos = torch.cat((pos_sin, pos_cos), dim=3)
203
+ pos = pos.permute(0, 3, 1, 2)
204
+
205
  return pos
206
 
207
 
 
708
  return mask
709
  # class InterfuserModel(nn.Module):
710
 
711
+ class InterfuserHDPE(nn.Module):
712
  def __init__(
713
  self,
714
  img_size=224,
 
915
  *[nn.Linear(embed_dim, 64), nn.ReLU(), nn.Linear(64, 1), nn.Sigmoid()]
916
  )
917
 
918
+ self.position_encoding = HyperDimensionalPositionalEncoding(embed_dim , normalize=True)
919
 
920
  encoder_layer = TransformerEncoderLayer(
921
  embed_dim, num_heads, dim_feedforward, dropout, act_layer, normalize_before
 
1159
  traffic_feature_with_vel = torch.cat([traffic_feature, velocity], dim=2)
1160
  traffic = self.traffic_pred_head(traffic_feature_with_vel)
1161
  return traffic, waypoints, is_junction, traffic_light_state, stop_sign, traffic_feature
1162
+
1163
+
1164
  def load_pretrained(self, model_path, strict=False):
1165
  """
1166
  تحميل الأوزان المدربة مسبقاً - نسخة محسنة
 
1228
  return False
1229
 
1230
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1231
 
1232
  # ==============================================================================
1233
  # الدالة الأولى: get_master_config
 
1300
  # الدالة الثانية: load_and_prepare_model
1301
  # ==============================================================================
1302
 
1303
+ def load_and_prepare_model(device: torch.device) -> InterfuserHDPE:
1304
  """
1305
  [النسخة الاحترافية]
1306
  تستخدم الإعدادات الرئيسية من `get_master_config` لإنشاء وتحميل النموذج.
 
1333
 
1334
  # 3. إنشاء نسخة من النموذج باستخدام الإعدادات الصحيحة
1335
  logging.info("Instantiating model with specified parameters...")
1336
+ model = InterfuserHDPE(**config['model_params']).to(device)
1337
 
1338
  # 4. تحميل الأوزان التي تم تنزيلها إلى النموذج
1339
  # نستخدم الدالة المساعدة الموجودة داخل كلاس النموذج نفسه