adpro commited on
Commit
8ff8f64
·
verified ·
1 Parent(s): cd88622

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -37
app.py CHANGED
@@ -1,6 +1,5 @@
1
  import io
2
  import os
3
- import sys
4
  import time
5
  import numpy as np
6
  import cv2
@@ -12,34 +11,14 @@ import uvicorn
12
 
13
  app = FastAPI()
14
 
15
- # 🟢 Clone FastDepth từ GitHub nếu chưa
16
- if not os.path.exists("fastdepth"):
17
- os.system("git clone https://github.com/dwofk/fast-depth.git fastdepth")
18
-
19
- # 🟢 Thêm `fastdepth` vào `sys.path`
20
- sys.path.append(os.path.abspath("fastdepth"))
21
-
22
- # 🟢 Tải đúng file trọng số nếu chưa có
23
- weights_path = "fastdepth/models/fastdepth_nyu.pt"
24
- if not os.path.exists(weights_path):
25
- print("🔻 Trọng số chưa có, đang tải từ GitHub...")
26
- os.system(f"wget -O {weights_path} https://github.com/dwofk/fast-depth/raw/master/models/fastdepth_nyu.pt")
27
- else:
28
- print("✅ Trọng số đã có sẵn.")
29
-
30
- # 🟢 Import FastDepth
31
- from fastdepth.models import MobileNetSkipAdd
32
-
33
- # 🟢 Load mô hình FastDepth đúng trọng số
34
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
35
- model = MobileNetSkipAdd(output_size=(224, 224))
36
- try:
37
- model.load_state_dict(torch.load(weights_path, map_location=device))
38
- print("✅ Mô hình FastDepth đã được load thành công!")
39
- except FileNotFoundError:
40
- print("❌ Không tìm thấy file trọng số! Kiểm tra lại đường dẫn.")
41
 
42
- model.eval().to(device)
 
 
43
 
44
  @app.post("/analyze_path/")
45
  async def analyze_path(file: UploadFile = File(...)):
@@ -53,25 +32,24 @@ async def analyze_path(file: UploadFile = File(...)):
53
  # 🟢 Lật ảnh trước khi tính toán Depth Map
54
  flipped_image = cv2.flip(image_np, -1)
55
 
56
- # 🟢 Chuyển đổi lại thành ảnh PIL để đưa vào FastDepth
57
  flipped_image_pil = Image.fromarray(flipped_image)
58
 
59
- # 🟢 Chuyển đổi ảnh thành tensor (chuẩn hóa cho FastDepth)
60
- transform = torchvision.transforms.Compose([
61
- torchvision.transforms.Resize((224, 224)),
62
- torchvision.transforms.ToTensor(),
63
- ])
64
- img_tensor = transform(flipped_image_pil).unsqueeze(0).to(device)
65
 
66
  # 🟢 Bắt đầu đo thời gian dự đoán Depth Map
67
  start_time = time.time()
68
 
69
- # 🟢 Dự đoán Depth Map với FastDepth
70
  with torch.no_grad():
71
- depth_map = model(img_tensor).squeeze().cpu().numpy()
 
 
 
72
 
73
  end_time = time.time()
74
- print(f"⏳ FastDepth xử lý trong {end_time - start_time:.4f} giây")
75
 
76
  # 🟢 Đo thời gian xử lý đường đi
77
  start_detect_time = time.time()
 
1
  import io
2
  import os
 
3
  import time
4
  import numpy as np
5
  import cv2
 
11
 
12
  app = FastAPI()
13
 
14
+ # 🟢 Tải hình MiDaS từ PyTorch Hub
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
+ model = torch.hub.load("intel-isl/MiDaS", "MiDaS_small").to(device) # 🟢 Dùng phiên bản nhẹ MiDaS_small
17
+ model.eval()
 
 
 
 
18
 
19
+ # 🟢 Load transform cho MiDaS
20
+ midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms")
21
+ transform = midas_transforms.small_transform # 🟢 Dùng transform phù hợp với MiDaS_small
22
 
23
  @app.post("/analyze_path/")
24
  async def analyze_path(file: UploadFile = File(...)):
 
32
  # 🟢 Lật ảnh trước khi tính toán Depth Map
33
  flipped_image = cv2.flip(image_np, -1)
34
 
35
+ # 🟢 Chuyển đổi lại thành ảnh PIL để đưa vào MiDaS
36
  flipped_image_pil = Image.fromarray(flipped_image)
37
 
38
+ # 🟢 Chuyển đổi ảnh thành tensor phù hợp với MiDaS
39
+ img_tensor = transform(flipped_image_pil).to(device)
 
 
 
 
40
 
41
  # 🟢 Bắt đầu đo thời gian dự đoán Depth Map
42
  start_time = time.time()
43
 
44
+ # 🟢 Dự đoán Depth Map với MiDaS
45
  with torch.no_grad():
46
+ depth_map = model(img_tensor)
47
+ depth_map = torch.nn.functional.interpolate(
48
+ depth_map.unsqueeze(1), size=flipped_image_pil.size[::-1], mode="bicubic", align_corners=False
49
+ ).squeeze().cpu().numpy()
50
 
51
  end_time = time.time()
52
+ print(f"⏳ MiDaS xử lý trong {end_time - start_time:.4f} giây")
53
 
54
  # 🟢 Đo thời gian xử lý đường đi
55
  start_detect_time = time.time()