Baseer_Server / app.py
altawil
Update app.py
8176c5b verified
raw
history blame
21.2 kB
# app.py - InterFuser Self-Driving API Server
import uuid
import base64
import cv2
import torch
import numpy as np
from fastapi import FastAPI, HTTPException
from fastapi.responses import HTMLResponse
from pydantic import BaseModel
from torchvision import transforms
from typing import List, Dict, Any, Optional
import logging
# استيراد من ملفاتنا المحلية
from model_definition import InterfuserModel, load_and_prepare_model, create_model_config
from simulation_modules import (
InterfuserController, ControllerConfig, Tracker, DisplayInterface,
render, render_waypoints, render_self_car, WAYPOINT_SCALE_FACTOR,
T1_FUTURE_TIME, T2_FUTURE_TIME
)
# إعداد التسجيل
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# ================== إعدادات عامة وتحميل النموذج ==================
app = FastAPI(
title="Baseer Self-Driving API",
description="API للقيادة الذاتية باستخدام نموذج InterFuser",
version="1.0.0"
)
device = torch.device("cpu")
logger.info(f"Using device: {device}")
# تحميل النموذج باستخدام الدالة المحسنة
try:
# إنشاء إعدادات النموذج باستخدام الإعدادات الصحيحة من التدريب
model_config = create_model_config(
model_path="model/best_model.pth"
# الإعدادات الصحيحة من التدريب ستطبق تلقائياً:
# embed_dim=256, rgb_backbone_name='r50', waypoints_pred_head='gru'
# with_lidar=False, with_right_left_sensors=False, with_center_sensor=False
)
# تحميل النموذج مع الأوزان
model = load_and_prepare_model(model_config, device)
logger.info("✅ تم تحميل النموذج بنجاح")
except Exception as e:
logger.error(f"❌ خطأ في تحميل النموذج: {e}")
logger.info("🔄 محاولة إنشاء نموذج بأوزان عشوائية...")
try:
model = InterfuserModel()
model.to(device)
model.eval()
logger.warning("⚠️ تم إنشاء النموذج بأوزان عشوائية")
except Exception as e2:
logger.error(f"❌ فشل في إنشاء النموذج: {e2}")
model = None
# تهيئة واجهة العرض
display = DisplayInterface()
# قاموس لتخزين جلسات المستخدمين
SESSIONS: Dict[str, Dict] = {}
# ================== هياكل بيانات Pydantic ==================
class Measurements(BaseModel):
pos: List[float] = [0.0, 0.0] # [x, y] position
theta: float = 0.0 # orientation angle
speed: float = 0.0 # current speed
steer: float = 0.0 # current steering
throttle: float = 0.0 # current throttle
brake: bool = False # brake status
command: int = 4 # driving command (4 = FollowLane)
target_point: List[float] = [0.0, 0.0] # target point [x, y]
class ModelOutputs(BaseModel):
traffic: List[List[List[float]]] # 20x20x7 grid
waypoints: List[List[float]] # Nx2 waypoints
is_junction: float
traffic_light_state: float
stop_sign: float
class ControlCommands(BaseModel):
steer: float
throttle: float
brake: bool
class RunStepInput(BaseModel):
session_id: str
image_b64: str
measurements: Measurements
class RunStepOutput(BaseModel):
model_outputs: ModelOutputs
control_commands: ControlCommands
dashboard_image_b64: str
class SessionResponse(BaseModel):
session_id: str
message: str
# ================== دوال المساعدة ==================
def get_image_transform():
"""إنشاء تحويلات الصورة كما في PDMDataset"""
return transforms.Compose([
transforms.ToTensor(),
transforms.Resize((224, 224), antialias=True),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# إنشاء كائن التحويل مرة واحدة
image_transform = get_image_transform()
def preprocess_input(frame_rgb: np.ndarray, measurements: Measurements, device: torch.device) -> Dict[str, torch.Tensor]:
"""
تحاكي ما يفعله PDMDataset.__getitem__ لإنشاء دفعة (batch) واحدة.
"""
# 1. معالجة الصورة الرئيسية
from PIL import Image
if isinstance(frame_rgb, np.ndarray):
frame_rgb = Image.fromarray(frame_rgb)
image_tensor = image_transform(frame_rgb).unsqueeze(0).to(device) # إضافة بُعد الدفعة
# 2. إنشاء مدخلات الكاميرات الأخرى عن طريق الاستنساخ
batch = {
'rgb': image_tensor,
'rgb_left': image_tensor.clone(),
'rgb_right': image_tensor.clone(),
'rgb_center': image_tensor.clone(),
}
# 3. إنشاء مدخل ليدار وهمي (أصفار)
batch['lidar'] = torch.zeros(1, 3, 224, 224, dtype=torch.float32).to(device)
# 4. تجميع القياسات بنفس ترتيب PDMDataset
m = measurements
measurements_tensor = torch.tensor([[
m.pos[0], m.pos[1], m.theta,
m.steer, m.throttle, float(m.brake),
m.speed, float(m.command)
]], dtype=torch.float32).to(device)
batch['measurements'] = measurements_tensor
# 5. إنشاء نقطة هدف
batch['target_point'] = torch.tensor([m.target_point], dtype=torch.float32).to(device)
# لا نحتاج إلى قيم ground truth (gt_*) أثناء التنبؤ
return batch
def decode_base64_image(image_b64: str) -> np.ndarray:
"""
فك تشفير صورة Base64
"""
try:
image_bytes = base64.b64decode(image_b64)
nparr = np.frombuffer(image_bytes, np.uint8)
image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
return image
except Exception as e:
raise HTTPException(status_code=400, detail=f"Invalid image format: {str(e)}")
def encode_image_to_base64(image: np.ndarray) -> str:
"""
تشفير صورة إلى Base64
"""
_, buffer = cv2.imencode('.jpg', image, [cv2.IMWRITE_JPEG_QUALITY, 85])
return base64.b64encode(buffer).decode('utf-8')
# ================== نقاط نهاية الـ API ==================
@app.get("/", response_class=HTMLResponse)
async def root():
"""
الصفحة الرئيسية للـ API
"""
html_content = f"""
<!DOCTYPE html>
<html dir="rtl" lang="ar">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>🚗 Baseer Self-Driving API</title>
<style>
* {{
margin: 0;
padding: 0;
box-sizing: border-box;
}}
body {{
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
min-height: 100vh;
display: flex;
align-items: center;
justify-content: center;
padding: 20px;
}}
.container {{
background: rgba(255, 255, 255, 0.95);
backdrop-filter: blur(10px);
border-radius: 20px;
padding: 40px;
box-shadow: 0 20px 40px rgba(0, 0, 0, 0.1);
text-align: center;
max-width: 600px;
width: 100%;
}}
.logo {{
font-size: 4rem;
margin-bottom: 20px;
animation: bounce 2s infinite;
}}
@keyframes bounce {{
0%, 20%, 50%, 80%, 100% {{ transform: translateY(0); }}
40% {{ transform: translateY(-10px); }}
60% {{ transform: translateY(-5px); }}
}}
h1 {{
color: #333;
margin-bottom: 10px;
font-size: 2.5rem;
}}
.subtitle {{
color: #666;
margin-bottom: 30px;
font-size: 1.2rem;
}}
.status {{
display: inline-block;
background: #4CAF50;
color: white;
padding: 8px 16px;
border-radius: 20px;
margin: 10px 0;
font-weight: bold;
}}
.stats {{
display: grid;
grid-template-columns: repeat(auto-fit, minmax(150px, 1fr));
gap: 20px;
margin: 30px 0;
}}
.stat-card {{
background: #f8f9fa;
padding: 20px;
border-radius: 15px;
border-left: 4px solid #667eea;
}}
.stat-number {{
font-size: 2rem;
font-weight: bold;
color: #667eea;
}}
.stat-label {{
color: #666;
margin-top: 5px;
}}
.buttons {{
display: flex;
gap: 15px;
justify-content: center;
flex-wrap: wrap;
margin-top: 30px;
}}
.btn {{
display: inline-block;
padding: 12px 24px;
border-radius: 25px;
text-decoration: none;
font-weight: bold;
transition: all 0.3s ease;
border: none;
cursor: pointer;
}}
.btn-primary {{
background: #667eea;
color: white;
}}
.btn-secondary {{
background: #6c757d;
color: white;
}}
.btn:hover {{
transform: translateY(-2px);
box-shadow: 0 5px 15px rgba(0, 0, 0, 0.2);
}}
.features {{
text-align: right;
margin-top: 30px;
padding: 20px;
background: #f8f9fa;
border-radius: 15px;
}}
.features h3 {{
color: #333;
margin-bottom: 15px;
}}
.features ul {{
list-style: none;
padding: 0;
}}
.features li {{
padding: 5px 0;
color: #666;
}}
.features li:before {{
content: "✅ ";
margin-left: 10px;
}}
</style>
</head>
<body>
<div class="container">
<div class="logo">🚗</div>
<h1>Baseer Self-Driving API</h1>
<p class="subtitle">نظام القيادة الذاتية المتقدم</p>
<div class="status">🟢 يعمل بنجاح</div>
<div class="stats">
<div class="stat-card">
<div class="stat-number">{len(SESSIONS)}</div>
<div class="stat-label">الجلسات النشطة</div>
</div>
<div class="stat-card">
<div class="stat-number">v1.0</div>
<div class="stat-label">الإصدار</div>
</div>
<div class="stat-card">
<div class="stat-number">FastAPI</div>
<div class="stat-label">التقنية</div>
</div>
</div>
<div class="buttons">
<a href="/docs" class="btn btn-primary">📚 توثيق API</a>
<a href="/sessions" class="btn btn-secondary">📊 الجلسات</a>
</div>
<div class="features">
<h3>🌟 الميزات الرئيسية</h3>
<ul>
<li>نموذج InterFuser للقيادة الذاتية</li>
<li>معالجة الصور في الوقت الفعلي</li>
<li>اكتشاف الكائنات المرورية</li>
<li>تحديد المسارات الذكية</li>
<li>واجهة RESTful سهلة الاستخدام</li>
<li>إدارة جلسات متعددة</li>
</ul>
</div>
</div>
</body>
</html>
"""
return html_content
@app.post("/start_session", response_model=SessionResponse)
async def start_session():
"""
بدء جلسة جديدة للمحاكاة
"""
session_id = str(uuid.uuid4())
# إنشاء جلسة جديدة
SESSIONS[session_id] = {
'tracker': Tracker(frequency=10),
'controller': InterfuserController(ControllerConfig()),
'frame_num': 0,
'created_at': np.datetime64('now'),
'last_activity': np.datetime64('now')
}
logger.info(f"New session created: {session_id}")
return SessionResponse(
session_id=session_id,
message="Session started successfully"
)
@app.post("/run_step", response_model=RunStepOutput)
async def run_step(data: RunStepInput):
"""
تنفيذ خطوة محاكاة كاملة
"""
# التحقق من وجود الجلسة
if data.session_id not in SESSIONS:
raise HTTPException(status_code=404, detail="Session not found")
session = SESSIONS[data.session_id]
tracker = session['tracker']
controller = session['controller']
# تحديث وقت النشاط
session['last_activity'] = np.datetime64('now')
try:
# 1. فك تشفير الصورة
frame_bgr = decode_base64_image(data.image_b64)
frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
# 2. معالجة المدخلات
inputs = preprocess_input(frame_rgb, data.measurements, device)
# 3. تشغيل النموذج
if model is None:
raise HTTPException(status_code=500, detail="Model not loaded")
with torch.no_grad():
traffic, waypoints, is_junction, traffic_light, stop_sign, _ = model(inputs)
# 4. معالجة مخرجات النموذج
traffic_np = traffic.cpu().numpy()[0] # أخذ أول عنصر من الـ batch
waypoints_np = waypoints.cpu().numpy()[0]
is_junction_prob = torch.sigmoid(is_junction)[0, 1].item()
traffic_light_prob = torch.sigmoid(traffic_light)[0, 0].item()
stop_sign_prob = torch.sigmoid(stop_sign)[0, 1].item()
# 5. تحديث التتبع
# تحويل traffic grid إلى detections للتتبع
detections = []
h, w, c = traffic_np.shape
for y in range(h):
for x in range(w):
for ch in range(c):
if traffic_np[y, x, ch] > 0.2: # عتبة الكشف
world_x = (x / w - 0.5) * 64 # تحويل إلى إحداثيات العالم
world_y = (y / h - 0.5) * 64
detections.append({
'position': [world_x, world_y],
'feature': traffic_np[y, x, ch]
})
updated_traffic = tracker.update_and_predict(detections, session['frame_num'])
# 6. تشغيل المتحكم
steer, throttle, brake, metadata = controller.run_step(
current_speed=data.measurements.speed,
waypoints=waypoints_np,
junction=is_junction_prob,
traffic_light_state=traffic_light_prob,
stop_sign=stop_sign_prob,
meta_data={'frame': session['frame_num']}
)
# 7. إنشاء خرائط العرض
surround_t0, counts_t0 = render(updated_traffic, t=0)
surround_t1, counts_t1 = render(updated_traffic, t=T1_FUTURE_TIME)
surround_t2, counts_t2 = render(updated_traffic, t=T2_FUTURE_TIME)
# إضافة المسار المقترح
wp_map = render_waypoints(waypoints_np)
map_t0 = cv2.add(surround_t0, wp_map)
# إضافة السيارة الذاتية
map_t0 = render_self_car(map_t0)
map_t1 = render_self_car(surround_t1)
map_t2 = render_self_car(surround_t2)
# 8. إنشاء لوحة العرض النهائية
interface_data = {
'camera_view': frame_bgr,
'map_t0': map_t0,
'map_t1': map_t1,
'map_t2': map_t2,
'text_info': {
'Frame': f"Frame: {session['frame_num']}",
'Control': f"Steer: {steer:.2f}, Throttle: {throttle:.2f}, Brake: {brake}",
'Speed': f"Speed: {data.measurements.speed:.1f} km/h",
'Junction': f"Junction: {is_junction_prob:.2f}",
'Traffic Light': f"Red Light: {traffic_light_prob:.2f}",
'Stop Sign': f"Stop Sign: {stop_sign_prob:.2f}",
'Metadata': metadata
},
'object_counts': {
't0': counts_t0,
't1': counts_t1,
't2': counts_t2
}
}
dashboard_image = display.run_interface(interface_data)
dashboard_b64 = encode_image_to_base64(dashboard_image)
# 9. تجميع المخرجات النهائية
response = RunStepOutput(
model_outputs=ModelOutputs(
traffic=traffic_np.tolist(),
waypoints=waypoints_np.tolist(),
is_junction=is_junction_prob,
traffic_light_state=traffic_light_prob,
stop_sign=stop_sign_prob
),
control_commands=ControlCommands(
steer=float(steer),
throttle=float(throttle),
brake=bool(brake)
),
dashboard_image_b64=dashboard_b64
)
# تحديث رقم الإطار
session['frame_num'] += 1
logger.info(f"Step completed for session {data.session_id}, frame {session['frame_num']}")
return response
except Exception as e:
logger.error(f"Error in run_step: {str(e)}")
raise HTTPException(status_code=500, detail=f"Processing error: {str(e)}")
@app.post("/end_session", response_model=SessionResponse)
async def end_session(session_id: str):
"""
إنهاء جلسة المحاكاة
"""
if session_id not in SESSIONS:
raise HTTPException(status_code=404, detail="Session not found")
# حذف الجلسة
del SESSIONS[session_id]
logger.info(f"Session ended: {session_id}")
return SessionResponse(
session_id=session_id,
message="Session ended successfully"
)
@app.get("/sessions")
async def list_sessions():
"""
عرض قائمة الجلسات النشطة
"""
active_sessions = []
current_time = np.datetime64('now')
for session_id, session_data in SESSIONS.items():
time_diff = current_time - session_data['last_activity']
active_sessions.append({
'session_id': session_id,
'frame_count': session_data['frame_num'],
'created_at': str(session_data['created_at']),
'last_activity': str(session_data['last_activity']),
'inactive_minutes': float(time_diff / np.timedelta64(1, 'm'))
})
return {
'total_sessions': len(active_sessions),
'sessions': active_sessions
}
@app.delete("/sessions/cleanup")
async def cleanup_inactive_sessions(max_inactive_minutes: int = 30):
"""
تنظيف الجلسات غير النشطة
"""
current_time = np.datetime64('now')
cleaned_sessions = []
for session_id in list(SESSIONS.keys()):
session = SESSIONS[session_id]
time_diff = current_time - session['last_activity']
inactive_minutes = float(time_diff / np.timedelta64(1, 'm'))
if inactive_minutes > max_inactive_minutes:
del SESSIONS[session_id]
cleaned_sessions.append(session_id)
logger.info(f"Cleaned up {len(cleaned_sessions)} inactive sessions")
return {
'message': f"Cleaned up {len(cleaned_sessions)} inactive sessions",
'cleaned_sessions': cleaned_sessions,
'remaining_sessions': len(SESSIONS)
}
# ================== معالج الأخطاء ==================
@app.exception_handler(Exception)
async def global_exception_handler(request, exc):
logger.error(f"Global exception: {str(exc)}")
return {
"error": "Internal server error",
"detail": str(exc)
}
# ================== تشغيل الخادم ==================
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)