DeepLearning101 commited on
Commit
a12f4d9
·
verified ·
1 Parent(s): a23d52e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +127 -50
app.py CHANGED
@@ -1,60 +1,137 @@
 
 
 
 
 
 
 
1
  import os
2
  import time
3
- import json
4
- import gradio as gr
5
- import torch
6
- import torchaudio
7
- import numpy as np
8
- from denoiser.demucs import Demucs
9
- from pydub import AudioSegment
10
- from transformers import AutoModelForSequenceClassification, AutoTokenizer
11
-
12
- # 设置 Hugging Face Hub 的 Access Token
13
- auth_token = os.getenv("HF_TOKEN")
14
-
15
- # 加载私有模型
16
- model_id = "DeepLearning101/Speech-Quality-Inspection_Meta-Denoiser"
17
- model = AutoModelForSequenceClassification.from_pretrained(model_id, token=auth_token)
18
- tokenizer = AutoTokenizer.from_pretrained(model_id, token=auth_token)
19
-
20
- def transcribe(file_upload, microphone):
21
- file = microphone if microphone is not None else file_upload
22
- demucs_model = Demucs(hidden=64)
23
- state_dict = torch.load("path_to_model_checkpoint", map_location='cpu') # 请确保提供正确的模型文件路径
24
- demucs_model.load_state_dict(state_dict)
25
- x, sr = torchaudio.load(file)
26
- out = demucs_model(x[None])[0]
27
- out = out / max(out.abs().max().item(), 1)
28
- torchaudio.save('enhanced.wav', out, sr)
29
- enhanced = AudioSegment.from_wav('enhanced.wav') # 只有去完噪的需要降bitrate再做语音识别
30
- enhanced.export('enhanced.wav', format="wav", bitrate="256k")
31
-
32
- # 假设模型是用于文本分类
33
- inputs = tokenizer("enhanced.wav", return_tensors="pt")
34
- outputs = model(**inputs)
35
- predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
36
-
37
- return "enhanced.wav", predictions
38
-
39
- demo = gr.Interface(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  fn=transcribe,
41
  inputs=[
42
- gr.Audio(type="filepath", label="语音质检麦克风实时录音"),
43
- gr.Audio(type="filepath", label="语音质检原始音档"),
44
  ],
45
- outputs=[
46
- gr.Audio(type="filepath", label="Output"),
47
- gr.Textbox(label="Model Predictions")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  ],
49
- title="<p style='text-align: center'><a href='https://www.twman.org/AI' target='_blank'>语音质检噪音去除 (语音增强):Meta Denoiser</a>",
50
- description="为了提升语音识别的效果,可以在识别前先进行噪音去除",
 
 
 
 
 
 
 
51
  allow_flagging="never",
52
- examples=[
53
- ["exampleAudio/15s_2020-03-27_sep1.wav"],
54
- ["exampleAudio/13s_2020-03-27_sep2.wav"],
55
- ["exampleAudio/30s_2020-04-23_sep1.wav"],
56
- ["exampleAudio/15s_2020-04-23_sep2.wav"],
 
 
57
  ],
 
 
 
 
 
 
 
 
 
 
58
  )
59
 
60
- demo.launch(debug=True)
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ import yt_dlp as youtube_dl
4
+ from transformers import pipeline
5
+ from transformers.pipelines.audio_utils import ffmpeg_read
6
+
7
+ import tempfile
8
  import os
9
  import time
