Spaces:
Running
on
Zero
Running
on
Zero
File size: 9,880 Bytes
406f22d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
import os
import random
from typing import Union
import soundfile as sf
import torch
import yaml
import json
import argparse
import numpy as np
import pandas as pd
from tqdm import tqdm
from pprint import pprint
from scipy.io import wavfile
import warnings
import torchaudio
warnings.filterwarnings("ignore")
import look2hear.models
import look2hear.datas
from look2hear.metrics import MetricsTracker
from look2hear.utils import tensors_to_device, RichProgressBarTheme, MyMetricsTextColumn, BatchesProcessedColumn
from rich.progress import (
BarColumn,
Progress,
TextColumn,
TimeRemainingColumn,
TransferSpeedColumn,
)
parser = argparse.ArgumentParser()
parser.add_argument("--conf_dir",
default="local/mixit_conf.yml",
help="Full path to save best validation model")
compute_metrics = ["si_sdr", "sdr"]
os.environ['CUDA_VISIBLE_DEVICES'] = "8"
def main(config):
metricscolumn = MyMetricsTextColumn(style=RichProgressBarTheme.metrics)
progress = Progress(
TextColumn("[bold blue]Testing", justify="right"),
BarColumn(bar_width=None),
"•",
BatchesProcessedColumn(style=RichProgressBarTheme.batch_progress),
"•",
TransferSpeedColumn(),
"•",
TimeRemainingColumn(),
"•",
metricscolumn
)
# import pdb; pdb.set_trace()
config["train_conf"]["main_args"]["exp_dir"] = os.path.join(
os.getcwd(), "Experiments", "checkpoint", config["train_conf"]["exp"]["exp_name"]
)
model_path = os.path.join(config["train_conf"]["main_args"]["exp_dir"], "best_model.pth")
# import pdb; pdb.set_trace()
# conf["train_conf"]["masknet"].update({"n_src": 2})
model = getattr(look2hear.models, config["train_conf"]["audionet"]["audionet_name"]).from_pretrain(
model_path,
sample_rate=config["train_conf"]["datamodule"]["data_config"]["sample_rate"],
**config["train_conf"]["audionet"]["audionet_config"],
)
if config["train_conf"]["training"]["gpus"]:
device = "cuda"
model.to(device)
model_device = next(model.parameters()).device
datamodule: object = getattr(look2hear.datas, config["train_conf"]["datamodule"]["data_name"])(
**config["train_conf"]["datamodule"]["data_config"]
)
datamodule.setup()
_, _ , test_set = datamodule.make_sets
# Randomly choose the indexes of sentences to save.
ex_save_dir = os.path.join(config["train_conf"]["main_args"]["exp_dir"], "results/")
os.makedirs(ex_save_dir, exist_ok=True)
metrics = MetricsTracker(
save_file=os.path.join(ex_save_dir, "metrics.csv"))
torch.no_grad().__enter__()
with progress:
for idx in progress.track(range(len(test_set))):
if idx == 825:
# Forward the network on the mixture.
mix, sources, key = tensors_to_device(test_set[idx],
device=model_device)
est_sources = model(mix[None])
mix_np = mix
sources_np = sources
est_sources_np = est_sources.squeeze(0)
# metrics(mix=mix_np,
# clean=sources_np,
# estimate=est_sources_np,
# key=key)
save_dir = os.path.join("./result/TIGER", "idx{}".format(idx))
# est_sources_np = normalize_tensor_wav(est_sources_np)
for i in range(est_sources_np.shape[0]):
os.makedirs(os.path.join(save_dir, "s{}/".format(i + 1)), exist_ok=True)
# torchaudio.save(os.path.join(save_dir, "s{}/".format(i + 1)) + key, est_sources_np[i].unsqueeze(0).cpu(), 16000)
torchaudio.save(os.path.join(save_dir, "s{}/".format(i + 1)) + key.split("/")[-1], est_sources_np[i].unsqueeze(0).cpu(), 16000)
# if idx % 50 == 0:
# metricscolumn.update(metrics.update())
metrics.final()
if __name__ == "__main__":
args = parser.parse_args()
arg_dic = dict(vars(args))
# Load training config
with open(args.conf_dir, "rb") as f:
train_conf = yaml.safe_load(f)
arg_dic["train_conf"] = train_conf
# print(arg_dic)
main(arg_dic)
|