# model_definition.py # ============================================================================ # الاستيرادات الأساسية # ============================================================================ import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torch.optim import AdamW from torch.optim.lr_scheduler import OneCycleLR from torch.utils.data import Dataset, DataLoader from torchvision import transforms from functools import partial from typing import Optional, List from torch import Tensor # مكتبات إضافية import os import json import logging import math import copy from pathlib import Path from collections import OrderedDict # مكتبات معالجة البيانات import numpy as np import cv2 # مكتبات اختيارية (يمكن تعطيلها إذا لم تكن متوفرة) try: import wandb WANDB_AVAILABLE = True except ImportError: WANDB_AVAILABLE = False try: from tqdm import tqdm except ImportError: # إذا لم تكن tqdm متوفرة، استخدم دالة بديلة def tqdm(iterable, *args, **kwargs): return iterable # ============================================================================ # دوال مساعدة # ============================================================================ def to_2tuple(x): """تحويل قيمة إلى tuple من عنصرين""" if isinstance(x, (list, tuple)): return tuple(x) return (x, x) # ============================================================================ # ============================================================================ class HybridEmbed(nn.Module): def __init__( self, backbone, img_size=224, patch_size=1, feature_size=None, in_chans=3, embed_dim=768, ): super().__init__() assert isinstance(backbone, nn.Module) img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) self.img_size = img_size self.patch_size = patch_size self.backbone = backbone if feature_size is None: with torch.no_grad(): training = backbone.training if training: backbone.eval() o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1])) if isinstance(o, (list, tuple)): o = o[-1] # last feature if backbone outputs list/tuple of features feature_size = o.shape[-2:] feature_dim = o.shape[1] backbone.train(training) else: feature_size = to_2tuple(feature_size) if hasattr(self.backbone, "feature_info"): feature_dim = self.backbone.feature_info.channels()[-1] else: feature_dim = self.backbone.num_features self.proj = nn.Conv2d(feature_dim, embed_dim, kernel_size=1, stride=1) def forward(self, x): x = self.backbone(x) if isinstance(x, (list, tuple)): x = x[-1] # last feature if backbone outputs list/tuple of features x = self.proj(x) global_x = torch.mean(x, [2, 3], keepdim=False)[:, :, None] return x, global_x class PositionEmbeddingSine(nn.Module): """ This is a more standard version of the position embedding, very similar to the one used by the Attention is all you need paper, generalized to work on images. """ def __init__( self, num_pos_feats=64, temperature=10000, normalize=False, scale=None ): super().__init__() self.num_pos_feats = num_pos_feats self.temperature = temperature self.normalize = normalize if scale is not None and normalize is False: raise ValueError("normalize should be True if scale is passed") if scale is None: scale = 2 * math.pi self.scale = scale def forward(self, tensor): x = tensor bs, _, h, w = x.shape not_mask = torch.ones((bs, h, w), device=x.device) y_embed = not_mask.cumsum(1, dtype=torch.float32) x_embed = not_mask.cumsum(2, dtype=torch.float32) if self.normalize: eps = 1e-6 y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) pos_x = x_embed[:, :, :, None] / dim_t pos_y = y_embed[:, :, :, None] / dim_t pos_x = torch.stack( (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4 ).flatten(3) pos_y = torch.stack( (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4 ).flatten(3) pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) return pos class TransformerEncoder(nn.Module): def __init__(self, encoder_layer, num_layers, norm=None): super().__init__() self.layers = _get_clones(encoder_layer, num_layers) self.num_layers = num_layers self.norm = norm def forward( self, src, mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, ): output = src for layer in self.layers: output = layer( output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, pos=pos, ) if self.norm is not None: output = self.norm(output) return output class SpatialSoftmax(nn.Module): def __init__(self, height, width, channel, temperature=None, data_format="NCHW"): super().__init__() self.data_format = data_format self.height = height self.width = width self.channel = channel if temperature: self.temperature = Parameter(torch.ones(1) * temperature) else: self.temperature = 1.0 pos_x, pos_y = np.meshgrid( np.linspace(-1.0, 1.0, self.height), np.linspace(-1.0, 1.0, self.width) ) pos_x = torch.from_numpy(pos_x.reshape(self.height * self.width)).float() pos_y = torch.from_numpy(pos_y.reshape(self.height * self.width)).float() self.register_buffer("pos_x", pos_x) self.register_buffer("pos_y", pos_y) def forward(self, feature): # Output: # (N, C*2) x_0 y_0 ... if self.data_format == "NHWC": feature = ( feature.transpose(1, 3) .tranpose(2, 3) .view(-1, self.height * self.width) ) else: feature = feature.view(-1, self.height * self.width) weight = F.softmax(feature / self.temperature, dim=-1) expected_x = torch.sum( torch.autograd.Variable(self.pos_x) * weight, dim=1, keepdim=True ) expected_y = torch.sum( torch.autograd.Variable(self.pos_y) * weight, dim=1, keepdim=True ) expected_xy = torch.cat([expected_x, expected_y], 1) feature_keypoints = expected_xy.view(-1, self.channel, 2) feature_keypoints[:, :, 1] = (feature_keypoints[:, :, 1] - 1) * 12 feature_keypoints[:, :, 0] = feature_keypoints[:, :, 0] * 12 return feature_keypoints class MultiPath_Generator(nn.Module): def __init__(self, in_channel, embed_dim, out_channel): super().__init__() self.spatial_softmax = SpatialSoftmax(100, 100, out_channel) self.tconv0 = nn.Sequential( nn.ConvTranspose2d(in_channel, 256, 4, 2, 1, bias=False), nn.BatchNorm2d(256), nn.ReLU(True), ) self.tconv1 = nn.Sequential( nn.ConvTranspose2d(256, 256, 4, 2, 1, bias=False), nn.BatchNorm2d(256), nn.ReLU(True), ) self.tconv2 = nn.Sequential( nn.ConvTranspose2d(256, 192, 4, 2, 1, bias=False), nn.BatchNorm2d(192), nn.ReLU(True), ) self.tconv3 = nn.Sequential( nn.ConvTranspose2d(192, 64, 4, 2, 1, bias=False), nn.BatchNorm2d(64), nn.ReLU(True), ) self.tconv4_list = torch.nn.ModuleList( [ nn.Sequential( nn.ConvTranspose2d(64, out_channel, 8, 2, 3, bias=False), nn.Tanh(), ) for _ in range(6) ] ) self.upsample = nn.Upsample(size=(50, 50), mode="bilinear") def forward(self, x, measurements): mask = measurements[:, :6] mask = mask.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).repeat(1, 1, 1, 100, 100) velocity = measurements[:, 6:7].unsqueeze(-1).unsqueeze(-1) velocity = velocity.repeat(1, 32, 2, 2) n, d, c = x.shape x = x.transpose(1, 2) x = x.view(n, -1, 2, 2) x = torch.cat([x, velocity], dim=1) x = self.tconv0(x) x = self.tconv1(x) x = self.tconv2(x) x = self.tconv3(x) x = self.upsample(x) xs = [] for i in range(6): xt = self.tconv4_list[i](x) xs.append(xt) xs = torch.stack(xs, dim=1) x = torch.sum(xs * mask, dim=1) x = self.spatial_softmax(x) return x class LinearWaypointsPredictor(nn.Module): def __init__(self, input_dim, cumsum=True): super().__init__() self.cumsum = cumsum self.rank_embed = nn.Parameter(torch.zeros(1, 10, input_dim)) self.head_fc1_list = nn.ModuleList([nn.Linear(input_dim, 64) for _ in range(6)]) self.head_relu = nn.ReLU(inplace=True) self.head_fc2_list = nn.ModuleList([nn.Linear(64, 2) for _ in range(6)]) def forward(self, x, measurements): # input shape: n 10 embed_dim bs, n, dim = x.shape x = x + self.rank_embed x = x.reshape(-1, dim) mask = measurements[:, :6] mask = torch.unsqueeze(mask, -1).repeat(n, 1, 2) rs = [] for i in range(6): res = self.head_fc1_list[i](x) res = self.head_relu(res) res = self.head_fc2_list[i](res) rs.append(res) rs = torch.stack(rs, 1) x = torch.sum(rs * mask, dim=1) x = x.view(bs, n, 2) if self.cumsum: x = torch.cumsum(x, 1) return x class GRUWaypointsPredictor(nn.Module): def __init__(self, input_dim, waypoints=10): super().__init__() # self.gru = torch.nn.GRUCell(input_size=input_dim, hidden_size=64) self.gru = torch.nn.GRU(input_size=input_dim, hidden_size=64, batch_first=True) self.encoder = nn.Linear(2, 64) self.decoder = nn.Linear(64, 2) self.waypoints = waypoints def forward(self, x, target_point): bs = x.shape[0] z = self.encoder(target_point).unsqueeze(0) output, _ = self.gru(x, z) output = output.reshape(bs * self.waypoints, -1) output = self.decoder(output).reshape(bs, self.waypoints, 2) output = torch.cumsum(output, 1) return output class GRUWaypointsPredictorWithCommand(nn.Module): def __init__(self, input_dim, waypoints=10): super().__init__() # self.gru = torch.nn.GRUCell(input_size=input_dim, hidden_size=64) self.grus = nn.ModuleList([torch.nn.GRU(input_size=input_dim, hidden_size=64, batch_first=True) for _ in range(6)]) self.encoder = nn.Linear(2, 64) self.decoders = nn.ModuleList([nn.Linear(64, 2) for _ in range(6)]) self.waypoints = waypoints def forward(self, x, target_point, measurements): bs, n, dim = x.shape mask = measurements[:, :6, None, None] mask = mask.repeat(1, 1, self.waypoints, 2) z = self.encoder(target_point).unsqueeze(0) outputs = [] for i in range(6): output, _ = self.grus[i](x, z) output = output.reshape(bs * self.waypoints, -1) output = self.decoders[i](output).reshape(bs, self.waypoints, 2) output = torch.cumsum(output, 1) outputs.append(output) outputs = torch.stack(outputs, 1) output = torch.sum(outputs * mask, dim=1) return output class TransformerDecoder(nn.Module): def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False): super().__init__() self.layers = _get_clones(decoder_layer, num_layers) self.num_layers = num_layers self.norm = norm self.return_intermediate = return_intermediate def forward( self, tgt, memory, tgt_mask: Optional[Tensor] = None, memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, query_pos: Optional[Tensor] = None, ): output = tgt intermediate = [] for layer in self.layers: output = layer( output, memory, tgt_mask=tgt_mask, memory_mask=memory_mask, tgt_key_padding_mask=tgt_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask, pos=pos, query_pos=query_pos, ) if self.return_intermediate: intermediate.append(self.norm(output)) if self.norm is not None: output = self.norm(output) if self.return_intermediate: intermediate.pop() intermediate.append(output) if self.return_intermediate: return torch.stack(intermediate) return output.unsqueeze(0) class TransformerEncoderLayer(nn.Module): def __init__( self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation=nn.ReLU(), normalize_before=False, ): super().__init__() self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) # Implementation of Feedforward model self.linear1 = nn.Linear(d_model, dim_feedforward) self.dropout = nn.Dropout(dropout) self.linear2 = nn.Linear(dim_feedforward, d_model) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) self.dropout1 = nn.Dropout(dropout) self.dropout2 = nn.Dropout(dropout) self.activation = activation() self.normalize_before = normalize_before def with_pos_embed(self, tensor, pos: Optional[Tensor]): return tensor if pos is None else tensor + pos def forward_post( self, src, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, ): q = k = self.with_pos_embed(src, pos) src2 = self.self_attn( q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask )[0] src = src + self.dropout1(src2) src = self.norm1(src) src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) src = src + self.dropout2(src2) src = self.norm2(src) return src def forward_pre( self, src, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, ): src2 = self.norm1(src) q = k = self.with_pos_embed(src2, pos) src2 = self.self_attn( q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask )[0] src = src + self.dropout1(src2) src2 = self.norm2(src) src2 = self.linear2(self.dropout(self.activation(self.linear1(src2)))) src = src + self.dropout2(src2) return src def forward( self, src, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, ): if self.normalize_before: return self.forward_pre(src, src_mask, src_key_padding_mask, pos) return self.forward_post(src, src_mask, src_key_padding_mask, pos) class TransformerDecoderLayer(nn.Module): def __init__( self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation=nn.ReLU(), normalize_before=False, ): super().__init__() self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) # Implementation of Feedforward model self.linear1 = nn.Linear(d_model, dim_feedforward) self.dropout = nn.Dropout(dropout) self.linear2 = nn.Linear(dim_feedforward, d_model) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) self.norm3 = nn.LayerNorm(d_model) self.dropout1 = nn.Dropout(dropout) self.dropout2 = nn.Dropout(dropout) self.dropout3 = nn.Dropout(dropout) self.activation = activation() self.normalize_before = normalize_before def with_pos_embed(self, tensor, pos: Optional[Tensor]): return tensor if pos is None else tensor + pos def forward_post( self, tgt, memory, tgt_mask: Optional[Tensor] = None, memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, query_pos: Optional[Tensor] = None, ): q = k = self.with_pos_embed(tgt, query_pos) tgt2 = self.self_attn( q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask )[0] tgt = tgt + self.dropout1(tgt2) tgt = self.norm1(tgt) tgt2 = self.multihead_attn( query=self.with_pos_embed(tgt, query_pos), key=self.with_pos_embed(memory, pos), value=memory, attn_mask=memory_mask, key_padding_mask=memory_key_padding_mask, )[0] tgt = tgt + self.dropout2(tgt2) tgt = self.norm2(tgt) tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) tgt = tgt + self.dropout3(tgt2) tgt = self.norm3(tgt) return tgt def forward_pre( self, tgt, memory, tgt_mask: Optional[Tensor] = None, memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, query_pos: Optional[Tensor] = None, ): tgt2 = self.norm1(tgt) q = k = self.with_pos_embed(tgt2, query_pos) tgt2 = self.self_attn( q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask )[0] tgt = tgt + self.dropout1(tgt2) tgt2 = self.norm2(tgt) tgt2 = self.multihead_attn( query=self.with_pos_embed(tgt2, query_pos), key=self.with_pos_embed(memory, pos), value=memory, attn_mask=memory_mask, key_padding_mask=memory_key_padding_mask, )[0] tgt = tgt + self.dropout2(tgt2) tgt2 = self.norm3(tgt) tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) tgt = tgt + self.dropout3(tgt2) return tgt def forward( self, tgt, memory, tgt_mask: Optional[Tensor] = None, memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, query_pos: Optional[Tensor] = None, ): if self.normalize_before: return self.forward_pre( tgt, memory, tgt_mask, memory_mask, tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos, ) return self.forward_post( tgt, memory, tgt_mask, memory_mask, tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos, ) def _get_clones(module, N): return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) def _get_activation_fn(activation): """Return an activation function given a string""" if activation == "relu": return F.relu if activation == "gelu": return F.gelu if activation == "glu": return F.glu raise RuntimeError(f"activation should be relu/gelu, not {activation}.") def build_attn_mask(mask_type): mask = torch.ones((151, 151), dtype=torch.bool).cuda() if mask_type == "seperate_all": mask[:50, :50] = False mask[50:67, 50:67] = False mask[67:84, 67:84] = False mask[84:101, 84:101] = False mask[101:151, 101:151] = False elif mask_type == "seperate_view": mask[:50, :50] = False mask[50:67, 50:67] = False mask[67:84, 67:84] = False mask[84:101, 84:101] = False mask[101:151, :] = False mask[:, 101:151] = False return mask # class InterfuserModel(nn.Module): class InterfuserModel(nn.Module): def __init__( self, img_size=224, multi_view_img_size=112, patch_size=8, in_chans=3, embed_dim=768, enc_depth=6, dec_depth=6, dim_feedforward=2048, normalize_before=False, rgb_backbone_name="r50", lidar_backbone_name="r50", num_heads=8, norm_layer=None, dropout=0.1, end2end=False, direct_concat=False, separate_view_attention=False, separate_all_attention=False, act_layer=None, weight_init="", freeze_num=-1, with_lidar=False, with_right_left_sensors=False, with_center_sensor=False, traffic_pred_head_type="det", waypoints_pred_head="heatmap", reverse_pos=True, use_different_backbone=False, use_view_embed=False, use_mmad_pretrain=None, ): super().__init__() self.traffic_pred_head_type = traffic_pred_head_type self.num_features = ( self.embed_dim ) = embed_dim # num_features for consistency with other models norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) act_layer = act_layer or nn.GELU self.reverse_pos = reverse_pos self.waypoints_pred_head = waypoints_pred_head self.with_lidar = with_lidar self.with_right_left_sensors = with_right_left_sensors self.with_center_sensor = with_center_sensor self.direct_concat = direct_concat self.separate_view_attention = separate_view_attention self.separate_all_attention = separate_all_attention self.end2end = end2end self.use_view_embed = use_view_embed if self.direct_concat: in_chans = in_chans * 4 self.with_center_sensor = False self.with_right_left_sensors = False if self.separate_view_attention: self.attn_mask = build_attn_mask("seperate_view") elif self.separate_all_attention: self.attn_mask = build_attn_mask("seperate_all") else: self.attn_mask = None if use_different_backbone: if rgb_backbone_name == "r50": self.rgb_backbone = resnet50d( pretrained=True, in_chans=in_chans, features_only=True, out_indices=[4], ) elif rgb_backbone_name == "r26": self.rgb_backbone = resnet26d( pretrained=True, in_chans=in_chans, features_only=True, out_indices=[4], ) elif rgb_backbone_name == "r18": self.rgb_backbone = resnet18d( pretrained=True, in_chans=in_chans, features_only=True, out_indices=[4], ) if lidar_backbone_name == "r50": self.lidar_backbone = resnet50d( pretrained=False, in_chans=in_chans, features_only=True, out_indices=[4], ) elif lidar_backbone_name == "r26": self.lidar_backbone = resnet26d( pretrained=False, in_chans=in_chans, features_only=True, out_indices=[4], ) elif lidar_backbone_name == "r18": self.lidar_backbone = resnet18d( pretrained=False, in_chans=3, features_only=True, out_indices=[4] ) rgb_embed_layer = partial(HybridEmbed, backbone=self.rgb_backbone) lidar_embed_layer = partial(HybridEmbed, backbone=self.lidar_backbone) if use_mmad_pretrain: params = torch.load(use_mmad_pretrain)["state_dict"] updated_params = OrderedDict() for key in params: if "backbone" in key: updated_params[key.replace("backbone.", "")] = params[key] self.rgb_backbone.load_state_dict(updated_params) self.rgb_patch_embed = rgb_embed_layer( img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, ) self.lidar_patch_embed = lidar_embed_layer( img_size=img_size, patch_size=patch_size, in_chans=3, embed_dim=embed_dim, ) else: if rgb_backbone_name == "r50": self.rgb_backbone = resnet50d( pretrained=True, in_chans=3, features_only=True, out_indices=[4] ) elif rgb_backbone_name == "r101": self.rgb_backbone = resnet101d( pretrained=True, in_chans=3, features_only=True, out_indices=[4] ) elif rgb_backbone_name == "r26": self.rgb_backbone = resnet26d( pretrained=True, in_chans=3, features_only=True, out_indices=[4] ) elif rgb_backbone_name == "r18": self.rgb_backbone = resnet18d( pretrained=True, in_chans=3, features_only=True, out_indices=[4] ) embed_layer = partial(HybridEmbed, backbone=self.rgb_backbone) self.rgb_patch_embed = embed_layer( img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, ) self.lidar_patch_embed = embed_layer( img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, ) self.global_embed = nn.Parameter(torch.zeros(1, embed_dim, 5)) self.view_embed = nn.Parameter(torch.zeros(1, embed_dim, 5, 1)) if self.end2end: self.query_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, 4)) self.query_embed = nn.Parameter(torch.zeros(4, 1, embed_dim)) elif self.waypoints_pred_head == "heatmap": self.query_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, 5)) self.query_embed = nn.Parameter(torch.zeros(400 + 5, 1, embed_dim)) else: self.query_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, 11)) self.query_embed = nn.Parameter(torch.zeros(400 + 11, 1, embed_dim)) if self.end2end: self.waypoints_generator = GRUWaypointsPredictor(embed_dim, 4) elif self.waypoints_pred_head == "heatmap": self.waypoints_generator = MultiPath_Generator( embed_dim + 32, embed_dim, 10 ) elif self.waypoints_pred_head == "gru": self.waypoints_generator = GRUWaypointsPredictor(embed_dim) elif self.waypoints_pred_head == "gru-command": self.waypoints_generator = GRUWaypointsPredictorWithCommand(embed_dim) elif self.waypoints_pred_head == "linear": self.waypoints_generator = LinearWaypointsPredictor(embed_dim) elif self.waypoints_pred_head == "linear-sum": self.waypoints_generator = LinearWaypointsPredictor(embed_dim, cumsum=True) self.junction_pred_head = nn.Linear(embed_dim, 2) self.traffic_light_pred_head = nn.Linear(embed_dim, 2) self.stop_sign_head = nn.Linear(embed_dim, 2) if self.traffic_pred_head_type == "det": self.traffic_pred_head = nn.Sequential( *[ nn.Linear(embed_dim + 32, 64), nn.ReLU(), nn.Linear(64, 7), # nn.Sigmoid(), ] ) elif self.traffic_pred_head_type == "seg": self.traffic_pred_head = nn.Sequential( *[nn.Linear(embed_dim, 64), nn.ReLU(), nn.Linear(64, 1), nn.Sigmoid()] ) self.position_encoding = PositionEmbeddingSine(embed_dim // 2, normalize=True) encoder_layer = TransformerEncoderLayer( embed_dim, num_heads, dim_feedforward, dropout, act_layer, normalize_before ) self.encoder = TransformerEncoder(encoder_layer, enc_depth, None) decoder_layer = TransformerDecoderLayer( embed_dim, num_heads, dim_feedforward, dropout, act_layer, normalize_before ) decoder_norm = nn.LayerNorm(embed_dim) self.decoder = TransformerDecoder( decoder_layer, dec_depth, decoder_norm, return_intermediate=False ) self.reset_parameters() def reset_parameters(self): nn.init.uniform_(self.global_embed) nn.init.uniform_(self.view_embed) nn.init.uniform_(self.query_embed) nn.init.uniform_(self.query_pos_embed) def forward_features( self, front_image, left_image, right_image, front_center_image, lidar, measurements, ): features = [] # Front view processing front_image_token, front_image_token_global = self.rgb_patch_embed(front_image) if self.use_view_embed: front_image_token = ( front_image_token + self.view_embed[:, :, 0:1, :] + self.position_encoding(front_image_token) ) else: front_image_token = front_image_token + self.position_encoding( front_image_token ) front_image_token = front_image_token.flatten(2).permute(2, 0, 1) front_image_token_global = ( front_image_token_global + self.view_embed[:, :, 0, :] + self.global_embed[:, :, 0:1] ) front_image_token_global = front_image_token_global.permute(2, 0, 1) features.extend([front_image_token, front_image_token_global]) if self.with_right_left_sensors: # Left view processing left_image_token, left_image_token_global = self.rgb_patch_embed(left_image) if self.use_view_embed: left_image_token = ( left_image_token + self.view_embed[:, :, 1:2, :] + self.position_encoding(left_image_token) ) else: left_image_token = left_image_token + self.position_encoding( left_image_token ) left_image_token = left_image_token.flatten(2).permute(2, 0, 1) left_image_token_global = ( left_image_token_global + self.view_embed[:, :, 1, :] + self.global_embed[:, :, 1:2] ) left_image_token_global = left_image_token_global.permute(2, 0, 1) # Right view processing right_image_token, right_image_token_global = self.rgb_patch_embed( right_image ) if self.use_view_embed: right_image_token = ( right_image_token + self.view_embed[:, :, 2:3, :] + self.position_encoding(right_image_token) ) else: right_image_token = right_image_token + self.position_encoding( right_image_token ) right_image_token = right_image_token.flatten(2).permute(2, 0, 1) right_image_token_global = ( right_image_token_global + self.view_embed[:, :, 2, :] + self.global_embed[:, :, 2:3] ) right_image_token_global = right_image_token_global.permute(2, 0, 1) features.extend( [ left_image_token, left_image_token_global, right_image_token, right_image_token_global, ] ) if self.with_center_sensor: # Front center view processing ( front_center_image_token, front_center_image_token_global, ) = self.rgb_patch_embed(front_center_image) if self.use_view_embed: front_center_image_token = ( front_center_image_token + self.view_embed[:, :, 3:4, :] + self.position_encoding(front_center_image_token) ) else: front_center_image_token = ( front_center_image_token + self.position_encoding(front_center_image_token) ) front_center_image_token = front_center_image_token.flatten(2).permute( 2, 0, 1 ) front_center_image_token_global = ( front_center_image_token_global + self.view_embed[:, :, 3, :] + self.global_embed[:, :, 3:4] ) front_center_image_token_global = front_center_image_token_global.permute( 2, 0, 1 ) features.extend([front_center_image_token, front_center_image_token_global]) if self.with_lidar: lidar_token, lidar_token_global = self.lidar_patch_embed(lidar) if self.use_view_embed: lidar_token = ( lidar_token + self.view_embed[:, :, 4:5, :] + self.position_encoding(lidar_token) ) else: lidar_token = lidar_token + self.position_encoding(lidar_token) lidar_token = lidar_token.flatten(2).permute(2, 0, 1) lidar_token_global = ( lidar_token_global + self.view_embed[:, :, 4, :] + self.global_embed[:, :, 4:5] ) lidar_token_global = lidar_token_global.permute(2, 0, 1) features.extend([lidar_token, lidar_token_global]) features = torch.cat(features, 0) return features def forward(self, x): front_image = x["rgb"] left_image = x["rgb_left"] right_image = x["rgb_right"] front_center_image = x["rgb_center"] measurements = x["measurements"] target_point = x["target_point"] lidar = x["lidar"] if self.direct_concat: img_size = front_image.shape[-1] left_image = torch.nn.functional.interpolate( left_image, size=(img_size, img_size) ) right_image = torch.nn.functional.interpolate( right_image, size=(img_size, img_size) ) front_center_image = torch.nn.functional.interpolate( front_center_image, size=(img_size, img_size) ) front_image = torch.cat( [front_image, left_image, right_image, front_center_image], dim=1 ) features = self.forward_features( front_image, left_image, right_image, front_center_image, lidar, measurements, ) bs = front_image.shape[0] if self.end2end: tgt = self.query_pos_embed.repeat(bs, 1, 1) else: tgt = self.position_encoding( torch.ones((bs, 1, 20, 20), device=x["rgb"].device) ) tgt = tgt.flatten(2) tgt = torch.cat([tgt, self.query_pos_embed.repeat(bs, 1, 1)], 2) tgt = tgt.permute(2, 0, 1) memory = self.encoder(features, mask=self.attn_mask) hs = self.decoder(self.query_embed.repeat(1, bs, 1), memory, query_pos=tgt)[0] hs = hs.permute(1, 0, 2) # Batchsize , N, C if self.end2end: waypoints = self.waypoints_generator(hs, target_point) return waypoints if self.waypoints_pred_head != "heatmap": traffic_feature = hs[:, :400] is_junction_feature = hs[:, 400] traffic_light_state_feature = hs[:, 400] stop_sign_feature = hs[:, 400] waypoints_feature = hs[:, 401:411] else: traffic_feature = hs[:, :400] is_junction_feature = hs[:, 400] traffic_light_state_feature = hs[:, 400] stop_sign_feature = hs[:, 400] waypoints_feature = hs[:, 401:405] if self.waypoints_pred_head == "heatmap": waypoints = self.waypoints_generator(waypoints_feature, measurements) elif self.waypoints_pred_head == "gru": waypoints = self.waypoints_generator(waypoints_feature, target_point) elif self.waypoints_pred_head == "gru-command": waypoints = self.waypoints_generator(waypoints_feature, target_point, measurements) elif self.waypoints_pred_head == "linear": waypoints = self.waypoints_generator(waypoints_feature, measurements) elif self.waypoints_pred_head == "linear-sum": waypoints = self.waypoints_generator(waypoints_feature, measurements) is_junction = self.junction_pred_head(is_junction_feature) traffic_light_state = self.traffic_light_pred_head(traffic_light_state_feature) stop_sign = self.stop_sign_head(stop_sign_feature) velocity = measurements[:, 6:7].unsqueeze(-1) velocity = velocity.repeat(1, 400, 32) traffic_feature_with_vel = torch.cat([traffic_feature, velocity], dim=2) traffic = self.traffic_pred_head(traffic_feature_with_vel) return traffic, waypoints, is_junction, traffic_light_state, stop_sign, traffic_feature def load_pretrained(self, model_path, strict=False): """ تحميل الأوزان المدربة مسبقاً - نسخة محسنة Args: model_path (str): مسار ملف الأوزان strict (bool): إذا كان True، يتطلب تطابق تام للمفاتيح """ if not model_path or not Path(model_path).exists(): logging.warning(f"ملف الأوزان غير موجود: {model_path}") logging.info("سيتم استخدام أوزان عشوائية") return False try: logging.info(f"محاولة تحميل الأوزان من: {model_path}") # تحميل الملف مع معالجة أنواع مختلفة من ملفات الحفظ checkpoint = torch.load(model_path, map_location='cpu', weights_only=False) # استخراج state_dict من أنواع مختلفة من ملفات الحفظ if isinstance(checkpoint, dict): if 'model_state_dict' in checkpoint: state_dict = checkpoint['model_state_dict'] logging.info("تم العثور على 'model_state_dict' في الملف") elif 'state_dict' in checkpoint: state_dict = checkpoint['state_dict'] logging.info("تم العثور على 'state_dict' في الملف") elif 'model' in checkpoint: state_dict = checkpoint['model'] logging.info("تم العثور على 'model' في الملف") else: state_dict = checkpoint logging.info("استخدام الملف كـ state_dict مباشرة") else: state_dict = checkpoint logging.info("استخدام الملف كـ state_dict مباشرة") # تنظيف أسماء المفاتيح (إزالة 'module.' إذا كانت موجودة) clean_state_dict = OrderedDict() for k, v in state_dict.items(): # إزالة 'module.' من بداية اسم المفتاح إذا كان موجوداً clean_key = k[7:] if k.startswith('module.') else k clean_state_dict[clean_key] = v # تحميل الأوزان missing_keys, unexpected_keys = self.load_state_dict(clean_state_dict, strict=strict) # تقرير حالة التحميل if missing_keys: logging.warning(f"مفاتيح مفقودة ({len(missing_keys)}): {missing_keys[:5]}..." if len(missing_keys) > 5 else f"مفاتيح مفقودة: {missing_keys}") if unexpected_keys: logging.warning(f"مفاتيح غير متوقعة ({len(unexpected_keys)}): {unexpected_keys[:5]}..." if len(unexpected_keys) > 5 else f"مفاتيح غير متوقعة: {unexpected_keys}") if not missing_keys and not unexpected_keys: logging.info("✅ تم تحميل جميع الأوزان بنجاح تام") elif not strict: logging.info("✅ تم تحميل الأوزان بنجاح (مع تجاهل عدم التطابق)") return True except Exception as e: logging.error(f"❌ خطأ في تحميل الأوزان: {str(e)}") logging.info("سيتم استخدام أوزان عشوائية") return False # ============================================================================ # دوال مساعدة لتحميل النموذج # ============================================================================ def load_and_prepare_model(config, device): """ يقوم بإنشاء النموذج وتحميل الأوزان المدربة مسبقًا. Args: config (dict): إعدادات النموذج والمسارات device (torch.device): الجهاز المستهدف (CPU/GPU) Returns: InterfuserModel: النموذج المحمل """ try: # إنشاء النموذج model = InterfuserModel(**config.get('model_params', {})).to(device) logging.info(f"تم إنشاء النموذج على الجهاز: {device}") # تحميل الأوزان إذا كان المسار محدد checkpoint_path = config.get('paths', {}).get('pretrained_weights') if checkpoint_path: success = model.load_pretrained(checkpoint_path, strict=False) if success: logging.info("✅ تم تحميل النموذج والأوزان بنجاح") else: logging.warning("⚠️ تم إنشاء النموذج بأوزان عشوائية") else: logging.info("لم يتم تحديد مسار الأوزان، سيتم استخدام أوزان عشوائية") # وضع النموذج في وضع التقييم model.eval() return model except Exception as e: logging.error(f"خطأ في إنشاء النموذج: {str(e)}") raise def create_model_config(model_path="model/best_model.pth", **model_params): """ إنشاء إعدادات النموذج باستخدام الإعدادات الصحيحة من التدريب Args: model_path (str): مسار ملف الأوزان **model_params: معاملات النموذج الإضافية Returns: dict: إعدادات النموذج """ # الإعدادات الصحيحة من كونفيج التدريب الأصلي training_config_params = { "img_size": 224, "embed_dim": 256, # مهم: هذه القيمة من التدريب الأصلي "enc_depth": 6, "dec_depth": 6, "rgb_backbone_name": 'r50', "lidar_backbone_name": 'r18', "waypoints_pred_head": 'gru', "use_different_backbone": True, "with_lidar": False, "with_right_left_sensors": False, "with_center_sensor": False, # إعدادات إضافية من الكونفيج الأصلي "multi_view_img_size": 112, "patch_size": 8, "in_chans": 3, "dim_feedforward": 2048, "normalize_before": False, "num_heads": 8, "dropout": 0.1, "end2end": False, "direct_concat": False, "separate_view_attention": False, "separate_all_attention": False, "freeze_num": -1, "traffic_pred_head_type": "det", "reverse_pos": True, "use_view_embed": False, "use_mmad_pretrain": None, } # دمج المعاملات المخصصة مع الإعدادات من التدريب training_config_params.update(model_params) config = { 'model_params': training_config_params, 'paths': { 'pretrained_weights': model_path }, # إضافة إعدادات الشبكة من التدريب 'grid_conf': { 'h': 20, 'w': 20, 'x_res': 1.0, 'y_res': 1.0, 'y_min': 0.0, 'y_max': 20.0, 'x_min': -10.0, 'x_max': 10.0, }, # معلومات إضافية عن التدريب 'training_info': { 'original_project': 'Interfuser_Finetuning', 'run_name': 'Finetune_Focus_on_Detection_v5', 'focus': 'traffic_detection_and_iou', 'backbone': 'ResNet50 + ResNet18', 'trained_on': 'PDM_Lite_Carla' } } return config def get_training_config(): """ إرجاع إعدادات التدريب الأصلية للمرجع هذه الإعدادات توضح كيف تم تدريب النموذج """ return { 'project_info': { 'project': 'Interfuser_Finetuning', 'entity': None, 'run_name': 'Finetune_Focus_on_Detection_v5' }, 'training': { 'epochs': 50, 'batch_size': 8, 'num_workers': 2, 'learning_rate': 1e-4, # معدل تعلم منخفض للـ Fine-tuning 'weight_decay': 1e-2, 'patience': 15, 'clip_grad_norm': 1.0, }, 'loss_weights': { 'iou': 2.0, # أولوية قصوى لدقة الصناديق 'traffic_map': 25.0, # تركيز عالي على اكتشاف الكائنات 'waypoints': 1.0, # مرجع أساسي 'junction': 0.25, # مهام متقنة بالفعل 'traffic_light': 0.5, 'stop_sign': 0.25, }, 'data_split': { 'strategy': 'interleaved', 'segment_length': 100, 'validation_frequency': 10, }, 'transforms': { 'use_data_augmentation': False, # معطل للتركيز على البيانات الأصلية } }