Spaces:
Running
Running
# app.py - InterFuser Self-Driving API Server | |
import uuid | |
import base64 | |
import cv2 | |
import torch | |
import numpy as np | |
from fastapi import FastAPI, HTTPException | |
from fastapi.responses import HTMLResponse | |
from pydantic import BaseModel | |
from torchvision import transforms | |
from typing import List, Dict, Any, Optional | |
import logging | |
import uuid | |
import base64 | |
import cv2 | |
import torch | |
import numpy as np | |
import logging | |
from fastapi import FastAPI, HTTPException | |
from fastapi.responses import HTMLResponse | |
from pydantic import BaseModel, Field | |
from typing import List, Dict, Tuple | |
from model_definition import InterfuserHDPE , load_and_prepare_model, get_master_config | |
from simulation_modules import InterfuserController, Tracker | |
from simulation_modules import DisplayInterface, render_bev, unnormalize_image, DisplayConfig | |
# ============================================================================== | |
# 2. إعدادات عامة وتطبيق FastAPI | |
# ============================================================================== | |
# إعداد التسجيل (Logging) | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
# تهيئة تطبيق FastAPI | |
app = FastAPI( | |
title="Baseer Self-Driving API", | |
description="An advanced API for the InterFuser self-driving model, providing real-time control commands and scene analysis.", | |
version="1.1.0" | |
) | |
# متغيرات عامة سيتم تهيئتها عند بدء التشغيل | |
MODEL: InterfuserHDPE = None | |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
SESSIONS: Dict[str, Dict] = {} # قاموس لتخزين حالة الجلسات النشطة | |
# ============================================================================== | |
# 3. تعريف نماذج البيانات (Pydantic Models) للـ API | |
# ============================================================================== | |
class Measurements(BaseModel): | |
pos_global: Tuple[float, float] = Field(..., example=(0.0, 0.0), description="Global [X, Y] position of the vehicle.") | |
theta: float = Field(..., example=0.0, description="Global orientation angle of the vehicle in radians.") | |
speed: float = Field(..., example=0.0, description="Current speed in m/s.") | |
target_point: Tuple[float, float] = Field(..., example=(10.0, 0.0), description="Target point relative to the vehicle.") | |
class RunStepRequest(BaseModel): | |
session_id: str | |
image_b64: str = Field(..., description="Base64 encoded string of the vehicle's front camera view (BGR format).") | |
measurements: Measurements | |
class ControlCommands(BaseModel): | |
steer: float | |
throttle: float | |
brake: bool | |
class SceneAnalysis(BaseModel): | |
is_junction: float | |
traffic_light_state: float | |
stop_sign: float | |
class RunStepResponse(BaseModel): | |
control_commands: ControlCommands | |
scene_analysis: SceneAnalysis | |
predicted_waypoints: List[Tuple[float, float]] | |
dashboard_b64: str = Field(..., description="Base64 encoded string of the comprehensive dashboard view.") | |
reason: str = Field(..., description="The reason for the current control action (e.g., 'Following ID 12', 'Red Light').") | |
# ============================================================================== | |
# 4. دوال مساعدة (Helpers) | |
# ============================================================================== | |
def b64_to_cv2(b64_string: str) -> np.ndarray: | |
try: | |
img_bytes = base64.b64decode(b64_string) | |
img_array = np.frombuffer(img_bytes, dtype=np.uint8) | |
return cv2.imdecode(img_array, cv2.IMREAD_COLOR) | |
except Exception: | |
raise HTTPException(status_code=400, detail="Invalid Base64 image string.") | |
def cv2_to_b64(img: np.ndarray) -> str: | |
_, buffer = cv2.imencode('.jpg', img) | |
return base64.b64encode(buffer).decode('utf-8') | |
def prepare_model_input(image: np.ndarray, measurements: Measurements) -> Dict[str, torch.Tensor]: | |
""" | |
إعداد دفعة (batch of 1) لتمريرها إلى النموذج. | |
""" | |
transform = transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Resize((224, 224), antialias=True), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
]) | |
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
image_tensor = transform(image_rgb).unsqueeze(0).to(DEVICE) | |
measurements_tensor = torch.tensor([[ | |
measurements.pos_global[0], measurements.pos_global[1], measurements.theta, | |
0.0, 0.0, 0.0, # Steer, throttle, brake (not used by model) | |
measurements.speed, 4.0 # Command (default to FollowLane) | |
]], dtype=torch.float32).to(DEVICE) | |
target_point_tensor = torch.tensor([measurements.target_point], dtype=torch.float32).to(DEVICE) | |
return { | |
'rgb': image_tensor, | |
'rgb_left': image_tensor.clone(), 'rgb_right': image_tensor.clone(), 'rgb_center': image_tensor.clone(), | |
'measurements': measurements_tensor, | |
'target_point': target_point_tensor, | |
'lidar': torch.zeros_like(image_tensor) | |
} | |
# ============================================================================== | |
# 5. أحداث دورة حياة التطبيق (Startup/Shutdown) | |
# ============================================================================== | |
async def startup_event(): | |
global MODEL | |
logging.info("🚗 Server starting up...") | |
logging.info(f"Using device: {DEVICE}") | |
MODEL = load_and_prepare_model(DEVICE) | |
if MODEL: | |
logging.info("✅ Model loaded successfully. Server is ready!") | |
else: | |
logging.error("❌ CRITICAL: Model could not be loaded. The API will not function correctly.") | |
# ============================================================================== | |
# 6. نقاط النهاية الرئيسية (API Endpoints) | |
# ============================================================================== | |
async def root(): | |
""" | |
[النسخة النهائية مع التمرير] | |
يعرض صفحة رئيسية احترافية وجذابة بصريًا مع تمكين التمرير العمودي. | |
""" | |
active_sessions_count = len(SESSIONS) | |
status_color = "#00ff7f" # SpringGreen | |
status_text = "متصل ويعمل" | |
if MODEL is None: | |
status_color = "#ff4757" # Red | |
status_text = "خطأ: النموذج غير متاح" | |
html_content = f""" | |
<!DOCTYPE html> | |
<html dir="rtl" lang="ar"> | |
<head> | |
<meta charset="UTF-8"> | |
<meta name="viewport" content="width=device-width, initial-scale=1.0"> | |
<title>🚗 Baseer - واجهة القيادة الذاتية</title> | |
<style> | |
@import url('https://fonts.googleapis.com/css2?family=Cairo:wght@400;700;900&display=swap'); | |
:root {{ | |
--bg-dark: #1a1a2e; | |
--panel-dark: #16213e; | |
--primary-accent: #0f3460; | |
--secondary-accent: #e94560; | |
--glow-color: #537895; | |
--text-light: #e0e0e0; | |
--text-header: #ffffff; | |
}} | |
body {{ | |
font-family: 'Cairo', sans-serif; | |
background: var(--bg-dark); | |
background-image: linear-gradient(to right top, #1a1a2e, #1c1d32, #1f2037, #22233b, #252640); | |
color: var(--text-light); | |
margin: 0; | |
padding: 40px 20px; /* إضافة padding علوي وسفلي للسماح بالمساحة */ | |
min-height: 100vh; | |
box-sizing: border-box; /* لضمان أن padding لا يضيف إلى الارتفاع */ | |
/* --- [التصحيح هنا] --- */ | |
overflow-x: hidden; /* إخفاء التمرير الأفقي غير المرغوب فيه */ | |
overflow-y: auto; /* السماح بالتمرير العمودي عند الحاجة */ | |
}} | |
.main-content {{ | |
display: flex; | |
justify-content: center; | |
align-items: center; | |
width: 100%; | |
}} | |
.container {{ | |
background: rgba(22, 33, 62, 0.85); | |
backdrop-filter: blur(15px); | |
border-radius: 25px; | |
padding: 40px 50px; | |
box-shadow: 0 25px 50px rgba(0, 0, 0, 0.3); | |
text-align: center; | |
max-width: 750px; | |
width: 100%; | |
border: 1px solid rgba(255, 255, 255, 0.1); | |
position: relative; | |
z-index: 1; | |
}} | |
/* ... باقي كود CSS يبقى كما هو تمامًا ... */ | |
.logo {{ | |
font-size: 5rem; | |
margin-bottom: 10px; | |
text-shadow: 0 0 15px var(--glow-color); | |
animation: float 4s ease-in-out infinite; | |
}} | |
@keyframes float {{ | |
0% {{ transform: translateY(0px); }} | |
50% {{ transform: translateY(-15px); }} | |
100% {{ transform: translateY(0px); }} | |
}} | |
h1 {{ | |
font-size: 3rem; | |
font-weight: 900; | |
color: var(--text-header); | |
margin-bottom: 5px; | |
letter-spacing: 1px; | |
}} | |
.subtitle {{ | |
font-size: 1.3rem; | |
color: #a7a9be; | |
margin-bottom: 25px; | |
}} | |
.status-badge {{ | |
display: inline-flex; | |
align-items: center; | |
gap: 10px; | |
background-color: rgba(255, 255, 255, 0.05); | |
border: 1px solid {status_color}; | |
color: {status_color}; | |
padding: 10px 22px; | |
border-radius: 50px; | |
font-weight: bold; | |
margin-bottom: 35px; | |
font-size: 1.1rem; | |
box-shadow: 0 0 15px {status_color}33; | |
}} | |
.stats-grid {{ | |
display: grid; | |
grid-template-columns: 1fr 1fr; | |
gap: 25px; | |
margin-bottom: 40px; | |
}} | |
.stat-card {{ | |
background: var(--primary-accent); | |
padding: 25px; | |
border-radius: 20px; | |
transition: transform 0.3s ease, box-shadow 0.3s ease; | |
}} | |
.stat-card:hover {{ | |
transform: scale(1.05); | |
box-shadow: 0 0 25px var(--secondary-accent)44; | |
}} | |
.stat-number {{ | |
font-size: 3rem; | |
font-weight: 700; | |
color: var(--secondary-accent); | |
}} | |
.stat-label {{ | |
font-size: 1rem; | |
color: #a7a9be; | |
margin-top: 5px; | |
}} | |
.button-group {{ | |
display: flex; | |
gap: 20px; | |
justify-content: center; | |
}} | |
.btn {{ | |
padding: 15px 35px; | |
border-radius: 50px; | |
text-decoration: none; | |
font-weight: 700; | |
font-size: 1.1rem; | |
transition: all 0.3s ease; | |
border: none; | |
cursor: pointer; | |
position: relative; | |
overflow: hidden; | |
}} | |
.btn-primary {{ | |
background: var(--secondary-accent); | |
color: var(--text-header); | |
box-shadow: 0 5px 15px {status_color}44; | |
}} | |
.btn-primary:hover {{ | |
box-shadow: 0 8px 25px {status_color}66; | |
transform: translateY(-3px); | |
}} | |
.btn-secondary {{ | |
background: transparent; | |
color: var(--text-light); | |
border: 2px solid var(--glow-color); | |
}} | |
.btn-secondary:hover {{ | |
background: var(--glow-color); | |
color: var(--text-header); | |
border-color: var(--glow-color); | |
}} | |
</style> | |
</head> | |
<body> | |
<div class="main-content"> | |
<div class="container"> | |
<div class="logo">🚀</div> | |
<h1>بصيـر API</h1> | |
<p class="subtitle">مستقبل القيادة الذاتية بين يديك</p> | |
<div class="status-badge"> | |
<span style="width: 12px; height: 12px; background-color: {status_color}; border-radius: 50%;"></span> | |
<span>{status_text}</span> | |
</div> | |
<div class="stats-grid"> | |
<div class="stat-card"> | |
<div class="stat-number">{active_sessions_count}</div> | |
<div class="stat-label">الجلسات النشطة</div> | |
</div> | |
<div class="stat-card"> | |
<div class="stat-number">1.1</div> | |
<div class="stat-label">الإصدار</div> | |
</div> | |
</div> | |
<div class="button-group"> | |
<a href="/docs" target="_blank" class="btn btn-primary">📚 التوثيق التفاعلي</a> | |
<a href="/sessions" target="_blank" class="btn btn-secondary">📊 عرض الجلسات</a> | |
<a href="https://huggingface.co/spaces/mohammed-aljafry/Baseer_Simulation" target="_blank" class="btn btn-primary"> التفاعل</a> | |
</div> | |
</div> | |
</div> | |
</body> | |
</html> | |
""" | |
return HTMLResponse(content=html_content) | |
def start_session(): | |
session_id = str(uuid.uuid4()) | |
# 1. الحصول على الإعدادات الكاملة من المصدر الوحيد | |
config = get_master_config() | |
# 2. استخراج الإعدادات المطلوبة لكل مكون بشكل صريح | |
grid_conf = config['grid_conf'] | |
controller_params = config['controller_params'] | |
simulation_freq = config['simulation']['frequency'] | |
# 3. تهيئة المتتبع (Tracker) بمعلماته المحددة | |
tracker = Tracker( | |
grid_conf=grid_conf, | |
match_threshold=controller_params.get('tracker_match_thresh', 2.5), | |
prune_age=controller_params.get('tracker_prune_age', 5) | |
) | |
# 4. تهيئة المتحكم (Controller) | |
controller = InterfuserController({ | |
'controller_params': controller_params, | |
'grid_conf': grid_conf, | |
'frequency': simulation_freq | |
}) | |
# 5. إنشاء الجلسة بالكائنات المهيأة | |
SESSIONS[session_id] = { | |
'tracker': tracker, | |
'controller': controller, | |
'frame_num': 0 | |
} | |
logging.info(f"New session started: {session_id}") | |
return {"session_id": session_id} | |
def run_step(request: RunStepRequest): | |
if MODEL is None: | |
raise HTTPException(status_code=503, detail="Model is not available.") | |
session = SESSIONS.get(request.session_id) | |
if not session: | |
raise HTTPException(status_code=404, detail="Session ID not found.") | |
# --- 1. الإدراك (Perception) --- | |
image = b64_to_cv2(request.image_b64) | |
model_input = prepare_model_input(image, request.measurements) | |
traffic, waypoints, junc, light, stop, _ = MODEL(model_input) | |
# --- 2. معالجة مخرجات النموذج --- | |
traffic_processed = torch.cat([torch.sigmoid(traffic[0][:, 0:1]), traffic[0][:, 1:]], dim=1) | |
traffic_np = traffic_processed.cpu().numpy().reshape(20, 20, -1) | |
waypoints_np = waypoints[0].cpu().numpy() | |
junction_prob = torch.softmax(junc, dim=1)[0, 1].item() | |
light_prob = torch.softmax(light, dim=1)[0, 1].item() | |
stop_prob = torch.softmax(stop, dim=1)[0, 1].item() | |
# --- 3. التتبع والتحكم --- | |
ego_pos = np.array(request.measurements.pos_global) | |
ego_theta = request.measurements.theta | |
frame_num = session['frame_num'] | |
active_tracks = session['tracker'].process_frame(traffic_np, ego_pos, ego_theta, frame_num) | |
steer, throttle, brake, ctrl_info = session['controller'].run_step( | |
speed=request.measurements.speed, waypoints=torch.from_numpy(waypoints_np), | |
junction=junction_prob, traffic_light=light_prob, stop_sign=stop_prob, | |
bev_map=traffic_np, ego_pos=ego_pos, ego_theta=ego_theta, frame_num=frame_num | |
) | |
# --- 4. إنشاء الواجهة المرئية --- | |
display_iface = DisplayInterface(DisplayConfig(width=1600, height=900)) | |
bev_maps = render_bev(active_tracks, waypoints_np, ego_pos, ego_theta) | |
display_data = { | |
'camera_view': image, 'map_t0': bev_maps['t0'], 'map_t1': bev_maps['t1'], 'map_t2': bev_maps['t2'], | |
'frame_num': frame_num, 'speed': request.measurements.speed * 3.6, | |
'target_speed': ctrl_info.get('target_speed', 0) * 3.6, | |
'steer': steer, 'throttle': throttle, 'brake': brake, | |
'light_prob': light_prob, 'stop_prob': stop_prob, | |
'object_counts': {'car': len(active_tracks)} | |
} | |
dashboard = display_iface.run_interface(display_data) | |
# --- 5. تحديث الجلسة وإرجاع الرد --- | |
session['frame_num'] += 1 | |
return RunStepResponse( | |
control_commands=ControlCommands(steer=steer, throttle=throttle, brake=brake), | |
scene_analysis=SceneAnalysis(is_junction=junction_prob, traffic_light_state=light_prob, stop_sign=stop_prob), | |
predicted_waypoints=[tuple(wp) for wp in waypoints_np.tolist()], | |
dashboard_b64=cv2_to_b64(dashboard), | |
reason=ctrl_info.get('brake_reason', "Cruising") | |
) | |
def end_session(session_id: str): | |
if session_id in SESSIONS: | |
del SESSIONS[session_id] | |
logging.info(f"Session ended: {session_id}") | |
return {"message": f"Session {session_id} ended."} | |
raise HTTPException(status_code=404, detail="Session not found.") | |
# ================== تشغيل الخادم ================== | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=7860) | |