fix decoding

pull/578/head
Hui Zhang 4 years ago
parent 8c5b8e355b
commit b5339633e3

@ -123,11 +123,7 @@ if not hasattr(paddle, 'cat'):
########### hcak paddle.Tensor ############# ########### hcak paddle.Tensor #############
def item(x: paddle.Tensor): def item(x: paddle.Tensor):
if x.dtype == paddle.fluid.core_avx.VarDesc.VarType.FP32: return x.numpy().item()
return float(x)
else:
raise ValueError("not support")
if not hasattr(paddle.Tensor, 'item'): if not hasattr(paddle.Tensor, 'item'):
logger.warn( logger.warn(

@ -381,8 +381,8 @@ class U2Tester(U2Trainer):
decoding_chunk_size=cfg.decoding_chunk_size, decoding_chunk_size=cfg.decoding_chunk_size,
num_decoding_left_chunks=cfg.num_decoding_left_chunks, num_decoding_left_chunks=cfg.num_decoding_left_chunks,
simulate_streaming=cfg.simulate_streaming) simulate_streaming=cfg.simulate_streaming)
decode_time = time.time() decode_time = time.time() - start_time
for target, result in zip(target_transcripts, result_transcripts): for target, result in zip(target_transcripts, result_transcripts):
errors, len_ref = errors_func(target, result) errors, len_ref = errors_func(target, result)
errors_sum += errors errors_sum += errors
@ -392,13 +392,13 @@ class U2Tester(U2Trainer):
fout.write(result + "\n") fout.write(result + "\n")
logger.info("\nTarget Transcription: %s\nOutput Transcription: %s" % logger.info("\nTarget Transcription: %s\nOutput Transcription: %s" %
(target, result)) (target, result))
logger.info("Current error rate [%s] = %f" % logger.info("One example error rate [%s] = %f" %
(cfg.error_rate_type, error_rate_func(target, result))) (cfg.error_rate_type, error_rate_func(target, result)))
return dict( return dict(
errors_sum=errors_sum, errors_sum=errors_sum,
len_refs=len_refs, len_refs=len_refs,
num_ins=num_ins, # num examples num_ins=num_ins, # num examples
error_rate=errors_sum / len_refs, error_rate=errors_sum / len_refs,
error_rate_type=cfg.error_rate_type, error_rate_type=cfg.error_rate_type,
num_frames=audio_len.sum().numpy().item(), num_frames=audio_len.sum().numpy().item(),
@ -411,6 +411,7 @@ class U2Tester(U2Trainer):
self.model.eval() self.model.eval()
logger.info(f"Test Total Examples: {len(self.test_loader.dataset)}") logger.info(f"Test Total Examples: {len(self.test_loader.dataset)}")
stride_ms = self.test_loader.dataset.stride_ms
error_rate_type = None error_rate_type = None
errors_sum, len_refs, num_ins = 0.0, 0, 0 errors_sum, len_refs, num_ins = 0.0, 0, 0
num_frames = 0.0 num_frames = 0.0
@ -424,11 +425,12 @@ class U2Tester(U2Trainer):
len_refs += metrics['len_refs'] len_refs += metrics['len_refs']
num_ins += metrics['num_ins'] num_ins += metrics['num_ins']
error_rate_type = metrics['error_rate_type'] error_rate_type = metrics['error_rate_type']
logger.info("Error rate [%s] (%d/?) = %f" % rtf = num_time / (num_frames * stride_ms)
(error_rate_type, num_ins, errors_sum / len_refs)) logger.info(
"RTF: %f, Error rate [%s] (%d/?) = %f" %
(rtf, error_rate_type, num_ins, errors_sum / len_refs))
rtf = num_time / (num_frames * self.test_loader.dataset.stride_ms / 1000.0) rtf = num_time / (num_frames * stride_ms)
# logging
msg = "Test: " msg = "Test: "
msg += "epoch: {}, ".format(self.epoch) msg += "epoch: {}, ".format(self.epoch)
msg += "step: {}, ".format(self.iteration) msg += "step: {}, ".format(self.iteration)

@ -108,7 +108,7 @@ class AudioFeaturizer(object):
@property @property
def stride_ms(self): def stride_ms(self):
return self._stride_ms return self._stride_ms
@property @property
def feature_size(self): def feature_size(self):
"""audio feature size""" """audio feature size"""

@ -148,7 +148,7 @@ class SpeechFeaturizer(object):
float: time(ms)/frame float: time(ms)/frame
""" """
return self._audio_featurizer.stride_ms return self._audio_featurizer.stride_ms
@property @property
def text_feature(self): def text_feature(self):
"""Return the text feature object. """Return the text feature object.

@ -63,7 +63,7 @@ class ManifestDataset(Dataset):
specgram_type='linear', # 'linear', 'mfcc', 'fbank' specgram_type='linear', # 'linear', 'mfcc', 'fbank'
feat_dim=0, # 'mfcc', 'fbank' feat_dim=0, # 'mfcc', 'fbank'
delta_delta=False, # 'mfcc', 'fbank' delta_delta=False, # 'mfcc', 'fbank'
dither=1.0, # feature dither dither=1.0, # feature dither
target_sample_rate=16000, # target sample rate target_sample_rate=16000, # target sample rate
use_dB_normalization=True, use_dB_normalization=True,
target_dB=-20, target_dB=-20,
@ -188,8 +188,7 @@ class ManifestDataset(Dataset):
super().__init__() super().__init__()
self._stride_ms = stride_ms self._stride_ms = stride_ms
self._target_sample_rate = target_sample_rate self._target_sample_rate = target_sample_rate
self._normalizer = FeatureNormalizer( self._normalizer = FeatureNormalizer(
mean_std_filepath) if mean_std_filepath else None mean_std_filepath) if mean_std_filepath else None
self._augmentation_pipeline = AugmentationPipeline( self._augmentation_pipeline = AugmentationPipeline(
@ -251,7 +250,7 @@ class ManifestDataset(Dataset):
@property @property
def feature_size(self): def feature_size(self):
return self._speech_featurizer.feature_size return self._speech_featurizer.feature_size
@property @property
def stride_ms(self): def stride_ms(self):
return self._speech_featurizer.stride_ms return self._speech_featurizer.stride_ms

@ -49,10 +49,10 @@ from deepspeech.utils.tensor_utils import pad_sequence
from deepspeech.utils.tensor_utils import th_accuracy from deepspeech.utils.tensor_utils import th_accuracy
from deepspeech.utils.utility import log_add from deepspeech.utils.utility import log_add
logger = Log(__name__).getlog()
__all__ = ["U2Model", "U2InferModel"] __all__ = ["U2Model", "U2InferModel"]
logger = Log(__name__).getlog()
class U2BaseModel(nn.Module): class U2BaseModel(nn.Module):
"""CTC-Attention hybrid Encoder-Decoder model""" """CTC-Attention hybrid Encoder-Decoder model"""
@ -398,14 +398,17 @@ class U2BaseModel(nn.Module):
assert decoding_chunk_size != 0 assert decoding_chunk_size != 0
batch_size = speech.shape[0] batch_size = speech.shape[0]
# Let's assume B = batch_size # Let's assume B = batch_size
# encoder_out: (B, maxlen, encoder_dim)
# encoder_mask: (B, 1, Tmax)
encoder_out, encoder_mask = self._forward_encoder( encoder_out, encoder_mask = self._forward_encoder(
speech, speech_lengths, decoding_chunk_size, speech, speech_lengths, decoding_chunk_size,
num_decoding_left_chunks, num_decoding_left_chunks, simulate_streaming)
simulate_streaming) # (B, maxlen, encoder_dim)
maxlen = encoder_out.size(1) maxlen = encoder_out.size(1)
encoder_out_lens = encoder_mask.squeeze(1).sum(1) # (TODO Hui Zhang): bool no support reduce_sum
# encoder_out_lens = encoder_mask.squeeze(1).sum(1)
encoder_out_lens = encoder_mask.squeeze(1).astype(paddle.int).sum(1)
ctc_probs = self.ctc.log_softmax(encoder_out) # (B, maxlen, vocab_size) ctc_probs = self.ctc.log_softmax(encoder_out) # (B, maxlen, vocab_size)
topk_prob, topk_index = ctc_probs.topk(1, dim=2) # (B, maxlen, 1) topk_prob, topk_index = ctc_probs.topk(1, axis=2) # (B, maxlen, 1)
topk_index = topk_index.view(batch_size, maxlen) # (B, maxlen) topk_index = topk_index.view(batch_size, maxlen) # (B, maxlen)
pad_mask = make_pad_mask(encoder_out_lens) # (B, maxlen) pad_mask = make_pad_mask(encoder_out_lens) # (B, maxlen)
topk_index = topk_index.masked_fill_(pad_mask, self.eos) # (B, maxlen) topk_index = topk_index.masked_fill_(pad_mask, self.eos) # (B, maxlen)
@ -573,11 +576,11 @@ class U2BaseModel(nn.Module):
hyps_lens = hyps_lens + 1 # Add <sos> at begining hyps_lens = hyps_lens + 1 # Add <sos> at begining
encoder_out = encoder_out.repeat(beam_size, 1, 1) encoder_out = encoder_out.repeat(beam_size, 1, 1)
encoder_mask = paddle.ones( encoder_mask = paddle.ones(
beam_size, 1, encoder_out.size(1), dtype=paddle.bool) (beam_size, 1, encoder_out.size(1)), dtype=paddle.bool)
decoder_out, _ = self.decoder( decoder_out, _ = self.decoder(
encoder_out, encoder_mask, hyps_pad, encoder_out, encoder_mask, hyps_pad,
hyps_lens) # (beam_size, max_hyps_len, vocab_size) hyps_lens) # (beam_size, max_hyps_len, vocab_size)
decoder_out = paddle.nn.functional.log_softmax(decoder_out, dim=-1) decoder_out = paddle.nn.functional.log_softmax(decoder_out, axis=-1)
decoder_out = decoder_out.numpy() decoder_out = decoder_out.numpy()
# Only use decoder score for rescoring # Only use decoder score for rescoring
best_score = -float('inf') best_score = -float('inf')

@ -66,7 +66,9 @@ def pad_sequence(sequences: List[paddle.Tensor],
# assuming trailing dimensions and type of all the Tensors # assuming trailing dimensions and type of all the Tensors
# in sequences are same and fetching those from sequences[0] # in sequences are same and fetching those from sequences[0]
max_size = sequences[0].size() max_size = sequences[0].size()
trailing_dims = max_size[1:] # (TODO Hui Zhang): slice not supprot `end==start`
# trailing_dims = max_size[1:]
trailing_dims = max_size[1:] if max_size.ndim >= 2 else ()
max_len = max([s.size(0) for s in sequences]) max_len = max([s.size(0) for s in sequences])
if batch_first: if batch_first:
out_dims = (len(sequences), max_len) + trailing_dims out_dims = (len(sequences), max_len) + trailing_dims

@ -1,71 +1,83 @@
#! /usr/bin/env bash #! /usr/bin/env bash
stage=-1
stop_stage=100
source ${MAIN_ROOT}/utils/parse_options.sh
mkdir -p data mkdir -p data
TARGET_DIR=${MAIN_ROOT}/examples/dataset TARGET_DIR=${MAIN_ROOT}/examples/dataset
mkdir -p ${TARGET_DIR} mkdir -p ${TARGET_DIR}
# download data, generate manifests if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
python3 ${TARGET_DIR}/aishell/aishell.py \ # download data, generate manifests
--manifest_prefix="data/manifest" \ python3 ${TARGET_DIR}/aishell/aishell.py \
--target_dir="${TARGET_DIR}/aishell" --manifest_prefix="data/manifest" \
--target_dir="${TARGET_DIR}/aishell"
if [ $? -ne 0 ]; then
echo "Prepare Aishell failed. Terminated."
exit 1
fi
if [ $? -ne 0 ]; then for dataset in train dev test; do
echo "Prepare Aishell failed. Terminated." mv data/manifest.${dataset} data/manifest.${dataset}.raw
exit 1 done
fi fi
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
for dataset in train dev test; do # download data, generate manifests
mv data/manifest.${dataset} data/manifest.${dataset}.raw # build vocabulary
done python3 ${MAIN_ROOT}/utils/build_vocab.py \
--unit_type="char" \
--count_threshold=0 \
# build vocabulary --vocab_path="data/vocab.txt" \
python3 ${MAIN_ROOT}/utils/build_vocab.py \ --manifest_paths "data/manifest.train.raw"
--unit_type="char" \
--count_threshold=0 \ if [ $? -ne 0 ]; then
--vocab_path="data/vocab.txt" \ echo "Build vocabulary failed. Terminated."
--manifest_paths "data/manifest.train.raw" exit 1
fi
if [ $? -ne 0 ]; then
echo "Build vocabulary failed. Terminated."
exit 1
fi fi
# compute mean and stddev for normalizer if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
python3 ${MAIN_ROOT}/utils/compute_mean_std.py \ # compute mean and stddev for normalizer
--manifest_path="data/manifest.train.raw" \ python3 ${MAIN_ROOT}/utils/compute_mean_std.py \
--specgram_type="fbank" \ --manifest_path="data/manifest.train.raw" \
--feat_dim=80 \ --specgram_type="fbank" \
--delta_delta=false \ --feat_dim=80 \
--stride_ms=10.0 \ --delta_delta=false \
--window_ms=25.0 \ --stride_ms=10.0 \
--sample_rate=16000 \ --window_ms=25.0 \
--num_samples=2000 \ --sample_rate=16000 \
--num_workers=0 \ --num_samples=-1 \
--output_path="data/mean_std.json" --num_workers=16 \
--output_path="data/mean_std.json"
if [ $? -ne 0 ]; then
echo "Compute mean and stddev failed. Terminated." if [ $? -ne 0 ]; then
exit 1 echo "Compute mean and stddev failed. Terminated."
exit 1
fi
fi fi
# format manifest with tokenids, vocab size if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
for dataset in train dev test; do # format manifest with tokenids, vocab size
python3 ${MAIN_ROOT}/utils/format_data.py \ for dataset in train dev test; do
--feat_type "raw" \ python3 ${MAIN_ROOT}/utils/format_data.py \
--cmvn_path "data/mean_std.npz" \ --feat_type "raw" \
--unit_type "char" \ --cmvn_path "data/mean_std.json" \
--vocab_path="data/vocab.txt" \ --unit_type "char" \
--manifest_path="data/manifest.${dataset}.raw" \ --vocab_path="data/vocab.txt" \
--output_path="data/manifest.${dataset}" --manifest_path="data/manifest.${dataset}.raw" \
done --output_path="data/manifest.${dataset}"
done
if [ $? -ne 0 ]; then
echo "Formt mnaifest failed. Terminated." if [ $? -ne 0 ]; then
exit 1 echo "Formt mnaifest failed. Terminated."
exit 1
fi
fi fi
echo "Aishell data preparation done." echo "Aishell data preparation done."

@ -1,6 +1,6 @@
#! /usr/bin/env bash #! /usr/bin/env bash
ngpu=$(echo ${CUDA_VISIBLE_DEVICES} | python -c 'import sys; a = sys.stdin.read(); print(len(a.split(",")));') ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..." echo "using $ngpu gpus..."
python3 -u ${BIN_DIR}/train.py \ python3 -u ${BIN_DIR}/train.py \

@ -1,74 +1,89 @@
#! /usr/bin/env bash #! /usr/bin/env bash
mkdir -p data stage=-1
TARGET_DIR=${MAIN_ROOT}/examples/dataset stop_stage=100
mkdir -p ${TARGET_DIR}
# download data, generate manifests
python3 ${TARGET_DIR}/librispeech/librispeech.py \
--manifest_prefix="data/manifest" \
--target_dir="${TARGET_DIR}/librispeech" \
--full_download="False"
if [ $? -ne 0 ]; then
echo "Prepare LibriSpeech failed. Terminated."
exit 1
fi
head -n 64 data/manifest.dev-clean > data/manifest.tiny.raw
# bpemode (unigram or bpe) # bpemode (unigram or bpe)
nbpe=200 nbpe=200
bpemode=unigram bpemode=unigram
bpeprefix="data/bpe_${bpemode}_${nbpe}" bpeprefix="data/bpe_${bpemode}_${nbpe}"
# build vocabulary
python3 ${MAIN_ROOT}/utils/build_vocab.py \
--unit_type "spm" \
--spm_vocab_size=${nbpe} \
--spm_mode ${bpemode} \
--spm_model_prefix ${bpeprefix} \
--vocab_path="data/vocab.txt" \
--manifest_paths="data/manifest.tiny.raw"
if [ $? -ne 0 ]; then source ${MAIN_ROOT}/utils/parse_options.sh
echo "Build vocabulary failed. Terminated."
exit 1
fi
# compute mean and stddev for normalizer mkdir -p data
python3 ${MAIN_ROOT}/utils/compute_mean_std.py \ TARGET_DIR=${MAIN_ROOT}/examples/dataset
--manifest_path="data/manifest.tiny.raw" \ mkdir -p ${TARGET_DIR}
--num_samples=64 \
--specgram_type="fbank" \
--feat_dim=80 \
--delta_delta=false \
--sample_rate=16000 \
--stride_ms=10.0 \
--window_ms=25.0 \
--num_workers=0 \
--output_path="data/mean_std.json"
if [ $? -ne 0 ]; then if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
echo "Compute mean and stddev failed. Terminated." # download data, generate manifests
exit 1 python3 ${TARGET_DIR}/librispeech/librispeech.py \
--manifest_prefix="data/manifest" \
--target_dir="${TARGET_DIR}/librispeech" \
--full_download="False"
if [ $? -ne 0 ]; then
echo "Prepare LibriSpeech failed. Terminated."
exit 1
fi
head -n 64 data/manifest.dev-clean > data/manifest.tiny.raw
fi
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
# build vocabulary
python3 ${MAIN_ROOT}/utils/build_vocab.py \
--unit_type "spm" \
--spm_vocab_size=${nbpe} \
--spm_mode ${bpemode} \
--spm_model_prefix ${bpeprefix} \
--vocab_path="data/vocab.txt" \
--manifest_paths="data/manifest.tiny.raw"
if [ $? -ne 0 ]; then
echo "Build vocabulary failed. Terminated."
exit 1
fi
fi fi
# format manifest with tokenids, vocab size if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
python3 ${MAIN_ROOT}/utils/format_data.py \ # compute mean and stddev for normalizer
--feat_type "raw" \ python3 ${MAIN_ROOT}/utils/compute_mean_std.py \
--cmvn_path "data/mean_std.npz" \ --manifest_path="data/manifest.tiny.raw" \
--unit_type "spm" \ --num_samples=64 \
--spm_model_prefix ${bpeprefix} \ --specgram_type="fbank" \
--vocab_path="data/vocab.txt" \ --feat_dim=80 \
--manifest_path="data/manifest.tiny.raw" \ --delta_delta=false \
--output_path="data/manifest.tiny" --sample_rate=16000 \
--stride_ms=10.0 \
--window_ms=25.0 \
--num_workers=2 \
--output_path="data/mean_std.json"
if [ $? -ne 0 ]; then
echo "Compute mean and stddev failed. Terminated."
exit 1
fi
fi
if [ $? -ne 0 ]; then if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
echo "Formt mnaifest failed. Terminated." # format manifest with tokenids, vocab size
exit 1 python3 ${MAIN_ROOT}/utils/format_data.py \
--feat_type "raw" \
--cmvn_path "data/mean_std.json" \
--unit_type "spm" \
--spm_model_prefix ${bpeprefix} \
--vocab_path="data/vocab.txt" \
--manifest_path="data/manifest.tiny.raw" \
--output_path="data/manifest.tiny"
if [ $? -ne 0 ]; then
echo "Formt mnaifest failed. Terminated."
exit 1
fi
fi fi
echo "LibriSpeech Data preparation done." echo "LibriSpeech Data preparation done."

@ -1,12 +1,11 @@
#! /usr/bin/env bash #! /usr/bin/env bash
# download language model # download language model
bash local/download_lm_en.sh #bash local/download_lm_en.sh
if [ $? -ne 0 ]; then #if [ $? -ne 0 ]; then
exit 1 # exit 1
fi #fi
CUDA_VISIBLE_DEVICES=0 \
python3 -u ${BIN_DIR}/test.py \ python3 -u ${BIN_DIR}/test.py \
--device 'gpu' \ --device 'gpu' \
--nproc 1 \ --nproc 1 \

@ -1,6 +1,6 @@
#! /usr/bin/env bash #! /usr/bin/env bash
ngpu=$(echo ${CUDA_VISIBLE_DEVICES} | python -c 'import sys; a = sys.stdin.read(); print(len(a.split(",")));') ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..." echo "using $ngpu gpus..."
python3 -u ${BIN_DIR}/train.py \ python3 -u ${BIN_DIR}/train.py \

Loading…
Cancel
Save