File size: 4,925 Bytes
1d7163f
 
 
 
 
 
 
 
7ee40ec
1d7163f
 
 
 
 
7ee40ec
1d7163f
 
 
 
 
 
 
 
 
 
 
 
972e738
1d7163f
 
 
 
 
 
 
972e738
1d7163f
 
972e738
1d7163f
 
 
f5bbafc
1d7163f
7ee40ec
1d7163f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
972e738
1d7163f
 
 
 
 
972e738
1d7163f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f5bbafc
1d7163f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import logging
import time
from typing import List, Literal

import librosa
from funasr import AutoModel
from resampy.core import resample
from tqdm.auto import tqdm
import torch

from corrector.Corrector import Corrector
from transcriber.TranscribeResult import TranscribeResult

logger = logging.getLogger(__name__)
device = "cuda" if torch.cuda.is_available() else "cpu"

class AutoTranscriber:
    """
    Transcriber class that uses FunASR's AutoModel for VAD and ASR
    """

    def __init__(
        self,
        corrector: Literal["opencc", "bert", None] = None,
        use_denoiser=False,
        with_punct=True,
        offset_in_seconds=-0.25,
        max_length_seconds=5,
        sr=16000,
    ):
        self.corrector = corrector
        self.use_denoiser = use_denoiser
        self.with_punct = with_punct
        self.sr = sr
        self.offset_in_seconds = offset_in_seconds
        self.max_length_seconds = max_length_seconds

        # Initialize models
        self.vad_model = AutoModel(model="fsmn-vad", device=device, max_single_segment_time=self.max_length_seconds * 1000)
        self.asr_model = AutoModel(
            model="iic/SenseVoiceSmall",
            vad_model=None,  # We'll handle VAD separately
            punc_model=None,
            ban_emo_unks=True,
            device=device,
        )

    def transcribe(
        self,
        audio_file: str,
    ) -> List[TranscribeResult]:
        """
        Transcribe audio file to text with timestamps.

        Args:
            audio_file (str): Path to audio file

        Returns:
            List[TranscribeResult]: List of transcription results
        """
        # Load and preprocess audio
        speech, sr = librosa.load(audio_file, sr=self.sr)

        # if self.use_denoiser:
        #     logger.info("Denoising speech...")
        #     speech, _ = denoiser(speech, sr)

        if sr != 16_000:
            speech = resample(speech, sr, 16_000, filter="kaiser_best", parallel=True)

        # Get VAD segments
        logger.info("Segmenting speech...")

        start_time = time.time()
        vad_results = self.vad_model.generate(input=speech, disable_pbar=True)
        logger.info("VAD took %.2f seconds", time.time() - start_time)

        if not vad_results or not vad_results[0]["value"]:
            return []

        vad_segments = vad_results[0]["value"]

        # Process each segment
        results = []

        start_time = time.time()
        for segment in tqdm(vad_segments, desc="Transcribing"):
            start_sample = int(segment[0] * 16)  # Convert ms to samples
            end_sample = int(segment[1] * 16)
            segment_audio = speech[start_sample:end_sample]

            # Get ASR results for segment
            asr_result = self.asr_model.generate(
                input=segment_audio, language="yue", use_itn=self.with_punct
            )

            if not asr_result:
                continue

            start_time = max(0, segment[0] / 1000.0 + self.offset_in_seconds)
            end_time = segment[1] / 1000.0 + self.offset_in_seconds

            # Convert ASR result to TranscribeResult format
            segment_result = TranscribeResult(
                text=asr_result[0]["text"],
                start_time=start_time,  # Convert ms to seconds
                end_time=end_time,
            )
            results.append(segment_result)

        logger.info("ASR took %.2f seconds", time.time() - start_time)

        # Apply Chinese conversion if needed
        start_time = time.time()
        results = self._convert_to_traditional_chinese(results)
        logger.info("Conversion took %.2f seconds", time.time() - start_time)

        return results

    def _convert_to_traditional_chinese(
        self, results: List[TranscribeResult]
    ) -> List[TranscribeResult]:
        """Convert simplified Chinese to traditional Chinese"""
        if not results or not self.corrector:
            return results

        corrector = Corrector(self.corrector)
        if self.corrector == "bert":
            for result in tqdm(
                results, total=len(results), desc="Converting to Traditional Chinese"
            ):
                result.text = corrector.correct(result.text)
        elif self.corrector == "opencc":
            # Use a special delimiter that won't appear in Chinese text
            delimiter = "|||"
            # Concatenate all texts with delimiter
            combined_text = delimiter.join(result.text for result in results)
            # Convert all text at once
            converted_text = corrector.correct(combined_text)
            # Split back into individual results
            converted_parts = converted_text.split(delimiter)

            # Update results with converted text
            for result, converted in zip(results, converted_parts):
                result.text = converted

        return results