JulienHalgand's picture
GPU support
162bd64
import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), 'amt/src')))
import shutil
import mimetypes
import subprocess
import gradio as gr
import torchaudio
import spaces
from model_helper import load_model_checkpoint, transcribe
from prepare_media import prepare_media
from typing import Tuple, Dict, Literal
MODEL_NAME = 'YPTF.MoE+Multi (noPS)' # @param ["YMT3+", "YPTF+Single (noPS)", "YPTF+Multi (PS)", "YPTF.MoE+Multi (noPS)", "YPTF.MoE+Multi (PS)"]
PRECISION = '16'# if torch.cuda.is_available() else '32'# @param ["32", "bf16-mixed", "16"]
PROJECT = '2024'
MODELS = {
"YMT3+": {
"checkpoint": "[email protected]",
"args": ["[email protected]", '-p', PROJECT, '-pr', PRECISION]
},
"YPTF+Single (noPS)": {
"checkpoint": "ptf_all_cross_rebal5_mirst_xk2_edr005_attend_c_full_plus_b100@model.ckpt",
"args": ["ptf_all_cross_rebal5_mirst_xk2_edr005_attend_c_full_plus_b100@model.ckpt", '-p', PROJECT, '-enc', 'perceiver-tf', '-ac', 'spec',
'-hop', '300', '-atc', '1', '-pr', PRECISION]
},
"YPTF+Multi (PS)": {
"checkpoint": "mc13_256_all_cross_v6_xk5_amp0811_edr005_attend_c_full_plus_2psn_nl26_sb_b26r_800k@model.ckpt",
"args": ["mc13_256_all_cross_v6_xk5_amp0811_edr005_attend_c_full_plus_2psn_nl26_sb_b26r_800k@model.ckpt", '-p', PROJECT, '-tk', 'mc13_full_plus_256',
'-dec', 'multi-t5', '-nl', '26', '-enc', 'perceiver-tf','-ac', 'spec', '-hop', '300', '-atc', '1', '-pr', PROJECT]
},
"YPTF.MoE+Multi (noPS)": {
"checkpoint": "mc13_256_g4_all_v7_mt3f_sqr_rms_moe_wf4_n8k2_silu_rope_rp_b36_nops@last.ckpt",
"args": ["mc13_256_g4_all_v7_mt3f_sqr_rms_moe_wf4_n8k2_silu_rope_rp_b36_nops@last.ckpt", '-p', PROJECT, '-tk', 'mc13_full_plus_256', '-dec', 'multi-t5',
'-nl', '26', '-enc', 'perceiver-tf', '-sqr', '1', '-ff', 'moe',
'-wf', '4', '-nmoe', '8', '-kmoe', '2', '-act', 'silu', '-epe', 'rope',
'-rp', '1', '-ac', 'spec', '-hop', '300', '-atc', '1', '-pr', PRECISION]
},
"YPTF.MoE+Multi (PS)": {
"checkpoint": "mc13_256_g4_all_v7_mt3f_sqr_rms_moe_wf4_n8k2_silu_rope_rp_b80_ps2@model.ckpt",
"args": ["mc13_256_g4_all_v7_mt3f_sqr_rms_moe_wf4_n8k2_silu_rope_rp_b80_ps2@model.ckpt", '-p', PROJECT, '-tk', 'mc13_full_plus_256', '-dec', 'multi-t5',
'-nl', '26', '-enc', 'perceiver-tf', '-sqr', '1', '-ff', 'moe',
'-wf', '4', '-nmoe', '8', '-kmoe', '2', '-act', 'silu', '-epe', 'rope',
'-rp', '1', '-ac', 'spec', '-hop', '300', '-atc', '1', '-pr', PRECISION]
}
}
log_file = 'amt/log.txt'
model = load_model_checkpoint(args=MODELS[MODEL_NAME]["args"], device="cpu")
model.to("cuda")
def prepare_media(source_path_or_url: os.PathLike,
source_type: Literal['audio_filepath', 'youtube_url'],
delete_video: bool = True,
simulate = False) -> Dict:
"""prepare media from source path or youtube, and return audio info"""
# Get audio_file
if source_type == 'audio_filepath':
audio_file = source_path_or_url
elif source_type == 'youtube_url':
if os.path.exists('/download/yt_audio.mp3'):
os.remove('/download/yt_audio.mp3')
# Download from youtube
with open(log_file, 'w') as lf:
audio_file = './downloaded/yt_audio'
command = ['yt-dlp', '-x', source_path_or_url, '-f', 'bestaudio',
'-o', audio_file, '--audio-format', 'mp3', '--restrict-filenames',
'--extractor-retries', '10',
'--force-overwrites', '--username', 'oauth2', '--password', '', '-v']
if simulate:
command = command + ['-s']
process = subprocess.Popen(command,
stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True)
for line in iter(process.stdout.readline, ''):
# Filter out unnecessary messages
print(line)
if "www.google.com/device" in line:
hl_text = line.replace("https://www.google.com/device", "\033[93mhttps://www.google.com/device\x1b[0m").split()
hl_text[-1] = "\x1b[31;1m" + hl_text[-1] + "\x1b[0m"
lf.write(' '.join(hl_text)); lf.flush()
elif "Authorization successful" in line or "Video unavailable" in line:
lf.write(line); lf.flush()
process.stdout.close()
process.wait()
audio_file += '.mp3'
else:
raise ValueError(source_type)
# Create info
info = torchaudio.info(audio_file)
return {
"filepath": audio_file,
"track_name": os.path.basename(audio_file).split('.')[0],
"sample_rate": int(info.sample_rate),
"bits_per_sample": int(info.bits_per_sample),
"num_channels": int(info.num_channels),
"num_frames": int(info.num_frames),
"duration": int(info.num_frames / info.sample_rate),
"encoding": str.lower(info.encoding),
}
@spaces.GPU
def handle_audio(file_path):
# Guess extension from MIME
mime_type, _ = mimetypes.guess_type(file_path)
ext = mimetypes.guess_extension(mime_type) or os.path.splitext(file_path)[1] or ".bin"
output_path = f"received_audio{ext}"
shutil.copy(file_path, output_path)
audio_info = prepare_media(output_path, source_type='audio_filepath')
midifile_path = transcribe(model, audio_info)
return midifile_path
demo = gr.Interface(
fn=handle_audio,
inputs=gr.Audio(type="filepath"),
outputs=gr.File(),
)
if __name__ == "__main__":
demo.launch(
server_port=7860
)