adpro commited on
Commit
64c49d0
·
verified ·
1 Parent(s): e8e141f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -25
app.py CHANGED
@@ -8,33 +8,39 @@ from fastapi import FastAPI, File, UploadFile
8
  from PIL import Image
9
  import uvicorn
10
 
 
11
  app = FastAPI()
12
 
13
  # 🟢 Chọn thiết bị xử lý (GPU nếu có)
14
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
 
16
- # 🟢 Tải model DPT-Hybrid thay cho ZoeDepth để tăng tốc
17
  feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-swinv2-tiny-256")
18
  model = DPTForDepthEstimation.from_pretrained("Intel/dpt-swinv2-tiny-256").to(device)
19
  model.eval()
20
 
 
 
 
21
  @app.post("/analyze_path/")
22
  async def analyze_path(file: UploadFile = File(...)):
23
- # 🟢 Bắt đầu đo thời gian dự đoán Depth Map
 
 
24
  start_time = time.time()
 
25
  # 🟢 Đọc file ảnh từ ESP32
26
  image_bytes = await file.read()
27
  image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
28
 
29
  # 🔵 Resize ảnh để xử lý nhanh hơn
30
- image = image.resize((192, 192)) # Giảm kích thước giúp tăng tốc độ xử lý
31
  image_np = np.array(image)
32
  flipped_image = cv2.flip(image_np, -1)
 
33
  # 🟢 Chuẩn bị ảnh cho mô hình
34
  inputs = feature_extractor(images=flipped_image, return_tensors="pt").to(device)
35
 
36
-
37
-
38
  # 🟢 Dự đoán Depth Map với DPT-Hybrid
39
  with torch.no_grad():
40
  outputs = model(**inputs)
@@ -43,6 +49,13 @@ async def analyze_path(file: UploadFile = File(...)):
43
  predicted_depth = outputs.predicted_depth.squeeze().cpu().numpy()
44
  depth_map = (predicted_depth * 255 / predicted_depth.max()).astype("uint8")
45
 
 
 
 
 
 
 
 
46
  end_time = time.time()
47
  print(f"⏳ DPT xử lý trong {end_time - start_time:.4f} giây")
48
 
@@ -50,9 +63,9 @@ async def analyze_path(file: UploadFile = File(...)):
50
  start_detect_time = time.time()
51
  command = detect_path(depth_map)
52
  end_detect_time = time.time()
53
- print(f"⏳ detect_path() xử lý trong {end_detect_time - start_detect_time:.4f} giây Lệnh: {command}")
54
 
55
- return {"command": command}
56
 
57
  def detect_path(depth_map):
58
  """Phân tích đường đi từ ảnh Depth Map"""
@@ -60,31 +73,21 @@ def detect_path(depth_map):
60
  center_x = w // 2
61
  scan_y = int(h * 0.8) # Quét dòng 80% từ trên xuống
62
 
63
- # 🟢 Chia ảnh thành 3 vùng: trái, giữa, phải
64
- left_region = np.mean(depth_map[scan_y, :center_x - 40])
65
- right_region = np.mean(depth_map[scan_y, center_x + 40:])
66
  center_region = np.mean(depth_map[scan_y, center_x - 40:center_x + 40])
67
 
68
- # 🟢 Ngưỡng phát hiện vật cản (càng thấp, càng nhạy)
69
- threshold = 80
70
-
71
- # 🟢 Không có vật cản ở cả 3 vùng → đi thẳng
72
- if left_region > threshold and center_region > threshold and right_region > threshold:
73
- return "forward"
74
-
75
- # 🟢 Nếu chỉ có giữa trống → đi thẳng
76
  if center_region > threshold:
77
  return "forward"
78
-
79
- # 🟢 Nếu chỉ có trái hoặc phải trống → chọn hướng có vùng trống lớn nhất
80
- if left_region > right_region:
81
  return "left"
82
  elif right_region > left_region:
83
  return "right"
