Baseer_Server / app.py
BaseerAI's picture
Update app.py
e42aa92 verified
# app.py - InterFuser Self-Driving API Server
import uuid
import base64
import cv2
import torch
import numpy as np
from fastapi import FastAPI, HTTPException
from fastapi.responses import HTMLResponse
from pydantic import BaseModel
from torchvision import transforms
from typing import List, Dict, Any, Optional
import logging
import uuid
import base64
import cv2
import torch
import numpy as np
import logging
from fastapi import FastAPI, HTTPException
from fastapi.responses import HTMLResponse
from pydantic import BaseModel, Field
from typing import List, Dict, Tuple
from model_definition import InterfuserHDPE , load_and_prepare_model, get_master_config
from simulation_modules import InterfuserController, Tracker
from simulation_modules import DisplayInterface, render_bev, unnormalize_image, DisplayConfig
# ==============================================================================
# 2. إعدادات عامة وتطبيق FastAPI
# ==============================================================================
# إعداد التسجيل (Logging)
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# تهيئة تطبيق FastAPI
app = FastAPI(
title="Baseer Self-Driving API",
description="An advanced API for the InterFuser self-driving model, providing real-time control commands and scene analysis.",
version="1.1.0"
)
# متغيرات عامة سيتم تهيئتها عند بدء التشغيل
MODEL: InterfuserHDPE = None
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SESSIONS: Dict[str, Dict] = {} # قاموس لتخزين حالة الجلسات النشطة
# ==============================================================================
# 3. تعريف نماذج البيانات (Pydantic Models) للـ API
# ==============================================================================
class Measurements(BaseModel):
pos_global: Tuple[float, float] = Field(..., example=(0.0, 0.0), description="Global [X, Y] position of the vehicle.")
theta: float = Field(..., example=0.0, description="Global orientation angle of the vehicle in radians.")
speed: float = Field(..., example=0.0, description="Current speed in m/s.")
target_point: Tuple[float, float] = Field(..., example=(10.0, 0.0), description="Target point relative to the vehicle.")
class RunStepRequest(BaseModel):
session_id: str
image_b64: str = Field(..., description="Base64 encoded string of the vehicle's front camera view (BGR format).")
measurements: Measurements
class ControlCommands(BaseModel):
steer: float
throttle: float
brake: bool
class SceneAnalysis(BaseModel):
is_junction: float
traffic_light_state: float
stop_sign: float
class RunStepResponse(BaseModel):
control_commands: ControlCommands
scene_analysis: SceneAnalysis
predicted_waypoints: List[Tuple[float, float]]
dashboard_b64: str = Field(..., description="Base64 encoded string of the comprehensive dashboard view.")
reason: str = Field(..., description="The reason for the current control action (e.g., 'Following ID 12', 'Red Light').")
# ==============================================================================
# 4. دوال مساعدة (Helpers)
# ==============================================================================
def b64_to_cv2(b64_string: str) -> np.ndarray:
try:
img_bytes = base64.b64decode(b64_string)
img_array = np.frombuffer(img_bytes, dtype=np.uint8)
return cv2.imdecode(img_array, cv2.IMREAD_COLOR)
except Exception:
raise HTTPException(status_code=400, detail="Invalid Base64 image string.")
def cv2_to_b64(img: np.ndarray) -> str:
_, buffer = cv2.imencode('.jpg', img)
return base64.b64encode(buffer).decode('utf-8')
def prepare_model_input(image: np.ndarray, measurements: Measurements) -> Dict[str, torch.Tensor]:
"""
إعداد دفعة (batch of 1) لتمريرها إلى النموذج.
"""
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Resize((224, 224), antialias=True),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image_tensor = transform(image_rgb).unsqueeze(0).to(DEVICE)
measurements_tensor = torch.tensor([[
measurements.pos_global[0], measurements.pos_global[1], measurements.theta,
0.0, 0.0, 0.0, # Steer, throttle, brake (not used by model)
measurements.speed, 4.0 # Command (default to FollowLane)
]], dtype=torch.float32).to(DEVICE)
target_point_tensor = torch.tensor([measurements.target_point], dtype=torch.float32).to(DEVICE)
return {
'rgb': image_tensor,
'rgb_left': image_tensor.clone(), 'rgb_right': image_tensor.clone(), 'rgb_center': image_tensor.clone(),
'measurements': measurements_tensor,
'target_point': target_point_tensor,
'lidar': torch.zeros_like(image_tensor)
}
# ==============================================================================
# 5. أحداث دورة حياة التطبيق (Startup/Shutdown)
# ==============================================================================
@app.on_event("startup")
async def startup_event():
global MODEL
logging.info("🚗 Server starting up...")
logging.info(f"Using device: {DEVICE}")
MODEL = load_and_prepare_model(DEVICE)
if MODEL:
logging.info("✅ Model loaded successfully. Server is ready!")
else:
logging.error("❌ CRITICAL: Model could not be loaded. The API will not function correctly.")
# ==============================================================================
# 6. نقاط النهاية الرئيسية (API Endpoints)
# ==============================================================================
@app.get("/", response_class=HTMLResponse, include_in_schema=False, tags=["General"])
async def root():
"""
[النسخة النهائية مع التمرير]
يعرض صفحة رئيسية احترافية وجذابة بصريًا مع تمكين التمرير العمودي.
"""
active_sessions_count = len(SESSIONS)
status_color = "#00ff7f" # SpringGreen
status_text = "متصل ويعمل"
if MODEL is None:
status_color = "#ff4757" # Red
status_text = "خطأ: النموذج غير متاح"
html_content = f"""
<!DOCTYPE html>
<html dir="rtl" lang="ar">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>🚗 Baseer - واجهة القيادة الذاتية</title>
<style>
@import url('https://fonts.googleapis.com/css2?family=Cairo:wght@400;700;900&display=swap');
:root {{
--bg-dark: #1a1a2e;
--panel-dark: #16213e;
--primary-accent: #0f3460;
--secondary-accent: #e94560;
--glow-color: #537895;
--text-light: #e0e0e0;
--text-header: #ffffff;
}}
body {{
font-family: 'Cairo', sans-serif;
background: var(--bg-dark);
background-image: linear-gradient(to right top, #1a1a2e, #1c1d32, #1f2037, #22233b, #252640);
color: var(--text-light);
margin: 0;
padding: 40px 20px; /* إضافة padding علوي وسفلي للسماح بالمساحة */
min-height: 100vh;
box-sizing: border-box; /* لضمان أن padding لا يضيف إلى الارتفاع */
/* --- [التصحيح هنا] --- */
overflow-x: hidden; /* إخفاء التمرير الأفقي غير المرغوب فيه */
overflow-y: auto; /* السماح بالتمرير العمودي عند الحاجة */
}}
.main-content {{
display: flex;
justify-content: center;
align-items: center;
width: 100%;
}}
.container {{
background: rgba(22, 33, 62, 0.85);
backdrop-filter: blur(15px);
border-radius: 25px;
padding: 40px 50px;
box-shadow: 0 25px 50px rgba(0, 0, 0, 0.3);
text-align: center;
max-width: 750px;
width: 100%;
border: 1px solid rgba(255, 255, 255, 0.1);
position: relative;
z-index: 1;
}}
/* ... باقي كود CSS يبقى كما هو تمامًا ... */
.logo {{
font-size: 5rem;
margin-bottom: 10px;
text-shadow: 0 0 15px var(--glow-color);
animation: float 4s ease-in-out infinite;
}}
@keyframes float {{
0% {{ transform: translateY(0px); }}
50% {{ transform: translateY(-15px); }}
100% {{ transform: translateY(0px); }}
}}
h1 {{
font-size: 3rem;
font-weight: 900;
color: var(--text-header);
margin-bottom: 5px;
letter-spacing: 1px;
}}
.subtitle {{
font-size: 1.3rem;
color: #a7a9be;
margin-bottom: 25px;
}}
.status-badge {{
display: inline-flex;
align-items: center;
gap: 10px;
background-color: rgba(255, 255, 255, 0.05);
border: 1px solid {status_color};
color: {status_color};
padding: 10px 22px;
border-radius: 50px;
font-weight: bold;
margin-bottom: 35px;
font-size: 1.1rem;
box-shadow: 0 0 15px {status_color}33;
}}
.stats-grid {{
display: grid;
grid-template-columns: 1fr 1fr;
gap: 25px;
margin-bottom: 40px;
}}
.stat-card {{
background: var(--primary-accent);
padding: 25px;
border-radius: 20px;
transition: transform 0.3s ease, box-shadow 0.3s ease;
}}
.stat-card:hover {{
transform: scale(1.05);
box-shadow: 0 0 25px var(--secondary-accent)44;
}}
.stat-number {{
font-size: 3rem;
font-weight: 700;
color: var(--secondary-accent);
}}
.stat-label {{
font-size: 1rem;
color: #a7a9be;
margin-top: 5px;
}}
.button-group {{
display: flex;
gap: 20px;
justify-content: center;
}}
.btn {{
padding: 15px 35px;
border-radius: 50px;
text-decoration: none;
font-weight: 700;
font-size: 1.1rem;
transition: all 0.3s ease;
border: none;
cursor: pointer;
position: relative;
overflow: hidden;
}}
.btn-primary {{
background: var(--secondary-accent);
color: var(--text-header);
box-shadow: 0 5px 15px {status_color}44;
}}
.btn-primary:hover {{
box-shadow: 0 8px 25px {status_color}66;
transform: translateY(-3px);
}}
.btn-secondary {{
background: transparent;
color: var(--text-light);
border: 2px solid var(--glow-color);
}}
.btn-secondary:hover {{
background: var(--glow-color);
color: var(--text-header);
border-color: var(--glow-color);
}}
</style>
</head>
<body>
<div class="main-content">
<div class="container">
<div class="logo">🚀</div>
<h1>بصيـر API</h1>
<p class="subtitle">مستقبل القيادة الذاتية بين يديك</p>
<div class="status-badge">
<span style="width: 12px; height: 12px; background-color: {status_color}; border-radius: 50%;"></span>
<span>{status_text}</span>
</div>
<div class="stats-grid">
<div class="stat-card">
<div class="stat-number">{active_sessions_count}</div>
<div class="stat-label">الجلسات النشطة</div>
</div>
<div class="stat-card">
<div class="stat-number">1.1</div>
<div class="stat-label">الإصدار</div>
</div>
</div>
<div class="button-group">
<a href="/docs" target="_blank" class="btn btn-primary">📚 التوثيق التفاعلي</a>
<a href="/sessions" target="_blank" class="btn btn-secondary">📊 عرض الجلسات</a>
<a href="https://huggingface.co/spaces/mohammed-aljafry/Baseer_Simulation" target="_blank" class="btn btn-primary"> التفاعل</a>
</div>
</div>
</div>
</body>
</html>
"""
return HTMLResponse(content=html_content)
@app.post("/start_session", summary="Start a new driving session", tags=["Session Management"])
def start_session():
session_id = str(uuid.uuid4())
# 1. الحصول على الإعدادات الكاملة من المصدر الوحيد
config = get_master_config()
# 2. استخراج الإعدادات المطلوبة لكل مكون بشكل صريح
grid_conf = config['grid_conf']
controller_params = config['controller_params']
simulation_freq = config['simulation']['frequency']
# 3. تهيئة المتتبع (Tracker) بمعلماته المحددة
tracker = Tracker(
grid_conf=grid_conf,
match_threshold=controller_params.get('tracker_match_thresh', 2.5),
prune_age=controller_params.get('tracker_prune_age', 5)
)
# 4. تهيئة المتحكم (Controller)
controller = InterfuserController({
'controller_params': controller_params,
'grid_conf': grid_conf,
'frequency': simulation_freq
})
# 5. إنشاء الجلسة بالكائنات المهيأة
SESSIONS[session_id] = {
'tracker': tracker,
'controller': controller,
'frame_num': 0
}
logging.info(f"New session started: {session_id}")
return {"session_id": session_id}
@app.post("/run_step", response_model=RunStepResponse, summary="Process a single simulation step", tags=["Core"])
@torch.no_grad()
def run_step(request: RunStepRequest):
if MODEL is None:
raise HTTPException(status_code=503, detail="Model is not available.")
session = SESSIONS.get(request.session_id)
if not session:
raise HTTPException(status_code=404, detail="Session ID not found.")
# --- 1. الإدراك (Perception) ---
image = b64_to_cv2(request.image_b64)
model_input = prepare_model_input(image, request.measurements)
traffic, waypoints, junc, light, stop, _ = MODEL(model_input)
# --- 2. معالجة مخرجات النموذج ---
traffic_processed = torch.cat([torch.sigmoid(traffic[0][:, 0:1]), traffic[0][:, 1:]], dim=1)
traffic_np = traffic_processed.cpu().numpy().reshape(20, 20, -1)
waypoints_np = waypoints[0].cpu().numpy()
junction_prob = torch.softmax(junc, dim=1)[0, 1].item()
light_prob = torch.softmax(light, dim=1)[0, 1].item()
stop_prob = torch.softmax(stop, dim=1)[0, 1].item()
# --- 3. التتبع والتحكم ---
ego_pos = np.array(request.measurements.pos_global)
ego_theta = request.measurements.theta
frame_num = session['frame_num']
active_tracks = session['tracker'].process_frame(traffic_np, ego_pos, ego_theta, frame_num)
steer, throttle, brake, ctrl_info = session['controller'].run_step(
speed=request.measurements.speed, waypoints=torch.from_numpy(waypoints_np),
junction=junction_prob, traffic_light=light_prob, stop_sign=stop_prob,
bev_map=traffic_np, ego_pos=ego_pos, ego_theta=ego_theta, frame_num=frame_num
)
# --- 4. إنشاء الواجهة المرئية ---
display_iface = DisplayInterface(DisplayConfig(width=1600, height=900))
bev_maps = render_bev(active_tracks, waypoints_np, ego_pos, ego_theta)
display_data = {
'camera_view': image, 'map_t0': bev_maps['t0'], 'map_t1': bev_maps['t1'], 'map_t2': bev_maps['t2'],
'frame_num': frame_num, 'speed': request.measurements.speed * 3.6,
'target_speed': ctrl_info.get('target_speed', 0) * 3.6,
'steer': steer, 'throttle': throttle, 'brake': brake,
'light_prob': light_prob, 'stop_prob': stop_prob,
'object_counts': {'car': len(active_tracks)}
}
dashboard = display_iface.run_interface(display_data)
# --- 5. تحديث الجلسة وإرجاع الرد ---
session['frame_num'] += 1
return RunStepResponse(
control_commands=ControlCommands(steer=steer, throttle=throttle, brake=brake),
scene_analysis=SceneAnalysis(is_junction=junction_prob, traffic_light_state=light_prob, stop_sign=stop_prob),
predicted_waypoints=[tuple(wp) for wp in waypoints_np.tolist()],
dashboard_b64=cv2_to_b64(dashboard),
reason=ctrl_info.get('brake_reason', "Cruising")
)
@app.post("/end_session", summary="End and clean up a session", tags=["Session Management"])
def end_session(session_id: str):
if session_id in SESSIONS:
del SESSIONS[session_id]
logging.info(f"Session ended: {session_id}")
return {"message": f"Session {session_id} ended."}
raise HTTPException(status_code=404, detail="Session not found.")
# ================== تشغيل الخادم ==================
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)