adpro commited on
Commit
3412511
·
verified ·
1 Parent(s): c046173

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -9
app.py CHANGED
@@ -10,10 +10,23 @@ import uvicorn
10
 
11
  app = FastAPI()
12
 
13
- # 🟢 Tải hình MobileNetDepth (MobileNet v3 Large)
 
 
 
 
 
 
 
 
 
 
 
 
14
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
- model = torchvision.models.mobilenet_v3_large(pretrained=True).to(device)
16
- model.eval()
 
17
 
18
  @app.post("/analyze_path/")
19
  async def analyze_path(file: UploadFile = File(...)):
@@ -21,9 +34,9 @@ async def analyze_path(file: UploadFile = File(...)):
21
  image_bytes = await file.read()
22
  image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
23
 
24
- # 🟢 Chuyển đổi ảnh thành tensor phù hợp với MobileNetDepth
25
  transform = torchvision.transforms.Compose([
26
- torchvision.transforms.Resize((224, 224)), # MobileNetDepth yêu cầu ảnh 224x224
27
  torchvision.transforms.ToTensor(),
28
  ])
29
  img_tensor = transform(image).unsqueeze(0).to(device)
@@ -31,13 +44,15 @@ async def analyze_path(file: UploadFile = File(...)):
31
  # 🟢 Bắt đầu đo thời gian dự đoán Depth Map
32
  start_time = time.time()
33
 
34
- # 🟢 Dự đoán Depth Map với MobileNetDepth
35
  with torch.no_grad():
36
- depth_map = model(img_tensor)
37
- depth_map = depth_map.squeeze().cpu().numpy()
38
 
39
  end_time = time.time()
40
- print(f"⏳ MobileNetDepth xử lý trong {end_time - start_time:.4f} giây")
 
 
 
41
 
42
  # 🟢 Đo thời gian xử lý đường đi
43
  start_detect_time = time.time()
@@ -49,6 +64,9 @@ async def analyze_path(file: UploadFile = File(...)):
49
 
50
  def detect_path(depth_map):
51
  """Phân tích đường đi từ ảnh Depth Map"""
 
 
 
52
  h, w = depth_map.shape
53
  center_x = w // 2
54
  scan_y = h - 20 # Quét dòng gần đáy ảnh
 
10
 
11
  app = FastAPI()
12
 
13
+ # 🟢 Clone FastDepth nếu chưa
14
+ fastdepth_path = "FastDepth"
15
+ if not os.path.exists(fastdepth_path):
16
+ os.system("git clone https://github.com/dwofk/fast-depth.git FastDepth")
17
+
18
+ # 🟢 Thêm FastDepth vào sys.path để import được
19
+ import sys
20
+ sys.path.append(fastdepth_path)
21
+
22
+ # 🟢 Import FastDepth sau khi đã tải về
23
+ from FastDepth.models import MobileNetSkipAdd
24
+
25
+ # 🟢 Load mô hình FastDepth
26
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
27
+ model = MobileNetSkipAdd(output_size=(224, 224)) # 🟢 FastDepth hỗ trợ đầu ra 224x224
28
+ model.load_state_dict(torch.load(f"{fastdepth_path}/models/fastdepth_nyu.pt", map_location=device))
29
+ model.eval().to(device)
30
 
31
  @app.post("/analyze_path/")
32
  async def analyze_path(file: UploadFile = File(...)):
 
34
  image_bytes = await file.read()
35
  image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
36
 
37
+ # 🟢 Chuyển đổi ảnh thành tensor phù hợp với FastDepth
38
  transform = torchvision.transforms.Compose([
39
+ torchvision.transforms.Resize((224, 224)), # FastDepth yêu cầu ảnh 224x224
40
  torchvision.transforms.ToTensor(),
41
  ])
42
  img_tensor = transform(image).unsqueeze(0).to(device)
 
44
  # 🟢 Bắt đầu đo thời gian dự đoán Depth Map
45
  start_time = time.time()
46
 
47
+ # 🟢 Dự đoán Depth Map với FastDepth
48
  with torch.no_grad():
49
+ depth_map = model(img_tensor).squeeze().cpu().numpy()
 
50
 
51
  end_time = time.time()
52
+ print(f"⏳ FastDepth xử lý trong {end_time - start_time:.4f} giây")
53
+
54
+ # 🟢 Kiểm tra kích thước Depth Map
55
+ print(f"📏 Depth Map Shape: {depth_map.shape}")
56
 
57
  # 🟢 Đo thời gian xử lý đường đi
58
  start_detect_time = time.time()
 
64
 
65
  def detect_path(depth_map):
66
  """Phân tích đường đi từ ảnh Depth Map"""
67
+ if len(depth_map.shape) != 2: # 🟢 Kiểm tra nếu depth_map không phải 2D
68
+ raise ValueError("Depth map không phải ảnh 2D hợp lệ!")
69
+
70
  h, w = depth_map.shape
71
  center_x = w // 2
72
  scan_y = h - 20 # Quét dòng gần đáy ảnh