84
-
85
- # 🟢 Nếu tất cả đều có vật cản → lùi lại
86
- return "backward"
87
 
88
  # 🟢 Chạy server FastAPI
89
  if __name__ == "__main__":
90
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
8
  from PIL import Image
9
  import uvicorn
10
 
11
+ # 🟢 Tạo FastAPI
12
  app = FastAPI()
13
 
14
  # 🟢 Chọn thiết bị xử lý (GPU nếu có)
15
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
 
17
+ # 🟢 Tải model DPT-Hybrid để tăng tốc
18
  feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-swinv2-tiny-256")
19
  model = DPTForDepthEstimation.from_pretrained("Intel/dpt-swinv2-tiny-256").to(device)
20
  model.eval()
21
 
22
+ # 🟢 Biến lưu ảnh Depth Map để hiển thị trên Gradio
23
+ depth_map_global = None
24
+
25
  @app.post("/analyze_path/")
26
  async def analyze_path(file: UploadFile = File(...)):
27
+ """Xử ảnh Depth Map trả về lệnh điều hướng"""
28
+ global depth_map_global # Dùng biến toàn cục để hiển thị trên Gradio
29
+
30
  start_time = time.time()
31
+
32
  # 🟢 Đọc file ảnh từ ESP32
33
  image_bytes = await file.read()
34
  image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
35
 
36
  # 🔵 Resize ảnh để xử lý nhanh hơn
37
+ image = image.resize((256, 256))
38
  image_np = np.array(image)
39
  flipped_image = cv2.flip(image_np, -1)
40
+
41
  # 🟢 Chuẩn bị ảnh cho mô hình
42
  inputs = feature_extractor(images=flipped_image, return_tensors="pt").to(device)
43
 
 
 
44
  # 🟢 Dự đoán Depth Map với DPT-Hybrid
45
  with torch.no_grad():
46
  outputs = model(**inputs)
 
49
  predicted_depth = outputs.predicted_depth.squeeze().cpu().numpy()
50
  depth_map = (predicted_depth * 255 / predicted_depth.max()).astype("uint8")
51
 
52
+ # 🔵 Chuyển depth_map thành ảnh có thể hiển thị
53
+ depth_colored = cv2.applyColorMap(depth_map, cv2.COLORMAP_INFERNO)
54
+ depth_pil = Image.fromarray(depth_colored)
55
+
56
+ # 🟢 Lưu ảnh Depth Map để hiển thị trên Gradio
57
+ depth_map_global = depth_pil
58
+
59
  end_time = time.time()
60
  print(f"⏳ DPT xử lý trong {end_time - start_time:.4f} giây")
61
 
 
63
  start_detect_time = time.time()
64
  command = detect_path(depth_map)
65
  end_detect_time = time.time()
66
+ print(f"⏳ detect_path() xử lý trong {end_detect_time - start_detect_time:.4f} giây")
67
 
68
+ return command # Trả về lệnh điều hướng (không kèm ảnh)
69
 
70
  def detect_path(depth_map):
71
  """Phân tích đường đi từ ảnh Depth Map"""
 
73
  center_x = w // 2
74
  scan_y = int(h * 0.8) # Quét dòng 80% từ trên xuống
75
 
76
+ left_region = np.mean(depth_map[scan_y, :center_x])
77
+ right_region = np.mean(depth_map[scan_y, center_x:])
 
78
  center_region = np.mean(depth_map[scan_y, center_x - 40:center_x + 40])
79
 
80
+ # 🟢 Cải thiện logic xử
81
+ threshold = 100 # Ngưỡng phân biệt vật cản
 
 
 
 
 
 
82
  if center_region > threshold:
83
  return "forward"
84
+ elif left_region > right_region:
 
 
85
  return "left"
86
  elif right_region > left_region:
87
  return "right"
88
+ else:
89
+ return "backward"
 
90
 
91
  # 🟢 Chạy server FastAPI
92
  if __name__ == "__main__":
93
+ uvicorn.run(app, host="0.0.0.0", port=7860)