Spaces:
Sleeping
Sleeping
| import re | |
| import torch | |
| import torchaudio | |
| from transformers import Wav2Vec2ForCTC, Wav2Vec2FeatureExtractor, Wav2Vec2CTCTokenizer, Wav2Vec2Processor | |
| from tortoise.utils.audio import load_audio | |
| def max_alignment(s1, s2, skip_character='~', record=None): | |
| """ | |
| A clever function that aligns s1 to s2 as best it can. Wherever a character from s1 is not found in s2, a '~' is | |
| used to replace that character. | |
| Finally got to use my DP skills! | |
| """ | |
| if record is None: | |
| record = {} | |
| assert skip_character not in s1, f"Found the skip character {skip_character} in the provided string, {s1}" | |
| if len(s1) == 0: | |
| return '' | |
| if len(s2) == 0: | |
| return skip_character * len(s1) | |
| if s1 == s2: | |
| return s1 | |
| if s1[0] == s2[0]: | |
| return s1[0] + max_alignment(s1[1:], s2[1:], skip_character, record) | |
| take_s1_key = (len(s1), len(s2) - 1) | |
| if take_s1_key in record: | |
| take_s1, take_s1_score = record[take_s1_key] | |
| else: | |
| take_s1 = max_alignment(s1, s2[1:], skip_character, record) | |
| take_s1_score = len(take_s1.replace(skip_character, '')) | |
| record[take_s1_key] = (take_s1, take_s1_score) | |
| take_s2_key = (len(s1) - 1, len(s2)) | |
| if take_s2_key in record: | |
| take_s2, take_s2_score = record[take_s2_key] | |
| else: | |
| take_s2 = max_alignment(s1[1:], s2, skip_character, record) | |
| take_s2_score = len(take_s2.replace(skip_character, '')) | |
| record[take_s2_key] = (take_s2, take_s2_score) | |
| return take_s1 if take_s1_score > take_s2_score else skip_character + take_s2 | |
| class Wav2VecAlignment: | |
| """ | |
| Uses wav2vec2 to perform audio<->text alignment. | |
| """ | |
| def __init__(self, device='cuda' if not torch.backends.mps.is_available() else 'mps'): | |
| self.model = Wav2Vec2ForCTC.from_pretrained("jbetker/wav2vec2-large-robust-ft-libritts-voxpopuli").cpu() | |
| self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(f"facebook/wav2vec2-large-960h") | |
| self.tokenizer = Wav2Vec2CTCTokenizer.from_pretrained('jbetker/tacotron-symbols') | |
| self.device = device | |
| def align(self, audio, expected_text, audio_sample_rate=24000): | |
| orig_len = audio.shape[-1] | |
| with torch.no_grad(): | |
| self.model = self.model.to(self.device) | |
| audio = audio.to(self.device) | |
| audio = torchaudio.functional.resample(audio, audio_sample_rate, 16000) | |
| clip_norm = (audio - audio.mean()) / torch.sqrt(audio.var() + 1e-7) | |
| logits = self.model(clip_norm).logits | |
| self.model = self.model.cpu() | |
| logits = logits[0] | |
| pred_string = self.tokenizer.decode(logits.argmax(-1).tolist()) | |
| fixed_expectation = max_alignment(expected_text.lower(), pred_string) | |
| w2v_compression = orig_len // logits.shape[0] | |
| expected_tokens = self.tokenizer.encode(fixed_expectation) | |
| expected_chars = list(fixed_expectation) | |
| if len(expected_tokens) == 1: | |
| return [0] # The alignment is simple; there is only one token. | |
| expected_tokens.pop(0) # The first token is a given. | |
| expected_chars.pop(0) | |
| alignments = [0] | |
| def pop_till_you_win(): | |
| if len(expected_tokens) == 0: | |
| return None | |
| popped = expected_tokens.pop(0) | |
| popped_char = expected_chars.pop(0) | |
| while popped_char == '~': | |
| alignments.append(-1) | |
| if len(expected_tokens) == 0: | |
| return None | |
| popped = expected_tokens.pop(0) | |
| popped_char = expected_chars.pop(0) | |
| return popped | |
| next_expected_token = pop_till_you_win() | |
| for i, logit in enumerate(logits): | |
| top = logit.argmax() | |
| if next_expected_token == top: | |
| alignments.append(i * w2v_compression) | |
| if len(expected_tokens) > 0: | |
| next_expected_token = pop_till_you_win() | |
| else: | |
| break | |
| pop_till_you_win() | |
| if not (len(expected_tokens) == 0 and len(alignments) == len(expected_text)): | |
| torch.save([audio, expected_text], 'alignment_debug.pth') | |
| assert False, "Something went wrong with the alignment algorithm. I've dumped a file, 'alignment_debug.pth' to" \ | |
| "your current working directory. Please report this along with the file so it can get fixed." | |
| # Now fix up alignments. Anything with -1 should be interpolated. | |
| alignments.append(orig_len) # This'll get removed but makes the algorithm below more readable. | |
| for i in range(len(alignments)): | |
| if alignments[i] == -1: | |
| for j in range(i+1, len(alignments)): | |
| if alignments[j] != -1: | |
| next_found_token = j | |
| break | |
| for j in range(i, next_found_token): | |
| gap = alignments[next_found_token] - alignments[i-1] | |
| alignments[j] = (j-i+1) * gap // (next_found_token-i+1) + alignments[i-1] | |
| return alignments[:-1] | |
| def redact(self, audio, expected_text, audio_sample_rate=24000): | |
| if '[' not in expected_text: | |
| return audio | |
| splitted = expected_text.split('[') | |
| fully_split = [splitted[0]] | |
| for spl in splitted[1:]: | |
| assert ']' in spl, 'Every "[" character must be paired with a "]" with no nesting.' | |
| fully_split.extend(spl.split(']')) | |
| # At this point, fully_split is a list of strings, with every other string being something that should be redacted. | |
| non_redacted_intervals = [] | |
| last_point = 0 | |
| for i in range(len(fully_split)): | |
| if i % 2 == 0 and fully_split[i] != "": # Check for empty string fixes index error | |
| end_interval = max(0, last_point + len(fully_split[i]) - 1) | |
| non_redacted_intervals.append((last_point, end_interval)) | |
| last_point += len(fully_split[i]) | |
| bare_text = ''.join(fully_split) | |
| alignments = self.align(audio, bare_text, audio_sample_rate) | |
| output_audio = [] | |
| for nri in non_redacted_intervals: | |
| start, stop = nri | |
| output_audio.append(audio[:, alignments[start]:alignments[stop]]) | |
| return torch.cat(output_audio, dim=-1) | |