File size: 3,546 Bytes
b3652fb
81def50
b461974
cd88622
02cc722
81def50
 
 
 
02cc722
bef39ed
81def50
02cc722
 
b461974
2d5e944
 
 
b461974
 
 
cd88622
9bb59fd
 
cd88622
9bb59fd
cd88622
 
9bb59fd
b461974
9bb59fd
2d5e944
cd88622
81def50
903cef3
cd88622
 
 
 
 
 
81def50
3f318c7
b3652fb
 
81def50
b3652fb
81def50
903cef3
 
 
 
 
cd88622
903cef3
 
 
2d5e944
81def50
2d5e944
 
 
 
903cef3
81def50
cd88622
 
 
81def50
2d5e944
 
81def50
cd88622
 
 
 
 
903cef3
cd88622
 
b3652fb
 
 
 
81def50
 
 
 
b3652fb
81def50
 
 
 
 
b3652fb
81def50
b3652fb
81def50
f6e7520
81def50
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import io
import os
import sys
import time
import numpy as np
import cv2
import torch
import torchvision
from fastapi import FastAPI, File, UploadFile
from PIL import Image
import uvicorn

app = FastAPI()

# 🟢 Clone FastDepth từ GitHub nếu chưa có
if not os.path.exists("fastdepth"):
    os.system("git clone https://github.com/dwofk/fast-depth.git fastdepth")

# 🟢 Thêm `fastdepth` vào `sys.path`
sys.path.append(os.path.abspath("fastdepth"))

# 🟢 Tải đúng file trọng số nếu chưa có
weights_path = "fastdepth/models/fastdepth_nyu.pt"
if not os.path.exists(weights_path):
    print("🔻 Trọng số chưa có, đang tải từ GitHub...")
    os.system(f"wget -O {weights_path} https://github.com/dwofk/fast-depth/raw/master/models/fastdepth_nyu.pt")
else:
    print("✅ Trọng số đã có sẵn.")

# 🟢 Import FastDepth
from fastdepth.models import MobileNetSkipAdd  

# 🟢 Load mô hình FastDepth đúng trọng số
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = MobileNetSkipAdd(output_size=(224, 224))
try:
    model.load_state_dict(torch.load(weights_path, map_location=device))
    print("✅ Mô hình FastDepth đã được load thành công!")
except FileNotFoundError:
    print("❌ Không tìm thấy file trọng số! Kiểm tra lại đường dẫn.")

model.eval().to(device)

@app.post("/analyze_path/")
async def analyze_path(file: UploadFile = File(...)):
    # 🟢 Đọc file ảnh từ ESP32
    image_bytes = await file.read()
    image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
    
    # 🟢 Chuyển ảnh sang NumPy để lật đúng chiều
    image_np = np.array(image)
    
    # 🟢 Lật ảnh trước khi tính toán Depth Map
    flipped_image = cv2.flip(image_np, -1)
    
    # 🟢 Chuyển đổi lại thành ảnh PIL để đưa vào FastDepth
    flipped_image_pil = Image.fromarray(flipped_image)

    # 🟢 Chuyển đổi ảnh thành tensor (chuẩn hóa cho FastDepth)
    transform = torchvision.transforms.Compose([
        torchvision.transforms.Resize((224, 224)),
        torchvision.transforms.ToTensor(),
    ])
    img_tensor = transform(flipped_image_pil).unsqueeze(0).to(device)

    # 🟢 Bắt đầu đo thời gian dự đoán Depth Map
    start_time = time.time()
    
    # 🟢 Dự đoán Depth Map với FastDepth
    with torch.no_grad():
        depth_map = model(img_tensor).squeeze().cpu().numpy()

    end_time = time.time()
    print(f"⏳ FastDepth xử lý trong {end_time - start_time:.4f} giây")

    # 🟢 Đo thời gian xử lý đường đi
    start_detect_time = time.time()
    command = detect_path(depth_map)
    end_detect_time = time.time()
    print(f"⏳ detect_path() xử lý trong {end_detect_time - start_detect_time:.4f} giây")

    return {"command": command}

def detect_path(depth_map):
    """Phân tích đường đi từ ảnh Depth Map"""
    h, w = depth_map.shape
    center_x = w // 2
    scan_y = h - 20  # Quét dòng gần đáy ảnh

    left_region = np.mean(depth_map[scan_y, :center_x])
    right_region = np.mean(depth_map[scan_y, center_x:])
    center_region = np.mean(depth_map[scan_y, center_x - 20:center_x + 20])

    if center_region > 200:
        return "forward"
    elif left_region > right_region:
        return "left"
    elif right_region > left_region:
        return "right"
    else:
        return "backward"

# 🟢 Chạy server FastAPI
if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=7860)