TTS-STT / app.py
eolang's picture
Update app.py
c69a0b0 verified
raw
history blame
6.21 kB
import os
import warnings
from transformers.utils import logging as transformers_logging
# Silence all transformers warnings
transformers_logging.set_verbosity_error()
warnings.filterwarnings("ignore", category=UserWarning)
import gradio as gr
import torch
from transformers import (
SpeechT5Processor,
SpeechT5ForTextToSpeech,
SpeechT5HifiGan,
pipeline
)
import json
import soundfile as sf
import numpy as np
from huggingface_hub import login
from jiwer import wer
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.metrics.pairwise import cosine_similarity
# -------------------------------------------------------------------------------------------------------------------
# Authentication $ Env Setup
HF_Key = os.environ.get("HF_Key")
login(token = HF_Key)
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
# Silence all transformers warnings
transformers_logging.set_verbosity_error()
warnings.filterwarnings("ignore", category=UserWarning)
# -------------------------------------------------------------------------------------------------------------------
def cosine_sim_wer_single(reference, prediction):
"""
Calculate a WER-like metric based on cosine similarity for a single reference-prediction pair
Args:
reference: Single reference transcript (string)
prediction: Single model prediction (string)
Returns:
Error rate based on cosine similarity (100% - similarity%)
"""
# Clean inputs
ref = reference.strip() if reference else ""
pred = prediction.strip() if prediction else ""
# Handle empty inputs
if not ref or not pred:
print("Warning: Empty reference or prediction")
return 100.0 # Return 100% error for invalid input
try:
# Use character n-grams to handle morphological variations better
vectorizer = CountVectorizer(analyzer='char_wb', ngram_range=(2, 3))
# Fit and transform
vectors = vectorizer.fit_transform([ref, pred])
# Calculate cosine similarity
similarity = cosine_similarity(vectors[0:1], vectors[1:2])[0][0] * 100
# Convert to error rate (100% - similarity%)
error_rate = 100.0 - similarity
print(f"Similarity: {similarity:.2f}%")
print(f"Error rate: {error_rate:.2f}%")
except Exception as e:
print(f"Error calculating similarity: {e}")
return 100.0 # Return 100% error in case of calculation failure
# -------------------------------------------------------------------------------------------------------------------
## TTS Module
speaker_file_path = 'speaker2.json'
model_id = 'eolang/speecht5_v4-2'
with open(speaker_file_path, 'r') as file:
example = json.load(file)
speaker_embeddings = torch.tensor(example).unsqueeze(0)
l_model = SpeechT5ForTextToSpeech.from_pretrained(
"eolang/speecht5_v4-2"
)
l_processor = SpeechT5Processor.from_pretrained("eolang/speecht5_v4-2")
l_vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan")
def synthesize(input_text):
inputs = l_processor(text=input_text, return_tensors="pt")
speech = l_model.generate_speech(inputs["input_ids"], speaker_embeddings, vocoder=l_vocoder)
# Audio(speech.numpy(), rate=16000)
sf.write('test_output.wav', speech.numpy(), 16000)
# return speech
# -------------------------------------------------------------------------------------------------------------------
## STT Module
### Custom/Tunned Whisper
tuned_pipeline = pipeline(
"automatic-speech-recognition",
model="eolang/whisper-small-sw-WER-13-zindi",
device = device,
return_timestamps=True,
generate_kwargs={
"no_repeat_ngram_size": 3, # Blocks repeating 3-grams
"repetition_penalty": 1.5, # Penalize repetitions (1.0 = no penalty)
}
)
def tunned_transcribe(filepath):
transcription = tuned_pipeline(filepath, return_timestamps=True)
return transcription["text"]
### OpenAI WHisper (Un-tuned)
openai_pipeline = pipeline(
"automatic-speech-recognition",
model="openai/whisper-small",
device = device,
return_timestamps=True,
generate_kwargs={
"no_repeat_ngram_size": 3, # Blocks repeating 3-grams
"repetition_penalty": 1.5, # Penalize repetitions (1.0 = no penalty)
}
)
def openai_transcribe(filepath):
transcription = openai_pipeline(filepath, return_timestamps=True)
return transcription["text"]
# -------------------------------------------------------------------------------------------------------------------
## Full Loop module
def full_loop(ref_text):
# synthesize
synthesize(ref_text)
# Get transcriptions USING THE WRAPPER FUNCTIONS that return just text
tunned_transcription = tunned_transcribe('test_output.wav')
openai_trancsription = openai_transcribe('test_output.wav')
tunned_WER = wer(ref_text, tunned_transcription)
base_WER = wer(ref_text, openai_trancsription)
result = f'Tunned Model transciption: {tunned_transcription}\n'
result += f"Word error rate for the tunned model: {round(tunned_WER, 2)}\n"
# Call cosine sim for tuned model (this will print results)
cosine_sim_wer_single(ref_text, tunned_transcription)
result += f'\nBase Model transciption: {openai_trancsription}\n'
result += f"Word error rate for base-untunned model: {round(base_WER, 2)}\n"
# Call cosine sim for base model (this will print results)
cosine_sim_wer_single(ref_text, openai_trancsription)
return 'test_output.wav', result
# -------------------------------------------------------------------------------------------------------------------
# Add minimal Gradio wrapper
# Create a simple Gradio interface
demo = gr.Interface(
fn=full_loop, # Use your existing function without modifications
inputs=gr.Textbox(value="Kuna mambo kadhaa yanayoitajika kuzingatiwa wakati wa kufundisha modeli."),
outputs=[gr.Audio(), gr.Textbox()],
title="TTS-STT Evaluation"
)
# Launch the interface
if __name__ == "__main__":
demo.launch()