Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	File size: 4,187 Bytes
			
			| c667b6b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 | import os
os.environ["MODELSCOPE_CACHE"] = ".cache/"
import string
import time
from threading import Lock
import librosa
import numpy as np
import opencc
import torch
from faster_whisper import WhisperModel
t2s_converter = opencc.OpenCC("t2s")
def load_model(*, device="cuda"):
    model = WhisperModel(
        "medium",
        device=device,
        compute_type="float16",
        download_root="faster_whisper",
    )
    print("faster_whisper loaded!")
    return model
@torch.no_grad()
def batch_asr_internal(model: WhisperModel, audios, sr):
    resampled_audios = []
    for audio in audios:
        if isinstance(audio, np.ndarray):
            audio = torch.from_numpy(audio).float()
        if audio.dim() > 1:
            audio = audio.squeeze()
        assert audio.dim() == 1
        audio_np = audio.numpy()
        resampled_audio = librosa.resample(audio_np, orig_sr=sr, target_sr=16000)
        resampled_audios.append(resampled_audio)
    trans_results = []
    for resampled_audio in resampled_audios:
        segments, info = model.transcribe(
            resampled_audio,
            language=None,
            beam_size=5,
            initial_prompt="Punctuation is needed in any language.",
        )
        trans_results.append(list(segments))
    results = []
    for trans_res, audio in zip(trans_results, audios):
        duration = len(audio) / sr * 1000
        huge_gap = False
        max_gap = 0.0
        text = None
        last_tr = None
        for tr in trans_res:
            delta = tr.text.strip()
            if tr.id > 1:
                max_gap = max(tr.start - last_tr.end, max_gap)
                text += delta
            else:
                text = delta
            last_tr = tr
            if max_gap > 3.0:
                huge_gap = True
                break
        sim_text = t2s_converter.convert(text)
        results.append(
            {
                "text": sim_text,
                "duration": duration,
                "huge_gap": huge_gap,
            }
        )
    return results
global_lock = Lock()
def batch_asr(model, audios, sr):
    return batch_asr_internal(model, audios, sr)
def is_chinese(text):
    return True
def calculate_wer(text1, text2, debug=False):
    chars1 = remove_punctuation(text1)
    chars2 = remove_punctuation(text2)
    m, n = len(chars1), len(chars2)
    if m > n:
        chars1, chars2 = chars2, chars1
        m, n = n, m
    prev = list(range(m + 1))  # row 0 distance: [0, 1, 2, ...]
    curr = [0] * (m + 1)
    for j in range(1, n + 1):
        curr[0] = j
        for i in range(1, m + 1):
            if chars1[i - 1] == chars2[j - 1]:
                curr[i] = prev[i - 1]
            else:
                curr[i] = min(prev[i], curr[i - 1], prev[i - 1]) + 1
        prev, curr = curr, prev
    edits = prev[m]
    tot = max(len(chars1), len(chars2))
    wer = edits / tot
    if debug:
        print("            gt:   ", chars1)
        print("          pred:   ", chars2)
        print(" edits/tot = wer: ", edits, "/", tot, "=", wer)
    return wer
def remove_punctuation(text):
    chinese_punctuation = (
        " \n\t”“!?。。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃《》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—"
        '‛""„‟…‧﹏'
    )
    all_punctuation = string.punctuation + chinese_punctuation
    translator = str.maketrans("", "", all_punctuation)
    text_without_punctuation = text.translate(translator)
    return text_without_punctuation
if __name__ == "__main__":
    model = load_model()
    audios = [
        librosa.load("44100.wav", sr=44100)[0],
        librosa.load("lengyue.wav", sr=44100)[0],
    ]
    print(np.array(audios[0]))
    print(batch_asr(model, audios, 44100))
    start_time = time.time()
    for _ in range(10):
        print(batch_asr(model, audios, 44100))
    print("Time taken:", time.time() - start_time) |