|
import os |
|
import warnings |
|
from transformers.utils import logging as transformers_logging |
|
|
|
|
|
transformers_logging.set_verbosity_error() |
|
warnings.filterwarnings("ignore", category=UserWarning) |
|
|
|
import gradio as gr |
|
import torch |
|
from transformers import ( |
|
SpeechT5Processor, |
|
SpeechT5ForTextToSpeech, |
|
SpeechT5HifiGan, |
|
pipeline |
|
) |
|
import json |
|
import soundfile as sf |
|
import numpy as np |
|
from huggingface_hub import login |
|
from jiwer import wer |
|
|
|
from sklearn.feature_extraction.text import CountVectorizer |
|
from sklearn.metrics.pairwise import cosine_similarity |
|
|
|
|
|
|
|
|
|
HF_Key = os.environ.get("HF_Key") |
|
login(token = HF_Key) |
|
|
|
device = "cuda:0" if torch.cuda.is_available() else "cpu" |
|
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 |
|
|
|
|
|
transformers_logging.set_verbosity_error() |
|
warnings.filterwarnings("ignore", category=UserWarning) |
|
|
|
|
|
|
|
|
|
def cosine_sim_wer_single(reference, prediction): |
|
""" |
|
Calculate a WER-like metric based on cosine similarity for a single reference-prediction pair |
|
|
|
Args: |
|
reference: Single reference transcript (string) |
|
prediction: Single model prediction (string) |
|
|
|
Returns: |
|
Error rate based on cosine similarity (100% - similarity%) |
|
""" |
|
|
|
ref = reference.strip() if reference else "" |
|
pred = prediction.strip() if prediction else "" |
|
|
|
|
|
if not ref or not pred: |
|
print("Warning: Empty reference or prediction") |
|
return 100.0 |
|
|
|
try: |
|
|
|
vectorizer = CountVectorizer(analyzer='char_wb', ngram_range=(2, 3)) |
|
|
|
|
|
vectors = vectorizer.fit_transform([ref, pred]) |
|
|
|
|
|
similarity = cosine_similarity(vectors[0:1], vectors[1:2])[0][0] * 100 |
|
|
|
|
|
error_rate = 100.0 - similarity |
|
|
|
print(f"Similarity: {similarity:.2f}%") |
|
print(f"Error rate: {error_rate:.2f}%") |
|
|
|
except Exception as e: |
|
print(f"Error calculating similarity: {e}") |
|
return 100.0 |
|
|
|
|
|
|
|
|
|
speaker_file_path = 'speaker2.json' |
|
model_id = 'eolang/speecht5_v4-2' |
|
|
|
with open(speaker_file_path, 'r') as file: |
|
example = json.load(file) |
|
|
|
speaker_embeddings = torch.tensor(example).unsqueeze(0) |
|
|
|
l_model = SpeechT5ForTextToSpeech.from_pretrained( |
|
"eolang/speecht5_v4-2" |
|
) |
|
|
|
l_processor = SpeechT5Processor.from_pretrained("eolang/speecht5_v4-2") |
|
l_vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan") |
|
|
|
def synthesize(input_text): |
|
inputs = l_processor(text=input_text, return_tensors="pt") |
|
speech = l_model.generate_speech(inputs["input_ids"], speaker_embeddings, vocoder=l_vocoder) |
|
|
|
|
|
sf.write('test_output.wav', speech.numpy(), 16000) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tuned_pipeline = pipeline( |
|
"automatic-speech-recognition", |
|
model="eolang/whisper-small-sw-WER-13-zindi", |
|
device = device, |
|
return_timestamps=True, |
|
generate_kwargs={ |
|
"no_repeat_ngram_size": 3, |
|
"repetition_penalty": 1.5, |
|
} |
|
) |
|
|
|
|
|
def tunned_transcribe(filepath): |
|
transcription = tuned_pipeline(filepath, return_timestamps=True) |
|
return transcription["text"] |
|
|
|
|
|
|
|
|
|
openai_pipeline = pipeline( |
|
"automatic-speech-recognition", |
|
model="openai/whisper-small", |
|
device = device, |
|
return_timestamps=True, |
|
generate_kwargs={ |
|
"no_repeat_ngram_size": 3, |
|
"repetition_penalty": 1.5, |
|
} |
|
) |
|
|
|
|
|
def openai_transcribe(filepath): |
|
transcription = openai_pipeline(filepath, return_timestamps=True) |
|
return transcription["text"] |
|
|
|
|
|
|
|
|
|
def full_loop(ref_text): |
|
|
|
synthesize(ref_text) |
|
|
|
|
|
tunned_transcription = tunned_transcribe('test_output.wav') |
|
openai_trancsription = openai_transcribe('test_output.wav') |
|
|
|
tunned_WER = wer(ref_text, tunned_transcription) |
|
base_WER = wer(ref_text, openai_trancsription) |
|
|
|
result = f'Tunned Model transciption: {tunned_transcription}\n' |
|
result += f"Word error rate for the tunned model: {round(tunned_WER, 2)}\n" |
|
|
|
|
|
cosine_sim_wer_single(ref_text, tunned_transcription) |
|
|
|
result += f'\nBase Model transciption: {openai_trancsription}\n' |
|
result += f"Word error rate for base-untunned model: {round(base_WER, 2)}\n" |
|
|
|
|
|
cosine_sim_wer_single(ref_text, openai_trancsription) |
|
|
|
return 'test_output.wav', result |
|
|
|
|
|
|
|
|
|
|
|
demo = gr.Interface( |
|
fn=full_loop, |
|
inputs=gr.Textbox(value="Kuna mambo kadhaa yanayoitajika kuzingatiwa wakati wa kufundisha modeli."), |
|
outputs=[gr.Audio(), gr.Textbox()], |
|
title="TTS-STT Evaluation" |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch() |