Update app.py
Browse files
app.py
CHANGED
@@ -1,60 +1,72 @@
|
|
1 |
-
from fastapi import FastAPI, File, UploadFile
|
2 |
import io
|
|
|
3 |
import numpy as np
|
|
|
|
|
|
|
|
|
4 |
from PIL import Image
|
5 |
import uvicorn
|
6 |
-
|
7 |
-
from fastdepth import FastDepth
|
8 |
-
model = FastDepth(pretrained=True)
|
9 |
-
model.eval()
|
10 |
app = FastAPI()
|
11 |
|
12 |
-
|
13 |
if not os.path.exists("fastdepth"):
|
14 |
os.system("git clone https://github.com/dwofk/fast-depth.git fastdepth")
|
15 |
|
16 |
-
|
17 |
-
|
18 |
-
app = FastAPI()
|
19 |
|
20 |
# 🟢 Load mô hình FastDepth
|
21 |
-
|
22 |
-
model
|
|
|
|
|
23 |
|
24 |
-
def analyzepath(image):
|
25 |
-
depth_map = model(image).squeeze().cpu().numpy()
|
26 |
-
return detect_path(depth_map) # Xử lý đường đi nhanh hơn
|
27 |
@app.post("/analyze_path/")
|
28 |
async def analyze_path(file: UploadFile = File(...)):
|
|
|
29 |
image_bytes = await file.read()
|
30 |
-
image = Image.open(io.BytesIO(image_bytes)).convert("
|
31 |
-
depth_map = np.array(image)
|
32 |
|
33 |
-
# 🟢
|
34 |
-
flipped_depth_map = cv2.flip(depth_map, -1)
|
35 |
transform = torchvision.transforms.Compose([
|
36 |
torchvision.transforms.Resize((224, 224)),
|
37 |
torchvision.transforms.ToTensor(),
|
38 |
])
|
39 |
img_tensor = transform(image).unsqueeze(0).to(device)
|
|
|
|
|
40 |
with torch.no_grad():
|
41 |
depth_map = model(img_tensor).squeeze().cpu().numpy()
|
|
|
|
|
|
|
|
|
42 |
# 🟢 Phân tích đường đi
|
43 |
command = detect_path(flipped_depth_map)
|
44 |
|
45 |
return {"command": command}
|
46 |
|
47 |
def detect_path(depth_map):
|
48 |
-
|
49 |
-
|
|
|
|
|
50 |
|
51 |
-
|
|
|
|
|
|
|
|
|
52 |
return "forward"
|
53 |
-
|
54 |
-
left_region = np.mean(depth_map[:, :depth_map.shape[1]//3])
|
55 |
-
right_region = np.mean(depth_map[:, 2*depth_map.shape[1]//3:])
|
56 |
-
|
57 |
-
if left_region > right_region:
|
58 |
return "left"
|
59 |
-
|
60 |
return "right"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import io
|
2 |
+
import os
|
3 |
import numpy as np
|
4 |
+
import cv2
|
5 |
+
import torch
|
6 |
+
import torchvision
|
7 |
+
from fastapi import FastAPI, File, UploadFile
|
8 |
from PIL import Image
|
9 |
import uvicorn
|
10 |
+
|
|
|
|
|
|
|
11 |
app = FastAPI()
|
12 |
|
13 |
+
# 🟢 Clone FastDepth từ GitHub (chỉ cần làm 1 lần)
|
14 |
if not os.path.exists("fastdepth"):
|
15 |
os.system("git clone https://github.com/dwofk/fast-depth.git fastdepth")
|
16 |
|
17 |
+
# 🟢 Import FastDepth sau khi clone
|
18 |
+
from fastdepth.models import MobileNetSkipAdd # Model chính của FastDepth
|
|
|
19 |
|
20 |
# 🟢 Load mô hình FastDepth
|
21 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
22 |
+
model = MobileNetSkipAdd()
|
23 |
+
model.load_state_dict(torch.load("fastdepth/models/fastdepth_nyu.pt", map_location=device))
|
24 |
+
model.eval().to(device)
|
25 |
|
|
|
|
|
|
|
26 |
@app.post("/analyze_path/")
|
27 |
async def analyze_path(file: UploadFile = File(...)):
|
28 |
+
# 🟢 Đọc file ảnh từ ESP32
|
29 |
image_bytes = await file.read()
|
30 |
+
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
|
|
|
31 |
|
32 |
+
# 🟢 Chuyển đổi ảnh thành tensor (chuẩn hóa cho FastDepth)
|
|
|
33 |
transform = torchvision.transforms.Compose([
|
34 |
torchvision.transforms.Resize((224, 224)),
|
35 |
torchvision.transforms.ToTensor(),
|
36 |
])
|
37 |
img_tensor = transform(image).unsqueeze(0).to(device)
|
38 |
+
|
39 |
+
# 🟢 Dự đoán Depth Map với FastDepth
|
40 |
with torch.no_grad():
|
41 |
depth_map = model(img_tensor).squeeze().cpu().numpy()
|
42 |
+
|
43 |
+
# 🟢 Lật ngược ảnh (nếu cần)
|
44 |
+
flipped_depth_map = cv2.flip(depth_map, -1)
|
45 |
+
|
46 |
# 🟢 Phân tích đường đi
|
47 |
command = detect_path(flipped_depth_map)
|
48 |
|
49 |
return {"command": command}
|
50 |
|
51 |
def detect_path(depth_map):
|
52 |
+
"""Phân tích đường đi từ ảnh Depth Map"""
|
53 |
+
h, w = depth_map.shape
|
54 |
+
center_x = w // 2
|
55 |
+
scan_y = h - 20 # Quét dòng gần đáy ảnh
|
56 |
|
57 |
+
left_region = np.mean(depth_map[scan_y, :center_x])
|
58 |
+
right_region = np.mean(depth_map[scan_y, center_x:])
|
59 |
+
center_region = np.mean(depth_map[scan_y, center_x - 20:center_x + 20])
|
60 |
+
|
61 |
+
if center_region > 200:
|
62 |
return "forward"
|
63 |
+
elif left_region > right_region:
|
|
|
|
|
|
|
|
|
64 |
return "left"
|
65 |
+
elif right_region > left_region:
|
66 |
return "right"
|
67 |
+
else:
|
68 |
+
return "backward"
|
69 |
+
|
70 |
+
# 🟢 Chạy server FastAPI
|
71 |
+
if __name__ == "__main__":
|
72 |
+
uvicorn.run(app, host="0.0.0.0", port=7860)
|