Spaces:
Runtime error
Runtime error
Upload predict.py
Browse files- predict.py +75 -0
predict.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
import data.utils
|
| 5 |
+
import model.utils as model_utils
|
| 6 |
+
|
| 7 |
+
from test import predict_song
|
| 8 |
+
from model.waveunet import Waveunet
|
| 9 |
+
|
| 10 |
+
def main(args):
|
| 11 |
+
# MODEL
|
| 12 |
+
num_features = [args.features*i for i in range(1, args.levels+1)] if args.feature_growth == "add" else \
|
| 13 |
+
[args.features*2**i for i in range(0, args.levels)]
|
| 14 |
+
target_outputs = int(args.output_size * args.sr)
|
| 15 |
+
model = Waveunet(args.channels, num_features, args.channels, args.instruments, kernel_size=args.kernel_size,
|
| 16 |
+
target_output_size=target_outputs, depth=args.depth, strides=args.strides,
|
| 17 |
+
conv_type=args.conv_type, res=args.res, separate=args.separate)
|
| 18 |
+
|
| 19 |
+
if args.cuda:
|
| 20 |
+
model = model_utils.DataParallel(model)
|
| 21 |
+
print("move model to gpu")
|
| 22 |
+
model.cuda()
|
| 23 |
+
|
| 24 |
+
print("Loading model from checkpoint " + str(args.load_model))
|
| 25 |
+
state = model_utils.load_model(model, None, args.load_model, args.cuda)
|
| 26 |
+
print('Step', state['step'])
|
| 27 |
+
|
| 28 |
+
preds = predict_song(args, args.input, model)
|
| 29 |
+
|
| 30 |
+
output_folder = os.path.dirname(args.input) if args.output is None else args.output
|
| 31 |
+
for inst in preds.keys():
|
| 32 |
+
data.utils.write_wav(os.path.join(output_folder, os.path.basename(args.input) + "_" + inst + ".wav"), preds[inst], args.sr)
|
| 33 |
+
|
| 34 |
+
if __name__ == '__main__':
|
| 35 |
+
parser = argparse.ArgumentParser()
|
| 36 |
+
parser.add_argument('--instruments', type=str, nargs='+', default=["bass", "drums", "other", "vocals"],
|
| 37 |
+
help="List of instruments to separate (default: \"bass drums other vocals\")")
|
| 38 |
+
parser.add_argument('--cuda', action='store_true',
|
| 39 |
+
help='Use CUDA (default: False)')
|
| 40 |
+
parser.add_argument('--features', type=int, default=32,
|
| 41 |
+
help='Number of feature channels per layer')
|
| 42 |
+
parser.add_argument('--load_model', type=str, default='checkpoints/waveunet/model',
|
| 43 |
+
help='Reload a previously trained model')
|
| 44 |
+
parser.add_argument('--batch_size', type=int, default=4,
|
| 45 |
+
help="Batch size")
|
| 46 |
+
parser.add_argument('--levels', type=int, default=6,
|
| 47 |
+
help="Number of DS/US blocks")
|
| 48 |
+
parser.add_argument('--depth', type=int, default=1,
|
| 49 |
+
help="Number of convs per block")
|
| 50 |
+
parser.add_argument('--sr', type=int, default=44100,
|
| 51 |
+
help="Sampling rate")
|
| 52 |
+
parser.add_argument('--channels', type=int, default=2,
|
| 53 |
+
help="Number of input audio channels")
|
| 54 |
+
parser.add_argument('--kernel_size', type=int, default=5,
|
| 55 |
+
help="Filter width of kernels. Has to be an odd number")
|
| 56 |
+
parser.add_argument('--output_size', type=float, default=2.0,
|
| 57 |
+
help="Output duration")
|
| 58 |
+
parser.add_argument('--strides', type=int, default=4,
|
| 59 |
+
help="Strides in Waveunet")
|
| 60 |
+
parser.add_argument('--conv_type', type=str, default="gn",
|
| 61 |
+
help="Type of convolution (normal, BN-normalised, GN-normalised): normal/bn/gn")
|
| 62 |
+
parser.add_argument('--res', type=str, default="fixed",
|
| 63 |
+
help="Resampling strategy: fixed sinc-based lowpass filtering or learned conv layer: fixed/learned")
|
| 64 |
+
parser.add_argument('--separate', type=int, default=1,
|
| 65 |
+
help="Train separate model for each source (1) or only one (0)")
|
| 66 |
+
parser.add_argument('--feature_growth', type=str, default="double",
|
| 67 |
+
help="How the features in each layer should grow, either (add) the initial number of features each time, or multiply by 2 (double)")
|
| 68 |
+
|
| 69 |
+
parser.add_argument('--input', type=str, default=os.path.join("audio_examples", "Cristina Vane - So Easy", "mix.mp3"),
|
| 70 |
+
help="Path to input mixture to be separated")
|
| 71 |
+
parser.add_argument('--output', type=str, default=None, help="Output path (same folder as input path if not set)")
|
| 72 |
+
|
| 73 |
+
args = parser.parse_args()
|
| 74 |
+
|
| 75 |
+
main(args)
|