# model_definition.py # ============================================================================ # الاستيرادات الأساسية # ============================================================================ import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torch.optim import AdamW from torch.optim.lr_scheduler import OneCycleLR from torch.utils.data import Dataset, DataLoader from torchvision import transforms from functools import partial from typing import Optional, List from torch import Tensor import os import json import numpy as np import cv2 from PIL import Image import torch import torch.nn as nn import torch.nn.functional as F from torchvision import transforms from functools import partial from collections import deque, OrderedDict import math from torch.nn import MultiheadAttention from torch.nn import TransformerEncoder, TransformerEncoderLayer from torch.nn import TransformerDecoder, TransformerDecoderLayer from timm.models.resnet import resnet50d, resnet26d, resnet18d try: from timm.layers import trunc_normal_ except ImportError: from timm.models.layers import trunc_normal_ # مكتبات إضافية import os import json import logging import math import copy from pathlib import Path from collections import OrderedDict # مكتبات معالجة البيانات import numpy as np import cv2 # مكتبات اختيارية (يمكن تعطيلها إذا لم تكن متوفرة) try: import wandb WANDB_AVAILABLE = True except ImportError: WANDB_AVAILABLE = False try: from tqdm import tqdm except ImportError: # إذا لم تكن tqdm متوفرة، استخدم دالة بديلة def tqdm(iterable, *args, **kwargs): return iterable # ============================================================================ # دوال مساعدة # ============================================================================ def to_2tuple(x): """تحويل قيمة إلى tuple من عنصرين""" if isinstance(x, (list, tuple)): return tuple(x) return (x, x) # ============================================================================ # ============================================================================ class HybridEmbed(nn.Module): def __init__( self, backbone, img_size=224, patch_size=1, feature_size=None, in_chans=3, embed_dim=768, ): super().__init__() assert isinstance(backbone, nn.Module) img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) self.img_size = img_size self.patch_size = patch_size self.backbone = backbone if feature_size is None: with torch.no_grad(): training = backbone.training if training: backbone.eval() o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1])) if isinstance(o, (list, tuple)): o = o[-1] # last feature if backbone outputs list/tuple of features feature_size = o.shape[-2:] feature_dim = o.shape[1] backbone.train(training) else: feature_size = to_2tuple(feature_size) if hasattr(self.backbone, "feature_info"): feature_dim = self.backbone.feature_info.channels()[-1] else: feature_dim = self.backbone.num_features self.proj = nn.Conv2d(feature_dim, embed_dim, kernel_size=1, stride=1) def forward(self, x): x = self.backbone(x) if isinstance(x, (list, tuple)): x = x[-1] # last feature if backbone outputs list/tuple of features x = self.proj(x) global_x = torch.mean(x, [2, 3], keepdim=False)[:, :, None] return x, global_x class PositionEmbeddingSine(nn.Module): """ This is a more standard version of the position embedding, very similar to the one used by the Attention is all you need paper, generalized to work on images. """ def __init__( self, num_pos_feats=64, temperature=10000, normalize=False, scale=None ): super().__init__() self.num_pos_feats = num_pos_feats self.temperature = temperature self.normalize = normalize if scale is not None and normalize is False: raise ValueError("normalize should be True if scale is passed") if scale is None: scale = 2 * math.pi self.scale = scale def forward(self, tensor): x = tensor bs, _, h, w = x.shape not_mask = torch.ones((bs, h, w), device=x.device) y_embed = not_mask.cumsum(1, dtype=torch.float32) x_embed = not_mask.cumsum(2, dtype=torch.float32) if self.normalize: eps = 1e-6 y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) pos_x = x_embed[:, :, :, None] / dim_t pos_y = y_embed[:, :, :, None] / dim_t pos_x = torch.stack( (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4 ).flatten(3) pos_y = torch.stack( (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4 ).flatten(3) pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) return pos class TransformerEncoder(nn.Module): def __init__(self, encoder_layer, num_layers, norm=None): super().__init__() self.layers = _get_clones(encoder_layer, num_layers) self.num_layers = num_layers self.norm = norm def forward( self, src, mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, ): output = src for layer in self.layers: output = layer( output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, pos=pos, ) if self.norm is not None: output = self.norm(output) return output class SpatialSoftmax(nn.Module): def __init__(self, height, width, channel, temperature=None, data_format="NCHW"): super().__init__() self.data_format = data_format self.height = height self.width = width self.channel = channel if temperature: self.temperature = Parameter(torch.ones(1) * temperature) else: self.temperature = 1.0 pos_x, pos_y = np.meshgrid( np.linspace(-1.0, 1.0, self.height), np.linspace(-1.0, 1.0, self.width) ) pos_x = torch.from_numpy(pos_x.reshape(self.height * self.width)).float() pos_y = torch.from_numpy(pos_y.reshape(self.height * self.width)).float() self.register_buffer("pos_x", pos_x) self.register_buffer("pos_y", pos_y) def forward(self, feature): # Output: # (N, C*2) x_0 y_0 ... if self.data_format == "NHWC": feature = ( feature.transpose(1, 3) .tranpose(2, 3) .view(-1, self.height * self.width) ) else: feature = feature.view(-1, self.height * self.width) weight = F.softmax(feature / self.temperature, dim=-1) expected_x = torch.sum( torch.autograd.Variable(self.pos_x) * weight, dim=1, keepdim=True ) expected_y = torch.sum( torch.autograd.Variable(self.pos_y) * weight, dim=1, keepdim=True ) expected_xy = torch.cat([expected_x, expected_y], 1) feature_keypoints = expected_xy.view(-1, self.channel, 2) feature_keypoints[:, :, 1] = (feature_keypoints[:, :, 1] - 1) * 12 feature_keypoints[:, :, 0] = feature_keypoints[:, :, 0] * 12 return feature_keypoints class MultiPath_Generator(nn.Module): def __init__(self, in_channel, embed_dim, out_channel): super().__init__() self.spatial_softmax = SpatialSoftmax(100, 100, out_channel) self.tconv0 = nn.Sequential( nn.ConvTranspose2d(in_channel, 256, 4, 2, 1, bias=False), nn.BatchNorm2d(256), nn.ReLU(True), ) self.tconv1 = nn.Sequential( nn.ConvTranspose2d(256, 256, 4, 2, 1, bias=False), nn.BatchNorm2d(256), nn.ReLU(True), ) self.tconv2 = nn.Sequential( nn.ConvTranspose2d(256, 192, 4, 2, 1, bias=False), nn.BatchNorm2d(192), nn.ReLU(True), ) self.tconv3 = nn.Sequential( nn.ConvTranspose2d(192, 64, 4, 2, 1, bias=False), nn.BatchNorm2d(64), nn.ReLU(True), ) self.tconv4_list = torch.nn.ModuleList( [ nn.Sequential( nn.ConvTranspose2d(64, out_channel, 8, 2, 3, bias=False), nn.Tanh(), ) for _ in range(6) ] ) self.upsample = nn.Upsample(size=(50, 50), mode="bilinear") def forward(self, x, measurements): mask = measurements[:, :6] mask = mask.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).repeat(1, 1, 1, 100, 100) velocity = measurements[:, 6:7].unsqueeze(-1).unsqueeze(-1) velocity = velocity.repeat(1, 32, 2, 2) n, d, c = x.shape x = x.transpose(1, 2) x = x.view(n, -1, 2, 2) x = torch.cat([x, velocity], dim=1) x = self.tconv0(x) x = self.tconv1(x) x = self.tconv2(x) x = self.tconv3(x) x = self.upsample(x) xs = [] for i in range(6): xt = self.tconv4_list[i](x) xs.append(xt) xs = torch.stack(xs, dim=1) x = torch.sum(xs * mask, dim=1) x = self.spatial_softmax(x) return x class LinearWaypointsPredictor(nn.Module): def __init__(self, input_dim, cumsum=True): super().__init__() self.cumsum = cumsum self.rank_embed = nn.Parameter(torch.zeros(1, 10, input_dim)) self.head_fc1_list = nn.ModuleList([nn.Linear(input_dim, 64) for _ in range(6)]) self.head_relu = nn.ReLU(inplace=True) self.head_fc2_list = nn.ModuleList([nn.Linear(64, 2) for _ in range(6)]) def forward(self, x, measurements): # input shape: n 10 embed_dim bs, n, dim = x.shape x = x + self.rank_embed x = x.reshape(-1, dim) mask = measurements[:, :6] mask = torch.unsqueeze(mask, -1).repeat(n, 1, 2) rs = [] for i in range(6): res = self.head_fc1_list[i](x) res = self.head_relu(res) res = self.head_fc2_list[i](res) rs.append(res) rs = torch.stack(rs, 1) x = torch.sum(rs * mask, dim=1) x = x.view(bs, n, 2) if self.cumsum: x = torch.cumsum(x, 1) return x class GRUWaypointsPredictor(nn.Module): def __init__(self, input_dim, waypoints=10): super().__init__() # self.gru = torch.nn.GRUCell(input_size=input_dim, hidden_size=64) self.gru = torch.nn.GRU(input_size=input_dim, hidden_size=64, batch_first=True) self.encoder = nn.Linear(2, 64) self.decoder = nn.Linear(64, 2) self.waypoints = waypoints def forward(self, x, target_point): bs = x.shape[0] z = self.encoder(target_point).unsqueeze(0) output, _ = self.gru(x, z) output = output.reshape(bs * self.waypoints, -1) output = self.decoder(output).reshape(bs, self.waypoints, 2) output = torch.cumsum(output, 1) return output class GRUWaypointsPredictorWithCommand(nn.Module): def __init__(self, input_dim, waypoints=10): super().__init__() # self.gru = torch.nn.GRUCell(input_size=input_dim, hidden_size=64) self.grus = nn.ModuleList([torch.nn.GRU(input_size=input_dim, hidden_size=64, batch_first=True) for _ in range(6)]) self.encoder = nn.Linear(2, 64) self.decoders = nn.ModuleList([nn.Linear(64, 2) for _ in range(6)]) self.waypoints = waypoints def forward(self, x, target_point, measurements): bs, n, dim = x.shape mask = measurements[:, :6, None, None] mask = mask.repeat(1, 1, self.waypoints, 2) z = self.encoder(target_point).unsqueeze(0) outputs = [] for i in range(6): output, _ = self.grus[i](x, z) output = output.reshape(bs * self.waypoints, -1) output = self.decoders[i](output).reshape(bs, self.waypoints, 2) output = torch.cumsum(output, 1) outputs.append(output) outputs = torch.stack(outputs, 1) output = torch.sum(outputs * mask, dim=1) return output class TransformerDecoder(nn.Module): def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False): super().__init__() self.layers = _get_clones(decoder_layer, num_layers) self.num_layers = num_layers self.norm = norm self.return_intermediate = return_intermediate def forward( self, tgt, memory, tgt_mask: Optional[Tensor] = None, memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, query_pos: Optional[Tensor] = None, ): output = tgt intermediate = [] for layer in self.layers: output = layer( output, memory, tgt_mask=tgt_mask, memory_mask=memory_mask, tgt_key_padding_mask=tgt_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask, pos=pos, query_pos=query_pos, ) if self.return_intermediate: intermediate.append(self.norm(output)) if self.norm is not None: output = self.norm(output) if self.return_intermediate: intermediate.pop() intermediate.append(output) if self.return_intermediate: return torch.stack(intermediate) return output.unsqueeze(0) class TransformerEncoderLayer(nn.Module): def __init__( self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation=nn.ReLU(), normalize_before=False, ): super().__init__() self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) # Implementation of Feedforward model self.linear1 = nn.Linear(d_model, dim_feedforward) self.dropout = nn.Dropout(dropout) self.linear2 = nn.Linear(dim_feedforward, d_model) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) self.dropout1 = nn.Dropout(dropout) self.dropout2 = nn.Dropout(dropout) self.activation = activation() self.normalize_before = normalize_before def with_pos_embed(self, tensor, pos: Optional[Tensor]): return tensor if pos is None else tensor + pos def forward_post( self, src, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, ): q = k = self.with_pos_embed(src, pos) src2 = self.self_attn( q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask )[0] src = src + self.dropout1(src2) src = self.norm1(src) src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) src = src + self.dropout2(src2) src = self.norm2(src) return src def forward_pre( self, src, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, ): src2 = self.norm1(src) q = k = self.with_pos_embed(src2, pos) src2 = self.self_attn( q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask )[0] src = src + self.dropout1(src2) src2 = self.norm2(src) src2 = self.linear2(self.dropout(self.activation(self.linear1(src2)))) src = src + self.dropout2(src2) return src def forward( self, src, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, ): if self.normalize_before: return self.forward_pre(src, src_mask, src_key_padding_mask, pos) return self.forward_post(src, src_mask, src_key_padding_mask, pos) class TransformerDecoderLayer(nn.Module): def __init__( self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation=nn.ReLU(), normalize_before=False, ): super().__init__() self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) # Implementation of Feedforward model self.linear1 = nn.Linear(d_model, dim_feedforward) self.dropout = nn.Dropout(dropout) self.linear2 = nn.Linear(dim_feedforward, d_model) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) self.norm3 = nn.LayerNorm(d_model) self.dropout1 = nn.Dropout(dropout) self.dropout2 = nn.Dropout(dropout) self.dropout3 = nn.Dropout(dropout) self.activation = activation() self.normalize_before = normalize_before def with_pos_embed(self, tensor, pos: Optional[Tensor]): return tensor if pos is None else tensor + pos def forward_post( self, tgt, memory, tgt_mask: Optional[Tensor] = None, memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, query_pos: Optional[Tensor] = None, ): q = k = self.with_pos_embed(tgt, query_pos) tgt2 = self.self_attn( q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask )[0] tgt = tgt + self.dropout1(tgt2) tgt = self.norm1(tgt) tgt2 = self.multihead_attn( query=self.with_pos_embed(tgt, query_pos), key=self.with_pos_embed(memory, pos), value=memory, attn_mask=memory_mask, key_padding_mask=memory_key_padding_mask, )[0] tgt = tgt + self.dropout2(tgt2) tgt = self.norm2(tgt) tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) tgt = tgt + self.dropout3(tgt2) tgt = self.norm3(tgt) return tgt def forward_pre( self, tgt, memory, tgt_mask: Optional[Tensor] = None, memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, query_pos: Optional[Tensor] = None, ): tgt2 = self.norm1(tgt) q = k = self.with_pos_embed(tgt2, query_pos) tgt2 = self.self_attn( q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask )[0] tgt = tgt + self.dropout1(tgt2) tgt2 = self.norm2(tgt) tgt2 = self.multihead_attn( query=self.with_pos_embed(tgt2, query_pos), key=self.with_pos_embed(memory, pos), value=memory, attn_mask=memory_mask, key_padding_mask=memory_key_padding_mask, )[0] tgt = tgt + self.dropout2(tgt2) tgt2 = self.norm3(tgt) tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) tgt = tgt + self.dropout3(tgt2) return tgt def forward( self, tgt, memory, tgt_mask: Optional[Tensor] = None, memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, query_pos: Optional[Tensor] = None, ): if self.normalize_before: return self.forward_pre( tgt, memory, tgt_mask, memory_mask, tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos, ) return self.forward_post( tgt, memory, tgt_mask, memory_mask, tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos, ) def _get_clones(module, N): return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) def _get_activation_fn(activation): """Return an activation function given a string""" if activation == "relu": return F.relu if activation == "gelu": return F.gelu if activation == "glu": return F.glu raise RuntimeError(f"activation should be relu/gelu, not {activation}.") def build_attn_mask(mask_type): mask = torch.ones((151, 151), dtype=torch.bool).cuda() if mask_type == "seperate_all": mask[:50, :50] = False mask[50:67, 50:67] = False mask[67:84, 67:84] = False mask[84:101, 84:101] = False mask[101:151, 101:151] = False elif mask_type == "seperate_view": mask[:50, :50] = False mask[50:67, 50:67] = False mask[67:84, 67:84] = False mask[84:101, 84:101] = False mask[101:151, :] = False mask[:, 101:151] = False return mask # class InterfuserModel(nn.Module): class InterfuserModel(nn.Module): def __init__( self, img_size=224, multi_view_img_size=112, patch_size=8, in_chans=3, embed_dim=768, enc_depth=6, dec_depth=6, dim_feedforward=2048, normalize_before=False, rgb_backbone_name="r50", lidar_backbone_name="r50", num_heads=8, norm_layer=None, dropout=0.1, end2end=False, direct_concat=False, separate_view_attention=False, separate_all_attention=False, act_layer=None, weight_init="", freeze_num=-1, with_lidar=False, with_right_left_sensors=False, with_center_sensor=False, traffic_pred_head_type="det", waypoints_pred_head="heatmap", reverse_pos=True, use_different_backbone=False, use_view_embed=False, use_mmad_pretrain=None, ): super().__init__() self.traffic_pred_head_type = traffic_pred_head_type self.num_features = ( self.embed_dim ) = embed_dim # num_features for consistency with other models norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) act_layer = act_layer or nn.GELU self.reverse_pos = reverse_pos self.waypoints_pred_head = waypoints_pred_head self.with_lidar = with_lidar self.with_right_left_sensors = with_right_left_sensors self.with_center_sensor = with_center_sensor self.direct_concat = direct_concat self.separate_view_attention = separate_view_attention self.separate_all_attention = separate_all_attention self.end2end = end2end self.use_view_embed = use_view_embed if self.direct_concat: in_chans = in_chans * 4 self.with_center_sensor = False self.with_right_left_sensors = False if self.separate_view_attention: self.attn_mask = build_attn_mask("seperate_view") elif self.separate_all_attention: self.attn_mask = build_attn_mask("seperate_all") else: self.attn_mask = None if use_different_backbone: if rgb_backbone_name == "r50": self.rgb_backbone = resnet50d( pretrained=True, in_chans=in_chans, features_only=True, out_indices=[4], ) elif rgb_backbone_name == "r26": self.rgb_backbone = resnet26d( pretrained=True, in_chans=in_chans, features_only=True, out_indices=[4], ) elif rgb_backbone_name == "r18": self.rgb_backbone = resnet18d( pretrained=True, in_chans=in_chans, features_only=True, out_indices=[4], ) if lidar_backbone_name == "r50": self.lidar_backbone = resnet50d( pretrained=False, in_chans=in_chans, features_only=True, out_indices=[4], ) elif lidar_backbone_name == "r26": self.lidar_backbone = resnet26d( pretrained=False, in_chans=in_chans, features_only=True, out_indices=[4], ) elif lidar_backbone_name == "r18": self.lidar_backbone = resnet18d( pretrained=False, in_chans=3, features_only=True, out_indices=[4] ) rgb_embed_layer = partial(HybridEmbed, backbone=self.rgb_backbone) lidar_embed_layer = partial(HybridEmbed, backbone=self.lidar_backbone) if use_mmad_pretrain: params = torch.load(use_mmad_pretrain)["state_dict"] updated_params = OrderedDict() for key in params: if "backbone" in key: updated_params[key.replace("backbone.", "")] = params[key] self.rgb_backbone.load_state_dict(updated_params) self.rgb_patch_embed = rgb_embed_layer( img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, ) self.lidar_patch_embed = lidar_embed_layer( img_size=img_size, patch_size=patch_size, in_chans=3, embed_dim=embed_dim, ) else: if rgb_backbone_name == "r50": self.rgb_backbone = resnet50d( pretrained=True, in_chans=3, features_only=True, out_indices=[4] ) elif rgb_backbone_name == "r101": self.rgb_backbone = resnet101d( pretrained=True, in_chans=3, features_only=True, out_indices=[4] ) elif rgb_backbone_name == "r26": self.rgb_backbone = resnet26d( pretrained=True, in_chans=3, features_only=True, out_indices=[4] ) elif rgb_backbone_name == "r18": self.rgb_backbone = resnet18d( pretrained=True, in_chans=3, features_only=True, out_indices=[4] ) embed_layer = partial(HybridEmbed, backbone=self.rgb_backbone) self.rgb_patch_embed = embed_layer( img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, ) self.lidar_patch_embed = embed_layer( img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, ) self.global_embed = nn.Parameter(torch.zeros(1, embed_dim, 5)) self.view_embed = nn.Parameter(torch.zeros(1, embed_dim, 5, 1)) if self.end2end: self.query_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, 4)) self.query_embed = nn.Parameter(torch.zeros(4, 1, embed_dim)) elif self.waypoints_pred_head == "heatmap": self.query_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, 5)) self.query_embed = nn.Parameter(torch.zeros(400 + 5, 1, embed_dim)) else: self.query_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, 11)) self.query_embed = nn.Parameter(torch.zeros(400 + 11, 1, embed_dim)) if self.end2end: self.waypoints_generator = GRUWaypointsPredictor(embed_dim, 4) elif self.waypoints_pred_head == "heatmap": self.waypoints_generator = MultiPath_Generator( embed_dim + 32, embed_dim, 10 ) elif self.waypoints_pred_head == "gru": self.waypoints_generator = GRUWaypointsPredictor(embed_dim) elif self.waypoints_pred_head == "gru-command": self.waypoints_generator = GRUWaypointsPredictorWithCommand(embed_dim) elif self.waypoints_pred_head == "linear": self.waypoints_generator = LinearWaypointsPredictor(embed_dim) elif self.waypoints_pred_head == "linear-sum": self.waypoints_generator = LinearWaypointsPredictor(embed_dim, cumsum=True) self.junction_pred_head = nn.Linear(embed_dim, 2) self.traffic_light_pred_head = nn.Linear(embed_dim, 2) self.stop_sign_head = nn.Linear(embed_dim, 2) if self.traffic_pred_head_type == "det": self.traffic_pred_head = nn.Sequential( *[ nn.Linear(embed_dim + 32, 64), nn.ReLU(), nn.Linear(64, 7), # nn.Sigmoid(), ] ) elif self.traffic_pred_head_type == "seg": self.traffic_pred_head = nn.Sequential( *[nn.Linear(embed_dim, 64), nn.ReLU(), nn.Linear(64, 1), nn.Sigmoid()] ) self.position_encoding = PositionEmbeddingSine(embed_dim // 2, normalize=True) encoder_layer = TransformerEncoderLayer( embed_dim, num_heads, dim_feedforward, dropout, act_layer, normalize_before ) self.encoder = TransformerEncoder(encoder_layer, enc_depth, None) decoder_layer = TransformerDecoderLayer( embed_dim, num_heads, dim_feedforward, dropout, act_layer, normalize_before ) decoder_norm = nn.LayerNorm(embed_dim) self.decoder = TransformerDecoder( decoder_layer, dec_depth, decoder_norm, return_intermediate=False ) self.reset_parameters() def reset_parameters(self): nn.init.uniform_(self.global_embed) nn.init.uniform_(self.view_embed) nn.init.uniform_(self.query_embed) nn.init.uniform_(self.query_pos_embed) def forward_features( self, front_image, left_image, right_image, front_center_image, lidar, measurements, ): features = [] # Front view processing front_image_token, front_image_token_global = self.rgb_patch_embed(front_image) if self.use_view_embed: front_image_token = ( front_image_token + self.view_embed[:, :, 0:1, :] + self.position_encoding(front_image_token) ) else: front_image_token = front_image_token + self.position_encoding( front_image_token ) front_image_token = front_image_token.flatten(2).permute(2, 0, 1) front_image_token_global = ( front_image_token_global + self.view_embed[:, :, 0, :] + self.global_embed[:, :, 0:1] ) front_image_token_global = front_image_token_global.permute(2, 0, 1) features.extend([front_image_token, front_image_token_global]) if self.with_right_left_sensors: # Left view processing left_image_token, left_image_token_global = self.rgb_patch_embed(left_image) if self.use_view_embed: left_image_token = ( left_image_token + self.view_embed[:, :, 1:2, :] + self.position_encoding(left_image_token) ) else: left_image_token = left_image_token + self.position_encoding( left_image_token ) left_image_token = left_image_token.flatten(2).permute(2, 0, 1) left_image_token_global = ( left_image_token_global + self.view_embed[:, :, 1, :] + self.global_embed[:, :, 1:2] ) left_image_token_global = left_image_token_global.permute(2, 0, 1) # Right view processing right_image_token, right_image_token_global = self.rgb_patch_embed( right_image ) if self.use_view_embed: right_image_token = ( right_image_token + self.view_embed[:, :, 2:3, :] + self.position_encoding(right_image_token) ) else: right_image_token = right_image_token + self.position_encoding( right_image_token ) right_image_token = right_image_token.flatten(2).permute(2, 0, 1) right_image_token_global = ( right_image_token_global + self.view_embed[:, :, 2, :] + self.global_embed[:, :, 2:3] ) right_image_token_global = right_image_token_global.permute(2, 0, 1) features.extend( [ left_image_token, left_image_token_global, right_image_token, right_image_token_global, ] ) if self.with_center_sensor: # Front center view processing ( front_center_image_token, front_center_image_token_global, ) = self.rgb_patch_embed(front_center_image) if self.use_view_embed: front_center_image_token = ( front_center_image_token + self.view_embed[:, :, 3:4, :] + self.position_encoding(front_center_image_token) ) else: front_center_image_token = ( front_center_image_token + self.position_encoding(front_center_image_token) ) front_center_image_token = front_center_image_token.flatten(2).permute( 2, 0, 1 ) front_center_image_token_global = ( front_center_image_token_global + self.view_embed[:, :, 3, :] + self.global_embed[:, :, 3:4] ) front_center_image_token_global = front_center_image_token_global.permute( 2, 0, 1 ) features.extend([front_center_image_token, front_center_image_token_global]) if self.with_lidar: lidar_token, lidar_token_global = self.lidar_patch_embed(lidar) if self.use_view_embed: lidar_token = ( lidar_token + self.view_embed[:, :, 4:5, :] + self.position_encoding(lidar_token) ) else: lidar_token = lidar_token + self.position_encoding(lidar_token) lidar_token = lidar_token.flatten(2).permute(2, 0, 1) lidar_token_global = ( lidar_token_global + self.view_embed[:, :, 4, :] + self.global_embed[:, :, 4:5] ) lidar_token_global = lidar_token_global.permute(2, 0, 1) features.extend([lidar_token, lidar_token_global]) features = torch.cat(features, 0) return features def forward(self, x): front_image = x["rgb"] left_image = x["rgb_left"] right_image = x["rgb_right"] front_center_image = x["rgb_center"] measurements = x["measurements"] target_point = x["target_point"] lidar = x["lidar"] if self.direct_concat: img_size = front_image.shape[-1] left_image = torch.nn.functional.interpolate( left_image, size=(img_size, img_size) ) right_image = torch.nn.functional.interpolate( right_image, size=(img_size, img_size) ) front_center_image = torch.nn.functional.interpolate( front_center_image, size=(img_size, img_size) ) front_image = torch.cat( [front_image, left_image, right_image, front_center_image], dim=1 ) features = self.forward_features( front_image, left_image, right_image, front_center_image, lidar, measurements, ) bs = front_image.shape[0] if self.end2end: tgt = self.query_pos_embed.repeat(bs, 1, 1) else: tgt = self.position_encoding( torch.ones((bs, 1, 20, 20), device=x["rgb"].device) ) tgt = tgt.flatten(2) tgt = torch.cat([tgt, self.query_pos_embed.repeat(bs, 1, 1)], 2) tgt = tgt.permute(2, 0, 1) memory = self.encoder(features, mask=self.attn_mask) hs = self.decoder(self.query_embed.repeat(1, bs, 1), memory, query_pos=tgt)[0] hs = hs.permute(1, 0, 2) # Batchsize , N, C if self.end2end: waypoints = self.waypoints_generator(hs, target_point) return waypoints if self.waypoints_pred_head != "heatmap": traffic_feature = hs[:, :400] is_junction_feature = hs[:, 400] traffic_light_state_feature = hs[:, 400] stop_sign_feature = hs[:, 400] waypoints_feature = hs[:, 401:411] else: traffic_feature = hs[:, :400] is_junction_feature = hs[:, 400] traffic_light_state_feature = hs[:, 400] stop_sign_feature = hs[:, 400] waypoints_feature = hs[:, 401:405] if self.waypoints_pred_head == "heatmap": waypoints = self.waypoints_generator(waypoints_feature, measurements) elif self.waypoints_pred_head == "gru": waypoints = self.waypoints_generator(waypoints_feature, target_point) elif self.waypoints_pred_head == "gru-command": waypoints = self.waypoints_generator(waypoints_feature, target_point, measurements) elif self.waypoints_pred_head == "linear": waypoints = self.waypoints_generator(waypoints_feature, measurements) elif self.waypoints_pred_head == "linear-sum": waypoints = self.waypoints_generator(waypoints_feature, measurements) is_junction = self.junction_pred_head(is_junction_feature) traffic_light_state = self.traffic_light_pred_head(traffic_light_state_feature) stop_sign = self.stop_sign_head(stop_sign_feature) velocity = measurements[:, 6:7].unsqueeze(-1) velocity = velocity.repeat(1, 400, 32) traffic_feature_with_vel = torch.cat([traffic_feature, velocity], dim=2) traffic = self.traffic_pred_head(traffic_feature_with_vel) return traffic, waypoints, is_junction, traffic_light_state, stop_sign, traffic_feature def load_pretrained(self, model_path, strict=False): """ تحميل الأوزان المدربة مسبقاً - نسخة محسنة Args: model_path (str): مسار ملف الأوزان strict (bool): إذا كان True، يتطلب تطابق تام للمفاتيح """ if not model_path or not Path(model_path).exists(): logging.warning(f"ملف الأوزان غير موجود: {model_path}") logging.info("سيتم استخدام أوزان عشوائية") return False try: logging.info(f"محاولة تحميل الأوزان من: {model_path}") # تحميل الملف مع معالجة أنواع مختلفة من ملفات الحفظ checkpoint = torch.load(model_path, map_location='cpu', weights_only=False) # استخراج state_dict من أنواع مختلفة من ملفات الحفظ if isinstance(checkpoint, dict): if 'model_state_dict' in checkpoint: state_dict = checkpoint['model_state_dict'] logging.info("تم العثور على 'model_state_dict' في الملف") elif 'state_dict' in checkpoint: state_dict = checkpoint['state_dict'] logging.info("تم العثور على 'state_dict' في الملف") elif 'model' in checkpoint: state_dict = checkpoint['model'] logging.info("تم العثور على 'model' في الملف") else: state_dict = checkpoint logging.info("استخدام الملف كـ state_dict مباشرة") else: state_dict = checkpoint logging.info("استخدام الملف كـ state_dict مباشرة") # تنظيف أسماء المفاتيح (إزالة 'module.' إذا كانت موجودة) clean_state_dict = OrderedDict() for k, v in state_dict.items(): # إزالة 'module.' من بداية اسم المفتاح إذا كان موجوداً clean_key = k[7:] if k.startswith('module.') else k clean_state_dict[clean_key] = v # تحميل الأوزان missing_keys, unexpected_keys = self.load_state_dict(clean_state_dict, strict=strict) # تقرير حالة التحميل if missing_keys: logging.warning(f"مفاتيح مفقودة ({len(missing_keys)}): {missing_keys[:5]}..." if len(missing_keys) > 5 else f"مفاتيح مفقودة: {missing_keys}") if unexpected_keys: logging.warning(f"مفاتيح غير متوقعة ({len(unexpected_keys)}): {unexpected_keys[:5]}..." if len(unexpected_keys) > 5 else f"مفاتيح غير متوقعة: {unexpected_keys}") if not missing_keys and not unexpected_keys: logging.info("✅ تم تحميل جميع الأوزان بنجاح تام") elif not strict: logging.info("✅ تم تحميل الأوزان بنجاح (مع تجاهل عدم التطابق)") return True except Exception as e: logging.error(f"❌ خطأ في تحميل الأوزان: {str(e)}") logging.info("سيتم استخدام أوزان عشوائية") return False # ============================================================================ # دوال مساعدة لتحميل النموذج # ============================================================================ def load_and_prepare_model(config, device): """ يقوم بإنشاء النموذج وتحميل الأوزان المدربة مسبقًا. Args: config (dict): إعدادات النموذج والمسارات device (torch.device): الجهاز المستهدف (CPU/GPU) Returns: InterfuserModel: النموذج المحمل """ try: # إنشاء النموذج model = InterfuserModel(**config.get('model_params', {})).to(device) logging.info(f"تم إنشاء النموذج على الجهاز: {device}") # تحميل الأوزان إذا كان المسار محدد checkpoint_path = config.get('paths', {}).get('pretrained_weights') if checkpoint_path: success = model.load_pretrained(checkpoint_path, strict=False) if success: logging.info("✅ تم تحميل النموذج والأوزان بنجاح") else: logging.warning("⚠️ تم إنشاء النموذج بأوزان عشوائية") else: logging.info("لم يتم تحديد مسار الأوزان، سيتم استخدام أوزان عشوائية") # وضع النموذج في وضع التقييم model.eval() return model except Exception as e: logging.error(f"خطأ في إنشاء النموذج: {str(e)}") raise def create_model_config(model_path="model/best_model.pth", **model_params): """ إنشاء إعدادات النموذج باستخدام الإعدادات الصحيحة من التدريب Args: model_path (str): مسار ملف الأوزان **model_params: معاملات النموذج الإضافية Returns: dict: إعدادات النموذج """ # الإعدادات الصحيحة من كونفيج التدريب الأصلي training_config_params = { "img_size": 224, "embed_dim": 256, # مهم: هذه القيمة من التدريب الأصلي "enc_depth": 6, "dec_depth": 6, "rgb_backbone_name": 'r50', "lidar_backbone_name": 'r18', "waypoints_pred_head": 'gru', "use_different_backbone": True, "with_lidar": False, "with_right_left_sensors": False, "with_center_sensor": False, # إعدادات إضافية من الكونفيج الأصلي "multi_view_img_size": 112, "patch_size": 8, "in_chans": 3, "dim_feedforward": 2048, "normalize_before": False, "num_heads": 8, "dropout": 0.1, "end2end": False, "direct_concat": False, "separate_view_attention": False, "separate_all_attention": False, "freeze_num": -1, "traffic_pred_head_type": "det", "reverse_pos": True, "use_view_embed": False, "use_mmad_pretrain": None, } # دمج المعاملات المخصصة مع الإعدادات من التدريب training_config_params.update(model_params) config = { 'model_params': training_config_params, 'paths': { 'pretrained_weights': model_path }, # إضافة إعدادات الشبكة من التدريب 'grid_conf': { 'h': 20, 'w': 20, 'x_res': 1.0, 'y_res': 1.0, 'y_min': 0.0, 'y_max': 20.0, 'x_min': -10.0, 'x_max': 10.0, }, # معلومات إضافية عن التدريب 'training_info': { 'original_project': 'Interfuser_Finetuning', 'run_name': 'Finetune_Focus_on_Detection_v5', 'focus': 'traffic_detection_and_iou', 'backbone': 'ResNet50 + ResNet18', 'trained_on': 'PDM_Lite_Carla' } } return config def get_training_config(): """ إرجاع إعدادات التدريب الأصلية للمرجع هذه الإعدادات توضح كيف تم تدريب النموذج """ return { 'project_info': { 'project': 'Interfuser_Finetuning', 'entity': None, 'run_name': 'Finetune_Focus_on_Detection_v5' }, 'training': { 'epochs': 50, 'batch_size': 8, 'num_workers': 2, 'learning_rate': 1e-4, # معدل تعلم منخفض للـ Fine-tuning 'weight_decay': 1e-2, 'patience': 15, 'clip_grad_norm': 1.0, }, 'loss_weights': { 'iou': 2.0, # أولوية قصوى لدقة الصناديق 'traffic_map': 25.0, # تركيز عالي على اكتشاف الكائنات 'waypoints': 1.0, # مرجع أساسي 'junction': 0.25, # مهام متقنة بالفعل 'traffic_light': 0.5, 'stop_sign': 0.25, }, 'data_split': { 'strategy': 'interleaved', 'segment_length': 100, 'validation_frequency': 10, }, 'transforms': { 'use_data_augmentation': False, # معطل للتركيز على البيانات الأصلية } }