Spaces:
Sleeping
Sleeping
# app.py - InterFuser Self-Driving API Server | |
import uuid | |
import base64 | |
import cv2 | |
import torch | |
import numpy as np | |
from fastapi import FastAPI, HTTPException | |
from fastapi.responses import HTMLResponse | |
from pydantic import BaseModel | |
from torchvision import transforms | |
from typing import List, Dict, Any, Optional | |
import logging | |
import uuid | |
import base64 | |
import cv2 | |
import torch | |
import numpy as np | |
import logging | |
from fastapi import FastAPI, HTTPException | |
from fastapi.responses import HTMLResponse | |
from pydantic import BaseModel, Field | |
from typing import List, Dict, Tuple | |
from model_definition import InterfuserModel, load_and_prepare_model, get_master_config | |
from simulation_modules import InterfuserController, Tracker | |
from simulation_modules import DisplayInterface, render_bev, unnormalize_image, DisplayConfig | |
# ============================================================================== | |
# 2. إعدادات عامة وتطبيق FastAPI | |
# ============================================================================== | |
# إعداد التسجيل (Logging) | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
# تهيئة تطبيق FastAPI | |
app = FastAPI( | |
title="Baseer Self-Driving API", | |
description="An advanced API for the InterFuser self-driving model, providing real-time control commands and scene analysis.", | |
version="1.1.0" | |
) | |
# متغيرات عامة سيتم تهيئتها عند بدء التشغيل | |
MODEL: InterfuserModel = None | |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
SESSIONS: Dict[str, Dict] = {} # قاموس لتخزين حالة الجلسات النشطة | |
# ============================================================================== | |
# 3. تعريف نماذج البيانات (Pydantic Models) للـ API | |
# ============================================================================== | |
class Measurements(BaseModel): | |
pos_global: Tuple[float, float] = Field(..., example=(0.0, 0.0), description="Global [X, Y] position of the vehicle.") | |
theta: float = Field(..., example=0.0, description="Global orientation angle of the vehicle in radians.") | |
speed: float = Field(..., example=0.0, description="Current speed in m/s.") | |
target_point: Tuple[float, float] = Field(..., example=(10.0, 0.0), description="Target point relative to the vehicle.") | |
class RunStepRequest(BaseModel): | |
session_id: str | |
image_b64: str = Field(..., description="Base64 encoded string of the vehicle's front camera view (BGR format).") | |
measurements: Measurements | |
class ControlCommands(BaseModel): | |
steer: float | |
throttle: float | |
brake: bool | |
class SceneAnalysis(BaseModel): | |
is_junction: float | |
traffic_light_state: float | |
stop_sign: float | |
class RunStepResponse(BaseModel): | |
control_commands: ControlCommands | |
scene_analysis: SceneAnalysis | |
predicted_waypoints: List[Tuple[float, float]] | |
dashboard_b64: str = Field(..., description="Base64 encoded string of the comprehensive dashboard view.") | |
reason: str = Field(..., description="The reason for the current control action (e.g., 'Following ID 12', 'Red Light').") | |
# ============================================================================== | |
# 4. دوال مساعدة (Helpers) | |
# ============================================================================== | |
def b64_to_cv2(b64_string: str) -> np.ndarray: | |
try: | |
img_bytes = base64.b64decode(b64_string) | |
img_array = np.frombuffer(img_bytes, dtype=np.uint8) | |
return cv2.imdecode(img_array, cv2.IMREAD_COLOR) | |
except Exception: | |
raise HTTPException(status_code=400, detail="Invalid Base64 image string.") | |
def cv2_to_b64(img: np.ndarray) -> str: | |
_, buffer = cv2.imencode('.jpg', img) | |
return base64.b64encode(buffer).decode('utf-8') | |
def prepare_model_input(image: np.ndarray, measurements: Measurements) -> Dict[str, torch.Tensor]: | |
""" | |
إعداد دفعة (batch of 1) لتمريرها إلى النموذج. | |
""" | |
transform = transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Resize((224, 224), antialias=True), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
]) | |
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
image_tensor = transform(image_rgb).unsqueeze(0).to(DEVICE) | |
measurements_tensor = torch.tensor([[ | |
measurements.pos_global[0], measurements.pos_global[1], measurements.theta, | |
0.0, 0.0, 0.0, # Steer, throttle, brake (not used by model) | |
measurements.speed, 4.0 # Command (default to FollowLane) | |
]], dtype=torch.float32).to(DEVICE) | |
target_point_tensor = torch.tensor([measurements.target_point], dtype=torch.float32).to(DEVICE) | |
return { | |
'rgb': image_tensor, | |
'rgb_left': image_tensor.clone(), 'rgb_right': image_tensor.clone(), 'rgb_center': image_tensor.clone(), | |
'measurements': measurements_tensor, | |
'target_point': target_point_tensor, | |
'lidar': torch.zeros_like(image_tensor) | |
} | |
# ============================================================================== | |
# 5. أحداث دورة حياة التطبيق (Startup/Shutdown) | |
# ============================================================================== | |
async def startup_event(): | |
global MODEL | |
logging.info("🚗 Server starting up...") | |
logging.info(f"Using device: {DEVICE}") | |
# MODEL = load_and_prepare_model(model_config, DEVICE) | |
MODEL = load_and_prepare_model(DEVICE) | |
if MODEL: | |
logging.info("✅ Model loaded successfully. Server is ready!") | |
else: | |
logging.error("❌ CRITICAL: Model could not be loaded. The API will not function correctly.") | |
# ============================================================================== | |
# 6. نقاط النهاية الرئيسية (API Endpoints) | |
# ============================================================================== | |
async def root(): | |
""" | |
يعرض صفحة رئيسية احترافية وتفاعلية لواجهة برمجة التطبيقات. | |
""" | |
# الحصول على عدد الجلسات النشطة حاليًا | |
active_sessions_count = len(SESSIONS) | |
# تحديد لون حالة النظام | |
status_color = "#4CAF50" # أخضر | |
status_text = "يعمل بنجاح" | |
if MODEL is None: | |
status_color = "#F44336" # أحمر | |
status_text = "خطأ: النموذج غير محمل" | |
html_content = f""" | |
<!DOCTYPE html> | |
<html dir="rtl" lang="ar"> | |
<head> | |
<meta charset="UTF-8"> | |
<meta name="viewport" content="width=device-width, initial-scale=1.0"> | |
<title>🚗 Baseer Self-Driving API</title> | |
<style> | |
@import url('https://fonts.googleapis.com/css2?family=Tajawal:wght@400;700&display=swap'); | |
:root {{ | |
--primary-color: #4a69bd; | |
--secondary-color: #6a89cc; | |
--text-color: #333; | |
--bg-color: #f4f7f6; | |
--panel-bg: #ffffff; | |
--shadow: 0 10px 30px rgba(0, 0, 0, 0.1); | |
}} | |
body {{ | |
font-family: 'Tajawal', sans-serif; | |
background-color: var(--bg-color); | |
color: var(--text-color); | |
margin: 0; | |
display: flex; | |
justify-content: center; | |
align-items: center; | |
min-height: 100vh; | |
padding: 20px; | |
}} | |
.container {{ | |
background: var(--panel-bg); | |
border-radius: 20px; | |
padding: 40px; | |
box-shadow: var(--shadow); | |
text-align: center; | |
max-width: 700px; | |
width: 100%; | |
border-top: 5px solid var(--primary-color); | |
}} | |
.logo {{ | |
font-size: 4rem; | |
margin-bottom: 15px; | |
animation: car-drive 3s ease-in-out infinite; | |
}} | |
@keyframes car-drive {{ | |
0% {{ transform: translateX(-20px); }} | |
50% {{ transform: translateX(20px); }} | |
100% {{ transform: translateX(-20px); }} | |
}} | |
h1 {{ | |
font-size: 2.8rem; | |
font-weight: 700; | |
color: var(--primary-color); | |
margin-bottom: 10px; | |
}} | |
.subtitle {{ | |
font-size: 1.2rem; | |
color: #777; | |
margin-bottom: 25px; | |
}} | |
.status-badge {{ | |
display: inline-block; | |
background-color: {status_color}; | |
color: white; | |
padding: 10px 20px; | |
border-radius: 25px; | |
font-weight: bold; | |
margin-bottom: 30px; | |
font-size: 1rem; | |
}} | |
.stats-grid {{ | |
display: grid; | |
grid-template-columns: 1fr 1fr; | |
gap: 20px; | |
margin-bottom: 30px; | |
}} | |
.stat-card {{ | |
background: #f8f9fa; | |
padding: 20px; | |
border-radius: 15px; | |
}} | |
.stat-number {{ | |
font-size: 2.5rem; | |
font-weight: 700; | |
color: var(--primary-color); | |
}} | |
.stat-label {{ | |
font-size: 1rem; | |
color: #666; | |
margin-top: 5px; | |
}} | |
.button-group {{ | |
display: flex; | |
gap: 15px; | |
justify-content: center; | |
}} | |
.btn {{ | |
padding: 14px 28px; | |
border-radius: 30px; | |
text-decoration: none; | |
font-weight: bold; | |
font-size: 1rem; | |
transition: all 0.3s ease; | |
border: 2px solid transparent; | |
cursor: pointer; | |
}} | |
.btn-primary {{ | |
background: var(--primary-color); | |
color: white; | |
}} | |
.btn-primary:hover {{ | |
background: var(--secondary-color); | |
transform: translateY(-3px); | |
}} | |
.btn-secondary {{ | |
background: transparent; | |
color: var(--primary-color); | |
border-color: var(--primary-color); | |
}} | |
.btn-secondary:hover {{ | |
background: var(--primary-color); | |
color: white; | |
transform: translateY(-3px); | |
}} | |
</style> | |
</head> | |
<body> | |
<div class="container"> | |
<div class="logo">🚗</div> | |
<h1>Baseer Self-Driving API</h1> | |
<p class="subtitle">واجهة برمجية متقدمة للقيادة الذاتية</p> | |
<div class="status-badge">{status_text}</div> | |
<div class="stats-grid"> | |
<div class="stat-card"> | |
<div class="stat-number">{active_sessions_count}</div> | |
<div class="stat-label">الجلسات النشطة</div> | |
</div> | |
<div class="stat-card"> | |
<div class="stat-number">1.1.0</div> | |
<div class="stat-label">إصدار الـ API</div> | |
</div> | |
</div> | |
<div class="button-group"> | |
<a href="/docs" target="_blank" class="btn btn-primary">📚 التوثيق التفاعلي (Docs)</a> | |
<a href="/sessions" target="_blank" class="btn btn-secondary">📊 عرض الجلسات</a> | |
</div> | |
</div> | |
</body> | |
</html> | |
""" | |
return HTMLResponse(content=html_content) | |
def start_session(): | |
session_id = str(uuid.uuid4()) | |
# 1. الحصول على الإعدادات الكاملة من المصدر الوحيد | |
config = get_master_config() | |
# 2. استخراج الإعدادات المطلوبة لكل مكون بشكل صريح | |
grid_conf = config['grid_conf'] | |
controller_params = config['controller_params'] | |
simulation_freq = config['simulation']['frequency'] | |
# 3. تهيئة المتتبع (Tracker) بمعلماته المحددة | |
tracker = Tracker( | |
grid_conf=grid_conf, | |
match_threshold=controller_params.get('tracker_match_thresh', 2.5), | |
prune_age=controller_params.get('tracker_prune_age', 5) | |
) | |
# 4. تهيئة المتحكم (Controller) | |
controller = InterfuserController({ | |
'controller_params': controller_params, | |
'grid_conf': grid_conf, | |
'frequency': simulation_freq | |
}) | |
# 5. إنشاء الجلسة بالكائنات المهيأة | |
SESSIONS[session_id] = { | |
'tracker': tracker, | |
'controller': controller, | |
'frame_num': 0 | |
} | |
logging.info(f"New session started: {session_id}") | |
return {"session_id": session_id} | |
def run_step(request: RunStepRequest): | |
if MODEL is None: | |
raise HTTPException(status_code=503, detail="Model is not available.") | |
session = SESSIONS.get(request.session_id) | |
if not session: | |
raise HTTPException(status_code=404, detail="Session ID not found.") | |
# --- 1. الإدراك (Perception) --- | |
image = b64_to_cv2(request.image_b64) | |
model_input = prepare_model_input(image, request.measurements) | |
traffic, waypoints, junc, light, stop, _ = MODEL(model_input) | |
# --- 2. معالجة مخرجات النموذج --- | |
traffic_processed = torch.cat([torch.sigmoid(traffic[0][:, 0:1]), traffic[0][:, 1:]], dim=1) | |
traffic_np = traffic_processed.cpu().numpy().reshape(20, 20, -1) | |
waypoints_np = waypoints[0].cpu().numpy() | |
junction_prob = torch.softmax(junc, dim=1)[0, 1].item() | |
light_prob = torch.softmax(light, dim=1)[0, 1].item() | |
stop_prob = torch.softmax(stop, dim=1)[0, 1].item() | |
# --- 3. التتبع والتحكم --- | |
ego_pos = np.array(request.measurements.pos_global) | |
ego_theta = request.measurements.theta | |
frame_num = session['frame_num'] | |
active_tracks = session['tracker'].process_frame(traffic_np, ego_pos, ego_theta, frame_num) | |
steer, throttle, brake, ctrl_info = session['controller'].run_step( | |
speed=request.measurements.speed, waypoints=torch.from_numpy(waypoints_np), | |
junction=junction_prob, traffic_light=light_prob, stop_sign=stop_prob, | |
bev_map=traffic_np, ego_pos=ego_pos, ego_theta=ego_theta, frame_num=frame_num | |
) | |
# --- 4. إنشاء الواجهة المرئية --- | |
display_iface = DisplayInterface(DisplayConfig(width=1280, height=720)) | |
bev_maps = render_bev(active_tracks, waypoints_np, ego_pos, ego_theta) | |
display_data = { | |
'camera_view': image, 'map_t0': bev_maps['t0'], 'map_t1': bev_maps['t1'], 'map_t2': bev_maps['t2'], | |
'frame_num': frame_num, 'speed': request.measurements.speed * 3.6, | |
'target_speed': ctrl_info.get('target_speed', 0) * 3.6, | |
'steer': steer, 'throttle': throttle, 'brake': brake, | |
'light_prob': light_prob, 'stop_prob': stop_prob, | |
'object_counts': {'car': len(active_tracks)} | |
} | |
dashboard = display_iface.run_interface(display_data) | |
# --- 5. تحديث الجلسة وإرجاع الرد --- | |
session['frame_num'] += 1 | |
return RunStepResponse( | |
control_commands=ControlCommands(steer=steer, throttle=throttle, brake=brake), | |
scene_analysis=SceneAnalysis(is_junction=junction_prob, traffic_light_state=light_prob, stop_sign=stop_prob), | |
predicted_waypoints=[tuple(wp) for wp in waypoints_np.tolist()], | |
dashboard_b64=cv2_to_b64(dashboard), | |
reason=ctrl_info.get('brake_reason', "Cruising") | |
) | |
def end_session(session_id: str): | |
if session_id in SESSIONS: | |
del SESSIONS[session_id] | |
logging.info(f"Session ended: {session_id}") | |
return {"message": f"Session {session_id} ended."} | |
raise HTTPException(status_code=404, detail="Session not found.") | |
# ================== تشغيل الخادم ================== | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=7860) | |