Spaces:
Running
Running
# app.py - InterFuser Self-Driving API Server | |
import uuid | |
import base64 | |
import cv2 | |
import torch | |
import numpy as np | |
from fastapi import FastAPI, HTTPException | |
from fastapi.responses import HTMLResponse | |
from pydantic import BaseModel | |
from torchvision import transforms | |
from typing import List, Dict, Any, Optional | |
import logging | |
# استيراد من ملفاتنا المحلية | |
from model_definition import InterfuserModel, load_and_prepare_model, create_model_config | |
from simulation_modules import ( | |
InterfuserController, ControllerConfig, Tracker, DisplayInterface, | |
render, render_waypoints, render_self_car, WAYPOINT_SCALE_FACTOR, | |
T1_FUTURE_TIME, T2_FUTURE_TIME | |
) | |
# إعداد التسجيل | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# ================== إعدادات عامة وتحميل النموذج ================== | |
app = FastAPI( | |
title="Baseer Self-Driving API", | |
description="API للقيادة الذاتية باستخدام نموذج InterFuser", | |
version="1.0.0" | |
) | |
device = torch.device("cpu") | |
logger.info(f"Using device: {device}") | |
# تحميل النموذج باستخدام الدالة المحسنة | |
try: | |
# إنشاء إعدادات النموذج باستخدام الإعدادات الصحيحة من التدريب | |
model_config = create_model_config( | |
model_path="model/best_model.pth" | |
# الإعدادات الصحيحة من التدريب ستطبق تلقائياً: | |
# embed_dim=256, rgb_backbone_name='r50', waypoints_pred_head='gru' | |
# with_lidar=False, with_right_left_sensors=False, with_center_sensor=False | |
) | |
# تحميل النموذج مع الأوزان | |
model = load_and_prepare_model(model_config, device) | |
logger.info("✅ تم تحميل النموذج بنجاح") | |
except Exception as e: | |
logger.error(f"❌ خطأ في تحميل النموذج: {e}") | |
logger.info("🔄 محاولة إنشاء نموذج بأوزان عشوائية...") | |
try: | |
model = InterfuserModel() | |
model.to(device) | |
model.eval() | |
logger.warning("⚠️ تم إنشاء النموذج بأوزان عشوائية") | |
except Exception as e2: | |
logger.error(f"❌ فشل في إنشاء النموذج: {e2}") | |
model = None | |
# تهيئة واجهة العرض | |
display = DisplayInterface() | |
# قاموس لتخزين جلسات المستخدمين | |
SESSIONS: Dict[str, Dict] = {} | |
# ================== هياكل بيانات Pydantic ================== | |
class Measurements(BaseModel): | |
pos: List[float] = [0.0, 0.0] # [x, y] position | |
theta: float = 0.0 # orientation angle | |
speed: float = 0.0 # current speed | |
steer: float = 0.0 # current steering | |
throttle: float = 0.0 # current throttle | |
brake: bool = False # brake status | |
command: int = 4 # driving command (4 = FollowLane) | |
target_point: List[float] = [0.0, 0.0] # target point [x, y] | |
class ModelOutputs(BaseModel): | |
traffic: List[List[List[float]]] # 20x20x7 grid | |
waypoints: List[List[float]] # Nx2 waypoints | |
is_junction: float | |
traffic_light_state: float | |
stop_sign: float | |
class ControlCommands(BaseModel): | |
steer: float | |
throttle: float | |
brake: bool | |
class RunStepInput(BaseModel): | |
session_id: str | |
image_b64: str | |
measurements: Measurements | |
class RunStepOutput(BaseModel): | |
model_outputs: ModelOutputs | |
control_commands: ControlCommands | |
dashboard_image_b64: str | |
class SessionResponse(BaseModel): | |
session_id: str | |
message: str | |
# ================== دوال المساعدة ================== | |
def get_image_transform(): | |
"""إنشاء تحويلات الصورة كما في PDMDataset""" | |
return transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Resize((224, 224), antialias=True), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
]) | |
# إنشاء كائن التحويل مرة واحدة | |
image_transform = get_image_transform() | |
def preprocess_input(frame_rgb: np.ndarray, measurements: Measurements, device: torch.device) -> Dict[str, torch.Tensor]: | |
""" | |
تحاكي ما يفعله PDMDataset.__getitem__ لإنشاء دفعة (batch) واحدة. | |
""" | |
# 1. معالجة الصورة الرئيسية | |
from PIL import Image | |
if isinstance(frame_rgb, np.ndarray): | |
frame_rgb = Image.fromarray(frame_rgb) | |
image_tensor = image_transform(frame_rgb).unsqueeze(0).to(device) # إضافة بُعد الدفعة | |
# 2. إنشاء مدخلات الكاميرات الأخرى عن طريق الاستنساخ | |
batch = { | |
'rgb': image_tensor, | |
'rgb_left': image_tensor.clone(), | |
'rgb_right': image_tensor.clone(), | |
'rgb_center': image_tensor.clone(), | |
} | |
# 3. إنشاء مدخل ليدار وهمي (أصفار) | |
batch['lidar'] = torch.zeros(1, 3, 224, 224, dtype=torch.float32).to(device) | |
# 4. تجميع القياسات بنفس ترتيب PDMDataset | |
m = measurements | |
measurements_tensor = torch.tensor([[ | |
m.pos[0], m.pos[1], m.theta, | |
m.steer, m.throttle, float(m.brake), | |
m.speed, float(m.command) | |
]], dtype=torch.float32).to(device) | |
batch['measurements'] = measurements_tensor | |
# 5. إنشاء نقطة هدف | |
batch['target_point'] = torch.tensor([m.target_point], dtype=torch.float32).to(device) | |
# لا نحتاج إلى قيم ground truth (gt_*) أثناء التنبؤ | |
return batch | |
def decode_base64_image(image_b64: str) -> np.ndarray: | |
""" | |
فك تشفير صورة Base64 | |
""" | |
try: | |
image_bytes = base64.b64decode(image_b64) | |
nparr = np.frombuffer(image_bytes, np.uint8) | |
image = cv2.imdecode(nparr, cv2.IMREAD_COLOR) | |
return image | |
except Exception as e: | |
raise HTTPException(status_code=400, detail=f"Invalid image format: {str(e)}") | |
def encode_image_to_base64(image: np.ndarray) -> str: | |
""" | |
تشفير صورة إلى Base64 | |
""" | |
_, buffer = cv2.imencode('.jpg', image, [cv2.IMWRITE_JPEG_QUALITY, 85]) | |
return base64.b64encode(buffer).decode('utf-8') | |
# ================== نقاط نهاية الـ API ================== | |
async def root(): | |
""" | |
الصفحة الرئيسية للـ API | |
""" | |
html_content = f""" | |
<!DOCTYPE html> | |
<html dir="rtl" lang="ar"> | |
<head> | |
<meta charset="UTF-8"> | |
<meta name="viewport" content="width=device-width, initial-scale=1.0"> | |
<title>🚗 Baseer Self-Driving API</title> | |
<style> | |
* {{ | |
margin: 0; | |
padding: 0; | |
box-sizing: border-box; | |
}} | |
body {{ | |
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; | |
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
min-height: 100vh; | |
display: flex; | |
align-items: center; | |
justify-content: center; | |
padding: 20px; | |
}} | |
.container {{ | |
background: rgba(255, 255, 255, 0.95); | |
backdrop-filter: blur(10px); | |
border-radius: 20px; | |
padding: 40px; | |
box-shadow: 0 20px 40px rgba(0, 0, 0, 0.1); | |
text-align: center; | |
max-width: 600px; | |
width: 100%; | |
}} | |
.logo {{ | |
font-size: 4rem; | |
margin-bottom: 20px; | |
animation: bounce 2s infinite; | |
}} | |
@keyframes bounce {{ | |
0%, 20%, 50%, 80%, 100% {{ transform: translateY(0); }} | |
40% {{ transform: translateY(-10px); }} | |
60% {{ transform: translateY(-5px); }} | |
}} | |
h1 {{ | |
color: #333; | |
margin-bottom: 10px; | |
font-size: 2.5rem; | |
}} | |
.subtitle {{ | |
color: #666; | |
margin-bottom: 30px; | |
font-size: 1.2rem; | |
}} | |
.status {{ | |
display: inline-block; | |
background: #4CAF50; | |
color: white; | |
padding: 8px 16px; | |
border-radius: 20px; | |
margin: 10px 0; | |
font-weight: bold; | |
}} | |
.stats {{ | |
display: grid; | |
grid-template-columns: repeat(auto-fit, minmax(150px, 1fr)); | |
gap: 20px; | |
margin: 30px 0; | |
}} | |
.stat-card {{ | |
background: #f8f9fa; | |
padding: 20px; | |
border-radius: 15px; | |
border-left: 4px solid #667eea; | |
}} | |
.stat-number {{ | |
font-size: 2rem; | |
font-weight: bold; | |
color: #667eea; | |
}} | |
.stat-label {{ | |
color: #666; | |
margin-top: 5px; | |
}} | |
.buttons {{ | |
display: flex; | |
gap: 15px; | |
justify-content: center; | |
flex-wrap: wrap; | |
margin-top: 30px; | |
}} | |
.btn {{ | |
display: inline-block; | |
padding: 12px 24px; | |
border-radius: 25px; | |
text-decoration: none; | |
font-weight: bold; | |
transition: all 0.3s ease; | |
border: none; | |
cursor: pointer; | |
}} | |
.btn-primary {{ | |
background: #667eea; | |
color: white; | |
}} | |
.btn-secondary {{ | |
background: #6c757d; | |
color: white; | |
}} | |
.btn:hover {{ | |
transform: translateY(-2px); | |
box-shadow: 0 5px 15px rgba(0, 0, 0, 0.2); | |
}} | |
.features {{ | |
text-align: right; | |
margin-top: 30px; | |
padding: 20px; | |
background: #f8f9fa; | |
border-radius: 15px; | |
}} | |
.features h3 {{ | |
color: #333; | |
margin-bottom: 15px; | |
}} | |
.features ul {{ | |
list-style: none; | |
padding: 0; | |
}} | |
.features li {{ | |
padding: 5px 0; | |
color: #666; | |
}} | |
.features li:before {{ | |
content: "✅ "; | |
margin-left: 10px; | |
}} | |
</style> | |
</head> | |
<body> | |
<div class="container"> | |
<div class="logo">🚗</div> | |
<h1>Baseer Self-Driving API</h1> | |
<p class="subtitle">نظام القيادة الذاتية المتقدم</p> | |
<div class="status">🟢 يعمل بنجاح</div> | |
<div class="stats"> | |
<div class="stat-card"> | |
<div class="stat-number">{len(SESSIONS)}</div> | |
<div class="stat-label">الجلسات النشطة</div> | |
</div> | |
<div class="stat-card"> | |
<div class="stat-number">v1.0</div> | |
<div class="stat-label">الإصدار</div> | |
</div> | |
<div class="stat-card"> | |
<div class="stat-number">FastAPI</div> | |
<div class="stat-label">التقنية</div> | |
</div> | |
</div> | |
<div class="buttons"> | |
<a href="/docs" class="btn btn-primary">📚 توثيق API</a> | |
<a href="/sessions" class="btn btn-secondary">📊 الجلسات</a> | |
</div> | |
<div class="features"> | |
<h3>🌟 الميزات الرئيسية</h3> | |
<ul> | |
<li>نموذج InterFuser للقيادة الذاتية</li> | |
<li>معالجة الصور في الوقت الفعلي</li> | |
<li>اكتشاف الكائنات المرورية</li> | |
<li>تحديد المسارات الذكية</li> | |
<li>واجهة RESTful سهلة الاستخدام</li> | |
<li>إدارة جلسات متعددة</li> | |
</ul> | |
</div> | |
</div> | |
</body> | |
</html> | |
""" | |
return html_content | |
async def start_session(): | |
""" | |
بدء جلسة جديدة للمحاكاة | |
""" | |
session_id = str(uuid.uuid4()) | |
# إنشاء جلسة جديدة | |
SESSIONS[session_id] = { | |
'tracker': Tracker(frequency=10), | |
'controller': InterfuserController(ControllerConfig()), | |
'frame_num': 0, | |
'created_at': np.datetime64('now'), | |
'last_activity': np.datetime64('now') | |
} | |
logger.info(f"New session created: {session_id}") | |
return SessionResponse( | |
session_id=session_id, | |
message="Session started successfully" | |
) | |
async def run_step(data: RunStepInput): | |
""" | |
تنفيذ خطوة محاكاة كاملة | |
""" | |
# التحقق من وجود الجلسة | |
if data.session_id not in SESSIONS: | |
raise HTTPException(status_code=404, detail="Session not found") | |
session = SESSIONS[data.session_id] | |
tracker = session['tracker'] | |
controller = session['controller'] | |
# تحديث وقت النشاط | |
session['last_activity'] = np.datetime64('now') | |
try: | |
# 1. فك تشفير الصورة | |
frame_bgr = decode_base64_image(data.image_b64) | |
frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB) | |
# 2. معالجة المدخلات | |
inputs = preprocess_input(frame_rgb, data.measurements, device) | |
# 3. تشغيل النموذج | |
if model is None: | |
raise HTTPException(status_code=500, detail="Model not loaded") | |
with torch.no_grad(): | |
traffic, waypoints, is_junction, traffic_light, stop_sign, _ = model(inputs) | |
# 4. معالجة مخرجات النموذج | |
traffic_np = traffic.cpu().numpy()[0] # أخذ أول عنصر من الـ batch | |
waypoints_np = waypoints.cpu().numpy()[0] | |
is_junction_prob = torch.sigmoid(is_junction)[0, 1].item() | |
traffic_light_prob = torch.sigmoid(traffic_light)[0, 0].item() | |
stop_sign_prob = torch.sigmoid(stop_sign)[0, 1].item() | |
# 5. تحديث التتبع | |
# تحويل traffic grid إلى detections للتتبع | |
detections = [] | |
h, w, c = traffic_np.shape | |
for y in range(h): | |
for x in range(w): | |
for ch in range(c): | |
if traffic_np[y, x, ch] > 0.2: # عتبة الكشف | |
world_x = (x / w - 0.5) * 64 # تحويل إلى إحداثيات العالم | |
world_y = (y / h - 0.5) * 64 | |
detections.append({ | |
'position': [world_x, world_y], | |
'feature': traffic_np[y, x, ch] | |
}) | |
updated_traffic = tracker.update_and_predict(detections, session['frame_num']) | |
# 6. تشغيل المتحكم | |
steer, throttle, brake, metadata = controller.run_step( | |
current_speed=data.measurements.speed, | |
waypoints=waypoints_np, | |
junction=is_junction_prob, | |
traffic_light_state=traffic_light_prob, | |
stop_sign=stop_sign_prob, | |
meta_data={'frame': session['frame_num']} | |
) | |
# 7. إنشاء خرائط العرض | |
surround_t0, counts_t0 = render(updated_traffic, t=0) | |
surround_t1, counts_t1 = render(updated_traffic, t=T1_FUTURE_TIME) | |
surround_t2, counts_t2 = render(updated_traffic, t=T2_FUTURE_TIME) | |
# إضافة المسار المقترح | |
wp_map = render_waypoints(waypoints_np) | |
map_t0 = cv2.add(surround_t0, wp_map) | |
# إضافة السيارة الذاتية | |
map_t0 = render_self_car(map_t0) | |
map_t1 = render_self_car(surround_t1) | |
map_t2 = render_self_car(surround_t2) | |
# 8. إنشاء لوحة العرض النهائية | |
interface_data = { | |
'camera_view': frame_bgr, | |
'map_t0': map_t0, | |
'map_t1': map_t1, | |
'map_t2': map_t2, | |
'text_info': { | |
'Frame': f"Frame: {session['frame_num']}", | |
'Control': f"Steer: {steer:.2f}, Throttle: {throttle:.2f}, Brake: {brake}", | |
'Speed': f"Speed: {data.measurements.speed:.1f} km/h", | |
'Junction': f"Junction: {is_junction_prob:.2f}", | |
'Traffic Light': f"Red Light: {traffic_light_prob:.2f}", | |
'Stop Sign': f"Stop Sign: {stop_sign_prob:.2f}", | |
'Metadata': metadata | |
}, | |
'object_counts': { | |
't0': counts_t0, | |
't1': counts_t1, | |
't2': counts_t2 | |
} | |
} | |
dashboard_image = display.run_interface(interface_data) | |
dashboard_b64 = encode_image_to_base64(dashboard_image) | |
# 9. تجميع المخرجات النهائية | |
response = RunStepOutput( | |
model_outputs=ModelOutputs( | |
traffic=traffic_np.tolist(), | |
waypoints=waypoints_np.tolist(), | |
is_junction=is_junction_prob, | |
traffic_light_state=traffic_light_prob, | |
stop_sign=stop_sign_prob | |
), | |
control_commands=ControlCommands( | |
steer=float(steer), | |
throttle=float(throttle), | |
brake=bool(brake) | |
), | |
dashboard_image_b64=dashboard_b64 | |
) | |
# تحديث رقم الإطار | |
session['frame_num'] += 1 | |
logger.info(f"Step completed for session {data.session_id}, frame {session['frame_num']}") | |
return response | |
except Exception as e: | |
logger.error(f"Error in run_step: {str(e)}") | |
raise HTTPException(status_code=500, detail=f"Processing error: {str(e)}") | |
async def end_session(session_id: str): | |
""" | |
إنهاء جلسة المحاكاة | |
""" | |
if session_id not in SESSIONS: | |
raise HTTPException(status_code=404, detail="Session not found") | |
# حذف الجلسة | |
del SESSIONS[session_id] | |
logger.info(f"Session ended: {session_id}") | |
return SessionResponse( | |
session_id=session_id, | |
message="Session ended successfully" | |
) | |
async def list_sessions(): | |
""" | |
عرض قائمة الجلسات النشطة | |
""" | |
active_sessions = [] | |
current_time = np.datetime64('now') | |
for session_id, session_data in SESSIONS.items(): | |
time_diff = current_time - session_data['last_activity'] | |
active_sessions.append({ | |
'session_id': session_id, | |
'frame_count': session_data['frame_num'], | |
'created_at': str(session_data['created_at']), | |
'last_activity': str(session_data['last_activity']), | |
'inactive_minutes': float(time_diff / np.timedelta64(1, 'm')) | |
}) | |
return { | |
'total_sessions': len(active_sessions), | |
'sessions': active_sessions | |
} | |
async def cleanup_inactive_sessions(max_inactive_minutes: int = 30): | |
""" | |
تنظيف الجلسات غير النشطة | |
""" | |
current_time = np.datetime64('now') | |
cleaned_sessions = [] | |
for session_id in list(SESSIONS.keys()): | |
session = SESSIONS[session_id] | |
time_diff = current_time - session['last_activity'] | |
inactive_minutes = float(time_diff / np.timedelta64(1, 'm')) | |
if inactive_minutes > max_inactive_minutes: | |
del SESSIONS[session_id] | |
cleaned_sessions.append(session_id) | |
logger.info(f"Cleaned up {len(cleaned_sessions)} inactive sessions") | |
return { | |
'message': f"Cleaned up {len(cleaned_sessions)} inactive sessions", | |
'cleaned_sessions': cleaned_sessions, | |
'remaining_sessions': len(SESSIONS) | |
} | |
# ================== معالج الأخطاء ================== | |
async def global_exception_handler(request, exc): | |
logger.error(f"Global exception: {str(exc)}") | |
return { | |
"error": "Internal server error", | |
"detail": str(exc) | |
} | |
# ================== تشغيل الخادم ================== | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=7860) | |