Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -27,13 +27,16 @@ checkpoint_path = load_model()
|
|
27 |
|
28 |
# 将 detect 函数移到 GPU 装饰器下
|
29 |
@spaces.GPU
|
30 |
-
def detect_on_gpu(
|
31 |
"""在 GPU 上进行音频伪造检测"""
|
32 |
print("\n=== 开始音频检测 ===")
|
33 |
|
34 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
35 |
print(f"使用设备: {device}")
|
36 |
|
|
|
|
|
|
|
37 |
print("正在初始化模型...")
|
38 |
model = SpoofVerificationModel().to(device)
|
39 |
|
@@ -59,8 +62,9 @@ def detect_on_gpu(dataset):
|
|
59 |
|
60 |
print("\n开始处理音频数据...")
|
61 |
with torch.no_grad():
|
62 |
-
for batch_idx, batch in enumerate(
|
63 |
print(f"\n处理批次 {batch_idx + 1}")
|
|
|
64 |
waveforms = batch['waveforms'].numpy() # [B, T]
|
65 |
features = feature_extractor(waveforms, sampling_rate=16000, return_attention_mask=True, padding_value=0, return_tensors="pt").to(device)
|
66 |
outputs = model(features)
|
@@ -77,12 +81,9 @@ def detect_on_gpu(dataset):
|
|
77 |
return result
|
78 |
|
79 |
def audio_deepfake_detection(audio_path):
|
80 |
-
|
81 |
-
#
|
82 |
-
|
83 |
-
|
84 |
-
# 调用 GPU 检测函数
|
85 |
-
result = detect_on_gpu(audio_dataset)
|
86 |
is_fake = "是/Yes" if result["is_fake"] else "否/No"
|
87 |
confidence = f"{100*result['confidence']:.2f}%"
|
88 |
|
|
|
27 |
|
28 |
# 将 detect 函数移到 GPU 装饰器下
|
29 |
@spaces.GPU
|
30 |
+
def detect_on_gpu(audio_path):
|
31 |
"""在 GPU 上进行音频伪造检测"""
|
32 |
print("\n=== 开始音频检测 ===")
|
33 |
|
34 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
35 |
print(f"使用设备: {device}")
|
36 |
|
37 |
+
# 数据集处理移到GPU函数内部
|
38 |
+
audio_dataset = dataset.DemoDataset(audio_path)
|
39 |
+
|
40 |
print("正在初始化模型...")
|
41 |
model = SpoofVerificationModel().to(device)
|
42 |
|
|
|
62 |
|
63 |
print("\n开始处理音频数据...")
|
64 |
with torch.no_grad():
|
65 |
+
for batch_idx, batch in enumerate(audio_dataset):
|
66 |
print(f"\n处理批次 {batch_idx + 1}")
|
67 |
+
print('waveforms shape:', batch['waveforms'].shape)
|
68 |
waveforms = batch['waveforms'].numpy() # [B, T]
|
69 |
features = feature_extractor(waveforms, sampling_rate=16000, return_attention_mask=True, padding_value=0, return_tensors="pt").to(device)
|
70 |
outputs = model(features)
|
|
|
81 |
return result
|
82 |
|
83 |
def audio_deepfake_detection(audio_path):
|
84 |
+
# 移除了数据集处理步骤
|
85 |
+
# 直接传递音频路径到GPU函数
|
86 |
+
result = detect_on_gpu(audio_path)
|
|
|
|
|
|
|
87 |
is_fake = "是/Yes" if result["is_fake"] else "否/No"
|
88 |
confidence = f"{100*result['confidence']:.2f}%"
|
89 |
|