Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
e5007d2
1
Parent(s):
b538a96
Ça marche sur le CPU
Browse files- README.md +2 -0
- app.py +67 -5
- environment.yml +40 -0
README.md
CHANGED
@@ -10,3 +10,5 @@ pinned: false
|
|
10 |
---
|
11 |
|
12 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
|
|
10 |
---
|
11 |
|
12 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
13 |
+
|
14 |
+
|
app.py
CHANGED
@@ -1,14 +1,18 @@
|
|
|
|
1 |
import sys
|
2 |
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), 'amt/src')))
|
3 |
-
import os
|
4 |
import shutil
|
5 |
import mimetypes
|
|
|
6 |
|
7 |
import gradio as gr
|
|
|
8 |
|
9 |
from model_helper import load_model_checkpoint, transcribe
|
10 |
from prepare_media import prepare_media
|
11 |
|
|
|
|
|
12 |
MODEL_NAME = 'YPTF.MoE+Multi (noPS)' # @param ["YMT3+", "YPTF+Single (noPS)", "YPTF+Multi (PS)", "YPTF.MoE+Multi (noPS)", "YPTF.MoE+Multi (PS)"]
|
13 |
PRECISION = '16'# if torch.cuda.is_available() else '32'# @param ["32", "bf16-mixed", "16"]
|
14 |
PROJECT = '2024'
|
@@ -44,24 +48,82 @@ MODELS = {
|
|
44 |
}
|
45 |
}
|
46 |
|
|
|
|
|
47 |
model = load_model_checkpoint(args=MODELS[MODEL_NAME]["args"], device="cpu")
|
48 |
#model.to("cuda")
|
49 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
|
51 |
def handle_audio(file_path):
|
52 |
# Guess extension from MIME
|
53 |
mime_type, _ = mimetypes.guess_type(file_path)
|
54 |
ext = mimetypes.guess_extension(mime_type) or os.path.splitext(file_path)[1] or ".bin"
|
55 |
-
|
56 |
output_path = f"received_audio{ext}"
|
57 |
shutil.copy(file_path, output_path)
|
58 |
-
|
|
|
|
|
|
|
|
|
59 |
|
60 |
demo = gr.Interface(
|
61 |
fn=handle_audio,
|
62 |
inputs=gr.Audio(type="filepath"),
|
63 |
-
outputs=gr.File()
|
64 |
)
|
65 |
|
66 |
if __name__ == "__main__":
|
67 |
-
demo.launch(
|
|
|
|
|
|
1 |
+
import os
|
2 |
import sys
|
3 |
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), 'amt/src')))
|
|
|
4 |
import shutil
|
5 |
import mimetypes
|
6 |
+
import subprocess
|
7 |
|
8 |
import gradio as gr
|
9 |
+
import torchaudio
|
10 |
|
11 |
from model_helper import load_model_checkpoint, transcribe
|
12 |
from prepare_media import prepare_media
|
13 |
|
14 |
+
from typing import Tuple, Dict, Literal
|
15 |
+
|
16 |
MODEL_NAME = 'YPTF.MoE+Multi (noPS)' # @param ["YMT3+", "YPTF+Single (noPS)", "YPTF+Multi (PS)", "YPTF.MoE+Multi (noPS)", "YPTF.MoE+Multi (PS)"]
|
17 |
PRECISION = '16'# if torch.cuda.is_available() else '32'# @param ["32", "bf16-mixed", "16"]
|
18 |
PROJECT = '2024'
|
|
|
48 |
}
|
49 |
}
|
50 |
|
51 |
+
log_file = 'amt/log.txt'
|
52 |
+
|
53 |
model = load_model_checkpoint(args=MODELS[MODEL_NAME]["args"], device="cpu")
|
54 |
#model.to("cuda")
|
55 |
|
56 |
+
def prepare_media(source_path_or_url: os.PathLike,
|
57 |
+
source_type: Literal['audio_filepath', 'youtube_url'],
|
58 |
+
delete_video: bool = True,
|
59 |
+
simulate = False) -> Dict:
|
60 |
+
"""prepare media from source path or youtube, and return audio info"""
|
61 |
+
# Get audio_file
|
62 |
+
if source_type == 'audio_filepath':
|
63 |
+
audio_file = source_path_or_url
|
64 |
+
elif source_type == 'youtube_url':
|
65 |
+
if os.path.exists('/download/yt_audio.mp3'):
|
66 |
+
os.remove('/download/yt_audio.mp3')
|
67 |
+
# Download from youtube
|
68 |
+
with open(log_file, 'w') as lf:
|
69 |
+
audio_file = './downloaded/yt_audio'
|
70 |
+
command = ['yt-dlp', '-x', source_path_or_url, '-f', 'bestaudio',
|
71 |
+
'-o', audio_file, '--audio-format', 'mp3', '--restrict-filenames',
|
72 |
+
'--extractor-retries', '10',
|
73 |
+
'--force-overwrites', '--username', 'oauth2', '--password', '', '-v']
|
74 |
+
if simulate:
|
75 |
+
command = command + ['-s']
|
76 |
+
process = subprocess.Popen(command,
|
77 |
+
stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True)
|
78 |
+
|
79 |
+
for line in iter(process.stdout.readline, ''):
|
80 |
+
# Filter out unnecessary messages
|
81 |
+
print(line)
|
82 |
+
if "www.google.com/device" in line:
|
83 |
+
hl_text = line.replace("https://www.google.com/device", "\033[93mhttps://www.google.com/device\x1b[0m").split()
|
84 |
+
hl_text[-1] = "\x1b[31;1m" + hl_text[-1] + "\x1b[0m"
|
85 |
+
lf.write(' '.join(hl_text)); lf.flush()
|
86 |
+
elif "Authorization successful" in line or "Video unavailable" in line:
|
87 |
+
lf.write(line); lf.flush()
|
88 |
+
process.stdout.close()
|
89 |
+
process.wait()
|
90 |
+
|
91 |
+
audio_file += '.mp3'
|
92 |
+
else:
|
93 |
+
raise ValueError(source_type)
|
94 |
+
|
95 |
+
# Create info
|
96 |
+
info = torchaudio.info(audio_file)
|
97 |
+
return {
|
98 |
+
"filepath": audio_file,
|
99 |
+
"track_name": os.path.basename(audio_file).split('.')[0],
|
100 |
+
"sample_rate": int(info.sample_rate),
|
101 |
+
"bits_per_sample": int(info.bits_per_sample),
|
102 |
+
"num_channels": int(info.num_channels),
|
103 |
+
"num_frames": int(info.num_frames),
|
104 |
+
"duration": int(info.num_frames / info.sample_rate),
|
105 |
+
"encoding": str.lower(info.encoding),
|
106 |
+
}
|
107 |
|
108 |
def handle_audio(file_path):
|
109 |
# Guess extension from MIME
|
110 |
mime_type, _ = mimetypes.guess_type(file_path)
|
111 |
ext = mimetypes.guess_extension(mime_type) or os.path.splitext(file_path)[1] or ".bin"
|
|
|
112 |
output_path = f"received_audio{ext}"
|
113 |
shutil.copy(file_path, output_path)
|
114 |
+
|
115 |
+
audio_info = prepare_media(output_path, source_type='audio_filepath')
|
116 |
+
midifile_path = transcribe(model, audio_info)
|
117 |
+
|
118 |
+
return midifile_path
|
119 |
|
120 |
demo = gr.Interface(
|
121 |
fn=handle_audio,
|
122 |
inputs=gr.Audio(type="filepath"),
|
123 |
+
outputs=gr.File(),
|
124 |
)
|
125 |
|
126 |
if __name__ == "__main__":
|
127 |
+
demo.launch(
|
128 |
+
server_port=7860
|
129 |
+
)
|
environment.yml
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: gradio
|
2 |
+
channels:
|
3 |
+
- conda-forge
|
4 |
+
dependencies:
|
5 |
+
- python=3.12
|
6 |
+
- lightning>=2.2.1
|
7 |
+
- deprecated
|
8 |
+
- librosa
|
9 |
+
- einops
|
10 |
+
- transformers==4.45.1
|
11 |
+
- numpy==1.26.4
|
12 |
+
- wandb
|
13 |
+
- annotated-types==0.7.0
|
14 |
+
- anyio==4.9.0
|
15 |
+
- blinker==1.9.0
|
16 |
+
- ertifi==2025.4.26
|
17 |
+
- click==8.1.8
|
18 |
+
- fastapi==0.115.12
|
19 |
+
- Flask==3.1.0
|
20 |
+
- h11==0.16.0
|
21 |
+
- idna==3.10
|
22 |
+
- importlib_metadata==8.6.1
|
23 |
+
- itsdangerous==2.2.0
|
24 |
+
- Jinja2==3.1.6
|
25 |
+
- MarkupSafe==2.1.5
|
26 |
+
- pip==25.1.1
|
27 |
+
- pydantic==2.11.4
|
28 |
+
- pydantic_core==2.33.2
|
29 |
+
- python-multipart==0.0.20
|
30 |
+
- setuptools==80.1.0
|
31 |
+
- sniffio==1.3.1
|
32 |
+
- starlette==0.46.2
|
33 |
+
- typing_extensions==4.13.2
|
34 |
+
- typing-inspection==0.4.0
|
35 |
+
- uvicorn==0.34.2
|
36 |
+
- Werkzeug==3.1.3
|
37 |
+
- wheel==0.45.1
|
38 |
+
- zipp==3.21.0
|
39 |
+
pip:
|
40 |
+
- --extra-index-url https://download.pytorch.org/whl/cu113
|