liuyiyang01 commited on
Commit
8eded48
·
1 Parent(s): 460372d

fix history video display

Browse files
Files changed (2) hide show
  1. app_utils.py +93 -17
  2. requirements.txt +2 -1
app_utils.py CHANGED
@@ -13,6 +13,7 @@ from collections import defaultdict
13
  import shutil
14
  from urllib.parse import urljoin
15
  import oss2
 
16
 
17
  TMP_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "tmp")
18
  os.makedirs(TMP_ROOT, exist_ok=True)
@@ -168,7 +169,7 @@ OSS_CONFIG = {
168
  }
169
 
170
  auth = oss2.Auth(OSS_CONFIG["access_key_id"], OSS_CONFIG["access_key_secret"])
171
- bucket = oss2.Bucket(auth, OSS_CONFIG["endpoint"], OSS_CONFIG["bucket_name"])
172
 
173
 
174
  def list_oss_files(folder_path: str) -> List[str]:
@@ -179,9 +180,15 @@ def list_oss_files(folder_path: str) -> List[str]:
179
  files.append(obj.key)
180
  return sorted(files, key=lambda x: os.path.splitext(x)[0])
181
 
182
- def download_oss_file(oss_path: str, local_path: str):
183
- """从OSS下载文件到本地"""
184
- bucket.get_object_to_file(oss_path, local_path)
 
 
 
 
 
 
185
 
186
  def oss_file_exists(oss_path):
187
  try:
@@ -190,6 +197,7 @@ def oss_file_exists(oss_path):
190
  except Exception as e:
191
  print(f"Error checking if file exists in OSS: {str(e)}")
192
  return False
 
193
  def stream_simulation_results(result_folder: str, task_id: str, request: gr.Request, fps: int = 30):
