diff --git a/examples/aishell/asr0/conf/deepspeech2.yaml b/examples/aishell/asr0/conf/deepspeech2.yaml index fb6998647..913354f5d 100644 --- a/examples/aishell/asr0/conf/deepspeech2.yaml +++ b/examples/aishell/asr0/conf/deepspeech2.yaml @@ -15,50 +15,53 @@ max_output_input_ratio: .inf ########################################### # Dataloader # ########################################### -batch_size: 64 # one gpu -mean_std_filepath: data/mean_std.json -unit_type: char vocab_filepath: data/lang_char/vocab.txt -augmentation_config: conf/augmentation.json -random_seed: 0 -spm_model_prefix: -spectrum_type: linear +spm_model_prefix: '' +unit_type: 'char' +preprocess_config: conf/preprocess.yaml feat_dim: 161 -delta_delta: False stride_ms: 10.0 -window_ms: 20.0 -n_fft: None -max_freq: None -target_sample_rate: 16000 -use_dB_normalization: True -target_dB: -20 -dither: 1.0 -keep_transcription_text: False -sortagrad: True -shuffle_method: batch_shuffle -num_workers: 2 +window_ms: 25.0 +sortagrad: 0 # Feed samples from shortest to longest ; -1: enabled for all epochs, 0: disabled, other: enabled for 'other' epochs +batch_size: 64 +maxlen_in: 512 # if input length > maxlen-in, batchsize is automatically reduced +maxlen_out: 150 # if output length > maxlen-out, batchsize is automatically reduced +minibatches: 0 # for debug +batch_count: auto +batch_bins: 0 +batch_frames_in: 0 +batch_frames_out: 0 +batch_frames_inout: 0 +num_workers: 8 +subsampling_factor: 1 +num_encs: 1 ############################################ # Network Architecture # ############################################ num_conv_layers: 2 -num_rnn_layers: 3 +num_rnn_layers: 5 rnn_layer_size: 1024 -use_gru: True -share_rnn_weights: False +rnn_direction: bidirect # [forward, bidirect] +num_fc_layers: 0 +fc_layers_size_list: -1, +use_gru: False blank_id: 0 -ctc_grad_norm_type: instance - + + ########################################### # Training # ########################################### -n_epoch: 80 +n_epoch: 50 accum_grad: 1 -lr: 2.0e-3 -lr_decay: 0.83 +lr: 5.0e-4 +lr_decay: 0.93 weight_decay: 1.0e-6 global_grad_clip: 3.0 -log_interval: 100 +dist_sampler: False +log_interval: 1 checkpoint: kbest_n: 50 latest_n: 5 + + diff --git a/examples/aishell/asr0/conf/deepspeech2_online.yaml b/examples/aishell/asr0/conf/deepspeech2_online.yaml index ef01ac595..a53e19f37 100644 --- a/examples/aishell/asr0/conf/deepspeech2_online.yaml +++ b/examples/aishell/asr0/conf/deepspeech2_online.yaml @@ -15,28 +15,26 @@ max_output_input_ratio: .inf ########################################### # Dataloader # ########################################### -batch_size: 64 # one gpu -mean_std_filepath: data/mean_std.json -unit_type: char vocab_filepath: data/lang_char/vocab.txt -augmentation_config: conf/augmentation.json -random_seed: 0 -spm_model_prefix: -spectrum_type: linear #linear, mfcc, fbank +spm_model_prefix: '' +unit_type: 'char' +preprocess_config: conf/preprocess.yaml feat_dim: 161 -delta_delta: False stride_ms: 10.0 -window_ms: 20.0 -n_fft: None -max_freq: None -target_sample_rate: 16000 -use_dB_normalization: True -target_dB: -20 -dither: 1.0 -keep_transcription_text: False -sortagrad: True -shuffle_method: batch_shuffle -num_workers: 0 +window_ms: 25.0 +sortagrad: 0 # Feed samples from shortest to longest ; -1: enabled for all epochs, 0: disabled, other: enabled for 'other' epochs +batch_size: 64 +maxlen_in: 512 # if input length > maxlen-in, batchsize is automatically reduced +maxlen_out: 150 # if output length > maxlen-out, batchsize is automatically reduced +minibatches: 0 # for debug +batch_count: auto +batch_bins: 0 +batch_frames_in: 0 +batch_frames_out: 0 +batch_frames_inout: 0 +num_workers: 8 +subsampling_factor: 1 +num_encs: 1 ############################################ # Network Architecture # @@ -54,12 +52,13 @@ blank_id: 0 ########################################### # Training # ########################################### -n_epoch: 65 +n_epoch: 30 accum_grad: 1 lr: 5.0e-4 lr_decay: 0.93 weight_decay: 1.0e-6 global_grad_clip: 3.0 +dist_sampler: False log_interval: 100 checkpoint: kbest_n: 50 diff --git a/examples/aishell/asr0/conf/tuning/decode.yaml b/examples/aishell/asr0/conf/tuning/decode.yaml index 5778e6565..7dbc6fa82 100644 --- a/examples/aishell/asr0/conf/tuning/decode.yaml +++ b/examples/aishell/asr0/conf/tuning/decode.yaml @@ -2,9 +2,9 @@ decode_batch_size: 128 error_rate_type: cer decoding_method: ctc_beam_search lang_model_path: data/lm/zh_giga.no_cna_cmn.prune01244.klm -alpha: 1.9 -beta: 5.0 -beam_size: 300 +alpha: 2.2 +beta: 4.3 +beam_size: 500 cutoff_prob: 0.99 cutoff_top_n: 40 num_proc_bsearch: 10 diff --git a/examples/aishell/asr0/local/data.sh b/examples/aishell/asr0/local/data.sh index ec692eba6..8722c1ca3 100755 --- a/examples/aishell/asr0/local/data.sh +++ b/examples/aishell/asr0/local/data.sh @@ -33,12 +33,13 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then num_workers=$(nproc) python3 ${MAIN_ROOT}/utils/compute_mean_std.py \ --manifest_path="data/manifest.train.raw" \ - --spectrum_type="linear" \ + --spectrum_type="fbank" \ + --feat_dim=161 \ --delta_delta=false \ --stride_ms=10 \ - --window_ms=20 \ + --window_ms=25 \ --sample_rate=16000 \ - --use_dB_normalization=True \ + --use_dB_normalization=False \ --num_samples=2000 \ --num_workers=${num_workers} \ --output_path="data/mean_std.json" diff --git a/examples/aishell/asr0/run.sh b/examples/aishell/asr0/run.sh index 114af5a97..0358b821d 100755 --- a/examples/aishell/asr0/run.sh +++ b/examples/aishell/asr0/run.sh @@ -7,8 +7,7 @@ stage=0 stop_stage=100 conf_path=conf/deepspeech2.yaml #conf/deepspeech2.yaml or conf/deepspeech2_online.yaml decode_conf_path=conf/tuning/decode.yaml -avg_num=1 -model_type=offline # offline or online +avg_num=10 audio_file=data/demo_01_03.wav source ${MAIN_ROOT}/utils/parse_options.sh || exit 1; diff --git a/examples/librispeech/asr0/conf/deepspeech2.yaml b/examples/librispeech/asr0/conf/deepspeech2.yaml index 0307b9f39..cca695fe5 100644 --- a/examples/librispeech/asr0/conf/deepspeech2.yaml +++ b/examples/librispeech/asr0/conf/deepspeech2.yaml @@ -15,51 +15,51 @@ max_output_input_ratio: .inf ########################################### # Dataloader # ########################################### -batch_size: 20 -mean_std_filepath: data/mean_std.json -unit_type: char -vocab_filepath: data/lang_char/vocab.txt -augmentation_config: conf/augmentation.json -random_seed: 0 -spm_model_prefix: -spectrum_type: linear -feat_dim: -target_sample_rate: 16000 -max_freq: None -n_fft: None +vocab_filepath: data/lang_char/vocab.txt +spm_model_prefix: '' +unit_type: 'char' +preprocess_config: conf/preprocess.yaml +feat_dim: 161 stride_ms: 10.0 -window_ms: 20.0 -delta_delta: False -dither: 1.0 -use_dB_normalization: True -target_dB: -20 -random_seed: 0 -keep_transcription_text: False -sortagrad: True -shuffle_method: batch_shuffle -num_workers: 2 +window_ms: 25.0 +sortagrad: 0 # Feed samples from shortest to longest ; -1: enabled for all epochs, 0: disabled, other: enabled for 'other' epochs +batch_size: 64 +maxlen_in: 512 # if input length > maxlen-in, batchsize is automatically reduced +maxlen_out: 150 # if output length > maxlen-out, batchsize is automatically reduced +minibatches: 0 # for debug +batch_count: auto +batch_bins: 0 +batch_frames_in: 0 +batch_frames_out: 0 +batch_frames_inout: 0 +num_workers: 8 +subsampling_factor: 1 +num_encs: 1 ############################################ # Network Architecture # ############################################ num_conv_layers: 2 -num_rnn_layers: 3 -rnn_layer_size: 2048 +num_rnn_layers: 5 +rnn_layer_size: 1024 +rnn_direction: bidirect +num_fc_layers: 0 +fc_layers_size_list: -1 use_gru: False -share_rnn_weights: True blank_id: 0 ########################################### # Training # ########################################### -n_epoch: 50 +n_epoch: 15 accum_grad: 1 -lr: 1.0e-3 -lr_decay: 0.83 +lr: 5.0e-4 +lr_decay: 0.93 weight_decay: 1.0e-6 global_grad_clip: 5.0 -log_interval: 100 +dist_sampler: False +log_interval: 1 checkpoint: kbest_n: 50 latest_n: 5 diff --git a/examples/librispeech/asr0/conf/deepspeech2_online.yaml b/examples/librispeech/asr0/conf/deepspeech2_online.yaml index a0d2bcfe2..93421ef44 100644 --- a/examples/librispeech/asr0/conf/deepspeech2_online.yaml +++ b/examples/librispeech/asr0/conf/deepspeech2_online.yaml @@ -15,39 +15,36 @@ max_output_input_ratio: .inf ########################################### # Dataloader # ########################################### -batch_size: 15 -mean_std_filepath: data/mean_std.json -unit_type: char -vocab_filepath: data/lang_char/vocab.txt -augmentation_config: conf/augmentation.json -random_seed: 0 -spm_model_prefix: -spectrum_type: linear -feat_dim: -target_sample_rate: 16000 -max_freq: None -n_fft: None +vocab_filepath: data/lang_char/vocab.txt +spm_model_prefix: '' +unit_type: 'char' +preprocess_config: conf/preprocess.yaml +feat_dim: 161 stride_ms: 10.0 -window_ms: 20.0 -delta_delta: False -dither: 1.0 -use_dB_normalization: True -target_dB: -20 -random_seed: 0 -keep_transcription_text: False -sortagrad: True -shuffle_method: batch_shuffle -num_workers: 0 +window_ms: 25.0 +sortagrad: 0 # Feed samples from shortest to longest ; -1: enabled for all epochs, 0: disabled, other: enabled for 'other' epochs +batch_size: 64 +maxlen_in: 512 # if input length > maxlen-in, batchsize is automatically reduced +maxlen_out: 150 # if output length > maxlen-out, batchsize is automatically reduced +minibatches: 0 # for debug +batch_count: auto +batch_bins: 0 +batch_frames_in: 0 +batch_frames_out: 0 +batch_frames_inout: 0 +num_workers: 8 +subsampling_factor: 1 +num_encs: 1 ############################################ # Network Architecture # ############################################ num_conv_layers: 2 -num_rnn_layers: 3 -rnn_layer_size: 2048 +num_rnn_layers: 5 +rnn_layer_size: 1024 rnn_direction: forward -num_fc_layers: 2 -fc_layers_size_list: 512, 256 +num_fc_layers: 0 +fc_layers_size_list: -1 use_gru: False blank_id: 0 @@ -55,13 +52,13 @@ blank_id: 0 ########################################### # Training # ########################################### -n_epoch: 50 -accum_grad: 4 -lr: 1.0e-3 -lr_decay: 0.83 +n_epoch: 65 +accum_grad: 1 +lr: 5.0e-4 +lr_decay: 0.93 weight_decay: 1.0e-6 global_grad_clip: 5.0 -log_interval: 100 +log_interval: 1 checkpoint: kbest_n: 50 latest_n: 5 diff --git a/examples/librispeech/asr0/local/data.sh b/examples/librispeech/asr0/local/data.sh index b97e8c211..a28fddc96 100755 --- a/examples/librispeech/asr0/local/data.sh +++ b/examples/librispeech/asr0/local/data.sh @@ -49,12 +49,13 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then python3 ${MAIN_ROOT}/utils/compute_mean_std.py \ --manifest_path="data/manifest.train.raw" \ --num_samples=2000 \ - --spectrum_type="linear" \ + --spectrum_type="fbank" \ + --feat_dim=161 \ --delta_delta=false \ --sample_rate=16000 \ --stride_ms=10 \ - --window_ms=20 \ - --use_dB_normalization=True \ + --window_ms=25 \ + --use_dB_normalization=False \ --num_workers=${num_workers} \ --output_path="data/mean_std.json" diff --git a/examples/librispeech/asr0/local/test.sh b/examples/librispeech/asr0/local/test.sh index ea40046b1..5654a8794 100755 --- a/examples/librispeech/asr0/local/test.sh +++ b/examples/librispeech/asr0/local/test.sh @@ -4,6 +4,8 @@ if [ $# != 4 ];then echo "usage: ${0} config_path decode_config_path ckpt_path_prefix model_type" exit -1 fi +stage=0 +stop_stage=100 ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') echo "using $ngpu gpus..." @@ -19,17 +21,44 @@ if [ $? -ne 0 ]; then exit 1 fi -python3 -u ${BIN_DIR}/test.py \ ---ngpu ${ngpu} \ ---config ${config_path} \ ---decode_cfg ${decode_config_path} \ ---result_file ${ckpt_prefix}.rsl \ ---checkpoint_path ${ckpt_prefix} \ ---model_type ${model_type} +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + # format the reference test file + python utils/format_rsl.py \ + --origin_ref data/manifest.test-clean.raw \ + --trans_ref data/manifest.test-clean.text -if [ $? -ne 0 ]; then - echo "Failed in evaluation!" - exit 1 + python3 -u ${BIN_DIR}/test.py \ + --ngpu ${ngpu} \ + --config ${config_path} \ + --decode_cfg ${decode_config_path} \ + --result_file ${ckpt_prefix}.rsl \ + --checkpoint_path ${ckpt_prefix} \ + --model_type ${model_type} + + if [ $? -ne 0 ]; then + echo "Failed in evaluation!" + exit 1 + fi + + python utils/format_rsl.py \ + --origin_hyp ${ckpt_prefix}.rsl \ + --trans_hyp ${ckpt_prefix}.rsl.text + + python utils/compute-wer.py --char=1 --v=1 \ + data/manifest.test-clean.text ${ckpt_prefix}.rsl.text > ${ckpt_prefix}.error +fi + +if [ ${stage} -le 101 ] && [ ${stop_stage} -ge 101 ]; then + python utils/format_rsl.py \ + --origin_ref data/manifest.test-clean.raw \ + --trans_ref_sclite data/manifest.test.text-clean.sclite + + python utils/format_rsl.py \ + --origin_hyp ${ckpt_prefix}.rsl \ + --trans_hyp_sclite ${ckpt_prefix}.rsl.text.sclite + + mkdir -p ${ckpt_prefix}_sclite + sclite -i wsj -r data/manifest.test-clean.text.sclite -h ${ckpt_prefix}.rsl.text.sclite -e utf-8 -o all -O ${ckpt_prefix}_sclite -c NOASCII fi diff --git a/examples/librispeech/asr0/run.sh b/examples/librispeech/asr0/run.sh index ca2c2b9da..d96f65823 100755 --- a/examples/librispeech/asr0/run.sh +++ b/examples/librispeech/asr0/run.sh @@ -2,13 +2,12 @@ set -e source path.sh -gpus=0,1,2,3,4,5,6,7 +gpus=0,1,2,3 stage=0 stop_stage=100 conf_path=conf/deepspeech2.yaml decode_conf_path=conf/tuning/decode.yaml -avg_num=30 -model_type=offline +avg_num=5 audio_file=data/demo_002_en.wav source ${MAIN_ROOT}/utils/parse_options.sh || exit 1; @@ -43,6 +42,11 @@ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then fi if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then + # test export ckpt avg_n + CUDA_VISIBLE_DEVICES=0 ./local/test_export.sh ${conf_path} ${decode_conf_path} exp/${ckpt}/checkpoints/${avg_ckpt}.jit ${model_type}|| exit -1 +fi + +if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then # test a single .wav file CUDA_VISIBLE_DEVICES=0 ./local/test_wav.sh ${conf_path} ${decode_conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} ${model_type} ${audio_file} || exit -1 fi diff --git a/paddlespeech/cli/asr/infer.py b/paddlespeech/cli/asr/infer.py index 92f9b0e41..f26901a1e 100644 --- a/paddlespeech/cli/asr/infer.py +++ b/paddlespeech/cli/asr/infer.py @@ -138,6 +138,7 @@ class ASRExecutor(BaseExecutor): tag = model_type + '-' + lang + '-' + sample_rate_str self.task_resource.set_task_model(tag, version=None) self.res_path = self.task_resource.res_dir + self.cfg_path = os.path.join( self.res_path, self.task_resource.res_dict['cfg_path']) self.ckpt_path = os.path.join( @@ -158,15 +159,18 @@ class ASRExecutor(BaseExecutor): self.config.merge_from_file(self.cfg_path) with UpdateConfig(self.config): - if "deepspeech2online" in model_type or "deepspeech2offline" in model_type: - from paddlespeech.s2t.io.collator import SpeechCollator - self.vocab = self.config.vocab_filepath + if self.config.spm_model_prefix: + self.config.spm_model_prefix = os.path.join( + self.res_path, self.config.spm_model_prefix) + self.text_feature = TextFeaturizer( + unit_type=self.config.unit_type, + vocab=self.config.vocab_filepath, + spm_model_prefix=self.config.spm_model_prefix) + if "deepspeech2" in model_type: self.config.decode.lang_model_path = os.path.join( MODEL_HOME, 'language_model', self.config.decode.lang_model_path) - self.collate_fn_test = SpeechCollator.from_config(self.config) - self.text_feature = TextFeaturizer( - unit_type=self.config.unit_type, vocab=self.vocab) + lm_url = self.task_resource.res_dict['lm_url'] lm_md5 = self.task_resource.res_dict['lm_md5'] self.download_lm( @@ -174,12 +178,6 @@ class ASRExecutor(BaseExecutor): os.path.dirname(self.config.decode.lang_model_path), lm_md5) elif "conformer" in model_type or "transformer" in model_type: - self.config.spm_model_prefix = os.path.join( - self.res_path, self.config.spm_model_prefix) - self.text_feature = TextFeaturizer( - unit_type=self.config.unit_type, - vocab=self.config.vocab_filepath, - spm_model_prefix=self.config.spm_model_prefix) self.config.decode.decoding_method = decode_method else: @@ -222,19 +220,7 @@ class ASRExecutor(BaseExecutor): logger.info("Preprocess audio_file:" + audio_file) # Get the object for feature extraction - if "deepspeech2online" in model_type or "deepspeech2offline" in model_type: - audio, _ = self.collate_fn_test.process_utterance( - audio_file=audio_file, transcript=" ") - audio_len = audio.shape[0] - audio = paddle.to_tensor(audio, dtype='float32') - audio_len = paddle.to_tensor(audio_len) - audio = paddle.unsqueeze(audio, axis=0) - # vocab_list = collate_fn_test.vocab_list - self._inputs["audio"] = audio - self._inputs["audio_len"] = audio_len - logger.info(f"audio feat shape: {audio.shape}") - - elif "conformer" in model_type or "transformer" in model_type: + if "deepspeech2" in model_type or "conformer" in model_type or "transformer" in model_type: logger.info("get the preprocess conf") preprocess_conf = self.config.preprocess_config preprocess_args = {"train": False} @@ -242,7 +228,6 @@ class ASRExecutor(BaseExecutor): logger.info("read the audio file") audio, audio_sample_rate = soundfile.read( audio_file, dtype="int16", always_2d=True) - if self.change_format: if audio.shape[1] >= 2: audio = audio.mean(axis=1, dtype=np.int16) @@ -285,7 +270,7 @@ class ASRExecutor(BaseExecutor): cfg = self.config.decode audio = self._inputs["audio"] audio_len = self._inputs["audio_len"] - if "deepspeech2online" in model_type or "deepspeech2offline" in model_type: + if "deepspeech2" in model_type: decode_batch_size = audio.shape[0] self.model.decoder.init_decoder( decode_batch_size, self.text_feature.vocab_list, diff --git a/paddlespeech/resource/model_alias.py b/paddlespeech/resource/model_alias.py index 2b19ed065..0aeda13d4 100644 --- a/paddlespeech/resource/model_alias.py +++ b/paddlespeech/resource/model_alias.py @@ -23,7 +23,7 @@ model_alias = { # --------------------------------- "deepspeech2offline": ["paddlespeech.s2t.models.ds2:DeepSpeech2Model"], "deepspeech2online": - ["paddlespeech.s2t.models.ds2_online:DeepSpeech2ModelOnline"], + ["paddlespeech.s2t.models.ds2:DeepSpeech2Model"], "conformer": ["paddlespeech.s2t.models.u2:U2Model"], "conformer_online": ["paddlespeech.s2t.models.u2:U2Model"], "transformer": ["paddlespeech.s2t.models.u2:U2Model"], diff --git a/paddlespeech/resource/pretrained_models.py b/paddlespeech/resource/pretrained_models.py index 84362f967..f01f9165f 100644 --- a/paddlespeech/resource/pretrained_models.py +++ b/paddlespeech/resource/pretrained_models.py @@ -136,9 +136,9 @@ asr_dynamic_pretrained_models = { "deepspeech2online_wenetspeech-zh-16k": { '1.0': { 'url': - 'https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr0/asr0_deepspeech2_online_wenetspeech_ckpt_1.0.0a.model.tar.gz', + 'https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr0/asr0_deepspeech2_online_wenetspeech_ckpt_1.0.1.model.tar.gz', 'md5': - 'e393d4d274af0f6967db24fc146e8074', + 'd1be86a3e786042ab64f05161b5fae62', 'cfg_path': 'model.yaml', 'ckpt_path': @@ -152,13 +152,13 @@ asr_dynamic_pretrained_models = { "deepspeech2offline_aishell-zh-16k": { '1.0': { 'url': - 'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_aishell_ckpt_0.1.1.model.tar.gz', + 'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_offline_aishell_ckpt_1.0.1.model.tar.gz', 'md5': - '932c3593d62fe5c741b59b31318aa314', + '4d26066c6f19f52087425dc722ae5b13', 'cfg_path': 'model.yaml', 'ckpt_path': - 'exp/deepspeech2/checkpoints/avg_1', + 'exp/deepspeech2/checkpoints/avg_10', 'lm_url': 'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm', 'lm_md5': @@ -168,9 +168,9 @@ asr_dynamic_pretrained_models = { "deepspeech2online_aishell-zh-16k": { '1.0': { 'url': - 'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_fbank161_ckpt_0.2.1.model.tar.gz', + 'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_fbank161_ckpt_1.0.1.model.tar.gz', 'md5': - '98b87b171b7240b7cae6e07d8d0bc9be', + 'df5ddeac8b679a470176649ac4b78726', 'cfg_path': 'model.yaml', 'ckpt_path': @@ -188,13 +188,13 @@ asr_dynamic_pretrained_models = { "deepspeech2offline_librispeech-en-16k": { '1.0': { 'url': - 'https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr0/asr0_deepspeech2_librispeech_ckpt_0.1.1.model.tar.gz', + 'https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr0/asr0_deepspeech2_offline_librispeech_ckpt_1.0.1.model.tar.gz', 'md5': - 'f5666c81ad015c8de03aac2bc92e5762', + 'ed9e2b008a65268b3484020281ab048c', 'cfg_path': 'model.yaml', 'ckpt_path': - 'exp/deepspeech2/checkpoints/avg_1', + 'exp/deepspeech2/checkpoints/avg_5', 'lm_url': 'https://deepspeech.bj.bcebos.com/en_lm/common_crawl_00.prune01111.trie.klm', 'lm_md5': @@ -207,17 +207,17 @@ asr_static_pretrained_models = { "deepspeech2offline_aishell-zh-16k": { '1.0': { 'url': - 'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_aishell_ckpt_0.1.1.model.tar.gz', + 'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_offline_aishell_ckpt_1.0.1.model.tar.gz', 'md5': - '932c3593d62fe5c741b59b31318aa314', + '4d26066c6f19f52087425dc722ae5b13', 'cfg_path': 'model.yaml', 'ckpt_path': - 'exp/deepspeech2/checkpoints/avg_1', + 'exp/deepspeech2/checkpoints/avg_10', 'model': - 'exp/deepspeech2/checkpoints/avg_1.jit.pdmodel', + 'exp/deepspeech2/checkpoints/avg_10.jit.pdmodel', 'params': - 'exp/deepspeech2/checkpoints/avg_1.jit.pdiparams', + 'exp/deepspeech2/checkpoints/avg_10.jit.pdiparams', 'lm_url': 'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm', 'lm_md5': @@ -830,7 +830,7 @@ vector_dynamic_pretrained_models = { 'cfg_path': 'conf/model.yaml', # the yaml config path 'ckpt_path': - 'model/model', # the format is ${dir}/{model_name}, + 'model/model', # the format is ${dir}/{model_name}, # so the first 'model' is dir, the second 'model' is the name # this means we have a model stored as model/model.pdparams }, diff --git a/paddlespeech/s2t/exps/deepspeech2/bin/export.py b/paddlespeech/s2t/exps/deepspeech2/bin/export.py index 62bf191df..049e7b688 100644 --- a/paddlespeech/s2t/exps/deepspeech2/bin/export.py +++ b/paddlespeech/s2t/exps/deepspeech2/bin/export.py @@ -32,11 +32,9 @@ def main(config, args): if __name__ == "__main__": parser = default_argument_parser() - # save jit model to + # save jit model to parser.add_argument( "--export_path", type=str, help="path of the jit model to save") - parser.add_argument( - "--model_type", type=str, default='offline', help="offline/online") parser.add_argument( '--nxpu', type=int, @@ -44,7 +42,6 @@ if __name__ == "__main__": choices=[0, 1], help="if nxpu == 0 and ngpu == 0, use cpu.") args = parser.parse_args() - print("model_type:{}".format(args.model_type)) print_arguments(args) # https://yaml.org/type/float.html diff --git a/paddlespeech/s2t/exps/deepspeech2/bin/test.py b/paddlespeech/s2t/exps/deepspeech2/bin/test.py index a7d99e02a..a9828f6e7 100644 --- a/paddlespeech/s2t/exps/deepspeech2/bin/test.py +++ b/paddlespeech/s2t/exps/deepspeech2/bin/test.py @@ -32,9 +32,7 @@ def main(config, args): if __name__ == "__main__": parser = default_argument_parser() - parser.add_argument( - "--model_type", type=str, default='offline', help='offline/online') - # save asr result to + # save asr result to parser.add_argument( "--result_file", type=str, help="path of save the asr result") parser.add_argument( @@ -45,7 +43,6 @@ if __name__ == "__main__": help="if nxpu == 0 and ngpu == 0, use cpu.") args = parser.parse_args() print_arguments(args, globals()) - print("model_type:{}".format(args.model_type)) # https://yaml.org/type/float.html config = CfgNode(new_allowed=True) diff --git a/paddlespeech/s2t/exps/deepspeech2/bin/test_export.py b/paddlespeech/s2t/exps/deepspeech2/bin/test_export.py index 58815ad63..8db081e7b 100644 --- a/paddlespeech/s2t/exps/deepspeech2/bin/test_export.py +++ b/paddlespeech/s2t/exps/deepspeech2/bin/test_export.py @@ -38,8 +38,6 @@ if __name__ == "__main__": #load jit model from parser.add_argument( "--export_path", type=str, help="path of the jit model to save") - parser.add_argument( - "--model_type", type=str, default='offline', help='offline/online') parser.add_argument( '--nxpu', type=int, @@ -50,7 +48,6 @@ if __name__ == "__main__": "--enable-auto-log", action="store_true", help="use auto log") args = parser.parse_args() print_arguments(args, globals()) - print("model_type:{}".format(args.model_type)) # https://yaml.org/type/float.html config = CfgNode(new_allowed=True) diff --git a/paddlespeech/s2t/exps/deepspeech2/bin/test_wav.py b/paddlespeech/s2t/exps/deepspeech2/bin/test_wav.py index a909dd416..90b7d8a18 100644 --- a/paddlespeech/s2t/exps/deepspeech2/bin/test_wav.py +++ b/paddlespeech/s2t/exps/deepspeech2/bin/test_wav.py @@ -23,7 +23,6 @@ from yacs.config import CfgNode from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer from paddlespeech.s2t.io.collator import SpeechCollator from paddlespeech.s2t.models.ds2 import DeepSpeech2Model -from paddlespeech.s2t.models.ds2_online import DeepSpeech2ModelOnline from paddlespeech.s2t.training.cli import default_argument_parser from paddlespeech.s2t.utils import mp_tools from paddlespeech.s2t.utils.checkpoint import Checkpoint @@ -113,12 +112,7 @@ class DeepSpeech2Tester_hub(): config.input_dim = self.collate_fn_test.feature_size config.output_dim = self.collate_fn_test.vocab_size - if self.args.model_type == 'offline': - model = DeepSpeech2Model.from_config(config) - elif self.args.model_type == 'online': - model = DeepSpeech2ModelOnline.from_config(config) - else: - raise Exception("wrong model type") + model = DeepSpeech2Model.from_config(config) self.model = model @@ -172,8 +166,6 @@ def main(config, args): if __name__ == "__main__": parser = default_argument_parser() - parser.add_argument( - "--model_type", type=str, default='offline', help='offline/online') parser.add_argument("--audio_file", type=str, help='audio file path') # save asr result to parser.add_argument( @@ -184,7 +176,6 @@ if __name__ == "__main__": print("Please input the audio file path") sys.exit(-1) check(args.audio_file) - print("model_type:{}".format(args.model_type)) # https://yaml.org/type/float.html config = CfgNode(new_allowed=True) diff --git a/paddlespeech/s2t/exps/deepspeech2/bin/train.py b/paddlespeech/s2t/exps/deepspeech2/bin/train.py index 3906b2fc6..fee7079d9 100644 --- a/paddlespeech/s2t/exps/deepspeech2/bin/train.py +++ b/paddlespeech/s2t/exps/deepspeech2/bin/train.py @@ -31,8 +31,6 @@ def main(config, args): if __name__ == "__main__": parser = default_argument_parser() - parser.add_argument( - "--model_type", type=str, default='offline', help='offline/online') parser.add_argument( '--nxpu', type=int, @@ -40,7 +38,6 @@ if __name__ == "__main__": choices=[0, 1], help="if nxpu == 0 and ngpu == 0, use cpu.") args = parser.parse_args() - print("model_type:{}".format(args.model_type)) print_arguments(args, globals()) # https://yaml.org/type/float.html diff --git a/paddlespeech/s2t/exps/deepspeech2/model.py b/paddlespeech/s2t/exps/deepspeech2/model.py index 3c2eaab72..5aa8e3743 100644 --- a/paddlespeech/s2t/exps/deepspeech2/model.py +++ b/paddlespeech/s2t/exps/deepspeech2/model.py @@ -23,16 +23,12 @@ import paddle from paddle import distributed as dist from paddle import inference from paddle.io import DataLoader +from paddlespeech.s2t.io.dataloader import BatchDataLoader from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer -from paddlespeech.s2t.io.collator import SpeechCollator from paddlespeech.s2t.io.dataset import ManifestDataset -from paddlespeech.s2t.io.sampler import SortagradBatchSampler -from paddlespeech.s2t.io.sampler import SortagradDistributedBatchSampler from paddlespeech.s2t.models.ds2 import DeepSpeech2InferModel from paddlespeech.s2t.models.ds2 import DeepSpeech2Model -from paddlespeech.s2t.models.ds2_online import DeepSpeech2InferModelOnline -from paddlespeech.s2t.models.ds2_online import DeepSpeech2ModelOnline from paddlespeech.s2t.training.gradclip import ClipGradByGlobalNormWithLog from paddlespeech.s2t.training.reporter import report from paddlespeech.s2t.training.timer import Timer @@ -136,18 +132,13 @@ class DeepSpeech2Trainer(Trainer): config = self.config.clone() with UpdateConfig(config): if self.train: - config.input_dim = self.train_loader.collate_fn.feature_size - config.output_dim = self.train_loader.collate_fn.vocab_size + config.input_dim = self.train_loader.feat_dim + config.output_dim = self.train_loader.vocab_size else: - config.input_dim = self.test_loader.collate_fn.feature_size - config.output_dim = self.test_loader.collate_fn.vocab_size + config.input_dim = self.test_loader.feat_dim + config.output_dim = self.test_loader.vocab_size - if self.args.model_type == 'offline': - model = DeepSpeech2Model.from_config(config) - elif self.args.model_type == 'online': - model = DeepSpeech2ModelOnline.from_config(config) - else: - raise Exception("wrong model type") + model = DeepSpeech2Model.from_config(config) if self.parallel: model = paddle.DataParallel(model) @@ -175,76 +166,81 @@ class DeepSpeech2Trainer(Trainer): config = self.config.clone() config.defrost() if self.train: - # train - config.manifest = config.train_manifest - train_dataset = ManifestDataset.from_config(config) - if self.parallel: - batch_sampler = SortagradDistributedBatchSampler( - train_dataset, - batch_size=config.batch_size, - num_replicas=None, - rank=None, - shuffle=True, - drop_last=True, - sortagrad=config.sortagrad, - shuffle_method=config.shuffle_method) - else: - batch_sampler = SortagradBatchSampler( - train_dataset, - shuffle=True, - batch_size=config.batch_size, - drop_last=True, - sortagrad=config.sortagrad, - shuffle_method=config.shuffle_method) - - config.keep_transcription_text = False - collate_fn_train = SpeechCollator.from_config(config) - self.train_loader = DataLoader( - train_dataset, - batch_sampler=batch_sampler, - collate_fn=collate_fn_train, - num_workers=config.num_workers) - - # dev - config.manifest = config.dev_manifest - dev_dataset = ManifestDataset.from_config(config) - - config.augmentation_config = "" - config.keep_transcription_text = False - collate_fn_dev = SpeechCollator.from_config(config) - self.valid_loader = DataLoader( - dev_dataset, - batch_size=int(config.batch_size), - shuffle=False, - drop_last=False, - collate_fn=collate_fn_dev, - num_workers=config.num_workers) - logger.info("Setup train/valid Dataloader!") + # train/valid dataset, return token ids + self.train_loader = BatchDataLoader( + json_file=config.train_manifest, + train_mode=True, + sortagrad=config.sortagrad, + batch_size=config.batch_size, + maxlen_in=config.maxlen_in, + maxlen_out=config.maxlen_out, + minibatches=config.minibatches, + mini_batch_size=self.args.ngpu, + batch_count=config.batch_count, + batch_bins=config.batch_bins, + batch_frames_in=config.batch_frames_in, + batch_frames_out=config.batch_frames_out, + batch_frames_inout=config.batch_frames_inout, + preprocess_conf=config.preprocess_config, + n_iter_processes=config.num_workers, + subsampling_factor=1, + num_encs=1, + dist_sampler=config.get('dist_sampler', False), + shortest_first=False) + + self.valid_loader = BatchDataLoader( + json_file=config.dev_manifest, + train_mode=False, + sortagrad=False, + batch_size=config.batch_size, + maxlen_in=float('inf'), + maxlen_out=float('inf'), + minibatches=0, + mini_batch_size=self.args.ngpu, + batch_count='auto', + batch_bins=0, + batch_frames_in=0, + batch_frames_out=0, + batch_frames_inout=0, + preprocess_conf=config.preprocess_config, + n_iter_processes=config.num_workers, + subsampling_factor=1, + num_encs=1, + dist_sampler=config.get('dist_sampler', False), + shortest_first=False) + logger.info("Setup train/valid Dataloader!") else: - # test - config.manifest = config.test_manifest - test_dataset = ManifestDataset.from_config(config) - - config.augmentation_config = "" - config.keep_transcription_text = True - collate_fn_test = SpeechCollator.from_config(config) decode_batch_size = config.get('decode', dict()).get( 'decode_batch_size', 1) - self.test_loader = DataLoader( - test_dataset, + # test dataset, return raw text + self.test_loader = BatchDataLoader( + json_file=config.test_manifest, + train_mode=False, + sortagrad=False, batch_size=decode_batch_size, - shuffle=False, - drop_last=False, - collate_fn=collate_fn_test, - num_workers=config.num_workers) - logger.info("Setup test Dataloader!") + maxlen_in=float('inf'), + maxlen_out=float('inf'), + minibatches=0, + mini_batch_size=1, + batch_count='auto', + batch_bins=0, + batch_frames_in=0, + batch_frames_out=0, + batch_frames_inout=0, + preprocess_conf=config.preprocess_config, + n_iter_processes=1, + subsampling_factor=1, + num_encs=1) + logger.info("Setup test/align Dataloader!") class DeepSpeech2Tester(DeepSpeech2Trainer): def __init__(self, config, args): super().__init__(config, args) self._text_featurizer = TextFeaturizer( - unit_type=config.unit_type, vocab=None) + unit_type=config.unit_type, + vocab=config.vocab_filepath) + self.vocab_list = self._text_featurizer.vocab_list def ordid2token(self, texts, texts_len): """ ord() id to chr() chr """ @@ -252,7 +248,8 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): for text, n in zip(texts, texts_len): n = n.numpy().item() ids = text[:n] - trans.append(''.join([chr(i) for i in ids])) + #trans.append(''.join([chr(i) for i in ids])) + trans.append(self._text_featurizer.defeaturize(ids.numpy().tolist())) return trans def compute_metrics(self, @@ -307,8 +304,8 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): # Initialized the decoder in model decode_cfg = self.config.decode - vocab_list = self.test_loader.collate_fn.vocab_list - decode_batch_size = self.test_loader.batch_size + vocab_list = self.vocab_list + decode_batch_size = decode_cfg.decode_batch_size self.model.decoder.init_decoder( decode_batch_size, vocab_list, decode_cfg.decoding_method, decode_cfg.lang_model_path, decode_cfg.alpha, decode_cfg.beta, @@ -338,17 +335,9 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): @paddle.no_grad() def export(self): - if self.args.model_type == 'offline': - infer_model = DeepSpeech2InferModel.from_pretrained( - self.test_loader, self.config, self.args.checkpoint_path) - elif self.args.model_type == 'online': - infer_model = DeepSpeech2InferModelOnline.from_pretrained( - self.test_loader, self.config, self.args.checkpoint_path) - else: - raise Exception("wrong model type") - + infer_model = DeepSpeech2InferModel.from_pretrained( + self.test_loader, self.config, self.args.checkpoint_path) infer_model.eval() - feat_dim = self.test_loader.collate_fn.feature_size static_model = infer_model.export() logger.info(f"Export code: {static_model.forward.code}") paddle.jit.save(static_model, self.args.export_path) @@ -376,10 +365,10 @@ class DeepSpeech2ExportTester(DeepSpeech2Tester): # Initialized the decoder in model decode_cfg = self.config.decode - vocab_list = self.test_loader.collate_fn.vocab_list - if self.args.model_type == "online": + vocab_list = self.vocab_list + if self.config.rnn_direction == "forward": decode_batch_size = 1 - elif self.args.model_type == "offline": + elif self.config.rnn_direction == "bidirect": decode_batch_size = self.test_loader.batch_size else: raise Exception("wrong model type") @@ -412,11 +401,11 @@ class DeepSpeech2ExportTester(DeepSpeech2Tester): self.model.decoder.del_decoder() def compute_result_transcripts(self, audio, audio_len): - if self.args.model_type == "online": + if self.config.rnn_direction == "forward": output_probs, output_lens, trans_batch = self.static_forward_online( audio, audio_len, decoder_chunk_size=1) result_transcripts = [trans[-1] for trans in trans_batch] - elif self.args.model_type == "offline": + elif self.config.rnn_direction == "bidirect": output_probs, output_lens = self.static_forward_offline(audio, audio_len) batch_size = output_probs.shape[0] diff --git a/paddlespeech/s2t/models/ds2/conv.py b/paddlespeech/s2t/models/ds2/conv.py index 4e766e793..448d4d1bb 100644 --- a/paddlespeech/s2t/models/ds2/conv.py +++ b/paddlespeech/s2t/models/ds2/conv.py @@ -11,161 +11,23 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from paddle import nn -from paddle.nn import functional as F +import paddle -from paddlespeech.s2t.modules.activation import brelu -from paddlespeech.s2t.modules.mask import make_non_pad_mask -from paddlespeech.s2t.utils.log import Log +from paddlespeech.s2t.modules.subsampling import Conv2dSubsampling4 -logger = Log(__name__).getlog() -__all__ = ['ConvStack', "conv_output_size"] +class Conv2dSubsampling4Pure(Conv2dSubsampling4): + def __init__(self, idim: int, odim: int, dropout_rate: float): + super().__init__(idim, odim, dropout_rate, None) + self.output_dim = ((idim - 1) // 2 - 1) // 2 * odim + self.receptive_field_length = 2 * ( + 3 - 1) + 3 # stride_1 * (kernel_size_2 - 1) + kerel_size_1 - -def conv_output_size(I, F, P, S): - # https://stanford.edu/~shervine/teaching/cs-230/cheatsheet-convolutional-neural-networks#hyperparameters - # Output size after Conv: - # By noting I the length of the input volume size, - # F the length of the filter, - # P the amount of zero padding, - # S the stride, - # then the output size O of the feature map along that dimension is given by: - # O = (I - F + Pstart + Pend) // S + 1 - # When Pstart == Pend == P, we can replace Pstart + Pend by 2P. - # When Pstart == Pend == 0 - # O = (I - F - S) // S - # https://iq.opengenus.org/output-size-of-convolution/ - # Output height = (Input height + padding height top + padding height bottom - kernel height) / (stride height) + 1 - # Output width = (Output width + padding width right + padding width left - kernel width) / (stride width) + 1 - return (I - F + 2 * P - S) // S - - -# receptive field calculator -# https://fomoro.com/research/article/receptive-field-calculator -# https://stanford.edu/~shervine/teaching/cs-230/cheatsheet-convolutional-neural-networks#hyperparameters -# https://distill.pub/2019/computing-receptive-fields/ -# Rl-1 = Sl * Rl + (Kl - Sl) - - -class ConvBn(nn.Layer): - """Convolution layer with batch normalization. - - :param kernel_size: The x dimension of a filter kernel. Or input a tuple for - two image dimension. - :type kernel_size: int|tuple|list - :param num_channels_in: Number of input channels. - :type num_channels_in: int - :param num_channels_out: Number of output channels. - :type num_channels_out: int - :param stride: The x dimension of the stride. Or input a tuple for two - image dimension. - :type stride: int|tuple|list - :param padding: The x dimension of the padding. Or input a tuple for two - image dimension. - :type padding: int|tuple|list - :param act: Activation type, relu|brelu - :type act: string - :return: Batch norm layer after convolution layer. - :rtype: Variable - - """ - - def __init__(self, num_channels_in, num_channels_out, kernel_size, stride, - padding, act): - - super().__init__() - assert len(kernel_size) == 2 - assert len(stride) == 2 - assert len(padding) == 2 - self.kernel_size = kernel_size - self.stride = stride - self.padding = padding - - self.conv = nn.Conv2D( - num_channels_in, - num_channels_out, - kernel_size=kernel_size, - stride=stride, - padding=padding, - weight_attr=None, - bias_attr=False, - data_format='NCHW') - - self.bn = nn.BatchNorm2D( - num_channels_out, - weight_attr=None, - bias_attr=None, - data_format='NCHW') - self.act = F.relu if act == 'relu' else brelu - - def forward(self, x, x_len): - """ - x(Tensor): audio, shape [B, C, D, T] - """ + def forward(self, x: paddle.Tensor, + x_len: paddle.Tensor) -> [paddle.Tensor, paddle.Tensor]: + x = x.unsqueeze(1) # (b, c=1, t, f) x = self.conv(x) - x = self.bn(x) - x = self.act(x) - - x_len = (x_len - self.kernel_size[1] + 2 * self.padding[1] - ) // self.stride[1] + 1 - - # reset padding part to 0 - masks = make_non_pad_mask(x_len) #[B, T] - masks = masks.unsqueeze(1).unsqueeze(1) # [B, 1, 1, T] - # TODO(Hui Zhang): not support bool multiply - # masks = masks.type_as(x) - masks = masks.astype(x.dtype) - x = x.multiply(masks) - return x, x_len - - -class ConvStack(nn.Layer): - """Convolution group with stacked convolution layers. - - :param feat_size: audio feature dim. - :type feat_size: int - :param num_stacks: Number of stacked convolution layers. - :type num_stacks: int - """ - - def __init__(self, feat_size, num_stacks): - super().__init__() - self.feat_size = feat_size # D - self.num_stacks = num_stacks - - self.conv_in = ConvBn( - num_channels_in=1, - num_channels_out=32, - kernel_size=(41, 11), #[D, T] - stride=(2, 3), - padding=(20, 5), - act='brelu') - - out_channel = 32 - convs = [ - ConvBn( - num_channels_in=32, - num_channels_out=out_channel, - kernel_size=(21, 11), - stride=(2, 1), - padding=(10, 5), - act='brelu') for i in range(num_stacks - 1) - ] - self.conv_stack = nn.LayerList(convs) - - # conv output feat_dim - output_height = (feat_size - 1) // 2 + 1 - for i in range(self.num_stacks - 1): - output_height = (output_height - 1) // 2 + 1 - self.output_height = out_channel * output_height - - def forward(self, x, x_len): - """ - x: shape [B, C, D, T] - x_len : shape [B] - """ - x, x_len = self.conv_in(x, x_len) - for i, conv in enumerate(self.conv_stack): - x, x_len = conv(x, x_len) + #b, c, t, f = paddle.shape(x) #not work under jit + x = x.transpose([0, 2, 1, 3]).reshape([0, 0, -1]) + x_len = ((x_len - 1) // 2 - 1) // 2 return x, x_len diff --git a/paddlespeech/s2t/models/ds2/deepspeech2.py b/paddlespeech/s2t/models/ds2/deepspeech2.py index 9c6b66c25..3f2c76ced 100644 --- a/paddlespeech/s2t/models/ds2/deepspeech2.py +++ b/paddlespeech/s2t/models/ds2/deepspeech2.py @@ -13,15 +13,14 @@ # limitations under the License. """Deepspeech2 ASR Model""" import paddle +import paddle.nn.functional as F from paddle import nn -from paddlespeech.s2t.models.ds2.conv import ConvStack -from paddlespeech.s2t.models.ds2.rnn import RNNStack +from paddlespeech.s2t.models.ds2.conv import Conv2dSubsampling4Pure from paddlespeech.s2t.modules.ctc import CTCDecoder from paddlespeech.s2t.utils import layer_tools from paddlespeech.s2t.utils.checkpoint import Checkpoint from paddlespeech.s2t.utils.log import Log - logger = Log(__name__).getlog() __all__ = ['DeepSpeech2Model', 'DeepSpeech2InferModel'] @@ -32,72 +31,197 @@ class CRNNEncoder(nn.Layer): feat_size, dict_size, num_conv_layers=2, - num_rnn_layers=3, + num_rnn_layers=4, rnn_size=1024, - use_gru=False, - share_rnn_weights=True): + rnn_direction='forward', + num_fc_layers=2, + fc_layers_size_list=[512, 256], + use_gru=False): super().__init__() self.rnn_size = rnn_size self.feat_size = feat_size # 161 for linear self.dict_size = dict_size - - self.conv = ConvStack(feat_size, num_conv_layers) - - i_size = self.conv.output_height # H after conv stack - self.rnn = RNNStack( - i_size=i_size, - h_size=rnn_size, - num_stacks=num_rnn_layers, - use_gru=use_gru, - share_rnn_weights=share_rnn_weights) + self.num_rnn_layers = num_rnn_layers + self.num_fc_layers = num_fc_layers + self.rnn_direction = rnn_direction + self.fc_layers_size_list = fc_layers_size_list + self.use_gru = use_gru + self.conv = Conv2dSubsampling4Pure(feat_size, 32, dropout_rate=0.0) + + self.output_dim = self.conv.output_dim + + i_size = self.conv.output_dim + self.rnn = nn.LayerList() + self.layernorm_list = nn.LayerList() + self.fc_layers_list = nn.LayerList() + if rnn_direction == 'bidirect' or rnn_direction == 'bidirectional': + layernorm_size = 2 * rnn_size + elif rnn_direction == 'forward': + layernorm_size = rnn_size + else: + raise Exception("Wrong rnn direction") + for i in range(0, num_rnn_layers): + if i == 0: + rnn_input_size = i_size + else: + rnn_input_size = layernorm_size + if use_gru is True: + self.rnn.append( + nn.GRU( + input_size=rnn_input_size, + hidden_size=rnn_size, + num_layers=1, + direction=rnn_direction)) + else: + self.rnn.append( + nn.LSTM( + input_size=rnn_input_size, + hidden_size=rnn_size, + num_layers=1, + direction=rnn_direction)) + self.layernorm_list.append(nn.LayerNorm(layernorm_size)) + self.output_dim = layernorm_size + + fc_input_size = layernorm_size + for i in range(self.num_fc_layers): + self.fc_layers_list.append( + nn.Linear(fc_input_size, fc_layers_size_list[i])) + fc_input_size = fc_layers_size_list[i] + self.output_dim = fc_layers_size_list[i] @property def output_size(self): - return self.rnn_size * 2 + return self.output_dim - def forward(self, audio, audio_len): + def forward(self, x, x_lens, init_state_h_box=None, init_state_c_box=None): """Compute Encoder outputs Args: - audio (Tensor): [B, Tmax, D] - text (Tensor): [B, Umax] - audio_len (Tensor): [B] - text_len (Tensor): [B] - Returns: + x (Tensor): [B, T, D] + x_lens (Tensor): [B] + init_state_h_box(Tensor): init_states h for RNN layers: [num_rnn_layers * num_directions, batch_size, hidden_size] + init_state_c_box(Tensor): init_states c for RNN layers: [num_rnn_layers * num_directions, batch_size, hidden_size] + Return: x (Tensor): encoder outputs, [B, T, D] x_lens (Tensor): encoder length, [B] + final_state_h_box(Tensor): final_states h for RNN layers: [num_rnn_layers * num_directions, batch_size, hidden_size] + final_state_c_box(Tensor): final_states c for RNN layers: [num_rnn_layers * num_directions, batch_size, hidden_size] """ - # [B, T, D] -> [B, D, T] - audio = audio.transpose([0, 2, 1]) - # [B, D, T] -> [B, C=1, D, T] - x = audio.unsqueeze(1) - x_lens = audio_len + if init_state_h_box is not None: + init_state_list = None + + if self.use_gru is True: + init_state_h_list = paddle.split( + init_state_h_box, self.num_rnn_layers, axis=0) + init_state_list = init_state_h_list + else: + init_state_h_list = paddle.split( + init_state_h_box, self.num_rnn_layers, axis=0) + init_state_c_list = paddle.split( + init_state_c_box, self.num_rnn_layers, axis=0) + init_state_list = [(init_state_h_list[i], init_state_c_list[i]) + for i in range(self.num_rnn_layers)] + else: + init_state_list = [None] * self.num_rnn_layers - # convolution group x, x_lens = self.conv(x, x_lens) + final_chunk_state_list = [] + for i in range(0, self.num_rnn_layers): + x, final_state = self.rnn[i](x, init_state_list[i], + x_lens) #[B, T, D] + final_chunk_state_list.append(final_state) + x = self.layernorm_list[i](x) + + for i in range(self.num_fc_layers): + x = self.fc_layers_list[i](x) + x = F.relu(x) + + if self.use_gru is True: + final_chunk_state_h_box = paddle.concat( + final_chunk_state_list, axis=0) + final_chunk_state_c_box = init_state_c_box + else: + final_chunk_state_h_list = [ + final_chunk_state_list[i][0] for i in range(self.num_rnn_layers) + ] + final_chunk_state_c_list = [ + final_chunk_state_list[i][1] for i in range(self.num_rnn_layers) + ] + final_chunk_state_h_box = paddle.concat( + final_chunk_state_h_list, axis=0) + final_chunk_state_c_box = paddle.concat( + final_chunk_state_c_list, axis=0) + + return x, x_lens, final_chunk_state_h_box, final_chunk_state_c_box + + def forward_chunk_by_chunk(self, x, x_lens, decoder_chunk_size=8): + """Compute Encoder outputs - # convert data from convolution feature map to sequence of vectors - #B, C, D, T = paddle.shape(x) # not work under jit - x = x.transpose([0, 3, 1, 2]) #[B, T, C, D] - #x = x.reshape([B, T, C * D]) #[B, T, C*D] # not work under jit - x = x.reshape([0, 0, -1]) #[B, T, C*D] - - # remove padding part - x, x_lens = self.rnn(x, x_lens) #[B, T, D] - return x, x_lens + Args: + x (Tensor): [B, T, D] + x_lens (Tensor): [B] + decoder_chunk_size: The chunk size of decoder + Returns: + eouts_list (List of Tensor): The list of encoder outputs in chunk_size: [B, chunk_size, D] * num_chunks + eouts_lens_list (List of Tensor): The list of encoder length in chunk_size: [B] * num_chunks + final_state_h_box(Tensor): final_states h for RNN layers: [num_rnn_layers * num_directions, batch_size, hidden_size] + final_state_c_box(Tensor): final_states c for RNN layers: [num_rnn_layers * num_directions, batch_size, hidden_size] + """ + subsampling_rate = self.conv.subsampling_rate + receptive_field_length = self.conv.receptive_field_length + chunk_size = (decoder_chunk_size - 1 + ) * subsampling_rate + receptive_field_length + chunk_stride = subsampling_rate * decoder_chunk_size + max_len = x.shape[1] + assert (chunk_size <= max_len) + + eouts_chunk_list = [] + eouts_chunk_lens_list = [] + if (max_len - chunk_size) % chunk_stride != 0: + padding_len = chunk_stride - (max_len - chunk_size) % chunk_stride + else: + padding_len = 0 + padding = paddle.zeros((x.shape[0], padding_len, x.shape[2])) + padded_x = paddle.concat([x, padding], axis=1) + num_chunk = (max_len + padding_len - chunk_size) / chunk_stride + 1 + num_chunk = int(num_chunk) + chunk_state_h_box = None + chunk_state_c_box = None + final_state_h_box = None + final_state_c_box = None + for i in range(0, num_chunk): + start = i * chunk_stride + end = start + chunk_size + x_chunk = padded_x[:, start:end, :] + + x_len_left = paddle.where(x_lens - i * chunk_stride < 0, + paddle.zeros_like(x_lens), + x_lens - i * chunk_stride) + x_chunk_len_tmp = paddle.ones_like(x_lens) * chunk_size + x_chunk_lens = paddle.where(x_len_left < x_chunk_len_tmp, + x_len_left, x_chunk_len_tmp) + + eouts_chunk, eouts_chunk_lens, chunk_state_h_box, chunk_state_c_box = self.forward( + x_chunk, x_chunk_lens, chunk_state_h_box, chunk_state_c_box) + + eouts_chunk_list.append(eouts_chunk) + eouts_chunk_lens_list.append(eouts_chunk_lens) + final_state_h_box = chunk_state_h_box + final_state_c_box = chunk_state_c_box + return eouts_chunk_list, eouts_chunk_lens_list, final_state_h_box, final_state_c_box class DeepSpeech2Model(nn.Layer): """The DeepSpeech2 network structure. - :param audio_data: Audio spectrogram data layer. - :type audio_data: Variable - :param text_data: Transcription text data layer. - :type text_data: Variable + :param audio: Audio spectrogram data layer. + :type audio: Variable + :param text: Transcription text data layer. + :type text: Variable :param audio_len: Valid sequence length data layer. :type audio_len: Variable - :param masks: Masks data layer to reset padding. - :type masks: Variable + :param feat_size: feature size for audio. + :type feat_size: int :param dict_size: Dictionary size for tokenized transcription. :type dict_size: int :param num_conv_layers: Number of stacking convolution layers. @@ -106,37 +230,41 @@ class DeepSpeech2Model(nn.Layer): :type num_rnn_layers: int :param rnn_size: RNN layer size (dimension of RNN cells). :type rnn_size: int + :param num_fc_layers: Number of stacking FC layers. + :type num_fc_layers: int + :param fc_layers_size_list: The list of FC layer sizes. + :type fc_layers_size_list: [int,] :param use_gru: Use gru if set True. Use simple rnn if set False. :type use_gru: bool - :param share_rnn_weights: Whether to share input-hidden weights between - forward and backward direction RNNs. - It is only available when use_gru=False. - :type share_weights: bool :return: A tuple of an output unnormalized log probability layer ( before softmax) and a ctc cost layer. :rtype: tuple of LayerOutput """ - def __init__(self, - feat_size, - dict_size, - num_conv_layers=2, - num_rnn_layers=3, - rnn_size=1024, - use_gru=False, - share_rnn_weights=True, - blank_id=0, - ctc_grad_norm_type=None): + def __init__( + self, + feat_size, + dict_size, + num_conv_layers=2, + num_rnn_layers=4, + rnn_size=1024, + rnn_direction='forward', + num_fc_layers=2, + fc_layers_size_list=[512, 256], + use_gru=False, + blank_id=0, + ctc_grad_norm_type=None, ): super().__init__() self.encoder = CRNNEncoder( feat_size=feat_size, dict_size=dict_size, num_conv_layers=num_conv_layers, num_rnn_layers=num_rnn_layers, + rnn_direction=rnn_direction, + num_fc_layers=num_fc_layers, + fc_layers_size_list=fc_layers_size_list, rnn_size=rnn_size, - use_gru=use_gru, - share_rnn_weights=share_rnn_weights) - assert (self.encoder.output_size == rnn_size * 2) + use_gru=use_gru) self.decoder = CTCDecoder( odim=dict_size, # is in vocab @@ -151,7 +279,7 @@ class DeepSpeech2Model(nn.Layer): """Compute Model loss Args: - audio (Tensors): [B, T, D] + audio (Tensor): [B, T, D] audio_len (Tensor): [B] text (Tensor): [B, U] text_len (Tensor): [B] @@ -159,22 +287,22 @@ class DeepSpeech2Model(nn.Layer): Returns: loss (Tensor): [1] """ - eouts, eouts_len = self.encoder(audio, audio_len) + eouts, eouts_len, final_state_h_box, final_state_c_box = self.encoder( + audio, audio_len, None, None) loss = self.decoder(eouts, eouts_len, text, text_len) return loss @paddle.no_grad() def decode(self, audio, audio_len): # decoders only accept string encoded in utf-8 - # Make sure the decoder has been initialized - eouts, eouts_len = self.encoder(audio, audio_len) + eouts, eouts_len, final_state_h_box, final_state_c_box = self.encoder( + audio, audio_len, None, None) probs = self.decoder.softmax(eouts) batch_size = probs.shape[0] self.decoder.reset_decoder(batch_size=batch_size) self.decoder.next(probs, eouts_len) trans_best, trans_beam = self.decoder.decode() - return trans_best @classmethod @@ -196,13 +324,15 @@ class DeepSpeech2Model(nn.Layer): The model built from pretrained result. """ model = cls( - feat_size=dataloader.collate_fn.feature_size, - dict_size=dataloader.collate_fn.vocab_size, + feat_size=dataloader.feat_dim, + dict_size=dataloader.vocab_size, num_conv_layers=config.num_conv_layers, num_rnn_layers=config.num_rnn_layers, rnn_size=config.rnn_layer_size, + rnn_direction=config.rnn_direction, + num_fc_layers=config.num_fc_layers, + fc_layers_size_list=config.fc_layers_size_list, use_gru=config.use_gru, - share_rnn_weights=config.share_rnn_weights, blank_id=config.blank_id, ctc_grad_norm_type=config.get('ctc_grad_norm_type', None), ) infos = Checkpoint().load_parameters( @@ -229,8 +359,10 @@ class DeepSpeech2Model(nn.Layer): num_conv_layers=config.num_conv_layers, num_rnn_layers=config.num_rnn_layers, rnn_size=config.rnn_layer_size, + rnn_direction=config.rnn_direction, + num_fc_layers=config.num_fc_layers, + fc_layers_size_list=config.fc_layers_size_list, use_gru=config.use_gru, - share_rnn_weights=config.share_rnn_weights, blank_id=config.blank_id, ctc_grad_norm_type=config.get('ctc_grad_norm_type', None), ) return model @@ -240,28 +372,46 @@ class DeepSpeech2InferModel(DeepSpeech2Model): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - def forward(self, audio, audio_len): - """export model function - - Args: - audio (Tensor): [B, T, D] - audio_len (Tensor): [B] - - Returns: - probs: probs after softmax - """ - eouts, eouts_len = self.encoder(audio, audio_len) - probs = self.decoder.softmax(eouts) - return probs, eouts_len + def forward(self, audio_chunk, audio_chunk_lens, chunk_state_h_box=None, + chunk_state_c_box=None): + if self.encoder.rnn_direction == "forward": + eouts_chunk, eouts_chunk_lens, final_state_h_box, final_state_c_box = self.encoder( + audio_chunk, audio_chunk_lens, chunk_state_h_box, chunk_state_c_box) + probs_chunk = self.decoder.softmax(eouts_chunk) + return probs_chunk, eouts_chunk_lens, final_state_h_box, final_state_c_box + elif self.encoder.rnn_direction == "bidirect": + eouts, eouts_len, _, _ = self.encoder(audio_chunk, audio_chunk_lens) + probs = self.decoder.softmax(eouts) + return probs, eouts_len + else: + raise Exception("wrong model type") def export(self): - static_model = paddle.jit.to_static( - self, - input_spec=[ - paddle.static.InputSpec( - shape=[None, None, self.encoder.feat_size], - dtype='float32'), # audio, [B,T,D] - paddle.static.InputSpec(shape=[None], - dtype='int64'), # audio_length, [B] - ]) + if self.encoder.rnn_direction == "forward": + static_model = paddle.jit.to_static( + self, + input_spec=[ + paddle.static.InputSpec( + shape=[None, None, + self.encoder.feat_size], #[B, chunk_size, feat_dim] + dtype='float32'), + paddle.static.InputSpec(shape=[None], + dtype='int64'), # audio_length, [B] + paddle.static.InputSpec( + shape=[None, None, None], dtype='float32'), + paddle.static.InputSpec( + shape=[None, None, None], dtype='float32') + ]) + elif self.encoder.rnn_direction == "bidirect": + static_model = paddle.jit.to_static( + self, + input_spec=[ + paddle.static.InputSpec( + shape=[None, None, self.encoder.feat_size], + dtype='float32'), # audio, [B,T,D] + paddle.static.InputSpec(shape=[None], + dtype='int64'), # audio_length, [B] + ]) + else: + raise Exception("wrong model type") return static_model diff --git a/paddlespeech/s2t/models/ds2/rnn.py b/paddlespeech/s2t/models/ds2/rnn.py deleted file mode 100644 index f655b2d82..000000000 --- a/paddlespeech/s2t/models/ds2/rnn.py +++ /dev/null @@ -1,315 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import math - -import paddle -from paddle import nn -from paddle.nn import functional as F -from paddle.nn import initializer as I - -from paddlespeech.s2t.modules.activation import brelu -from paddlespeech.s2t.modules.mask import make_non_pad_mask -from paddlespeech.s2t.utils.log import Log - -logger = Log(__name__).getlog() - -__all__ = ['RNNStack'] - - -class RNNCell(nn.RNNCellBase): - r""" - Elman RNN (SimpleRNN) cell. Given the inputs and previous states, it - computes the outputs and updates states. - The formula used is as follows: - .. math:: - h_{t} & = act(x_{t} + b_{ih} + W_{hh}h_{t-1} + b_{hh}) - y_{t} & = h_{t} - - where :math:`act` is for :attr:`activation`. - """ - - def __init__(self, - hidden_size: int, - activation="tanh", - weight_ih_attr=None, - weight_hh_attr=None, - bias_ih_attr=None, - bias_hh_attr=None, - name=None): - super().__init__() - std = 1.0 / math.sqrt(hidden_size) - self.weight_hh = self.create_parameter( - (hidden_size, hidden_size), - weight_hh_attr, - default_initializer=I.Uniform(-std, std)) - self.bias_ih = None - self.bias_hh = self.create_parameter( - (hidden_size, ), - bias_hh_attr, - is_bias=True, - default_initializer=I.Uniform(-std, std)) - - self.hidden_size = hidden_size - if activation not in ["tanh", "relu", "brelu"]: - raise ValueError( - "activation for SimpleRNNCell should be tanh or relu, " - "but get {}".format(activation)) - self.activation = activation - self._activation_fn = paddle.tanh \ - if activation == "tanh" \ - else F.relu - if activation == 'brelu': - self._activation_fn = brelu - - def forward(self, inputs, states=None): - if states is None: - states = self.get_initial_states(inputs, self.state_shape) - pre_h = states - i2h = inputs - if self.bias_ih is not None: - i2h += self.bias_ih - h2h = paddle.matmul(pre_h, self.weight_hh, transpose_y=True) - if self.bias_hh is not None: - h2h += self.bias_hh - h = self._activation_fn(i2h + h2h) - return h, h - - @property - def state_shape(self): - return (self.hidden_size, ) - - -class GRUCell(nn.RNNCellBase): - r""" - Gated Recurrent Unit (GRU) RNN cell. Given the inputs and previous states, - it computes the outputs and updates states. - The formula for GRU used is as follows: - .. math:: - r_{t} & = \sigma(W_{ir}x_{t} + b_{ir} + W_{hr}h_{t-1} + b_{hr}) - z_{t} & = \sigma(W_{iz}x_{t} + b_{iz} + W_{hz}h_{t-1} + b_{hz}) - \widetilde{h}_{t} & = \tanh(W_{ic}x_{t} + b_{ic} + r_{t} * (W_{hc}h_{t-1} + b_{hc})) - h_{t} & = z_{t} * h_{t-1} + (1 - z_{t}) * \widetilde{h}_{t} - y_{t} & = h_{t} - - where :math:`\sigma` is the sigmoid fucntion, and * is the elemetwise - multiplication operator. - """ - - def __init__(self, - input_size: int, - hidden_size: int, - weight_ih_attr=None, - weight_hh_attr=None, - bias_ih_attr=None, - bias_hh_attr=None, - name=None): - super().__init__() - std = 1.0 / math.sqrt(hidden_size) - self.weight_hh = self.create_parameter( - (3 * hidden_size, hidden_size), - weight_hh_attr, - default_initializer=I.Uniform(-std, std)) - self.bias_ih = None - self.bias_hh = self.create_parameter( - (3 * hidden_size, ), - bias_hh_attr, - is_bias=True, - default_initializer=I.Uniform(-std, std)) - - self.hidden_size = hidden_size - self.input_size = input_size - self._gate_activation = F.sigmoid - self._activation = paddle.tanh - - def forward(self, inputs, states=None): - if states is None: - states = self.get_initial_states(inputs, self.state_shape) - - pre_hidden = states - x_gates = inputs - if self.bias_ih is not None: - x_gates = x_gates + self.bias_ih - h_gates = paddle.matmul(pre_hidden, self.weight_hh, transpose_y=True) - if self.bias_hh is not None: - h_gates = h_gates + self.bias_hh - - x_r, x_z, x_c = paddle.split(x_gates, num_or_sections=3, axis=1) - h_r, h_z, h_c = paddle.split(h_gates, num_or_sections=3, axis=1) - - r = self._gate_activation(x_r + h_r) - z = self._gate_activation(x_z + h_z) - c = self._activation(x_c + r * h_c) # apply reset gate after mm - h = (pre_hidden - c) * z + c - # https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/fluid/layers/dynamic_gru_cn.html#dynamic-gru - - return h, h - - @property - def state_shape(self): - r""" - The `state_shape` of GRUCell is a shape `[hidden_size]` (-1 for batch - size would be automatically inserted into shape). The shape corresponds - to the shape of :math:`h_{t-1}`. - """ - return (self.hidden_size, ) - - -class BiRNNWithBN(nn.Layer): - """Bidirectonal simple rnn layer with sequence-wise batch normalization. - The batch normalization is only performed on input-state weights. - - :param size: Dimension of RNN cells. - :type size: int - :param share_weights: Whether to share input-hidden weights between - forward and backward directional RNNs. - :type share_weights: bool - :return: Bidirectional simple rnn layer. - :rtype: Variable - """ - - def __init__(self, i_size: int, h_size: int, share_weights: bool): - super().__init__() - self.share_weights = share_weights - if self.share_weights: - #input-hidden weights shared between bi-directional rnn. - self.fw_fc = nn.Linear(i_size, h_size, bias_attr=False) - # batch norm is only performed on input-state projection - self.fw_bn = nn.BatchNorm1D( - h_size, bias_attr=None, data_format='NLC') - self.bw_fc = self.fw_fc - self.bw_bn = self.fw_bn - else: - self.fw_fc = nn.Linear(i_size, h_size, bias_attr=False) - self.fw_bn = nn.BatchNorm1D( - h_size, bias_attr=None, data_format='NLC') - self.bw_fc = nn.Linear(i_size, h_size, bias_attr=False) - self.bw_bn = nn.BatchNorm1D( - h_size, bias_attr=None, data_format='NLC') - - self.fw_cell = RNNCell(hidden_size=h_size, activation='brelu') - self.bw_cell = RNNCell(hidden_size=h_size, activation='brelu') - self.fw_rnn = nn.RNN( - self.fw_cell, is_reverse=False, time_major=False) #[B, T, D] - self.bw_rnn = nn.RNN( - self.fw_cell, is_reverse=True, time_major=False) #[B, T, D] - - def forward(self, x: paddle.Tensor, x_len: paddle.Tensor): - # x, shape [B, T, D] - fw_x = self.fw_bn(self.fw_fc(x)) - bw_x = self.bw_bn(self.bw_fc(x)) - fw_x, _ = self.fw_rnn(inputs=fw_x, sequence_length=x_len) - bw_x, _ = self.bw_rnn(inputs=bw_x, sequence_length=x_len) - x = paddle.concat([fw_x, bw_x], axis=-1) - return x, x_len - - -class BiGRUWithBN(nn.Layer): - """Bidirectonal gru layer with sequence-wise batch normalization. - The batch normalization is only performed on input-state weights. - - :param name: Name of the layer. - :type name: string - :param input: Input layer. - :type input: Variable - :param size: Dimension of GRU cells. - :type size: int - :param act: Activation type. - :type act: string - :return: Bidirectional GRU layer. - :rtype: Variable - """ - - def __init__(self, i_size: int, h_size: int): - super().__init__() - hidden_size = h_size * 3 - - self.fw_fc = nn.Linear(i_size, hidden_size, bias_attr=False) - self.fw_bn = nn.BatchNorm1D( - hidden_size, bias_attr=None, data_format='NLC') - self.bw_fc = nn.Linear(i_size, hidden_size, bias_attr=False) - self.bw_bn = nn.BatchNorm1D( - hidden_size, bias_attr=None, data_format='NLC') - - self.fw_cell = GRUCell(input_size=hidden_size, hidden_size=h_size) - self.bw_cell = GRUCell(input_size=hidden_size, hidden_size=h_size) - self.fw_rnn = nn.RNN( - self.fw_cell, is_reverse=False, time_major=False) #[B, T, D] - self.bw_rnn = nn.RNN( - self.fw_cell, is_reverse=True, time_major=False) #[B, T, D] - - def forward(self, x, x_len): - # x, shape [B, T, D] - fw_x = self.fw_bn(self.fw_fc(x)) - bw_x = self.bw_bn(self.bw_fc(x)) - fw_x, _ = self.fw_rnn(inputs=fw_x, sequence_length=x_len) - bw_x, _ = self.bw_rnn(inputs=bw_x, sequence_length=x_len) - x = paddle.concat([fw_x, bw_x], axis=-1) - return x, x_len - - -class RNNStack(nn.Layer): - """RNN group with stacked bidirectional simple RNN or GRU layers. - - :param input: Input layer. - :type input: Variable - :param size: Dimension of RNN cells in each layer. - :type size: int - :param num_stacks: Number of stacked rnn layers. - :type num_stacks: int - :param use_gru: Use gru if set True. Use simple rnn if set False. - :type use_gru: bool - :param share_rnn_weights: Whether to share input-hidden weights between - forward and backward directional RNNs. - It is only available when use_gru=False. - :type share_weights: bool - :return: Output layer of the RNN group. - :rtype: Variable - """ - - def __init__(self, - i_size: int, - h_size: int, - num_stacks: int, - use_gru: bool, - share_rnn_weights: bool): - super().__init__() - rnn_stacks = [] - for i in range(num_stacks): - if use_gru: - #default:GRU using tanh - rnn_stacks.append(BiGRUWithBN(i_size=i_size, h_size=h_size)) - else: - rnn_stacks.append( - BiRNNWithBN( - i_size=i_size, - h_size=h_size, - share_weights=share_rnn_weights)) - i_size = h_size * 2 - - self.rnn_stacks = nn.LayerList(rnn_stacks) - - def forward(self, x: paddle.Tensor, x_len: paddle.Tensor): - """ - x: shape [B, T, D] - x_len: shpae [B] - """ - for i, rnn in enumerate(self.rnn_stacks): - x, x_len = rnn(x, x_len) - masks = make_non_pad_mask(x_len) #[B, T] - masks = masks.unsqueeze(-1) # [B, T, 1] - # TODO(Hui Zhang): not support bool multiply - masks = masks.astype(x.dtype) - x = x.multiply(masks) - - return x, x_len diff --git a/paddlespeech/s2t/models/ds2_online/__init__.py b/paddlespeech/s2t/models/ds2_online/__init__.py deleted file mode 100644 index de772b645..000000000 --- a/paddlespeech/s2t/models/ds2_online/__init__.py +++ /dev/null @@ -1,31 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from .deepspeech2 import DeepSpeech2InferModelOnline -from .deepspeech2 import DeepSpeech2ModelOnline -from paddlespeech.s2t.utils import dynamic_pip_install -import sys - -try: - import paddlespeech_ctcdecoders -except ImportError: - try: - package_name = 'paddlespeech_ctcdecoders' - if sys.platform != "win32": - dynamic_pip_install.install(package_name) - except Exception: - raise RuntimeError( - "Can not install package paddlespeech_ctcdecoders on your system. \ - The DeepSpeech2 model is not supported for your system") - -__all__ = ['DeepSpeech2ModelOnline', 'DeepSpeech2InferModelOnline'] diff --git a/paddlespeech/s2t/models/ds2_online/conv.py b/paddlespeech/s2t/models/ds2_online/conv.py deleted file mode 100644 index 25a9715a3..000000000 --- a/paddlespeech/s2t/models/ds2_online/conv.py +++ /dev/null @@ -1,33 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import paddle - -from paddlespeech.s2t.modules.subsampling import Conv2dSubsampling4 - - -class Conv2dSubsampling4Online(Conv2dSubsampling4): - def __init__(self, idim: int, odim: int, dropout_rate: float): - super().__init__(idim, odim, dropout_rate, None) - self.output_dim = ((idim - 1) // 2 - 1) // 2 * odim - self.receptive_field_length = 2 * ( - 3 - 1) + 3 # stride_1 * (kernel_size_2 - 1) + kerel_size_1 - - def forward(self, x: paddle.Tensor, - x_len: paddle.Tensor) -> [paddle.Tensor, paddle.Tensor]: - x = x.unsqueeze(1) # (b, c=1, t, f) - x = self.conv(x) - #b, c, t, f = paddle.shape(x) #not work under jit - x = x.transpose([0, 2, 1, 3]).reshape([0, 0, -1]) - x_len = ((x_len - 1) // 2 - 1) // 2 - return x, x_len diff --git a/paddlespeech/s2t/models/ds2_online/deepspeech2.py b/paddlespeech/s2t/models/ds2_online/deepspeech2.py deleted file mode 100644 index 9574a62bd..000000000 --- a/paddlespeech/s2t/models/ds2_online/deepspeech2.py +++ /dev/null @@ -1,397 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Deepspeech2 ASR Online Model""" -import paddle -import paddle.nn.functional as F -from paddle import nn - -from paddlespeech.s2t.models.ds2_online.conv import Conv2dSubsampling4Online -from paddlespeech.s2t.modules.ctc import CTCDecoder -from paddlespeech.s2t.utils import layer_tools -from paddlespeech.s2t.utils.checkpoint import Checkpoint -from paddlespeech.s2t.utils.log import Log -logger = Log(__name__).getlog() - -__all__ = ['DeepSpeech2ModelOnline', 'DeepSpeech2InferModelOnline'] - - -class CRNNEncoder(nn.Layer): - def __init__(self, - feat_size, - dict_size, - num_conv_layers=2, - num_rnn_layers=4, - rnn_size=1024, - rnn_direction='forward', - num_fc_layers=2, - fc_layers_size_list=[512, 256], - use_gru=False): - super().__init__() - self.rnn_size = rnn_size - self.feat_size = feat_size # 161 for linear - self.dict_size = dict_size - self.num_rnn_layers = num_rnn_layers - self.num_fc_layers = num_fc_layers - self.rnn_direction = rnn_direction - self.fc_layers_size_list = fc_layers_size_list - self.use_gru = use_gru - self.conv = Conv2dSubsampling4Online(feat_size, 32, dropout_rate=0.0) - - self.output_dim = self.conv.output_dim - - i_size = self.conv.output_dim - self.rnn = nn.LayerList() - self.layernorm_list = nn.LayerList() - self.fc_layers_list = nn.LayerList() - if rnn_direction == 'bidirect' or rnn_direction == 'bidirectional': - layernorm_size = 2 * rnn_size - elif rnn_direction == 'forward': - layernorm_size = rnn_size - else: - raise Exception("Wrong rnn direction") - for i in range(0, num_rnn_layers): - if i == 0: - rnn_input_size = i_size - else: - rnn_input_size = layernorm_size - if use_gru is True: - self.rnn.append( - nn.GRU( - input_size=rnn_input_size, - hidden_size=rnn_size, - num_layers=1, - direction=rnn_direction)) - else: - self.rnn.append( - nn.LSTM( - input_size=rnn_input_size, - hidden_size=rnn_size, - num_layers=1, - direction=rnn_direction)) - self.layernorm_list.append(nn.LayerNorm(layernorm_size)) - self.output_dim = layernorm_size - - fc_input_size = layernorm_size - for i in range(self.num_fc_layers): - self.fc_layers_list.append( - nn.Linear(fc_input_size, fc_layers_size_list[i])) - fc_input_size = fc_layers_size_list[i] - self.output_dim = fc_layers_size_list[i] - - @property - def output_size(self): - return self.output_dim - - def forward(self, x, x_lens, init_state_h_box=None, init_state_c_box=None): - """Compute Encoder outputs - - Args: - x (Tensor): [B, T, D] - x_lens (Tensor): [B] - init_state_h_box(Tensor): init_states h for RNN layers: [num_rnn_layers * num_directions, batch_size, hidden_size] - init_state_c_box(Tensor): init_states c for RNN layers: [num_rnn_layers * num_directions, batch_size, hidden_size] - Return: - x (Tensor): encoder outputs, [B, T, D] - x_lens (Tensor): encoder length, [B] - final_state_h_box(Tensor): final_states h for RNN layers: [num_rnn_layers * num_directions, batch_size, hidden_size] - final_state_c_box(Tensor): final_states c for RNN layers: [num_rnn_layers * num_directions, batch_size, hidden_size] - """ - if init_state_h_box is not None: - init_state_list = None - - if self.use_gru is True: - init_state_h_list = paddle.split( - init_state_h_box, self.num_rnn_layers, axis=0) - init_state_list = init_state_h_list - else: - init_state_h_list = paddle.split( - init_state_h_box, self.num_rnn_layers, axis=0) - init_state_c_list = paddle.split( - init_state_c_box, self.num_rnn_layers, axis=0) - init_state_list = [(init_state_h_list[i], init_state_c_list[i]) - for i in range(self.num_rnn_layers)] - else: - init_state_list = [None] * self.num_rnn_layers - - x, x_lens = self.conv(x, x_lens) - final_chunk_state_list = [] - for i in range(0, self.num_rnn_layers): - x, final_state = self.rnn[i](x, init_state_list[i], - x_lens) #[B, T, D] - final_chunk_state_list.append(final_state) - x = self.layernorm_list[i](x) - - for i in range(self.num_fc_layers): - x = self.fc_layers_list[i](x) - x = F.relu(x) - - if self.use_gru is True: - final_chunk_state_h_box = paddle.concat( - final_chunk_state_list, axis=0) - final_chunk_state_c_box = init_state_c_box - else: - final_chunk_state_h_list = [ - final_chunk_state_list[i][0] for i in range(self.num_rnn_layers) - ] - final_chunk_state_c_list = [ - final_chunk_state_list[i][1] for i in range(self.num_rnn_layers) - ] - final_chunk_state_h_box = paddle.concat( - final_chunk_state_h_list, axis=0) - final_chunk_state_c_box = paddle.concat( - final_chunk_state_c_list, axis=0) - - return x, x_lens, final_chunk_state_h_box, final_chunk_state_c_box - - def forward_chunk_by_chunk(self, x, x_lens, decoder_chunk_size=8): - """Compute Encoder outputs - - Args: - x (Tensor): [B, T, D] - x_lens (Tensor): [B] - decoder_chunk_size: The chunk size of decoder - Returns: - eouts_list (List of Tensor): The list of encoder outputs in chunk_size: [B, chunk_size, D] * num_chunks - eouts_lens_list (List of Tensor): The list of encoder length in chunk_size: [B] * num_chunks - final_state_h_box(Tensor): final_states h for RNN layers: [num_rnn_layers * num_directions, batch_size, hidden_size] - final_state_c_box(Tensor): final_states c for RNN layers: [num_rnn_layers * num_directions, batch_size, hidden_size] - """ - subsampling_rate = self.conv.subsampling_rate - receptive_field_length = self.conv.receptive_field_length - chunk_size = (decoder_chunk_size - 1 - ) * subsampling_rate + receptive_field_length - chunk_stride = subsampling_rate * decoder_chunk_size - max_len = x.shape[1] - assert (chunk_size <= max_len) - - eouts_chunk_list = [] - eouts_chunk_lens_list = [] - if (max_len - chunk_size) % chunk_stride != 0: - padding_len = chunk_stride - (max_len - chunk_size) % chunk_stride - else: - padding_len = 0 - padding = paddle.zeros((x.shape[0], padding_len, x.shape[2])) - padded_x = paddle.concat([x, padding], axis=1) - num_chunk = (max_len + padding_len - chunk_size) / chunk_stride + 1 - num_chunk = int(num_chunk) - chunk_state_h_box = None - chunk_state_c_box = None - final_state_h_box = None - final_state_c_box = None - for i in range(0, num_chunk): - start = i * chunk_stride - end = start + chunk_size - x_chunk = padded_x[:, start:end, :] - - x_len_left = paddle.where(x_lens - i * chunk_stride < 0, - paddle.zeros_like(x_lens), - x_lens - i * chunk_stride) - x_chunk_len_tmp = paddle.ones_like(x_lens) * chunk_size - x_chunk_lens = paddle.where(x_len_left < x_chunk_len_tmp, - x_len_left, x_chunk_len_tmp) - - eouts_chunk, eouts_chunk_lens, chunk_state_h_box, chunk_state_c_box = self.forward( - x_chunk, x_chunk_lens, chunk_state_h_box, chunk_state_c_box) - - eouts_chunk_list.append(eouts_chunk) - eouts_chunk_lens_list.append(eouts_chunk_lens) - final_state_h_box = chunk_state_h_box - final_state_c_box = chunk_state_c_box - return eouts_chunk_list, eouts_chunk_lens_list, final_state_h_box, final_state_c_box - - -class DeepSpeech2ModelOnline(nn.Layer): - """The DeepSpeech2 network structure for online. - - :param audio: Audio spectrogram data layer. - :type audio: Variable - :param text: Transcription text data layer. - :type text: Variable - :param audio_len: Valid sequence length data layer. - :type audio_len: Variable - :param feat_size: feature size for audio. - :type feat_size: int - :param dict_size: Dictionary size for tokenized transcription. - :type dict_size: int - :param num_conv_layers: Number of stacking convolution layers. - :type num_conv_layers: int - :param num_rnn_layers: Number of stacking RNN layers. - :type num_rnn_layers: int - :param rnn_size: RNN layer size (dimension of RNN cells). - :type rnn_size: int - :param num_fc_layers: Number of stacking FC layers. - :type num_fc_layers: int - :param fc_layers_size_list: The list of FC layer sizes. - :type fc_layers_size_list: [int,] - :param use_gru: Use gru if set True. Use simple rnn if set False. - :type use_gru: bool - :return: A tuple of an output unnormalized log probability layer ( - before softmax) and a ctc cost layer. - :rtype: tuple of LayerOutput - """ - - def __init__( - self, - feat_size, - dict_size, - num_conv_layers=2, - num_rnn_layers=4, - rnn_size=1024, - rnn_direction='forward', - num_fc_layers=2, - fc_layers_size_list=[512, 256], - use_gru=False, - blank_id=0, - ctc_grad_norm_type=None, ): - super().__init__() - self.encoder = CRNNEncoder( - feat_size=feat_size, - dict_size=dict_size, - num_conv_layers=num_conv_layers, - num_rnn_layers=num_rnn_layers, - rnn_direction=rnn_direction, - num_fc_layers=num_fc_layers, - fc_layers_size_list=fc_layers_size_list, - rnn_size=rnn_size, - use_gru=use_gru) - - self.decoder = CTCDecoder( - odim=dict_size, # is in vocab - enc_n_units=self.encoder.output_size, - blank_id=blank_id, - dropout_rate=0.0, - reduction=True, # sum - batch_average=True, # sum / batch_size - grad_norm_type=ctc_grad_norm_type) - - def forward(self, audio, audio_len, text, text_len): - """Compute Model loss - - Args: - audio (Tensor): [B, T, D] - audio_len (Tensor): [B] - text (Tensor): [B, U] - text_len (Tensor): [B] - - Returns: - loss (Tensor): [1] - """ - eouts, eouts_len, final_state_h_box, final_state_c_box = self.encoder( - audio, audio_len, None, None) - loss = self.decoder(eouts, eouts_len, text, text_len) - return loss - - @paddle.no_grad() - def decode(self, audio, audio_len): - # decoders only accept string encoded in utf-8 - # Make sure the decoder has been initialized - eouts, eouts_len, final_state_h_box, final_state_c_box = self.encoder( - audio, audio_len, None, None) - probs = self.decoder.softmax(eouts) - batch_size = probs.shape[0] - self.decoder.reset_decoder(batch_size=batch_size) - self.decoder.next(probs, eouts_len) - trans_best, trans_beam = self.decoder.decode() - return trans_best - - @classmethod - def from_pretrained(cls, dataloader, config, checkpoint_path): - """Build a DeepSpeech2Model model from a pretrained model. - Parameters - ---------- - dataloader: paddle.io.DataLoader - - config: yacs.config.CfgNode - model configs - - checkpoint_path: Path or str - the path of pretrained model checkpoint, without extension name - - Returns - ------- - DeepSpeech2ModelOnline - The model built from pretrained result. - """ - model = cls( - feat_size=dataloader.collate_fn.feature_size, - dict_size=dataloader.collate_fn.vocab_size, - num_conv_layers=config.num_conv_layers, - num_rnn_layers=config.num_rnn_layers, - rnn_size=config.rnn_layer_size, - rnn_direction=config.rnn_direction, - num_fc_layers=config.num_fc_layers, - fc_layers_size_list=config.fc_layers_size_list, - use_gru=config.use_gru, - blank_id=config.blank_id, - ctc_grad_norm_type=config.get('ctc_grad_norm_type', None), ) - infos = Checkpoint().load_parameters( - model, checkpoint_path=checkpoint_path) - logger.info(f"checkpoint info: {infos}") - layer_tools.summary(model) - return model - - @classmethod - def from_config(cls, config): - """Build a DeepSpeec2ModelOnline from config - Parameters - - config: yacs.config.CfgNode - config - Returns - ------- - DeepSpeech2ModelOnline - The model built from config. - """ - model = cls( - feat_size=config.input_dim, - dict_size=config.output_dim, - num_conv_layers=config.num_conv_layers, - num_rnn_layers=config.num_rnn_layers, - rnn_size=config.rnn_layer_size, - rnn_direction=config.rnn_direction, - num_fc_layers=config.num_fc_layers, - fc_layers_size_list=config.fc_layers_size_list, - use_gru=config.use_gru, - blank_id=config.blank_id, - ctc_grad_norm_type=config.get('ctc_grad_norm_type', None), ) - return model - - -class DeepSpeech2InferModelOnline(DeepSpeech2ModelOnline): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - def forward(self, audio_chunk, audio_chunk_lens, chunk_state_h_box, - chunk_state_c_box): - eouts_chunk, eouts_chunk_lens, final_state_h_box, final_state_c_box = self.encoder( - audio_chunk, audio_chunk_lens, chunk_state_h_box, chunk_state_c_box) - probs_chunk = self.decoder.softmax(eouts_chunk) - return probs_chunk, eouts_chunk_lens, final_state_h_box, final_state_c_box - - def export(self): - static_model = paddle.jit.to_static( - self, - input_spec=[ - paddle.static.InputSpec( - shape=[None, None, - self.encoder.feat_size], #[B, chunk_size, feat_dim] - dtype='float32'), - paddle.static.InputSpec(shape=[None], - dtype='int64'), # audio_length, [B] - paddle.static.InputSpec( - shape=[None, None, None], dtype='float32'), - paddle.static.InputSpec( - shape=[None, None, None], dtype='float32') - ]) - return static_model diff --git a/paddlespeech/server/engine/asr/online/asr_engine.py b/paddlespeech/server/engine/asr/online/asr_engine.py index 14715bf35..a52219730 100644 --- a/paddlespeech/server/engine/asr/online/asr_engine.py +++ b/paddlespeech/server/engine/asr/online/asr_engine.py @@ -25,7 +25,6 @@ from paddlespeech.cli.log import logger from paddlespeech.cli.utils import MODEL_HOME from paddlespeech.resource import CommonTaskResource from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer -from paddlespeech.s2t.frontend.speech import SpeechSegment from paddlespeech.s2t.modules.ctc import CTCDecoder from paddlespeech.s2t.transform.transformation import Transformation from paddlespeech.s2t.utils.tensor_utils import add_sos_eos @@ -66,10 +65,13 @@ class PaddleASRConnectionHanddler: self.text_feature = self.asr_engine.executor.text_feature if "deepspeech2" in self.model_type: - from paddlespeech.s2t.io.collator import SpeechCollator self.am_predictor = self.asr_engine.executor.am_predictor - self.collate_fn_test = SpeechCollator.from_config(self.model_config) + # extract feat, new only fbank in conformer model + self.preprocess_conf = self.model_config.preprocess_config + self.preprocess_args = {"train": False} + self.preprocessing = Transformation(self.preprocess_conf) + self.decoder = CTCDecoder( odim=self.model_config.output_dim, # is in vocab enc_n_units=self.model_config.rnn_layer_size * 2, @@ -89,10 +91,8 @@ class PaddleASRConnectionHanddler: cfg.num_proc_bsearch) # frame window and frame shift, in samples unit - self.win_length = int(self.model_config.window_ms / 1000 * - self.sample_rate) - self.n_shift = int(self.model_config.stride_ms / 1000 * - self.sample_rate) + self.win_length = self.preprocess_conf.process[0]['win_length'] + self.n_shift = self.preprocess_conf.process[0]['n_shift'] elif "conformer" in self.model_type or "transformer" in self.model_type: # acoustic model @@ -114,20 +114,15 @@ class PaddleASRConnectionHanddler: raise ValueError(f"Not supported: {self.model_type}") def extract_feat(self, samples): - # we compute the elapsed time of first char occuring + # we compute the elapsed time of first char occuring # and we record the start time at the first pcm sample arraving if "deepspeech2online" in self.model_type: - # self.reamined_wav stores all the samples, + # self.reamined_wav stores all the samples, # include the original remained_wav and this package samples samples = np.frombuffer(samples, dtype=np.int16) assert samples.ndim == 1 - # pcm16 -> pcm 32 - # pcm2float will change the orignal samples, - # so we shoule do pcm2float before concatenate - samples = pcm2float(samples) - if self.remained_wav is None: self.remained_wav = samples else: @@ -137,26 +132,11 @@ class PaddleASRConnectionHanddler: f"The connection remain the audio samples: {self.remained_wav.shape}" ) - # read audio - speech_segment = SpeechSegment.from_pcm( - self.remained_wav, self.sample_rate, transcript=" ") - # audio augment - self.collate_fn_test.augmentation.transform_audio(speech_segment) - - # extract speech feature - spectrum, transcript_part = self.collate_fn_test._speech_featurizer.featurize( - speech_segment, self.collate_fn_test.keep_transcription_text) - # CMVN spectrum - if self.collate_fn_test._normalizer: - spectrum = self.collate_fn_test._normalizer.apply(spectrum) - - # spectrum augment - feat = self.collate_fn_test.augmentation.transform_feature(spectrum) - - # audio_len is frame num - frame_num = feat.shape[0] - feat = paddle.to_tensor(feat, dtype='float32') - feat = paddle.unsqueeze(feat, axis=0) + # fbank + feat = self.preprocessing(self.remained_wav, + **self.preprocess_args) + feat = paddle.to_tensor( + feat, dtype="float32").unsqueeze(axis=0) if self.cached_feat is None: self.cached_feat = feat @@ -170,8 +150,11 @@ class PaddleASRConnectionHanddler: if self.device is None: self.device = self.cached_feat.place - self.num_frames += frame_num - self.remained_wav = self.remained_wav[self.n_shift * frame_num:] + # cur frame step + num_frames = feat.shape[1] + + self.num_frames += num_frames + self.remained_wav = self.remained_wav[self.n_shift * num_frames:] logger.info( f"process the audio feature success, the connection feat shape: {self.cached_feat.shape}" @@ -190,7 +173,7 @@ class PaddleASRConnectionHanddler: f"This package receive {samples.shape[0]} pcm data. Global samples:{self.num_samples}" ) - # self.reamined_wav stores all the samples, + # self.reamined_wav stores all the samples, # include the original remained_wav and this package samples if self.remained_wav is None: self.remained_wav = samples @@ -246,7 +229,7 @@ class PaddleASRConnectionHanddler: def reset(self): if "deepspeech2" in self.model_type: - # for deepspeech2 + # for deepspeech2 # init state self.chunk_state_h_box = np.zeros( (self.model_config.num_rnn_layers, 1, @@ -275,7 +258,7 @@ class PaddleASRConnectionHanddler: ## conformer - # cache for conformer online + # cache for conformer online self.subsampling_cache = None self.elayers_output_cache = None self.conformer_cnn_cache = None @@ -359,7 +342,7 @@ class PaddleASRConnectionHanddler: # update feat cache self.cached_feat = self.cached_feat[:, end - cached_feature_num:, :] - # return trans_best[0] + # return trans_best[0] elif "conformer" in self.model_type or "transformer" in self.model_type: try: logger.info( @@ -565,7 +548,7 @@ class PaddleASRConnectionHanddler: @paddle.no_grad() def rescoring(self): - """Second-Pass Decoding, + """Second-Pass Decoding, only for conformer and transformer model. """ if "deepspeech2" in self.model_type: @@ -652,11 +635,11 @@ class PaddleASRConnectionHanddler: ## asr results # hyps[0][0]: the sentence word-id in the vocab with a tuple # hyps[0][1]: the sentence decoding probability with all paths - ## timestamp + ## timestamp # hyps[0][2]: viterbi_blank ending probability # hyps[0][3]: viterbi_non_blank dending probability # hyps[0][4]: current_token_prob, - # hyps[0][5]: times_viterbi_blank ending timestamp, + # hyps[0][5]: times_viterbi_blank ending timestamp, # hyps[0][6]: times_titerbi_non_blank encding timestamp. self.hyps = [hyps[best_index][0]] logger.info(f"best hyp ids: {self.hyps}") @@ -752,16 +735,19 @@ class ASRServerExecutor(ASRExecutor): self.config = CfgNode(new_allowed=True) self.config.merge_from_file(self.cfg_path) + if self.config.spm_model_prefix: + self.config.spm_model_prefix = os.path.join( + self.res_path, self.config.spm_model_prefix) + self.text_feature = TextFeaturizer( + unit_type=self.config.unit_type, + vocab=self.config.vocab_filepath, + spm_model_prefix=self.config.spm_model_prefix) + self.vocab = self.config.vocab_filepath with UpdateConfig(self.config): if "deepspeech2" in model_type: - from paddlespeech.s2t.io.collator import SpeechCollator - self.vocab = self.config.vocab_filepath self.config.decode.lang_model_path = os.path.join( MODEL_HOME, 'language_model', self.config.decode.lang_model_path) - self.collate_fn_test = SpeechCollator.from_config(self.config) - self.text_feature = TextFeaturizer( - unit_type=self.config.unit_type, vocab=self.vocab) lm_url = self.task_resource.res_dict['lm_url'] lm_md5 = self.task_resource.res_dict['lm_md5'] @@ -772,14 +758,6 @@ class ASRServerExecutor(ASRExecutor): elif "conformer" in model_type or "transformer" in model_type: logger.info("start to create the stream conformer asr engine") - if self.config.spm_model_prefix: - self.config.spm_model_prefix = os.path.join( - self.res_path, self.config.spm_model_prefix) - self.vocab = self.config.vocab_filepath - self.text_feature = TextFeaturizer( - unit_type=self.config.unit_type, - vocab=self.config.vocab_filepath, - spm_model_prefix=self.config.spm_model_prefix) # update the decoding method if decode_method: self.config.decode.decoding_method = decode_method diff --git a/paddlespeech/server/engine/asr/paddleinference/asr_engine.py b/paddlespeech/server/engine/asr/paddleinference/asr_engine.py index b030293f7..c83f52936 100644 --- a/paddlespeech/server/engine/asr/paddleinference/asr_engine.py +++ b/paddlespeech/server/engine/asr/paddleinference/asr_engine.py @@ -54,6 +54,7 @@ class ASRServerExecutor(ASRExecutor): sample_rate_str = '16k' if sample_rate == 16000 else '8k' tag = model_type + '-' + lang + '-' + sample_rate_str + self.max_len = 50 self.task_resource.set_task_model(model_tag=tag) if cfg_path is None or am_model is None or am_params is None: self.res_path = self.task_resource.res_dir @@ -80,22 +81,24 @@ class ASRServerExecutor(ASRExecutor): self.config.merge_from_file(self.cfg_path) with UpdateConfig(self.config): - if "deepspeech2online" in model_type or "deepspeech2offline" in model_type: - from paddlespeech.s2t.io.collator import SpeechCollator + if "deepspeech2" in model_type: self.vocab = self.config.vocab_filepath + if self.config.spm_model_prefix: + self.config.spm_model_prefix = os.path.join( + self.res_path, self.config.spm_model_prefix) + self.text_feature = TextFeaturizer( + unit_type=self.config.unit_type, vocab=self.vocab, + spm_model_prefix=self.config.spm_model_prefix) self.config.decode.lang_model_path = os.path.join( MODEL_HOME, 'language_model', self.config.decode.lang_model_path) - self.collate_fn_test = SpeechCollator.from_config(self.config) - self.text_feature = TextFeaturizer( - unit_type=self.config.unit_type, vocab=self.vocab) lm_url = self.task_resource.res_dict['lm_url'] lm_md5 = self.task_resource.res_dict['lm_md5'] self.download_lm( lm_url, os.path.dirname(self.config.decode.lang_model_path), lm_md5) - elif "conformer" in model_type or "transformer" in model_type or "wenetspeech" in model_type: + elif "conformer" in model_type or "transformer" in model_type: raise Exception("wrong type") else: raise Exception("wrong type") @@ -125,7 +128,7 @@ class ASRServerExecutor(ASRExecutor): cfg = self.config.decode audio = self._inputs["audio"] audio_len = self._inputs["audio_len"] - if "deepspeech2online" in model_type or "deepspeech2offline" in model_type: + if "deepspeech2" in model_type: decode_batch_size = audio.shape[0] # init once self.decoder.init_decoder( @@ -192,7 +195,7 @@ class ASREngine(BaseEngine): return True def run(self, audio_data): - """engine run + """engine run Args: audio_data (bytes): base64.b64decode