Spaces:
Running
Running
# model_definition.py | |
# ============================================================================ | |
# الاستيرادات الأساسية | |
# ============================================================================ | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.optim as optim | |
from torch.optim import AdamW | |
from torch.optim.lr_scheduler import OneCycleLR | |
from torch.utils.data import Dataset, DataLoader | |
from torchvision import transforms | |
from functools import partial | |
from typing import Optional, List | |
from torch import Tensor | |
import os | |
import json | |
import numpy as np | |
import cv2 | |
from PIL import Image | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torchvision import transforms | |
from functools import partial | |
from collections import deque, OrderedDict | |
import math | |
from torch.nn import MultiheadAttention | |
from torch.nn import TransformerEncoder, TransformerEncoderLayer | |
from torch.nn import TransformerDecoder, TransformerDecoderLayer | |
from timm.models.resnet import resnet50d, resnet26d, resnet18d | |
try: | |
from timm.layers import trunc_normal_ | |
except ImportError: | |
from timm.models.layers import trunc_normal_ | |
from huggingface_hub import hf_hub_download, HfApi | |
from huggingface_hub.utils import HfFolder | |
# مكتبات إضافية | |
import os | |
import json | |
import logging | |
import math | |
import copy | |
from pathlib import Path | |
from collections import OrderedDict | |
# مكتبات معالجة البيانات | |
import numpy as np | |
import cv2 | |
# مكتبات اختيارية (يمكن تعطيلها إذا لم تكن متوفرة) | |
try: | |
from tqdm import tqdm | |
except ImportError: | |
# إذا لم تكن tqdm متوفرة، استخدم دالة بديلة | |
def tqdm(iterable, *args, **kwargs): | |
return iterable | |
# ============================================================================ | |
# دوال مساعدة | |
# ============================================================================ | |
def to_2tuple(x): | |
"""تحويل قيمة إلى tuple من عنصرين""" | |
if isinstance(x, (list, tuple)): | |
return tuple(x) | |
return (x, x) | |
# ============================================================================ | |
# ============================================================================ | |
class HybridEmbed(nn.Module): | |
def __init__( | |
self, | |
backbone, | |
img_size=224, | |
patch_size=1, | |
feature_size=None, | |
in_chans=3, | |
embed_dim=768, | |
): | |
super().__init__() | |
assert isinstance(backbone, nn.Module) | |
img_size = to_2tuple(img_size) | |
patch_size = to_2tuple(patch_size) | |
self.img_size = img_size | |
self.patch_size = patch_size | |
self.backbone = backbone | |
if feature_size is None: | |
with torch.no_grad(): | |
training = backbone.training | |
if training: | |
backbone.eval() | |
o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1])) | |
if isinstance(o, (list, tuple)): | |
o = o[-1] # last feature if backbone outputs list/tuple of features | |
feature_size = o.shape[-2:] | |
feature_dim = o.shape[1] | |
backbone.train(training) | |
else: | |
feature_size = to_2tuple(feature_size) | |
if hasattr(self.backbone, "feature_info"): | |
feature_dim = self.backbone.feature_info.channels()[-1] | |
else: | |
feature_dim = self.backbone.num_features | |
self.proj = nn.Conv2d(feature_dim, embed_dim, kernel_size=1, stride=1) | |
def forward(self, x): | |
x = self.backbone(x) | |
if isinstance(x, (list, tuple)): | |
x = x[-1] # last feature if backbone outputs list/tuple of features | |
x = self.proj(x) | |
global_x = torch.mean(x, [2, 3], keepdim=False)[:, :, None] | |
return x, global_x | |
class HyperDimensionalPositionalEncoding(nn.Module): | |
""" | |
[GCPE v1.1 - Professional & Corrected Implementation] | |
A novel positional encoding scheme based on geometric centrality. | |
This class is designed as a drop-in replacement for the standard | |
PositionEmbeddingSine, accepting similar arguments and producing an | |
output of the same shape. This version corrects a type error in the | |
distance calculation. | |
""" | |
def __init__(self, num_pos_feats=256, temperature=10000, normalize=True, scale=None): | |
""" | |
Args: | |
num_pos_feats (int): The desired number of output channels for the positional encoding. | |
This must be an even number. | |
temperature (int): A constant used to scale the frequencies. | |
normalize (bool): If True, normalizes the coordinates to the range [0, scale]. | |
scale (float, optional): The scaling factor for normalization. Defaults to 2*pi. | |
""" | |
super().__init__() | |
if num_pos_feats % 2 != 0: | |
raise ValueError(f"num_pos_feats must be an even number, but got {num_pos_feats}") | |
self.num_pos_feats = num_pos_feats | |
self.temperature = temperature | |
self.normalize = normalize | |
if scale is not None and not normalize: | |
raise ValueError("normalize should be True if scale is passed") | |
if scale is None: | |
scale = 2 * math.pi | |
self.scale = scale | |
def forward(self, tensor: torch.Tensor) -> torch.Tensor: | |
""" | |
Args: | |
tensor (torch.Tensor): A 4D tensor of shape (B, C, H, W). The content is not | |
used, only its shape and device. | |
Returns: | |
torch.Tensor: A 4D tensor of positional encodings with shape (B, num_pos_feats, H, W). | |
""" | |
batch_size, _, h, w = tensor.shape | |
device = tensor.device | |
# 1. Create coordinate grids | |
y_embed = torch.arange(h, dtype=torch.float32, device=device).view(h, 1) | |
x_embed = torch.arange(w, dtype=torch.float32, device=device).view(1, w) | |
# 2. Calculate normalized distance from the center | |
# Use floating point division for center calculation | |
center_y, center_x = (h - 1) / 2.0, (w - 1) / 2.0 | |
# Calculate the Euclidean distance for each pixel from the center | |
dist_map = torch.sqrt( | |
(y_embed - center_y)**2 + (x_embed - center_x)**2 | |
) | |
# ✅ CORRECTION: The max distance is a scalar, no need for torch.sqrt on a float. | |
# We can calculate it with math.sqrt or just compute the squared value. | |
# To keep everything in tensors for consistency, we can do this: | |
max_dist_sq = torch.tensor(center_y**2 + center_x**2, device=device) | |
max_dist = torch.sqrt(max_dist_sq) | |
# Normalize the distance map to the range [0, 1] | |
normalized_dist_map = dist_map / (max_dist + 1e-6) | |
if self.normalize: | |
normalized_dist_map = normalized_dist_map * self.scale | |
pos_dist = normalized_dist_map.unsqueeze(0).repeat(batch_size, 1, 1) | |
# 3. Create the frequency-based embedding | |
# This part remains the same as it operates on tensors correctly. | |
dim_t = torch.arange(self.num_pos_feats // 2, dtype=torch.float32, device=device) | |
dim_t = self.temperature ** (2 * dim_t / (self.num_pos_feats // 2)) | |
pos = pos_dist.unsqueeze(-1) / dim_t | |
pos_sin = pos.sin() | |
pos_cos = pos.cos() | |
# 4. Concatenate and reshape to match the desired output format | |
pos = torch.cat((pos_sin, pos_cos), dim=3) | |
pos = pos.permute(0, 3, 1, 2) | |
return pos | |
class TransformerEncoder(nn.Module): | |
def __init__(self, encoder_layer, num_layers, norm=None): | |
super().__init__() | |
self.layers = _get_clones(encoder_layer, num_layers) | |
self.num_layers = num_layers | |
self.norm = norm | |
def forward( | |
self, | |
src, | |
mask: Optional[Tensor] = None, | |
src_key_padding_mask: Optional[Tensor] = None, | |
pos: Optional[Tensor] = None, | |
): | |
output = src | |
for layer in self.layers: | |
output = layer( | |
output, | |
src_mask=mask, | |
src_key_padding_mask=src_key_padding_mask, | |
pos=pos, | |
) | |
if self.norm is not None: | |
output = self.norm(output) | |
return output | |
class SpatialSoftmax(nn.Module): | |
def __init__(self, height, width, channel, temperature=None, data_format="NCHW"): | |
super().__init__() | |
self.data_format = data_format | |
self.height = height | |
self.width = width | |
self.channel = channel | |
if temperature: | |
self.temperature = Parameter(torch.ones(1) * temperature) | |
else: | |
self.temperature = 1.0 | |
pos_x, pos_y = np.meshgrid( | |
np.linspace(-1.0, 1.0, self.height), np.linspace(-1.0, 1.0, self.width) | |
) | |
pos_x = torch.from_numpy(pos_x.reshape(self.height * self.width)).float() | |
pos_y = torch.from_numpy(pos_y.reshape(self.height * self.width)).float() | |
self.register_buffer("pos_x", pos_x) | |
self.register_buffer("pos_y", pos_y) | |
def forward(self, feature): | |
# Output: | |
# (N, C*2) x_0 y_0 ... | |
if self.data_format == "NHWC": | |
feature = ( | |
feature.transpose(1, 3) | |
.tranpose(2, 3) | |
.view(-1, self.height * self.width) | |
) | |
else: | |
feature = feature.view(-1, self.height * self.width) | |
weight = F.softmax(feature / self.temperature, dim=-1) | |
expected_x = torch.sum( | |
torch.autograd.Variable(self.pos_x) * weight, dim=1, keepdim=True | |
) | |
expected_y = torch.sum( | |
torch.autograd.Variable(self.pos_y) * weight, dim=1, keepdim=True | |
) | |
expected_xy = torch.cat([expected_x, expected_y], 1) | |
feature_keypoints = expected_xy.view(-1, self.channel, 2) | |
feature_keypoints[:, :, 1] = (feature_keypoints[:, :, 1] - 1) * 12 | |
feature_keypoints[:, :, 0] = feature_keypoints[:, :, 0] * 12 | |
return feature_keypoints | |
class MultiPath_Generator(nn.Module): | |
def __init__(self, in_channel, embed_dim, out_channel): | |
super().__init__() | |
self.spatial_softmax = SpatialSoftmax(100, 100, out_channel) | |
self.tconv0 = nn.Sequential( | |
nn.ConvTranspose2d(in_channel, 256, 4, 2, 1, bias=False), | |
nn.BatchNorm2d(256), | |
nn.ReLU(True), | |
) | |
self.tconv1 = nn.Sequential( | |
nn.ConvTranspose2d(256, 256, 4, 2, 1, bias=False), | |
nn.BatchNorm2d(256), | |
nn.ReLU(True), | |
) | |
self.tconv2 = nn.Sequential( | |
nn.ConvTranspose2d(256, 192, 4, 2, 1, bias=False), | |
nn.BatchNorm2d(192), | |
nn.ReLU(True), | |
) | |
self.tconv3 = nn.Sequential( | |
nn.ConvTranspose2d(192, 64, 4, 2, 1, bias=False), | |
nn.BatchNorm2d(64), | |
nn.ReLU(True), | |
) | |
self.tconv4_list = torch.nn.ModuleList( | |
[ | |
nn.Sequential( | |
nn.ConvTranspose2d(64, out_channel, 8, 2, 3, bias=False), | |
nn.Tanh(), | |
) | |
for _ in range(6) | |
] | |
) | |
self.upsample = nn.Upsample(size=(50, 50), mode="bilinear") | |
def forward(self, x, measurements): | |
mask = measurements[:, :6] | |
mask = mask.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).repeat(1, 1, 1, 100, 100) | |
velocity = measurements[:, 6:7].unsqueeze(-1).unsqueeze(-1) | |
velocity = velocity.repeat(1, 32, 2, 2) | |
n, d, c = x.shape | |
x = x.transpose(1, 2) | |
x = x.view(n, -1, 2, 2) | |
x = torch.cat([x, velocity], dim=1) | |
x = self.tconv0(x) | |
x = self.tconv1(x) | |
x = self.tconv2(x) | |
x = self.tconv3(x) | |
x = self.upsample(x) | |
xs = [] | |
for i in range(6): | |
xt = self.tconv4_list[i](x) | |
xs.append(xt) | |
xs = torch.stack(xs, dim=1) | |
x = torch.sum(xs * mask, dim=1) | |
x = self.spatial_softmax(x) | |
return x | |
class LinearWaypointsPredictor(nn.Module): | |
def __init__(self, input_dim, cumsum=True): | |
super().__init__() | |
self.cumsum = cumsum | |
self.rank_embed = nn.Parameter(torch.zeros(1, 10, input_dim)) | |
self.head_fc1_list = nn.ModuleList([nn.Linear(input_dim, 64) for _ in range(6)]) | |
self.head_relu = nn.ReLU(inplace=True) | |
self.head_fc2_list = nn.ModuleList([nn.Linear(64, 2) for _ in range(6)]) | |
def forward(self, x, measurements): | |
# input shape: n 10 embed_dim | |
bs, n, dim = x.shape | |
x = x + self.rank_embed | |
x = x.reshape(-1, dim) | |
mask = measurements[:, :6] | |
mask = torch.unsqueeze(mask, -1).repeat(n, 1, 2) | |
rs = [] | |
for i in range(6): | |
res = self.head_fc1_list[i](x) | |
res = self.head_relu(res) | |
res = self.head_fc2_list[i](res) | |
rs.append(res) | |
rs = torch.stack(rs, 1) | |
x = torch.sum(rs * mask, dim=1) | |
x = x.view(bs, n, 2) | |
if self.cumsum: | |
x = torch.cumsum(x, 1) | |
return x | |
class GRUWaypointsPredictor(nn.Module): | |
def __init__(self, input_dim, waypoints=10): | |
super().__init__() | |
# self.gru = torch.nn.GRUCell(input_size=input_dim, hidden_size=64) | |
self.gru = torch.nn.GRU(input_size=input_dim, hidden_size=64, batch_first=True) | |
self.encoder = nn.Linear(2, 64) | |
self.decoder = nn.Linear(64, 2) | |
self.waypoints = waypoints | |
def forward(self, x, target_point): | |
bs = x.shape[0] | |
z = self.encoder(target_point).unsqueeze(0) | |
output, _ = self.gru(x, z) | |
output = output.reshape(bs * self.waypoints, -1) | |
output = self.decoder(output).reshape(bs, self.waypoints, 2) | |
output = torch.cumsum(output, 1) | |
return output | |
class GRUWaypointsPredictorWithCommand(nn.Module): | |
def __init__(self, input_dim, waypoints=10): | |
super().__init__() | |
# self.gru = torch.nn.GRUCell(input_size=input_dim, hidden_size=64) | |
self.grus = nn.ModuleList([torch.nn.GRU(input_size=input_dim, hidden_size=64, batch_first=True) for _ in range(6)]) | |
self.encoder = nn.Linear(2, 64) | |
self.decoders = nn.ModuleList([nn.Linear(64, 2) for _ in range(6)]) | |
self.waypoints = waypoints | |
def forward(self, x, target_point, measurements): | |
bs, n, dim = x.shape | |
mask = measurements[:, :6, None, None] | |
mask = mask.repeat(1, 1, self.waypoints, 2) | |
z = self.encoder(target_point).unsqueeze(0) | |
outputs = [] | |
for i in range(6): | |
output, _ = self.grus[i](x, z) | |
output = output.reshape(bs * self.waypoints, -1) | |
output = self.decoders[i](output).reshape(bs, self.waypoints, 2) | |
output = torch.cumsum(output, 1) | |
outputs.append(output) | |
outputs = torch.stack(outputs, 1) | |
output = torch.sum(outputs * mask, dim=1) | |
return output | |
class TransformerDecoder(nn.Module): | |
def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False): | |
super().__init__() | |
self.layers = _get_clones(decoder_layer, num_layers) | |
self.num_layers = num_layers | |
self.norm = norm | |
self.return_intermediate = return_intermediate | |
def forward( | |
self, | |
tgt, | |
memory, | |
tgt_mask: Optional[Tensor] = None, | |
memory_mask: Optional[Tensor] = None, | |
tgt_key_padding_mask: Optional[Tensor] = None, | |
memory_key_padding_mask: Optional[Tensor] = None, | |
pos: Optional[Tensor] = None, | |
query_pos: Optional[Tensor] = None, | |
): | |
output = tgt | |
intermediate = [] | |
for layer in self.layers: | |
output = layer( | |
output, | |
memory, | |
tgt_mask=tgt_mask, | |
memory_mask=memory_mask, | |
tgt_key_padding_mask=tgt_key_padding_mask, | |
memory_key_padding_mask=memory_key_padding_mask, | |
pos=pos, | |
query_pos=query_pos, | |
) | |
if self.return_intermediate: | |
intermediate.append(self.norm(output)) | |
if self.norm is not None: | |
output = self.norm(output) | |
if self.return_intermediate: | |
intermediate.pop() | |
intermediate.append(output) | |
if self.return_intermediate: | |
return torch.stack(intermediate) | |
return output.unsqueeze(0) | |
class TransformerEncoderLayer(nn.Module): | |
def __init__( | |
self, | |
d_model, | |
nhead, | |
dim_feedforward=2048, | |
dropout=0.1, | |
activation=nn.ReLU(), | |
normalize_before=False, | |
): | |
super().__init__() | |
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) | |
# Implementation of Feedforward model | |
self.linear1 = nn.Linear(d_model, dim_feedforward) | |
self.dropout = nn.Dropout(dropout) | |
self.linear2 = nn.Linear(dim_feedforward, d_model) | |
self.norm1 = nn.LayerNorm(d_model) | |
self.norm2 = nn.LayerNorm(d_model) | |
self.dropout1 = nn.Dropout(dropout) | |
self.dropout2 = nn.Dropout(dropout) | |
self.activation = activation() | |
self.normalize_before = normalize_before | |
def with_pos_embed(self, tensor, pos: Optional[Tensor]): | |
return tensor if pos is None else tensor + pos | |
def forward_post( | |
self, | |
src, | |
src_mask: Optional[Tensor] = None, | |
src_key_padding_mask: Optional[Tensor] = None, | |
pos: Optional[Tensor] = None, | |
): | |
q = k = self.with_pos_embed(src, pos) | |
src2 = self.self_attn( | |
q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask | |
)[0] | |
src = src + self.dropout1(src2) | |
src = self.norm1(src) | |
src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) | |
src = src + self.dropout2(src2) | |
src = self.norm2(src) | |
return src | |
def forward_pre( | |
self, | |
src, | |
src_mask: Optional[Tensor] = None, | |
src_key_padding_mask: Optional[Tensor] = None, | |
pos: Optional[Tensor] = None, | |
): | |
src2 = self.norm1(src) | |
q = k = self.with_pos_embed(src2, pos) | |
src2 = self.self_attn( | |
q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask | |
)[0] | |
src = src + self.dropout1(src2) | |
src2 = self.norm2(src) | |
src2 = self.linear2(self.dropout(self.activation(self.linear1(src2)))) | |
src = src + self.dropout2(src2) | |
return src | |
def forward( | |
self, | |
src, | |
src_mask: Optional[Tensor] = None, | |
src_key_padding_mask: Optional[Tensor] = None, | |
pos: Optional[Tensor] = None, | |
): | |
if self.normalize_before: | |
return self.forward_pre(src, src_mask, src_key_padding_mask, pos) | |
return self.forward_post(src, src_mask, src_key_padding_mask, pos) | |
class TransformerDecoderLayer(nn.Module): | |
def __init__( | |
self, | |
d_model, | |
nhead, | |
dim_feedforward=2048, | |
dropout=0.1, | |
activation=nn.ReLU(), | |
normalize_before=False, | |
): | |
super().__init__() | |
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) | |
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) | |
# Implementation of Feedforward model | |
self.linear1 = nn.Linear(d_model, dim_feedforward) | |
self.dropout = nn.Dropout(dropout) | |
self.linear2 = nn.Linear(dim_feedforward, d_model) | |
self.norm1 = nn.LayerNorm(d_model) | |
self.norm2 = nn.LayerNorm(d_model) | |
self.norm3 = nn.LayerNorm(d_model) | |
self.dropout1 = nn.Dropout(dropout) | |
self.dropout2 = nn.Dropout(dropout) | |
self.dropout3 = nn.Dropout(dropout) | |
self.activation = activation() | |
self.normalize_before = normalize_before | |
def with_pos_embed(self, tensor, pos: Optional[Tensor]): | |
return tensor if pos is None else tensor + pos | |
def forward_post( | |
self, | |
tgt, | |
memory, | |
tgt_mask: Optional[Tensor] = None, | |
memory_mask: Optional[Tensor] = None, | |
tgt_key_padding_mask: Optional[Tensor] = None, | |
memory_key_padding_mask: Optional[Tensor] = None, | |
pos: Optional[Tensor] = None, | |
query_pos: Optional[Tensor] = None, | |
): | |
q = k = self.with_pos_embed(tgt, query_pos) | |
tgt2 = self.self_attn( | |
q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask | |
)[0] | |
tgt = tgt + self.dropout1(tgt2) | |
tgt = self.norm1(tgt) | |
tgt2 = self.multihead_attn( | |
query=self.with_pos_embed(tgt, query_pos), | |
key=self.with_pos_embed(memory, pos), | |
value=memory, | |
attn_mask=memory_mask, | |
key_padding_mask=memory_key_padding_mask, | |
)[0] | |
tgt = tgt + self.dropout2(tgt2) | |
tgt = self.norm2(tgt) | |
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) | |
tgt = tgt + self.dropout3(tgt2) | |
tgt = self.norm3(tgt) | |
return tgt | |
def forward_pre( | |
self, | |
tgt, | |
memory, | |
tgt_mask: Optional[Tensor] = None, | |
memory_mask: Optional[Tensor] = None, | |
tgt_key_padding_mask: Optional[Tensor] = None, | |
memory_key_padding_mask: Optional[Tensor] = None, | |
pos: Optional[Tensor] = None, | |
query_pos: Optional[Tensor] = None, | |
): | |
tgt2 = self.norm1(tgt) | |
q = k = self.with_pos_embed(tgt2, query_pos) | |
tgt2 = self.self_attn( | |
q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask | |
)[0] | |
tgt = tgt + self.dropout1(tgt2) | |
tgt2 = self.norm2(tgt) | |
tgt2 = self.multihead_attn( | |
query=self.with_pos_embed(tgt2, query_pos), | |
key=self.with_pos_embed(memory, pos), | |
value=memory, | |
attn_mask=memory_mask, | |
key_padding_mask=memory_key_padding_mask, | |
)[0] | |
tgt = tgt + self.dropout2(tgt2) | |
tgt2 = self.norm3(tgt) | |
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) | |
tgt = tgt + self.dropout3(tgt2) | |
return tgt | |
def forward( | |
self, | |
tgt, | |
memory, | |
tgt_mask: Optional[Tensor] = None, | |
memory_mask: Optional[Tensor] = None, | |
tgt_key_padding_mask: Optional[Tensor] = None, | |
memory_key_padding_mask: Optional[Tensor] = None, | |
pos: Optional[Tensor] = None, | |
query_pos: Optional[Tensor] = None, | |
): | |
if self.normalize_before: | |
return self.forward_pre( | |
tgt, | |
memory, | |
tgt_mask, | |
memory_mask, | |
tgt_key_padding_mask, | |
memory_key_padding_mask, | |
pos, | |
query_pos, | |
) | |
return self.forward_post( | |
tgt, | |
memory, | |
tgt_mask, | |
memory_mask, | |
tgt_key_padding_mask, | |
memory_key_padding_mask, | |
pos, | |
query_pos, | |
) | |
def _get_clones(module, N): | |
return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) | |
def _get_activation_fn(activation): | |
"""Return an activation function given a string""" | |
if activation == "relu": | |
return F.relu | |
if activation == "gelu": | |
return F.gelu | |
if activation == "glu": | |
return F.glu | |
raise RuntimeError(f"activation should be relu/gelu, not {activation}.") | |
def build_attn_mask(mask_type): | |
mask = torch.ones((151, 151), dtype=torch.bool).cuda() | |
if mask_type == "seperate_all": | |
mask[:50, :50] = False | |
mask[50:67, 50:67] = False | |
mask[67:84, 67:84] = False | |
mask[84:101, 84:101] = False | |
mask[101:151, 101:151] = False | |
elif mask_type == "seperate_view": | |
mask[:50, :50] = False | |
mask[50:67, 50:67] = False | |
mask[67:84, 67:84] = False | |
mask[84:101, 84:101] = False | |
mask[101:151, :] = False | |
mask[:, 101:151] = False | |
return mask | |
# class InterfuserModel(nn.Module): | |
class InterfuserHDPE(nn.Module): | |
def __init__( | |
self, | |
img_size=224, | |
multi_view_img_size=112, | |
patch_size=8, | |
in_chans=3, | |
embed_dim=768, | |
enc_depth=6, | |
dec_depth=6, | |
dim_feedforward=2048, | |
normalize_before=False, | |
rgb_backbone_name="r50", | |
lidar_backbone_name="r50", | |
num_heads=8, | |
norm_layer=None, | |
dropout=0.1, | |
end2end=False, | |
direct_concat=False, | |
separate_view_attention=False, | |
separate_all_attention=False, | |
act_layer=None, | |
weight_init="", | |
freeze_num=-1, | |
with_lidar=False, | |
with_right_left_sensors=False, | |
with_center_sensor=False, | |
traffic_pred_head_type="det", | |
waypoints_pred_head="heatmap", | |
reverse_pos=True, | |
use_different_backbone=False, | |
use_view_embed=False, | |
use_mmad_pretrain=None, | |
): | |
super().__init__() | |
self.traffic_pred_head_type = traffic_pred_head_type | |
self.num_features = ( | |
self.embed_dim | |
) = embed_dim # num_features for consistency with other models | |
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) | |
act_layer = act_layer or nn.GELU | |
self.reverse_pos = reverse_pos | |
self.waypoints_pred_head = waypoints_pred_head | |
self.with_lidar = with_lidar | |
self.with_right_left_sensors = with_right_left_sensors | |
self.with_center_sensor = with_center_sensor | |
self.direct_concat = direct_concat | |
self.separate_view_attention = separate_view_attention | |
self.separate_all_attention = separate_all_attention | |
self.end2end = end2end | |
self.use_view_embed = use_view_embed | |
if self.direct_concat: | |
in_chans = in_chans * 4 | |
self.with_center_sensor = False | |
self.with_right_left_sensors = False | |
if self.separate_view_attention: | |
self.attn_mask = build_attn_mask("seperate_view") | |
elif self.separate_all_attention: | |
self.attn_mask = build_attn_mask("seperate_all") | |
else: | |
self.attn_mask = None | |
if use_different_backbone: | |
if rgb_backbone_name == "r50": | |
self.rgb_backbone = resnet50d( | |
pretrained=True, | |
in_chans=in_chans, | |
features_only=True, | |
out_indices=[4], | |
) | |
elif rgb_backbone_name == "r26": | |
self.rgb_backbone = resnet26d( | |
pretrained=True, | |
in_chans=in_chans, | |
features_only=True, | |
out_indices=[4], | |
) | |
elif rgb_backbone_name == "r18": | |
self.rgb_backbone = resnet18d( | |
pretrained=True, | |
in_chans=in_chans, | |
features_only=True, | |
out_indices=[4], | |
) | |
if lidar_backbone_name == "r50": | |
self.lidar_backbone = resnet50d( | |
pretrained=False, | |
in_chans=in_chans, | |
features_only=True, | |
out_indices=[4], | |
) | |
elif lidar_backbone_name == "r26": | |
self.lidar_backbone = resnet26d( | |
pretrained=False, | |
in_chans=in_chans, | |
features_only=True, | |
out_indices=[4], | |
) | |
elif lidar_backbone_name == "r18": | |
self.lidar_backbone = resnet18d( | |
pretrained=False, in_chans=3, features_only=True, out_indices=[4] | |
) | |
rgb_embed_layer = partial(HybridEmbed, backbone=self.rgb_backbone) | |
lidar_embed_layer = partial(HybridEmbed, backbone=self.lidar_backbone) | |
if use_mmad_pretrain: | |
params = torch.load(use_mmad_pretrain)["state_dict"] | |
updated_params = OrderedDict() | |
for key in params: | |
if "backbone" in key: | |
updated_params[key.replace("backbone.", "")] = params[key] | |
self.rgb_backbone.load_state_dict(updated_params) | |
self.rgb_patch_embed = rgb_embed_layer( | |
img_size=img_size, | |
patch_size=patch_size, | |
in_chans=in_chans, | |
embed_dim=embed_dim, | |
) | |
self.lidar_patch_embed = lidar_embed_layer( | |
img_size=img_size, | |
patch_size=patch_size, | |
in_chans=3, | |
embed_dim=embed_dim, | |
) | |
else: | |
if rgb_backbone_name == "r50": | |
self.rgb_backbone = resnet50d( | |
pretrained=True, in_chans=3, features_only=True, out_indices=[4] | |
) | |
elif rgb_backbone_name == "r101": | |
self.rgb_backbone = resnet101d( | |
pretrained=True, in_chans=3, features_only=True, out_indices=[4] | |
) | |
elif rgb_backbone_name == "r26": | |
self.rgb_backbone = resnet26d( | |
pretrained=True, in_chans=3, features_only=True, out_indices=[4] | |
) | |
elif rgb_backbone_name == "r18": | |
self.rgb_backbone = resnet18d( | |
pretrained=True, in_chans=3, features_only=True, out_indices=[4] | |
) | |
embed_layer = partial(HybridEmbed, backbone=self.rgb_backbone) | |
self.rgb_patch_embed = embed_layer( | |
img_size=img_size, | |
patch_size=patch_size, | |
in_chans=in_chans, | |
embed_dim=embed_dim, | |
) | |
self.lidar_patch_embed = embed_layer( | |
img_size=img_size, | |
patch_size=patch_size, | |
in_chans=in_chans, | |
embed_dim=embed_dim, | |
) | |
self.global_embed = nn.Parameter(torch.zeros(1, embed_dim, 5)) | |
self.view_embed = nn.Parameter(torch.zeros(1, embed_dim, 5, 1)) | |
if self.end2end: | |
self.query_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, 4)) | |
self.query_embed = nn.Parameter(torch.zeros(4, 1, embed_dim)) | |
elif self.waypoints_pred_head == "heatmap": | |
self.query_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, 5)) | |
self.query_embed = nn.Parameter(torch.zeros(400 + 5, 1, embed_dim)) | |
else: | |
self.query_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, 11)) | |
self.query_embed = nn.Parameter(torch.zeros(400 + 11, 1, embed_dim)) | |
if self.end2end: | |
self.waypoints_generator = GRUWaypointsPredictor(embed_dim, 4) | |
elif self.waypoints_pred_head == "heatmap": | |
self.waypoints_generator = MultiPath_Generator( | |
embed_dim + 32, embed_dim, 10 | |
) | |
elif self.waypoints_pred_head == "gru": | |
self.waypoints_generator = GRUWaypointsPredictor(embed_dim) | |
elif self.waypoints_pred_head == "gru-command": | |
self.waypoints_generator = GRUWaypointsPredictorWithCommand(embed_dim) | |
elif self.waypoints_pred_head == "linear": | |
self.waypoints_generator = LinearWaypointsPredictor(embed_dim) | |
elif self.waypoints_pred_head == "linear-sum": | |
self.waypoints_generator = LinearWaypointsPredictor(embed_dim, cumsum=True) | |
self.junction_pred_head = nn.Linear(embed_dim, 2) | |
self.traffic_light_pred_head = nn.Linear(embed_dim, 2) | |
self.stop_sign_head = nn.Linear(embed_dim, 2) | |
if self.traffic_pred_head_type == "det": | |
self.traffic_pred_head = nn.Sequential( | |
*[ | |
nn.Linear(embed_dim + 32, 64), | |
nn.ReLU(), | |
nn.Linear(64, 7), | |
# nn.Sigmoid(), | |
] | |
) | |
elif self.traffic_pred_head_type == "seg": | |
self.traffic_pred_head = nn.Sequential( | |
*[nn.Linear(embed_dim, 64), nn.ReLU(), nn.Linear(64, 1), nn.Sigmoid()] | |
) | |
self.position_encoding = HyperDimensionalPositionalEncoding(embed_dim , normalize=True) | |
encoder_layer = TransformerEncoderLayer( | |
embed_dim, num_heads, dim_feedforward, dropout, act_layer, normalize_before | |
) | |
self.encoder = TransformerEncoder(encoder_layer, enc_depth, None) | |
decoder_layer = TransformerDecoderLayer( | |
embed_dim, num_heads, dim_feedforward, dropout, act_layer, normalize_before | |
) | |
decoder_norm = nn.LayerNorm(embed_dim) | |
self.decoder = TransformerDecoder( | |
decoder_layer, dec_depth, decoder_norm, return_intermediate=False | |
) | |
self.reset_parameters() | |
def reset_parameters(self): | |
nn.init.uniform_(self.global_embed) | |
nn.init.uniform_(self.view_embed) | |
nn.init.uniform_(self.query_embed) | |
nn.init.uniform_(self.query_pos_embed) | |
def forward_features( | |
self, | |
front_image, | |
left_image, | |
right_image, | |
front_center_image, | |
lidar, | |
measurements, | |
): | |
features = [] | |
# Front view processing | |
front_image_token, front_image_token_global = self.rgb_patch_embed(front_image) | |
if self.use_view_embed: | |
front_image_token = ( | |
front_image_token | |
+ self.view_embed[:, :, 0:1, :] | |
+ self.position_encoding(front_image_token) | |
) | |
else: | |
front_image_token = front_image_token + self.position_encoding( | |
front_image_token | |
) | |
front_image_token = front_image_token.flatten(2).permute(2, 0, 1) | |
front_image_token_global = ( | |
front_image_token_global | |
+ self.view_embed[:, :, 0, :] | |
+ self.global_embed[:, :, 0:1] | |
) | |
front_image_token_global = front_image_token_global.permute(2, 0, 1) | |
features.extend([front_image_token, front_image_token_global]) | |
if self.with_right_left_sensors: | |
# Left view processing | |
left_image_token, left_image_token_global = self.rgb_patch_embed(left_image) | |
if self.use_view_embed: | |
left_image_token = ( | |
left_image_token | |
+ self.view_embed[:, :, 1:2, :] | |
+ self.position_encoding(left_image_token) | |
) | |
else: | |
left_image_token = left_image_token + self.position_encoding( | |
left_image_token | |
) | |
left_image_token = left_image_token.flatten(2).permute(2, 0, 1) | |
left_image_token_global = ( | |
left_image_token_global | |
+ self.view_embed[:, :, 1, :] | |
+ self.global_embed[:, :, 1:2] | |
) | |
left_image_token_global = left_image_token_global.permute(2, 0, 1) | |
# Right view processing | |
right_image_token, right_image_token_global = self.rgb_patch_embed( | |
right_image | |
) | |
if self.use_view_embed: | |
right_image_token = ( | |
right_image_token | |
+ self.view_embed[:, :, 2:3, :] | |
+ self.position_encoding(right_image_token) | |
) | |
else: | |
right_image_token = right_image_token + self.position_encoding( | |
right_image_token | |
) | |
right_image_token = right_image_token.flatten(2).permute(2, 0, 1) | |
right_image_token_global = ( | |
right_image_token_global | |
+ self.view_embed[:, :, 2, :] | |
+ self.global_embed[:, :, 2:3] | |
) | |
right_image_token_global = right_image_token_global.permute(2, 0, 1) | |
features.extend( | |
[ | |
left_image_token, | |
left_image_token_global, | |
right_image_token, | |
right_image_token_global, | |
] | |
) | |
if self.with_center_sensor: | |
# Front center view processing | |
( | |
front_center_image_token, | |
front_center_image_token_global, | |
) = self.rgb_patch_embed(front_center_image) | |
if self.use_view_embed: | |
front_center_image_token = ( | |
front_center_image_token | |
+ self.view_embed[:, :, 3:4, :] | |
+ self.position_encoding(front_center_image_token) | |
) | |
else: | |
front_center_image_token = ( | |
front_center_image_token | |
+ self.position_encoding(front_center_image_token) | |
) | |
front_center_image_token = front_center_image_token.flatten(2).permute( | |
2, 0, 1 | |
) | |
front_center_image_token_global = ( | |
front_center_image_token_global | |
+ self.view_embed[:, :, 3, :] | |
+ self.global_embed[:, :, 3:4] | |
) | |
front_center_image_token_global = front_center_image_token_global.permute( | |
2, 0, 1 | |
) | |
features.extend([front_center_image_token, front_center_image_token_global]) | |
if self.with_lidar: | |
lidar_token, lidar_token_global = self.lidar_patch_embed(lidar) | |
if self.use_view_embed: | |
lidar_token = ( | |
lidar_token | |
+ self.view_embed[:, :, 4:5, :] | |
+ self.position_encoding(lidar_token) | |
) | |
else: | |
lidar_token = lidar_token + self.position_encoding(lidar_token) | |
lidar_token = lidar_token.flatten(2).permute(2, 0, 1) | |
lidar_token_global = ( | |
lidar_token_global | |
+ self.view_embed[:, :, 4, :] | |
+ self.global_embed[:, :, 4:5] | |
) | |
lidar_token_global = lidar_token_global.permute(2, 0, 1) | |
features.extend([lidar_token, lidar_token_global]) | |
features = torch.cat(features, 0) | |
return features | |
def forward(self, x): | |
front_image = x["rgb"] | |
left_image = x["rgb_left"] | |
right_image = x["rgb_right"] | |
front_center_image = x["rgb_center"] | |
measurements = x["measurements"] | |
target_point = x["target_point"] | |
lidar = x["lidar"] | |
if self.direct_concat: | |
img_size = front_image.shape[-1] | |
left_image = torch.nn.functional.interpolate( | |
left_image, size=(img_size, img_size) | |
) | |
right_image = torch.nn.functional.interpolate( | |
right_image, size=(img_size, img_size) | |
) | |
front_center_image = torch.nn.functional.interpolate( | |
front_center_image, size=(img_size, img_size) | |
) | |
front_image = torch.cat( | |
[front_image, left_image, right_image, front_center_image], dim=1 | |
) | |
features = self.forward_features( | |
front_image, | |
left_image, | |
right_image, | |
front_center_image, | |
lidar, | |
measurements, | |
) | |
bs = front_image.shape[0] | |
if self.end2end: | |
tgt = self.query_pos_embed.repeat(bs, 1, 1) | |
else: | |
tgt = self.position_encoding( | |
torch.ones((bs, 1, 20, 20), device=x["rgb"].device) | |
) | |
tgt = tgt.flatten(2) | |
tgt = torch.cat([tgt, self.query_pos_embed.repeat(bs, 1, 1)], 2) | |
tgt = tgt.permute(2, 0, 1) | |
memory = self.encoder(features, mask=self.attn_mask) | |
hs = self.decoder(self.query_embed.repeat(1, bs, 1), memory, query_pos=tgt)[0] | |
hs = hs.permute(1, 0, 2) # Batchsize , N, C | |
if self.end2end: | |
waypoints = self.waypoints_generator(hs, target_point) | |
return waypoints | |
if self.waypoints_pred_head != "heatmap": | |
traffic_feature = hs[:, :400] | |
is_junction_feature = hs[:, 400] | |
traffic_light_state_feature = hs[:, 400] | |
stop_sign_feature = hs[:, 400] | |
waypoints_feature = hs[:, 401:411] | |
else: | |
traffic_feature = hs[:, :400] | |
is_junction_feature = hs[:, 400] | |
traffic_light_state_feature = hs[:, 400] | |
stop_sign_feature = hs[:, 400] | |
waypoints_feature = hs[:, 401:405] | |
if self.waypoints_pred_head == "heatmap": | |
waypoints = self.waypoints_generator(waypoints_feature, measurements) | |
elif self.waypoints_pred_head == "gru": | |
waypoints = self.waypoints_generator(waypoints_feature, target_point) | |
elif self.waypoints_pred_head == "gru-command": | |
waypoints = self.waypoints_generator(waypoints_feature, target_point, measurements) | |
elif self.waypoints_pred_head == "linear": | |
waypoints = self.waypoints_generator(waypoints_feature, measurements) | |
elif self.waypoints_pred_head == "linear-sum": | |
waypoints = self.waypoints_generator(waypoints_feature, measurements) | |
is_junction = self.junction_pred_head(is_junction_feature) | |
traffic_light_state = self.traffic_light_pred_head(traffic_light_state_feature) | |
stop_sign = self.stop_sign_head(stop_sign_feature) | |
velocity = measurements[:, 6:7].unsqueeze(-1) | |
velocity = velocity.repeat(1, 400, 32) | |
traffic_feature_with_vel = torch.cat([traffic_feature, velocity], dim=2) | |
traffic = self.traffic_pred_head(traffic_feature_with_vel) | |
return traffic, waypoints, is_junction, traffic_light_state, stop_sign, traffic_feature | |
def load_pretrained(self, model_path, strict=False): | |
""" | |
تحميل الأوزان المدربة مسبقاً - نسخة محسنة | |
Args: | |
model_path (str): مسار ملف الأوزان | |
strict (bool): إذا كان True، يتطلب تطابق تام للمفاتيح | |
""" | |
if not model_path or not Path(model_path).exists(): | |
logging.warning(f"ملف الأوزان غير موجود: {model_path}") | |
logging.info("سيتم استخدام أوزان عشوائية") | |
return False | |
try: | |
logging.info(f"محاولة تحميل الأوزان من: {model_path}") | |
# تحميل الملف مع معالجة أنواع مختلفة من ملفات الحفظ | |
checkpoint = torch.load(model_path, map_location='cpu', weights_only=False) | |
# استخراج state_dict من أنواع مختلفة من ملفات الحفظ | |
if isinstance(checkpoint, dict): | |
if 'model_state_dict' in checkpoint: | |
state_dict = checkpoint['model_state_dict'] | |
logging.info("تم العثور على 'model_state_dict' في الملف") | |
elif 'state_dict' in checkpoint: | |
state_dict = checkpoint['state_dict'] | |
logging.info("تم العثور على 'state_dict' في الملف") | |
elif 'model' in checkpoint: | |
state_dict = checkpoint['model'] | |
logging.info("تم العثور على 'model' في الملف") | |
else: | |
state_dict = checkpoint | |
logging.info("استخدام الملف كـ state_dict مباشرة") | |
else: | |
state_dict = checkpoint | |
logging.info("استخدام الملف كـ state_dict مباشرة") | |
# تنظيف أسماء المفاتيح (إزالة 'module.' إذا كانت موجودة) | |
clean_state_dict = OrderedDict() | |
for k, v in state_dict.items(): | |
# إزالة 'module.' من بداية اسم المفتاح إذا كان موجوداً | |
clean_key = k[7:] if k.startswith('module.') else k | |
clean_state_dict[clean_key] = v | |
# تحميل الأوزان | |
missing_keys, unexpected_keys = self.load_state_dict(clean_state_dict, strict=strict) | |
# تقرير حالة التحميل | |
if missing_keys: | |
logging.warning(f"مفاتيح مفقودة ({len(missing_keys)}): {missing_keys[:5]}..." if len(missing_keys) > 5 else f"مفاتيح مفقودة: {missing_keys}") | |
if unexpected_keys: | |
logging.warning(f"مفاتيح غير متوقعة ({len(unexpected_keys)}): {unexpected_keys[:5]}..." if len(unexpected_keys) > 5 else f"مفاتيح غير متوقعة: {unexpected_keys}") | |
if not missing_keys and not unexpected_keys: | |
logging.info("✅ تم تحميل جميع الأوزان بنجاح تام") | |
elif not strict: | |
logging.info("✅ تم تحميل الأوزان بنجاح (مع تجاهل عدم التطابق)") | |
return True | |
except Exception as e: | |
logging.error(f"❌ خطأ في تحميل الأوزان: {str(e)}") | |
logging.info("سيتم استخدام أوزان عشوائية") | |
return False | |
# ============================================================================== | |
# الدالة الأولى: get_master_config | |
# ============================================================================== | |
def get_master_config(): | |
""" | |
[النسخة الاحترافية] | |
يعيد قاموسًا شاملاً يحتوي على جميع إعدادات التطبيق الثابتة. | |
هذه الدالة هي المصدر الوحيد للحقيقة للإعدادات. | |
""" | |
# --- القسم 1: معلومات مستودع النموذج على Hugging Face Hub --- | |
huggingface_repo = { | |
'repo_id': "BaseerAI/Interfuser-Baseer-v1", # استبدله باسم مستودع النموذج الخاص بك | |
'filename': "pytorch_model.bin" # اسم ملف الأوزان داخل المستودع | |
} | |
# --- القسم 2: إعدادات بنية نموذج Interfuser --- | |
model_params = { | |
"img_size": 224, "embed_dim": 256, "enc_depth": 6, "dec_depth": 6, | |
"rgb_backbone_name": 'r50', "lidar_backbone_name": 'r18', | |
"waypoints_pred_head": 'gru', "use_different_backbone": True, | |
"with_lidar": False, "with_right_left_sensors": False, | |
"with_center_sensor": False, "multi_view_img_size": 112, | |
"patch_size": 8, "in_chans": 3, "dim_feedforward": 2048, | |
"normalize_before": False, "num_heads": 8, "dropout": 0.1, | |
"end2end": False, "direct_concat": False, "separate_view_attention": False, | |
"separate_all_attention": False, "freeze_num": -1, | |
"traffic_pred_head_type": "det", "reverse_pos": True, | |
"use_view_embed": False, "use_mmad_pretrain": None, | |
} | |
# --- القسم 3: إعدادات الشبكة ومنظور عين الطائر (BEV) --- | |
grid_conf = { | |
'h': 20, 'w': 20, 'x_res': 1.0, 'y_res': 1.0, | |
'y_min': 0.0, 'y_max': 20.0, 'x_min': -10.0, 'x_max': 10.0, | |
} | |
# --- القسم 4: إعدادات وحدة التحكم (Controller) والمتتبع (Tracker) --- | |
controller_params = { | |
'turn_KP': 0.75, 'turn_KI': 0.05, 'turn_KD': 0.25, 'turn_n': 20, | |
'speed_KP': 0.55, 'speed_KI': 0.05, 'speed_KD': 0.15, 'speed_n': 20, | |
'max_speed': 8.0, 'max_throttle': 0.75, 'min_speed': 0.1, | |
'brake_sensitivity': 0.3, 'light_threshold': 0.5, 'stop_threshold': 0.6, | |
'stop_sign_duration': 20, 'max_stop_time': 250, | |
'forced_move_duration': 20, 'forced_throttle': 0.5, | |
'max_red_light_time': 150, 'red_light_block_duration': 80, | |
'accel_rate': 0.1, 'decel_rate': 0.2, 'critical_distance': 4.0, | |
'follow_distance': 10.0, 'speed_match_factor': 0.9, | |
'tracker_match_thresh': 2.5, 'tracker_prune_age': 5, | |
'follow_grace_period': 20 | |
} | |
# --- القسم 5: تجميع كل شيء في قاموس رئيسي واحد --- | |
master_config = { | |
'huggingface_repo': huggingface_repo, | |
'model_params': model_params, | |
'grid_conf': grid_conf, | |
'controller_params': controller_params, | |
'simulation': { | |
'frequency': 10.0 | |
} | |
} | |
return master_config | |
# ============================================================================== | |
# الدالة الثانية: load_and_prepare_model | |
# ============================================================================== | |
def load_and_prepare_model(device: torch.device) -> InterfuserHDPE: | |
""" | |
[النسخة الاحترافية] | |
تستخدم الإعدادات الرئيسية من `get_master_config` لإنشاء وتحميل النموذج. | |
تقوم بتحويل معرّف النموذج من Hugging Face Hub إلى مسار ملف حقيقي. | |
Args: | |
device (torch.device): الجهاز المستهدف (CPU/GPU) | |
Returns: | |
Interfuser: النموذج المحمل وجاهز للاستدلال. | |
""" | |
try: | |
logging.info("Initializing model loading process...") | |
# 1. الحصول على جميع الإعدادات من المصدر الوحيد للحقيقة | |
config = get_master_config() | |
# 2. تحميل ملف الأوزان من Hugging Face Hub | |
repo_info = config['huggingface_repo'] | |
logging.info(f"Downloading model weights from repo: '{repo_info['repo_id']}'") | |
# استخدام token إذا كان المستودع خاصًا | |
# token = HfFolder.get_token() # أو يمكن تمريره مباشرة | |
actual_model_path = hf_hub_download( | |
repo_id=repo_info['repo_id'], | |
filename=repo_info['filename'], | |
# token=token, # قم بإلغاء التعليق إذا كان المستودع خاصًا | |
) | |
logging.info(f"Model weights are available at local path: {actual_model_path}") | |
# 3. إنشاء نسخة من النموذج باستخدام الإعدادات الصحيحة | |
logging.info("Instantiating model with specified parameters...") | |
model = InterfuserHDPE(**config['model_params']).to(device) | |
# 4. تحميل الأوزان التي تم تنزيلها إلى النموذج | |
# نستخدم الدالة المساعدة الموجودة داخل كلاس النموذج نفسه | |
success = model.load_pretrained(actual_model_path, strict=False) | |
if not success: | |
logging.warning("⚠️ Model weights were not loaded successfully. The model will use random weights.") | |
# 5. وضع النموذج في وضع التقييم (خطوة حاسمة) | |
model.eval() | |
logging.info("✅ Model prepared and set to evaluation mode. Ready for inference.") | |
return model | |
except Exception as e: | |
# تسجيل الخطأ بالتفصيل ثم إطلاقه مرة أخرى ليتم التعامل معه في مستوى أعلى | |
logging.error(f"❌ CRITICAL ERROR during model initialization: {e}", exc_info=True) | |
raise | |