Update chunkedTranscriber.py
Browse files- chunkedTranscriber.py +8 -5
chunkedTranscriber.py
CHANGED
@@ -3,18 +3,16 @@ import gc
|
|
3 |
import sys
|
4 |
import time
|
5 |
import torch
|
|
|
6 |
import torchaudio
|
7 |
import numpy as np
|
8 |
from scipy.signal import resample
|
9 |
from pyannote.audio import Pipeline
|
10 |
from dotenv import load_dotenv
|
11 |
load_dotenv()
|
12 |
-
import logging
|
13 |
-
import time
|
14 |
from difflib import SequenceMatcher
|
15 |
from transformers import Wav2Vec2ForSequenceClassification, AutoFeatureExtractor, Wav2Vec2ForCTC, AutoProcessor, AutoTokenizer, AutoModelForSeq2SeqLM
|
16 |
from difflib import SequenceMatcher
|
17 |
-
import gc
|
18 |
|
19 |
class ChunkedTranscriber:
|
20 |
def __init__(self, chunk_size=5, overlap=1, sample_rate=16000):
|
@@ -32,6 +30,7 @@ class ChunkedTranscriber:
|
|
32 |
pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization", use_auth_token=hf_token)
|
33 |
return pipeline
|
34 |
|
|
|
35 |
def diarize_audio(self, audio_path):
|
36 |
"""
|
37 |
Perform speaker diarization on the input audio.
|
@@ -45,7 +44,8 @@ class ChunkedTranscriber:
|
|
45 |
model = Wav2Vec2ForSequenceClassification.from_pretrained(model_id)
|
46 |
return processor, model
|
47 |
|
48 |
-
|
|
|
49 |
def language_identification(self, model, processor, chunk, device="cuda"):
|
50 |
inputs = processor(chunk, sampling_rate=16_000, return_tensors="pt")
|
51 |
model.to(device)
|
@@ -69,6 +69,7 @@ class ChunkedTranscriber:
|
|
69 |
return model, processor
|
70 |
|
71 |
|
|
|
72 |
def mms_transcription(self, model, processor, chunk, device="cuda"):
|
73 |
|
74 |
inputs = processor(chunk, sampling_rate=16_000, return_tensors="pt")
|
@@ -92,7 +93,8 @@ class ChunkedTranscriber:
|
|
92 |
model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
|
93 |
return model, tokenizer
|
94 |
|
95 |
-
|
|
|
96 |
def text2text_translation(self, translation_model, translation_tokenizer, transcript, device="cuda"):
|
97 |
# model, tokenizer = load_translation_model()
|
98 |
|
@@ -108,6 +110,7 @@ class ChunkedTranscriber:
|
|
108 |
gc.collect()
|
109 |
return translation_tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]
|
110 |
|
|
|
111 |
def preprocess_audio(self, audio):
|
112 |
"""
|
113 |
Create overlapping chunks with improved timing logic
|
|
|
3 |
import sys
|
4 |
import time
|
5 |
import torch
|
6 |
+
import spaces
|
7 |
import torchaudio
|
8 |
import numpy as np
|
9 |
from scipy.signal import resample
|
10 |
from pyannote.audio import Pipeline
|
11 |
from dotenv import load_dotenv
|
12 |
load_dotenv()
|
|
|
|
|
13 |
from difflib import SequenceMatcher
|
14 |
from transformers import Wav2Vec2ForSequenceClassification, AutoFeatureExtractor, Wav2Vec2ForCTC, AutoProcessor, AutoTokenizer, AutoModelForSeq2SeqLM
|
15 |
from difflib import SequenceMatcher
|
|
|
16 |
|
17 |
class ChunkedTranscriber:
|
18 |
def __init__(self, chunk_size=5, overlap=1, sample_rate=16000):
|
|
|
30 |
pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization", use_auth_token=hf_token)
|
31 |
return pipeline
|
32 |
|
33 |
+
@spaces.GPU(duration=60)
|
34 |
def diarize_audio(self, audio_path):
|
35 |
"""
|
36 |
Perform speaker diarization on the input audio.
|
|
|
44 |
model = Wav2Vec2ForSequenceClassification.from_pretrained(model_id)
|
45 |
return processor, model
|
46 |
|
47 |
+
|
48 |
+
@spaces.GPU(duration=60)
|
49 |
def language_identification(self, model, processor, chunk, device="cuda"):
|
50 |
inputs = processor(chunk, sampling_rate=16_000, return_tensors="pt")
|
51 |
model.to(device)
|
|
|
69 |
return model, processor
|
70 |
|
71 |
|
72 |
+
@spaces.GPU(duration=60)
|
73 |
def mms_transcription(self, model, processor, chunk, device="cuda"):
|
74 |
|
75 |
inputs = processor(chunk, sampling_rate=16_000, return_tensors="pt")
|
|
|
93 |
model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
|
94 |
return model, tokenizer
|
95 |
|
96 |
+
|
97 |
+
@spaces.GPU(duration=60)
|
98 |
def text2text_translation(self, translation_model, translation_tokenizer, transcript, device="cuda"):
|
99 |
# model, tokenizer = load_translation_model()
|
100 |
|
|
|
110 |
gc.collect()
|
111 |
return translation_tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]
|
112 |
|
113 |
+
|
114 |
def preprocess_audio(self, audio):
|
115 |
"""
|
116 |
Create overlapping chunks with improved timing logic
|