Update app.py
Browse files
app.py
CHANGED
@@ -1,6 +1,5 @@
|
|
1 |
import io
|
2 |
import os
|
3 |
-
import sys
|
4 |
import time
|
5 |
import numpy as np
|
6 |
import cv2
|
@@ -12,34 +11,14 @@ import uvicorn
|
|
12 |
|
13 |
app = FastAPI()
|
14 |
|
15 |
-
# 🟢
|
16 |
-
if not os.path.exists("fastdepth"):
|
17 |
-
os.system("git clone https://github.com/dwofk/fast-depth.git fastdepth")
|
18 |
-
|
19 |
-
# 🟢 Thêm `fastdepth` vào `sys.path`
|
20 |
-
sys.path.append(os.path.abspath("fastdepth"))
|
21 |
-
|
22 |
-
# 🟢 Tải đúng file trọng số nếu chưa có
|
23 |
-
weights_path = "fastdepth/models/fastdepth_nyu.pt"
|
24 |
-
if not os.path.exists(weights_path):
|
25 |
-
print("🔻 Trọng số chưa có, đang tải từ GitHub...")
|
26 |
-
os.system(f"wget -O {weights_path} https://github.com/dwofk/fast-depth/raw/master/models/fastdepth_nyu.pt")
|
27 |
-
else:
|
28 |
-
print("✅ Trọng số đã có sẵn.")
|
29 |
-
|
30 |
-
# 🟢 Import FastDepth
|
31 |
-
from fastdepth.models import MobileNetSkipAdd
|
32 |
-
|
33 |
-
# 🟢 Load mô hình FastDepth đúng trọng số
|
34 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
35 |
-
model =
|
36 |
-
|
37 |
-
model.load_state_dict(torch.load(weights_path, map_location=device))
|
38 |
-
print("✅ Mô hình FastDepth đã được load thành công!")
|
39 |
-
except FileNotFoundError:
|
40 |
-
print("❌ Không tìm thấy file trọng số! Kiểm tra lại đường dẫn.")
|
41 |
|
42 |
-
|
|
|
|
|
43 |
|
44 |
@app.post("/analyze_path/")
|
45 |
async def analyze_path(file: UploadFile = File(...)):
|
@@ -53,25 +32,24 @@ async def analyze_path(file: UploadFile = File(...)):
|
|
53 |
# 🟢 Lật ảnh trước khi tính toán Depth Map
|
54 |
flipped_image = cv2.flip(image_np, -1)
|
55 |
|
56 |
-
# 🟢 Chuyển đổi lại thành ảnh PIL để đưa vào
|
57 |
flipped_image_pil = Image.fromarray(flipped_image)
|
58 |
|
59 |
-
# 🟢 Chuyển đổi ảnh thành tensor
|
60 |
-
|
61 |
-
torchvision.transforms.Resize((224, 224)),
|
62 |
-
torchvision.transforms.ToTensor(),
|
63 |
-
])
|
64 |
-
img_tensor = transform(flipped_image_pil).unsqueeze(0).to(device)
|
65 |
|
66 |
# 🟢 Bắt đầu đo thời gian dự đoán Depth Map
|
67 |
start_time = time.time()
|
68 |
|
69 |
-
# 🟢 Dự đoán Depth Map với
|
70 |
with torch.no_grad():
|
71 |
-
depth_map = model(img_tensor)
|
|
|
|
|
|
|
72 |
|
73 |
end_time = time.time()
|
74 |
-
print(f"⏳
|
75 |
|
76 |
# 🟢 Đo thời gian xử lý đường đi
|
77 |
start_detect_time = time.time()
|
|
|
1 |
import io
|
2 |
import os
|
|
|
3 |
import time
|
4 |
import numpy as np
|
5 |
import cv2
|
|
|
11 |
|
12 |
app = FastAPI()
|
13 |
|
14 |
+
# 🟢 Tải mô hình MiDaS từ PyTorch Hub
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
16 |
+
model = torch.hub.load("intel-isl/MiDaS", "MiDaS_small").to(device) # 🟢 Dùng phiên bản nhẹ MiDaS_small
|
17 |
+
model.eval()
|
|
|
|
|
|
|
|
|
18 |
|
19 |
+
# 🟢 Load transform cho MiDaS
|
20 |
+
midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms")
|
21 |
+
transform = midas_transforms.small_transform # 🟢 Dùng transform phù hợp với MiDaS_small
|
22 |
|
23 |
@app.post("/analyze_path/")
|
24 |
async def analyze_path(file: UploadFile = File(...)):
|
|
|
32 |
# 🟢 Lật ảnh trước khi tính toán Depth Map
|
33 |
flipped_image = cv2.flip(image_np, -1)
|
34 |
|
35 |
+
# 🟢 Chuyển đổi lại thành ảnh PIL để đưa vào MiDaS
|
36 |
flipped_image_pil = Image.fromarray(flipped_image)
|
37 |
|
38 |
+
# 🟢 Chuyển đổi ảnh thành tensor phù hợp với MiDaS
|
39 |
+
img_tensor = transform(flipped_image_pil).to(device)
|
|
|
|
|
|
|
|
|
40 |
|
41 |
# 🟢 Bắt đầu đo thời gian dự đoán Depth Map
|
42 |
start_time = time.time()
|
43 |
|
44 |
+
# 🟢 Dự đoán Depth Map với MiDaS
|
45 |
with torch.no_grad():
|
46 |
+
depth_map = model(img_tensor)
|
47 |
+
depth_map = torch.nn.functional.interpolate(
|
48 |
+
depth_map.unsqueeze(1), size=flipped_image_pil.size[::-1], mode="bicubic", align_corners=False
|
49 |
+
).squeeze().cpu().numpy()
|
50 |
|
51 |
end_time = time.time()
|
52 |
+
print(f"⏳ MiDaS xử lý trong {end_time - start_time:.4f} giây")
|
53 |
|
54 |
# 🟢 Đo thời gian xử lý đường đi
|
55 |
start_detect_time = time.time()
|