cantonese-srt / transcriber /AutoTranscriber.py
laubonghaudoi's picture
Update transcriber/AutoTranscriber.py (#1)
f5bbafc verified
import logging
import time
from typing import List, Literal
import librosa
from funasr import AutoModel
from resampy.core import resample
from tqdm.auto import tqdm
import torch
from corrector.Corrector import Corrector
from transcriber.TranscribeResult import TranscribeResult
logger = logging.getLogger(__name__)
device = "cuda" if torch.cuda.is_available() else "cpu"
class AutoTranscriber:
"""
Transcriber class that uses FunASR's AutoModel for VAD and ASR
"""
def __init__(
self,
corrector: Literal["opencc", "bert", None] = None,
use_denoiser=False,
with_punct=True,
offset_in_seconds=-0.25,
max_length_seconds=5,
sr=16000,
):
self.corrector = corrector
self.use_denoiser = use_denoiser
self.with_punct = with_punct
self.sr = sr
self.offset_in_seconds = offset_in_seconds
self.max_length_seconds = max_length_seconds
# Initialize models
self.vad_model = AutoModel(model="fsmn-vad", device=device, max_single_segment_time=self.max_length_seconds * 1000)
self.asr_model = AutoModel(
model="iic/SenseVoiceSmall",
vad_model=None, # We'll handle VAD separately
punc_model=None,
ban_emo_unks=True,
device=device,
)
def transcribe(
self,
audio_file: str,
) -> List[TranscribeResult]:
"""
Transcribe audio file to text with timestamps.
Args:
audio_file (str): Path to audio file
Returns:
List[TranscribeResult]: List of transcription results
"""
# Load and preprocess audio
speech, sr = librosa.load(audio_file, sr=self.sr)
# if self.use_denoiser:
# logger.info("Denoising speech...")
# speech, _ = denoiser(speech, sr)
if sr != 16_000:
speech = resample(speech, sr, 16_000, filter="kaiser_best", parallel=True)
# Get VAD segments
logger.info("Segmenting speech...")
start_time = time.time()
vad_results = self.vad_model.generate(input=speech, disable_pbar=True)
logger.info("VAD took %.2f seconds", time.time() - start_time)
if not vad_results or not vad_results[0]["value"]:
return []
vad_segments = vad_results[0]["value"]
# Process each segment
results = []
start_time = time.time()
for segment in tqdm(vad_segments, desc="Transcribing"):
start_sample = int(segment[0] * 16) # Convert ms to samples
end_sample = int(segment[1] * 16)
segment_audio = speech[start_sample:end_sample]
# Get ASR results for segment
asr_result = self.asr_model.generate(
input=segment_audio, language="yue", use_itn=self.with_punct
)
if not asr_result:
continue
start_time = max(0, segment[0] / 1000.0 + self.offset_in_seconds)
end_time = segment[1] / 1000.0 + self.offset_in_seconds
# Convert ASR result to TranscribeResult format
segment_result = TranscribeResult(
text=asr_result[0]["text"],
start_time=start_time, # Convert ms to seconds
end_time=end_time,
)
results.append(segment_result)
logger.info("ASR took %.2f seconds", time.time() - start_time)
# Apply Chinese conversion if needed
start_time = time.time()
results = self._convert_to_traditional_chinese(results)
logger.info("Conversion took %.2f seconds", time.time() - start_time)
return results
def _convert_to_traditional_chinese(
self, results: List[TranscribeResult]
) -> List[TranscribeResult]:
"""Convert simplified Chinese to traditional Chinese"""
if not results or not self.corrector:
return results
corrector = Corrector(self.corrector)
if self.corrector == "bert":
for result in tqdm(
results, total=len(results), desc="Converting to Traditional Chinese"
):
result.text = corrector.correct(result.text)
elif self.corrector == "opencc":
# Use a special delimiter that won't appear in Chinese text
delimiter = "|||"
# Concatenate all texts with delimiter
combined_text = delimiter.join(result.text for result in results)
# Convert all text at once
converted_text = corrector.correct(combined_text)
# Split back into individual results
converted_parts = converted_text.split(delimiter)
# Update results with converted text
for result, converted in zip(results, converted_parts):
result.text = converted
return results