Spaces:
Sleeping
Sleeping
sachin
commited on
Commit
·
cd681db
1
Parent(s):
62cf342
test-multiple mode;
Browse files- src/asr_api.py +70 -139
src/asr_api.py
CHANGED
@@ -1,8 +1,7 @@
|
|
1 |
import torch
|
2 |
import nemo.collections.asr as nemo_asr
|
3 |
from fastapi import FastAPI, File, UploadFile, HTTPException, Query
|
4 |
-
from fastapi.responses import RedirectResponse
|
5 |
-
from fastapi.responses import JSONResponse
|
6 |
from pydantic import BaseModel
|
7 |
from pydub import AudioSegment
|
8 |
import os
|
@@ -18,44 +17,25 @@ import argparse
|
|
18 |
import uvicorn
|
19 |
import shutil
|
20 |
|
21 |
-
|
22 |
# Configure logging with log rotation
|
23 |
logging.basicConfig(
|
24 |
level=logging.INFO,
|
25 |
format='%(asctime)s - %(levelname)s - %(message)s',
|
26 |
handlers=[
|
27 |
-
RotatingFileHandler("transcription_api.log", maxBytes=10*1024*1024, backupCount=5),
|
28 |
-
logging.StreamHandler()
|
29 |
]
|
30 |
)
|
31 |
|
32 |
class ASRModelManager:
|
33 |
-
def __init__(self,
|
34 |
-
self.default_language = default_language
|
35 |
self.device_type = device_type
|
36 |
self.model_language = {
|
37 |
-
"kannada": "kn",
|
38 |
-
"
|
39 |
-
"
|
40 |
-
"
|
41 |
-
"
|
42 |
-
"bodo": "brx",
|
43 |
-
"dogri": "doi",
|
44 |
-
"gujarati": "gu",
|
45 |
-
"kashmiri": "ks",
|
46 |
-
"konkani": "kok",
|
47 |
-
"maithili": "mai",
|
48 |
-
"manipuri": "mni",
|
49 |
-
"marathi": "mr",
|
50 |
-
"nepali": "ne",
|
51 |
-
"odia": "or",
|
52 |
-
"punjabi": "pa",
|
53 |
-
"sanskrit": "sa",
|
54 |
-
"santali": "sat",
|
55 |
-
"sindhi": "sd",
|
56 |
-
"tamil": "ta",
|
57 |
-
"telugu": "te",
|
58 |
-
"urdu": "ur"
|
59 |
}
|
60 |
self.config_models = {
|
61 |
"as": "ai4bharat/indicconformer_stt_as_hybrid_rnnt_large",
|
@@ -81,65 +61,68 @@ class ASRModelManager:
|
|
81 |
"te": "ai4bharat/indicconformer_stt_te_hybrid_rnnt_large",
|
82 |
"ur": "ai4bharat/indicconformer_stt_ur_hybrid_rnnt_large"
|
83 |
}
|
84 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
|
86 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
87 |
model_name = self.config_models.get(language_id, self.config_models["kn"])
|
88 |
model = nemo_asr.models.ASRModel.from_pretrained(model_name)
|
89 |
-
|
90 |
device = torch.device(self.device_type if torch.cuda.is_available() and self.device_type == "cuda" else "cpu")
|
91 |
-
model.freeze()
|
92 |
-
model = model.to(device)
|
93 |
-
|
94 |
return model
|
95 |
|
96 |
def split_audio(self, file_path, chunk_duration_ms=15000):
|
97 |
-
"""
|
98 |
-
Splits an audio file into chunks of specified duration if the audio duration exceeds the chunk duration.
|
99 |
-
|
100 |
-
:param file_path: Path to the audio file.
|
101 |
-
:param chunk_duration_ms: Duration of each chunk in milliseconds (default is 15000 ms or 15 seconds).
|
102 |
-
"""
|
103 |
-
# Load the audio file
|
104 |
audio = AudioSegment.from_file(file_path)
|
105 |
-
|
106 |
-
# Get the duration of the audio in milliseconds
|
107 |
duration_ms = len(audio)
|
108 |
-
|
109 |
-
# Check if the duration is more than the specified chunk duration
|
110 |
if duration_ms > chunk_duration_ms:
|
111 |
-
|
112 |
-
num_chunks = (duration_ms + chunk_duration_ms - 1) // chunk_duration_ms # This handles the remainder correctly
|
113 |
-
|
114 |
-
# Split the audio into chunks
|
115 |
chunks = [audio[i*chunk_duration_ms:min((i+1)*chunk_duration_ms, duration_ms)] for i in range(num_chunks)]
|
116 |
-
|
117 |
-
# Create a directory to save the chunks
|
118 |
output_dir = "audio_chunks"
|
119 |
os.makedirs(output_dir, exist_ok=True)
|
120 |
-
|
121 |
-
# Export each chunk to separate files
|
122 |
chunk_file_paths = []
|
123 |
for i, chunk in enumerate(chunks):
|
124 |
chunk_file_path = os.path.join(output_dir, f"chunk_{i}.wav")
|
125 |
chunk.export(chunk_file_path, format="wav")
|
126 |
chunk_file_paths.append(chunk_file_path)
|
127 |
-
print(f"Chunk {i} exported successfully to {chunk_file_path}.")
|
128 |
-
|
129 |
return chunk_file_paths
|
130 |
else:
|
131 |
return [file_path]
|
132 |
-
|
133 |
-
def cleanup(
|
134 |
-
# Create a directory to save the chunks
|
135 |
output_dir = "audio_chunks"
|
136 |
if os.path.exists(output_dir):
|
137 |
shutil.rmtree(output_dir)
|
138 |
|
139 |
app = FastAPI()
|
140 |
-
asr_manager = ASRModelManager()
|
141 |
|
142 |
-
# Define the response model
|
143 |
class TranscriptionResponse(BaseModel):
|
144 |
text: str
|
145 |
|
@@ -150,75 +133,52 @@ class BatchTranscriptionResponse(BaseModel):
|
|
150 |
async def transcribe_audio(file: UploadFile = File(...), language: str = Query(..., enum=list(asr_manager.model_language.keys()))):
|
151 |
start_time = time()
|
152 |
try:
|
153 |
-
# Check file extension
|
154 |
file_extension = file.filename.split(".")[-1].lower()
|
155 |
if file_extension not in ["wav", "mp3"]:
|
156 |
logging.warning(f"Unsupported file format: {file_extension}")
|
157 |
raise HTTPException(status_code=400, detail="Unsupported file format. Please upload a WAV or MP3 file.")
|
158 |
|
159 |
-
# Read the file content
|
160 |
file_content = await file.read()
|
161 |
-
|
162 |
-
# Convert MP3 to WAV if necessary
|
163 |
if file_extension == "mp3":
|
164 |
audio = AudioSegment.from_mp3(io.BytesIO(file_content))
|
165 |
else:
|
166 |
audio = AudioSegment.from_wav(io.BytesIO(file_content))
|
167 |
|
168 |
-
|
169 |
-
sample_rate = audio.frame_rate
|
170 |
-
|
171 |
-
# Convert WAV to the required format using ffmpeg if necessary
|
172 |
-
if sample_rate != 16000:
|
173 |
audio = audio.set_frame_rate(16000).set_channels(1)
|
174 |
|
175 |
-
# Export the audio to a temporary WAV file
|
176 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
|
177 |
audio.export(tmp_file.name, format="wav")
|
178 |
tmp_file_path = tmp_file.name
|
179 |
|
180 |
-
# Split the audio if necessary
|
181 |
chunk_file_paths = asr_manager.split_audio(tmp_file_path)
|
182 |
|
183 |
try:
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
if language_id != asr_manager.default_language:
|
188 |
-
asr_manager.model = asr_manager.load_model(language_id)
|
189 |
-
asr_manager.default_language = language_id
|
190 |
-
|
191 |
-
asr_manager.model.cur_decoder = "rnnt"
|
192 |
|
193 |
transcriptions = []
|
194 |
for chunk_file_path in chunk_file_paths:
|
195 |
-
rnnt_texts =
|
196 |
if isinstance(rnnt_texts, list) and len(rnnt_texts) > 0:
|
197 |
transcriptions.append(rnnt_texts[0])
|
198 |
else:
|
199 |
transcriptions.append(rnnt_texts)
|
200 |
|
201 |
joined_transcriptions = ' '.join(transcriptions)
|
202 |
-
|
203 |
-
# Process rnnt_texts as needed
|
204 |
-
#with torch.amp.autocast('cuda', dtype=torch.bfloat16):
|
205 |
-
# rnnt_texts = asr_manager.model.transcribe(chunk_file_paths, batch_size=1, language_id=language_id)
|
206 |
-
#print(joined_transcriptions)
|
207 |
end_time = time()
|
208 |
logging.info(f"Transcription completed in {end_time - start_time:.2f} seconds")
|
209 |
return JSONResponse(content={"text": joined_transcriptions})
|
210 |
-
|
211 |
-
logging.error(f"FFmpeg conversion failed: {str(e)}")
|
212 |
-
raise HTTPException(status_code=500, detail=f"FFmpeg conversion failed: {str(e)}")
|
213 |
-
except Exception as e:
|
214 |
-
logging.error(f"An error occurred during processing: {str(e)}")
|
215 |
-
raise HTTPException(status_code=500, detail=f"An error occurred during processing: {str(e)}")
|
216 |
finally:
|
217 |
-
# Clean up temporary files
|
218 |
for chunk_file_path in chunk_file_paths:
|
219 |
if os.path.exists(chunk_file_path):
|
220 |
os.remove(chunk_file_path)
|
|
|
|
|
221 |
asr_manager.cleanup()
|
|
|
222 |
except HTTPException as e:
|
223 |
logging.error(f"HTTPException: {str(e)}")
|
224 |
raise e
|
@@ -233,83 +193,57 @@ async def home():
|
|
233 |
@app.post("/transcribe_batch/", response_model=BatchTranscriptionResponse)
|
234 |
async def transcribe_audio_batch(files: List[UploadFile] = File(...), language: str = Query(..., enum=list(asr_manager.model_language.keys()))):
|
235 |
start_time = time()
|
236 |
-
tmp_file_paths = []
|
237 |
all_transcriptions = []
|
238 |
try:
|
239 |
for file in files:
|
240 |
-
# Check file extension
|
241 |
file_extension = file.filename.split(".")[-1].lower()
|
242 |
if file_extension not in ["wav", "mp3"]:
|
243 |
logging.warning(f"Unsupported file format: {file_extension}")
|
244 |
raise HTTPException(status_code=400, detail="Unsupported file format. Please upload WAV or MP3 files.")
|
245 |
|
246 |
-
# Read the file content
|
247 |
file_content = await file.read()
|
248 |
-
|
249 |
-
# Convert MP3 to WAV if necessary
|
250 |
if file_extension == "mp3":
|
251 |
audio = AudioSegment.from_mp3(io.BytesIO(file_content))
|
252 |
else:
|
253 |
audio = AudioSegment.from_wav(io.BytesIO(file_content))
|
254 |
|
255 |
-
|
256 |
-
sample_rate = audio.frame_rate
|
257 |
-
|
258 |
-
# Convert WAV to the required format using ffmpeg if necessary
|
259 |
-
if sample_rate != 16000:
|
260 |
audio = audio.set_frame_rate(16000).set_channels(1)
|
261 |
|
262 |
-
# Export the audio to a temporary WAV file
|
263 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
|
264 |
audio.export(tmp_file.name, format="wav")
|
265 |
tmp_file_path = tmp_file.name
|
266 |
|
267 |
-
# Split the audio if necessary
|
268 |
chunk_file_paths = asr_manager.split_audio(tmp_file_path)
|
269 |
-
#tmp_file_paths.extend(chunk_file_paths)
|
270 |
|
271 |
-
logging.info(f"Temporary file paths: {chunk_file_paths}")
|
272 |
try:
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
if language_id != asr_manager.default_language:
|
277 |
-
asr_manager.model = asr_manager.load_model(language_id)
|
278 |
-
asr_manager.default_language = language_id
|
279 |
-
|
280 |
-
asr_manager.model.cur_decoder = "rnnt"
|
281 |
|
282 |
transcriptions = []
|
283 |
for chunk_file_path in chunk_file_paths:
|
284 |
-
rnnt_texts =
|
285 |
if isinstance(rnnt_texts, list) and len(rnnt_texts) > 0:
|
286 |
transcriptions.append(rnnt_texts[0])
|
287 |
else:
|
288 |
transcriptions.append(rnnt_texts)
|
289 |
|
290 |
joined_transcriptions = ' '.join(transcriptions)
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
logging.info(f"Raw transcriptions from model: {joined_transcriptions}")
|
295 |
-
end_time = time()
|
296 |
-
logging.info(f"Transcription completed in {end_time - start_time:.2f} seconds")
|
297 |
-
|
298 |
all_transcriptions.append(joined_transcriptions)
|
299 |
-
|
300 |
-
#transcriptions = [text for sublist in rnnt_texts for text in sublist]
|
301 |
-
except subprocess.CalledProcessError as e:
|
302 |
-
logging.error(f"FFmpeg conversion failed: {str(e)}")
|
303 |
-
raise HTTPException(status_code=500, detail=f"FFmpeg conversion failed: {str(e)}")
|
304 |
-
except Exception as e:
|
305 |
-
logging.error(f"An error occurred during processing: {str(e)}")
|
306 |
-
raise HTTPException(status_code=500, detail=f"An error occurred during processing: {str(e)}")
|
307 |
finally:
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
|
|
312 |
asr_manager.cleanup()
|
|
|
|
|
|
|
|
|
|
|
313 |
except HTTPException as e:
|
314 |
logging.error(f"HTTPException: {str(e)}")
|
315 |
raise e
|
@@ -317,15 +251,12 @@ async def transcribe_audio_batch(files: List[UploadFile] = File(...), language:
|
|
317 |
logging.error(f"An unexpected error occurred: {str(e)}")
|
318 |
raise HTTPException(status_code=500, detail=f"An unexpected error occurred: {str(e)}")
|
319 |
|
320 |
-
return JSONResponse(content={"transcriptions": all_transcriptions})
|
321 |
-
|
322 |
if __name__ == "__main__":
|
323 |
parser = argparse.ArgumentParser(description="Run the FastAPI server for ASR.")
|
324 |
parser.add_argument("--port", type=int, default=8888, help="Port to run the server on.")
|
325 |
-
parser.add_argument("--language", type=str, default="kn", help="Default language for the ASR model.")
|
326 |
parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to run the server on.")
|
327 |
parser.add_argument("--device", type=str, default="cuda", help="Device type to run the model on (cuda or cpu).")
|
328 |
args = parser.parse_args()
|
329 |
|
330 |
-
asr_manager = ASRModelManager(
|
331 |
uvicorn.run(app, host=args.host, port=args.port)
|
|
|
1 |
import torch
|
2 |
import nemo.collections.asr as nemo_asr
|
3 |
from fastapi import FastAPI, File, UploadFile, HTTPException, Query
|
4 |
+
from fastapi.responses import RedirectResponse, JSONResponse
|
|
|
5 |
from pydantic import BaseModel
|
6 |
from pydub import AudioSegment
|
7 |
import os
|
|
|
17 |
import uvicorn
|
18 |
import shutil
|
19 |
|
|
|
20 |
# Configure logging with log rotation
|
21 |
logging.basicConfig(
|
22 |
level=logging.INFO,
|
23 |
format='%(asctime)s - %(levelname)s - %(message)s',
|
24 |
handlers=[
|
25 |
+
RotatingFileHandler("transcription_api.log", maxBytes=10*1024*1024, backupCount=5),
|
26 |
+
logging.StreamHandler()
|
27 |
]
|
28 |
)
|
29 |
|
30 |
class ASRModelManager:
|
31 |
+
def __init__(self, languages_to_load=["kn", "hi", "ta", "te", "ml"], device_type="cuda"):
|
|
|
32 |
self.device_type = device_type
|
33 |
self.model_language = {
|
34 |
+
"kannada": "kn", "hindi": "hi", "malayalam": "ml", "assamese": "as", "bengali": "bn",
|
35 |
+
"bodo": "brx", "dogri": "doi", "gujarati": "gu", "kashmiri": "ks", "konkani": "kok",
|
36 |
+
"maithili": "mai", "manipuri": "mni", "marathi": "mr", "nepali": "ne", "odia": "or",
|
37 |
+
"punjabi": "pa", "sanskrit": "sa", "santali": "sat", "sindhi": "sd", "tamil": "ta",
|
38 |
+
"telugu": "te", "urdu": "ur"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
}
|
40 |
self.config_models = {
|
41 |
"as": "ai4bharat/indicconformer_stt_as_hybrid_rnnt_large",
|
|
|
61 |
"te": "ai4bharat/indicconformer_stt_te_hybrid_rnnt_large",
|
62 |
"ur": "ai4bharat/indicconformer_stt_ur_hybrid_rnnt_large"
|
63 |
}
|
64 |
+
# Load models for specified languages on startup
|
65 |
+
self.models = {}
|
66 |
+
self.load_initial_models(languages_to_load)
|
67 |
+
|
68 |
+
def load_initial_models(self, languages):
|
69 |
+
device = torch.device(self.device_type if torch.cuda.is_available() and self.device_type == "cuda" else "cpu")
|
70 |
+
logging.info(f"Loading models on device: {device}")
|
71 |
+
for lang_id in languages:
|
72 |
+
if lang_id not in self.config_models:
|
73 |
+
logging.warning(f"No model available for language ID: {lang_id}. Skipping.")
|
74 |
+
continue
|
75 |
+
try:
|
76 |
+
model_name = self.config_models[lang_id]
|
77 |
+
logging.info(f"Loading model for {lang_id}: {model_name}")
|
78 |
+
model = nemo_asr.models.ASRModel.from_pretrained(model_name)
|
79 |
+
model.freeze() # Set to inference mode
|
80 |
+
model = model.to(device)
|
81 |
+
self.models[lang_id] = model
|
82 |
+
logging.info(f"Successfully loaded model for {lang_id}")
|
83 |
+
except Exception as e:
|
84 |
+
logging.error(f"Failed to load model for {lang_id}: {str(e)}")
|
85 |
|
86 |
+
def get_model(self, language_id):
|
87 |
+
if language_id not in self.models:
|
88 |
+
logging.warning(f"Model for {language_id} not pre-loaded. Loading now...")
|
89 |
+
model = self.load_model(language_id)
|
90 |
+
self.models[language_id] = model
|
91 |
+
return self.models[language_id]
|
92 |
+
|
93 |
+
def load_model(self, language_id):
|
94 |
model_name = self.config_models.get(language_id, self.config_models["kn"])
|
95 |
model = nemo_asr.models.ASRModel.from_pretrained(model_name)
|
|
|
96 |
device = torch.device(self.device_type if torch.cuda.is_available() and self.device_type == "cuda" else "cpu")
|
97 |
+
model.freeze()
|
98 |
+
model = model.to(device)
|
|
|
99 |
return model
|
100 |
|
101 |
def split_audio(self, file_path, chunk_duration_ms=15000):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
audio = AudioSegment.from_file(file_path)
|
|
|
|
|
103 |
duration_ms = len(audio)
|
|
|
|
|
104 |
if duration_ms > chunk_duration_ms:
|
105 |
+
num_chunks = (duration_ms + chunk_duration_ms - 1) // chunk_duration_ms
|
|
|
|
|
|
|
106 |
chunks = [audio[i*chunk_duration_ms:min((i+1)*chunk_duration_ms, duration_ms)] for i in range(num_chunks)]
|
|
|
|
|
107 |
output_dir = "audio_chunks"
|
108 |
os.makedirs(output_dir, exist_ok=True)
|
|
|
|
|
109 |
chunk_file_paths = []
|
110 |
for i, chunk in enumerate(chunks):
|
111 |
chunk_file_path = os.path.join(output_dir, f"chunk_{i}.wav")
|
112 |
chunk.export(chunk_file_path, format="wav")
|
113 |
chunk_file_paths.append(chunk_file_path)
|
|
|
|
|
114 |
return chunk_file_paths
|
115 |
else:
|
116 |
return [file_path]
|
117 |
+
|
118 |
+
def cleanup(self):
|
|
|
119 |
output_dir = "audio_chunks"
|
120 |
if os.path.exists(output_dir):
|
121 |
shutil.rmtree(output_dir)
|
122 |
|
123 |
app = FastAPI()
|
124 |
+
asr_manager = ASRModelManager(languages_to_load=["kn", "hi", "ta", "te", "ml"]) # Load Kannada, Hindi, Tamil, Telugu, Malayalam
|
125 |
|
|
|
126 |
class TranscriptionResponse(BaseModel):
|
127 |
text: str
|
128 |
|
|
|
133 |
async def transcribe_audio(file: UploadFile = File(...), language: str = Query(..., enum=list(asr_manager.model_language.keys()))):
|
134 |
start_time = time()
|
135 |
try:
|
|
|
136 |
file_extension = file.filename.split(".")[-1].lower()
|
137 |
if file_extension not in ["wav", "mp3"]:
|
138 |
logging.warning(f"Unsupported file format: {file_extension}")
|
139 |
raise HTTPException(status_code=400, detail="Unsupported file format. Please upload a WAV or MP3 file.")
|
140 |
|
|
|
141 |
file_content = await file.read()
|
|
|
|
|
142 |
if file_extension == "mp3":
|
143 |
audio = AudioSegment.from_mp3(io.BytesIO(file_content))
|
144 |
else:
|
145 |
audio = AudioSegment.from_wav(io.BytesIO(file_content))
|
146 |
|
147 |
+
if audio.frame_rate != 16000:
|
|
|
|
|
|
|
|
|
148 |
audio = audio.set_frame_rate(16000).set_channels(1)
|
149 |
|
|
|
150 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
|
151 |
audio.export(tmp_file.name, format="wav")
|
152 |
tmp_file_path = tmp_file.name
|
153 |
|
|
|
154 |
chunk_file_paths = asr_manager.split_audio(tmp_file_path)
|
155 |
|
156 |
try:
|
157 |
+
language_id = asr_manager.model_language.get(language, "kn")
|
158 |
+
model = asr_manager.get_model(language_id)
|
159 |
+
model.cur_decoder = "rnnt"
|
|
|
|
|
|
|
|
|
|
|
160 |
|
161 |
transcriptions = []
|
162 |
for chunk_file_path in chunk_file_paths:
|
163 |
+
rnnt_texts = model.transcribe([chunk_file_path], batch_size=1, language_id=language_id)[0]
|
164 |
if isinstance(rnnt_texts, list) and len(rnnt_texts) > 0:
|
165 |
transcriptions.append(rnnt_texts[0])
|
166 |
else:
|
167 |
transcriptions.append(rnnt_texts)
|
168 |
|
169 |
joined_transcriptions = ' '.join(transcriptions)
|
|
|
|
|
|
|
|
|
|
|
170 |
end_time = time()
|
171 |
logging.info(f"Transcription completed in {end_time - start_time:.2f} seconds")
|
172 |
return JSONResponse(content={"text": joined_transcriptions})
|
173 |
+
|
|
|
|
|
|
|
|
|
|
|
174 |
finally:
|
|
|
175 |
for chunk_file_path in chunk_file_paths:
|
176 |
if os.path.exists(chunk_file_path):
|
177 |
os.remove(chunk_file_path)
|
178 |
+
if os.path.exists(tmp_file_path):
|
179 |
+
os.remove(tmp_file_path)
|
180 |
asr_manager.cleanup()
|
181 |
+
|
182 |
except HTTPException as e:
|
183 |
logging.error(f"HTTPException: {str(e)}")
|
184 |
raise e
|
|
|
193 |
@app.post("/transcribe_batch/", response_model=BatchTranscriptionResponse)
|
194 |
async def transcribe_audio_batch(files: List[UploadFile] = File(...), language: str = Query(..., enum=list(asr_manager.model_language.keys()))):
|
195 |
start_time = time()
|
|
|
196 |
all_transcriptions = []
|
197 |
try:
|
198 |
for file in files:
|
|
|
199 |
file_extension = file.filename.split(".")[-1].lower()
|
200 |
if file_extension not in ["wav", "mp3"]:
|
201 |
logging.warning(f"Unsupported file format: {file_extension}")
|
202 |
raise HTTPException(status_code=400, detail="Unsupported file format. Please upload WAV or MP3 files.")
|
203 |
|
|
|
204 |
file_content = await file.read()
|
|
|
|
|
205 |
if file_extension == "mp3":
|
206 |
audio = AudioSegment.from_mp3(io.BytesIO(file_content))
|
207 |
else:
|
208 |
audio = AudioSegment.from_wav(io.BytesIO(file_content))
|
209 |
|
210 |
+
if audio.frame_rate != 16000:
|
|
|
|
|
|
|
|
|
211 |
audio = audio.set_frame_rate(16000).set_channels(1)
|
212 |
|
|
|
213 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
|
214 |
audio.export(tmp_file.name, format="wav")
|
215 |
tmp_file_path = tmp_file.name
|
216 |
|
|
|
217 |
chunk_file_paths = asr_manager.split_audio(tmp_file_path)
|
|
|
218 |
|
|
|
219 |
try:
|
220 |
+
language_id = asr_manager.model_language.get(language, "kn")
|
221 |
+
model = asr_manager.get_model(language_id)
|
222 |
+
model.cur_decoder = "rnnt"
|
|
|
|
|
|
|
|
|
|
|
223 |
|
224 |
transcriptions = []
|
225 |
for chunk_file_path in chunk_file_paths:
|
226 |
+
rnnt_texts = model.transcribe([chunk_file_path], batch_size=1, language_id=language_id)[0]
|
227 |
if isinstance(rnnt_texts, list) and len(rnnt_texts) > 0:
|
228 |
transcriptions.append(rnnt_texts[0])
|
229 |
else:
|
230 |
transcriptions.append(rnnt_texts)
|
231 |
|
232 |
joined_transcriptions = ' '.join(transcriptions)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
233 |
all_transcriptions.append(joined_transcriptions)
|
234 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
235 |
finally:
|
236 |
+
for chunk_file_path in chunk_file_paths:
|
237 |
+
if os.path.exists(chunk_file_path):
|
238 |
+
os.remove(chunk_file_path)
|
239 |
+
if os.path.exists(tmp_file_path):
|
240 |
+
os.remove(tmp_file_path)
|
241 |
asr_manager.cleanup()
|
242 |
+
|
243 |
+
end_time = time()
|
244 |
+
logging.info(f"Batch transcription completed in {end_time - start_time:.2f} seconds")
|
245 |
+
return JSONResponse(content={"transcriptions": all_transcriptions})
|
246 |
+
|
247 |
except HTTPException as e:
|
248 |
logging.error(f"HTTPException: {str(e)}")
|
249 |
raise e
|
|
|
251 |
logging.error(f"An unexpected error occurred: {str(e)}")
|
252 |
raise HTTPException(status_code=500, detail=f"An unexpected error occurred: {str(e)}")
|
253 |
|
|
|
|
|
254 |
if __name__ == "__main__":
|
255 |
parser = argparse.ArgumentParser(description="Run the FastAPI server for ASR.")
|
256 |
parser.add_argument("--port", type=int, default=8888, help="Port to run the server on.")
|
|
|
257 |
parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to run the server on.")
|
258 |
parser.add_argument("--device", type=str, default="cuda", help="Device type to run the model on (cuda or cpu).")
|
259 |
args = parser.parse_args()
|
260 |
|
261 |
+
asr_manager = ASRModelManager(languages_to_load=["kn", "hi", "ta", "te", "ml"], device_type=args.device)
|
262 |
uvicorn.run(app, host=args.host, port=args.port)
|