194
  """
195
  流式输出仿真结果,从OSS读取图片
@@ -208,11 +216,11 @@ def stream_simulation_results(result_folder: str, task_id: str, request: gr.Requ
208
  image_folder = os.path.join(result_folder, "image")
209
  os.makedirs(image_folder, exist_ok=True)
210
  frame_buffer: List[np.ndarray] = []
211
- frames_per_segment = fps * 2 # 每2秒60帧
212
  processed_files = set()
213
  width, height = 0, 0
214
  last_status_check = 0
215
- status_check_interval = 1 # 每5秒检查一次后端状态
216
  max_time = 240
217
 
218
  # 创建临时目录存储下载的图片
@@ -245,7 +253,6 @@ def stream_simulation_results(result_folder: str, task_id: str, request: gr.Requ
245
  try:
246
  oss_files = list_oss_files(image_folder)
247
  new_files = [f for f in oss_files if f not in processed_files]
248
- has_new_frames = False
249
 
250
  for oss_path in new_files:
251
  try:
@@ -262,15 +269,13 @@ def stream_simulation_results(result_folder: str, task_id: str, request: gr.Requ
262
 
263
  frame_buffer.append(frame)
264
  processed_files.add(oss_path)
265
- has_new_frames = True
266
  except Exception as e:
267
  print(f"Error processing {oss_path}: {e}")
268
 
269
- # 如果有新帧且积累够60帧,输出视频片段
270
- if has_new_frames and len(frame_buffer) >= frames_per_segment:
271
- segment_frames = frame_buffer[:frames_per_segment]
272
- frame_buffer = frame_buffer[frames_per_segment:]
273
- yield create_video_segment(segment_frames, fps, width, height, request)
274
 
275
  except Exception as e:
276
  print(f"Error accessing OSS: {e}")
@@ -279,7 +284,6 @@ def stream_simulation_results(result_folder: str, task_id: str, request: gr.Requ
279
 
280
  if max_time <= 0:
281
  raise gr.Error("timeout 240s")
282
-
283
  def create_video_segment(frames: List[np.ndarray], fps: int, width: int, height: int, req: gr.Request) -> str:
284
  """创建视频片段"""
285
  user_dir = os.path.join(TMP_ROOT, str(req.session_hash))
@@ -424,6 +428,76 @@ def convert_to_h264(video_path):
424
  except Exception as e:
425
  raise gr.Error(f"转换过程中发生错误: {str(e)}")
426
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
427
  def run_simulation(
428
  scene: str,
429
  prompt: str,
@@ -488,14 +562,16 @@ def run_simulation(
488
  status = get_task_status(task_id)
489
  print("status: ", status)
490
  if status.get("status") == "completed":
491
- # time.sleep(3)
492
  oss_video_path = os.path.join(result_folder, "manipulation.mp4")
493
  local_video_path = os.path.join(user_dir, task_id, "tasks", "manipulation.mp4")
494
- download_oss_file(oss_video_path, local_video_path)
495
  print("oss_video_path: ", oss_video_path)
496
  print("local_video_path: ", local_video_path)
497
 
498
- video_path = convert_to_h264(local_video_path)
 
 
499
 
500
  # 创建新的历史记录条目
501
  new_entry = {
 
13
  import shutil
14
  from urllib.parse import urljoin
15
  import oss2
16
+ from natsort import natsorted
17
 
18
  TMP_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "tmp")
19
  os.makedirs(TMP_ROOT, exist_ok=True)
 
169
  }
170
 
171
  auth = oss2.Auth(OSS_CONFIG["access_key_id"], OSS_CONFIG["access_key_secret"])
172
+ bucket = oss2.Bucket(auth, OSS_CONFIG["endpoint"], OSS_CONFIG["bucket_name"], enable_crc=False)
173
 
174
 
175
  def list_oss_files(folder_path: str) -> List[str]:
 
180
  files.append(obj.key)
181
  return sorted(files, key=lambda x: os.path.splitext(x)[0])
182
 
183
+ def download_oss_file(oss_path: str, local_path: str) -> bool:
184
+ """从OSS下载文件到本地,返回是否成功"""
185
+ try:
186
+ result = bucket.get_object_to_file(oss_path, local_path)
187
+ print(f"下载: {oss_path}, {result.status}")
188
+ return result.status == 200
189
+ except Exception as e:
190
+ print(f"下载失败: {e}")
191
+ return False
192
 
193
  def oss_file_exists(oss_path):
194
  try:
 
197
  except Exception as e:
198
  print(f"Error checking if file exists in OSS: {str(e)}")
199
  return False
200
+
201
  def stream_simulation_results(result_folder: str, task_id: str, request: gr.Request, fps: int = 30):
202
  """
203
  流式输出仿真结果,从OSS读取图片
 
216
  image_folder = os.path.join(result_folder, "image")
217
  os.makedirs(image_folder, exist_ok=True)
218
  frame_buffer: List[np.ndarray] = []
219
+ min_frames_per_segment = fps * 1 # 至少30帧才输出
220
  processed_files = set()
221
  width, height = 0, 0
222
  last_status_check = 0
223
+ status_check_interval = 5 # 每5秒检查一次后端状态
224
  max_time = 240
225
 
226
  # 创建临时目录存储下载的图片
 
253
  try:
254
  oss_files = list_oss_files(image_folder)
255
  new_files = [f for f in oss_files if f not in processed_files]
 
256
 
257
  for oss_path in new_files:
258
  try:
 
269
 
270
  frame_buffer.append(frame)
271
  processed_files.add(oss_path)
 
272
  except Exception as e:
273
  print(f"Error processing {oss_path}: {e}")
274
 
275
+ # 如果有新帧且积累够60帧以上,输出所有当前帧
276
+ if len(frame_buffer) >= min_frames_per_segment:
277
+ yield create_video_segment(frame_buffer, fps, width, height, request)
278
+ frame_buffer = [] # 清空缓冲区
 
279
 
280
  except Exception as e:
281
  print(f"Error accessing OSS: {e}")
 
284
 
285
  if max_time <= 0:
286
  raise gr.Error("timeout 240s")
 
287
  def create_video_segment(frames: List[np.ndarray], fps: int, width: int, height: int, req: gr.Request) -> str:
