Adam commited on
Commit
7b0dd2f
·
1 Parent(s): d80d18d

Deploy Baseer Self-Driving API v1.0

Browse files
Files changed (9) hide show
  1. Dockerfile +41 -0
  2. LICENSE +21 -0
  3. app.py +419 -0
  4. app_config.yaml +17 -0
  5. health_check.py +127 -0
  6. model/README.md +29 -0
  7. model_definition.py +1318 -0
  8. requirements.txt +22 -0
  9. simulation_modules.py +336 -0
Dockerfile ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Read the doc: https://huggingface.co/docs/hub/spaces-sdks-docker
2
+ # Baseer Self-Driving API - Hugging Face Space
3
+
4
+ FROM python:3.9
5
+
6
+ # إنشاء مستخدم غير root (متطلب Hugging Face)
7
+ RUN useradd -m -u 1000 user
8
+ USER user
9
+ ENV PATH="/home/user/.local/bin:$PATH"
10
+
11
+ # تعيين مجلد العمل
12
+ WORKDIR /app
13
+
14
+ # تثبيت متطلبات النظام كـ root
15
+ USER root
16
+ RUN apt-get update && apt-get install -y \
17
+ libglib2.0-0 \
18
+ libsm6 \
19
+ libxext6 \
20
+ libxrender-dev \
21
+ libgomp1 \
22
+ libgtk-3-0 \
23
+ && rm -rf /var/lib/apt/lists/*
24
+
25
+ # العودة لمستخدم user
26
+ USER user
27
+
28
+ # نسخ ملفات المتطلبات
29
+ COPY --chown=user ./requirements.txt requirements.txt
30
+
31
+ # تثبيت المتطلبات
32
+ RUN pip install --no-cache-dir --upgrade -r requirements.txt
33
+
34
+ # نسخ كود التطبيق
35
+ COPY --chown=user . /app
36
+
37
+ # تعيين متغيرات البيئة
38
+ ENV PYTHONPATH=/app
39
+
40
+ # تشغيل التطبيق على المنفذ 7860 (متطلب Hugging Face)
41
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 Baseer Team
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
app.py ADDED
@@ -0,0 +1,419 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py - InterFuser Self-Driving API Server
2
+
3
+ import uuid
4
+ import base64
5
+ import cv2
6
+ import torch
7
+ import numpy as np
8
+ from fastapi import FastAPI, HTTPException
9
+ from pydantic import BaseModel
10
+ from torchvision import transforms
11
+ from typing import List, Dict, Any, Optional
12
+ import logging
13
+
14
+ # استيراد من ملفاتنا المحلية
15
+ from model_definition import InterfuserModel, load_and_prepare_model, create_model_config
16
+ from simulation_modules import (
17
+ InterfuserController, ControllerConfig, Tracker, DisplayInterface,
18
+ render, render_waypoints, render_self_car, WAYPOINT_SCALE_FACTOR,
19
+ T1_FUTURE_TIME, T2_FUTURE_TIME
20
+ )
21
+
22
+ # إعداد التسجيل
23
+ logging.basicConfig(level=logging.INFO)
24
+ logger = logging.getLogger(__name__)
25
+
26
+ # ================== إعدادات عامة وتحميل النموذج ==================
27
+ app = FastAPI(
28
+ title="Baseer Self-Driving API",
29
+ description="API للقيادة الذاتية باستخدام نموذج InterFuser",
30
+ version="1.0.0"
31
+ )
32
+
33
+ device = torch.device("cpu")
34
+ logger.info(f"Using device: {device}")
35
+
36
+ # تحميل النموذج باستخدام الدالة المحسنة
37
+ try:
38
+ # إنشاء إعدادات النموذج باستخدام الإعدادات الصحيحة من التدريب
39
+ model_config = create_model_config(
40
+ model_path="model/best_model.pth"
41
+ # الإعدادات الصحيحة من التدريب ستطبق تلقائياً:
42
+ # embed_dim=256, rgb_backbone_name='r50', waypoints_pred_head='gru'
43
+ # with_lidar=False, with_right_left_sensors=False, with_center_sensor=False
44
+ )
45
+
46
+ # تحميل النموذج مع الأوزان
47
+ model = load_and_prepare_model(model_config, device)
48
+ logger.info("✅ تم تحميل النموذج بنجاح")
49
+
50
+ except Exception as e:
51
+ logger.error(f"❌ خطأ في تحميل النموذج: {e}")
52
+ logger.info("🔄 محاولة إنشاء نموذج بأوزان عشوائية...")
53
+ try:
54
+ model = InterfuserModel()
55
+ model.to(device)
56
+ model.eval()
57
+ logger.warning("⚠️ تم إنشاء النموذج بأوزان عشوائية")
58
+ except Exception as e2:
59
+ logger.error(f"❌ فشل في إنشاء النموذج: {e2}")
60
+ model = None
61
+
62
+ # تهيئة واجهة العرض
63
+ display = DisplayInterface()
64
+
65
+ # قاموس لتخزين جلسات المستخدمين
66
+ SESSIONS: Dict[str, Dict] = {}
67
+
68
+ # ================== هياكل بيانات Pydantic ==================
69
+ class Measurements(BaseModel):
70
+ pos: List[float] = [0.0, 0.0] # [x, y] position
71
+ theta: float = 0.0 # orientation angle
72
+ speed: float = 0.0 # current speed
73
+ steer: float = 0.0 # current steering
74
+ throttle: float = 0.0 # current throttle
75
+ brake: bool = False # brake status
76
+ command: int = 4 # driving command (4 = FollowLane)
77
+ target_point: List[float] = [0.0, 0.0] # target point [x, y]
78
+
79
+ class ModelOutputs(BaseModel):
80
+ traffic: List[List[List[float]]] # 20x20x7 grid
81
+ waypoints: List[List[float]] # Nx2 waypoints
82
+ is_junction: float
83
+ traffic_light_state: float
84
+ stop_sign: float
85
+
86
+ class ControlCommands(BaseModel):
87
+ steer: float
88
+ throttle: float
89
+ brake: bool
90
+
91
+ class RunStepInput(BaseModel):
92
+ session_id: str
93
+ image_b64: str
94
+ measurements: Measurements
95
+
96
+ class RunStepOutput(BaseModel):
97
+ model_outputs: ModelOutputs
98
+ control_commands: ControlCommands
99
+ dashboard_image_b64: str
100
+
101
+ class SessionResponse(BaseModel):
102
+ session_id: str
103
+ message: str
104
+
105
+ # ================== دوال المساعدة ==================
106
+ def get_image_transform():
107
+ """إنشاء تحويلات الصورة كما في PDMDataset"""
108
+ return transforms.Compose([
109
+ transforms.ToTensor(),
110
+ transforms.Resize((224, 224), antialias=True),
111
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
112
+ ])
113
+
114
+ # إنشاء كائن التحويل مرة واحدة
115
+ image_transform = get_image_transform()
116
+
117
+ def preprocess_input(frame_rgb: np.ndarray, measurements: Measurements, device: torch.device) -> Dict[str, torch.Tensor]:
118
+ """
119
+ تحاكي ما يفعله PDMDataset.__getitem__ لإنشاء دفعة (batch) واحدة.
120
+ """
121
+ # 1. معالجة الصورة الرئيسية
122
+ from PIL import Image
123
+ if isinstance(frame_rgb, np.ndarray):
124
+ frame_rgb = Image.fromarray(frame_rgb)
125
+
126
+ image_tensor = image_transform(frame_rgb).unsqueeze(0).to(device) # إضافة بُعد الدفعة
127
+
128
+ # 2. إنشاء مدخلات الكاميرات الأخرى عن طريق الاستنساخ
129
+ batch = {
130
+ 'rgb': image_tensor,
131
+ 'rgb_left': image_tensor.clone(),
132
+ 'rgb_right': image_tensor.clone(),
133
+ 'rgb_center': image_tensor.clone(),
134
+ }
135
+
136
+ # 3. إنشاء مدخل ليدار وهمي (أصفار)
137
+ batch['lidar'] = torch.zeros(1, 3, 224, 224, dtype=torch.float32).to(device)
138
+
139
+ # 4. تجميع القياسات بنفس ترتيب PDMDataset
140
+ m = measurements
141
+ measurements_tensor = torch.tensor([[
142
+ m.pos[0], m.pos[1], m.theta,
143
+ m.steer, m.throttle, float(m.brake),
144
+ m.speed, float(m.command)
145
+ ]], dtype=torch.float32).to(device)
146
+ batch['measurements'] = measurements_tensor
147
+
148
+ # 5. إنشاء نقطة هدف
149
+ batch['target_point'] = torch.tensor([m.target_point], dtype=torch.float32).to(device)
150
+
151
+ # لا نحتاج إلى قيم ground truth (gt_*) أثناء التنبؤ
152
+ return batch
153
+
154
+ def decode_base64_image(image_b64: str) -> np.ndarray:
155
+ """
156
+ فك تشفير صورة Base64
157
+ """
158
+ try:
159
+ image_bytes = base64.b64decode(image_b64)
160
+ nparr = np.frombuffer(image_bytes, np.uint8)
161
+ image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
162
+ return image
163
+ except Exception as e:
164
+ raise HTTPException(status_code=400, detail=f"Invalid image format: {str(e)}")
165
+
166
+ def encode_image_to_base64(image: np.ndarray) -> str:
167
+ """
168
+ تشفير صورة إلى Base64
169
+ """
170
+ _, buffer = cv2.imencode('.jpg', image, [cv2.IMWRITE_JPEG_QUALITY, 85])
171
+ return base64.b64encode(buffer).decode('utf-8')
172
+
173
+ # ================== نقاط نهاية الـ API ==================
174
+ @app.get("/")
175
+ async def root():
176
+ """
177
+ نقطة البداية للـ API
178
+ """
179
+ return {
180
+ "message": "InterFuser Self-Driving API",
181
+ "version": "1.0.0",
182
+ "status": "running",
183
+ "active_sessions": len(SESSIONS)
184
+ }
185
+
186
+ @app.post("/start_session", response_model=SessionResponse)
187
+ async def start_session():
188
+ """
189
+ بدء جلسة جديدة للمحاكاة
190
+ """
191
+ session_id = str(uuid.uuid4())
192
+
193
+ # إنشاء جلسة جديدة
194
+ SESSIONS[session_id] = {
195
+ 'tracker': Tracker(frequency=10),
196
+ 'controller': InterfuserController(ControllerConfig()),
197
+ 'frame_num': 0,
198
+ 'created_at': np.datetime64('now'),
199
+ 'last_activity': np.datetime64('now')
200
+ }
201
+
202
+ logger.info(f"New session created: {session_id}")
203
+
204
+ return SessionResponse(
205
+ session_id=session_id,
206
+ message="Session started successfully"
207
+ )
208
+
209
+ @app.post("/run_step", response_model=RunStepOutput)
210
+ async def run_step(data: RunStepInput):
211
+ """
212
+ تنفيذ خطوة محاكاة كاملة
213
+ """
214
+ # التحقق من وجود الجلسة
215
+ if data.session_id not in SESSIONS:
216
+ raise HTTPException(status_code=404, detail="Session not found")
217
+
218
+ session = SESSIONS[data.session_id]
219
+ tracker = session['tracker']
220
+ controller = session['controller']
221
+
222
+ # تحديث وقت النشاط
223
+ session['last_activity'] = np.datetime64('now')
224
+
225
+ try:
226
+ # 1. فك تشفير الصورة
227
+ frame_bgr = decode_base64_image(data.image_b64)
228
+ frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
229
+
230
+ # 2. معالجة المدخلات
231
+ inputs = preprocess_input(frame_rgb, data.measurements, device)
232
+
233
+ # 3. تشغيل النموذج
234
+ if model is None:
235
+ raise HTTPException(status_code=500, detail="Model not loaded")
236
+
237
+ with torch.no_grad():
238
+ traffic, waypoints, is_junction, traffic_light, stop_sign, _ = model(inputs)
239
+
240
+ # 4. معالجة مخرجات النموذج
241
+ traffic_np = traffic.cpu().numpy()[0] # أخذ أول عنصر من الـ batch
242
+ waypoints_np = waypoints.cpu().numpy()[0]
243
+ is_junction_prob = torch.sigmoid(is_junction)[0, 1].item()
244
+ traffic_light_prob = torch.sigmoid(traffic_light)[0, 0].item()
245
+ stop_sign_prob = torch.sigmoid(stop_sign)[0, 1].item()
246
+
247
+ # 5. تحديث التتبع
248
+ # تحويل traffic grid إلى detections للتتبع
249
+ detections = []
250
+ h, w, c = traffic_np.shape
251
+ for y in range(h):
252
+ for x in range(w):
253
+ for ch in range(c):
254
+ if traffic_np[y, x, ch] > 0.2: # عتبة الكشف
255
+ world_x = (x / w - 0.5) * 64 # تحويل إلى إحداثيات العالم
256
+ world_y = (y / h - 0.5) * 64
257
+ detections.append({
258
+ 'position': [world_x, world_y],
259
+ 'feature': traffic_np[y, x, ch]
260
+ })
261
+
262
+ updated_traffic = tracker.update_and_predict(detections, session['frame_num'])
263
+
264
+ # 6. تشغيل المتحكم
265
+ steer, throttle, brake, metadata = controller.run_step(
266
+ current_speed=data.measurements.speed,
267
+ waypoints=waypoints_np,
268
+ junction=is_junction_prob,
269
+ traffic_light_state=traffic_light_prob,
270
+ stop_sign=stop_sign_prob,
271
+ meta_data={'frame': session['frame_num']}
272
+ )
273
+
274
+ # 7. إنشاء خرائط العرض
275
+ surround_t0, counts_t0 = render(updated_traffic, t=0)
276
+ surround_t1, counts_t1 = render(updated_traffic, t=T1_FUTURE_TIME)
277
+ surround_t2, counts_t2 = render(updated_traffic, t=T2_FUTURE_TIME)
278
+
279
+ # إضافة المسار المقترح
280
+ wp_map = render_waypoints(waypoints_np)
281
+ map_t0 = cv2.add(surround_t0, wp_map)
282
+
283
+ # إضافة السيارة الذاتية
284
+ map_t0 = render_self_car(map_t0)
285
+ map_t1 = render_self_car(surround_t1)
286
+ map_t2 = render_self_car(surround_t2)
287
+
288
+ # 8. إنشاء لوحة العرض النهائية
289
+ interface_data = {
290
+ 'camera_view': frame_bgr,
291
+ 'map_t0': map_t0,
292
+ 'map_t1': map_t1,
293
+ 'map_t2': map_t2,
294
+ 'text_info': {
295
+ 'Frame': f"Frame: {session['frame_num']}",
296
+ 'Control': f"Steer: {steer:.2f}, Throttle: {throttle:.2f}, Brake: {brake}",
297
+ 'Speed': f"Speed: {data.measurements.speed:.1f} km/h",
298
+ 'Junction': f"Junction: {is_junction_prob:.2f}",
299
+ 'Traffic Light': f"Red Light: {traffic_light_prob:.2f}",
300
+ 'Stop Sign': f"Stop Sign: {stop_sign_prob:.2f}",
301
+ 'Metadata': metadata
302
+ },
303
+ 'object_counts': {
304
+ 't0': counts_t0,
305
+ 't1': counts_t1,
306
+ 't2': counts_t2
307
+ }
308
+ }
309
+
310
+ dashboard_image = display.run_interface(interface_data)
311
+ dashboard_b64 = encode_image_to_base64(dashboard_image)
312
+
313
+ # 9. تجميع المخرجات النهائية
314
+ response = RunStepOutput(
315
+ model_outputs=ModelOutputs(
316
+ traffic=traffic_np.tolist(),
317
+ waypoints=waypoints_np.tolist(),
318
+ is_junction=is_junction_prob,
319
+ traffic_light_state=traffic_light_prob,
320
+ stop_sign=stop_sign_prob
321
+ ),
322
+ control_commands=ControlCommands(
323
+ steer=float(steer),
324
+ throttle=float(throttle),
325
+ brake=bool(brake)
326
+ ),
327
+ dashboard_image_b64=dashboard_b64
328
+ )
329
+
330
+ # تحديث رقم الإطار
331
+ session['frame_num'] += 1
332
+
333
+ logger.info(f"Step completed for session {data.session_id}, frame {session['frame_num']}")
334
+
335
+ return response
336
+
337
+ except Exception as e:
338
+ logger.error(f"Error in run_step: {str(e)}")
339
+ raise HTTPException(status_code=500, detail=f"Processing error: {str(e)}")
340
+
341
+ @app.post("/end_session", response_model=SessionResponse)
342
+ async def end_session(session_id: str):
343
+ """
344
+ إنهاء جلسة المحاكاة
345
+ """
346
+ if session_id not in SESSIONS:
347
+ raise HTTPException(status_code=404, detail="Session not found")
348
+
349
+ # حذف الجلسة
350
+ del SESSIONS[session_id]
351
+
352
+ logger.info(f"Session ended: {session_id}")
353
+
354
+ return SessionResponse(
355
+ session_id=session_id,
356
+ message="Session ended successfully"
357
+ )
358
+
359
+ @app.get("/sessions")
360
+ async def list_sessions():
361
+ """
362
+ عرض قائمة الجلسات النشطة
363
+ """
364
+ active_sessions = []
365
+ current_time = np.datetime64('now')
366
+
367
+ for session_id, session_data in SESSIONS.items():
368
+ time_diff = current_time - session_data['last_activity']
369
+ active_sessions.append({
370
+ 'session_id': session_id,
371
+ 'frame_count': session_data['frame_num'],
372
+ 'created_at': str(session_data['created_at']),
373
+ 'last_activity': str(session_data['last_activity']),
374
+ 'inactive_minutes': float(time_diff / np.timedelta64(1, 'm'))
375
+ })
376
+
377
+ return {
378
+ 'total_sessions': len(active_sessions),
379
+ 'sessions': active_sessions
380
+ }
381
+
382
+ @app.delete("/sessions/cleanup")
383
+ async def cleanup_inactive_sessions(max_inactive_minutes: int = 30):
384
+ """
385
+ تنظيف الجلسات غير النشطة
386
+ """
387
+ current_time = np.datetime64('now')
388
+ cleaned_sessions = []
389
+
390
+ for session_id in list(SESSIONS.keys()):
391
+ session = SESSIONS[session_id]
392
+ time_diff = current_time - session['last_activity']
393
+ inactive_minutes = float(time_diff / np.timedelta64(1, 'm'))
394
+
395
+ if inactive_minutes > max_inactive_minutes:
396
+ del SESSIONS[session_id]
397
+ cleaned_sessions.append(session_id)
398
+
399
+ logger.info(f"Cleaned up {len(cleaned_sessions)} inactive sessions")
400
+
401
+ return {
402
+ 'message': f"Cleaned up {len(cleaned_sessions)} inactive sessions",
403
+ 'cleaned_sessions': cleaned_sessions,
404
+ 'remaining_sessions': len(SESSIONS)
405
+ }
406
+
407
+ # ================== معالج الأخطاء ==================
408
+ @app.exception_handler(Exception)
409
+ async def global_exception_handler(request, exc):
410
+ logger.error(f"Global exception: {str(exc)}")
411
+ return {
412
+ "error": "Internal server error",
413
+ "detail": str(exc)
414
+ }
415
+
416
+ # ================== تشغيل الخادم ==================
417
+ if __name__ == "__main__":
418
+ import uvicorn
419
+ uvicorn.run(app, host="0.0.0.0", port=7860)
app_config.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ title: Baseer Self-Driving API
2
+ emoji: 🚗
3
+ colorFrom: blue
4
+ colorTo: green
5
+ sdk: docker
6
+ app_port: 7860
7
+ pinned: false
8
+ license: mit
9
+ short_description: API للقيادة الذاتية باستخدام نموذج Baseer InterFuser
10
+ tags:
11
+ - computer-vision
12
+ - autonomous-driving
13
+ - deep-learning
14
+ - fastapi
15
+ - pytorch
16
+ - interfuser
17
+ - graduation-project
health_check.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # health_check.py - فحص صحة النظام قبل النشر
3
+
4
+ import os
5
+ import sys
6
+ import torch
7
+ import logging
8
+ from pathlib import Path
9
+
10
+ def check_python_version():
11
+ """فحص إصدار Python"""
12
+ version = sys.version_info
13
+ if version.major == 3 and version.minor >= 9:
14
+ print(f"✅ Python {version.major}.{version.minor}.{version.micro}")
15
+ return True
16
+ else:
17
+ print(f"❌ Python {version.major}.{version.minor}.{version.micro} - يتطلب Python 3.9+")
18
+ return False
19
+
20
+ def check_pytorch():
21
+ """فحص PyTorch"""
22
+ try:
23
+ print(f"✅ PyTorch {torch.__version__}")
24
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
+ print(f"✅ Device: {device}")
26
+ return True
27
+ except Exception as e:
28
+ print(f"❌ PyTorch Error: {e}")
29
+ return False
30
+
31
+ def check_required_files():
32
+ """فحص الملفات المطلوبة"""
33
+ required_files = [
34
+ "app.py",
35
+ "model_definition.py",
36
+ "simulation_modules.py",
37
+ "requirements.txt",
38
+ "Dockerfile",
39
+ "app_config.yaml",
40
+ "model/best_model.pth"
41
+ ]
42
+
43
+ missing_files = []
44
+ for file in required_files:
45
+ if Path(file).exists():
46
+ size = Path(file).stat().st_size
47
+ print(f"✅ {file} ({size:,} bytes)")
48
+ else:
49
+ print(f"❌ {file} - مفقود")
50
+ missing_files.append(file)
51
+
52
+ return len(missing_files) == 0
53
+
54
+ def check_model_loading():
55
+ """فحص تحميل النموذج"""
56
+ try:
57
+ from model_definition import InterfuserModel, create_model_config, load_and_prepare_model
58
+
59
+ # إنشاء إعدادات النموذج
60
+ config = create_model_config("model/best_model.pth")
61
+ print("✅ تم إنشاء إعدادات النموذج")
62
+
63
+ # تحميل النموذج
64
+ device = torch.device("cpu")
65
+ model = load_and_prepare_model(config, device)
66
+ print("✅ تم تحميل النموذج بنجاح")
67
+
68
+ return True
69
+
70
+ except Exception as e:
71
+ print(f"❌ خطأ في تحميل النموذج: {e}")
72
+ return False
73
+
74
+ def check_api_imports():
75
+ """فحص استيراد مكونات الـ API"""
76
+ try:
77
+ from app import app
78
+ print("✅ تم استيراد FastAPI app")
79
+
80
+ from simulation_modules import DisplayInterface, InterfuserController
81
+ print("✅ تم استيراد وحدات المحاكاة")
82
+
83
+ return True
84
+
85
+ except Exception as e:
86
+ print(f"❌ خطأ في استيراد الـ API: {e}")
87
+ return False
88
+
89
+ def main():
90
+ """الفحص الشامل للنظام"""
91
+ print("🔍 فحص صحة نظام Baseer Self-Driving API")
92
+ print("=" * 50)
93
+
94
+ checks = [
95
+ ("Python Version", check_python_version),
96
+ ("PyTorch", check_pytorch),
97
+ ("Required Files", check_required_files),
98
+ ("API Imports", check_api_imports),
99
+ ("Model Loading", check_model_loading),
100
+ ]
101
+
102
+ passed = 0
103
+ total = len(checks)
104
+
105
+ for name, check_func in checks:
106
+ print(f"\n🔍 {name}:")
107
+ try:
108
+ if check_func():
109
+ passed += 1
110
+ else:
111
+ print(f"❌ فشل في فحص {name}")
112
+ except Exception as e:
113
+ print(f"❌ خطأ في فحص {name}: {e}")
114
+
115
+ print("\n" + "=" * 50)
116
+ print(f"📊 النتيجة النهائية: {passed}/{total} فحوصات نجحت")
117
+
118
+ if passed == total:
119
+ print("🎉 النظام جاهز للنشر!")
120
+ return True
121
+ else:
122
+ print("⚠️ يجب إصلاح المشاكل قبل النشر")
123
+ return False
124
+
125
+ if __name__ == "__main__":
126
+ success = main()
127
+ sys.exit(0 if success else 1)
model/README.md ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # InterFuser Model Directory
2
+
3
+ ## Required Files
4
+
5
+ Place your trained InterFuser model files in this directory:
6
+
7
+ 1. **`interfuser_model.pth`** - The main model weights file
8
+ 2. **`config.json`** (optional) - Model configuration file
9
+
10
+ ## Model Format
11
+
12
+ The model should be a PyTorch state dict saved with:
13
+ ```python
14
+ torch.save(model.state_dict(), 'interfuser_model.pth')
15
+ ```
16
+
17
+ ## Loading in Code
18
+
19
+ The model is loaded in `model_definition.py`:
20
+ ```python
21
+ model = InterFuserModel()
22
+ model.load_state_dict(torch.load('model/interfuser_model.pth', map_location='cpu'))
23
+ ```
24
+
25
+ ## Note
26
+
27
+ - The current implementation uses a dummy model for testing
28
+ - Replace with your actual trained InterFuser weights
29
+ - Ensure the model architecture matches the one defined in `model_definition.py`
model_definition.py ADDED
@@ -0,0 +1,1318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # model_definition.py
2
+ # ============================================================================
3
+ # الاستيرادات الأساسية
4
+ # ============================================================================
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ import torch.optim as optim
9
+ from torch.optim import AdamW
10
+ from torch.optim.lr_scheduler import OneCycleLR
11
+ from torch.utils.data import Dataset, DataLoader
12
+ from torchvision import transforms
13
+ from functools import partial
14
+ from typing import Optional, List
15
+ from torch import Tensor
16
+
17
+ # مكتبات إضافية
18
+ import os
19
+ import json
20
+ import logging
21
+ import math
22
+ import copy
23
+ from pathlib import Path
24
+ from collections import OrderedDict
25
+
26
+ # مكتبات معالجة البيانات
27
+ import numpy as np
28
+ import cv2
29
+
30
+ # مكتبات اختيارية (يمكن تعطيلها إذا لم تكن متوفرة)
31
+ try:
32
+ import wandb
33
+ WANDB_AVAILABLE = True
34
+ except ImportError:
35
+ WANDB_AVAILABLE = False
36
+
37
+ try:
38
+ from tqdm import tqdm
39
+ except ImportError:
40
+ # إذا لم تكن tqdm متوفرة، استخدم دالة بديلة
41
+ def tqdm(iterable, *args, **kwargs):
42
+ return iterable
43
+
44
+ # ============================================================================
45
+ # دوال مساعدة
46
+ # ============================================================================
47
+ def to_2tuple(x):
48
+ """تحويل قيمة إلى tuple من عنصرين"""
49
+ if isinstance(x, (list, tuple)):
50
+ return tuple(x)
51
+ return (x, x)
52
+ # ============================================================================
53
+ # ============================================================================
54
+
55
+ class HybridEmbed(nn.Module):
56
+ def __init__(
57
+ self,
58
+ backbone,
59
+ img_size=224,
60
+ patch_size=1,
61
+ feature_size=None,
62
+ in_chans=3,
63
+ embed_dim=768,
64
+ ):
65
+ super().__init__()
66
+ assert isinstance(backbone, nn.Module)
67
+ img_size = to_2tuple(img_size)
68
+ patch_size = to_2tuple(patch_size)
69
+ self.img_size = img_size
70
+ self.patch_size = patch_size
71
+ self.backbone = backbone
72
+ if feature_size is None:
73
+ with torch.no_grad():
74
+ training = backbone.training
75
+ if training:
76
+ backbone.eval()
77
+ o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))
78
+ if isinstance(o, (list, tuple)):
79
+ o = o[-1] # last feature if backbone outputs list/tuple of features
80
+ feature_size = o.shape[-2:]
81
+ feature_dim = o.shape[1]
82
+ backbone.train(training)
83
+ else:
84
+ feature_size = to_2tuple(feature_size)
85
+ if hasattr(self.backbone, "feature_info"):
86
+ feature_dim = self.backbone.feature_info.channels()[-1]
87
+ else:
88
+ feature_dim = self.backbone.num_features
89
+
90
+ self.proj = nn.Conv2d(feature_dim, embed_dim, kernel_size=1, stride=1)
91
+
92
+ def forward(self, x):
93
+ x = self.backbone(x)
94
+ if isinstance(x, (list, tuple)):
95
+ x = x[-1] # last feature if backbone outputs list/tuple of features
96
+ x = self.proj(x)
97
+ global_x = torch.mean(x, [2, 3], keepdim=False)[:, :, None]
98
+ return x, global_x
99
+
100
+
101
+ class PositionEmbeddingSine(nn.Module):
102
+ """
103
+ This is a more standard version of the position embedding, very similar to the one
104
+ used by the Attention is all you need paper, generalized to work on images.
105
+ """
106
+
107
+ def __init__(
108
+ self, num_pos_feats=64, temperature=10000, normalize=False, scale=None
109
+ ):
110
+ super().__init__()
111
+ self.num_pos_feats = num_pos_feats
112
+ self.temperature = temperature
113
+ self.normalize = normalize
114
+ if scale is not None and normalize is False:
115
+ raise ValueError("normalize should be True if scale is passed")
116
+ if scale is None:
117
+ scale = 2 * math.pi
118
+ self.scale = scale
119
+
120
+ def forward(self, tensor):
121
+ x = tensor
122
+ bs, _, h, w = x.shape
123
+ not_mask = torch.ones((bs, h, w), device=x.device)
124
+ y_embed = not_mask.cumsum(1, dtype=torch.float32)
125
+ x_embed = not_mask.cumsum(2, dtype=torch.float32)
126
+ if self.normalize:
127
+ eps = 1e-6
128
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
129
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
130
+
131
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
132
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
133
+
134
+ pos_x = x_embed[:, :, :, None] / dim_t
135
+ pos_y = y_embed[:, :, :, None] / dim_t
136
+ pos_x = torch.stack(
137
+ (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
138
+ ).flatten(3)
139
+ pos_y = torch.stack(
140
+ (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
141
+ ).flatten(3)
142
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
143
+ return pos
144
+
145
+
146
+ class TransformerEncoder(nn.Module):
147
+ def __init__(self, encoder_layer, num_layers, norm=None):
148
+ super().__init__()
149
+ self.layers = _get_clones(encoder_layer, num_layers)
150
+ self.num_layers = num_layers
151
+ self.norm = norm
152
+
153
+ def forward(
154
+ self,
155
+ src,
156
+ mask: Optional[Tensor] = None,
157
+ src_key_padding_mask: Optional[Tensor] = None,
158
+ pos: Optional[Tensor] = None,
159
+ ):
160
+ output = src
161
+
162
+ for layer in self.layers:
163
+ output = layer(
164
+ output,
165
+ src_mask=mask,
166
+ src_key_padding_mask=src_key_padding_mask,
167
+ pos=pos,
168
+ )
169
+
170
+ if self.norm is not None:
171
+ output = self.norm(output)
172
+
173
+ return output
174
+
175
+
176
+ class SpatialSoftmax(nn.Module):
177
+ def __init__(self, height, width, channel, temperature=None, data_format="NCHW"):
178
+ super().__init__()
179
+
180
+ self.data_format = data_format
181
+ self.height = height
182
+ self.width = width
183
+ self.channel = channel
184
+
185
+ if temperature:
186
+ self.temperature = Parameter(torch.ones(1) * temperature)
187
+ else:
188
+ self.temperature = 1.0
189
+
190
+ pos_x, pos_y = np.meshgrid(
191
+ np.linspace(-1.0, 1.0, self.height), np.linspace(-1.0, 1.0, self.width)
192
+ )
193
+ pos_x = torch.from_numpy(pos_x.reshape(self.height * self.width)).float()
194
+ pos_y = torch.from_numpy(pos_y.reshape(self.height * self.width)).float()
195
+ self.register_buffer("pos_x", pos_x)
196
+ self.register_buffer("pos_y", pos_y)
197
+
198
+ def forward(self, feature):
199
+ # Output:
200
+ # (N, C*2) x_0 y_0 ...
201
+
202
+ if self.data_format == "NHWC":
203
+ feature = (
204
+ feature.transpose(1, 3)
205
+ .tranpose(2, 3)
206
+ .view(-1, self.height * self.width)
207
+ )
208
+ else:
209
+ feature = feature.view(-1, self.height * self.width)
210
+
211
+ weight = F.softmax(feature / self.temperature, dim=-1)
212
+ expected_x = torch.sum(
213
+ torch.autograd.Variable(self.pos_x) * weight, dim=1, keepdim=True
214
+ )
215
+ expected_y = torch.sum(
216
+ torch.autograd.Variable(self.pos_y) * weight, dim=1, keepdim=True
217
+ )
218
+ expected_xy = torch.cat([expected_x, expected_y], 1)
219
+ feature_keypoints = expected_xy.view(-1, self.channel, 2)
220
+ feature_keypoints[:, :, 1] = (feature_keypoints[:, :, 1] - 1) * 12
221
+ feature_keypoints[:, :, 0] = feature_keypoints[:, :, 0] * 12
222
+ return feature_keypoints
223
+
224
+
225
+ class MultiPath_Generator(nn.Module):
226
+ def __init__(self, in_channel, embed_dim, out_channel):
227
+ super().__init__()
228
+ self.spatial_softmax = SpatialSoftmax(100, 100, out_channel)
229
+ self.tconv0 = nn.Sequential(
230
+ nn.ConvTranspose2d(in_channel, 256, 4, 2, 1, bias=False),
231
+ nn.BatchNorm2d(256),
232
+ nn.ReLU(True),
233
+ )
234
+ self.tconv1 = nn.Sequential(
235
+ nn.ConvTranspose2d(256, 256, 4, 2, 1, bias=False),
236
+ nn.BatchNorm2d(256),
237
+ nn.ReLU(True),
238
+ )
239
+ self.tconv2 = nn.Sequential(
240
+ nn.ConvTranspose2d(256, 192, 4, 2, 1, bias=False),
241
+ nn.BatchNorm2d(192),
242
+ nn.ReLU(True),
243
+ )
244
+ self.tconv3 = nn.Sequential(
245
+ nn.ConvTranspose2d(192, 64, 4, 2, 1, bias=False),
246
+ nn.BatchNorm2d(64),
247
+ nn.ReLU(True),
248
+ )
249
+ self.tconv4_list = torch.nn.ModuleList(
250
+ [
251
+ nn.Sequential(
252
+ nn.ConvTranspose2d(64, out_channel, 8, 2, 3, bias=False),
253
+ nn.Tanh(),
254
+ )
255
+ for _ in range(6)
256
+ ]
257
+ )
258
+
259
+ self.upsample = nn.Upsample(size=(50, 50), mode="bilinear")
260
+
261
+ def forward(self, x, measurements):
262
+ mask = measurements[:, :6]
263
+ mask = mask.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).repeat(1, 1, 1, 100, 100)
264
+ velocity = measurements[:, 6:7].unsqueeze(-1).unsqueeze(-1)
265
+ velocity = velocity.repeat(1, 32, 2, 2)
266
+
267
+ n, d, c = x.shape
268
+ x = x.transpose(1, 2)
269
+ x = x.view(n, -1, 2, 2)
270
+ x = torch.cat([x, velocity], dim=1)
271
+ x = self.tconv0(x)
272
+ x = self.tconv1(x)
273
+ x = self.tconv2(x)
274
+ x = self.tconv3(x)
275
+ x = self.upsample(x)
276
+ xs = []
277
+ for i in range(6):
278
+ xt = self.tconv4_list[i](x)
279
+ xs.append(xt)
280
+ xs = torch.stack(xs, dim=1)
281
+ x = torch.sum(xs * mask, dim=1)
282
+ x = self.spatial_softmax(x)
283
+ return x
284
+
285
+
286
+ class LinearWaypointsPredictor(nn.Module):
287
+ def __init__(self, input_dim, cumsum=True):
288
+ super().__init__()
289
+ self.cumsum = cumsum
290
+ self.rank_embed = nn.Parameter(torch.zeros(1, 10, input_dim))
291
+ self.head_fc1_list = nn.ModuleList([nn.Linear(input_dim, 64) for _ in range(6)])
292
+ self.head_relu = nn.ReLU(inplace=True)
293
+ self.head_fc2_list = nn.ModuleList([nn.Linear(64, 2) for _ in range(6)])
294
+
295
+ def forward(self, x, measurements):
296
+ # input shape: n 10 embed_dim
297
+ bs, n, dim = x.shape
298
+ x = x + self.rank_embed
299
+ x = x.reshape(-1, dim)
300
+
301
+ mask = measurements[:, :6]
302
+ mask = torch.unsqueeze(mask, -1).repeat(n, 1, 2)
303
+
304
+ rs = []
305
+ for i in range(6):
306
+ res = self.head_fc1_list[i](x)
307
+ res = self.head_relu(res)
308
+ res = self.head_fc2_list[i](res)
309
+ rs.append(res)
310
+ rs = torch.stack(rs, 1)
311
+ x = torch.sum(rs * mask, dim=1)
312
+
313
+ x = x.view(bs, n, 2)
314
+ if self.cumsum:
315
+ x = torch.cumsum(x, 1)
316
+ return x
317
+
318
+
319
+ class GRUWaypointsPredictor(nn.Module):
320
+ def __init__(self, input_dim, waypoints=10):
321
+ super().__init__()
322
+ # self.gru = torch.nn.GRUCell(input_size=input_dim, hidden_size=64)
323
+ self.gru = torch.nn.GRU(input_size=input_dim, hidden_size=64, batch_first=True)
324
+ self.encoder = nn.Linear(2, 64)
325
+ self.decoder = nn.Linear(64, 2)
326
+ self.waypoints = waypoints
327
+
328
+ def forward(self, x, target_point):
329
+ bs = x.shape[0]
330
+ z = self.encoder(target_point).unsqueeze(0)
331
+ output, _ = self.gru(x, z)
332
+ output = output.reshape(bs * self.waypoints, -1)
333
+ output = self.decoder(output).reshape(bs, self.waypoints, 2)
334
+ output = torch.cumsum(output, 1)
335
+ return output
336
+
337
+ class GRUWaypointsPredictorWithCommand(nn.Module):
338
+ def __init__(self, input_dim, waypoints=10):
339
+ super().__init__()
340
+ # self.gru = torch.nn.GRUCell(input_size=input_dim, hidden_size=64)
341
+ self.grus = nn.ModuleList([torch.nn.GRU(input_size=input_dim, hidden_size=64, batch_first=True) for _ in range(6)])
342
+ self.encoder = nn.Linear(2, 64)
343
+ self.decoders = nn.ModuleList([nn.Linear(64, 2) for _ in range(6)])
344
+ self.waypoints = waypoints
345
+
346
+ def forward(self, x, target_point, measurements):
347
+ bs, n, dim = x.shape
348
+ mask = measurements[:, :6, None, None]
349
+ mask = mask.repeat(1, 1, self.waypoints, 2)
350
+
351
+ z = self.encoder(target_point).unsqueeze(0)
352
+ outputs = []
353
+ for i in range(6):
354
+ output, _ = self.grus[i](x, z)
355
+ output = output.reshape(bs * self.waypoints, -1)
356
+ output = self.decoders[i](output).reshape(bs, self.waypoints, 2)
357
+ output = torch.cumsum(output, 1)
358
+ outputs.append(output)
359
+ outputs = torch.stack(outputs, 1)
360
+ output = torch.sum(outputs * mask, dim=1)
361
+ return output
362
+
363
+
364
+ class TransformerDecoder(nn.Module):
365
+ def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
366
+ super().__init__()
367
+ self.layers = _get_clones(decoder_layer, num_layers)
368
+ self.num_layers = num_layers
369
+ self.norm = norm
370
+ self.return_intermediate = return_intermediate
371
+
372
+ def forward(
373
+ self,
374
+ tgt,
375
+ memory,
376
+ tgt_mask: Optional[Tensor] = None,
377
+ memory_mask: Optional[Tensor] = None,
378
+ tgt_key_padding_mask: Optional[Tensor] = None,
379
+ memory_key_padding_mask: Optional[Tensor] = None,
380
+ pos: Optional[Tensor] = None,
381
+ query_pos: Optional[Tensor] = None,
382
+ ):
383
+ output = tgt
384
+
385
+ intermediate = []
386
+
387
+ for layer in self.layers:
388
+ output = layer(
389
+ output,
390
+ memory,
391
+ tgt_mask=tgt_mask,
392
+ memory_mask=memory_mask,
393
+ tgt_key_padding_mask=tgt_key_padding_mask,
394
+ memory_key_padding_mask=memory_key_padding_mask,
395
+ pos=pos,
396
+ query_pos=query_pos,
397
+ )
398
+ if self.return_intermediate:
399
+ intermediate.append(self.norm(output))
400
+
401
+ if self.norm is not None:
402
+ output = self.norm(output)
403
+ if self.return_intermediate:
404
+ intermediate.pop()
405
+ intermediate.append(output)
406
+
407
+ if self.return_intermediate:
408
+ return torch.stack(intermediate)
409
+
410
+ return output.unsqueeze(0)
411
+
412
+
413
+ class TransformerEncoderLayer(nn.Module):
414
+ def __init__(
415
+ self,
416
+ d_model,
417
+ nhead,
418
+ dim_feedforward=2048,
419
+ dropout=0.1,
420
+ activation=nn.ReLU(),
421
+ normalize_before=False,
422
+ ):
423
+ super().__init__()
424
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
425
+ # Implementation of Feedforward model
426
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
427
+ self.dropout = nn.Dropout(dropout)
428
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
429
+
430
+ self.norm1 = nn.LayerNorm(d_model)
431
+ self.norm2 = nn.LayerNorm(d_model)
432
+ self.dropout1 = nn.Dropout(dropout)
433
+ self.dropout2 = nn.Dropout(dropout)
434
+
435
+ self.activation = activation()
436
+ self.normalize_before = normalize_before
437
+
438
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
439
+ return tensor if pos is None else tensor + pos
440
+
441
+ def forward_post(
442
+ self,
443
+ src,
444
+ src_mask: Optional[Tensor] = None,
445
+ src_key_padding_mask: Optional[Tensor] = None,
446
+ pos: Optional[Tensor] = None,
447
+ ):
448
+ q = k = self.with_pos_embed(src, pos)
449
+ src2 = self.self_attn(
450
+ q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask
451
+ )[0]
452
+ src = src + self.dropout1(src2)
453
+ src = self.norm1(src)
454
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
455
+ src = src + self.dropout2(src2)
456
+ src = self.norm2(src)
457
+ return src
458
+
459
+ def forward_pre(
460
+ self,
461
+ src,
462
+ src_mask: Optional[Tensor] = None,
463
+ src_key_padding_mask: Optional[Tensor] = None,
464
+ pos: Optional[Tensor] = None,
465
+ ):
466
+ src2 = self.norm1(src)
467
+ q = k = self.with_pos_embed(src2, pos)
468
+ src2 = self.self_attn(
469
+ q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask
470
+ )[0]
471
+ src = src + self.dropout1(src2)
472
+ src2 = self.norm2(src)
473
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
474
+ src = src + self.dropout2(src2)
475
+ return src
476
+
477
+ def forward(
478
+ self,
479
+ src,
480
+ src_mask: Optional[Tensor] = None,
481
+ src_key_padding_mask: Optional[Tensor] = None,
482
+ pos: Optional[Tensor] = None,
483
+ ):
484
+ if self.normalize_before:
485
+ return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
486
+ return self.forward_post(src, src_mask, src_key_padding_mask, pos)
487
+
488
+
489
+ class TransformerDecoderLayer(nn.Module):
490
+ def __init__(
491
+ self,
492
+ d_model,
493
+ nhead,
494
+ dim_feedforward=2048,
495
+ dropout=0.1,
496
+ activation=nn.ReLU(),
497
+ normalize_before=False,
498
+ ):
499
+ super().__init__()
500
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
501
+ self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
502
+ # Implementation of Feedforward model
503
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
504
+ self.dropout = nn.Dropout(dropout)
505
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
506
+
507
+ self.norm1 = nn.LayerNorm(d_model)
508
+ self.norm2 = nn.LayerNorm(d_model)
509
+ self.norm3 = nn.LayerNorm(d_model)
510
+ self.dropout1 = nn.Dropout(dropout)
511
+ self.dropout2 = nn.Dropout(dropout)
512
+ self.dropout3 = nn.Dropout(dropout)
513
+
514
+ self.activation = activation()
515
+ self.normalize_before = normalize_before
516
+
517
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
518
+ return tensor if pos is None else tensor + pos
519
+
520
+ def forward_post(
521
+ self,
522
+ tgt,
523
+ memory,
524
+ tgt_mask: Optional[Tensor] = None,
525
+ memory_mask: Optional[Tensor] = None,
526
+ tgt_key_padding_mask: Optional[Tensor] = None,
527
+ memory_key_padding_mask: Optional[Tensor] = None,
528
+ pos: Optional[Tensor] = None,
529
+ query_pos: Optional[Tensor] = None,
530
+ ):
531
+ q = k = self.with_pos_embed(tgt, query_pos)
532
+ tgt2 = self.self_attn(
533
+ q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
534
+ )[0]
535
+ tgt = tgt + self.dropout1(tgt2)
536
+ tgt = self.norm1(tgt)
537
+ tgt2 = self.multihead_attn(
538
+ query=self.with_pos_embed(tgt, query_pos),
539
+ key=self.with_pos_embed(memory, pos),
540
+ value=memory,
541
+ attn_mask=memory_mask,
542
+ key_padding_mask=memory_key_padding_mask,
543
+ )[0]
544
+ tgt = tgt + self.dropout2(tgt2)
545
+ tgt = self.norm2(tgt)
546
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
547
+ tgt = tgt + self.dropout3(tgt2)
548
+ tgt = self.norm3(tgt)
549
+ return tgt
550
+
551
+ def forward_pre(
552
+ self,
553
+ tgt,
554
+ memory,
555
+ tgt_mask: Optional[Tensor] = None,
556
+ memory_mask: Optional[Tensor] = None,
557
+ tgt_key_padding_mask: Optional[Tensor] = None,
558
+ memory_key_padding_mask: Optional[Tensor] = None,
559
+ pos: Optional[Tensor] = None,
560
+ query_pos: Optional[Tensor] = None,
561
+ ):
562
+ tgt2 = self.norm1(tgt)
563
+ q = k = self.with_pos_embed(tgt2, query_pos)
564
+ tgt2 = self.self_attn(
565
+ q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
566
+ )[0]
567
+ tgt = tgt + self.dropout1(tgt2)
568
+ tgt2 = self.norm2(tgt)
569
+ tgt2 = self.multihead_attn(
570
+ query=self.with_pos_embed(tgt2, query_pos),
571
+ key=self.with_pos_embed(memory, pos),
572
+ value=memory,
573
+ attn_mask=memory_mask,
574
+ key_padding_mask=memory_key_padding_mask,
575
+ )[0]
576
+ tgt = tgt + self.dropout2(tgt2)
577
+ tgt2 = self.norm3(tgt)
578
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
579
+ tgt = tgt + self.dropout3(tgt2)
580
+ return tgt
581
+
582
+ def forward(
583
+ self,
584
+ tgt,
585
+ memory,
586
+ tgt_mask: Optional[Tensor] = None,
587
+ memory_mask: Optional[Tensor] = None,
588
+ tgt_key_padding_mask: Optional[Tensor] = None,
589
+ memory_key_padding_mask: Optional[Tensor] = None,
590
+ pos: Optional[Tensor] = None,
591
+ query_pos: Optional[Tensor] = None,
592
+ ):
593
+ if self.normalize_before:
594
+ return self.forward_pre(
595
+ tgt,
596
+ memory,
597
+ tgt_mask,
598
+ memory_mask,
599
+ tgt_key_padding_mask,
600
+ memory_key_padding_mask,
601
+ pos,
602
+ query_pos,
603
+ )
604
+ return self.forward_post(
605
+ tgt,
606
+ memory,
607
+ tgt_mask,
608
+ memory_mask,
609
+ tgt_key_padding_mask,
610
+ memory_key_padding_mask,
611
+ pos,
612
+ query_pos,
613
+ )
614
+
615
+
616
+ def _get_clones(module, N):
617
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
618
+
619
+
620
+ def _get_activation_fn(activation):
621
+ """Return an activation function given a string"""
622
+ if activation == "relu":
623
+ return F.relu
624
+ if activation == "gelu":
625
+ return F.gelu
626
+ if activation == "glu":
627
+ return F.glu
628
+ raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
629
+
630
+
631
+ def build_attn_mask(mask_type):
632
+ mask = torch.ones((151, 151), dtype=torch.bool).cuda()
633
+ if mask_type == "seperate_all":
634
+ mask[:50, :50] = False
635
+ mask[50:67, 50:67] = False
636
+ mask[67:84, 67:84] = False
637
+ mask[84:101, 84:101] = False
638
+ mask[101:151, 101:151] = False
639
+ elif mask_type == "seperate_view":
640
+ mask[:50, :50] = False
641
+ mask[50:67, 50:67] = False
642
+ mask[67:84, 67:84] = False
643
+ mask[84:101, 84:101] = False
644
+ mask[101:151, :] = False
645
+ mask[:, 101:151] = False
646
+ return mask
647
+ # class InterfuserModel(nn.Module):
648
+
649
+ class InterfuserModel(nn.Module):
650
+ def __init__(
651
+ self,
652
+ img_size=224,
653
+ multi_view_img_size=112,
654
+ patch_size=8,
655
+ in_chans=3,
656
+ embed_dim=768,
657
+ enc_depth=6,
658
+ dec_depth=6,
659
+ dim_feedforward=2048,
660
+ normalize_before=False,
661
+ rgb_backbone_name="r50",
662
+ lidar_backbone_name="r50",
663
+ num_heads=8,
664
+ norm_layer=None,
665
+ dropout=0.1,
666
+ end2end=False,
667
+ direct_concat=False,
668
+ separate_view_attention=False,
669
+ separate_all_attention=False,
670
+ act_layer=None,
671
+ weight_init="",
672
+ freeze_num=-1,
673
+ with_lidar=False,
674
+ with_right_left_sensors=False,
675
+ with_center_sensor=False,
676
+ traffic_pred_head_type="det",
677
+ waypoints_pred_head="heatmap",
678
+ reverse_pos=True,
679
+ use_different_backbone=False,
680
+ use_view_embed=False,
681
+ use_mmad_pretrain=None,
682
+ ):
683
+ super().__init__()
684
+ self.traffic_pred_head_type = traffic_pred_head_type
685
+ self.num_features = (
686
+ self.embed_dim
687
+ ) = embed_dim # num_features for consistency with other models
688
+ norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
689
+ act_layer = act_layer or nn.GELU
690
+
691
+ self.reverse_pos = reverse_pos
692
+ self.waypoints_pred_head = waypoints_pred_head
693
+ self.with_lidar = with_lidar
694
+ self.with_right_left_sensors = with_right_left_sensors
695
+ self.with_center_sensor = with_center_sensor
696
+
697
+ self.direct_concat = direct_concat
698
+ self.separate_view_attention = separate_view_attention
699
+ self.separate_all_attention = separate_all_attention
700
+ self.end2end = end2end
701
+ self.use_view_embed = use_view_embed
702
+
703
+ if self.direct_concat:
704
+ in_chans = in_chans * 4
705
+ self.with_center_sensor = False
706
+ self.with_right_left_sensors = False
707
+
708
+ if self.separate_view_attention:
709
+ self.attn_mask = build_attn_mask("seperate_view")
710
+ elif self.separate_all_attention:
711
+ self.attn_mask = build_attn_mask("seperate_all")
712
+ else:
713
+ self.attn_mask = None
714
+
715
+ if use_different_backbone:
716
+ if rgb_backbone_name == "r50":
717
+ self.rgb_backbone = resnet50d(
718
+ pretrained=True,
719
+ in_chans=in_chans,
720
+ features_only=True,
721
+ out_indices=[4],
722
+ )
723
+ elif rgb_backbone_name == "r26":
724
+ self.rgb_backbone = resnet26d(
725
+ pretrained=True,
726
+ in_chans=in_chans,
727
+ features_only=True,
728
+ out_indices=[4],
729
+ )
730
+ elif rgb_backbone_name == "r18":
731
+ self.rgb_backbone = resnet18d(
732
+ pretrained=True,
733
+ in_chans=in_chans,
734
+ features_only=True,
735
+ out_indices=[4],
736
+ )
737
+ if lidar_backbone_name == "r50":
738
+ self.lidar_backbone = resnet50d(
739
+ pretrained=False,
740
+ in_chans=in_chans,
741
+ features_only=True,
742
+ out_indices=[4],
743
+ )
744
+ elif lidar_backbone_name == "r26":
745
+ self.lidar_backbone = resnet26d(
746
+ pretrained=False,
747
+ in_chans=in_chans,
748
+ features_only=True,
749
+ out_indices=[4],
750
+ )
751
+ elif lidar_backbone_name == "r18":
752
+ self.lidar_backbone = resnet18d(
753
+ pretrained=False, in_chans=3, features_only=True, out_indices=[4]
754
+ )
755
+ rgb_embed_layer = partial(HybridEmbed, backbone=self.rgb_backbone)
756
+ lidar_embed_layer = partial(HybridEmbed, backbone=self.lidar_backbone)
757
+
758
+ if use_mmad_pretrain:
759
+ params = torch.load(use_mmad_pretrain)["state_dict"]
760
+ updated_params = OrderedDict()
761
+ for key in params:
762
+ if "backbone" in key:
763
+ updated_params[key.replace("backbone.", "")] = params[key]
764
+ self.rgb_backbone.load_state_dict(updated_params)
765
+
766
+ self.rgb_patch_embed = rgb_embed_layer(
767
+ img_size=img_size,
768
+ patch_size=patch_size,
769
+ in_chans=in_chans,
770
+ embed_dim=embed_dim,
771
+ )
772
+ self.lidar_patch_embed = lidar_embed_layer(
773
+ img_size=img_size,
774
+ patch_size=patch_size,
775
+ in_chans=3,
776
+ embed_dim=embed_dim,
777
+ )
778
+ else:
779
+ if rgb_backbone_name == "r50":
780
+ self.rgb_backbone = resnet50d(
781
+ pretrained=True, in_chans=3, features_only=True, out_indices=[4]
782
+ )
783
+ elif rgb_backbone_name == "r101":
784
+ self.rgb_backbone = resnet101d(
785
+ pretrained=True, in_chans=3, features_only=True, out_indices=[4]
786
+ )
787
+ elif rgb_backbone_name == "r26":
788
+ self.rgb_backbone = resnet26d(
789
+ pretrained=True, in_chans=3, features_only=True, out_indices=[4]
790
+ )
791
+ elif rgb_backbone_name == "r18":
792
+ self.rgb_backbone = resnet18d(
793
+ pretrained=True, in_chans=3, features_only=True, out_indices=[4]
794
+ )
795
+ embed_layer = partial(HybridEmbed, backbone=self.rgb_backbone)
796
+
797
+ self.rgb_patch_embed = embed_layer(
798
+ img_size=img_size,
799
+ patch_size=patch_size,
800
+ in_chans=in_chans,
801
+ embed_dim=embed_dim,
802
+ )
803
+ self.lidar_patch_embed = embed_layer(
804
+ img_size=img_size,
805
+ patch_size=patch_size,
806
+ in_chans=in_chans,
807
+ embed_dim=embed_dim,
808
+ )
809
+
810
+ self.global_embed = nn.Parameter(torch.zeros(1, embed_dim, 5))
811
+ self.view_embed = nn.Parameter(torch.zeros(1, embed_dim, 5, 1))
812
+
813
+ if self.end2end:
814
+ self.query_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, 4))
815
+ self.query_embed = nn.Parameter(torch.zeros(4, 1, embed_dim))
816
+ elif self.waypoints_pred_head == "heatmap":
817
+ self.query_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, 5))
818
+ self.query_embed = nn.Parameter(torch.zeros(400 + 5, 1, embed_dim))
819
+ else:
820
+ self.query_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, 11))
821
+ self.query_embed = nn.Parameter(torch.zeros(400 + 11, 1, embed_dim))
822
+
823
+ if self.end2end:
824
+ self.waypoints_generator = GRUWaypointsPredictor(embed_dim, 4)
825
+ elif self.waypoints_pred_head == "heatmap":
826
+ self.waypoints_generator = MultiPath_Generator(
827
+ embed_dim + 32, embed_dim, 10
828
+ )
829
+ elif self.waypoints_pred_head == "gru":
830
+ self.waypoints_generator = GRUWaypointsPredictor(embed_dim)
831
+ elif self.waypoints_pred_head == "gru-command":
832
+ self.waypoints_generator = GRUWaypointsPredictorWithCommand(embed_dim)
833
+ elif self.waypoints_pred_head == "linear":
834
+ self.waypoints_generator = LinearWaypointsPredictor(embed_dim)
835
+ elif self.waypoints_pred_head == "linear-sum":
836
+ self.waypoints_generator = LinearWaypointsPredictor(embed_dim, cumsum=True)
837
+
838
+ self.junction_pred_head = nn.Linear(embed_dim, 2)
839
+ self.traffic_light_pred_head = nn.Linear(embed_dim, 2)
840
+ self.stop_sign_head = nn.Linear(embed_dim, 2)
841
+
842
+ if self.traffic_pred_head_type == "det":
843
+ self.traffic_pred_head = nn.Sequential(
844
+ *[
845
+ nn.Linear(embed_dim + 32, 64),
846
+ nn.ReLU(),
847
+ nn.Linear(64, 7),
848
+ # nn.Sigmoid(),
849
+ ]
850
+ )
851
+ elif self.traffic_pred_head_type == "seg":
852
+ self.traffic_pred_head = nn.Sequential(
853
+ *[nn.Linear(embed_dim, 64), nn.ReLU(), nn.Linear(64, 1), nn.Sigmoid()]
854
+ )
855
+
856
+ self.position_encoding = PositionEmbeddingSine(embed_dim // 2, normalize=True)
857
+
858
+ encoder_layer = TransformerEncoderLayer(
859
+ embed_dim, num_heads, dim_feedforward, dropout, act_layer, normalize_before
860
+ )
861
+ self.encoder = TransformerEncoder(encoder_layer, enc_depth, None)
862
+
863
+ decoder_layer = TransformerDecoderLayer(
864
+ embed_dim, num_heads, dim_feedforward, dropout, act_layer, normalize_before
865
+ )
866
+ decoder_norm = nn.LayerNorm(embed_dim)
867
+ self.decoder = TransformerDecoder(
868
+ decoder_layer, dec_depth, decoder_norm, return_intermediate=False
869
+ )
870
+ self.reset_parameters()
871
+
872
+ def reset_parameters(self):
873
+ nn.init.uniform_(self.global_embed)
874
+ nn.init.uniform_(self.view_embed)
875
+ nn.init.uniform_(self.query_embed)
876
+ nn.init.uniform_(self.query_pos_embed)
877
+
878
+ def forward_features(
879
+ self,
880
+ front_image,
881
+ left_image,
882
+ right_image,
883
+ front_center_image,
884
+ lidar,
885
+ measurements,
886
+ ):
887
+ features = []
888
+
889
+ # Front view processing
890
+ front_image_token, front_image_token_global = self.rgb_patch_embed(front_image)
891
+ if self.use_view_embed:
892
+ front_image_token = (
893
+ front_image_token
894
+ + self.view_embed[:, :, 0:1, :]
895
+ + self.position_encoding(front_image_token)
896
+ )
897
+ else:
898
+ front_image_token = front_image_token + self.position_encoding(
899
+ front_image_token
900
+ )
901
+ front_image_token = front_image_token.flatten(2).permute(2, 0, 1)
902
+ front_image_token_global = (
903
+ front_image_token_global
904
+ + self.view_embed[:, :, 0, :]
905
+ + self.global_embed[:, :, 0:1]
906
+ )
907
+ front_image_token_global = front_image_token_global.permute(2, 0, 1)
908
+ features.extend([front_image_token, front_image_token_global])
909
+
910
+ if self.with_right_left_sensors:
911
+ # Left view processing
912
+ left_image_token, left_image_token_global = self.rgb_patch_embed(left_image)
913
+ if self.use_view_embed:
914
+ left_image_token = (
915
+ left_image_token
916
+ + self.view_embed[:, :, 1:2, :]
917
+ + self.position_encoding(left_image_token)
918
+ )
919
+ else:
920
+ left_image_token = left_image_token + self.position_encoding(
921
+ left_image_token
922
+ )
923
+ left_image_token = left_image_token.flatten(2).permute(2, 0, 1)
924
+ left_image_token_global = (
925
+ left_image_token_global
926
+ + self.view_embed[:, :, 1, :]
927
+ + self.global_embed[:, :, 1:2]
928
+ )
929
+ left_image_token_global = left_image_token_global.permute(2, 0, 1)
930
+
931
+ # Right view processing
932
+ right_image_token, right_image_token_global = self.rgb_patch_embed(
933
+ right_image
934
+ )
935
+ if self.use_view_embed:
936
+ right_image_token = (
937
+ right_image_token
938
+ + self.view_embed[:, :, 2:3, :]
939
+ + self.position_encoding(right_image_token)
940
+ )
941
+ else:
942
+ right_image_token = right_image_token + self.position_encoding(
943
+ right_image_token
944
+ )
945
+ right_image_token = right_image_token.flatten(2).permute(2, 0, 1)
946
+ right_image_token_global = (
947
+ right_image_token_global
948
+ + self.view_embed[:, :, 2, :]
949
+ + self.global_embed[:, :, 2:3]
950
+ )
951
+ right_image_token_global = right_image_token_global.permute(2, 0, 1)
952
+
953
+ features.extend(
954
+ [
955
+ left_image_token,
956
+ left_image_token_global,
957
+ right_image_token,
958
+ right_image_token_global,
959
+ ]
960
+ )
961
+
962
+ if self.with_center_sensor:
963
+ # Front center view processing
964
+ (
965
+ front_center_image_token,
966
+ front_center_image_token_global,
967
+ ) = self.rgb_patch_embed(front_center_image)
968
+ if self.use_view_embed:
969
+ front_center_image_token = (
970
+ front_center_image_token
971
+ + self.view_embed[:, :, 3:4, :]
972
+ + self.position_encoding(front_center_image_token)
973
+ )
974
+ else:
975
+ front_center_image_token = (
976
+ front_center_image_token
977
+ + self.position_encoding(front_center_image_token)
978
+ )
979
+
980
+ front_center_image_token = front_center_image_token.flatten(2).permute(
981
+ 2, 0, 1
982
+ )
983
+ front_center_image_token_global = (
984
+ front_center_image_token_global
985
+ + self.view_embed[:, :, 3, :]
986
+ + self.global_embed[:, :, 3:4]
987
+ )
988
+ front_center_image_token_global = front_center_image_token_global.permute(
989
+ 2, 0, 1
990
+ )
991
+ features.extend([front_center_image_token, front_center_image_token_global])
992
+
993
+ if self.with_lidar:
994
+ lidar_token, lidar_token_global = self.lidar_patch_embed(lidar)
995
+ if self.use_view_embed:
996
+ lidar_token = (
997
+ lidar_token
998
+ + self.view_embed[:, :, 4:5, :]
999
+ + self.position_encoding(lidar_token)
1000
+ )
1001
+ else:
1002
+ lidar_token = lidar_token + self.position_encoding(lidar_token)
1003
+ lidar_token = lidar_token.flatten(2).permute(2, 0, 1)
1004
+ lidar_token_global = (
1005
+ lidar_token_global
1006
+ + self.view_embed[:, :, 4, :]
1007
+ + self.global_embed[:, :, 4:5]
1008
+ )
1009
+ lidar_token_global = lidar_token_global.permute(2, 0, 1)
1010
+ features.extend([lidar_token, lidar_token_global])
1011
+
1012
+ features = torch.cat(features, 0)
1013
+ return features
1014
+
1015
+ def forward(self, x):
1016
+ front_image = x["rgb"]
1017
+ left_image = x["rgb_left"]
1018
+ right_image = x["rgb_right"]
1019
+ front_center_image = x["rgb_center"]
1020
+ measurements = x["measurements"]
1021
+ target_point = x["target_point"]
1022
+ lidar = x["lidar"]
1023
+
1024
+ if self.direct_concat:
1025
+ img_size = front_image.shape[-1]
1026
+ left_image = torch.nn.functional.interpolate(
1027
+ left_image, size=(img_size, img_size)
1028
+ )
1029
+ right_image = torch.nn.functional.interpolate(
1030
+ right_image, size=(img_size, img_size)
1031
+ )
1032
+ front_center_image = torch.nn.functional.interpolate(
1033
+ front_center_image, size=(img_size, img_size)
1034
+ )
1035
+ front_image = torch.cat(
1036
+ [front_image, left_image, right_image, front_center_image], dim=1
1037
+ )
1038
+ features = self.forward_features(
1039
+ front_image,
1040
+ left_image,
1041
+ right_image,
1042
+ front_center_image,
1043
+ lidar,
1044
+ measurements,
1045
+ )
1046
+
1047
+ bs = front_image.shape[0]
1048
+
1049
+ if self.end2end:
1050
+ tgt = self.query_pos_embed.repeat(bs, 1, 1)
1051
+ else:
1052
+ tgt = self.position_encoding(
1053
+ torch.ones((bs, 1, 20, 20), device=x["rgb"].device)
1054
+ )
1055
+ tgt = tgt.flatten(2)
1056
+ tgt = torch.cat([tgt, self.query_pos_embed.repeat(bs, 1, 1)], 2)
1057
+ tgt = tgt.permute(2, 0, 1)
1058
+
1059
+ memory = self.encoder(features, mask=self.attn_mask)
1060
+ hs = self.decoder(self.query_embed.repeat(1, bs, 1), memory, query_pos=tgt)[0]
1061
+
1062
+ hs = hs.permute(1, 0, 2) # Batchsize , N, C
1063
+ if self.end2end:
1064
+ waypoints = self.waypoints_generator(hs, target_point)
1065
+ return waypoints
1066
+
1067
+ if self.waypoints_pred_head != "heatmap":
1068
+ traffic_feature = hs[:, :400]
1069
+ is_junction_feature = hs[:, 400]
1070
+ traffic_light_state_feature = hs[:, 400]
1071
+ stop_sign_feature = hs[:, 400]
1072
+ waypoints_feature = hs[:, 401:411]
1073
+ else:
1074
+ traffic_feature = hs[:, :400]
1075
+ is_junction_feature = hs[:, 400]
1076
+ traffic_light_state_feature = hs[:, 400]
1077
+ stop_sign_feature = hs[:, 400]
1078
+ waypoints_feature = hs[:, 401:405]
1079
+
1080
+ if self.waypoints_pred_head == "heatmap":
1081
+ waypoints = self.waypoints_generator(waypoints_feature, measurements)
1082
+ elif self.waypoints_pred_head == "gru":
1083
+ waypoints = self.waypoints_generator(waypoints_feature, target_point)
1084
+ elif self.waypoints_pred_head == "gru-command":
1085
+ waypoints = self.waypoints_generator(waypoints_feature, target_point, measurements)
1086
+ elif self.waypoints_pred_head == "linear":
1087
+ waypoints = self.waypoints_generator(waypoints_feature, measurements)
1088
+ elif self.waypoints_pred_head == "linear-sum":
1089
+ waypoints = self.waypoints_generator(waypoints_feature, measurements)
1090
+
1091
+ is_junction = self.junction_pred_head(is_junction_feature)
1092
+ traffic_light_state = self.traffic_light_pred_head(traffic_light_state_feature)
1093
+ stop_sign = self.stop_sign_head(stop_sign_feature)
1094
+
1095
+ velocity = measurements[:, 6:7].unsqueeze(-1)
1096
+ velocity = velocity.repeat(1, 400, 32)
1097
+ traffic_feature_with_vel = torch.cat([traffic_feature, velocity], dim=2)
1098
+ traffic = self.traffic_pred_head(traffic_feature_with_vel)
1099
+ return traffic, waypoints, is_junction, traffic_light_state, stop_sign, traffic_feature
1100
+ def load_pretrained(self, model_path, strict=False):
1101
+ """
1102
+ تحميل الأ��زان المدربة مسبقاً - نسخة محسنة
1103
+
1104
+ Args:
1105
+ model_path (str): مسار ملف الأوزان
1106
+ strict (bool): إذا كان True، يتطلب تطابق تام للمفاتيح
1107
+ """
1108
+ if not model_path or not Path(model_path).exists():
1109
+ logging.warning(f"ملف الأوزان غير موجود: {model_path}")
1110
+ logging.info("سيتم استخدام أوزان عشوائية")
1111
+ return False
1112
+
1113
+ try:
1114
+ logging.info(f"محاولة تحميل الأوزان من: {model_path}")
1115
+
1116
+ # تحميل الملف مع معالجة أنواع مختلفة من ملفات الحفظ
1117
+ checkpoint = torch.load(model_path, map_location='cpu', weights_only=False)
1118
+
1119
+ # استخراج state_dict من أنواع مختلفة من ملفات الحفظ
1120
+ if isinstance(checkpoint, dict):
1121
+ if 'model_state_dict' in checkpoint:
1122
+ state_dict = checkpoint['model_state_dict']
1123
+ logging.info("تم العثور على 'model_state_dict' في الملف")
1124
+ elif 'state_dict' in checkpoint:
1125
+ state_dict = checkpoint['state_dict']
1126
+ logging.info("تم العثور على 'state_dict' في الملف")
1127
+ elif 'model' in checkpoint:
1128
+ state_dict = checkpoint['model']
1129
+ logging.info("تم العثور على 'model' في الملف")
1130
+ else:
1131
+ state_dict = checkpoint
1132
+ logging.info("استخدام الملف كـ state_dict مباشرة")
1133
+ else:
1134
+ state_dict = checkpoint
1135
+ logging.info("استخدام الملف كـ state_dict مباشرة")
1136
+
1137
+ # تنظيف أسماء المفاتيح (إزالة 'module.' إذا كانت موجودة)
1138
+ clean_state_dict = OrderedDict()
1139
+ for k, v in state_dict.items():
1140
+ # إزالة 'module.' من بداية اسم المفتاح إذا كان موجوداً
1141
+ clean_key = k[7:] if k.startswith('module.') else k
1142
+ clean_state_dict[clean_key] = v
1143
+
1144
+ # تحميل الأوزان
1145
+ missing_keys, unexpected_keys = self.load_state_dict(clean_state_dict, strict=strict)
1146
+
1147
+ # تقرير حالة التحميل
1148
+ if missing_keys:
1149
+ logging.warning(f"مفاتيح مفقودة ({len(missing_keys)}): {missing_keys[:5]}..." if len(missing_keys) > 5 else f"مفاتيح مفقودة: {missing_keys}")
1150
+
1151
+ if unexpected_keys:
1152
+ logging.warning(f"مفاتيح غير متوقعة ({len(unexpected_keys)}): {unexpected_keys[:5]}..." if len(unexpected_keys) > 5 else f"مفاتيح غير متوقعة: {unexpected_keys}")
1153
+
1154
+ if not missing_keys and not unexpected_keys:
1155
+ logging.info("✅ تم تحميل جميع الأوزان بنجاح تام")
1156
+ elif not strict:
1157
+ logging.info("✅ تم تحميل الأوزان بنجاح (مع تجاهل عدم التطابق)")
1158
+
1159
+ return True
1160
+
1161
+ except Exception as e:
1162
+ logging.error(f"❌ خطأ في تحميل الأوزان: {str(e)}")
1163
+ logging.info("سيتم استخدام أوزان عشوائية")
1164
+ return False
1165
+
1166
+
1167
+ # ============================================================================
1168
+ # دوال مساعدة لتحميل النموذج
1169
+ # ============================================================================
1170
+
1171
+ def load_and_prepare_model(config, device):
1172
+ """
1173
+ يقوم بإنشاء النموذج وتحميل الأوزان المدربة مسبقًا.
1174
+
1175
+ Args:
1176
+ config (dict): إعدادات النموذج والمسارات
1177
+ device (torch.device): الجهاز المستهدف (CPU/GPU)
1178
+
1179
+ Returns:
1180
+ InterfuserModel: النموذج المحمل
1181
+ """
1182
+ try:
1183
+ # إنشاء النموذج
1184
+ model = InterfuserModel(**config.get('model_params', {})).to(device)
1185
+ logging.info(f"تم إنشاء النموذج على الجهاز: {device}")
1186
+
1187
+ # تحميل الأوزان إذا كان المسار محدد
1188
+ checkpoint_path = config.get('paths', {}).get('pretrained_weights')
1189
+ if checkpoint_path:
1190
+ success = model.load_pretrained(checkpoint_path, strict=False)
1191
+ if success:
1192
+ logging.info("✅ تم تحميل النموذج والأوزان بنجاح")
1193
+ else:
1194
+ logging.warning("⚠️ تم إنشاء النموذج بأوزان عشوائية")
1195
+ else:
1196
+ logging.info("لم يتم تحديد مسار الأوزان، سيتم استخدام أوزان عشوائية")
1197
+
1198
+ # وضع النموذج في وضع التقييم
1199
+ model.eval()
1200
+
1201
+ return model
1202
+
1203
+ except Exception as e:
1204
+ logging.error(f"خطأ في إنشاء النموذج: {str(e)}")
1205
+ raise
1206
+
1207
+
1208
+ def create_model_config(model_path="model/best_model.pth", **model_params):
1209
+ """
1210
+ إنشاء إعدادات النموذج باستخدام الإعدادات الصحيحة من التدريب
1211
+
1212
+ Args:
1213
+ model_path (str): مسار ملف الأوزان
1214
+ **model_params: معاملات النموذج الإضافية
1215
+
1216
+ Returns:
1217
+ dict: إعدادات النموذج
1218
+ """
1219
+ # الإعدادات الصحيحة من كونفيج التدريب الأصلي
1220
+ training_config_params = {
1221
+ "img_size": 224,
1222
+ "embed_dim": 256, # مهم: هذه القيمة من التدريب الأصلي
1223
+ "enc_depth": 6,
1224
+ "dec_depth": 6,
1225
+ "rgb_backbone_name": 'r50',
1226
+ "lidar_backbone_name": 'r18',
1227
+ "waypoints_pred_head": 'gru',
1228
+ "use_different_backbone": True,
1229
+ "with_lidar": False,
1230
+ "with_right_left_sensors": False,
1231
+ "with_center_sensor": False,
1232
+
1233
+ # إعدادات إضافية من الكونفيج الأصلي
1234
+ "multi_view_img_size": 112,
1235
+ "patch_size": 8,
1236
+ "in_chans": 3,
1237
+ "dim_feedforward": 2048,
1238
+ "normalize_before": False,
1239
+ "num_heads": 8,
1240
+ "dropout": 0.1,
1241
+ "end2end": False,
1242
+ "direct_concat": False,
1243
+ "separate_view_attention": False,
1244
+ "separate_all_attention": False,
1245
+ "freeze_num": -1,
1246
+ "traffic_pred_head_type": "det",
1247
+ "reverse_pos": True,
1248
+ "use_view_embed": False,
1249
+ "use_mmad_pretrain": None,
1250
+ }
1251
+
1252
+ # دمج المعاملات المخصصة مع الإعدادات من التدريب
1253
+ training_config_params.update(model_params)
1254
+
1255
+ config = {
1256
+ 'model_params': training_config_params,
1257
+ 'paths': {
1258
+ 'pretrained_weights': model_path
1259
+ },
1260
+
1261
+ # إضافة إعدادات الشبكة من التدريب
1262
+ 'grid_conf': {
1263
+ 'h': 20, 'w': 20,
1264
+ 'x_res': 1.0, 'y_res': 1.0,
1265
+ 'y_min': 0.0, 'y_max': 20.0,
1266
+ 'x_min': -10.0, 'x_max': 10.0,
1267
+ },
1268
+
1269
+ # معلومات إضافية عن التدريب
1270
+ 'training_info': {
1271
+ 'original_project': 'Interfuser_Finetuning',
1272
+ 'run_name': 'Finetune_Focus_on_Detection_v5',
1273
+ 'focus': 'traffic_detection_and_iou',
1274
+ 'backbone': 'ResNet50 + ResNet18',
1275
+ 'trained_on': 'PDM_Lite_Carla'
1276
+ }
1277
+ }
1278
+
1279
+ return config
1280
+
1281
+
1282
+ def get_training_config():
1283
+ """
1284
+ إرجاع إعدادات التدريب الأصلية للمرجع
1285
+ هذه الإعدادات توضح كيف تم تدريب النموذج
1286
+ """
1287
+ return {
1288
+ 'project_info': {
1289
+ 'project': 'Interfuser_Finetuning',
1290
+ 'entity': None,
1291
+ 'run_name': 'Finetune_Focus_on_Detection_v5'
1292
+ },
1293
+ 'training': {
1294
+ 'epochs': 50,
1295
+ 'batch_size': 8,
1296
+ 'num_workers': 2,
1297
+ 'learning_rate': 1e-4, # معدل تعلم منخفض للـ Fine-tuning
1298
+ 'weight_decay': 1e-2,
1299
+ 'patience': 15,
1300
+ 'clip_grad_norm': 1.0,
1301
+ },
1302
+ 'loss_weights': {
1303
+ 'iou': 2.0, # أولوية قصوى لدقة الصناديق
1304
+ 'traffic_map': 25.0, # تركيز عالي على اكتشاف الكائنات
1305
+ 'waypoints': 1.0, # مرجع أساسي
1306
+ 'junction': 0.25, # مهام متقنة بالفعل
1307
+ 'traffic_light': 0.5,
1308
+ 'stop_sign': 0.25,
1309
+ },
1310
+ 'data_split': {
1311
+ 'strategy': 'interleaved',
1312
+ 'segment_length': 100,
1313
+ 'validation_frequency': 10,
1314
+ },
1315
+ 'transforms': {
1316
+ 'use_data_augmentation': False, # معطل للتركيز على البيانات الأصلية
1317
+ }
1318
+ }
requirements.txt ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # مكتبات الخادم الأساسية
2
+ fastapi==0.104.1
3
+ uvicorn[standard]==0.24.0
4
+ python-multipart==0.0.6
5
+ pydantic==2.4.2
6
+
7
+ # مكتبات التعلم العميق
8
+ torch==2.1.0 --extra-index-url https://download.pytorch.org/whl/cpu
9
+ torchvision==0.16.0 --extra-index-url https://download.pytorch.org/whl/cpu
10
+
11
+ # مكتبات معالجة البيانات
12
+ numpy==1.24.3
13
+ opencv-python-headless==4.8.1.78
14
+ Pillow==10.0.1
15
+
16
+ # مكتبات إضافية للنموذج المحسن
17
+ tqdm==4.66.1
18
+ pathlib2==2.3.7
19
+
20
+ # مكتبات اختيارية (يمكن تثبيتها حسب الحاجة)
21
+ # wandb==0.15.12 # للمراقبة والتتبع
22
+ # timm==0.9.7 # لنماذج الرؤية الحاسوبية
simulation_modules.py ADDED
@@ -0,0 +1,336 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # simulation_modules.py
2
+
3
+ import torch
4
+ import numpy as np
5
+ import cv2
6
+ import math
7
+ from collections import deque
8
+ from typing import List, Tuple, Dict, Any, Optional
9
+
10
+ # ================== Constants ==================
11
+ WAYPOINT_SCALE_FACTOR = 5.0
12
+ T1_FUTURE_TIME = 1.0
13
+ T2_FUTURE_TIME = 2.0
14
+ PIXELS_PER_METER = 8
15
+ MAX_DISTANCE = 32
16
+ IMG_SIZE = MAX_DISTANCE * PIXELS_PER_METER * 2
17
+ EGO_CAR_X = IMG_SIZE // 2
18
+ EGO_CAR_Y = IMG_SIZE - (4.0 * PIXELS_PER_METER)
19
+
20
+ COLORS = {
21
+ 'vehicle': [255, 0, 0],
22
+ 'pedestrian': [0, 255, 0],
23
+ 'cyclist': [0, 0, 255],
24
+ 'waypoint': [255, 255, 0],
25
+ 'ego_car': [255, 255, 255]
26
+ }
27
+
28
+ # ================== PID Controller ==================
29
+ class PIDController:
30
+ def __init__(self, K_P=1.0, K_I=0.0, K_D=0.0, n=20):
31
+ self._K_P = K_P
32
+ self._K_I = K_I
33
+ self._K_D = K_D
34
+ self._window = deque([0 for _ in range(n)], maxlen=n)
35
+
36
+ def step(self, error):
37
+ self._window.append(error)
38
+ if len(self._window) >= 2:
39
+ integral = np.mean(self._window)
40
+ derivative = self._window[-1] - self._window[-2]
41
+ else:
42
+ integral = derivative = 0.0
43
+ return self._K_P * error + self._K_I * integral + self._K_D * derivative
44
+
45
+ # ================== Helper Functions ==================
46
+ def ensure_rgb(image):
47
+ if len(image.shape) == 2:
48
+ return cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
49
+ elif image.shape[2] == 1:
50
+ return cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
51
+ return image
52
+
53
+ def add_rect(img, loc, ori, box, value, color):
54
+ center_x = int(loc[0] * PIXELS_PER_METER + MAX_DISTANCE * PIXELS_PER_METER)
55
+ center_y = int(loc[1] * PIXELS_PER_METER + MAX_DISTANCE * PIXELS_PER_METER)
56
+ size_px = (int(box[0] * PIXELS_PER_METER), int(box[1] * PIXELS_PER_METER))
57
+ angle_deg = -np.degrees(math.atan2(ori[1], ori[0]))
58
+ box_points = cv2.boxPoints(((center_x, center_y), size_px, angle_deg))
59
+ box_points = np.int32(box_points)
60
+ adjusted_color = [int(c * value) for c in color]
61
+ cv2.fillConvexPoly(img, box_points, adjusted_color)
62
+ return img
63
+
64
+ def render(traffic_grid, t=0):
65
+ img = np.zeros((IMG_SIZE, IMG_SIZE, 3), dtype=np.uint8)
66
+ counts = {'vehicles': 0, 'pedestrians': 0, 'cyclists': 0}
67
+
68
+ if isinstance(traffic_grid, torch.Tensor):
69
+ traffic_grid = traffic_grid.cpu().numpy()
70
+
71
+ h, w, c = traffic_grid.shape
72
+ for y in range(h):
73
+ for x in range(w):
74
+ for ch in range(c):
75
+ if traffic_grid[y, x, ch] > 0.1:
76
+ world_x = (x / w - 0.5) * MAX_DISTANCE * 2
77
+ world_y = (y / h - 0.5) * MAX_DISTANCE * 2
78
+
79
+ if ch < 3:
80
+ color = COLORS['vehicle']
81
+ counts['vehicles'] += 1
82
+ box_size = [2.0, 4.0]
83
+ elif ch < 5:
84
+ color = COLORS['pedestrian']
85
+ counts['pedestrians'] += 1
86
+ box_size = [0.8, 0.8]
87
+ else:
88
+ color = COLORS['cyclist']
89
+ counts['cyclists'] += 1
90
+ box_size = [1.2, 2.0]
91
+
92
+ img = add_rect(img, [world_x, world_y], [1.0, 0.0],
93
+ box_size, traffic_grid[y, x, ch], color)
94
+
95
+ return img, counts
96
+
97
+ def render_waypoints(waypoints, scale_factor=WAYPOINT_SCALE_FACTOR):
98
+ img = np.zeros((IMG_SIZE, IMG_SIZE, 3), dtype=np.uint8)
99
+
100
+ if isinstance(waypoints, torch.Tensor):
101
+ waypoints = waypoints.cpu().numpy()
102
+
103
+ scaled_waypoints = waypoints * scale_factor
104
+
105
+ for i, wp in enumerate(scaled_waypoints):
106
+ px = int(wp[0] * PIXELS_PER_METER + IMG_SIZE // 2)
107
+ py = int(wp[1] * PIXELS_PER_METER + IMG_SIZE // 2)
108
+
109
+ if 0 <= px < IMG_SIZE and 0 <= py < IMG_SIZE:
110
+ radius = max(3, 8 - i)
111
+ cv2.circle(img, (px, py), radius, COLORS['waypoint'], -1)
112
+
113
+ if i > 0:
114
+ prev_px = int(scaled_waypoints[i-1][0] * PIXELS_PER_METER + IMG_SIZE // 2)
115
+ prev_py = int(scaled_waypoints[i-1][1] * PIXELS_PER_METER + IMG_SIZE // 2)
116
+ if 0 <= prev_px < IMG_SIZE and 0 <= prev_py < IMG_SIZE:
117
+ cv2.line(img, (prev_px, prev_py), (px, py), COLORS['waypoint'], 2)
118
+
119
+ return img
120
+
121
+ def render_self_car(img):
122
+ car_pos = [0, -4.0]
123
+ car_ori = [1.0, 0.0]
124
+ car_size = [2.0, 4.5]
125
+ return add_rect(img, car_pos, car_ori, car_size, 1.0, COLORS['ego_car'])
126
+
127
+ # ================== Tracker Classes ==================
128
+ class TrackedObject:
129
+ def __init__(self, obj_id: int):
130
+ self.id = obj_id
131
+ self.last_step = 0
132
+ self.last_pos = [0.0, 0.0]
133
+ self.historical_pos = []
134
+ self.historical_steps = []
135
+ self.velocity = [0.0, 0.0]
136
+ self.confidence = 1.0
137
+
138
+ def update(self, step: int, obj_info: List[float]):
139
+ self.last_step = step
140
+ self.last_pos = obj_info[:2]
141
+ self.historical_pos.append(obj_info[:2])
142
+ self.historical_steps.append(step)
143
+
144
+ if len(self.historical_pos) >= 2:
145
+ dt = self.historical_steps[-1] - self.historical_steps[-2]
146
+ if dt > 0:
147
+ dx = self.historical_pos[-1][0] - self.historical_pos[-2][0]
148
+ dy = self.historical_pos[-1][1] - self.historical_pos[-2][1]
149
+ self.velocity = [dx/dt, dy/dt]
150
+
151
+ def predict_position(self, future_time: float) -> List[float]:
152
+ predicted_x = self.last_pos[0] + self.velocity[0] * future_time
153
+ predicted_y = self.last_pos[1] + self.velocity[1] * future_time
154
+ return [predicted_x, predicted_y]
155
+
156
+ def is_alive(self, current_step: int, max_age: int = 5) -> bool:
157
+ return (current_step - self.last_step) <= max_age
158
+
159
+ class Tracker:
160
+ def __init__(self, frequency: int = 10):
161
+ self.tracks: List[TrackedObject] = []
162
+ self.frequency = frequency
163
+ self.next_id = 0
164
+ self.current_step = 0
165
+
166
+ def update_and_predict(self, detections: List[Dict], step: int) -> np.ndarray:
167
+ self.current_step = step
168
+
169
+ for detection in detections:
170
+ pos = detection.get('position', [0, 0])
171
+ feature = detection.get('feature', 0.5)
172
+
173
+ best_match = None
174
+ min_distance = float('inf')
175
+
176
+ for track in self.tracks:
177
+ if track.is_alive(step):
178
+ distance = np.linalg.norm(np.array(pos) - np.array(track.last_pos))
179
+ if distance < min_distance and distance < 2.0:
180
+ min_distance = distance
181
+ best_match = track
182
+
183
+ if best_match:
184
+ best_match.update(step, pos + [feature])
185
+ else:
186
+ new_track = TrackedObject(self.next_id)
187
+ new_track.update(step, pos + [feature])
188
+ self.tracks.append(new_track)
189
+ self.next_id += 1
190
+
191
+ self.tracks = [t for t in self.tracks if t.is_alive(step)]
192
+ return self._generate_prediction_grid()
193
+
194
+ def _generate_prediction_grid(self) -> np.ndarray:
195
+ grid = np.zeros((20, 20, 7), dtype=np.float32)
196
+
197
+ for track in self.tracks:
198
+ if track.is_alive(self.current_step):
199
+ current_pos = track.last_pos
200
+ future_pos_t1 = track.predict_position(T1_FUTURE_TIME)
201
+ future_pos_t2 = track.predict_position(T2_FUTURE_TIME)
202
+
203
+ for pos in [current_pos, future_pos_t1, future_pos_t2]:
204
+ grid_x = int((pos[0] / (MAX_DISTANCE * 2) + 0.5) * 20)
205
+ grid_y = int((pos[1] / (MAX_DISTANCE * 2) + 0.5) * 20)
206
+
207
+ if 0 <= grid_x < 20 and 0 <= grid_y < 20:
208
+ channel = 0
209
+ grid[grid_y, grid_x, channel] = max(grid[grid_y, grid_x, channel], track.confidence)
210
+
211
+ return grid
212
+
213
+ # ================== Controller Classes ==================
214
+ class ControllerConfig:
215
+ def __init__(self):
216
+ self.turn_KP = 1.0
217
+ self.turn_KI = 0.1
218
+ self.turn_KD = 0.1
219
+ self.turn_n = 20
220
+
221
+ self.speed_KP = 0.5
222
+ self.speed_KI = 0.05
223
+ self.speed_KD = 0.1
224
+ self.speed_n = 20
225
+
226
+ self.max_speed = 6.0
227
+ self.max_throttle = 0.75
228
+ self.clip_delta = 0.25
229
+
230
+ self.brake_speed = 0.4
231
+ self.brake_ratio = 1.1
232
+
233
+ class InterfuserController:
234
+ def __init__(self, config: ControllerConfig):
235
+ self.config = config
236
+ self.turn_controller = PIDController(config.turn_KP, config.turn_KI, config.turn_KD, config.turn_n)
237
+ self.speed_controller = PIDController(config.speed_KP, config.speed_KI, config.speed_KD, config.speed_n)
238
+ self.last_steer = 0.0
239
+ self.last_throttle = 0.0
240
+ self.target_speed = 3.0
241
+
242
+ def run_step(self, current_speed: float, waypoints: np.ndarray,
243
+ junction: float, traffic_light_state: float,
244
+ stop_sign: float, meta_data: Dict) -> Tuple[float, float, bool, str]:
245
+
246
+ if isinstance(waypoints, torch.Tensor):
247
+ waypoints = waypoints.cpu().numpy()
248
+
249
+ if len(waypoints) > 1:
250
+ dx = waypoints[1][0] - waypoints[0][0]
251
+ dy = waypoints[1][1] - waypoints[0][1]
252
+ target_yaw = math.atan2(dy, dx)
253
+ steer = self.turn_controller.step(target_yaw)
254
+ else:
255
+ steer = 0.0
256
+
257
+ steer = np.clip(steer, -1.0, 1.0)
258
+
259
+ target_speed = self.target_speed
260
+ if junction > 0.5:
261
+ target_speed *= 0.7
262
+ if abs(steer) > 0.3:
263
+ target_speed *= 0.8
264
+
265
+ speed_error = target_speed - current_speed
266
+ throttle = self.speed_controller.step(speed_error)
267
+ throttle = np.clip(throttle, 0.0, self.config.max_throttle)
268
+
269
+ brake = False
270
+ if traffic_light_state > 0.5 or stop_sign > 0.5 or current_speed > self.config.max_speed:
271
+ brake = True
272
+ throttle = 0.0
273
+
274
+ self.last_steer = steer
275
+ self.last_throttle = throttle
276
+
277
+ metadata = f"Speed:{current_speed:.1f} Target:{target_speed:.1f} Junction:{junction:.2f}"
278
+
279
+ return steer, throttle, brake, metadata
280
+
281
+ # ================== Display Interface ==================
282
+ class DisplayInterface:
283
+ def __init__(self, width: int = 1200, height: int = 600):
284
+ self._width = width
285
+ self._height = height
286
+ self.camera_width = width // 2
287
+ self.camera_height = height
288
+ self.map_width = width // 2
289
+ self.map_height = height // 3
290
+
291
+ def run_interface(self, data: Dict[str, Any]) -> np.ndarray:
292
+ dashboard = np.zeros((self._height, self._width, 3), dtype=np.uint8)
293
+
294
+ # Camera view
295
+ camera_view = data.get('camera_view')
296
+ if camera_view is not None:
297
+ camera_resized = cv2.resize(camera_view, (self.camera_width, self.camera_height))
298
+ dashboard[:, :self.camera_width] = camera_resized
299
+
300
+ # Maps
301
+ map_start_x = self.camera_width
302
+
303
+ map_t0 = data.get('map_t0')
304
+ if map_t0 is not None:
305
+ map_resized = cv2.resize(map_t0, (self.map_width, self.map_height))
306
+ dashboard[:self.map_height, map_start_x:] = map_resized
307
+ cv2.putText(dashboard, "Current (t=0)", (map_start_x + 10, 30),
308
+ cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
309
+
310
+ map_t1 = data.get('map_t1')
311
+ if map_t1 is not None:
312
+ map_resized = cv2.resize(map_t1, (self.map_width, self.map_height))
313
+ y_start = self.map_height
314
+ dashboard[y_start:y_start + self.map_height, map_start_x:] = map_resized
315
+ cv2.putText(dashboard, f"Future (t={T1_FUTURE_TIME}s)",
316
+ (map_start_x + 10, y_start + 30), cv2.FONT_HERSHEY_SIMPLEX,
317
+ 0.7, (255, 255, 255), 2)
318
+
319
+ map_t2 = data.get('map_t2')
320
+ if map_t2 is not None:
321
+ map_resized = cv2.resize(map_t2, (self.map_width, self.map_height))
322
+ y_start = self.map_height * 2
323
+ dashboard[y_start:, map_start_x:] = map_resized
324
+ cv2.putText(dashboard, f"Future (t={T2_FUTURE_TIME}s)",
325
+ (map_start_x + 10, y_start + 30), cv2.FONT_HERSHEY_SIMPLEX,
326
+ 0.7, (255, 255, 255), 2)
327
+
328
+ # Text info
329
+ text_info = data.get('text_info', {})
330
+ y_offset = 50
331
+ for key, value in text_info.items():
332
+ cv2.putText(dashboard, value, (10, y_offset), cv2.FONT_HERSHEY_SIMPLEX,
333
+ 0.6, (0, 255, 0), 2)
334
+ y_offset += 30
335
+
336
+ return dashboard