Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,825 Bytes
e5007d2 b538a96 e5007d2 b538a96 40a0ff4 e5007d2 162bd64 40a0ff4 b538a96 e5007d2 b538a96 e5007d2 b538a96 162bd64 b538a96 e5007d2 162bd64 b538a96 e5007d2 40a0ff4 b538a96 e5007d2 b538a96 40a0ff4 e5007d2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), 'amt/src')))
import shutil
import mimetypes
import subprocess
import gradio as gr
import torchaudio
import spaces
from model_helper import load_model_checkpoint, transcribe
from prepare_media import prepare_media
from typing import Tuple, Dict, Literal
MODEL_NAME = 'YPTF.MoE+Multi (noPS)' # @param ["YMT3+", "YPTF+Single (noPS)", "YPTF+Multi (PS)", "YPTF.MoE+Multi (noPS)", "YPTF.MoE+Multi (PS)"]
PRECISION = '16'# if torch.cuda.is_available() else '32'# @param ["32", "bf16-mixed", "16"]
PROJECT = '2024'
MODELS = {
"YMT3+": {
"checkpoint": "[email protected]",
"args": ["[email protected]", '-p', PROJECT, '-pr', PRECISION]
},
"YPTF+Single (noPS)": {
"checkpoint": "ptf_all_cross_rebal5_mirst_xk2_edr005_attend_c_full_plus_b100@model.ckpt",
"args": ["ptf_all_cross_rebal5_mirst_xk2_edr005_attend_c_full_plus_b100@model.ckpt", '-p', PROJECT, '-enc', 'perceiver-tf', '-ac', 'spec',
'-hop', '300', '-atc', '1', '-pr', PRECISION]
},
"YPTF+Multi (PS)": {
"checkpoint": "mc13_256_all_cross_v6_xk5_amp0811_edr005_attend_c_full_plus_2psn_nl26_sb_b26r_800k@model.ckpt",
"args": ["mc13_256_all_cross_v6_xk5_amp0811_edr005_attend_c_full_plus_2psn_nl26_sb_b26r_800k@model.ckpt", '-p', PROJECT, '-tk', 'mc13_full_plus_256',
'-dec', 'multi-t5', '-nl', '26', '-enc', 'perceiver-tf','-ac', 'spec', '-hop', '300', '-atc', '1', '-pr', PROJECT]
},
"YPTF.MoE+Multi (noPS)": {
"checkpoint": "mc13_256_g4_all_v7_mt3f_sqr_rms_moe_wf4_n8k2_silu_rope_rp_b36_nops@last.ckpt",
"args": ["mc13_256_g4_all_v7_mt3f_sqr_rms_moe_wf4_n8k2_silu_rope_rp_b36_nops@last.ckpt", '-p', PROJECT, '-tk', 'mc13_full_plus_256', '-dec', 'multi-t5',
'-nl', '26', '-enc', 'perceiver-tf', '-sqr', '1', '-ff', 'moe',
'-wf', '4', '-nmoe', '8', '-kmoe', '2', '-act', 'silu', '-epe', 'rope',
'-rp', '1', '-ac', 'spec', '-hop', '300', '-atc', '1', '-pr', PRECISION]
},
"YPTF.MoE+Multi (PS)": {
"checkpoint": "mc13_256_g4_all_v7_mt3f_sqr_rms_moe_wf4_n8k2_silu_rope_rp_b80_ps2@model.ckpt",
"args": ["mc13_256_g4_all_v7_mt3f_sqr_rms_moe_wf4_n8k2_silu_rope_rp_b80_ps2@model.ckpt", '-p', PROJECT, '-tk', 'mc13_full_plus_256', '-dec', 'multi-t5',
'-nl', '26', '-enc', 'perceiver-tf', '-sqr', '1', '-ff', 'moe',
'-wf', '4', '-nmoe', '8', '-kmoe', '2', '-act', 'silu', '-epe', 'rope',
'-rp', '1', '-ac', 'spec', '-hop', '300', '-atc', '1', '-pr', PRECISION]
}
}
log_file = 'amt/log.txt'
model = load_model_checkpoint(args=MODELS[MODEL_NAME]["args"], device="cpu")
model.to("cuda")
def prepare_media(source_path_or_url: os.PathLike,
source_type: Literal['audio_filepath', 'youtube_url'],
delete_video: bool = True,
simulate = False) -> Dict:
"""prepare media from source path or youtube, and return audio info"""
# Get audio_file
if source_type == 'audio_filepath':
audio_file = source_path_or_url
elif source_type == 'youtube_url':
if os.path.exists('/download/yt_audio.mp3'):
os.remove('/download/yt_audio.mp3')
# Download from youtube
with open(log_file, 'w') as lf:
audio_file = './downloaded/yt_audio'
command = ['yt-dlp', '-x', source_path_or_url, '-f', 'bestaudio',
'-o', audio_file, '--audio-format', 'mp3', '--restrict-filenames',
'--extractor-retries', '10',
'--force-overwrites', '--username', 'oauth2', '--password', '', '-v']
if simulate:
command = command + ['-s']
process = subprocess.Popen(command,
stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True)
for line in iter(process.stdout.readline, ''):
# Filter out unnecessary messages
print(line)
if "www.google.com/device" in line:
hl_text = line.replace("https://www.google.com/device", "\033[93mhttps://www.google.com/device\x1b[0m").split()
hl_text[-1] = "\x1b[31;1m" + hl_text[-1] + "\x1b[0m"
lf.write(' '.join(hl_text)); lf.flush()
elif "Authorization successful" in line or "Video unavailable" in line:
lf.write(line); lf.flush()
process.stdout.close()
process.wait()
audio_file += '.mp3'
else:
raise ValueError(source_type)
# Create info
info = torchaudio.info(audio_file)
return {
"filepath": audio_file,
"track_name": os.path.basename(audio_file).split('.')[0],
"sample_rate": int(info.sample_rate),
"bits_per_sample": int(info.bits_per_sample),
"num_channels": int(info.num_channels),
"num_frames": int(info.num_frames),
"duration": int(info.num_frames / info.sample_rate),
"encoding": str.lower(info.encoding),
}
@spaces.GPU
def handle_audio(file_path):
# Guess extension from MIME
mime_type, _ = mimetypes.guess_type(file_path)
ext = mimetypes.guess_extension(mime_type) or os.path.splitext(file_path)[1] or ".bin"
output_path = f"received_audio{ext}"
shutil.copy(file_path, output_path)
audio_info = prepare_media(output_path, source_type='audio_filepath')
midifile_path = transcribe(model, audio_info)
return midifile_path
demo = gr.Interface(
fn=handle_audio,
inputs=gr.Audio(type="filepath"),
outputs=gr.File(),
)
if __name__ == "__main__":
demo.launch(
server_port=7860
)
|