288
  """创建视频片段"""
289
  user_dir = os.path.join(TMP_ROOT, str(req.session_hash))
 
428
  except Exception as e:
429
  raise gr.Error(f"转换过程中发生错误: {str(e)}")
430
 
431
+
432
+ def generate_whole_video(task_id: str, request: gr.Request, fps: int = 30) -> str:
433
+ """
434
+ 从图片序列生成完整视频
435
+
436
+ Args:
437
+ task_id: 任务ID
438
+ fps: 视频帧率,默认为30
439
+
440
+ Returns:
441
+ 生成的视频文件路径
442
+ """
443
+ frame_buffer: List[np.ndarray] = []
444
+ user_dir = os.path.join(TMP_ROOT, str(request.session_hash))
445
+ image_folder = os.path.join(user_dir, task_id, "tasks", "images")
446
+
447
+
448
+ # 确保输出目录存在
449
+ result_folder = os.path.join(user_dir, task_id, "tasks", "video")
450
+ os.makedirs(result_folder, exist_ok=True)
451
+
452
+ # 获取所有图片文件并按自然顺序排序
453
+ image_files = [f for f in os.listdir(image_folder) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
454
+ image_files = natsorted(image_files) # 自然排序处理类似 'frame1, frame2, ..., frame10' 的情况
455
+
456
+ if not image_files:
457
+ raise ValueError(f"No image files found in {image_folder}")
458
+
459
+ # 初始化视频尺寸
460
+ width, height = 0, 0
461
+
462
+ for img_file in image_files:
463
+ img_path = os.path.join(image_folder, img_file)
464
+ try:
465
+ frame = cv2.imread(img_path)
466
+ if frame is not None:
467
+ if width == 0: # 第一次获取图像尺寸
468
+ height, width = frame.shape[:2]
469
+ frame_buffer.append(frame)
470
+ except Exception as e:
471
+ print(f"Error processing {img_path}: {e}")
472
+ continue
473
+
474
+ if not frame_buffer:
475
+ raise ValueError("No valid frames found to create video")
476
+
477
+ # 生成视频文件名
478
+ output_video_path = os.path.join(result_folder, f"manipulation.mp4")
479
+
480
+ # 创建视频
481
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v') # 或使用 'avc1' 更好的兼容性
482
+ video_writer = cv2.VideoWriter(
483
+ output_video_path,
484
+ fourcc,
485
+ fps,
486
+ (width, height)
487
+ )
488
+
489
+ for frame in frame_buffer:
490
+ video_writer.write(frame)
491
+
492
+ video_writer.release()
493
+
494
+ # 验证视频是否成功创建
495
+ if not os.path.exists(output_video_path) or os.path.getsize(output_video_path) == 0:
496
+ raise RuntimeError(f"Failed to create video at {output_video_path}")
497
+
498
+ return output_video_path
499
+
500
+
501
  def run_simulation(
502
  scene: str,
503
  prompt: str,
 
562
  status = get_task_status(task_id)
563
  print("status: ", status)
564
  if status.get("status") == "completed":
565
+ time.sleep(3)
566
  oss_video_path = os.path.join(result_folder, "manipulation.mp4")
567
  local_video_path = os.path.join(user_dir, task_id, "tasks", "manipulation.mp4")
568
+ # download_oss_file(oss_video_path, local_video_path)
569
  print("oss_video_path: ", oss_video_path)
570
  print("local_video_path: ", local_video_path)
571
 
572
+ video_path = generate_whole_video(task_id, request)
573
+
574
+ # video_path = convert_to_h264(local_video_path)
575
 
576
  # 创建新的历史记录条目
577
  new_entry = {
requirements.txt CHANGED
@@ -4,4 +4,5 @@ jsonlib-python3
4
  opencv-python
5
  numpy
6
  python-dateutil
7
- oss2
 
 
4
  opencv-python
5
  numpy
6
  python-dateutil
7
+ oss2
8
+ natsort