Baseer_Server / model_definition.py
Adam
Deploy Baseer Self-Driving API v1.0
7b0dd2f
raw
history blame
49 kB
# model_definition.py
# ============================================================================
# الاستيرادات الأساسية
# ============================================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import AdamW
from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from functools import partial
from typing import Optional, List
from torch import Tensor
# مكتبات إضافية
import os
import json
import logging
import math
import copy
from pathlib import Path
from collections import OrderedDict
# مكتبات معالجة البيانات
import numpy as np
import cv2
# مكتبات اختيارية (يمكن تعطيلها إذا لم تكن متوفرة)
try:
import wandb
WANDB_AVAILABLE = True
except ImportError:
WANDB_AVAILABLE = False
try:
from tqdm import tqdm
except ImportError:
# إذا لم تكن tqdm متوفرة، استخدم دالة بديلة
def tqdm(iterable, *args, **kwargs):
return iterable
# ============================================================================
# دوال مساعدة
# ============================================================================
def to_2tuple(x):
"""تحويل قيمة إلى tuple من عنصرين"""
if isinstance(x, (list, tuple)):
return tuple(x)
return (x, x)
# ============================================================================
# ============================================================================
class HybridEmbed(nn.Module):
def __init__(
self,
backbone,
img_size=224,
patch_size=1,
feature_size=None,
in_chans=3,
embed_dim=768,
):
super().__init__()
assert isinstance(backbone, nn.Module)
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
self.img_size = img_size
self.patch_size = patch_size
self.backbone = backbone
if feature_size is None:
with torch.no_grad():
training = backbone.training
if training:
backbone.eval()
o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))
if isinstance(o, (list, tuple)):
o = o[-1] # last feature if backbone outputs list/tuple of features
feature_size = o.shape[-2:]
feature_dim = o.shape[1]
backbone.train(training)
else:
feature_size = to_2tuple(feature_size)
if hasattr(self.backbone, "feature_info"):
feature_dim = self.backbone.feature_info.channels()[-1]
else:
feature_dim = self.backbone.num_features
self.proj = nn.Conv2d(feature_dim, embed_dim, kernel_size=1, stride=1)
def forward(self, x):
x = self.backbone(x)
if isinstance(x, (list, tuple)):
x = x[-1] # last feature if backbone outputs list/tuple of features
x = self.proj(x)
global_x = torch.mean(x, [2, 3], keepdim=False)[:, :, None]
return x, global_x
class PositionEmbeddingSine(nn.Module):
"""
This is a more standard version of the position embedding, very similar to the one
used by the Attention is all you need paper, generalized to work on images.
"""
def __init__(
self, num_pos_feats=64, temperature=10000, normalize=False, scale=None
):
super().__init__()
self.num_pos_feats = num_pos_feats
self.temperature = temperature
self.normalize = normalize
if scale is not None and normalize is False:
raise ValueError("normalize should be True if scale is passed")
if scale is None:
scale = 2 * math.pi
self.scale = scale
def forward(self, tensor):
x = tensor
bs, _, h, w = x.shape
not_mask = torch.ones((bs, h, w), device=x.device)
y_embed = not_mask.cumsum(1, dtype=torch.float32)
x_embed = not_mask.cumsum(2, dtype=torch.float32)
if self.normalize:
eps = 1e-6
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
pos_x = x_embed[:, :, :, None] / dim_t
pos_y = y_embed[:, :, :, None] / dim_t
pos_x = torch.stack(
(pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
).flatten(3)
pos_y = torch.stack(
(pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
).flatten(3)
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
return pos
class TransformerEncoder(nn.Module):
def __init__(self, encoder_layer, num_layers, norm=None):
super().__init__()
self.layers = _get_clones(encoder_layer, num_layers)
self.num_layers = num_layers
self.norm = norm
def forward(
self,
src,
mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
):
output = src
for layer in self.layers:
output = layer(
output,
src_mask=mask,
src_key_padding_mask=src_key_padding_mask,
pos=pos,
)
if self.norm is not None:
output = self.norm(output)
return output
class SpatialSoftmax(nn.Module):
def __init__(self, height, width, channel, temperature=None, data_format="NCHW"):
super().__init__()
self.data_format = data_format
self.height = height
self.width = width
self.channel = channel
if temperature:
self.temperature = Parameter(torch.ones(1) * temperature)
else:
self.temperature = 1.0
pos_x, pos_y = np.meshgrid(
np.linspace(-1.0, 1.0, self.height), np.linspace(-1.0, 1.0, self.width)
)
pos_x = torch.from_numpy(pos_x.reshape(self.height * self.width)).float()
pos_y = torch.from_numpy(pos_y.reshape(self.height * self.width)).float()
self.register_buffer("pos_x", pos_x)
self.register_buffer("pos_y", pos_y)
def forward(self, feature):
# Output:
# (N, C*2) x_0 y_0 ...
if self.data_format == "NHWC":
feature = (
feature.transpose(1, 3)
.tranpose(2, 3)
.view(-1, self.height * self.width)
)
else:
feature = feature.view(-1, self.height * self.width)
weight = F.softmax(feature / self.temperature, dim=-1)
expected_x = torch.sum(
torch.autograd.Variable(self.pos_x) * weight, dim=1, keepdim=True
)
expected_y = torch.sum(
torch.autograd.Variable(self.pos_y) * weight, dim=1, keepdim=True
)
expected_xy = torch.cat([expected_x, expected_y], 1)
feature_keypoints = expected_xy.view(-1, self.channel, 2)
feature_keypoints[:, :, 1] = (feature_keypoints[:, :, 1] - 1) * 12
feature_keypoints[:, :, 0] = feature_keypoints[:, :, 0] * 12
return feature_keypoints
class MultiPath_Generator(nn.Module):
def __init__(self, in_channel, embed_dim, out_channel):
super().__init__()
self.spatial_softmax = SpatialSoftmax(100, 100, out_channel)
self.tconv0 = nn.Sequential(
nn.ConvTranspose2d(in_channel, 256, 4, 2, 1, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(True),
)
self.tconv1 = nn.Sequential(
nn.ConvTranspose2d(256, 256, 4, 2, 1, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(True),
)
self.tconv2 = nn.Sequential(
nn.ConvTranspose2d(256, 192, 4, 2, 1, bias=False),
nn.BatchNorm2d(192),
nn.ReLU(True),
)
self.tconv3 = nn.Sequential(
nn.ConvTranspose2d(192, 64, 4, 2, 1, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(True),
)
self.tconv4_list = torch.nn.ModuleList(
[
nn.Sequential(
nn.ConvTranspose2d(64, out_channel, 8, 2, 3, bias=False),
nn.Tanh(),
)
for _ in range(6)
]
)
self.upsample = nn.Upsample(size=(50, 50), mode="bilinear")
def forward(self, x, measurements):
mask = measurements[:, :6]
mask = mask.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).repeat(1, 1, 1, 100, 100)
velocity = measurements[:, 6:7].unsqueeze(-1).unsqueeze(-1)
velocity = velocity.repeat(1, 32, 2, 2)
n, d, c = x.shape
x = x.transpose(1, 2)
x = x.view(n, -1, 2, 2)
x = torch.cat([x, velocity], dim=1)
x = self.tconv0(x)
x = self.tconv1(x)
x = self.tconv2(x)
x = self.tconv3(x)
x = self.upsample(x)
xs = []
for i in range(6):
xt = self.tconv4_list[i](x)
xs.append(xt)
xs = torch.stack(xs, dim=1)
x = torch.sum(xs * mask, dim=1)
x = self.spatial_softmax(x)
return x
class LinearWaypointsPredictor(nn.Module):
def __init__(self, input_dim, cumsum=True):
super().__init__()
self.cumsum = cumsum
self.rank_embed = nn.Parameter(torch.zeros(1, 10, input_dim))
self.head_fc1_list = nn.ModuleList([nn.Linear(input_dim, 64) for _ in range(6)])
self.head_relu = nn.ReLU(inplace=True)
self.head_fc2_list = nn.ModuleList([nn.Linear(64, 2) for _ in range(6)])
def forward(self, x, measurements):
# input shape: n 10 embed_dim
bs, n, dim = x.shape
x = x + self.rank_embed
x = x.reshape(-1, dim)
mask = measurements[:, :6]
mask = torch.unsqueeze(mask, -1).repeat(n, 1, 2)
rs = []
for i in range(6):
res = self.head_fc1_list[i](x)
res = self.head_relu(res)
res = self.head_fc2_list[i](res)
rs.append(res)
rs = torch.stack(rs, 1)
x = torch.sum(rs * mask, dim=1)
x = x.view(bs, n, 2)
if self.cumsum:
x = torch.cumsum(x, 1)
return x
class GRUWaypointsPredictor(nn.Module):
def __init__(self, input_dim, waypoints=10):
super().__init__()
# self.gru = torch.nn.GRUCell(input_size=input_dim, hidden_size=64)
self.gru = torch.nn.GRU(input_size=input_dim, hidden_size=64, batch_first=True)
self.encoder = nn.Linear(2, 64)
self.decoder = nn.Linear(64, 2)
self.waypoints = waypoints
def forward(self, x, target_point):
bs = x.shape[0]
z = self.encoder(target_point).unsqueeze(0)
output, _ = self.gru(x, z)
output = output.reshape(bs * self.waypoints, -1)
output = self.decoder(output).reshape(bs, self.waypoints, 2)
output = torch.cumsum(output, 1)
return output
class GRUWaypointsPredictorWithCommand(nn.Module):
def __init__(self, input_dim, waypoints=10):
super().__init__()
# self.gru = torch.nn.GRUCell(input_size=input_dim, hidden_size=64)
self.grus = nn.ModuleList([torch.nn.GRU(input_size=input_dim, hidden_size=64, batch_first=True) for _ in range(6)])
self.encoder = nn.Linear(2, 64)
self.decoders = nn.ModuleList([nn.Linear(64, 2) for _ in range(6)])
self.waypoints = waypoints
def forward(self, x, target_point, measurements):
bs, n, dim = x.shape
mask = measurements[:, :6, None, None]
mask = mask.repeat(1, 1, self.waypoints, 2)
z = self.encoder(target_point).unsqueeze(0)
outputs = []
for i in range(6):
output, _ = self.grus[i](x, z)
output = output.reshape(bs * self.waypoints, -1)
output = self.decoders[i](output).reshape(bs, self.waypoints, 2)
output = torch.cumsum(output, 1)
outputs.append(output)
outputs = torch.stack(outputs, 1)
output = torch.sum(outputs * mask, dim=1)
return output
class TransformerDecoder(nn.Module):
def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
super().__init__()
self.layers = _get_clones(decoder_layer, num_layers)
self.num_layers = num_layers
self.norm = norm
self.return_intermediate = return_intermediate
def forward(
self,
tgt,
memory,
tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None,
):
output = tgt
intermediate = []
for layer in self.layers:
output = layer(
output,
memory,
tgt_mask=tgt_mask,
memory_mask=memory_mask,
tgt_key_padding_mask=tgt_key_padding_mask,
memory_key_padding_mask=memory_key_padding_mask,
pos=pos,
query_pos=query_pos,
)
if self.return_intermediate:
intermediate.append(self.norm(output))
if self.norm is not None:
output = self.norm(output)
if self.return_intermediate:
intermediate.pop()
intermediate.append(output)
if self.return_intermediate:
return torch.stack(intermediate)
return output.unsqueeze(0)
class TransformerEncoderLayer(nn.Module):
def __init__(
self,
d_model,
nhead,
dim_feedforward=2048,
dropout=0.1,
activation=nn.ReLU(),
normalize_before=False,
):
super().__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
# Implementation of Feedforward model
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.activation = activation()
self.normalize_before = normalize_before
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
return tensor if pos is None else tensor + pos
def forward_post(
self,
src,
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
):
q = k = self.with_pos_embed(src, pos)
src2 = self.self_attn(
q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask
)[0]
src = src + self.dropout1(src2)
src = self.norm1(src)
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
src = src + self.dropout2(src2)
src = self.norm2(src)
return src
def forward_pre(
self,
src,
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
):
src2 = self.norm1(src)
q = k = self.with_pos_embed(src2, pos)
src2 = self.self_attn(
q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask
)[0]
src = src + self.dropout1(src2)
src2 = self.norm2(src)
src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
src = src + self.dropout2(src2)
return src
def forward(
self,
src,
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
):
if self.normalize_before:
return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
return self.forward_post(src, src_mask, src_key_padding_mask, pos)
class TransformerDecoderLayer(nn.Module):
def __init__(
self,
d_model,
nhead,
dim_feedforward=2048,
dropout=0.1,
activation=nn.ReLU(),
normalize_before=False,
):
super().__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
# Implementation of Feedforward model
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.dropout3 = nn.Dropout(dropout)
self.activation = activation()
self.normalize_before = normalize_before
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
return tensor if pos is None else tensor + pos
def forward_post(
self,
tgt,
memory,
tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None,
):
q = k = self.with_pos_embed(tgt, query_pos)
tgt2 = self.self_attn(
q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
)[0]
tgt = tgt + self.dropout1(tgt2)
tgt = self.norm1(tgt)
tgt2 = self.multihead_attn(
query=self.with_pos_embed(tgt, query_pos),
key=self.with_pos_embed(memory, pos),
value=memory,
attn_mask=memory_mask,
key_padding_mask=memory_key_padding_mask,
)[0]
tgt = tgt + self.dropout2(tgt2)
tgt = self.norm2(tgt)
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
tgt = tgt + self.dropout3(tgt2)
tgt = self.norm3(tgt)
return tgt
def forward_pre(
self,
tgt,
memory,
tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None,
):
tgt2 = self.norm1(tgt)
q = k = self.with_pos_embed(tgt2, query_pos)
tgt2 = self.self_attn(
q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
)[0]
tgt = tgt + self.dropout1(tgt2)
tgt2 = self.norm2(tgt)
tgt2 = self.multihead_attn(
query=self.with_pos_embed(tgt2, query_pos),
key=self.with_pos_embed(memory, pos),
value=memory,
attn_mask=memory_mask,
key_padding_mask=memory_key_padding_mask,
)[0]
tgt = tgt + self.dropout2(tgt2)
tgt2 = self.norm3(tgt)
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
tgt = tgt + self.dropout3(tgt2)
return tgt
def forward(
self,
tgt,
memory,
tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None,
):
if self.normalize_before:
return self.forward_pre(
tgt,
memory,
tgt_mask,
memory_mask,
tgt_key_padding_mask,
memory_key_padding_mask,
pos,
query_pos,
)
return self.forward_post(
tgt,
memory,
tgt_mask,
memory_mask,
tgt_key_padding_mask,
memory_key_padding_mask,
pos,
query_pos,
)
def _get_clones(module, N):
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
def _get_activation_fn(activation):
"""Return an activation function given a string"""
if activation == "relu":
return F.relu
if activation == "gelu":
return F.gelu
if activation == "glu":
return F.glu
raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
def build_attn_mask(mask_type):
mask = torch.ones((151, 151), dtype=torch.bool).cuda()
if mask_type == "seperate_all":
mask[:50, :50] = False
mask[50:67, 50:67] = False
mask[67:84, 67:84] = False
mask[84:101, 84:101] = False
mask[101:151, 101:151] = False
elif mask_type == "seperate_view":
mask[:50, :50] = False
mask[50:67, 50:67] = False
mask[67:84, 67:84] = False
mask[84:101, 84:101] = False
mask[101:151, :] = False
mask[:, 101:151] = False
return mask
# class InterfuserModel(nn.Module):
class InterfuserModel(nn.Module):
def __init__(
self,
img_size=224,
multi_view_img_size=112,
patch_size=8,
in_chans=3,
embed_dim=768,
enc_depth=6,
dec_depth=6,
dim_feedforward=2048,
normalize_before=False,
rgb_backbone_name="r50",
lidar_backbone_name="r50",
num_heads=8,
norm_layer=None,
dropout=0.1,
end2end=False,
direct_concat=False,
separate_view_attention=False,
separate_all_attention=False,
act_layer=None,
weight_init="",
freeze_num=-1,
with_lidar=False,
with_right_left_sensors=False,
with_center_sensor=False,
traffic_pred_head_type="det",
waypoints_pred_head="heatmap",
reverse_pos=True,
use_different_backbone=False,
use_view_embed=False,
use_mmad_pretrain=None,
):
super().__init__()
self.traffic_pred_head_type = traffic_pred_head_type
self.num_features = (
self.embed_dim
) = embed_dim # num_features for consistency with other models
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
act_layer = act_layer or nn.GELU
self.reverse_pos = reverse_pos
self.waypoints_pred_head = waypoints_pred_head
self.with_lidar = with_lidar
self.with_right_left_sensors = with_right_left_sensors
self.with_center_sensor = with_center_sensor
self.direct_concat = direct_concat
self.separate_view_attention = separate_view_attention
self.separate_all_attention = separate_all_attention
self.end2end = end2end
self.use_view_embed = use_view_embed
if self.direct_concat:
in_chans = in_chans * 4
self.with_center_sensor = False
self.with_right_left_sensors = False
if self.separate_view_attention:
self.attn_mask = build_attn_mask("seperate_view")
elif self.separate_all_attention:
self.attn_mask = build_attn_mask("seperate_all")
else:
self.attn_mask = None
if use_different_backbone:
if rgb_backbone_name == "r50":
self.rgb_backbone = resnet50d(
pretrained=True,
in_chans=in_chans,
features_only=True,
out_indices=[4],
)
elif rgb_backbone_name == "r26":
self.rgb_backbone = resnet26d(
pretrained=True,
in_chans=in_chans,
features_only=True,
out_indices=[4],
)
elif rgb_backbone_name == "r18":
self.rgb_backbone = resnet18d(
pretrained=True,
in_chans=in_chans,
features_only=True,
out_indices=[4],
)
if lidar_backbone_name == "r50":
self.lidar_backbone = resnet50d(
pretrained=False,
in_chans=in_chans,
features_only=True,
out_indices=[4],
)
elif lidar_backbone_name == "r26":
self.lidar_backbone = resnet26d(
pretrained=False,
in_chans=in_chans,
features_only=True,
out_indices=[4],
)
elif lidar_backbone_name == "r18":
self.lidar_backbone = resnet18d(
pretrained=False, in_chans=3, features_only=True, out_indices=[4]
)
rgb_embed_layer = partial(HybridEmbed, backbone=self.rgb_backbone)
lidar_embed_layer = partial(HybridEmbed, backbone=self.lidar_backbone)
if use_mmad_pretrain:
params = torch.load(use_mmad_pretrain)["state_dict"]
updated_params = OrderedDict()
for key in params:
if "backbone" in key:
updated_params[key.replace("backbone.", "")] = params[key]
self.rgb_backbone.load_state_dict(updated_params)
self.rgb_patch_embed = rgb_embed_layer(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
)
self.lidar_patch_embed = lidar_embed_layer(
img_size=img_size,
patch_size=patch_size,
in_chans=3,
embed_dim=embed_dim,
)
else:
if rgb_backbone_name == "r50":
self.rgb_backbone = resnet50d(
pretrained=True, in_chans=3, features_only=True, out_indices=[4]
)
elif rgb_backbone_name == "r101":
self.rgb_backbone = resnet101d(
pretrained=True, in_chans=3, features_only=True, out_indices=[4]
)
elif rgb_backbone_name == "r26":
self.rgb_backbone = resnet26d(
pretrained=True, in_chans=3, features_only=True, out_indices=[4]
)
elif rgb_backbone_name == "r18":
self.rgb_backbone = resnet18d(
pretrained=True, in_chans=3, features_only=True, out_indices=[4]
)
embed_layer = partial(HybridEmbed, backbone=self.rgb_backbone)
self.rgb_patch_embed = embed_layer(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
)
self.lidar_patch_embed = embed_layer(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
)
self.global_embed = nn.Parameter(torch.zeros(1, embed_dim, 5))
self.view_embed = nn.Parameter(torch.zeros(1, embed_dim, 5, 1))
if self.end2end:
self.query_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, 4))
self.query_embed = nn.Parameter(torch.zeros(4, 1, embed_dim))
elif self.waypoints_pred_head == "heatmap":
self.query_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, 5))
self.query_embed = nn.Parameter(torch.zeros(400 + 5, 1, embed_dim))
else:
self.query_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, 11))
self.query_embed = nn.Parameter(torch.zeros(400 + 11, 1, embed_dim))
if self.end2end:
self.waypoints_generator = GRUWaypointsPredictor(embed_dim, 4)
elif self.waypoints_pred_head == "heatmap":
self.waypoints_generator = MultiPath_Generator(
embed_dim + 32, embed_dim, 10
)
elif self.waypoints_pred_head == "gru":
self.waypoints_generator = GRUWaypointsPredictor(embed_dim)
elif self.waypoints_pred_head == "gru-command":
self.waypoints_generator = GRUWaypointsPredictorWithCommand(embed_dim)
elif self.waypoints_pred_head == "linear":
self.waypoints_generator = LinearWaypointsPredictor(embed_dim)
elif self.waypoints_pred_head == "linear-sum":
self.waypoints_generator = LinearWaypointsPredictor(embed_dim, cumsum=True)
self.junction_pred_head = nn.Linear(embed_dim, 2)
self.traffic_light_pred_head = nn.Linear(embed_dim, 2)
self.stop_sign_head = nn.Linear(embed_dim, 2)
if self.traffic_pred_head_type == "det":
self.traffic_pred_head = nn.Sequential(
*[
nn.Linear(embed_dim + 32, 64),
nn.ReLU(),
nn.Linear(64, 7),
# nn.Sigmoid(),
]
)
elif self.traffic_pred_head_type == "seg":
self.traffic_pred_head = nn.Sequential(
*[nn.Linear(embed_dim, 64), nn.ReLU(), nn.Linear(64, 1), nn.Sigmoid()]
)
self.position_encoding = PositionEmbeddingSine(embed_dim // 2, normalize=True)
encoder_layer = TransformerEncoderLayer(
embed_dim, num_heads, dim_feedforward, dropout, act_layer, normalize_before
)
self.encoder = TransformerEncoder(encoder_layer, enc_depth, None)
decoder_layer = TransformerDecoderLayer(
embed_dim, num_heads, dim_feedforward, dropout, act_layer, normalize_before
)
decoder_norm = nn.LayerNorm(embed_dim)
self.decoder = TransformerDecoder(
decoder_layer, dec_depth, decoder_norm, return_intermediate=False
)
self.reset_parameters()
def reset_parameters(self):
nn.init.uniform_(self.global_embed)
nn.init.uniform_(self.view_embed)
nn.init.uniform_(self.query_embed)
nn.init.uniform_(self.query_pos_embed)
def forward_features(
self,
front_image,
left_image,
right_image,
front_center_image,
lidar,
measurements,
):
features = []
# Front view processing
front_image_token, front_image_token_global = self.rgb_patch_embed(front_image)
if self.use_view_embed:
front_image_token = (
front_image_token
+ self.view_embed[:, :, 0:1, :]
+ self.position_encoding(front_image_token)
)
else:
front_image_token = front_image_token + self.position_encoding(
front_image_token
)
front_image_token = front_image_token.flatten(2).permute(2, 0, 1)
front_image_token_global = (
front_image_token_global
+ self.view_embed[:, :, 0, :]
+ self.global_embed[:, :, 0:1]
)
front_image_token_global = front_image_token_global.permute(2, 0, 1)
features.extend([front_image_token, front_image_token_global])
if self.with_right_left_sensors:
# Left view processing
left_image_token, left_image_token_global = self.rgb_patch_embed(left_image)
if self.use_view_embed:
left_image_token = (
left_image_token
+ self.view_embed[:, :, 1:2, :]
+ self.position_encoding(left_image_token)
)
else:
left_image_token = left_image_token + self.position_encoding(
left_image_token
)
left_image_token = left_image_token.flatten(2).permute(2, 0, 1)
left_image_token_global = (
left_image_token_global
+ self.view_embed[:, :, 1, :]
+ self.global_embed[:, :, 1:2]
)
left_image_token_global = left_image_token_global.permute(2, 0, 1)
# Right view processing
right_image_token, right_image_token_global = self.rgb_patch_embed(
right_image
)
if self.use_view_embed:
right_image_token = (
right_image_token
+ self.view_embed[:, :, 2:3, :]
+ self.position_encoding(right_image_token)
)
else:
right_image_token = right_image_token + self.position_encoding(
right_image_token
)
right_image_token = right_image_token.flatten(2).permute(2, 0, 1)
right_image_token_global = (
right_image_token_global
+ self.view_embed[:, :, 2, :]
+ self.global_embed[:, :, 2:3]
)
right_image_token_global = right_image_token_global.permute(2, 0, 1)
features.extend(
[
left_image_token,
left_image_token_global,
right_image_token,
right_image_token_global,
]
)
if self.with_center_sensor:
# Front center view processing
(
front_center_image_token,
front_center_image_token_global,
) = self.rgb_patch_embed(front_center_image)
if self.use_view_embed:
front_center_image_token = (
front_center_image_token
+ self.view_embed[:, :, 3:4, :]
+ self.position_encoding(front_center_image_token)
)
else:
front_center_image_token = (
front_center_image_token
+ self.position_encoding(front_center_image_token)
)
front_center_image_token = front_center_image_token.flatten(2).permute(
2, 0, 1
)
front_center_image_token_global = (
front_center_image_token_global
+ self.view_embed[:, :, 3, :]
+ self.global_embed[:, :, 3:4]
)
front_center_image_token_global = front_center_image_token_global.permute(
2, 0, 1
)
features.extend([front_center_image_token, front_center_image_token_global])
if self.with_lidar:
lidar_token, lidar_token_global = self.lidar_patch_embed(lidar)
if self.use_view_embed:
lidar_token = (
lidar_token
+ self.view_embed[:, :, 4:5, :]
+ self.position_encoding(lidar_token)
)
else:
lidar_token = lidar_token + self.position_encoding(lidar_token)
lidar_token = lidar_token.flatten(2).permute(2, 0, 1)
lidar_token_global = (
lidar_token_global
+ self.view_embed[:, :, 4, :]
+ self.global_embed[:, :, 4:5]
)
lidar_token_global = lidar_token_global.permute(2, 0, 1)
features.extend([lidar_token, lidar_token_global])
features = torch.cat(features, 0)
return features
def forward(self, x):
front_image = x["rgb"]
left_image = x["rgb_left"]
right_image = x["rgb_right"]
front_center_image = x["rgb_center"]
measurements = x["measurements"]
target_point = x["target_point"]
lidar = x["lidar"]
if self.direct_concat:
img_size = front_image.shape[-1]
left_image = torch.nn.functional.interpolate(
left_image, size=(img_size, img_size)
)
right_image = torch.nn.functional.interpolate(
right_image, size=(img_size, img_size)
)
front_center_image = torch.nn.functional.interpolate(
front_center_image, size=(img_size, img_size)
)
front_image = torch.cat(
[front_image, left_image, right_image, front_center_image], dim=1
)
features = self.forward_features(
front_image,
left_image,
right_image,
front_center_image,
lidar,
measurements,
)
bs = front_image.shape[0]
if self.end2end:
tgt = self.query_pos_embed.repeat(bs, 1, 1)
else:
tgt = self.position_encoding(
torch.ones((bs, 1, 20, 20), device=x["rgb"].device)
)
tgt = tgt.flatten(2)
tgt = torch.cat([tgt, self.query_pos_embed.repeat(bs, 1, 1)], 2)
tgt = tgt.permute(2, 0, 1)
memory = self.encoder(features, mask=self.attn_mask)
hs = self.decoder(self.query_embed.repeat(1, bs, 1), memory, query_pos=tgt)[0]
hs = hs.permute(1, 0, 2) # Batchsize , N, C
if self.end2end:
waypoints = self.waypoints_generator(hs, target_point)
return waypoints
if self.waypoints_pred_head != "heatmap":
traffic_feature = hs[:, :400]
is_junction_feature = hs[:, 400]
traffic_light_state_feature = hs[:, 400]
stop_sign_feature = hs[:, 400]
waypoints_feature = hs[:, 401:411]
else:
traffic_feature = hs[:, :400]
is_junction_feature = hs[:, 400]
traffic_light_state_feature = hs[:, 400]
stop_sign_feature = hs[:, 400]
waypoints_feature = hs[:, 401:405]
if self.waypoints_pred_head == "heatmap":
waypoints = self.waypoints_generator(waypoints_feature, measurements)
elif self.waypoints_pred_head == "gru":
waypoints = self.waypoints_generator(waypoints_feature, target_point)
elif self.waypoints_pred_head == "gru-command":
waypoints = self.waypoints_generator(waypoints_feature, target_point, measurements)
elif self.waypoints_pred_head == "linear":
waypoints = self.waypoints_generator(waypoints_feature, measurements)
elif self.waypoints_pred_head == "linear-sum":
waypoints = self.waypoints_generator(waypoints_feature, measurements)
is_junction = self.junction_pred_head(is_junction_feature)
traffic_light_state = self.traffic_light_pred_head(traffic_light_state_feature)
stop_sign = self.stop_sign_head(stop_sign_feature)
velocity = measurements[:, 6:7].unsqueeze(-1)
velocity = velocity.repeat(1, 400, 32)
traffic_feature_with_vel = torch.cat([traffic_feature, velocity], dim=2)
traffic = self.traffic_pred_head(traffic_feature_with_vel)
return traffic, waypoints, is_junction, traffic_light_state, stop_sign, traffic_feature
def load_pretrained(self, model_path, strict=False):
"""
تحميل الأوزان المدربة مسبقاً - نسخة محسنة
Args:
model_path (str): مسار ملف الأوزان
strict (bool): إذا كان True، يتطلب تطابق تام للمفاتيح
"""
if not model_path or not Path(model_path).exists():
logging.warning(f"ملف الأوزان غير موجود: {model_path}")
logging.info("سيتم استخدام أوزان عشوائية")
return False
try:
logging.info(f"محاولة تحميل الأوزان من: {model_path}")
# تحميل الملف مع معالجة أنواع مختلفة من ملفات الحفظ
checkpoint = torch.load(model_path, map_location='cpu', weights_only=False)
# استخراج state_dict من أنواع مختلفة من ملفات الحفظ
if isinstance(checkpoint, dict):
if 'model_state_dict' in checkpoint:
state_dict = checkpoint['model_state_dict']
logging.info("تم العثور على 'model_state_dict' في الملف")
elif 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
logging.info("تم العثور على 'state_dict' في الملف")
elif 'model' in checkpoint:
state_dict = checkpoint['model']
logging.info("تم العثور على 'model' في الملف")
else:
state_dict = checkpoint
logging.info("استخدام الملف كـ state_dict مباشرة")
else:
state_dict = checkpoint
logging.info("استخدام الملف كـ state_dict مباشرة")
# تنظيف أسماء المفاتيح (إزالة 'module.' إذا كانت موجودة)
clean_state_dict = OrderedDict()
for k, v in state_dict.items():
# إزالة 'module.' من بداية اسم المفتاح إذا كان موجوداً
clean_key = k[7:] if k.startswith('module.') else k
clean_state_dict[clean_key] = v
# تحميل الأوزان
missing_keys, unexpected_keys = self.load_state_dict(clean_state_dict, strict=strict)
# تقرير حالة التحميل
if missing_keys:
logging.warning(f"مفاتيح مفقودة ({len(missing_keys)}): {missing_keys[:5]}..." if len(missing_keys) > 5 else f"مفاتيح مفقودة: {missing_keys}")
if unexpected_keys:
logging.warning(f"مفاتيح غير متوقعة ({len(unexpected_keys)}): {unexpected_keys[:5]}..." if len(unexpected_keys) > 5 else f"مفاتيح غير متوقعة: {unexpected_keys}")
if not missing_keys and not unexpected_keys:
logging.info("✅ تم تحميل جميع الأوزان بنجاح تام")
elif not strict:
logging.info("✅ تم تحميل الأوزان بنجاح (مع تجاهل عدم التطابق)")
return True
except Exception as e:
logging.error(f"❌ خطأ في تحميل الأوزان: {str(e)}")
logging.info("سيتم استخدام أوزان عشوائية")
return False
# ============================================================================
# دوال مساعدة لتحميل النموذج
# ============================================================================
def load_and_prepare_model(config, device):
"""
يقوم بإنشاء النموذج وتحميل الأوزان المدربة مسبقًا.
Args:
config (dict): إعدادات النموذج والمسارات
device (torch.device): الجهاز المستهدف (CPU/GPU)
Returns:
InterfuserModel: النموذج المحمل
"""
try:
# إنشاء النموذج
model = InterfuserModel(**config.get('model_params', {})).to(device)
logging.info(f"تم إنشاء النموذج على الجهاز: {device}")
# تحميل الأوزان إذا كان المسار محدد
checkpoint_path = config.get('paths', {}).get('pretrained_weights')
if checkpoint_path:
success = model.load_pretrained(checkpoint_path, strict=False)
if success:
logging.info("✅ تم تحميل النموذج والأوزان بنجاح")
else:
logging.warning("⚠️ تم إنشاء النموذج بأوزان عشوائية")
else:
logging.info("لم يتم تحديد مسار الأوزان، سيتم استخدام أوزان عشوائية")
# وضع النموذج في وضع التقييم
model.eval()
return model
except Exception as e:
logging.error(f"خطأ في إنشاء النموذج: {str(e)}")
raise
def create_model_config(model_path="model/best_model.pth", **model_params):
"""
إنشاء إعدادات النموذج باستخدام الإعدادات الصحيحة من التدريب
Args:
model_path (str): مسار ملف الأوزان
**model_params: معاملات النموذج الإضافية
Returns:
dict: إعدادات النموذج
"""
# الإعدادات الصحيحة من كونفيج التدريب الأصلي
training_config_params = {
"img_size": 224,
"embed_dim": 256, # مهم: هذه القيمة من التدريب الأصلي
"enc_depth": 6,
"dec_depth": 6,
"rgb_backbone_name": 'r50',
"lidar_backbone_name": 'r18',
"waypoints_pred_head": 'gru',
"use_different_backbone": True,
"with_lidar": False,
"with_right_left_sensors": False,
"with_center_sensor": False,
# إعدادات إضافية من الكونفيج الأصلي
"multi_view_img_size": 112,
"patch_size": 8,
"in_chans": 3,
"dim_feedforward": 2048,
"normalize_before": False,
"num_heads": 8,
"dropout": 0.1,
"end2end": False,
"direct_concat": False,
"separate_view_attention": False,
"separate_all_attention": False,
"freeze_num": -1,
"traffic_pred_head_type": "det",
"reverse_pos": True,
"use_view_embed": False,
"use_mmad_pretrain": None,
}
# دمج المعاملات المخصصة مع الإعدادات من التدريب
training_config_params.update(model_params)
config = {
'model_params': training_config_params,
'paths': {
'pretrained_weights': model_path
},
# إضافة إعدادات الشبكة من التدريب
'grid_conf': {
'h': 20, 'w': 20,
'x_res': 1.0, 'y_res': 1.0,
'y_min': 0.0, 'y_max': 20.0,
'x_min': -10.0, 'x_max': 10.0,
},
# معلومات إضافية عن التدريب
'training_info': {
'original_project': 'Interfuser_Finetuning',
'run_name': 'Finetune_Focus_on_Detection_v5',
'focus': 'traffic_detection_and_iou',
'backbone': 'ResNet50 + ResNet18',
'trained_on': 'PDM_Lite_Carla'
}
}
return config
def get_training_config():
"""
إرجاع إعدادات التدريب الأصلية للمرجع
هذه الإعدادات توضح كيف تم تدريب النموذج
"""
return {
'project_info': {
'project': 'Interfuser_Finetuning',
'entity': None,
'run_name': 'Finetune_Focus_on_Detection_v5'
},
'training': {
'epochs': 50,
'batch_size': 8,
'num_workers': 2,
'learning_rate': 1e-4, # معدل تعلم منخفض للـ Fine-tuning
'weight_decay': 1e-2,
'patience': 15,
'clip_grad_norm': 1.0,
},
'loss_weights': {
'iou': 2.0, # أولوية قصوى لدقة الصناديق
'traffic_map': 25.0, # تركيز عالي على اكتشاف الكائنات
'waypoints': 1.0, # مرجع أساسي
'junction': 0.25, # مهام متقنة بالفعل
'traffic_light': 0.5,
'stop_sign': 0.25,
},
'data_split': {
'strategy': 'interleaved',
'segment_length': 100,
'validation_frequency': 10,
},
'transforms': {
'use_data_augmentation': False, # معطل للتركيز على البيانات الأصلية
}
}