Spaces:
Running
Running
# app.py - InterFuser Self-Driving API Server | |
import uuid | |
import base64 | |
import cv2 | |
# import torch | |
# import numpy as np | |
# from fastapi import FastAPI, HTTPException | |
# from fastapi.responses import HTMLResponse | |
# from pydantic import BaseModel | |
# from torchvision import transforms | |
# from typing import List, Dict, Any, Optional | |
# import logging | |
# # استيراد من ملفاتنا المحلية | |
# from model_definition import InterfuserModel, load_and_prepare_model, create_model_config | |
# from simulation_modules import ( | |
# InterfuserController, ControllerConfig, Tracker, DisplayInterface, | |
# render, render_waypoints, render_self_car, WAYPOINT_SCALE_FACTOR, | |
# T1_FUTURE_TIME, T2_FUTURE_TIME | |
# ) | |
# # إعداد التسجيل | |
# logging.basicConfig(level=logging.INFO) | |
# logger = logging.getLogger(__name__) | |
# # ================== إعدادات عامة وتحميل النموذج ================== | |
# app = FastAPI( | |
# title="Baseer Self-Driving API", | |
# description="API للقيادة الذاتية باستخدام نموذج InterFuser", | |
# version="1.0.0" | |
# ) | |
# device = torch.device("cpu") | |
# logger.info(f"Using device: {device}") | |
# # تحميل النموذج باستخدام الدالة المحسنة | |
# try: | |
# # إنشاء إعدادات النموذج باستخدام الإعدادات الصحيحة من التدريب | |
# model_config = create_model_config( | |
# model_path="model/best_model.pth" | |
# # الإعدادات الصحيحة من التدريب ستطبق تلقائياً: | |
# # embed_dim=256, rgb_backbone_name='r50', waypoints_pred_head='gru' | |
# # with_lidar=False, with_right_left_sensors=False, with_center_sensor=False | |
# ) | |
# # تحميل النموذج مع الأوزان | |
# model = load_and_prepare_model(model_config, device) | |
# logger.info("✅ تم تحميل النموذج بنجاح") | |
# except Exception as e: | |
# logger.error(f"❌ خطأ في تحميل النموذج: {e}") | |
# logger.info("🔄 محاولة إنشاء نموذج بأوزان عشوائية...") | |
# try: | |
# model = InterfuserModel() | |
# model.to(device) | |
# model.eval() | |
# logger.warning("⚠️ تم إنشاء النموذج بأوزان عشوائية") | |
# except Exception as e2: | |
# logger.error(f"❌ فشل في إنشاء النموذج: {e2}") | |
# model = None | |
# # تهيئة واجهة العرض | |
# display = DisplayInterface() | |
# # قاموس لتخزين جلسات المستخدمين | |
# SESSIONS: Dict[str, Dict] = {} | |
# # ================== هياكل بيانات Pydantic ================== | |
# class Measurements(BaseModel): | |
# pos: List[float] = [0.0, 0.0] # [x, y] position | |
# theta: float = 0.0 # orientation angle | |
# speed: float = 0.0 # current speed | |
# steer: float = 0.0 # current steering | |
# throttle: float = 0.0 # current throttle | |
# brake: bool = False # brake status | |
# command: int = 4 # driving command (4 = FollowLane) | |
# target_point: List[float] = [0.0, 0.0] # target point [x, y] | |
# class ModelOutputs(BaseModel): | |
# traffic: List[List[List[float]]] # 20x20x7 grid | |
# waypoints: List[List[float]] # Nx2 waypoints | |
# is_junction: float | |
# traffic_light_state: float | |
# stop_sign: float | |
# class ControlCommands(BaseModel): | |
# steer: float | |
# throttle: float | |
# brake: bool | |
# class RunStepInput(BaseModel): | |
# session_id: str | |
# image_b64: str | |
# measurements: Measurements | |
# class RunStepOutput(BaseModel): | |
# model_outputs: ModelOutputs | |
# control_commands: ControlCommands | |
# dashboard_image_b64: str | |
# class SessionResponse(BaseModel): | |
# session_id: str | |
# message: str | |
# # ================== دوال المساعدة ================== | |
# def get_image_transform(): | |
# """إنشاء تحويلات الصورة كما في PDMDataset""" | |
# return transforms.Compose([ | |
# transforms.ToTensor(), | |
# transforms.Resize((224, 224), antialias=True), | |
# transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
# ]) | |
# # إنشاء كائن التحويل مرة واحدة | |
# image_transform = get_image_transform() | |
# def preprocess_input(frame_rgb: np.ndarray, measurements: Measurements, device: torch.device) -> Dict[str, torch.Tensor]: | |
# """ | |
# تحاكي ما يفعله PDMDataset.__getitem__ لإنشاء دفعة (batch) واحدة. | |
# """ | |
# # 1. معالجة الصورة الرئيسية | |
# from PIL import Image | |
# if isinstance(frame_rgb, np.ndarray): | |
# frame_rgb = Image.fromarray(frame_rgb) | |
# image_tensor = image_transform(frame_rgb).unsqueeze(0).to(device) # إضافة بُعد الدفعة | |
# # 2. إنشاء مدخلات الكاميرات الأخرى عن طريق الاستنساخ | |
# batch = { | |
# 'rgb': image_tensor, | |
# 'rgb_left': image_tensor.clone(), | |
# 'rgb_right': image_tensor.clone(), | |
# 'rgb_center': image_tensor.clone(), | |
# } | |
# # 3. إنشاء مدخل ليدار وهمي (أصفار) | |
# batch['lidar'] = torch.zeros(1, 3, 224, 224, dtype=torch.float32).to(device) | |
# # 4. تجميع القياسات بنفس ترتيب PDMDataset | |
# m = measurements | |
# measurements_tensor = torch.tensor([[ | |
# m.pos[0], m.pos[1], m.theta, | |
# m.steer, m.throttle, float(m.brake), | |
# m.speed, float(m.command) | |
# ]], dtype=torch.float32).to(device) | |
# batch['measurements'] = measurements_tensor | |
# # 5. إنشاء نقطة هدف | |
# batch['target_point'] = torch.tensor([m.target_point], dtype=torch.float32).to(device) | |
# # لا نحتاج إلى قيم ground truth (gt_*) أثناء التنبؤ | |
# return batch | |
# def decode_base64_image(image_b64: str) -> np.ndarray: | |
# """ | |
# فك تشفير صورة Base64 | |
# """ | |
# try: | |
# image_bytes = base64.b64decode(image_b64) | |
# nparr = np.frombuffer(image_bytes, np.uint8) | |
# image = cv2.imdecode(nparr, cv2.IMREAD_COLOR) | |
# return image | |
# except Exception as e: | |
# raise HTTPException(status_code=400, detail=f"Invalid image format: {str(e)}") | |
# def encode_image_to_base64(image: np.ndarray) -> str: | |
# """ | |
# تشفير صورة إلى Base64 | |
# """ | |
# _, buffer = cv2.imencode('.jpg', image, [cv2.IMWRITE_JPEG_QUALITY, 85]) | |
# return base64.b64encode(buffer).decode('utf-8') | |
# # ================== نقاط نهاية الـ API ================== | |
# @app.get("/", response_class=HTMLResponse) | |
# async def root(): | |
# """ | |
# الصفحة الرئيسية للـ API | |
# """ | |
# html_content = f""" | |
# <!DOCTYPE html> | |
# <html dir="rtl" lang="ar"> | |
# <head> | |
# <meta charset="UTF-8"> | |
# <meta name="viewport" content="width=device-width, initial-scale=1.0"> | |
# <title>🚗 Baseer Self-Driving API</title> | |
# <style> | |
# * {{ | |
# margin: 0; | |
# padding: 0; | |
# box-sizing: border-box; | |
# }} | |
# body {{ | |
# font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; | |
# background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
# min-height: 100vh; | |
# display: flex; | |
# align-items: center; | |
# justify-content: center; | |
# padding: 20px; | |
# }} | |
# .container {{ | |
# background: rgba(255, 255, 255, 0.95); | |
# backdrop-filter: blur(10px); | |
# border-radius: 20px; | |
# padding: 40px; | |
# box-shadow: 0 20px 40px rgba(0, 0, 0, 0.1); | |
# text-align: center; | |
# max-width: 600px; | |
# width: 100%; | |
# }} | |
# .logo {{ | |
# font-size: 4rem; | |
# margin-bottom: 20px; | |
# animation: bounce 2s infinite; | |
# }} | |
# @keyframes bounce {{ | |
# 0%, 20%, 50%, 80%, 100% {{ transform: translateY(0); }} | |
# 40% {{ transform: translateY(-10px); }} | |
# 60% {{ transform: translateY(-5px); }} | |
# }} | |
# h1 {{ | |
# color: #333; | |
# margin-bottom: 10px; | |
# font-size: 2.5rem; | |
# }} | |
# .subtitle {{ | |
# color: #666; | |
# margin-bottom: 30px; | |
# font-size: 1.2rem; | |
# }} | |
# .status {{ | |
# display: inline-block; | |
# background: #4CAF50; | |
# color: white; | |
# padding: 8px 16px; | |
# border-radius: 20px; | |
# margin: 10px 0; | |
# font-weight: bold; | |
# }} | |
# .stats {{ | |
# display: grid; | |
# grid-template-columns: repeat(auto-fit, minmax(150px, 1fr)); | |
# gap: 20px; | |
# margin: 30px 0; | |
# }} | |
# .stat-card {{ | |
# background: #f8f9fa; | |
# padding: 20px; | |
# border-radius: 15px; | |
# border-left: 4px solid #667eea; | |
# }} | |
# .stat-number {{ | |
# font-size: 2rem; | |
# font-weight: bold; | |
# color: #667eea; | |
# }} | |
# .stat-label {{ | |
# color: #666; | |
# margin-top: 5px; | |
# }} | |
# .buttons {{ | |
# display: flex; | |
# gap: 15px; | |
# justify-content: center; | |
# flex-wrap: wrap; | |
# margin-top: 30px; | |
# }} | |
# .btn {{ | |
# display: inline-block; | |
# padding: 12px 24px; | |
# border-radius: 25px; | |
# text-decoration: none; | |
# font-weight: bold; | |
# transition: all 0.3s ease; | |
# border: none; | |
# cursor: pointer; | |
# }} | |
# .btn-primary {{ | |
# background: #667eea; | |
# color: white; | |
# }} | |
# .btn-secondary {{ | |
# background: #6c757d; | |
# color: white; | |
# }} | |
# .btn:hover {{ | |
# transform: translateY(-2px); | |
# box-shadow: 0 5px 15px rgba(0, 0, 0, 0.2); | |
# }} | |
# .features {{ | |
# text-align: right; | |
# margin-top: 30px; | |
# padding: 20px; | |
# background: #f8f9fa; | |
# border-radius: 15px; | |
# }} | |
# .features h3 {{ | |
# color: #333; | |
# margin-bottom: 15px; | |
# }} | |
# .features ul {{ | |
# list-style: none; | |
# padding: 0; | |
# }} | |
# .features li {{ | |
# padding: 5px 0; | |
# color: #666; | |
# }} | |
# .features li:before {{ | |
# content: "✅ "; | |
# margin-left: 10px; | |
# }} | |
# </style> | |
# </head> | |
# <body> | |
# <div class="container"> | |
# <div class="logo">🚗</div> | |
# <h1>Baseer Self-Driving API</h1> | |
# <p class="subtitle">نظام القيادة الذاتية المتقدم</p> | |
# <div class="status">🟢 يعمل بنجاح</div> | |
# <div class="stats"> | |
# <div class="stat-card"> | |
# <div class="stat-number">{len(SESSIONS)}</div> | |
# <div class="stat-label">الجلسات النشطة</div> | |
# </div> | |
# <div class="stat-card"> | |
# <div class="stat-number">v1.0</div> | |
# <div class="stat-label">الإصدار</div> | |
# </div> | |
# <div class="stat-card"> | |
# <div class="stat-number">FastAPI</div> | |
# <div class="stat-label">التقنية</div> | |
# </div> | |
# </div> | |
# <div class="buttons"> | |
# <a href="/docs" class="btn btn-primary">📚 توثيق API</a> | |
# <a href="/sessions" class="btn btn-secondary">📊 الجلسات</a> | |
# </div> | |
# <div class="features"> | |
# <h3>🌟 الميزات الرئيسية</h3> | |
# <ul> | |
# <li>نموذج InterFuser للقيادة الذاتية</li> | |
# <li>معالجة الصور في الوقت الفعلي</li> | |
# <li>اكتشاف الكائنات المرورية</li> | |
# <li>تحديد المسارات الذكية</li> | |
# <li>واجهة RESTful سهلة الاستخدام</li> | |
# <li>إدارة جلسات متعددة</li> | |
# </ul> | |
# </div> | |
# </div> | |
# </body> | |
# </html> | |
# """ | |
# return html_content | |
import uuid | |
import base64 | |
import cv2 | |
import torch | |
import numpy as np | |
import logging | |
from fastapi import FastAPI, HTTPException | |
from fastapi.responses import HTMLResponse | |
from pydantic import BaseModel, Field | |
from typing import List, Dict, Tuple | |
# ============================================================================== | |
# 1. استيراد كل مكونات المشروع التي قمنا بتطويرها | |
# (تأكد من أن هذه الملفات موجودة في نفس المجلد) | |
# ============================================================================== | |
# من ملف النموذج (يحتوي على كلاس Interfuser والدوال المساعدة) | |
from model_definition import InterfuserModel, load_and_prepare_model, create_model_config | |
# من ملفات التحكم والعرض | |
from simulation_modules import InterfuserController, Tracker | |
from simulation_modules import DisplayInterface, render_bev, unnormalize_image, DisplayConfig | |
# # استيراد من ملفاتنا المحلية | |
# from model_definition import InterfuserModel, load_and_prepare_model, create_model_config | |
# from simulation_modules import ( | |
# InterfuserController, ControllerConfig, Tracker, DisplayInterface, | |
# render, render_waypoints, render_self_car, WAYPOINT_SCALE_FACTOR, | |
# T1_FUTURE_TIME, T2_FUTURE_TIME | |
# ) | |
# ============================================================================== | |
# 2. إعدادات عامة وتطبيق FastAPI | |
# ============================================================================== | |
# إعداد التسجيل (Logging) | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
# تهيئة تطبيق FastAPI | |
app = FastAPI( | |
title="Baseer Self-Driving API", | |
description="An advanced API for the InterFuser self-driving model, providing real-time control commands and scene analysis.", | |
version="1.1.0" | |
) | |
# متغيرات عامة سيتم تهيئتها عند بدء التشغيل | |
MODEL: Interfuser = None | |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
SESSIONS: Dict[str, Dict] = {} # قاموس لتخزين حالة الجلسات النشطة | |
# ============================================================================== | |
# 3. تعريف نماذج البيانات (Pydantic Models) للـ API | |
# ============================================================================== | |
class Measurements(BaseModel): | |
pos_global: Tuple[float, float] = Field(..., example=(0.0, 0.0), description="Global [X, Y] position of the vehicle.") | |
theta: float = Field(..., example=0.0, description="Global orientation angle of the vehicle in radians.") | |
speed: float = Field(..., example=0.0, description="Current speed in m/s.") | |
target_point: Tuple[float, float] = Field(..., example=(10.0, 0.0), description="Target point relative to the vehicle.") | |
class RunStepRequest(BaseModel): | |
session_id: str | |
image_b64: str = Field(..., description="Base64 encoded string of the vehicle's front camera view (BGR format).") | |
measurements: Measurements | |
class ControlCommands(BaseModel): | |
steer: float | |
throttle: float | |
brake: bool | |
class SceneAnalysis(BaseModel): | |
is_junction: float | |
traffic_light_state: float | |
stop_sign: float | |
class RunStepResponse(BaseModel): | |
control_commands: ControlCommands | |
scene_analysis: SceneAnalysis | |
predicted_waypoints: List[Tuple[float, float]] | |
dashboard_b64: str = Field(..., description="Base64 encoded string of the comprehensive dashboard view.") | |
reason: str = Field(..., description="The reason for the current control action (e.g., 'Following ID 12', 'Red Light').") | |
# ============================================================================== | |
# 4. دوال مساعدة (Helpers) | |
# ============================================================================== | |
def b64_to_cv2(b64_string: str) -> np.ndarray: | |
try: | |
img_bytes = base64.b64decode(b64_string) | |
img_array = np.frombuffer(img_bytes, dtype=np.uint8) | |
return cv2.imdecode(img_array, cv2.IMREAD_COLOR) | |
except Exception: | |
raise HTTPException(status_code=400, detail="Invalid Base64 image string.") | |
def cv2_to_b64(img: np.ndarray) -> str: | |
_, buffer = cv2.imencode('.jpg', img) | |
return base64.b64encode(buffer).decode('utf-8') | |
def prepare_model_input(image: np.ndarray, measurements: Measurements) -> Dict[str, torch.Tensor]: | |
""" | |
إعداد دفعة (batch of 1) لتمريرها إلى النموذج. | |
""" | |
transform = transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Resize((224, 224), antialias=True), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
]) | |
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
image_tensor = transform(image_rgb).unsqueeze(0).to(DEVICE) | |
measurements_tensor = torch.tensor([[ | |
measurements.pos_global[0], measurements.pos_global[1], measurements.theta, | |
0.0, 0.0, 0.0, # Steer, throttle, brake (not used by model) | |
measurements.speed, 4.0 # Command (default to FollowLane) | |
]], dtype=torch.float32).to(DEVICE) | |
target_point_tensor = torch.tensor([measurements.target_point], dtype=torch.float32).to(DEVICE) | |
return { | |
'rgb': image_tensor, | |
'rgb_left': image_tensor.clone(), 'rgb_right': image_tensor.clone(), 'rgb_center': image_tensor.clone(), | |
'measurements': measurements_tensor, | |
'target_point': target_point_tensor, | |
'lidar': torch.zeros_like(image_tensor) | |
} | |
# ============================================================================== | |
# 5. أحداث دورة حياة التطبيق (Startup/Shutdown) | |
# ============================================================================== | |
async def startup_event(): | |
global MODEL | |
logging.info("🚗 Server starting up...") | |
logging.info(f"Using device: {DEVICE}") | |
MODEL = load_and_prepare_model(DEVICE) | |
if MODEL: | |
logging.info("✅ Model loaded successfully. Server is ready!") | |
else: | |
logging.error("❌ CRITICAL: Model could not be loaded. The API will not function correctly.") | |
# ============================================================================== | |
# 6. نقاط النهاية الرئيسية (API Endpoints) | |
# ============================================================================== | |
async def root(): | |
# هذا يعرض صفحة رئيسية بسيطة وجميلة للمستخدمين | |
return """ | |
<html> | |
<head><title>Baseer API</title></head> | |
<body style='font-family: sans-serif; text-align: center; padding-top: 50px;'> | |
<h1>🚗 Baseer Self-Driving API</h1> | |
<p>Welcome! The API is running.</p> | |
<p>Navigate to <a href="/docs">/docs</a> for the interactive API documentation.</p> | |
</body> | |
</html> | |
""" | |
def start_session(): | |
session_id = str(uuid.uuid4()) | |
config = create_model_config() | |
controller_params = config.get('controller_params', {}) | |
controller_params.update({'frequency': 10.0}) # Set default frequency | |
SESSIONS[session_id] = { | |
'tracker': Tracker(grid_conf=config['grid_conf']), | |
'controller': InterfuserController({'controller_params': controller_params, 'grid_conf': config['grid_conf']}), | |
'frame_num': 0 | |
} | |
logging.info(f"New session started: {session_id}") | |
return {"session_id": session_id} | |
def run_step(request: RunStepRequest): | |
if MODEL is None: | |
raise HTTPException(status_code=503, detail="Model is not available.") | |
session = SESSIONS.get(request.session_id) | |
if not session: | |
raise HTTPException(status_code=404, detail="Session ID not found.") | |
# --- 1. الإدراك (Perception) --- | |
image = b64_to_cv2(request.image_b64) | |
model_input = prepare_model_input(image, request.measurements) | |
traffic, waypoints, junc, light, stop, _ = MODEL(model_input) | |
# --- 2. معالجة مخرجات النموذج --- | |
traffic_processed = torch.cat([torch.sigmoid(traffic[0][:, 0:1]), traffic[0][:, 1:]], dim=1) | |
traffic_np = traffic_processed.cpu().numpy().reshape(20, 20, -1) | |
waypoints_np = waypoints[0].cpu().numpy() | |
junction_prob = torch.softmax(junc, dim=1)[0, 1].item() | |
light_prob = torch.softmax(light, dim=1)[0, 1].item() | |
stop_prob = torch.softmax(stop, dim=1)[0, 1].item() | |
# --- 3. التتبع والتحكم --- | |
ego_pos = np.array(request.measurements.pos_global) | |
ego_theta = request.measurements.theta | |
frame_num = session['frame_num'] | |
active_tracks = session['tracker'].process_frame(traffic_np, ego_pos, ego_theta, frame_num) | |
steer, throttle, brake, ctrl_info = session['controller'].run_step( | |
speed=request.measurements.speed, waypoints=torch.from_numpy(waypoints_np), | |
junction=junction_prob, traffic_light=light_prob, stop_sign=stop_prob, | |
bev_map=traffic_np, ego_pos=ego_pos, ego_theta=ego_theta, frame_num=frame_num | |
) | |
# --- 4. إنشاء الواجهة المرئية --- | |
display_iface = DisplayInterface(DisplayConfig(width=1280, height=720)) | |
bev_maps = render_bev(active_tracks, waypoints_np, ego_pos, ego_theta) | |
display_data = { | |
'camera_view': image, 'map_t0': bev_maps['t0'], 'map_t1': bev_maps['t1'], 'map_t2': bev_maps['t2'], | |
'frame_num': frame_num, 'speed': request.measurements.speed * 3.6, | |
'target_speed': ctrl_info.get('target_speed', 0) * 3.6, | |
'steer': steer, 'throttle': throttle, 'brake': brake, | |
'light_prob': light_prob, 'stop_prob': stop_prob, | |
'object_counts': {'car': len(active_tracks)} | |
} | |
dashboard = display_iface.run_interface(display_data) | |
# --- 5. تحديث الجلسة وإرجاع الرد --- | |
session['frame_num'] += 1 | |
return RunStepResponse( | |
control_commands=ControlCommands(steer=steer, throttle=throttle, brake=brake), | |
scene_analysis=SceneAnalysis(is_junction=junction_prob, traffic_light_state=light_prob, stop_sign=stop_prob), | |
predicted_waypoints=[tuple(wp) for wp in waypoints_np.tolist()], | |
dashboard_b64=cv2_to_b64(dashboard), | |
reason=ctrl_info.get('brake_reason', "Cruising") | |
) | |
def end_session(session_id: str): | |
if session_id in SESSIONS: | |
del SESSIONS[session_id] | |
logging.info(f"Session ended: {session_id}") | |
return {"message": f"Session {session_id} ended."} | |
raise HTTPException(status_code=404, detail="Session not found.") | |
# ================== تشغيل الخادم ================== | |
# if __name__ == "__main__": | |
# import uvicorn | |
# uvicorn.run(app, host="0.0.0.0", port=7860) | |