Baseer_Server / app.py
altawil
Update app.py
57b9764 verified
raw
history blame
24.4 kB
# app.py - InterFuser Self-Driving API Server
import uuid
import base64
import cv2
# import torch
# import numpy as np
# from fastapi import FastAPI, HTTPException
# from fastapi.responses import HTMLResponse
# from pydantic import BaseModel
# from torchvision import transforms
# from typing import List, Dict, Any, Optional
# import logging
# # استيراد من ملفاتنا المحلية
# from model_definition import InterfuserModel, load_and_prepare_model, create_model_config
# from simulation_modules import (
# InterfuserController, ControllerConfig, Tracker, DisplayInterface,
# render, render_waypoints, render_self_car, WAYPOINT_SCALE_FACTOR,
# T1_FUTURE_TIME, T2_FUTURE_TIME
# )
# # إعداد التسجيل
# logging.basicConfig(level=logging.INFO)
# logger = logging.getLogger(__name__)
# # ================== إعدادات عامة وتحميل النموذج ==================
# app = FastAPI(
# title="Baseer Self-Driving API",
# description="API للقيادة الذاتية باستخدام نموذج InterFuser",
# version="1.0.0"
# )
# device = torch.device("cpu")
# logger.info(f"Using device: {device}")
# # تحميل النموذج باستخدام الدالة المحسنة
# try:
# # إنشاء إعدادات النموذج باستخدام الإعدادات الصحيحة من التدريب
# model_config = create_model_config(
# model_path="model/best_model.pth"
# # الإعدادات الصحيحة من التدريب ستطبق تلقائياً:
# # embed_dim=256, rgb_backbone_name='r50', waypoints_pred_head='gru'
# # with_lidar=False, with_right_left_sensors=False, with_center_sensor=False
# )
# # تحميل النموذج مع الأوزان
# model = load_and_prepare_model(model_config, device)
# logger.info("✅ تم تحميل النموذج بنجاح")
# except Exception as e:
# logger.error(f"❌ خطأ في تحميل النموذج: {e}")
# logger.info("🔄 محاولة إنشاء نموذج بأوزان عشوائية...")
# try:
# model = InterfuserModel()
# model.to(device)
# model.eval()
# logger.warning("⚠️ تم إنشاء النموذج بأوزان عشوائية")
# except Exception as e2:
# logger.error(f"❌ فشل في إنشاء النموذج: {e2}")
# model = None
# # تهيئة واجهة العرض
# display = DisplayInterface()
# # قاموس لتخزين جلسات المستخدمين
# SESSIONS: Dict[str, Dict] = {}
# # ================== هياكل بيانات Pydantic ==================
# class Measurements(BaseModel):
# pos: List[float] = [0.0, 0.0] # [x, y] position
# theta: float = 0.0 # orientation angle
# speed: float = 0.0 # current speed
# steer: float = 0.0 # current steering
# throttle: float = 0.0 # current throttle
# brake: bool = False # brake status
# command: int = 4 # driving command (4 = FollowLane)
# target_point: List[float] = [0.0, 0.0] # target point [x, y]
# class ModelOutputs(BaseModel):
# traffic: List[List[List[float]]] # 20x20x7 grid
# waypoints: List[List[float]] # Nx2 waypoints
# is_junction: float
# traffic_light_state: float
# stop_sign: float
# class ControlCommands(BaseModel):
# steer: float
# throttle: float
# brake: bool
# class RunStepInput(BaseModel):
# session_id: str
# image_b64: str
# measurements: Measurements
# class RunStepOutput(BaseModel):
# model_outputs: ModelOutputs
# control_commands: ControlCommands
# dashboard_image_b64: str
# class SessionResponse(BaseModel):
# session_id: str
# message: str
# # ================== دوال المساعدة ==================
# def get_image_transform():
# """إنشاء تحويلات الصورة كما في PDMDataset"""
# return transforms.Compose([
# transforms.ToTensor(),
# transforms.Resize((224, 224), antialias=True),
# transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
# ])
# # إنشاء كائن التحويل مرة واحدة
# image_transform = get_image_transform()
# def preprocess_input(frame_rgb: np.ndarray, measurements: Measurements, device: torch.device) -> Dict[str, torch.Tensor]:
# """
# تحاكي ما يفعله PDMDataset.__getitem__ لإنشاء دفعة (batch) واحدة.
# """
# # 1. معالجة الصورة الرئيسية
# from PIL import Image
# if isinstance(frame_rgb, np.ndarray):
# frame_rgb = Image.fromarray(frame_rgb)
# image_tensor = image_transform(frame_rgb).unsqueeze(0).to(device) # إضافة بُعد الدفعة
# # 2. إنشاء مدخلات الكاميرات الأخرى عن طريق الاستنساخ
# batch = {
# 'rgb': image_tensor,
# 'rgb_left': image_tensor.clone(),
# 'rgb_right': image_tensor.clone(),
# 'rgb_center': image_tensor.clone(),
# }
# # 3. إنشاء مدخل ليدار وهمي (أصفار)
# batch['lidar'] = torch.zeros(1, 3, 224, 224, dtype=torch.float32).to(device)
# # 4. تجميع القياسات بنفس ترتيب PDMDataset
# m = measurements
# measurements_tensor = torch.tensor([[
# m.pos[0], m.pos[1], m.theta,
# m.steer, m.throttle, float(m.brake),
# m.speed, float(m.command)
# ]], dtype=torch.float32).to(device)
# batch['measurements'] = measurements_tensor
# # 5. إنشاء نقطة هدف
# batch['target_point'] = torch.tensor([m.target_point], dtype=torch.float32).to(device)
# # لا نحتاج إلى قيم ground truth (gt_*) أثناء التنبؤ
# return batch
# def decode_base64_image(image_b64: str) -> np.ndarray:
# """
# فك تشفير صورة Base64
# """
# try:
# image_bytes = base64.b64decode(image_b64)
# nparr = np.frombuffer(image_bytes, np.uint8)
# image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
# return image
# except Exception as e:
# raise HTTPException(status_code=400, detail=f"Invalid image format: {str(e)}")
# def encode_image_to_base64(image: np.ndarray) -> str:
# """
# تشفير صورة إلى Base64
# """
# _, buffer = cv2.imencode('.jpg', image, [cv2.IMWRITE_JPEG_QUALITY, 85])
# return base64.b64encode(buffer).decode('utf-8')
# # ================== نقاط نهاية الـ API ==================
# @app.get("/", response_class=HTMLResponse)
# async def root():
# """
# الصفحة الرئيسية للـ API
# """
# html_content = f"""
# <!DOCTYPE html>
# <html dir="rtl" lang="ar">
# <head>
# <meta charset="UTF-8">
# <meta name="viewport" content="width=device-width, initial-scale=1.0">
# <title>🚗 Baseer Self-Driving API</title>
# <style>
# * {{
# margin: 0;
# padding: 0;
# box-sizing: border-box;
# }}
# body {{
# font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
# background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
# min-height: 100vh;
# display: flex;
# align-items: center;
# justify-content: center;
# padding: 20px;
# }}
# .container {{
# background: rgba(255, 255, 255, 0.95);
# backdrop-filter: blur(10px);
# border-radius: 20px;
# padding: 40px;
# box-shadow: 0 20px 40px rgba(0, 0, 0, 0.1);
# text-align: center;
# max-width: 600px;
# width: 100%;
# }}
# .logo {{
# font-size: 4rem;
# margin-bottom: 20px;
# animation: bounce 2s infinite;
# }}
# @keyframes bounce {{
# 0%, 20%, 50%, 80%, 100% {{ transform: translateY(0); }}
# 40% {{ transform: translateY(-10px); }}
# 60% {{ transform: translateY(-5px); }}
# }}
# h1 {{
# color: #333;
# margin-bottom: 10px;
# font-size: 2.5rem;
# }}
# .subtitle {{
# color: #666;
# margin-bottom: 30px;
# font-size: 1.2rem;
# }}
# .status {{
# display: inline-block;
# background: #4CAF50;
# color: white;
# padding: 8px 16px;
# border-radius: 20px;
# margin: 10px 0;
# font-weight: bold;
# }}
# .stats {{
# display: grid;
# grid-template-columns: repeat(auto-fit, minmax(150px, 1fr));
# gap: 20px;
# margin: 30px 0;
# }}
# .stat-card {{
# background: #f8f9fa;
# padding: 20px;
# border-radius: 15px;
# border-left: 4px solid #667eea;
# }}
# .stat-number {{
# font-size: 2rem;
# font-weight: bold;
# color: #667eea;
# }}
# .stat-label {{
# color: #666;
# margin-top: 5px;
# }}
# .buttons {{
# display: flex;
# gap: 15px;
# justify-content: center;
# flex-wrap: wrap;
# margin-top: 30px;
# }}
# .btn {{
# display: inline-block;
# padding: 12px 24px;
# border-radius: 25px;
# text-decoration: none;
# font-weight: bold;
# transition: all 0.3s ease;
# border: none;
# cursor: pointer;
# }}
# .btn-primary {{
# background: #667eea;
# color: white;
# }}
# .btn-secondary {{
# background: #6c757d;
# color: white;
# }}
# .btn:hover {{
# transform: translateY(-2px);
# box-shadow: 0 5px 15px rgba(0, 0, 0, 0.2);
# }}
# .features {{
# text-align: right;
# margin-top: 30px;
# padding: 20px;
# background: #f8f9fa;
# border-radius: 15px;
# }}
# .features h3 {{
# color: #333;
# margin-bottom: 15px;
# }}
# .features ul {{
# list-style: none;
# padding: 0;
# }}
# .features li {{
# padding: 5px 0;
# color: #666;
# }}
# .features li:before {{
# content: "✅ ";
# margin-left: 10px;
# }}
# </style>
# </head>
# <body>
# <div class="container">
# <div class="logo">🚗</div>
# <h1>Baseer Self-Driving API</h1>
# <p class="subtitle">نظام القيادة الذاتية المتقدم</p>
# <div class="status">🟢 يعمل بنجاح</div>
# <div class="stats">
# <div class="stat-card">
# <div class="stat-number">{len(SESSIONS)}</div>
# <div class="stat-label">الجلسات النشطة</div>
# </div>
# <div class="stat-card">
# <div class="stat-number">v1.0</div>
# <div class="stat-label">الإصدار</div>
# </div>
# <div class="stat-card">
# <div class="stat-number">FastAPI</div>
# <div class="stat-label">التقنية</div>
# </div>
# </div>
# <div class="buttons">
# <a href="/docs" class="btn btn-primary">📚 توثيق API</a>
# <a href="/sessions" class="btn btn-secondary">📊 الجلسات</a>
# </div>
# <div class="features">
# <h3>🌟 الميزات الرئيسية</h3>
# <ul>
# <li>نموذج InterFuser للقيادة الذاتية</li>
# <li>معالجة الصور في الوقت الفعلي</li>
# <li>اكتشاف الكائنات المرورية</li>
# <li>تحديد المسارات الذكية</li>
# <li>واجهة RESTful سهلة الاستخدام</li>
# <li>إدارة جلسات متعددة</li>
# </ul>
# </div>
# </div>
# </body>
# </html>
# """
# return html_content
import uuid
import base64
import cv2
import torch
import numpy as np
import logging
from fastapi import FastAPI, HTTPException
from fastapi.responses import HTMLResponse
from pydantic import BaseModel, Field
from typing import List, Dict, Tuple
# ==============================================================================
# 1. استيراد كل مكونات المشروع التي قمنا بتطويرها
# (تأكد من أن هذه الملفات موجودة في نفس المجلد)
# ==============================================================================
# من ملف النموذج (يحتوي على كلاس Interfuser والدوال المساعدة)
from model_definition import InterfuserModel, load_and_prepare_model, create_model_config
# من ملفات التحكم والعرض
from simulation_modules import InterfuserController, Tracker
from simulation_modules import DisplayInterface, render_bev, unnormalize_image, DisplayConfig
# # استيراد من ملفاتنا المحلية
# from model_definition import InterfuserModel, load_and_prepare_model, create_model_config
# from simulation_modules import (
# InterfuserController, ControllerConfig, Tracker, DisplayInterface,
# render, render_waypoints, render_self_car, WAYPOINT_SCALE_FACTOR,
# T1_FUTURE_TIME, T2_FUTURE_TIME
# )
# ==============================================================================
# 2. إعدادات عامة وتطبيق FastAPI
# ==============================================================================
# إعداد التسجيل (Logging)
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# تهيئة تطبيق FastAPI
app = FastAPI(
title="Baseer Self-Driving API",
description="An advanced API for the InterFuser self-driving model, providing real-time control commands and scene analysis.",
version="1.1.0"
)
# متغيرات عامة سيتم تهيئتها عند بدء التشغيل
MODEL: Interfuser = None
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SESSIONS: Dict[str, Dict] = {} # قاموس لتخزين حالة الجلسات النشطة
# ==============================================================================
# 3. تعريف نماذج البيانات (Pydantic Models) للـ API
# ==============================================================================
class Measurements(BaseModel):
pos_global: Tuple[float, float] = Field(..., example=(0.0, 0.0), description="Global [X, Y] position of the vehicle.")
theta: float = Field(..., example=0.0, description="Global orientation angle of the vehicle in radians.")
speed: float = Field(..., example=0.0, description="Current speed in m/s.")
target_point: Tuple[float, float] = Field(..., example=(10.0, 0.0), description="Target point relative to the vehicle.")
class RunStepRequest(BaseModel):
session_id: str
image_b64: str = Field(..., description="Base64 encoded string of the vehicle's front camera view (BGR format).")
measurements: Measurements
class ControlCommands(BaseModel):
steer: float
throttle: float
brake: bool
class SceneAnalysis(BaseModel):
is_junction: float
traffic_light_state: float
stop_sign: float
class RunStepResponse(BaseModel):
control_commands: ControlCommands
scene_analysis: SceneAnalysis
predicted_waypoints: List[Tuple[float, float]]
dashboard_b64: str = Field(..., description="Base64 encoded string of the comprehensive dashboard view.")
reason: str = Field(..., description="The reason for the current control action (e.g., 'Following ID 12', 'Red Light').")
# ==============================================================================
# 4. دوال مساعدة (Helpers)
# ==============================================================================
def b64_to_cv2(b64_string: str) -> np.ndarray:
try:
img_bytes = base64.b64decode(b64_string)
img_array = np.frombuffer(img_bytes, dtype=np.uint8)
return cv2.imdecode(img_array, cv2.IMREAD_COLOR)
except Exception:
raise HTTPException(status_code=400, detail="Invalid Base64 image string.")
def cv2_to_b64(img: np.ndarray) -> str:
_, buffer = cv2.imencode('.jpg', img)
return base64.b64encode(buffer).decode('utf-8')
def prepare_model_input(image: np.ndarray, measurements: Measurements) -> Dict[str, torch.Tensor]:
"""
إعداد دفعة (batch of 1) لتمريرها إلى النموذج.
"""
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Resize((224, 224), antialias=True),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image_tensor = transform(image_rgb).unsqueeze(0).to(DEVICE)
measurements_tensor = torch.tensor([[
measurements.pos_global[0], measurements.pos_global[1], measurements.theta,
0.0, 0.0, 0.0, # Steer, throttle, brake (not used by model)
measurements.speed, 4.0 # Command (default to FollowLane)
]], dtype=torch.float32).to(DEVICE)
target_point_tensor = torch.tensor([measurements.target_point], dtype=torch.float32).to(DEVICE)
return {
'rgb': image_tensor,
'rgb_left': image_tensor.clone(), 'rgb_right': image_tensor.clone(), 'rgb_center': image_tensor.clone(),
'measurements': measurements_tensor,
'target_point': target_point_tensor,
'lidar': torch.zeros_like(image_tensor)
}
# ==============================================================================
# 5. أحداث دورة حياة التطبيق (Startup/Shutdown)
# ==============================================================================
@app.on_event("startup")
async def startup_event():
global MODEL
logging.info("🚗 Server starting up...")
logging.info(f"Using device: {DEVICE}")
MODEL = load_and_prepare_model(DEVICE)
if MODEL:
logging.info("✅ Model loaded successfully. Server is ready!")
else:
logging.error("❌ CRITICAL: Model could not be loaded. The API will not function correctly.")
# ==============================================================================
# 6. نقاط النهاية الرئيسية (API Endpoints)
# ==============================================================================
@app.get("/", response_class=HTMLResponse, include_in_schema=False)
async def root():
# هذا يعرض صفحة رئيسية بسيطة وجميلة للمستخدمين
return """
<html>
<head><title>Baseer API</title></head>
<body style='font-family: sans-serif; text-align: center; padding-top: 50px;'>
<h1>🚗 Baseer Self-Driving API</h1>
<p>Welcome! The API is running.</p>
<p>Navigate to <a href="/docs">/docs</a> for the interactive API documentation.</p>
</body>
</html>
"""
@app.post("/start_session", summary="Start a new driving session", tags=["Session Management"])
def start_session():
session_id = str(uuid.uuid4())
config = create_model_config()
controller_params = config.get('controller_params', {})
controller_params.update({'frequency': 10.0}) # Set default frequency
SESSIONS[session_id] = {
'tracker': Tracker(grid_conf=config['grid_conf']),
'controller': InterfuserController({'controller_params': controller_params, 'grid_conf': config['grid_conf']}),
'frame_num': 0
}
logging.info(f"New session started: {session_id}")
return {"session_id": session_id}
@app.post("/run_step", response_model=RunStepResponse, summary="Process a single simulation step", tags=["Core"])
@torch.no_grad()
def run_step(request: RunStepRequest):
if MODEL is None:
raise HTTPException(status_code=503, detail="Model is not available.")
session = SESSIONS.get(request.session_id)
if not session:
raise HTTPException(status_code=404, detail="Session ID not found.")
# --- 1. الإدراك (Perception) ---
image = b64_to_cv2(request.image_b64)
model_input = prepare_model_input(image, request.measurements)
traffic, waypoints, junc, light, stop, _ = MODEL(model_input)
# --- 2. معالجة مخرجات النموذج ---
traffic_processed = torch.cat([torch.sigmoid(traffic[0][:, 0:1]), traffic[0][:, 1:]], dim=1)
traffic_np = traffic_processed.cpu().numpy().reshape(20, 20, -1)
waypoints_np = waypoints[0].cpu().numpy()
junction_prob = torch.softmax(junc, dim=1)[0, 1].item()
light_prob = torch.softmax(light, dim=1)[0, 1].item()
stop_prob = torch.softmax(stop, dim=1)[0, 1].item()
# --- 3. التتبع والتحكم ---
ego_pos = np.array(request.measurements.pos_global)
ego_theta = request.measurements.theta
frame_num = session['frame_num']
active_tracks = session['tracker'].process_frame(traffic_np, ego_pos, ego_theta, frame_num)
steer, throttle, brake, ctrl_info = session['controller'].run_step(
speed=request.measurements.speed, waypoints=torch.from_numpy(waypoints_np),
junction=junction_prob, traffic_light=light_prob, stop_sign=stop_prob,
bev_map=traffic_np, ego_pos=ego_pos, ego_theta=ego_theta, frame_num=frame_num
)
# --- 4. إنشاء الواجهة المرئية ---
display_iface = DisplayInterface(DisplayConfig(width=1280, height=720))
bev_maps = render_bev(active_tracks, waypoints_np, ego_pos, ego_theta)
display_data = {
'camera_view': image, 'map_t0': bev_maps['t0'], 'map_t1': bev_maps['t1'], 'map_t2': bev_maps['t2'],
'frame_num': frame_num, 'speed': request.measurements.speed * 3.6,
'target_speed': ctrl_info.get('target_speed', 0) * 3.6,
'steer': steer, 'throttle': throttle, 'brake': brake,
'light_prob': light_prob, 'stop_prob': stop_prob,
'object_counts': {'car': len(active_tracks)}
}
dashboard = display_iface.run_interface(display_data)
# --- 5. تحديث الجلسة وإرجاع الرد ---
session['frame_num'] += 1
return RunStepResponse(
control_commands=ControlCommands(steer=steer, throttle=throttle, brake=brake),
scene_analysis=SceneAnalysis(is_junction=junction_prob, traffic_light_state=light_prob, stop_sign=stop_prob),
predicted_waypoints=[tuple(wp) for wp in waypoints_np.tolist()],
dashboard_b64=cv2_to_b64(dashboard),
reason=ctrl_info.get('brake_reason', "Cruising")
)
@app.post("/end_session", summary="End and clean up a session", tags=["Session Management"])
def end_session(session_id: str):
if session_id in SESSIONS:
del SESSIONS[session_id]
logging.info(f"Session ended: {session_id}")
return {"message": f"Session {session_id} ended."}
raise HTTPException(status_code=404, detail="Session not found.")
# ================== تشغيل الخادم ==================
# if __name__ == "__main__":
# import uvicorn
# uvicorn.run(app, host="0.0.0.0", port=7860)