midas / app.py
adpro's picture
Update app.py
8ff8f64 verified
raw
history blame
2.8 kB
import io
import os
import time
import numpy as np
import cv2
import torch
import torchvision
from fastapi import FastAPI, File, UploadFile
from PIL import Image
import uvicorn
app = FastAPI()
# 🟢 Tải mô hình MiDaS từ PyTorch Hub
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = torch.hub.load("intel-isl/MiDaS", "MiDaS_small").to(device) # 🟢 Dùng phiên bản nhẹ MiDaS_small
model.eval()
# 🟢 Load transform cho MiDaS
midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms")
transform = midas_transforms.small_transform # 🟢 Dùng transform phù hợp với MiDaS_small
@app.post("/analyze_path/")
async def analyze_path(file: UploadFile = File(...)):
# 🟢 Đọc file ảnh từ ESP32
image_bytes = await file.read()
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
# 🟢 Chuyển ảnh sang NumPy để lật đúng chiều
image_np = np.array(image)
# 🟢 Lật ảnh trước khi tính toán Depth Map
flipped_image = cv2.flip(image_np, -1)
# 🟢 Chuyển đổi lại thành ảnh PIL để đưa vào MiDaS
flipped_image_pil = Image.fromarray(flipped_image)
# 🟢 Chuyển đổi ảnh thành tensor phù hợp với MiDaS
img_tensor = transform(flipped_image_pil).to(device)
# 🟢 Bắt đầu đo thời gian dự đoán Depth Map
start_time = time.time()
# 🟢 Dự đoán Depth Map với MiDaS
with torch.no_grad():
depth_map = model(img_tensor)
depth_map = torch.nn.functional.interpolate(
depth_map.unsqueeze(1), size=flipped_image_pil.size[::-1], mode="bicubic", align_corners=False
).squeeze().cpu().numpy()
end_time = time.time()
print(f"⏳ MiDaS xử lý trong {end_time - start_time:.4f} giây")
# 🟢 Đo thời gian xử lý đường đi
start_detect_time = time.time()
command = detect_path(depth_map)
end_detect_time = time.time()
print(f"⏳ detect_path() xử lý trong {end_detect_time - start_detect_time:.4f} giây")
return {"command": command}
def detect_path(depth_map):
"""Phân tích đường đi từ ảnh Depth Map"""
h, w = depth_map.shape
center_x = w // 2
scan_y = h - 20 # Quét dòng gần đáy ảnh
left_region = np.mean(depth_map[scan_y, :center_x])
right_region = np.mean(depth_map[scan_y, center_x:])
center_region = np.mean(depth_map[scan_y, center_x - 20:center_x + 20])
if center_region > 200:
return "forward"
elif left_region > right_region:
return "left"
elif right_region > left_region:
return "right"
else:
return "backward"
# 🟢 Chạy server FastAPI
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)