Update app.py
Browse files
app.py
CHANGED
@@ -1,59 +1,48 @@
|
|
1 |
import io
|
2 |
-
import os # 🟢 Thêm dòng này để tránh lỗi NameError
|
3 |
-
import sys
|
4 |
import time
|
5 |
import numpy as np
|
6 |
import cv2
|
7 |
import torch
|
8 |
-
import
|
9 |
from fastapi import FastAPI, File, UploadFile
|
10 |
from PIL import Image
|
11 |
import uvicorn
|
12 |
|
13 |
app = FastAPI()
|
14 |
|
15 |
-
# 🟢
|
16 |
-
fastdepth_path = "FastDepth"
|
17 |
-
if not os.path.exists(fastdepth_path):
|
18 |
-
os.system("git clone https://github.com/dwofk/fast-depth.git FastDepth")
|
19 |
-
|
20 |
-
# 🟢 Thêm FastDepth vào sys.path để import được
|
21 |
-
sys.path.append(fastdepth_path)
|
22 |
-
|
23 |
-
# 🟢 Import FastDepth sau khi đã tải về
|
24 |
-
from FastDepth.models import MobileNetSkipAdd
|
25 |
-
|
26 |
-
# 🟢 Load mô hình FastDepth
|
27 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
28 |
-
|
29 |
-
model.
|
30 |
-
model.eval()
|
31 |
|
32 |
@app.post("/analyze_path/")
|
33 |
async def analyze_path(file: UploadFile = File(...)):
|
34 |
# 🟢 Đọc file ảnh từ ESP32
|
35 |
image_bytes = await file.read()
|
36 |
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
|
37 |
-
|
38 |
-
# 🟢
|
39 |
-
|
40 |
-
torchvision.transforms.Resize((224, 224)), # FastDepth yêu cầu ảnh 224x224
|
41 |
-
torchvision.transforms.ToTensor(),
|
42 |
-
])
|
43 |
-
img_tensor = transform(image).unsqueeze(0).to(device)
|
44 |
|
45 |
# 🟢 Bắt đầu đo thời gian dự đoán Depth Map
|
46 |
start_time = time.time()
|
47 |
-
|
48 |
-
# 🟢 Dự đoán Depth Map với
|
49 |
with torch.no_grad():
|
50 |
-
|
51 |
|
52 |
-
|
53 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
|
55 |
-
|
56 |
-
print(f"
|
57 |
|
58 |
# 🟢 Đo thời gian xử lý đường đi
|
59 |
start_detect_time = time.time()
|
@@ -65,9 +54,6 @@ async def analyze_path(file: UploadFile = File(...)):
|
|
65 |
|
66 |
def detect_path(depth_map):
|
67 |
"""Phân tích đường đi từ ảnh Depth Map"""
|
68 |
-
if len(depth_map.shape) != 2: # 🟢 Kiểm tra nếu depth_map không phải 2D
|
69 |
-
raise ValueError("Depth map không phải ảnh 2D hợp lệ!")
|
70 |
-
|
71 |
h, w = depth_map.shape
|
72 |
center_x = w // 2
|
73 |
scan_y = h - 20 # Quét dòng gần đáy ảnh
|
|
|
1 |
import io
|
|
|
|
|
2 |
import time
|
3 |
import numpy as np
|
4 |
import cv2
|
5 |
import torch
|
6 |
+
from transformers import AutoImageProcessor, ZoeDepthForDepthEstimation
|
7 |
from fastapi import FastAPI, File, UploadFile
|
8 |
from PIL import Image
|
9 |
import uvicorn
|
10 |
|
11 |
app = FastAPI()
|
12 |
|
13 |
+
# 🟢 Tải mô hình ZoeDepth từ Hugging Face
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
15 |
+
image_processor = AutoImageProcessor.from_pretrained("Intel/zoedepth-nyu-kitti")
|
16 |
+
model = ZoeDepthForDepthEstimation.from_pretrained("Intel/zoedepth-nyu-kitti").to(device)
|
17 |
+
model.eval()
|
18 |
|
19 |
@app.post("/analyze_path/")
|
20 |
async def analyze_path(file: UploadFile = File(...)):
|
21 |
# 🟢 Đọc file ảnh từ ESP32
|
22 |
image_bytes = await file.read()
|
23 |
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
|
24 |
+
|
25 |
+
# 🟢 Chuẩn bị ảnh cho mô hình ZoeDepth
|
26 |
+
inputs = image_processor(images=image, return_tensors="pt").to(device)
|
|
|
|
|
|
|
|
|
27 |
|
28 |
# 🟢 Bắt đầu đo thời gian dự đoán Depth Map
|
29 |
start_time = time.time()
|
30 |
+
|
31 |
+
# 🟢 Dự đoán Depth Map với ZoeDepth
|
32 |
with torch.no_grad():
|
33 |
+
outputs = model(**inputs)
|
34 |
|
35 |
+
# 🟢 Xử lý ảnh sau khi dự đoán
|
36 |
+
post_processed_output = image_processor.post_process_depth_estimation(
|
37 |
+
outputs,
|
38 |
+
source_sizes=[(image.height, image.width)],
|
39 |
+
)
|
40 |
+
predicted_depth = post_processed_output[0]["predicted_depth"]
|
41 |
+
depth_map = predicted_depth * 255 / predicted_depth.max()
|
42 |
+
depth_map = depth_map.detach().cpu().numpy().astype("uint8")
|
43 |
|
44 |
+
end_time = time.time()
|
45 |
+
print(f"⏳ ZoeDepth xử lý trong {end_time - start_time:.4f} giây")
|
46 |
|
47 |
# 🟢 Đo thời gian xử lý đường đi
|
48 |
start_detect_time = time.time()
|
|
|
54 |
|
55 |
def detect_path(depth_map):
|
56 |
"""Phân tích đường đi từ ảnh Depth Map"""
|
|
|
|
|
|
|
57 |
h, w = depth_map.shape
|
58 |
center_x = w // 2
|
59 |
scan_y = h - 20 # Quét dòng gần đáy ảnh
|