adpro commited on
Commit
4a9c7f0
·
verified ·
1 Parent(s): 425f8a0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -7
app.py CHANGED
@@ -7,7 +7,6 @@ from fastapi import FastAPI, File, UploadFile
7
  from PIL import Image
8
  import uvicorn
9
  from torchvision import transforms
10
- from midas.model_loader import load_model # Thư viện MiDaS
11
 
12
  # 🟢 Tạo FastAPI
13
  app = FastAPI()
@@ -15,9 +14,9 @@ app = FastAPI()
15
  # 🟢 Kiểm tra GPU
16
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
 
18
- # 🟢 Tải model MiDaS (DPT-Swin2 Large 384)
19
- model_path = "dpt_swin2_large_384.pt" # Đảm bảo đã tải file này từ GitHub
20
- midas = load_model(model_path, device)
21
  midas.eval()
22
 
23
  # 🟢 Chuẩn bị bộ tiền xử lý ảnh
@@ -32,9 +31,10 @@ async def analyze_path(file: UploadFile = File(...)):
32
  # 🟢 Đọc file ảnh từ ESP32
33
  image_bytes = await file.read()
34
  image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
35
-
 
36
  # 🔵 Resize và chuẩn hóa ảnh
37
- input_tensor = transform(image).unsqueeze(0).to(device)
38
 
39
  # 🟢 Dự đoán Depth Map với MiDaS
40
  start_time = time.time()
@@ -79,4 +79,4 @@ def detect_path(depth_map):
79
 
80
  # 🟢 Chạy server FastAPI
81
  if __name__ == "__main__":
82
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
7
  from PIL import Image
8
  import uvicorn
9
  from torchvision import transforms
 
10
 
11
  # 🟢 Tạo FastAPI
12
  app = FastAPI()
 
14
  # 🟢 Kiểm tra GPU
15
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
 
17
+ # 🟢 Tải model MiDaS
18
+ midas = torch.hub.load("isl-org/MiDaS", "DPT_Swin2_L_384")
19
+ midas.to(device)
20
  midas.eval()
21
 
22
  # 🟢 Chuẩn bị bộ tiền xử lý ảnh
 
31
  # 🟢 Đọc file ảnh từ ESP32
32
  image_bytes = await file.read()
33
  image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
34
+ image_np = np.array(image)
35
+ flipped_image = cv2.flip(image_np, -1)
36
  # 🔵 Resize và chuẩn hóa ảnh
37
+ input_tensor = transform(flipped_image).unsqueeze(0).to(device)
38
 
39
  # 🟢 Dự đoán Depth Map với MiDaS
40
  start_time = time.time()
 
79
 
80
  # 🟢 Chạy server FastAPI
81
  if __name__ == "__main__":
82
+ uvicorn.run(app, host="0.0.0.0", port=7860)