diff --git a/deepspeech/__init__.py b/deepspeech/__init__.py index 1c664aa3f..1c11de385 100644 --- a/deepspeech/__init__.py +++ b/deepspeech/__init__.py @@ -123,11 +123,7 @@ if not hasattr(paddle, 'cat'): ########### hcak paddle.Tensor ############# def item(x: paddle.Tensor): - if x.dtype == paddle.fluid.core_avx.VarDesc.VarType.FP32: - return float(x) - else: - raise ValueError("not support") - + return x.numpy().item() if not hasattr(paddle.Tensor, 'item'): logger.warn( diff --git a/deepspeech/exps/u2/model.py b/deepspeech/exps/u2/model.py index f124d3298..2fee8ded6 100644 --- a/deepspeech/exps/u2/model.py +++ b/deepspeech/exps/u2/model.py @@ -381,8 +381,8 @@ class U2Tester(U2Trainer): decoding_chunk_size=cfg.decoding_chunk_size, num_decoding_left_chunks=cfg.num_decoding_left_chunks, simulate_streaming=cfg.simulate_streaming) - decode_time = time.time() - + decode_time = time.time() - start_time + for target, result in zip(target_transcripts, result_transcripts): errors, len_ref = errors_func(target, result) errors_sum += errors @@ -392,13 +392,13 @@ class U2Tester(U2Trainer): fout.write(result + "\n") logger.info("\nTarget Transcription: %s\nOutput Transcription: %s" % (target, result)) - logger.info("Current error rate [%s] = %f" % + logger.info("One example error rate [%s] = %f" % (cfg.error_rate_type, error_rate_func(target, result))) return dict( errors_sum=errors_sum, len_refs=len_refs, - num_ins=num_ins, # num examples + num_ins=num_ins, # num examples error_rate=errors_sum / len_refs, error_rate_type=cfg.error_rate_type, num_frames=audio_len.sum().numpy().item(), @@ -411,6 +411,7 @@ class U2Tester(U2Trainer): self.model.eval() logger.info(f"Test Total Examples: {len(self.test_loader.dataset)}") + stride_ms = self.test_loader.dataset.stride_ms error_rate_type = None errors_sum, len_refs, num_ins = 0.0, 0, 0 num_frames = 0.0 @@ -424,11 +425,12 @@ class U2Tester(U2Trainer): len_refs += metrics['len_refs'] num_ins += metrics['num_ins'] error_rate_type = metrics['error_rate_type'] - logger.info("Error rate [%s] (%d/?) = %f" % - (error_rate_type, num_ins, errors_sum / len_refs)) + rtf = num_time / (num_frames * stride_ms) + logger.info( + "RTF: %f, Error rate [%s] (%d/?) = %f" % + (rtf, error_rate_type, num_ins, errors_sum / len_refs)) - rtf = num_time / (num_frames * self.test_loader.dataset.stride_ms / 1000.0) - # logging + rtf = num_time / (num_frames * stride_ms) msg = "Test: " msg += "epoch: {}, ".format(self.epoch) msg += "step: {}, ".format(self.iteration) diff --git a/deepspeech/frontend/featurizer/audio_featurizer.py b/deepspeech/frontend/featurizer/audio_featurizer.py index 5a1ffe79b..11c1fa2d4 100644 --- a/deepspeech/frontend/featurizer/audio_featurizer.py +++ b/deepspeech/frontend/featurizer/audio_featurizer.py @@ -108,7 +108,7 @@ class AudioFeaturizer(object): @property def stride_ms(self): return self._stride_ms - + @property def feature_size(self): """audio feature size""" diff --git a/deepspeech/frontend/featurizer/speech_featurizer.py b/deepspeech/frontend/featurizer/speech_featurizer.py index 19cb3be23..e6761cb52 100644 --- a/deepspeech/frontend/featurizer/speech_featurizer.py +++ b/deepspeech/frontend/featurizer/speech_featurizer.py @@ -148,7 +148,7 @@ class SpeechFeaturizer(object): float: time(ms)/frame """ return self._audio_featurizer.stride_ms - + @property def text_feature(self): """Return the text feature object. diff --git a/deepspeech/io/dataset.py b/deepspeech/io/dataset.py index 05fe408b9..da8b3f502 100644 --- a/deepspeech/io/dataset.py +++ b/deepspeech/io/dataset.py @@ -63,7 +63,7 @@ class ManifestDataset(Dataset): specgram_type='linear', # 'linear', 'mfcc', 'fbank' feat_dim=0, # 'mfcc', 'fbank' delta_delta=False, # 'mfcc', 'fbank' - dither=1.0, # feature dither + dither=1.0, # feature dither target_sample_rate=16000, # target sample rate use_dB_normalization=True, target_dB=-20, @@ -188,8 +188,7 @@ class ManifestDataset(Dataset): super().__init__() self._stride_ms = stride_ms self._target_sample_rate = target_sample_rate - - + self._normalizer = FeatureNormalizer( mean_std_filepath) if mean_std_filepath else None self._augmentation_pipeline = AugmentationPipeline( @@ -251,7 +250,7 @@ class ManifestDataset(Dataset): @property def feature_size(self): return self._speech_featurizer.feature_size - + @property def stride_ms(self): return self._speech_featurizer.stride_ms diff --git a/deepspeech/models/u2.py b/deepspeech/models/u2.py index f34aac771..6b93d089d 100644 --- a/deepspeech/models/u2.py +++ b/deepspeech/models/u2.py @@ -49,10 +49,10 @@ from deepspeech.utils.tensor_utils import pad_sequence from deepspeech.utils.tensor_utils import th_accuracy from deepspeech.utils.utility import log_add -logger = Log(__name__).getlog() - __all__ = ["U2Model", "U2InferModel"] +logger = Log(__name__).getlog() + class U2BaseModel(nn.Module): """CTC-Attention hybrid Encoder-Decoder model""" @@ -398,14 +398,17 @@ class U2BaseModel(nn.Module): assert decoding_chunk_size != 0 batch_size = speech.shape[0] # Let's assume B = batch_size + # encoder_out: (B, maxlen, encoder_dim) + # encoder_mask: (B, 1, Tmax) encoder_out, encoder_mask = self._forward_encoder( speech, speech_lengths, decoding_chunk_size, - num_decoding_left_chunks, - simulate_streaming) # (B, maxlen, encoder_dim) + num_decoding_left_chunks, simulate_streaming) maxlen = encoder_out.size(1) - encoder_out_lens = encoder_mask.squeeze(1).sum(1) + # (TODO Hui Zhang): bool no support reduce_sum + # encoder_out_lens = encoder_mask.squeeze(1).sum(1) + encoder_out_lens = encoder_mask.squeeze(1).astype(paddle.int).sum(1) ctc_probs = self.ctc.log_softmax(encoder_out) # (B, maxlen, vocab_size) - topk_prob, topk_index = ctc_probs.topk(1, dim=2) # (B, maxlen, 1) + topk_prob, topk_index = ctc_probs.topk(1, axis=2) # (B, maxlen, 1) topk_index = topk_index.view(batch_size, maxlen) # (B, maxlen) pad_mask = make_pad_mask(encoder_out_lens) # (B, maxlen) topk_index = topk_index.masked_fill_(pad_mask, self.eos) # (B, maxlen) @@ -573,11 +576,11 @@ class U2BaseModel(nn.Module): hyps_lens = hyps_lens + 1 # Add at begining encoder_out = encoder_out.repeat(beam_size, 1, 1) encoder_mask = paddle.ones( - beam_size, 1, encoder_out.size(1), dtype=paddle.bool) + (beam_size, 1, encoder_out.size(1)), dtype=paddle.bool) decoder_out, _ = self.decoder( encoder_out, encoder_mask, hyps_pad, hyps_lens) # (beam_size, max_hyps_len, vocab_size) - decoder_out = paddle.nn.functional.log_softmax(decoder_out, dim=-1) + decoder_out = paddle.nn.functional.log_softmax(decoder_out, axis=-1) decoder_out = decoder_out.numpy() # Only use decoder score for rescoring best_score = -float('inf') diff --git a/deepspeech/utils/tensor_utils.py b/deepspeech/utils/tensor_utils.py index 227306447..7679d9e1c 100644 --- a/deepspeech/utils/tensor_utils.py +++ b/deepspeech/utils/tensor_utils.py @@ -66,7 +66,9 @@ def pad_sequence(sequences: List[paddle.Tensor], # assuming trailing dimensions and type of all the Tensors # in sequences are same and fetching those from sequences[0] max_size = sequences[0].size() - trailing_dims = max_size[1:] + # (TODO Hui Zhang): slice not supprot `end==start` + # trailing_dims = max_size[1:] + trailing_dims = max_size[1:] if max_size.ndim >= 2 else () max_len = max([s.size(0) for s in sequences]) if batch_first: out_dims = (len(sequences), max_len) + trailing_dims diff --git a/examples/aishell/s0/local/data.sh b/examples/aishell/s0/local/data.sh index d29fb8bf3..b814fc754 100644 --- a/examples/aishell/s0/local/data.sh +++ b/examples/aishell/s0/local/data.sh @@ -1,71 +1,83 @@ #! /usr/bin/env bash +stage=-1 +stop_stage=100 + +source ${MAIN_ROOT}/utils/parse_options.sh + mkdir -p data TARGET_DIR=${MAIN_ROOT}/examples/dataset mkdir -p ${TARGET_DIR} -# download data, generate manifests -python3 ${TARGET_DIR}/aishell/aishell.py \ ---manifest_prefix="data/manifest" \ ---target_dir="${TARGET_DIR}/aishell" +if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then + # download data, generate manifests + python3 ${TARGET_DIR}/aishell/aishell.py \ + --manifest_prefix="data/manifest" \ + --target_dir="${TARGET_DIR}/aishell" + + if [ $? -ne 0 ]; then + echo "Prepare Aishell failed. Terminated." + exit 1 + fi -if [ $? -ne 0 ]; then - echo "Prepare Aishell failed. Terminated." - exit 1 + for dataset in train dev test; do + mv data/manifest.${dataset} data/manifest.${dataset}.raw + done fi - -for dataset in train dev test; do - mv data/manifest.${dataset} data/manifest.${dataset}.raw -done - - -# build vocabulary -python3 ${MAIN_ROOT}/utils/build_vocab.py \ ---unit_type="char" \ ---count_threshold=0 \ ---vocab_path="data/vocab.txt" \ ---manifest_paths "data/manifest.train.raw" - -if [ $? -ne 0 ]; then - echo "Build vocabulary failed. Terminated." - exit 1 +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + # download data, generate manifests + # build vocabulary + python3 ${MAIN_ROOT}/utils/build_vocab.py \ + --unit_type="char" \ + --count_threshold=0 \ + --vocab_path="data/vocab.txt" \ + --manifest_paths "data/manifest.train.raw" + + if [ $? -ne 0 ]; then + echo "Build vocabulary failed. Terminated." + exit 1 + fi fi -# compute mean and stddev for normalizer -python3 ${MAIN_ROOT}/utils/compute_mean_std.py \ ---manifest_path="data/manifest.train.raw" \ ---specgram_type="fbank" \ ---feat_dim=80 \ ---delta_delta=false \ ---stride_ms=10.0 \ ---window_ms=25.0 \ ---sample_rate=16000 \ ---num_samples=2000 \ ---num_workers=0 \ ---output_path="data/mean_std.json" - -if [ $? -ne 0 ]; then - echo "Compute mean and stddev failed. Terminated." - exit 1 +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + # compute mean and stddev for normalizer + python3 ${MAIN_ROOT}/utils/compute_mean_std.py \ + --manifest_path="data/manifest.train.raw" \ + --specgram_type="fbank" \ + --feat_dim=80 \ + --delta_delta=false \ + --stride_ms=10.0 \ + --window_ms=25.0 \ + --sample_rate=16000 \ + --num_samples=-1 \ + --num_workers=16 \ + --output_path="data/mean_std.json" + + if [ $? -ne 0 ]; then + echo "Compute mean and stddev failed. Terminated." + exit 1 + fi fi -# format manifest with tokenids, vocab size -for dataset in train dev test; do - python3 ${MAIN_ROOT}/utils/format_data.py \ - --feat_type "raw" \ - --cmvn_path "data/mean_std.npz" \ - --unit_type "char" \ - --vocab_path="data/vocab.txt" \ - --manifest_path="data/manifest.${dataset}.raw" \ - --output_path="data/manifest.${dataset}" -done - -if [ $? -ne 0 ]; then - echo "Formt mnaifest failed. Terminated." - exit 1 +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + # format manifest with tokenids, vocab size + for dataset in train dev test; do + python3 ${MAIN_ROOT}/utils/format_data.py \ + --feat_type "raw" \ + --cmvn_path "data/mean_std.json" \ + --unit_type "char" \ + --vocab_path="data/vocab.txt" \ + --manifest_path="data/manifest.${dataset}.raw" \ + --output_path="data/manifest.${dataset}" + done + + if [ $? -ne 0 ]; then + echo "Formt mnaifest failed. Terminated." + exit 1 + fi fi echo "Aishell data preparation done." diff --git a/examples/aishell/s1/local/train.sh b/examples/aishell/s1/local/train.sh index d20395d0f..0f032fa74 100644 --- a/examples/aishell/s1/local/train.sh +++ b/examples/aishell/s1/local/train.sh @@ -1,6 +1,6 @@ #! /usr/bin/env bash -ngpu=$(echo ${CUDA_VISIBLE_DEVICES} | python -c 'import sys; a = sys.stdin.read(); print(len(a.split(",")));') +ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') echo "using $ngpu gpus..." python3 -u ${BIN_DIR}/train.py \ diff --git a/examples/tiny/s0/local/data.sh b/examples/tiny/s0/local/data.sh index 234e87d28..bf3f1b85c 100644 --- a/examples/tiny/s0/local/data.sh +++ b/examples/tiny/s0/local/data.sh @@ -1,74 +1,89 @@ #! /usr/bin/env bash -mkdir -p data -TARGET_DIR=${MAIN_ROOT}/examples/dataset -mkdir -p ${TARGET_DIR} - -# download data, generate manifests -python3 ${TARGET_DIR}/librispeech/librispeech.py \ ---manifest_prefix="data/manifest" \ ---target_dir="${TARGET_DIR}/librispeech" \ ---full_download="False" - -if [ $? -ne 0 ]; then - echo "Prepare LibriSpeech failed. Terminated." - exit 1 -fi - -head -n 64 data/manifest.dev-clean > data/manifest.tiny.raw +stage=-1 +stop_stage=100 # bpemode (unigram or bpe) nbpe=200 bpemode=unigram bpeprefix="data/bpe_${bpemode}_${nbpe}" -# build vocabulary -python3 ${MAIN_ROOT}/utils/build_vocab.py \ ---unit_type "spm" \ ---spm_vocab_size=${nbpe} \ ---spm_mode ${bpemode} \ ---spm_model_prefix ${bpeprefix} \ ---vocab_path="data/vocab.txt" \ ---manifest_paths="data/manifest.tiny.raw" -if [ $? -ne 0 ]; then - echo "Build vocabulary failed. Terminated." - exit 1 -fi +source ${MAIN_ROOT}/utils/parse_options.sh -# compute mean and stddev for normalizer -python3 ${MAIN_ROOT}/utils/compute_mean_std.py \ ---manifest_path="data/manifest.tiny.raw" \ ---num_samples=64 \ ---specgram_type="fbank" \ ---feat_dim=80 \ ---delta_delta=false \ ---sample_rate=16000 \ ---stride_ms=10.0 \ ---window_ms=25.0 \ ---num_workers=0 \ ---output_path="data/mean_std.json" +mkdir -p data +TARGET_DIR=${MAIN_ROOT}/examples/dataset +mkdir -p ${TARGET_DIR} -if [ $? -ne 0 ]; then - echo "Compute mean and stddev failed. Terminated." - exit 1 +if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then + # download data, generate manifests + python3 ${TARGET_DIR}/librispeech/librispeech.py \ + --manifest_prefix="data/manifest" \ + --target_dir="${TARGET_DIR}/librispeech" \ + --full_download="False" + + if [ $? -ne 0 ]; then + echo "Prepare LibriSpeech failed. Terminated." + exit 1 + fi + + head -n 64 data/manifest.dev-clean > data/manifest.tiny.raw +fi + +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + # build vocabulary + python3 ${MAIN_ROOT}/utils/build_vocab.py \ + --unit_type "spm" \ + --spm_vocab_size=${nbpe} \ + --spm_mode ${bpemode} \ + --spm_model_prefix ${bpeprefix} \ + --vocab_path="data/vocab.txt" \ + --manifest_paths="data/manifest.tiny.raw" + + if [ $? -ne 0 ]; then + echo "Build vocabulary failed. Terminated." + exit 1 + fi fi -# format manifest with tokenids, vocab size -python3 ${MAIN_ROOT}/utils/format_data.py \ ---feat_type "raw" \ ---cmvn_path "data/mean_std.npz" \ ---unit_type "spm" \ ---spm_model_prefix ${bpeprefix} \ ---vocab_path="data/vocab.txt" \ ---manifest_path="data/manifest.tiny.raw" \ ---output_path="data/manifest.tiny" +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + # compute mean and stddev for normalizer + python3 ${MAIN_ROOT}/utils/compute_mean_std.py \ + --manifest_path="data/manifest.tiny.raw" \ + --num_samples=64 \ + --specgram_type="fbank" \ + --feat_dim=80 \ + --delta_delta=false \ + --sample_rate=16000 \ + --stride_ms=10.0 \ + --window_ms=25.0 \ + --num_workers=2 \ + --output_path="data/mean_std.json" + + if [ $? -ne 0 ]; then + echo "Compute mean and stddev failed. Terminated." + exit 1 + fi +fi -if [ $? -ne 0 ]; then - echo "Formt mnaifest failed. Terminated." - exit 1 +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + # format manifest with tokenids, vocab size + python3 ${MAIN_ROOT}/utils/format_data.py \ + --feat_type "raw" \ + --cmvn_path "data/mean_std.json" \ + --unit_type "spm" \ + --spm_model_prefix ${bpeprefix} \ + --vocab_path="data/vocab.txt" \ + --manifest_path="data/manifest.tiny.raw" \ + --output_path="data/manifest.tiny" + + + if [ $? -ne 0 ]; then + echo "Formt mnaifest failed. Terminated." + exit 1 + fi fi echo "LibriSpeech Data preparation done." diff --git a/examples/tiny/s1/local/test.sh b/examples/tiny/s1/local/test.sh index 475e941e7..e7ecc9b40 100644 --- a/examples/tiny/s1/local/test.sh +++ b/examples/tiny/s1/local/test.sh @@ -1,12 +1,11 @@ #! /usr/bin/env bash # download language model -bash local/download_lm_en.sh -if [ $? -ne 0 ]; then - exit 1 -fi +#bash local/download_lm_en.sh +#if [ $? -ne 0 ]; then +# exit 1 +#fi -CUDA_VISIBLE_DEVICES=0 \ python3 -u ${BIN_DIR}/test.py \ --device 'gpu' \ --nproc 1 \ diff --git a/examples/tiny/s1/local/train.sh b/examples/tiny/s1/local/train.sh index a0598e17a..2f73e9be9 100644 --- a/examples/tiny/s1/local/train.sh +++ b/examples/tiny/s1/local/train.sh @@ -1,6 +1,6 @@ #! /usr/bin/env bash -ngpu=$(echo ${CUDA_VISIBLE_DEVICES} | python -c 'import sys; a = sys.stdin.read(); print(len(a.split(",")));') +ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') echo "using $ngpu gpus..." python3 -u ${BIN_DIR}/train.py \