From a8793039f3cd29959a8110ab22c45a32d4b52587 Mon Sep 17 00:00:00 2001 From: yangyaming Date: Fri, 1 Dec 2017 19:36:07 +0800 Subject: [PATCH 1/2] Expose edit distance for error_rate.py --- utils/error_rate.py | 100 +++++++++++++++++++++++++++++++------------- 1 file changed, 71 insertions(+), 29 deletions(-) diff --git a/utils/error_rate.py b/utils/error_rate.py index ea829f470..2ff3f6960 100644 --- a/utils/error_rate.py +++ b/utils/error_rate.py @@ -56,6 +56,70 @@ def _levenshtein_distance(ref, hyp): return distance[m % 2][n] +def word_errors(reference, hypothesis, ignore_case=False, delimiter=' '): + """Compute the levenshtein distance between reference sequence and + hypothesis sequence in word-level. + + :param reference: The reference sentence. + :type reference: basestring + :param hypothesis: The hypothesis sentence. + :type hypothesis: basestring + :param ignore_case: Whether case-sensitive or not. + :type ignore_case: bool + :param delimiter: Delimiter of input sentences. + :type delimiter: char + :return: Levenshtein distance and word number of reference sentence. + :rtype: list + :raises ValueError: If word number of reference sentence is zero. + """ + if ignore_case == True: + reference = reference.lower() + hypothesis = hypothesis.lower() + + ref_words = filter(None, reference.split(delimiter)) + hyp_words = filter(None, hypothesis.split(delimiter)) + + if len(ref_words) == 0: + raise ValueError("Reference's word number should be greater than 0.") + + edit_distance = _levenshtein_distance(ref_words, hyp_words) + return float(edit_distance), len(ref_words) + + +def char_errors(reference, hypothesis, ignore_case=False, remove_space=False): + """Compute the levenshtein distance between reference sequence and + hypothesis sequence in char-level. + + :param reference: The reference sentence. + :type reference: basestring + :param hypothesis: The hypothesis sentence. + :type hypothesis: basestring + :param ignore_case: Whether case-sensitive or not. + :type ignore_case: bool + :param remove_space: Whether remove internal space characters + :type remove_space: bool + :return: Levenshtein distance and length of reference sentence. + :rtype: list + :raises ValueError: If the reference length is zero. + """ + if ignore_case == True: + reference = reference.lower() + hypothesis = hypothesis.lower() + + join_char = ' ' + if remove_space == True: + join_char = '' + + reference = join_char.join(filter(None, reference.split(' '))) + hypothesis = join_char.join(filter(None, hypothesis.split(' '))) + + if len(reference) == 0: + raise ValueError("Length of reference should be greater than 0.") + + edit_distance = _levenshtein_distance(reference, hypothesis) + return float(edit_distance), len(reference) + + def wer(reference, hypothesis, ignore_case=False, delimiter=' '): """Calculate word error rate (WER). WER compares reference text and hypothesis text in word-level. WER is defined as: @@ -85,20 +149,11 @@ def wer(reference, hypothesis, ignore_case=False, delimiter=' '): :type delimiter: char :return: Word error rate. :rtype: float - :raises ValueError: If the reference length is zero. + :raises ValueError: If word number of reference is zero. """ - if ignore_case == True: - reference = reference.lower() - hypothesis = hypothesis.lower() - - ref_words = filter(None, reference.split(delimiter)) - hyp_words = filter(None, hypothesis.split(delimiter)) - - if len(ref_words) == 0: - raise ValueError("Reference's word number should be greater than 0.") - - edit_distance = _levenshtein_distance(ref_words, hyp_words) - wer = float(edit_distance) / len(ref_words) + edit_distance, ref_len = word_errors(reference, hypothesis, ignore_case, + delimiter) + wer = float(edit_distance) / ref_len return wer @@ -135,20 +190,7 @@ def cer(reference, hypothesis, ignore_case=False, remove_space=False): :rtype: float :raises ValueError: If the reference length is zero. """ - if ignore_case == True: - reference = reference.lower() - hypothesis = hypothesis.lower() - - join_char = ' ' - if remove_space == True: - join_char = '' - - reference = join_char.join(filter(None, reference.split(' '))) - hypothesis = join_char.join(filter(None, hypothesis.split(' '))) - - if len(reference) == 0: - raise ValueError("Length of reference should be greater than 0.") - - edit_distance = _levenshtein_distance(reference, hypothesis) - cer = float(edit_distance) / len(reference) + edit_distance, ref_len = char_errors(reference, hypothesis, ignore_case, + remove_space) + cer = float(edit_distance) / len(ref_len) return cer From 0f9b3ebf0e75ed16e4748717589b962ec4747576 Mon Sep 17 00:00:00 2001 From: yangyaming Date: Fri, 1 Dec 2017 19:46:29 +0800 Subject: [PATCH 2/2] Move exception throwing logic to cer and wer. --- utils/error_rate.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/utils/error_rate.py b/utils/error_rate.py index 2ff3f6960..9aa900174 100644 --- a/utils/error_rate.py +++ b/utils/error_rate.py @@ -70,7 +70,6 @@ def word_errors(reference, hypothesis, ignore_case=False, delimiter=' '): :type delimiter: char :return: Levenshtein distance and word number of reference sentence. :rtype: list - :raises ValueError: If word number of reference sentence is zero. """ if ignore_case == True: reference = reference.lower() @@ -79,9 +78,6 @@ def word_errors(reference, hypothesis, ignore_case=False, delimiter=' '): ref_words = filter(None, reference.split(delimiter)) hyp_words = filter(None, hypothesis.split(delimiter)) - if len(ref_words) == 0: - raise ValueError("Reference's word number should be greater than 0.") - edit_distance = _levenshtein_distance(ref_words, hyp_words) return float(edit_distance), len(ref_words) @@ -100,7 +96,6 @@ def char_errors(reference, hypothesis, ignore_case=False, remove_space=False): :type remove_space: bool :return: Levenshtein distance and length of reference sentence. :rtype: list - :raises ValueError: If the reference length is zero. """ if ignore_case == True: reference = reference.lower() @@ -113,9 +108,6 @@ def char_errors(reference, hypothesis, ignore_case=False, remove_space=False): reference = join_char.join(filter(None, reference.split(' '))) hypothesis = join_char.join(filter(None, hypothesis.split(' '))) - if len(reference) == 0: - raise ValueError("Length of reference should be greater than 0.") - edit_distance = _levenshtein_distance(reference, hypothesis) return float(edit_distance), len(reference) @@ -153,6 +145,10 @@ def wer(reference, hypothesis, ignore_case=False, delimiter=' '): """ edit_distance, ref_len = word_errors(reference, hypothesis, ignore_case, delimiter) + + if ref_len == 0: + raise ValueError("Reference's word number should be greater than 0.") + wer = float(edit_distance) / ref_len return wer @@ -192,5 +188,9 @@ def cer(reference, hypothesis, ignore_case=False, remove_space=False): """ edit_distance, ref_len = char_errors(reference, hypothesis, ignore_case, remove_space) - cer = float(edit_distance) / len(ref_len) + + if ref_len == 0: + raise ValueError("Length of reference should be greater than 0.") + + cer = float(edit_distance) / ref_len return cer