|
import torchaudio |
|
import torch |
|
from model import M11 |
|
import gradio as gr |
|
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
model_PATH = "./model.ckpt" |
|
|
|
classifier = M11.load_from_checkpoint(model_PATH) |
|
classifier.eval() |
|
|
|
def preprocess(signal, sr, device): |
|
|
|
if sr != 8_000: |
|
resampler = torchaudio.transforms.Resample(sr, 8_000).to(device) |
|
signal = resampler(signal) |
|
|
|
if signal.shape[0] > 1: |
|
signal = torch.mean(signal, dim=0, keepdim=True) |
|
|
|
return signal |
|
|
|
def get_likely_index(tensor): |
|
|
|
return tensor.argmax(dim=-1) |
|
|
|
def pipeline(input): |
|
|
|
|
|
sample_rate, audio = input |
|
processed_audio = preprocess(torch.from_numpy(audio), sample_rate, DEVICE) |
|
|
|
with torch.no_grad(): |
|
pred = get_likely_index(classifier(processed_audio.unsqueeze(0))).view(-1) |
|
|
|
return pred[0] |
|
|
|
inputs = gr.inputs.Audio(label="Input Audio", type="numpy") |
|
outputs = "text" |
|
title = "Threat Detection From Bengali Voice Calls" |
|
description = "Gradio demo for Audio Classification, simply upload your audio, or click one of the examples to load them. Read more at the links below." |
|
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2005.07143' target='_blank'>ECAPA-TDNN: Emphasized Channel Attention, Propagation and Aggregation in TDNN Based Speaker Verification</a> | <a href='https://github.com/speechbrain/speechbrain' target='_blank'>Github Repo</a></p>" |
|
examples = [ |
|
['sample_audio.wav'] |
|
] |
|
gr.Interface(pipeline, inputs, outputs, title=title, description=description, article=article, examples=examples).launch() |