UTMOSv2 / app.py
kAIto47802
Resolved conflict in README.md
b55d767
raw
history blame
2.2 kB
import importlib
from types import SimpleNamespace
import gradio as gr
import pandas as pd
# import spaces
import torch
from utmosv2.utils import get_dataset, get_model
description = (
"# πŸš€ UTMOSv2 demo\n\n"
"This is a demonstration of MOS prediction using UTMOSv2. "
"This demonstration only accepts `.wav` format. Best at 16 kHz sampling rate."
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
config = importlib.import_module("utmosv2.config.fusion_stage3")
cfg = SimpleNamespace(**{attr: getattr(config, attr) for attr in config.__dict__ if not attr.startswith("__")})
cfg.reproduce = False
cfg.config = "fusion_stage3"
cfg.print_config = False
cfg.data_config = None
cfg.phase = "inference"
cfg.weight = None
cfg.num_workers = 1
# @spaces.GPU
def predict_mos(audio_path: str, domain: str) -> float:
data = pd.DataFrame({"file_path": [audio_path]})
data["dataset"] = domain
data['mos'] = 0
preds = 0.0
for fold in range(5):
cfg.now_fold = fold
model = get_model(cfg, device)
for _ in range(5):
test_dataset = get_dataset(cfg, data, "test")
p = model(*[torch.tensor(t).unsqueeze(0) for t in test_dataset[0][:-1]])
preds += p[0]
preds /= 25.0
return preds
with gr.Blocks() as demo:
gr.Markdown(description)
with gr.Row():
with gr.Column():
audio = gr.Audio(type="filepath", label="Audio")
domain = gr.Dropdown(
[
"sarulab",
"bvcc",
"somos",
"blizzard2008",
"blizzard2009",
"blizzard2010-EH1",
"blizzard2010-EH2",
"blizzard2010-ES1",
"blizzard2010-ES3",
"blizzard2011",
],
label="Data-domain ID for the MOS prediction",
)
submit = gr.Button(value="Submit")
with gr.Column():
output = gr.Textbox(label="Predicted MOS", type="text")
submit.click(fn=predict_mos, inputs=[audio, domain], outputs=[output])
demo.queue().launch()