You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
PaddleSpeech/paddlespeech/s2t/utils/error_rate.py

365 lines
13 KiB

3 years ago
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
7 years ago
"""This module provides functions to calculate error rate in different level.
e.g. wer for word-level, cer for char-level.
7 years ago
"""
3 years ago
from itertools import groupby
import editdistance
import numpy as np
__all__ = ['word_errors', 'char_errors', 'wer', 'cer', "ErrorCalculator"]
7 years ago
def _levenshtein_distance(ref, hyp):
"""Levenshtein distance is a string metric for measuring the difference
between two sequences. Informally, the levenshtein disctance is defined as
the minimum number of single-character edits (substitutions, insertions or
deletions) required to change one word into the other. We can naturally
extend the edits to word level when calculate levenshtein disctance for
two sentences.
7 years ago
"""
m = len(ref)
n = len(hyp)
# special case
if ref == hyp:
return 0
if m == 0:
return n
if n == 0:
return m
if m < n:
ref, hyp = hyp, ref
m, n = n, m
# use O(min(m, n)) space
distance = np.zeros((2, n + 1), dtype=np.int32)
7 years ago
# initialize distance matrix
for j in range(n + 1):
distance[0][j] = j
# calculate levenshtein distance
for i in range(1, m + 1):
prev_row_idx = (i - 1) % 2
cur_row_idx = i % 2
distance[cur_row_idx][0] = i
for j in range(1, n + 1):
if ref[i - 1] == hyp[j - 1]:
distance[cur_row_idx][j] = distance[prev_row_idx][j - 1]
else:
s_num = distance[prev_row_idx][j - 1] + 1
i_num = distance[cur_row_idx][j - 1] + 1
d_num = distance[prev_row_idx][j] + 1
distance[cur_row_idx][j] = min(s_num, i_num, d_num)
return distance[m % 2][n]
def word_errors(reference, hypothesis, ignore_case=False, delimiter=' '):
"""Compute the levenshtein distance between reference sequence and
hypothesis sequence in word-level.
:param reference: The reference sentence.
3 years ago
:type reference: str
:param hypothesis: The hypothesis sentence.
3 years ago
:type hypothesis: str
:param ignore_case: Whether case-sensitive or not.
:type ignore_case: bool
:param delimiter: Delimiter of input sentences.
:type delimiter: char
:return: Levenshtein distance and word number of reference sentence.
:rtype: list
"""
E2E/Streaming Transformer/Conformer ASR (#578) * add cmvn and label smoothing loss layer * add layer for transformer * add glu and conformer conv * add torch compatiable hack, mask funcs * not hack size since it exists * add test; attention * add attention, common utils, hack paddle * add audio utils * conformer batch padding mask bug fix #223 * fix typo, python infer fix rnn mem opt name error and batchnorm1d, will be available at 2.0.2 * fix ci * fix ci * add encoder * refactor egs * add decoder * refactor ctc, add ctc align, refactor ckpt, add warmup lr scheduler, cmvn utils * refactor docs * add fix * fix readme * fix bugs, refactor collator, add pad_sequence, fix ckpt bugs * fix docstring * refactor data feed order * add u2 model * refactor cmvn, test * add utils * add u2 config * fix bugs * fix bugs * fix autograd maybe has problem when using inplace operation * refactor data, build vocab; add format data * fix text featurizer * refactor build vocab * add fbank, refactor feature of speech * refactor audio feat * refactor data preprare * refactor data * model init from config * add u2 bins * flake8 * can train * fix bugs, add coverage, add scripts * test can run * fix data * speed perturb with sox * add spec aug * fix for train * fix train logitc * fix logger * log valid loss, time dataset process * using np for speed perturb, remove some debug log of grad clip * fix logger * fix build vocab * fix logger name * using module logger as default * fix * fix install * reorder imports * fix board logger * fix logger * kaldi fbank and mfcc * fix cmvn and print prarams * fix add_eos_sos and cmvn * fix cmvn compute * fix logger and cmvn * fix subsampling, label smoothing loss, remove useless * add notebook test * fix log * fix tb logger * multi gpu valid * fix log * fix log * fix config * fix compute cmvn, need paddle 2.1 * add cmvn notebook * fix layer tools * fix compute cmvn * add rtf * fix decoding * fix layer tools * fix log, add avg script * more avg and test info * fix dataset pickle problem; using 2.1 paddle; num_workers can > 0; ckpt save in exp dir;fix setup.sh; * add vimrc * refactor tiny script, add transformer and stream conf * spm demo; librisppech scripts and confs * fix log * add librispeech scripts * refactor data pipe; fix conf; fix u2 default params * fix bugs * refactor aishell scripts * fix test * fix cmvn * fix s0 scripts * fix ds2 scripts and bugs * fix dev & test dataset filter * fix dataset filter * filter dev * fix ckpt path * filter test, since librispeech will cause OOM, but all test wer will be worse, since mismatch train with test * add comment * add syllable doc * fix ds2 configs * add doc * add pypinyin tools * fix decoder using blank_id=0 * mmseg with pybind11 * format code
3 years ago
if ignore_case:
reference = reference.lower()
hypothesis = hypothesis.lower()
3 years ago
ref_words = list(filter(None, reference.split(delimiter)))
hyp_words = list(filter(None, hypothesis.split(delimiter)))
edit_distance = _levenshtein_distance(ref_words, hyp_words)
# `editdistance.eavl precision` less than `_levenshtein_distance`
# edit_distance = editdistance.eval(ref_words, hyp_words)
return float(edit_distance), len(ref_words)
def char_errors(reference, hypothesis, ignore_case=False, remove_space=False):
"""Compute the levenshtein distance between reference sequence and
hypothesis sequence in char-level.
:param reference: The reference sentence.
3 years ago
:type reference: str
:param hypothesis: The hypothesis sentence.
3 years ago
:type hypothesis: str
:param ignore_case: Whether case-sensitive or not.
:type ignore_case: bool
:param remove_space: Whether remove internal space characters
:type remove_space: bool
:return: Levenshtein distance and length of reference sentence.
:rtype: list
"""
E2E/Streaming Transformer/Conformer ASR (#578) * add cmvn and label smoothing loss layer * add layer for transformer * add glu and conformer conv * add torch compatiable hack, mask funcs * not hack size since it exists * add test; attention * add attention, common utils, hack paddle * add audio utils * conformer batch padding mask bug fix #223 * fix typo, python infer fix rnn mem opt name error and batchnorm1d, will be available at 2.0.2 * fix ci * fix ci * add encoder * refactor egs * add decoder * refactor ctc, add ctc align, refactor ckpt, add warmup lr scheduler, cmvn utils * refactor docs * add fix * fix readme * fix bugs, refactor collator, add pad_sequence, fix ckpt bugs * fix docstring * refactor data feed order * add u2 model * refactor cmvn, test * add utils * add u2 config * fix bugs * fix bugs * fix autograd maybe has problem when using inplace operation * refactor data, build vocab; add format data * fix text featurizer * refactor build vocab * add fbank, refactor feature of speech * refactor audio feat * refactor data preprare * refactor data * model init from config * add u2 bins * flake8 * can train * fix bugs, add coverage, add scripts * test can run * fix data * speed perturb with sox * add spec aug * fix for train * fix train logitc * fix logger * log valid loss, time dataset process * using np for speed perturb, remove some debug log of grad clip * fix logger * fix build vocab * fix logger name * using module logger as default * fix * fix install * reorder imports * fix board logger * fix logger * kaldi fbank and mfcc * fix cmvn and print prarams * fix add_eos_sos and cmvn * fix cmvn compute * fix logger and cmvn * fix subsampling, label smoothing loss, remove useless * add notebook test * fix log * fix tb logger * multi gpu valid * fix log * fix log * fix config * fix compute cmvn, need paddle 2.1 * add cmvn notebook * fix layer tools * fix compute cmvn * add rtf * fix decoding * fix layer tools * fix log, add avg script * more avg and test info * fix dataset pickle problem; using 2.1 paddle; num_workers can > 0; ckpt save in exp dir;fix setup.sh; * add vimrc * refactor tiny script, add transformer and stream conf * spm demo; librisppech scripts and confs * fix log * add librispeech scripts * refactor data pipe; fix conf; fix u2 default params * fix bugs * refactor aishell scripts * fix test * fix cmvn * fix s0 scripts * fix ds2 scripts and bugs * fix dev & test dataset filter * fix dataset filter * filter dev * fix ckpt path * filter test, since librispeech will cause OOM, but all test wer will be worse, since mismatch train with test * add comment * add syllable doc * fix ds2 configs * add doc * add pypinyin tools * fix decoder using blank_id=0 * mmseg with pybind11 * format code
3 years ago
if ignore_case:
reference = reference.lower()
hypothesis = hypothesis.lower()
join_char = ' '
E2E/Streaming Transformer/Conformer ASR (#578) * add cmvn and label smoothing loss layer * add layer for transformer * add glu and conformer conv * add torch compatiable hack, mask funcs * not hack size since it exists * add test; attention * add attention, common utils, hack paddle * add audio utils * conformer batch padding mask bug fix #223 * fix typo, python infer fix rnn mem opt name error and batchnorm1d, will be available at 2.0.2 * fix ci * fix ci * add encoder * refactor egs * add decoder * refactor ctc, add ctc align, refactor ckpt, add warmup lr scheduler, cmvn utils * refactor docs * add fix * fix readme * fix bugs, refactor collator, add pad_sequence, fix ckpt bugs * fix docstring * refactor data feed order * add u2 model * refactor cmvn, test * add utils * add u2 config * fix bugs * fix bugs * fix autograd maybe has problem when using inplace operation * refactor data, build vocab; add format data * fix text featurizer * refactor build vocab * add fbank, refactor feature of speech * refactor audio feat * refactor data preprare * refactor data * model init from config * add u2 bins * flake8 * can train * fix bugs, add coverage, add scripts * test can run * fix data * speed perturb with sox * add spec aug * fix for train * fix train logitc * fix logger * log valid loss, time dataset process * using np for speed perturb, remove some debug log of grad clip * fix logger * fix build vocab * fix logger name * using module logger as default * fix * fix install * reorder imports * fix board logger * fix logger * kaldi fbank and mfcc * fix cmvn and print prarams * fix add_eos_sos and cmvn * fix cmvn compute * fix logger and cmvn * fix subsampling, label smoothing loss, remove useless * add notebook test * fix log * fix tb logger * multi gpu valid * fix log * fix log * fix config * fix compute cmvn, need paddle 2.1 * add cmvn notebook * fix layer tools * fix compute cmvn * add rtf * fix decoding * fix layer tools * fix log, add avg script * more avg and test info * fix dataset pickle problem; using 2.1 paddle; num_workers can > 0; ckpt save in exp dir;fix setup.sh; * add vimrc * refactor tiny script, add transformer and stream conf * spm demo; librisppech scripts and confs * fix log * add librispeech scripts * refactor data pipe; fix conf; fix u2 default params * fix bugs * refactor aishell scripts * fix test * fix cmvn * fix s0 scripts * fix ds2 scripts and bugs * fix dev & test dataset filter * fix dataset filter * filter dev * fix ckpt path * filter test, since librispeech will cause OOM, but all test wer will be worse, since mismatch train with test * add comment * add syllable doc * fix ds2 configs * add doc * add pypinyin tools * fix decoder using blank_id=0 * mmseg with pybind11 * format code
3 years ago
if remove_space:
join_char = ''
3 years ago
reference = join_char.join(list(filter(None, reference.split(' '))))
hypothesis = join_char.join(list(filter(None, hypothesis.split(' '))))
edit_distance = _levenshtein_distance(reference, hypothesis)
# `editdistance.eavl precision` less than `_levenshtein_distance`
# edit_distance = editdistance.eval(reference, hypothesis)
return float(edit_distance), len(reference)
7 years ago
def wer(reference, hypothesis, ignore_case=False, delimiter=' '):
"""Calculate word error rate (WER). WER compares reference text and
7 years ago
hypothesis text in word-level. WER is defined as:
.. math::
WER = (Sw + Dw + Iw) / Nw
where
.. code-block:: text
Sw is the number of words subsituted,
Dw is the number of words deleted,
Iw is the number of words inserted,
Nw is the number of words in the reference
We can use levenshtein distance to calculate WER. Please draw an attention
that empty items will be removed when splitting sentences by delimiter.
:param reference: The reference sentence.
3 years ago
:type reference: str
7 years ago
:param hypothesis: The hypothesis sentence.
3 years ago
:type hypothesis: str
7 years ago
:param ignore_case: Whether case-sensitive or not.
:type ignore_case: bool
:param delimiter: Delimiter of input sentences.
:type delimiter: char
7 years ago
:return: Word error rate.
:rtype: float
:raises ValueError: If word number of reference is zero.
"""
edit_distance, ref_len = word_errors(reference, hypothesis, ignore_case,
delimiter)
if ref_len == 0:
raise ValueError("Reference's word number should be greater than 0.")
wer = float(edit_distance) / ref_len
return wer
def cer(reference, hypothesis, ignore_case=False, remove_space=False):
7 years ago
"""Calculate charactor error rate (CER). CER compares reference text and
7 years ago
hypothesis text in char-level. CER is defined as:
.. math::
CER = (Sc + Dc + Ic) / Nc
where
.. code-block:: text
7 years ago
Sc is the number of characters substituted,
Dc is the number of characters deleted,
Ic is the number of characters inserted
Nc is the number of characters in the reference
We can use levenshtein distance to calculate CER. Chinese input should be
encoded to unicode. Please draw an attention that the leading and tailing
space characters will be truncated and multiple consecutive space
characters in a sentence will be replaced by one space character.
:param reference: The reference sentence.
3 years ago
:type reference: str
7 years ago
:param hypothesis: The hypothesis sentence.
3 years ago
:type hypothesis: str
:param ignore_case: Whether case-sensitive or not.
:type ignore_case: bool
:param remove_space: Whether remove internal space characters
:type remove_space: bool
7 years ago
:return: Character error rate.
:rtype: float
7 years ago
:raises ValueError: If the reference length is zero.
"""
edit_distance, ref_len = char_errors(reference, hypothesis, ignore_case,
remove_space)
if ref_len == 0:
raise ValueError("Length of reference should be greater than 0.")
cer = float(edit_distance) / ref_len
return cer
class ErrorCalculator():
"""Calculate CER and WER for E2E_ASR and CTC models during training.
:param y_hats: numpy array with predicted text
:param y_pads: numpy array with true (target) text
:param char_list: List[str]
:param sym_space: <space>
:param sym_blank: <blank>
:return:
"""
3 years ago
def __init__(self,
char_list,
sym_space,
sym_blank,
report_cer=False,
report_wer=False):
"""Construct an ErrorCalculator object."""
super().__init__()
self.report_cer = report_cer
self.report_wer = report_wer
self.char_list = char_list
self.space = sym_space
self.blank = sym_blank
self.idx_blank = self.char_list.index(self.blank)
if self.space in self.char_list:
self.idx_space = self.char_list.index(self.space)
else:
self.idx_space = None
def __call__(self, ys_hat, ys_pad, is_ctc=False):
"""Calculate sentence-level WER/CER score.
:param paddle.Tensor ys_hat: prediction (batch, seqlen)
:param paddle.Tensor ys_pad: reference (batch, seqlen)
:param bool is_ctc: calculate CER score for CTC
:return: sentence-level WER score
:rtype float
:return: sentence-level CER score
:rtype float
"""
cer, wer = None, None
if is_ctc:
return self.calculate_cer_ctc(ys_hat, ys_pad)
elif not self.report_cer and not self.report_wer:
return cer, wer
seqs_hat, seqs_true = self.convert_to_char(ys_hat, ys_pad)
if self.report_cer:
cer = self.calculate_cer(seqs_hat, seqs_true)
if self.report_wer:
wer = self.calculate_wer(seqs_hat, seqs_true)
return cer, wer
def calculate_cer_ctc(self, ys_hat, ys_pad):
"""Calculate sentence-level CER score for CTC.
:param paddle.Tensor ys_hat: prediction (batch, seqlen)
:param paddle.Tensor ys_pad: reference (batch, seqlen)
:return: average sentence-level CER score
:rtype float
"""
cers, char_ref_lens = [], []
for i, y in enumerate(ys_hat):
y_hat = [x[0] for x in groupby(y)]
y_true = ys_pad[i]
seq_hat, seq_true = [], []
for idx in y_hat:
idx = int(idx)
if idx != -1 and idx != self.idx_blank and idx != self.idx_space:
seq_hat.append(self.char_list[int(idx)])
for idx in y_true:
idx = int(idx)
if idx != -1 and idx != self.idx_blank and idx != self.idx_space:
seq_true.append(self.char_list[int(idx)])
hyp_chars = "".join(seq_hat)
ref_chars = "".join(seq_true)
if len(ref_chars) > 0:
cers.append(editdistance.eval(hyp_chars, ref_chars))
char_ref_lens.append(len(ref_chars))
cer_ctc = float(sum(cers)) / sum(char_ref_lens) if cers else None
return cer_ctc
def convert_to_char(self, ys_hat, ys_pad):
"""Convert index to character.
:param paddle.Tensor seqs_hat: prediction (batch, seqlen)
:param paddle.Tensor seqs_true: reference (batch, seqlen)
:return: token list of prediction
:rtype list
:return: token list of reference
:rtype list
"""
seqs_hat, seqs_true = [], []
for i, y_hat in enumerate(ys_hat):
y_true = ys_pad[i]
eos_true = np.where(y_true == -1)[0]
ymax = eos_true[0] if len(eos_true) > 0 else len(y_true)
# NOTE: padding index (-1) in y_true is used to pad y_hat
seq_hat = [self.char_list[int(idx)] for idx in y_hat[:ymax]]
3 years ago
seq_true = [
self.char_list[int(idx)] for idx in y_true if int(idx) != -1
]
seq_hat_text = "".join(seq_hat).replace(self.space, " ")
seq_hat_text = seq_hat_text.replace(self.blank, "")
seq_true_text = "".join(seq_true).replace(self.space, " ")
seqs_hat.append(seq_hat_text)
seqs_true.append(seq_true_text)
return seqs_hat, seqs_true
def calculate_cer(self, seqs_hat, seqs_true):
"""Calculate sentence-level CER score.
:param list seqs_hat: prediction
:param list seqs_true: reference
:return: average sentence-level CER score
:rtype float
"""
char_eds, char_ref_lens = [], []
for i, seq_hat_text in enumerate(seqs_hat):
seq_true_text = seqs_true[i]
hyp_chars = seq_hat_text.replace(" ", "")
ref_chars = seq_true_text.replace(" ", "")
char_eds.append(editdistance.eval(hyp_chars, ref_chars))
char_ref_lens.append(len(ref_chars))
return float(sum(char_eds)) / sum(char_ref_lens)
def calculate_wer(self, seqs_hat, seqs_true):
"""Calculate sentence-level WER score.
:param list seqs_hat: prediction
:param list seqs_true: reference
:return: average sentence-level WER score
:rtype float
"""
word_eds, word_ref_lens = [], []
for i, seq_hat_text in enumerate(seqs_hat):
seq_true_text = seqs_true[i]
hyp_words = seq_hat_text.split()
ref_words = seq_true_text.split()
word_eds.append(editdistance.eval(hyp_words, ref_words))
word_ref_lens.append(len(ref_words))
return float(sum(word_eds)) / sum(word_ref_lens)