Follow comments.

pull/2/head
yangyaming 8 years ago
parent 26eb54eb37
commit def66a3223

@ -2,14 +2,20 @@
"""This module provides functions to calculate error rate in different level. """This module provides functions to calculate error rate in different level.
e.g. wer for word-level, cer for char-level. e.g. wer for word-level, cer for char-level.
""" """
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import numpy as np import numpy as np
def levenshtein_distance(ref, hyp): def _levenshtein_distance(ref, hyp):
"""Levenshtein distance is a string metric for measuring the difference between
two sequences. Informally, the levenshtein disctance is defined as the minimum
number of single-character edits (substitutions, insertions or deletions)
required to change one word into the other. We can naturally extend the edits to
word level when calculate levenshtein disctance for two sentences.
"""
ref_len = len(ref) ref_len = len(ref)
hyp_len = len(hyp) hyp_len = len(hyp)
@ -72,7 +78,7 @@ def wer(reference, hypothesis, ignore_case=False, delimiter=' '):
:type delimiter: char :type delimiter: char
:return: Word error rate. :return: Word error rate.
:rtype: float :rtype: float
:raises ValueError: If reference length is zero. :raises ValueError: If the reference length is zero.
""" """
if ignore_case == True: if ignore_case == True:
reference = reference.lower() reference = reference.lower()
@ -84,7 +90,7 @@ def wer(reference, hypothesis, ignore_case=False, delimiter=' '):
if len(ref_words) == 0: if len(ref_words) == 0:
raise ValueError("Reference's word number should be greater than 0.") raise ValueError("Reference's word number should be greater than 0.")
edit_distance = levenshtein_distance(ref_words, hyp_words) edit_distance = _levenshtein_distance(ref_words, hyp_words)
wer = float(edit_distance) / len(ref_words) wer = float(edit_distance) / len(ref_words)
return wer return wer
@ -118,7 +124,7 @@ def cer(reference, hypothesis, ignore_case=False):
:type ignore_case: bool :type ignore_case: bool
:return: Character error rate. :return: Character error rate.
:rtype: float :rtype: float
:raises ValueError: If reference length is zero. :raises ValueError: If the reference length is zero.
""" """
if ignore_case == True: if ignore_case == True:
reference = reference.lower() reference = reference.lower()
@ -130,6 +136,6 @@ def cer(reference, hypothesis, ignore_case=False):
if len(reference) == 0: if len(reference) == 0:
raise ValueError("Length of reference should be greater than 0.") raise ValueError("Length of reference should be greater than 0.")
edit_distance = levenshtein_distance(reference, hypothesis) edit_distance = _levenshtein_distance(reference, hypothesis)
cer = float(edit_distance) / len(reference) cer = float(edit_distance) / len(reference)
return cer return cer

@ -23,10 +23,8 @@ class TestParse(unittest.TestCase):
def test_wer_3(self): def test_wer_3(self):
ref = ' ' ref = ' '
hyp = 'Hypothesis sentence' hyp = 'Hypothesis sentence'
try: with self.assertRaises(ValueError):
word_error_rate = error_rate.wer(ref, hyp) word_error_rate = error_rate.wer(ref, hyp)
except Exception as e:
self.assertTrue(isinstance(e, ValueError))
def test_cer_1(self): def test_cer_1(self):
ref = 'werewolf' ref = 'werewolf'
@ -53,10 +51,8 @@ class TestParse(unittest.TestCase):
def test_cer_5(self): def test_cer_5(self):
ref = '' ref = ''
hyp = 'Hypothesis' hyp = 'Hypothesis'
try: with self.assertRaises(ValueError):
char_error_rate = error_rate.cer(ref, hyp) char_error_rate = error_rate.cer(ref, hyp)
except Exception as e:
self.assertTrue(isinstance(e, ValueError))
if __name__ == '__main__': if __name__ == '__main__':

Loading…
Cancel
Save