Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,925 Bytes
1d7163f 7ee40ec 1d7163f 7ee40ec 1d7163f 972e738 1d7163f 972e738 1d7163f 972e738 1d7163f f5bbafc 1d7163f 7ee40ec 1d7163f 972e738 1d7163f 972e738 1d7163f f5bbafc 1d7163f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
import logging
import time
from typing import List, Literal
import librosa
from funasr import AutoModel
from resampy.core import resample
from tqdm.auto import tqdm
import torch
from corrector.Corrector import Corrector
from transcriber.TranscribeResult import TranscribeResult
logger = logging.getLogger(__name__)
device = "cuda" if torch.cuda.is_available() else "cpu"
class AutoTranscriber:
"""
Transcriber class that uses FunASR's AutoModel for VAD and ASR
"""
def __init__(
self,
corrector: Literal["opencc", "bert", None] = None,
use_denoiser=False,
with_punct=True,
offset_in_seconds=-0.25,
max_length_seconds=5,
sr=16000,
):
self.corrector = corrector
self.use_denoiser = use_denoiser
self.with_punct = with_punct
self.sr = sr
self.offset_in_seconds = offset_in_seconds
self.max_length_seconds = max_length_seconds
# Initialize models
self.vad_model = AutoModel(model="fsmn-vad", device=device, max_single_segment_time=self.max_length_seconds * 1000)
self.asr_model = AutoModel(
model="iic/SenseVoiceSmall",
vad_model=None, # We'll handle VAD separately
punc_model=None,
ban_emo_unks=True,
device=device,
)
def transcribe(
self,
audio_file: str,
) -> List[TranscribeResult]:
"""
Transcribe audio file to text with timestamps.
Args:
audio_file (str): Path to audio file
Returns:
List[TranscribeResult]: List of transcription results
"""
# Load and preprocess audio
speech, sr = librosa.load(audio_file, sr=self.sr)
# if self.use_denoiser:
# logger.info("Denoising speech...")
# speech, _ = denoiser(speech, sr)
if sr != 16_000:
speech = resample(speech, sr, 16_000, filter="kaiser_best", parallel=True)
# Get VAD segments
logger.info("Segmenting speech...")
start_time = time.time()
vad_results = self.vad_model.generate(input=speech, disable_pbar=True)
logger.info("VAD took %.2f seconds", time.time() - start_time)
if not vad_results or not vad_results[0]["value"]:
return []
vad_segments = vad_results[0]["value"]
# Process each segment
results = []
start_time = time.time()
for segment in tqdm(vad_segments, desc="Transcribing"):
start_sample = int(segment[0] * 16) # Convert ms to samples
end_sample = int(segment[1] * 16)
segment_audio = speech[start_sample:end_sample]
# Get ASR results for segment
asr_result = self.asr_model.generate(
input=segment_audio, language="yue", use_itn=self.with_punct
)
if not asr_result:
continue
start_time = max(0, segment[0] / 1000.0 + self.offset_in_seconds)
end_time = segment[1] / 1000.0 + self.offset_in_seconds
# Convert ASR result to TranscribeResult format
segment_result = TranscribeResult(
text=asr_result[0]["text"],
start_time=start_time, # Convert ms to seconds
end_time=end_time,
)
results.append(segment_result)
logger.info("ASR took %.2f seconds", time.time() - start_time)
# Apply Chinese conversion if needed
start_time = time.time()
results = self._convert_to_traditional_chinese(results)
logger.info("Conversion took %.2f seconds", time.time() - start_time)
return results
def _convert_to_traditional_chinese(
self, results: List[TranscribeResult]
) -> List[TranscribeResult]:
"""Convert simplified Chinese to traditional Chinese"""
if not results or not self.corrector:
return results
corrector = Corrector(self.corrector)
if self.corrector == "bert":
for result in tqdm(
results, total=len(results), desc="Converting to Traditional Chinese"
):
result.text = corrector.correct(result.text)
elif self.corrector == "opencc":
# Use a special delimiter that won't appear in Chinese text
delimiter = "|||"
# Concatenate all texts with delimiter
combined_text = delimiter.join(result.text for result in results)
# Convert all text at once
converted_text = corrector.correct(combined_text)
# Split back into individual results
converted_parts = converted_text.split(delimiter)
# Update results with converted text
for result, converted in zip(results, converted_parts):
result.text = converted
return results
|