File size: 5,825 Bytes
e5007d2
b538a96
 
 
 
e5007d2
b538a96
40a0ff4
e5007d2
162bd64
40a0ff4
b538a96
 
 
e5007d2
 
b538a96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e5007d2
 
b538a96
162bd64
b538a96
e5007d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162bd64
b538a96
 
 
 
 
 
e5007d2
 
 
 
 
40a0ff4
b538a96
 
 
e5007d2
b538a96
40a0ff4
 
e5007d2
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), 'amt/src')))
import shutil
import mimetypes
import subprocess

import gradio as gr
import torchaudio
import spaces

from model_helper import load_model_checkpoint, transcribe
from prepare_media import prepare_media

from typing import Tuple, Dict, Literal

MODEL_NAME = 'YPTF.MoE+Multi (noPS)' # @param ["YMT3+", "YPTF+Single (noPS)", "YPTF+Multi (PS)", "YPTF.MoE+Multi (noPS)", "YPTF.MoE+Multi (PS)"]
PRECISION = '16'# if torch.cuda.is_available() else '32'# @param ["32", "bf16-mixed", "16"]
PROJECT = '2024'

MODELS = {
    "YMT3+": {
        "checkpoint": "[email protected]",
        "args": ["[email protected]", '-p', PROJECT, '-pr', PRECISION]
    },
    "YPTF+Single (noPS)": {
        "checkpoint": "ptf_all_cross_rebal5_mirst_xk2_edr005_attend_c_full_plus_b100@model.ckpt",
        "args": ["ptf_all_cross_rebal5_mirst_xk2_edr005_attend_c_full_plus_b100@model.ckpt", '-p', PROJECT, '-enc', 'perceiver-tf', '-ac', 'spec',
            '-hop', '300', '-atc', '1', '-pr', PRECISION]
    },
    "YPTF+Multi (PS)": {
        "checkpoint": "mc13_256_all_cross_v6_xk5_amp0811_edr005_attend_c_full_plus_2psn_nl26_sb_b26r_800k@model.ckpt",
        "args": ["mc13_256_all_cross_v6_xk5_amp0811_edr005_attend_c_full_plus_2psn_nl26_sb_b26r_800k@model.ckpt", '-p', PROJECT, '-tk', 'mc13_full_plus_256',
                '-dec', 'multi-t5', '-nl', '26', '-enc', 'perceiver-tf','-ac', 'spec', '-hop', '300', '-atc', '1', '-pr', PROJECT]
    },
    "YPTF.MoE+Multi (noPS)": {
        "checkpoint": "mc13_256_g4_all_v7_mt3f_sqr_rms_moe_wf4_n8k2_silu_rope_rp_b36_nops@last.ckpt",
        "args": ["mc13_256_g4_all_v7_mt3f_sqr_rms_moe_wf4_n8k2_silu_rope_rp_b36_nops@last.ckpt",  '-p', PROJECT, '-tk', 'mc13_full_plus_256', '-dec', 'multi-t5',
            '-nl', '26', '-enc', 'perceiver-tf', '-sqr', '1', '-ff', 'moe',
            '-wf', '4', '-nmoe', '8', '-kmoe', '2', '-act', 'silu', '-epe', 'rope',
            '-rp', '1', '-ac', 'spec', '-hop', '300', '-atc', '1', '-pr', PRECISION]
    },
    "YPTF.MoE+Multi (PS)": {
        "checkpoint": "mc13_256_g4_all_v7_mt3f_sqr_rms_moe_wf4_n8k2_silu_rope_rp_b80_ps2@model.ckpt",
        "args": ["mc13_256_g4_all_v7_mt3f_sqr_rms_moe_wf4_n8k2_silu_rope_rp_b80_ps2@model.ckpt", '-p', PROJECT, '-tk', 'mc13_full_plus_256', '-dec', 'multi-t5',
            '-nl', '26', '-enc', 'perceiver-tf', '-sqr', '1', '-ff', 'moe',
            '-wf', '4', '-nmoe', '8', '-kmoe', '2', '-act', 'silu', '-epe', 'rope',
            '-rp', '1', '-ac', 'spec', '-hop', '300', '-atc', '1', '-pr', PRECISION]
    }
}

log_file = 'amt/log.txt'

model = load_model_checkpoint(args=MODELS[MODEL_NAME]["args"], device="cpu")
model.to("cuda")

def prepare_media(source_path_or_url: os.PathLike,
                  source_type: Literal['audio_filepath', 'youtube_url'],
                  delete_video: bool = True,
                  simulate = False) -> Dict:
    """prepare media from source path or youtube, and return audio info"""
    # Get audio_file
    if source_type == 'audio_filepath':
        audio_file = source_path_or_url
    elif source_type == 'youtube_url':
        if os.path.exists('/download/yt_audio.mp3'):
            os.remove('/download/yt_audio.mp3')
        # Download from youtube
        with open(log_file, 'w') as lf:
            audio_file = './downloaded/yt_audio'
            command = ['yt-dlp', '-x', source_path_or_url, '-f', 'bestaudio',
                '-o', audio_file, '--audio-format', 'mp3', '--restrict-filenames',
                '--extractor-retries', '10',
                '--force-overwrites', '--username', 'oauth2', '--password', '', '-v']
            if simulate:
                command = command + ['-s']
            process = subprocess.Popen(command,
                stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True)
        
            for line in iter(process.stdout.readline, ''):
                # Filter out unnecessary messages
                print(line)
                if "www.google.com/device" in line:
                    hl_text = line.replace("https://www.google.com/device", "\033[93mhttps://www.google.com/device\x1b[0m").split()
                    hl_text[-1] = "\x1b[31;1m" + hl_text[-1] + "\x1b[0m"
                    lf.write(' '.join(hl_text)); lf.flush()
                elif "Authorization successful" in line or "Video unavailable" in line:
                    lf.write(line); lf.flush()
            process.stdout.close()
            process.wait()
        
        audio_file += '.mp3'
    else:
        raise ValueError(source_type)

    # Create info
    info = torchaudio.info(audio_file)
    return {
        "filepath": audio_file,
        "track_name": os.path.basename(audio_file).split('.')[0],
        "sample_rate": int(info.sample_rate),
        "bits_per_sample": int(info.bits_per_sample),
        "num_channels": int(info.num_channels),
        "num_frames": int(info.num_frames),
        "duration": int(info.num_frames / info.sample_rate),
        "encoding": str.lower(info.encoding),
        }
@spaces.GPU
def handle_audio(file_path):
    # Guess extension from MIME
    mime_type, _ = mimetypes.guess_type(file_path)
    ext = mimetypes.guess_extension(mime_type) or os.path.splitext(file_path)[1] or ".bin"
    output_path = f"received_audio{ext}"
    shutil.copy(file_path, output_path)

    audio_info = prepare_media(output_path, source_type='audio_filepath')
    midifile_path = transcribe(model, audio_info)

    return midifile_path

demo = gr.Interface(
    fn=handle_audio,
    inputs=gr.Audio(type="filepath"),
    outputs=gr.File(),
)

if __name__ == "__main__":
    demo.launch(
        server_port=7860
    )