reab5555's picture
Update transcribe.py
afa8431 verified
raw
history blame
3.39 kB
import os
import numpy as np
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, WhisperFeatureExtractor
from moviepy.editor import VideoFileClip, AudioFileClip
import nltk
nltk.download('punkt', quiet=True)
from nltk.tokenize import sent_tokenize
@spaces.GPU(duration=300)
def transcribe(video_file, transcribe_to_text=True, transcribe_to_srt=True, target_language='en'):
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model_id = "openai/whisper-large-v3"
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
feature_extractor = WhisperFeatureExtractor.from_pretrained(model_id)
video = VideoFileClip(video_file)
audio = video.audio
duration = audio.duration
chunk_duration = 60
n_chunks = int(np.ceil(duration / chunk_duration))
full_transcription = ""
for i in range(n_chunks):
start_time = i * chunk_duration
end_time = min((i + 1) * chunk_duration, duration)
audio_chunk = audio.subclip(start_time, end_time)
temp_file_path = f"temp_audio_chunk_{i}.wav"
audio_chunk.write_audiofile(temp_file_path, codec='pcm_s16le')
sound_array = AudioFileClip(temp_file_path).to_soundarray(fps=16000)
if sound_array.ndim > 1:
sound_array = np.mean(sound_array, axis=1)
input_features = feature_extractor(sound_array, sampling_rate=16000, return_tensors="pt").input_features
input_features = input_features.to(device=device, dtype=torch_dtype)
with torch.no_grad():
if target_language:
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language=target_language,
task="transcribe")
generated_ids = model.generate(input_features, max_length=448)
transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
full_transcription += transcription + " "
os.remove(temp_file_path)
print(f"Processed chunk {i + 1}/{n_chunks}")
# Split the transcription into sentences
sentences = sent_tokenize(full_transcription.strip())
# Estimate time for each sentence based on its length relative to the total transcription
total_chars = sum(len(s) for s in sentences)
sentence_times = []
current_time = 0
for sentence in sentences:
sentence_duration = (len(sentence) / total_chars) * duration
sentence_times.append((current_time, current_time + sentence_duration))
current_time += sentence_duration
output = ""
if transcribe_to_text:
output += "Text Transcription:\n" + full_transcription + "\n\n"
if transcribe_to_srt:
output += "SRT Transcription:\n"
for i, (sentence, (start, end)) in enumerate(zip(sentences, sentence_times), 1):
output += f"{i}\n{format_time(start)} --> {format_time(end)}\n{sentence}\n\n"
return output
def format_time(seconds):
m, s = divmod(seconds, 60)
h, m = divmod(m, 60)
return f"{int(h):02d}:{int(m):02d}:{s:06.3f}".replace('.', ',')