yiyang34 commited on
Commit
2c30e2d
·
1 Parent(s): 8eded48

parallel download from oss

Browse files
Files changed (2) hide show
  1. app_utils.py +94 -3
  2. requirements.txt +3 -1
app_utils.py CHANGED
@@ -14,6 +14,9 @@ import shutil
14
  from urllib.parse import urljoin
15
  import oss2
16
  from natsort import natsorted
 
 
 
17
 
18
  TMP_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "tmp")
19
  os.makedirs(TMP_ROOT, exist_ok=True)
@@ -180,11 +183,65 @@ def list_oss_files(folder_path: str) -> List[str]:
180
  files.append(obj.key)
181
  return sorted(files, key=lambda x: os.path.splitext(x)[0])
182
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
  def download_oss_file(oss_path: str, local_path: str) -> bool:
184
  """从OSS下载文件到本地,返回是否成功"""
 
185
  try:
186
  result = bucket.get_object_to_file(oss_path, local_path)
187
- print(f"下载: {oss_path}, {result.status}")
 
188
  return result.status == 200
189
  except Exception as e:
190
  print(f"下载失败: {e}")
@@ -254,12 +311,27 @@ def stream_simulation_results(result_folder: str, task_id: str, request: gr.Requ
254
  oss_files = list_oss_files(image_folder)
255
  new_files = [f for f in oss_files if f not in processed_files]
256
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
257
  for oss_path in new_files:
258
  try:
259
  # 下载文件到本地
260
  filename = os.path.basename(oss_path)
261
  local_path = os.path.join(local_image_dir, filename)
262
- download_oss_file(oss_path, local_path)
263
 
264
  # 读取图片
265
  frame = cv2.imread(local_path)
@@ -302,16 +374,30 @@ def create_video_segment(frames: List[np.ndarray], fps: int, width: int, height:
302
 
303
  def process_remaining_oss_images(oss_folder: str, local_dir: str, processed_files: set, frame_buffer: List[np.ndarray]):
304
  """处理OSS上剩余的图片"""
 
305
  try:
306
  oss_files = list_oss_files(oss_folder)
307
  new_files = [f for f in oss_files if f not in processed_files]
308
 
 
 
 
 
 
 
 
 
 
 
 
 
 
309
  for oss_path in new_files:
310
  try:
311
  # 下载文件到本地
312
  filename = os.path.basename(oss_path)
313
  local_path = os.path.join(local_dir, filename)
314
- download_oss_file(oss_path, local_path)
315
 
316
  # 读取图片
317
  frame = cv2.imread(local_path)
@@ -370,13 +456,18 @@ def get_task_status(task_id: str) -> dict:
370
  """
371
  查询任务状态
372
  """
 
373
  try:
374
  response = requests.get(
375
  f"{API_ENDPOINTS['query_status']}/{task_id}",
376
  timeout=5
377
  )
 
 
378
  return response.json()
379
  except Exception as e:
 
 
380
  return {"status": "error get_task_status", "message": str(e)}
381
 
382
  def terminate_task(task_id: str) -> Optional[dict]:
 
14
  from urllib.parse import urljoin
15
  import oss2
16
  from natsort import natsorted
17
+ from concurrent.futures import ThreadPoolExecutor, as_completed
18
+ from tqdm import tqdm
19
+ import hashlib
20
 
21
  TMP_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "tmp")
22
  os.makedirs(TMP_ROOT, exist_ok=True)
 
183
  files.append(obj.key)
184
  return sorted(files, key=lambda x: os.path.splitext(x)[0])
185
 
186
+ def parallel_download_oss_files(
187
+ bucket,
188
+ oss_folder: str,
189
+ local_dir: str,
190
+ file_list: list[str],
191
+ max_workers: int = 5
192
+ ) -> bool:
193
+ """
194
+ 极简版并行下载指定文件列表
195
+ 参数:
196
+ bucket: OSS Bucket对象
197
+ oss_folder: OSS文件夹路径 (如 "path/to/folder/")
198
+ local_dir: 本地存储目录
199
+ file_list: 需要下载的文件相对路径列表 (如 ["img1.jpg", "sub/img2.png"])
200
+ max_workers: 最大并发数
201
+ """
202
+ def download_single_file(oss_path, local_path):
203
+ try:
204
+ bucket.get_object_to_file(oss_path, local_path)
205
+ return True
206
+ except Exception as e:
207
+ print(f"下载失败 {oss_path}: {str(e)}")
208
+ return False
209
+
210
+ # 确保本地目录存在
211
+ os.makedirs(local_dir, exist_ok=True)
212
+
213
+ # 准备下载任务
214
+ tasks = []
215
+ for file in file_list:
216
+ oss_path = f"{file.lstrip('/')}"
217
+ filename = os.path.basename(oss_path)
218
+ local_path = os.path.join(local_dir, filename)
219
+ # os.makedirs(os.path.dirname(local_path), exist_ok=True)
220
+ tasks.append((oss_path, local_path))
221
+
222
+ # 并行下载
223
+ success_count = 0
224
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
225
+ futures = []
226
+ for oss_path, local_path in tasks:
227
+ futures.append(executor.submit(download_single_file, oss_path, local_path))
228
+
229
+ # 进度条显示
230
+ for future in tqdm(as_completed(futures), total=len(tasks), desc="下载进度"):
231
+ if future.result():
232
+ success_count += 1
233
+
234
+ print(f"下载完成: {success_count}/{len(tasks)} 成功")
235
+ return success_count == len(tasks)
236
+
237
+
238
  def download_oss_file(oss_path: str, local_path: str) -> bool:
239
  """从OSS下载文件到本地,返回是否成功"""
240
+ start_time = time.time() # 记录开始时间
241
  try:
242
  result = bucket.get_object_to_file(oss_path, local_path)
243
+ download_time = time.time() - start_time # 计算下载耗时
244
+ print(f"下载: {oss_path}, 状态码: {result.status}, 耗时: {download_time:.2f}秒")
245
  return result.status == 200
246
  except Exception as e:
247
  print(f"下载失败: {e}")
 
311
  oss_files = list_oss_files(image_folder)
312
  new_files = [f for f in oss_files if f not in processed_files]
313
 
314
+ if len(new_files) != 0:
315
+ print(f"发现新文件: {len(new_files)} 个", new_files)
316
+ success = parallel_download_oss_files(
317
+ bucket=bucket,
318
+ oss_folder=image_folder + "/",
319
+ local_dir=local_image_dir + "/",
320
+ file_list=new_files,
321
+ max_workers=5 # 根据网络带宽调整
322
+ )
323
+ if not success:
324
+ raise gr.Error("无法从OSS同步图片文件")
325
+ # if not download_oss_files_with_ossutil(image_folder + "/", local_image_dir + "/"):
326
+ # raise gr.Error("无法从OSS同步图片文件")
327
+
328
+
329
  for oss_path in new_files:
330
  try:
331
  # 下载文件到本地
332
  filename = os.path.basename(oss_path)
333
  local_path = os.path.join(local_image_dir, filename)
334
+ # download_oss_file(oss_path, local_path)
335
 
336
  # 读取图片
337
  frame = cv2.imread(local_path)
 
374
 
375
  def process_remaining_oss_images(oss_folder: str, local_dir: str, processed_files: set, frame_buffer: List[np.ndarray]):
376
  """处理OSS上剩余的图片"""
377
+
378
  try:
379
  oss_files = list_oss_files(oss_folder)
380
  new_files = [f for f in oss_files if f not in processed_files]
381
 
382
+ if len(new_files) != 0:
383
+ print(f"发现新文件: {len(new_files)} 个", new_files)
384
+ success = parallel_download_oss_files(
385
+ bucket=bucket,
386
+ oss_folder=oss_folder + "/",
387
+ local_dir=local_dir + "/",
388
+ file_list=new_files,
389
+ max_workers=5 # 根据网络带宽调整
390
+ )
391
+ if not success:
392
+ raise gr.Error("无法从OSS同步图片文件")
393
+
394
+
395
  for oss_path in new_files:
396
  try:
397
  # 下载文件到本地
398
  filename = os.path.basename(oss_path)
399
  local_path = os.path.join(local_dir, filename)
400
+ # download_oss_file(oss_path, local_path)
401
 
402
  # 读取图片
403
  frame = cv2.imread(local_path)
 
456
  """
457
  查询任务状态
458
  """
459
+ start_time = time.time()
460
  try:
461
  response = requests.get(
462
  f"{API_ENDPOINTS['query_status']}/{task_id}",
463
  timeout=5
464
  )
465
+ elapsed_time = time.time() - start_time # 计算耗时
466
+ print(f"[查询任务状态] task_id: {task_id}, 耗时: {elapsed_time:.3f}s")
467
  return response.json()
468
  except Exception as e:
469
+ elapsed_time = time.time() - start_time # 计算失败耗时
470
+ print(f"[查询任务状态失败] task_id: {task_id}, 错误: {str(e)}, 耗时: {elapsed_time:.3f}s")
471
  return {"status": "error get_task_status", "message": str(e)}
472
 
473
  def terminate_task(task_id: str) -> Optional[dict]:
requirements.txt CHANGED
@@ -5,4 +5,6 @@ opencv-python
5
  numpy
6
  python-dateutil
7
  oss2
8
- natsort
 
 
 
5
  numpy
6
  python-dateutil
7
  oss2
8
+ natsort
9
+ concurrent-log-handler
10
+ tqdm