wli3221134 commited on
Commit
4151dd8
·
verified ·
1 Parent(s): c63d90f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -8
app.py CHANGED
@@ -27,13 +27,16 @@ checkpoint_path = load_model()
27
 
28
  # 将 detect 函数移到 GPU 装饰器下
29
  @spaces.GPU
30
- def detect_on_gpu(dataset):
31
  """在 GPU 上进行音频伪造检测"""
32
  print("\n=== 开始音频检测 ===")
33
 
34
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
35
  print(f"使用设备: {device}")
36
 
 
 
 
37
  print("正在初始化模型...")
38
  model = SpoofVerificationModel().to(device)
39
 
@@ -59,8 +62,9 @@ def detect_on_gpu(dataset):
59
 
60
  print("\n开始处理音频数据...")
61
  with torch.no_grad():
62
- for batch_idx, batch in enumerate(dataset):
63
  print(f"\n处理批次 {batch_idx + 1}")
 
64
  waveforms = batch['waveforms'].numpy() # [B, T]
65
  features = feature_extractor(waveforms, sampling_rate=16000, return_attention_mask=True, padding_value=0, return_tensors="pt").to(device)
66
  outputs = model(features)
@@ -77,12 +81,9 @@ def detect_on_gpu(dataset):
77
  return result
78
 
79
  def audio_deepfake_detection(audio_path):
80
-
81
- # 数据集处理
82
- audio_dataset = dataset.DemoDataset(audio_path)
83
-
84
- # 调用 GPU 检测函数
85
- result = detect_on_gpu(audio_dataset)
86
  is_fake = "是/Yes" if result["is_fake"] else "否/No"
87
  confidence = f"{100*result['confidence']:.2f}%"
88
 
 
27
 
28
  # 将 detect 函数移到 GPU 装饰器下
29
  @spaces.GPU
30
+ def detect_on_gpu(audio_path):
31
  """在 GPU 上进行音频伪造检测"""
32
  print("\n=== 开始音频检测 ===")
33
 
34
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
35
  print(f"使用设备: {device}")
36
 
37
+ # 数据集处理移到GPU函数内部
38
+ audio_dataset = dataset.DemoDataset(audio_path)
39
+
40
  print("正在初始化模型...")
41
  model = SpoofVerificationModel().to(device)
42
 
 
62
 
63
  print("\n开始处理音频数据...")
64
  with torch.no_grad():
65
+ for batch_idx, batch in enumerate(audio_dataset):
66
  print(f"\n处理批次 {batch_idx + 1}")
67
+ print('waveforms shape:', batch['waveforms'].shape)
68
  waveforms = batch['waveforms'].numpy() # [B, T]
69
  features = feature_extractor(waveforms, sampling_rate=16000, return_attention_mask=True, padding_value=0, return_tensors="pt").to(device)
70
  outputs = model(features)
 
81
  return result
82
 
83
  def audio_deepfake_detection(audio_path):
84
+ # 移除了数据集处理步骤
85
+ # 直接传递音频路径到GPU函数
86
+ result = detect_on_gpu(audio_path)
 
 
 
87
  is_fake = "是/Yes" if result["is_fake"] else "否/No"
88
  confidence = f"{100*result['confidence']:.2f}%"
89