File size: 1,904 Bytes
5381499 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
from typing import Tuple, Union
import re
def levenshtein_distance(source: Tuple[str], target: Tuple[str]):
"""
Compute the Levenshtein distance between two sequences.
"""
n, m = len(source), len(target)
if n > m:
# Make sure n <= m, to use O(min(n,m)) space
source, target = target, source
n, m = m, n
current_row = range(n + 1) # Keep current and previous row, not entire matrix
for i in range(1, m + 1):
previous_row, current_row = current_row, [i] + [0] * n
for j in range(1, n + 1):
add, delete, change = (
previous_row[j] + 1,
current_row[j - 1] + 1,
previous_row[j - 1],
)
if source[j - 1] != target[i - 1]:
change += 1
current_row[j] = min(add, delete, change)
distance = current_row[n]
del current_row
del previous_row
return distance
def word_error_rate(
predicted: Union[str, Tuple[str]], transcript: Union[str, Tuple[str]]
):
if isinstance(predicted, str):
predicted = (predicted,)
if isinstance(transcript, str):
transcript = (transcript,)
pattern = r"\W+"
err, total = 0, 0
for pred, tgt in zip(predicted, transcript):
pred_tokens = re.split(pattern, pred)
tgt_tokens = re.split(pattern, tgt)
err += levenshtein_distance(pred_tokens, tgt_tokens)
total += len(tgt_tokens)
return err / total
def character_error_rate(
predicted: Union[str, Tuple[str]], transcript: Union[str, Tuple[str]]
):
if isinstance(predicted, str):
predicted = (predicted,)
if isinstance(transcript, str):
transcript = (transcript,)
err, total = 0, 0
for pred, tgt in zip(predicted, transcript):
err += levenshtein_distance(pred, tgt)
total += len(tgt)
return err / total
|