Spaces:
Running
on
Zero
Running
on
Zero
import sys | |
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), 'amt/src'))) | |
import os | |
import shutil | |
import mimetypes | |
import gradio as gr | |
from model_helper import load_model_checkpoint, transcribe | |
from prepare_media import prepare_media | |
MODEL_NAME = 'YPTF.MoE+Multi (noPS)' # @param ["YMT3+", "YPTF+Single (noPS)", "YPTF+Multi (PS)", "YPTF.MoE+Multi (noPS)", "YPTF.MoE+Multi (PS)"] | |
PRECISION = '16'# if torch.cuda.is_available() else '32'# @param ["32", "bf16-mixed", "16"] | |
PROJECT = '2024' | |
MODELS = { | |
"YMT3+": { | |
"checkpoint": "[email protected]", | |
"args": ["[email protected]", '-p', PROJECT, '-pr', PRECISION] | |
}, | |
"YPTF+Single (noPS)": { | |
"checkpoint": "ptf_all_cross_rebal5_mirst_xk2_edr005_attend_c_full_plus_b100@model.ckpt", | |
"args": ["ptf_all_cross_rebal5_mirst_xk2_edr005_attend_c_full_plus_b100@model.ckpt", '-p', PROJECT, '-enc', 'perceiver-tf', '-ac', 'spec', | |
'-hop', '300', '-atc', '1', '-pr', PRECISION] | |
}, | |
"YPTF+Multi (PS)": { | |
"checkpoint": "mc13_256_all_cross_v6_xk5_amp0811_edr005_attend_c_full_plus_2psn_nl26_sb_b26r_800k@model.ckpt", | |
"args": ["mc13_256_all_cross_v6_xk5_amp0811_edr005_attend_c_full_plus_2psn_nl26_sb_b26r_800k@model.ckpt", '-p', PROJECT, '-tk', 'mc13_full_plus_256', | |
'-dec', 'multi-t5', '-nl', '26', '-enc', 'perceiver-tf','-ac', 'spec', '-hop', '300', '-atc', '1', '-pr', PROJECT] | |
}, | |
"YPTF.MoE+Multi (noPS)": { | |
"checkpoint": "mc13_256_g4_all_v7_mt3f_sqr_rms_moe_wf4_n8k2_silu_rope_rp_b36_nops@last.ckpt", | |
"args": ["mc13_256_g4_all_v7_mt3f_sqr_rms_moe_wf4_n8k2_silu_rope_rp_b36_nops@last.ckpt", '-p', PROJECT, '-tk', 'mc13_full_plus_256', '-dec', 'multi-t5', | |
'-nl', '26', '-enc', 'perceiver-tf', '-sqr', '1', '-ff', 'moe', | |
'-wf', '4', '-nmoe', '8', '-kmoe', '2', '-act', 'silu', '-epe', 'rope', | |
'-rp', '1', '-ac', 'spec', '-hop', '300', '-atc', '1', '-pr', PRECISION] | |
}, | |
"YPTF.MoE+Multi (PS)": { | |
"checkpoint": "mc13_256_g4_all_v7_mt3f_sqr_rms_moe_wf4_n8k2_silu_rope_rp_b80_ps2@model.ckpt", | |
"args": ["mc13_256_g4_all_v7_mt3f_sqr_rms_moe_wf4_n8k2_silu_rope_rp_b80_ps2@model.ckpt", '-p', PROJECT, '-tk', 'mc13_full_plus_256', '-dec', 'multi-t5', | |
'-nl', '26', '-enc', 'perceiver-tf', '-sqr', '1', '-ff', 'moe', | |
'-wf', '4', '-nmoe', '8', '-kmoe', '2', '-act', 'silu', '-epe', 'rope', | |
'-rp', '1', '-ac', 'spec', '-hop', '300', '-atc', '1', '-pr', PRECISION] | |
} | |
} | |
model = load_model_checkpoint(args=MODELS[MODEL_NAME]["args"], device="cpu") | |
#model.to("cuda") | |
def handle_audio(file_path): | |
# Guess extension from MIME | |
mime_type, _ = mimetypes.guess_type(file_path) | |
ext = mimetypes.guess_extension(mime_type) or os.path.splitext(file_path)[1] or ".bin" | |
output_path = f"received_audio{ext}" | |
shutil.copy(file_path, output_path) | |
return output_path | |
demo = gr.Interface( | |
fn=handle_audio, | |
inputs=gr.Audio(type="filepath"), | |
outputs=gr.File() | |
) | |
if __name__ == "__main__": | |
demo.launch() | |