10
+
11
+ MODEL_NAME = "openai/whisper-large-v3"
12
+ BATCH_SIZE = 8
13
+ FILE_LIMIT_MB = 1000
14
+ YT_LENGTH_LIMIT_S = 3600 # limit to 1 hour YouTube files
15
+
16
+ device = 0 if torch.cuda.is_available() else "cpu"
17
+
18
+ pipe = pipeline(
19
+ task="automatic-speech-recognition",
20
+ model=MODEL_NAME,
21
+ chunk_length_s=30,
22
+ device=device,
23
+ )
24
+
25
+ def transcribe(inputs, task):
26
+ if inputs is None:
27
+ raise gr.Error("No audio file submitted! Please upload or record an audio file before submitting your request.")
28
+
29
+ text = pipe(inputs, batch_size=BATCH_SIZE, generate_kwargs={"task": task}, return_timestamps=True)["text"]
30
+ return text
31
+
32
+ def _return_yt_html_embed(yt_url):
33
+ video_id = yt_url.split("?v=")[-1]
34
+ HTML_str = (
35
+ f'<center> <iframe width="500" height="320" src="https://www.youtube.com/embed/{video_id}"> </iframe>'
36
+ " </center>"
37
+ )
38
+ return HTML_str
39
+
40
+ def download_yt_audio(yt_url, filename):
41
+ info_loader = youtube_dl.YoutubeDL()
42
+
43
+ try:
44
+ info = info_loader.extract_info(yt_url, download=False)
45
+ except youtube_dl.utils.DownloadError as err:
46
+ raise gr.Error(str(err))
47
+
48
+ file_length = info["duration"]
49
+ file_length_s = int(file_length)
50
+
51
+ if file_length_s > YT_LENGTH_LIMIT_S:
52
+ yt_length_limit_hms = time.strftime("%H:%M:%S", time.gmtime(YT_LENGTH_LIMIT_S))
53
+ file_length_hms = time.strftime("%H:%M:%S", time.gmtime(file_length_s))
54
+ raise gr.Error(f"Maximum YouTube length is {yt_length_limit_hms}, got {file_length_hms} YouTube video.")
55
+
56
+ ydl_opts = {"outtmpl": filename, "format": "bestaudio/best"}
57
+
58
+ with youtube_dl.YoutubeDL(ydl_opts) as ydl:
59
+ try:
60
+ ydl.download([yt_url])
61
+ except youtube_dl.utils.ExtractorError as err:
62
+ raise gr.Error(str(err))
63
+
64
+ def yt_transcribe(yt_url, task, max_filesize=75.0):
65
+ html_embed_str = _return_yt_html_embed(yt_url)
66
+
67
+ with tempfile.TemporaryDirectory() as tmpdirname:
68
+ filepath = os.path.join(tmpdirname, "audio.m4a")
69
+ download_yt_audio(yt_url, filepath)
70
+ with open(filepath, "rb") as f:
71
+ inputs = f.read()
72
+
73
+ inputs = ffmpeg_read(inputs, pipe.feature_extractor.sampling_rate)
74
+ inputs = {"array": inputs, "sampling_rate": pipe.feature_extractor.sampling_rate}
75
+
76
+ text = pipe(inputs, batch_size=BATCH_SIZE, generate_kwargs={"task": task}, return_timestamps=True)["text"]
77
+
78
+ return html_embed_str, text
79
+
80
+ mf_transcribe = gr.Interface(
81
  fn=transcribe,
82
  inputs=[
83
+ gr.Audio(type="filepath"),
84
+ gr.Radio(["transcribe", "translate"], label="Task", value="transcribe"),
85
  ],
86
+ outputs="text",
87
+ layout="horizontal",
88
+ theme="huggingface",
89
+ title="Whisper Large V3: Transcribe Audio",
90
+ description=(
91
+ "Transcribe long-form microphone or audio inputs with the click of a button! Demo uses the"
92
+ f" checkpoint [{MODEL_NAME}](https://huggingface.co/{MODEL_NAME}) and 🤗 Transformers to transcribe audio files"
93
+ " of arbitrary length."
94
+ ),
95
+ allow_flagging="never",
96
+ )
97
+
98
+ file_transcribe = gr.Interface(
99
+ fn=transcribe,
100
+ inputs=[
101
+ gr.Audio(type="filepath", label="Audio file"),
102
+ gr.Radio(["transcribe", "translate"], label="Task", value="transcribe"),
103
  ],
104
+ outputs="text",
105
+ layout="horizontal",
106
+ theme="huggingface",
107
+ title="Whisper Large V3: Transcribe Audio",
108
+ description=(
109
+ "Transcribe long-form microphone or audio inputs with the click of a button! Demo uses the"
110
+ f" checkpoint [{MODEL_NAME}](https://huggingface.co/{MODEL_NAME}) and 🤗 Transformers to transcribe audio files"
111
+ " of arbitrary length."
112
+ ),
113
  allow_flagging="never",
114
+ )
115
+
116
+ yt_transcribe = gr.Interface(
117
+ fn=yt_transcribe,
118
+ inputs=[
119
+ gr.Textbox(lines=1, placeholder="Paste the URL to a YouTube video here", label="YouTube URL"),
120
+ gr.Radio(["transcribe", "translate"], label="Task", value="transcribe")
121
  ],
122
+ outputs=["html", "text"],
123
+ layout="horizontal",
124
+ theme="huggingface",
125
+ title="Whisper Large V3: Transcribe YouTube",
126
+ description=(
127
+ "Transcribe long-form YouTube videos with the click of a button! Demo uses the checkpoint"
128
+ f" [{MODEL_NAME}](https://huggingface.co/{MODEL_NAME}) and 🤗 Transformers to transcribe video files of"
129
+ " arbitrary length."
130
+ ),
131
+ allow_flagging="never",
132
  )
133
 
134
+ with gr.Blocks() as demo:
135
+ gr.TabbedInterface([mf_transcribe, file_transcribe, yt_transcribe], ["Microphone", "Audio file", "YouTube"])
136
+
137
+ demo.launch(enable_queue=True)