JulienHalgand commited on
Commit
e5007d2
·
1 Parent(s): b538a96

Ça marche sur le CPU

Browse files
Files changed (3) hide show
  1. README.md +2 -0
  2. app.py +67 -5
  3. environment.yml +40 -0
README.md CHANGED
@@ -10,3 +10,5 @@ pinned: false
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
13
+
14
+
app.py CHANGED
@@ -1,14 +1,18 @@
 
1
  import sys
2
  sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), 'amt/src')))
3
- import os
4
  import shutil
5
  import mimetypes
 
6
 
7
  import gradio as gr
 
8
 
9
  from model_helper import load_model_checkpoint, transcribe
10
  from prepare_media import prepare_media
11
 
 
 
12
  MODEL_NAME = 'YPTF.MoE+Multi (noPS)' # @param ["YMT3+", "YPTF+Single (noPS)", "YPTF+Multi (PS)", "YPTF.MoE+Multi (noPS)", "YPTF.MoE+Multi (PS)"]
13
  PRECISION = '16'# if torch.cuda.is_available() else '32'# @param ["32", "bf16-mixed", "16"]
14
  PROJECT = '2024'
@@ -44,24 +48,82 @@ MODELS = {
44
  }
45
  }
46
 
 
 
47
  model = load_model_checkpoint(args=MODELS[MODEL_NAME]["args"], device="cpu")
48
  #model.to("cuda")
49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
  def handle_audio(file_path):
52
  # Guess extension from MIME
53
  mime_type, _ = mimetypes.guess_type(file_path)
54
  ext = mimetypes.guess_extension(mime_type) or os.path.splitext(file_path)[1] or ".bin"
55
-
56
  output_path = f"received_audio{ext}"
57
  shutil.copy(file_path, output_path)
58
- return output_path
 
 
 
 
59
 
60
  demo = gr.Interface(
61
  fn=handle_audio,
62
  inputs=gr.Audio(type="filepath"),
63
- outputs=gr.File()
64
  )
65
 
66
  if __name__ == "__main__":
67
- demo.launch()
 
 
 
1
+ import os
2
  import sys
3
  sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), 'amt/src')))
 
4
  import shutil
5
  import mimetypes
6
+ import subprocess
7
 
8
  import gradio as gr
9
+ import torchaudio
10
 
11
  from model_helper import load_model_checkpoint, transcribe
12
  from prepare_media import prepare_media
13
 
14
+ from typing import Tuple, Dict, Literal
15
+
16
  MODEL_NAME = 'YPTF.MoE+Multi (noPS)' # @param ["YMT3+", "YPTF+Single (noPS)", "YPTF+Multi (PS)", "YPTF.MoE+Multi (noPS)", "YPTF.MoE+Multi (PS)"]
17
  PRECISION = '16'# if torch.cuda.is_available() else '32'# @param ["32", "bf16-mixed", "16"]
18
  PROJECT = '2024'
 
48
  }
49
  }
50
 
51
+ log_file = 'amt/log.txt'
52
+
53
  model = load_model_checkpoint(args=MODELS[MODEL_NAME]["args"], device="cpu")
54
  #model.to("cuda")
55
 
