adpro commited on
Commit
81def50
·
verified ·
1 Parent(s): 2d5e944

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -27
app.py CHANGED
@@ -1,60 +1,72 @@
1
- from fastapi import FastAPI, File, UploadFile
2
  import io
 
3
  import numpy as np
 
 
 
 
4
  from PIL import Image
5
  import uvicorn
6
- import cv2
7
- from fastdepth import FastDepth
8
- model = FastDepth(pretrained=True)
9
- model.eval()
10
  app = FastAPI()
11
 
12
- import os
13
  if not os.path.exists("fastdepth"):
14
  os.system("git clone https://github.com/dwofk/fast-depth.git fastdepth")
15
 
16
- from fastdepth import FastDepth # Import sau khi clone
17
-
18
- app = FastAPI()
19
 
20
  # 🟢 Load mô hình FastDepth
21
- model = FastDepth(pretrained=True)
22
- model.eval()
 
 
23
 
24
- def analyzepath(image):
25
- depth_map = model(image).squeeze().cpu().numpy()
26
- return detect_path(depth_map) # Xử lý đường đi nhanh hơn
27
  @app.post("/analyze_path/")
28
  async def analyze_path(file: UploadFile = File(...)):
 
29
  image_bytes = await file.read()
30
- image = Image.open(io.BytesIO(image_bytes)).convert("L")
31
- depth_map = np.array(image)
32
 
33
- # 🟢 Lật ảnh (nếu cần)
34
- flipped_depth_map = cv2.flip(depth_map, -1)
35
  transform = torchvision.transforms.Compose([
36
  torchvision.transforms.Resize((224, 224)),
37
  torchvision.transforms.ToTensor(),
38
  ])
39
  img_tensor = transform(image).unsqueeze(0).to(device)
 
 
40
  with torch.no_grad():
41
  depth_map = model(img_tensor).squeeze().cpu().numpy()
 
 
 
 
42
  # 🟢 Phân tích đường đi
43
  command = detect_path(flipped_depth_map)
44
 
45
  return {"command": command}
46
 
47
  def detect_path(depth_map):
48
- _, thresh = cv2.threshold(depth_map, 200, 255, cv2.THRESH_BINARY)
49
- contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
 
 
50
 
51
- if len(contours) == 0:
 
 
 
 
52
  return "forward"
53
-
54
- left_region = np.mean(depth_map[:, :depth_map.shape[1]//3])
55
- right_region = np.mean(depth_map[:, 2*depth_map.shape[1]//3:])
56
-
57
- if left_region > right_region:
58
  return "left"
59
- else:
60
  return "right"
 
 
 
 
 
 
 
 
1
  import io
2
+ import os
3
  import numpy as np
4
+ import cv2
5
+ import torch
6
+ import torchvision
7
+ from fastapi import FastAPI, File, UploadFile
8
  from PIL import Image
9
  import uvicorn
10
+
 
 
 
11
  app = FastAPI()
12
 
13
+ # 🟢 Clone FastDepth từ GitHub (chỉ cần làm 1 lần)
14
  if not os.path.exists("fastdepth"):
15
  os.system("git clone https://github.com/dwofk/fast-depth.git fastdepth")
16
 
17
+ # 🟢 Import FastDepth sau khi clone
18
+ from fastdepth.models import MobileNetSkipAdd # Model chính của FastDepth
 
19
 
20
  # 🟢 Load mô hình FastDepth
21
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
+ model = MobileNetSkipAdd()
23
+ model.load_state_dict(torch.load("fastdepth/models/fastdepth_nyu.pt", map_location=device))
24
+ model.eval().to(device)
25
 
 
 
 
26
  @app.post("/analyze_path/")
27
  async def analyze_path(file: UploadFile = File(...)):
28
+ # 🟢 Đọc file ảnh từ ESP32
29
  image_bytes = await file.read()
30
+ image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
 
31
 
32
+ # 🟢 Chuyển đổi ảnh thành tensor (chuẩn hóa cho FastDepth)
 
33
  transform = torchvision.transforms.Compose([
34
  torchvision.transforms.Resize((224, 224)),
35
  torchvision.transforms.ToTensor(),
36
  ])
37
  img_tensor = transform(image).unsqueeze(0).to(device)
38
+
39
+ # 🟢 Dự đoán Depth Map với FastDepth
40
  with torch.no_grad():
41
  depth_map = model(img_tensor).squeeze().cpu().numpy()
42
+
43
+ # 🟢 Lật ngược ảnh (nếu cần)
44
+ flipped_depth_map = cv2.flip(depth_map, -1)
45
+
46
  # 🟢 Phân tích đường đi
47
  command = detect_path(flipped_depth_map)
48
 
49
  return {"command": command}
50
 
51
  def detect_path(depth_map):
52
+ """Phân tích đường đi từ ảnh Depth Map"""
53
+ h, w = depth_map.shape
54
+ center_x = w // 2
55
+ scan_y = h - 20 # Quét dòng gần đáy ảnh
56
 
57
+ left_region = np.mean(depth_map[scan_y, :center_x])
58
+ right_region = np.mean(depth_map[scan_y, center_x:])
59
+ center_region = np.mean(depth_map[scan_y, center_x - 20:center_x + 20])
60
+
61
+ if center_region > 200:
62
  return "forward"
63
+ elif left_region > right_region:
 
 
 
 
64
  return "left"
65
+ elif right_region > left_region:
66
  return "right"
67
+ else:
68
+ return "backward"
69
+
70
+ # 🟢 Chạy server FastAPI
71
+ if __name__ == "__main__":
72
+ uvicorn.run(app, host="0.0.0.0", port=7860)