midas / app.py
adpro's picture
Update app.py
4a9c7f0 verified
raw
history blame
2.61 kB
import io
import time
import torch
import numpy as np
import cv2
from fastapi import FastAPI, File, UploadFile
from PIL import Image
import uvicorn
from torchvision import transforms
# 🟢 Tạo FastAPI
app = FastAPI()
# 🟢 Kiểm tra GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 🟢 Tải model MiDaS
midas = torch.hub.load("isl-org/MiDaS", "DPT_Swin2_L_384")
midas.to(device)
midas.eval()
# 🟢 Chuẩn bị bộ tiền xử lý ảnh
transform = transforms.Compose([
transforms.Resize((384, 384)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
@app.post("/analyze_path/")
async def analyze_path(file: UploadFile = File(...)):
# 🟢 Đọc file ảnh từ ESP32
image_bytes = await file.read()
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
image_np = np.array(image)
flipped_image = cv2.flip(image_np, -1)
# 🔵 Resize và chuẩn hóa ảnh
input_tensor = transform(flipped_image).unsqueeze(0).to(device)
# 🟢 Dự đoán Depth Map với MiDaS
start_time = time.time()
with torch.no_grad():
depth_map = midas(input_tensor)
end_time = time.time()
print(f"⏳ MiDaS xử lý trong {end_time - start_time:.4f} giây")
# 🟢 Chuẩn hóa ảnh Depth Map
depth_map = depth_map.squeeze().cpu().numpy()
depth_map = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min()) * 255
depth_map = depth_map.astype("uint8")
# 🟢 Xử lý phát hiện đường đi
start_detect_time = time.time()
command = detect_path(depth_map)
end_detect_time = time.time()
print(f"⏳ detect_path() xử lý trong {end_detect_time - start_detect_time:.4f} giây")
return {"command": command}
def detect_path(depth_map):
"""Phân tích đường đi từ ảnh Depth Map"""
h, w = depth_map.shape
center_x = w // 2
scan_y = int(h * 0.8) # Quét dòng 80% từ trên xuống
left_region = np.mean(depth_map[scan_y, :center_x])
right_region = np.mean(depth_map[scan_y, center_x:])
center_region = np.mean(depth_map[scan_y, center_x - 40:center_x + 40])
# 🟢 Cải thiện logic xử lý
threshold = 100 # Ngưỡng phân biệt vật cản
if center_region > threshold:
return "forward"
elif left_region > right_region:
return "left"
elif right_region > left_region:
return "right"
else:
return "backward"
# 🟢 Chạy server FastAPI
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)