56
+ def prepare_media(source_path_or_url: os.PathLike,
57
+ source_type: Literal['audio_filepath', 'youtube_url'],
58
+ delete_video: bool = True,
59
+ simulate = False) -> Dict:
60
+ """prepare media from source path or youtube, and return audio info"""
61
+ # Get audio_file
62
+ if source_type == 'audio_filepath':
63
+ audio_file = source_path_or_url
64
+ elif source_type == 'youtube_url':
65
+ if os.path.exists('/download/yt_audio.mp3'):
66
+ os.remove('/download/yt_audio.mp3')
67
+ # Download from youtube
68
+ with open(log_file, 'w') as lf:
69
+ audio_file = './downloaded/yt_audio'
70
+ command = ['yt-dlp', '-x', source_path_or_url, '-f', 'bestaudio',
71
+ '-o', audio_file, '--audio-format', 'mp3', '--restrict-filenames',
72
+ '--extractor-retries', '10',
73
+ '--force-overwrites', '--username', 'oauth2', '--password', '', '-v']
74
+ if simulate:
75
+ command = command + ['-s']
76
+ process = subprocess.Popen(command,
77
+ stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True)
78
+
79
+ for line in iter(process.stdout.readline, ''):
80
+ # Filter out unnecessary messages
81
+ print(line)
82
+ if "www.google.com/device" in line:
83
+ hl_text = line.replace("https://www.google.com/device", "\033[93mhttps://www.google.com/device\x1b[0m").split()
84
+ hl_text[-1] = "\x1b[31;1m" + hl_text[-1] + "\x1b[0m"
85
+ lf.write(' '.join(hl_text)); lf.flush()
86
+ elif "Authorization successful" in line or "Video unavailable" in line:
87
+ lf.write(line); lf.flush()
88
+ process.stdout.close()
89
+ process.wait()
90
+
91
+ audio_file += '.mp3'
92
+ else:
93
+ raise ValueError(source_type)
94
+
95
+ # Create info
96
+ info = torchaudio.info(audio_file)
97
+ return {
98
+ "filepath": audio_file,
99
+ "track_name": os.path.basename(audio_file).split('.')[0],
100
+ "sample_rate": int(info.sample_rate),
101
+ "bits_per_sample": int(info.bits_per_sample),
102
+ "num_channels": int(info.num_channels),
103
+ "num_frames": int(info.num_frames),
104
+ "duration": int(info.num_frames / info.sample_rate),
105
+ "encoding": str.lower(info.encoding),
106
+ }
107
 
108
  def handle_audio(file_path):
109
  # Guess extension from MIME
110
  mime_type, _ = mimetypes.guess_type(file_path)
111
  ext = mimetypes.guess_extension(mime_type) or os.path.splitext(file_path)[1] or ".bin"
 
112
  output_path = f"received_audio{ext}"
113
  shutil.copy(file_path, output_path)
114
+
115
+ audio_info = prepare_media(output_path, source_type='audio_filepath')
116
+ midifile_path = transcribe(model, audio_info)
117
+
118
+ return midifile_path
119
 
120
  demo = gr.Interface(
121
  fn=handle_audio,
122
  inputs=gr.Audio(type="filepath"),
123
+ outputs=gr.File(),
124
  )
125
 
126
  if __name__ == "__main__":
127
+ demo.launch(
128
+ server_port=7860
129
+ )
environment.yml ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: gradio
2
+ channels:
3
+ - conda-forge
4
+ dependencies:
5
+ - python=3.12
6
+ - lightning>=2.2.1
7
+ - deprecated
8
+ - librosa
9
+ - einops
10
+ - transformers==4.45.1
11
+ - numpy==1.26.4
12
+ - wandb
13
+ - annotated-types==0.7.0
14
+ - anyio==4.9.0
15
+ - blinker==1.9.0
16
+ - ertifi==2025.4.26
17
+ - click==8.1.8
18
+ - fastapi==0.115.12
19
+ - Flask==3.1.0
20
+ - h11==0.16.0
21
+ - idna==3.10
22
+ - importlib_metadata==8.6.1
23
+ - itsdangerous==2.2.0
24
+ - Jinja2==3.1.6
25
+ - MarkupSafe==2.1.5
26
+ - pip==25.1.1
27
+ - pydantic==2.11.4
28
+ - pydantic_core==2.33.2
29
+ - python-multipart==0.0.20
30
+ - setuptools==80.1.0
31
+ - sniffio==1.3.1
32
+ - starlette==0.46.2
33
+ - typing_extensions==4.13.2
34
+ - typing-inspection==0.4.0
35
+ - uvicorn==0.34.2
36
+ - Werkzeug==3.1.3
37
+ - wheel==0.45.1
38
+ - zipp==3.21.0
39
+ pip:
40
+ - --extra-index-url https://download.pytorch.org/whl/cu113