|
|
|
|
|
""" |
|
|
Whisper Model WER Evaluation - Fine-tunes vs Commercial APIs |
|
|
Compares local fine-tuned models against commercial STT providers via EdenAI |
|
|
""" |
|
|
|
|
|
import os |
|
|
import json |
|
|
import time |
|
|
from datetime import datetime |
|
|
from pathlib import Path |
|
|
import torch |
|
|
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline |
|
|
import requests |
|
|
import jiwer |
|
|
from dotenv import load_dotenv |
|
|
|
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
|
|
|
AUDIO_FILE = "eval/test-audio.wav" |
|
|
TRUTH_FILE = "eval/truth.txt" |
|
|
RESULTS_DIR = "results" |
|
|
|
|
|
|
|
|
COMMERCIAL_PROVIDERS = ["deepgram", "openai", "assembly", "gladia"] |
|
|
|
|
|
|
|
|
MODELS = { |
|
|
|
|
|
"whisper-base-ft": { |
|
|
"type": "local", |
|
|
"path": "/home/daniel/ai/models/stt/finetunes/daniel-whisper-base-finetune", |
|
|
"description": "Fine-tuned Whisper Base" |
|
|
}, |
|
|
"whisper-small-ft": { |
|
|
"type": "local", |
|
|
"path": "/home/daniel/ai/models/stt/finetunes/whisper-small-en-futo", |
|
|
"description": "Fine-tuned Whisper Small" |
|
|
}, |
|
|
"whisper-tiny-ft": { |
|
|
"type": "local", |
|
|
"path": "/home/daniel/ai/models/stt/finetunes/whisper-tiny-en-futo", |
|
|
"description": "Fine-tuned Whisper Tiny" |
|
|
}, |
|
|
"whisper-large-turbo-ft": { |
|
|
"type": "local", |
|
|
"path": "/home/daniel/ai/models/stt/finetunes/whisper-large-turbo-finetune", |
|
|
"description": "Fine-tuned Whisper Large Turbo" |
|
|
} |
|
|
} |
|
|
|
|
|
def load_ground_truth(truth_file): |
|
|
"""Load ground truth transcription""" |
|
|
with open(truth_file, 'r') as f: |
|
|
return f.read().strip() |
|
|
|
|
|
def transcribe_local_model(model_path, audio_file): |
|
|
"""Transcribe audio using a local model""" |
|
|
try: |
|
|
device = "cuda:0" if torch.cuda.is_available() else "cpu" |
|
|
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 |
|
|
|
|
|
print(f" Loading model from {model_path}...") |
|
|
model = AutoModelForSpeechSeq2Seq.from_pretrained( |
|
|
model_path, |
|
|
torch_dtype=torch_dtype, |
|
|
low_cpu_mem_usage=True, |
|
|
use_safetensors=True |
|
|
) |
|
|
model.to(device) |
|
|
|
|
|
processor = AutoProcessor.from_pretrained(model_path) |
|
|
|
|
|
pipe = pipeline( |
|
|
"automatic-speech-recognition", |
|
|
model=model, |
|
|
tokenizer=processor.tokenizer, |
|
|
feature_extractor=processor.feature_extractor, |
|
|
max_new_tokens=128, |
|
|
chunk_length_s=30, |
|
|
batch_size=16, |
|
|
return_timestamps=False, |
|
|
torch_dtype=torch_dtype, |
|
|
device=device, |
|
|
) |
|
|
|
|
|
print(f" Transcribing...") |
|
|
result = pipe(audio_file) |
|
|
transcription = result["text"] |
|
|
|
|
|
|
|
|
del model |
|
|
del processor |
|
|
del pipe |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
return transcription.strip() |
|
|
|
|
|
except Exception as e: |
|
|
print(f" ERROR: {str(e)}") |
|
|
return None |
|
|
|
|
|
def transcribe_edenai(audio_file, providers, api_key): |
|
|
"""Transcribe audio using EdenAI with multiple providers""" |
|
|
results = {} |
|
|
|
|
|
for provider in providers: |
|
|
print(f"\n Testing {provider}...") |
|
|
try: |
|
|
|
|
|
url = "https://api.edenai.run/v2/audio/speech_to_text_async" |
|
|
headers = {"Authorization": f"Bearer {api_key}"} |
|
|
|
|
|
data = { |
|
|
"providers": provider, |
|
|
"language": "en" |
|
|
} |
|
|
|
|
|
with open(audio_file, 'rb') as f: |
|
|
files = {'file': f} |
|
|
response = requests.post(url, data=data, files=files, headers=headers) |
|
|
|
|
|
if response.status_code != 200: |
|
|
print(f" ❌ Failed to submit job: {response.status_code}") |
|
|
print(f" Response: {response.text}") |
|
|
results[provider] = None |
|
|
continue |
|
|
|
|
|
job_data = response.json() |
|
|
public_id = job_data.get("public_id") |
|
|
|
|
|
if not public_id: |
|
|
print(f" ❌ No job ID returned") |
|
|
results[provider] = None |
|
|
continue |
|
|
|
|
|
print(f" Job ID: {public_id}") |
|
|
print(f" Polling for results...") |
|
|
|
|
|
|
|
|
result_url = f"https://api.edenai.run/v2/audio/speech_to_text_async/{public_id}" |
|
|
max_attempts = 60 |
|
|
attempt = 0 |
|
|
|
|
|
while attempt < max_attempts: |
|
|
time.sleep(2) |
|
|
result_response = requests.get(result_url, headers=headers) |
|
|
|
|
|
if result_response.status_code != 200: |
|
|
print(f" ❌ Failed to get results: {result_response.status_code}") |
|
|
break |
|
|
|
|
|
result_data = result_response.json() |
|
|
status = result_data.get("status") |
|
|
|
|
|
if status == "finished": |
|
|
|
|
|
provider_result = result_data.get("results", {}).get(provider, {}) |
|
|
transcription = provider_result.get("text", "") |
|
|
|
|
|
if transcription: |
|
|
print(f" ✓ Transcription received") |
|
|
results[provider] = transcription.strip() |
|
|
else: |
|
|
print(f" ⚠️ No transcription in response") |
|
|
results[provider] = None |
|
|
break |
|
|
elif status == "failed": |
|
|
error = result_data.get("results", {}).get(provider, {}).get("error") |
|
|
print(f" ❌ Job failed: {error}") |
|
|
results[provider] = None |
|
|
break |
|
|
|
|
|
attempt += 1 |
|
|
if attempt % 10 == 0: |
|
|
print(f" Still waiting... ({attempt}/{max_attempts})") |
|
|
|
|
|
if attempt >= max_attempts: |
|
|
print(f" ⏱️ Timeout waiting for results") |
|
|
results[provider] = None |
|
|
|
|
|
except Exception as e: |
|
|
print(f" ❌ Error: {str(e)}") |
|
|
results[provider] = None |
|
|
|
|
|
return results |
|
|
|
|
|
def calculate_metrics(reference, hypothesis): |
|
|
"""Calculate WER and other metrics""" |
|
|
output = jiwer.process_words(reference, hypothesis) |
|
|
return { |
|
|
"wer": output.wer, |
|
|
"mer": output.mer, |
|
|
"wil": output.wil, |
|
|
"wip": output.wip, |
|
|
"hits": output.hits, |
|
|
"substitutions": output.substitutions, |
|
|
"deletions": output.deletions, |
|
|
"insertions": output.insertions |
|
|
} |
|
|
|
|
|
def save_transcription(model_name, transcription): |
|
|
"""Save transcription to file""" |
|
|
transcriptions_dir = Path(RESULTS_DIR) / "transcriptions" |
|
|
transcriptions_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
output_file = transcriptions_dir / f"transcription_{model_name}.txt" |
|
|
with open(output_file, 'w') as f: |
|
|
f.write(transcription) |
|
|
|
|
|
return output_file |
|
|
|
|
|
def format_results_table(results): |
|
|
"""Format results as ASCII table""" |
|
|
header = "| Rank | Model | Type | WER | MER | WIL | WIP |" |
|
|
separator = "|------|-------|------|-----|-----|-----|-----|" |
|
|
|
|
|
lines = [header, separator] |
|
|
for i, result in enumerate(results, 1): |
|
|
if result["model_type"] == "local": |
|
|
model_type = "Fine-tune" |
|
|
else: |
|
|
model_type = "Commercial" |
|
|
line = f"| {i} | {result['model_name']} | {model_type} | {result['wer']:.2%} | {result['mer']:.2%} | {result['wil']:.2%} | {result['wip']:.2%} |" |
|
|
lines.append(line) |
|
|
|
|
|
return "\n".join(lines) |
|
|
|
|
|
def generate_comparison_chart(results): |
|
|
"""Generate ASCII bar chart of WER results""" |
|
|
lines = ["WER Comparison (lower is better)", "=" * 80, ""] |
|
|
|
|
|
max_wer = max(r['wer'] for r in results) if results else 1 |
|
|
max_bar_length = 60 |
|
|
|
|
|
for result in results: |
|
|
wer = result['wer'] |
|
|
bar_length = int((wer / max_wer) * max_bar_length) if max_wer > 0 else 0 |
|
|
bar = "█" * bar_length |
|
|
model_type = "FT" if result["model_type"] == "local" else "CM" |
|
|
line = f"{result['model_name'][:30]:<30} [{model_type}] {bar} {wer:.2%}" |
|
|
lines.append(line) |
|
|
|
|
|
lines.append("") |
|
|
lines.append("Legend: [FT] = Fine-tuned (local), [CM] = Commercial API") |
|
|
return "\n".join(lines) |
|
|
|
|
|
def main(): |
|
|
print("=" * 80) |
|
|
print("Whisper Model WER Evaluation - Fine-tunes vs Commercial APIs") |
|
|
print("=" * 80) |
|
|
print() |
|
|
|
|
|
|
|
|
print("Checking EdenAI API key...") |
|
|
api_key = os.environ.get("EDENAI_API_KEY") |
|
|
if not api_key: |
|
|
print("⚠️ Warning: EDENAI_API_KEY not set.") |
|
|
print(" Export EDENAI_API_KEY=your_key to enable commercial API comparison") |
|
|
print(" Continuing with local models only...") |
|
|
else: |
|
|
print("✓ EDENAI_API_KEY found") |
|
|
print() |
|
|
|
|
|
|
|
|
print(f"Loading ground truth from {TRUTH_FILE}...") |
|
|
reference = load_ground_truth(TRUTH_FILE) |
|
|
print(f"Ground truth loaded: {len(reference.split())} words") |
|
|
print() |
|
|
|
|
|
|
|
|
Path(RESULTS_DIR).mkdir(exist_ok=True) |
|
|
|
|
|
|
|
|
results = [] |
|
|
failed_models = [] |
|
|
|
|
|
print("Evaluating local fine-tuned models...") |
|
|
print("-" * 80) |
|
|
|
|
|
for model_name, config in MODELS.items(): |
|
|
print(f"\n{model_name} ({config['description']})") |
|
|
|
|
|
start_time = time.time() |
|
|
|
|
|
if not Path(config["path"]).exists(): |
|
|
print(f" ⚠️ Model path not found: {config['path']}") |
|
|
failed_models.append({ |
|
|
"model_name": model_name, |
|
|
"description": config["description"], |
|
|
"error": "Model path not found" |
|
|
}) |
|
|
continue |
|
|
|
|
|
transcription = transcribe_local_model(config["path"], AUDIO_FILE) |
|
|
|
|
|
elapsed_time = time.time() - start_time |
|
|
|
|
|
if transcription is None: |
|
|
failed_models.append({ |
|
|
"model_name": model_name, |
|
|
"description": config["description"], |
|
|
"error": "Transcription failed" |
|
|
}) |
|
|
continue |
|
|
|
|
|
|
|
|
save_transcription(model_name, transcription) |
|
|
print(f" Saved transcription") |
|
|
|
|
|
|
|
|
metrics = calculate_metrics(reference, transcription) |
|
|
|
|
|
results.append({ |
|
|
"model_name": model_name, |
|
|
"description": config["description"], |
|
|
"model_type": "local", |
|
|
"transcription": transcription, |
|
|
"processing_time": elapsed_time, |
|
|
**metrics |
|
|
}) |
|
|
|
|
|
print(f" WER: {metrics['wer']:.2%}") |
|
|
print(f" Processing time: {elapsed_time:.2f}s") |
|
|
|
|
|
|
|
|
if api_key: |
|
|
print("\n" + "=" * 80) |
|
|
print("Evaluating commercial STT providers via EdenAI...") |
|
|
print("-" * 80) |
|
|
|
|
|
commercial_results = transcribe_edenai(AUDIO_FILE, COMMERCIAL_PROVIDERS, api_key) |
|
|
|
|
|
for provider, transcription in commercial_results.items(): |
|
|
if transcription: |
|
|
model_name = f"{provider}-api" |
|
|
|
|
|
|
|
|
save_transcription(model_name, transcription) |
|
|
|
|
|
|
|
|
metrics = calculate_metrics(reference, transcription) |
|
|
|
|
|
results.append({ |
|
|
"model_name": model_name, |
|
|
"description": f"{provider.title()} STT API", |
|
|
"model_type": "commercial", |
|
|
"transcription": transcription, |
|
|
"processing_time": 0, |
|
|
**metrics |
|
|
}) |
|
|
|
|
|
print(f"\n✓ {provider}: WER {metrics['wer']:.2%}") |
|
|
else: |
|
|
failed_models.append({ |
|
|
"model_name": f"{provider}-api", |
|
|
"description": f"{provider.title()} STT API", |
|
|
"error": "Transcription failed" |
|
|
}) |
|
|
|
|
|
|
|
|
results.sort(key=lambda x: x['wer']) |
|
|
|
|
|
|
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
|
|
report_file = Path(RESULTS_DIR) / f"commercial_comparison_report_{timestamp}.txt" |
|
|
json_file = Path(RESULTS_DIR) / f"commercial_comparison_results_{timestamp}.json" |
|
|
|
|
|
|
|
|
results_table = format_results_table(results) |
|
|
|
|
|
|
|
|
comparison_chart = generate_comparison_chart(results) |
|
|
|
|
|
|
|
|
with open(report_file, 'w') as f: |
|
|
f.write("=" * 80 + "\n") |
|
|
f.write("WHISPER MODEL WER EVALUATION - FINE-TUNES VS COMMERCIAL APIS\n") |
|
|
f.write("=" * 80 + "\n\n") |
|
|
f.write(f"Evaluation Date: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n") |
|
|
f.write(f"Test Audio: {AUDIO_FILE}\n") |
|
|
f.write(f"Ground Truth: {TRUTH_FILE}\n") |
|
|
f.write(f"Reference Word Count: {len(reference.split())} words\n") |
|
|
f.write(f"Commercial Providers: {', '.join(COMMERCIAL_PROVIDERS)}\n\n") |
|
|
|
|
|
f.write("RESULTS RANKED BY WER (BEST TO WORST)\n") |
|
|
f.write("=" * 80 + "\n\n") |
|
|
f.write(results_table + "\n\n") |
|
|
|
|
|
f.write("WER COMPARISON CHART\n") |
|
|
f.write("=" * 80 + "\n\n") |
|
|
f.write(comparison_chart + "\n\n") |
|
|
|
|
|
f.write("DETAILED METRICS\n") |
|
|
f.write("=" * 80 + "\n\n") |
|
|
|
|
|
for result in results: |
|
|
f.write(f"{result['model_name']} - {result['description']}\n") |
|
|
f.write(f" Type: {result['model_type'].title()}\n") |
|
|
f.write(f" WER: {result['wer']:.2%}\n") |
|
|
f.write(f" MER: {result['mer']:.2%}\n") |
|
|
f.write(f" WIL: {result['wil']:.2%}\n") |
|
|
f.write(f" WIP: {result['wip']:.2%}\n") |
|
|
f.write(f" Hits: {result['hits']}\n") |
|
|
f.write(f" Substitutions: {result['substitutions']}\n") |
|
|
f.write(f" Deletions: {result['deletions']}\n") |
|
|
f.write(f" Insertions: {result['insertions']}\n") |
|
|
if result['processing_time'] > 0: |
|
|
f.write(f" Processing Time: {result['processing_time']:.2f}s\n") |
|
|
f.write("\n") |
|
|
|
|
|
if failed_models: |
|
|
f.write("FAILED MODELS\n") |
|
|
f.write("=" * 80 + "\n\n") |
|
|
for failed in failed_models: |
|
|
f.write(f"{failed['model_name']} - {failed['description']}\n") |
|
|
f.write(f" Error: {failed['error']}\n\n") |
|
|
|
|
|
f.write("CONCLUSIONS\n") |
|
|
f.write("=" * 80 + "\n\n") |
|
|
|
|
|
if results: |
|
|
best = results[0] |
|
|
worst = results[-1] |
|
|
|
|
|
f.write(f"Best Performer: {best['model_name']} ({best['description']})\n") |
|
|
f.write(f" WER: {best['wer']:.2%}\n") |
|
|
f.write(f" Type: {best['model_type'].title()}\n\n") |
|
|
|
|
|
f.write(f"Worst Performer: {worst['model_name']} ({worst['description']})\n") |
|
|
f.write(f" WER: {worst['wer']:.2%}\n") |
|
|
f.write(f" Type: {worst['model_type'].title()}\n\n") |
|
|
|
|
|
|
|
|
local_models = [r for r in results if r['model_type'] == 'local'] |
|
|
commercial_models = [r for r in results if r['model_type'] == 'commercial'] |
|
|
|
|
|
if local_models and commercial_models: |
|
|
best_local = local_models[0] |
|
|
best_commercial = commercial_models[0] |
|
|
|
|
|
f.write(f"Best Fine-tune: {best_local['model_name']} - WER {best_local['wer']:.2%}\n") |
|
|
f.write(f"Best Commercial: {best_commercial['model_name']} - WER {best_commercial['wer']:.2%}\n\n") |
|
|
|
|
|
if best_local['wer'] < best_commercial['wer']: |
|
|
improvement = ((best_commercial['wer'] - best_local['wer']) / best_commercial['wer']) * 100 |
|
|
f.write(f"🎯 Fine-tuning Improvement: {improvement:.1f}% better WER than best commercial API\n") |
|
|
else: |
|
|
difference = ((best_local['wer'] - best_commercial['wer']) / best_commercial['wer']) * 100 |
|
|
f.write(f"Commercial API Advantage: {difference:.1f}% better WER than best fine-tune\n") |
|
|
|
|
|
|
|
|
with open(json_file, 'w') as f: |
|
|
json.dump({ |
|
|
"timestamp": timestamp, |
|
|
"audio_file": AUDIO_FILE, |
|
|
"truth_file": TRUTH_FILE, |
|
|
"reference_word_count": len(reference.split()), |
|
|
"commercial_providers": COMMERCIAL_PROVIDERS, |
|
|
"results": results, |
|
|
"failed_models": failed_models |
|
|
}, f, indent=2) |
|
|
|
|
|
|
|
|
latest_dir = Path(RESULTS_DIR) / "latest" |
|
|
latest_dir.mkdir(exist_ok=True) |
|
|
|
|
|
import shutil |
|
|
shutil.copy(report_file, latest_dir / "commercial_comparison_report.txt") |
|
|
shutil.copy(json_file, latest_dir / "commercial_comparison_results.json") |
|
|
|
|
|
with open(latest_dir / "commercial_comparison_chart.txt", 'w') as f: |
|
|
f.write(comparison_chart) |
|
|
|
|
|
|
|
|
print("\n" + "=" * 80) |
|
|
print("EVALUATION COMPLETE") |
|
|
print("=" * 80) |
|
|
print(f"\nResults saved to:") |
|
|
print(f" Report: {report_file}") |
|
|
print(f" JSON: {json_file}") |
|
|
print(f" Latest: {latest_dir}/") |
|
|
print(f"\nEvaluated {len(results)} models successfully") |
|
|
if failed_models: |
|
|
print(f"Failed to evaluate {len(failed_models)} models") |
|
|
print() |
|
|
print(results_table) |
|
|
print() |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|