|
import io |
|
import re |
|
|
|
import librosa |
|
import torch |
|
import torchaudio |
|
from cachetools import LRUCache, cached |
|
|
|
CACHE_MAXSIZE = 10000 |
|
MICRO_BATCH_SIZE = 8 |
|
ASR_SAMPLE_RATE = 16000 |
|
HUGE_GAP_THRESHOLD = 4000 |
|
|
|
|
|
@torch.no_grad() |
|
@torch.autocast(device_type="cuda", dtype=torch.half) |
|
def batch_encode(model, audios_list: list[bytes]): |
|
audios: list[torch.Tensor] = [ |
|
( |
|
torch.from_numpy( |
|
librosa.load(io.BytesIO(audio), sr=model.spec_transform.sample_rate)[0] |
|
)[None] |
|
if isinstance(audio, bytes) |
|
else audio |
|
) |
|
for audio in audios_list |
|
] |
|
|
|
lengths = torch.tensor([audio.shape[-1] for audio in audios], device=model.device) |
|
max_length = lengths.max().item() |
|
|
|
print(f"Encode max length: {max_length / model.spec_transform.sample_rate:.2f}s") |
|
|
|
padded = torch.stack( |
|
[ |
|
torch.nn.functional.pad(audio, (0, int(max_length - audio.shape[-1]))) |
|
for audio in audios |
|
] |
|
).to(model.device) |
|
|
|
features, feature_lengths = model.encode(padded, audio_lengths=lengths) |
|
features, feature_lengths = features.cpu(), feature_lengths.cpu() |
|
|
|
return [feature[..., :length] for feature, length in zip(features, feature_lengths)] |
|
|
|
|
|
@cached( |
|
cache=LRUCache(maxsize=CACHE_MAXSIZE), |
|
key=lambda model, audios: (model.device, tuple(audios)), |
|
) |
|
def cached_vqgan_batch_encode(model, audios: list[bytes]): |
|
return batch_encode(model, audios) |
|
|
|
|
|
@torch.no_grad() |
|
@torch.autocast(device_type="cuda", dtype=torch.half) |
|
def batch_vqgan_decode(model, features): |
|
lengths = torch.tensor( |
|
[feature.shape[-1] for feature in features], device=model.device |
|
) |
|
max_length = lengths.max().item() |
|
padded = torch.stack( |
|
[ |
|
torch.nn.functional.pad(feature, (0, max_length - feature.shape[-1])) |
|
for feature in features |
|
] |
|
).to(model.device) |
|
|
|
|
|
audios, audio_lengths = [], [] |
|
for i in range(0, padded.shape[0], MICRO_BATCH_SIZE): |
|
audio, audio_length = model.decode( |
|
padded[i : i + MICRO_BATCH_SIZE], |
|
feature_lengths=lengths[i : i + MICRO_BATCH_SIZE], |
|
) |
|
audios.append(audio) |
|
audio_lengths.append(audio_length) |
|
audios = torch.cat(audios, dim=0) |
|
audio_lengths = torch.cat(audio_lengths, dim=0) |
|
audios, audio_lengths = audios.cpu(), audio_lengths.cpu() |
|
|
|
return [audio[..., :length].numpy() for audio, length in zip(audios, audio_lengths)] |
|
|
|
|
|
@torch.no_grad() |
|
def batch_asr(model, lock, audios, sr, language="auto"): |
|
resampled_audios = [] |
|
for audio in audios: |
|
audio = torchaudio.functional.resample(audio, sr, ASR_SAMPLE_RATE) |
|
assert audio.ndim == 1 |
|
resampled_audios.append(audio) |
|
|
|
with lock: |
|
res = model.generate( |
|
input=resampled_audios, |
|
batch_size=len(resampled_audios), |
|
language=language, |
|
use_itn=True, |
|
) |
|
|
|
results = [] |
|
for r, audio in zip(res, audios): |
|
text = r["text"] |
|
text = re.sub(r"<\|.*?\|>", "", text) |
|
duration = len(audio) / sr * 1000 |
|
huge_gap = False |
|
|
|
if "timestamp" in r and len(r["timestamp"]) > 2: |
|
for timestamp_a, timestamp_b in zip( |
|
r["timestamp"][:-1], r["timestamp"][1:] |
|
): |
|
|
|
if timestamp_b[0] - timestamp_a[1] > HUGE_GAP_THRESHOLD: |
|
huge_gap = True |
|
break |
|
|
|
|
|
if duration - r["timestamp"][-1][1] > HUGE_GAP_THRESHOLD: |
|
huge_gap = True |
|
|
|
results.append( |
|
{ |
|
"text": text, |
|
"duration": duration, |
|
"huge_gap": huge_gap, |
|
} |
|
) |
|
|
|
return results |
|
|