diff --git a/deepspeech/models/ds2_online/deepspeech2.py b/deepspeech/models/ds2_online/deepspeech2.py index 3c82f325..75a6f044 100644 --- a/deepspeech/models/ds2_online/deepspeech2.py +++ b/deepspeech/models/ds2_online/deepspeech2.py @@ -56,12 +56,17 @@ class CRNNEncoder(nn.Layer): self.rnn = nn.LayerList() self.layernorm_list = nn.LayerList() self.fc_layers_list = nn.LayerList() - layernorm_size = rnn_size + if rnn_direction == 'bidirect' or rnn_direction == 'bidirectional': + layernorm_size = 2 * rnn_size + elif rnn_direction == 'forward': + layernorm_size = rnn_size + else: + raise Exception("Wrong rnn direction") for i in range(0, num_rnn_layers): if i == 0: rnn_input_size = i_size else: - rnn_input_size = rnn_size + rnn_input_size = layernorm_size if use_gru == True: self.rnn.append( nn.GRU( @@ -78,7 +83,7 @@ class CRNNEncoder(nn.Layer): direction=rnn_direction)) self.layernorm_list.append(nn.LayerNorm(layernorm_size)) - fc_input_size = rnn_size + fc_input_size = layernorm_size for i in range(self.num_fc_layers): self.fc_layers_list.append( nn.Linear(fc_input_size, fc_layers_size_list[i])) @@ -385,8 +390,8 @@ class DeepSpeech2InferModelOnline(DeepSpeech2ModelOnline): self, input_spec=[ paddle.static.InputSpec( - shape=[None, None, self.encoder.feat_size - ], #[B, chunk_size, feat_dim] + shape=[None, None, + self.encoder.feat_size], #[B, chunk_size, feat_dim] dtype='float32'), # audio, [B,T,D] paddle.static.InputSpec(shape=[None], dtype='int64'), # audio_length, [B] diff --git a/examples/aishell/s0/conf/deepspeech2_online.yaml b/examples/aishell/s0/conf/deepspeech2_online.yaml new file mode 100644 index 00000000..60df8d17 --- /dev/null +++ b/examples/aishell/s0/conf/deepspeech2_online.yaml @@ -0,0 +1,67 @@ +# https://yaml.org/type/float.html +data: + train_manifest: data/manifest.train + dev_manifest: data/manifest.dev + test_manifest: data/manifest.test + min_input_len: 0.0 + max_input_len: 27.0 # second + min_output_len: 0.0 + max_output_len: .inf + min_output_input_ratio: 0.00 + max_output_input_ratio: .inf + +collator: + batch_size: 32 # one gpu + mean_std_filepath: data/mean_std.json + unit_type: char + vocab_filepath: data/vocab.txt + augmentation_config: conf/augmentation.json + random_seed: 0 + spm_model_prefix: + specgram_type: linear #linear, mfcc, fbank + feat_dim: + delta_delta: False + stride_ms: 10.0 + window_ms: 20.0 + n_fft: None + max_freq: None + target_sample_rate: 16000 + use_dB_normalization: True + target_dB: -20 + dither: 1.0 + keep_transcription_text: False + sortagrad: True + shuffle_method: batch_shuffle + num_workers: 0 + +model: + num_conv_layers: 2 + num_rnn_layers: 4 + rnn_layer_size: 1024 + rnn_direction: bidirect + num_fc_layers: 2 + fc_layers_size_list: 512, 256 + use_gru: True + +training: + n_epoch: 50 + lr: 2e-3 + lr_decay: 0.83 + weight_decay: 1e-06 + global_grad_clip: 3.0 + log_interval: 100 + checkpoint: + kbest_n: 50 + latest_n: 5 + +decoding: + batch_size: 64 + error_rate_type: cer + decoding_method: ctc_beam_search + lang_model_path: data/lm/zh_giga.no_cna_cmn.prune01244.klm + alpha: 1.9 + beta: 5.0 + beam_size: 300 + cutoff_prob: 0.99 + cutoff_top_n: 40 + num_proc_bsearch: 10 diff --git a/examples/librispeech/s0/conf/deepspeech2_online.yaml b/examples/librispeech/s0/conf/deepspeech2_online.yaml new file mode 100644 index 00000000..2e4aed40 --- /dev/null +++ b/examples/librispeech/s0/conf/deepspeech2_online.yaml @@ -0,0 +1,67 @@ +# https://yaml.org/type/float.html +data: + train_manifest: data/manifest.train + dev_manifest: data/manifest.dev-clean + test_manifest: data/manifest.test-clean + min_input_len: 0.0 + max_input_len: 27.0 # second + min_output_len: 0.0 + max_output_len: .inf + min_output_input_ratio: 0.00 + max_output_input_ratio: .inf + +collator: + batch_size: 20 + mean_std_filepath: data/mean_std.json + unit_type: char + vocab_filepath: data/vocab.txt + augmentation_config: conf/augmentation.json + random_seed: 0 + spm_model_prefix: + specgram_type: linear + target_sample_rate: 16000 + max_freq: None + n_fft: None + stride_ms: 10.0 + window_ms: 20.0 + delta_delta: False + dither: 1.0 + use_dB_normalization: True + target_dB: -20 + random_seed: 0 + keep_transcription_text: False + sortagrad: True + shuffle_method: batch_shuffle + num_workers: 0 + +model: + num_conv_layers: 2 + num_rnn_layers: 3 + rnn_layer_size: 2048 + rnn_direction: forward + num_fc_layers: 2 + fc_layers_size_list: 512, 256 + use_gru: False + +training: + n_epoch: 50 + lr: 1e-3 + lr_decay: 0.83 + weight_decay: 1e-06 + global_grad_clip: 5.0 + log_interval: 100 + checkpoint: + kbest_n: 50 + latest_n: 5 + +decoding: + batch_size: 128 + error_rate_type: wer + decoding_method: ctc_beam_search + lang_model_path: data/lm/common_crawl_00.prune01111.trie.klm + alpha: 1.9 + beta: 0.3 + beam_size: 500 + cutoff_prob: 1.0 + cutoff_top_n: 40 + num_proc_bsearch: 8 diff --git a/examples/tiny/s0/conf/deepspeech2_online.yaml b/examples/tiny/s0/conf/deepspeech2_online.yaml new file mode 100644 index 00000000..333c2b9a --- /dev/null +++ b/examples/tiny/s0/conf/deepspeech2_online.yaml @@ -0,0 +1,69 @@ +# https://yaml.org/type/float.html +data: + train_manifest: data/manifest.tiny + dev_manifest: data/manifest.tiny + test_manifest: data/manifest.tiny + min_input_len: 0.0 + max_input_len: 27.0 + min_output_len: 0.0 + max_output_len: 400.0 + min_output_input_ratio: 0.05 + max_output_input_ratio: 10.0 + + +collator: + mean_std_filepath: data/mean_std.json + unit_type: char + vocab_filepath: data/vocab.txt + augmentation_config: conf/augmentation.json + random_seed: 0 + spm_model_prefix: + specgram_type: linear + feat_dim: + delta_delta: False + stride_ms: 10.0 + window_ms: 20.0 + n_fft: None + max_freq: None + target_sample_rate: 16000 + use_dB_normalization: True + target_dB: -20 + dither: 1.0 + keep_transcription_text: False + sortagrad: True + shuffle_method: batch_shuffle + num_workers: 0 + batch_size: 4 + +model: + num_conv_layers: 2 + num_rnn_layers: 4 + rnn_layer_size: 2048 + rnn_direction: forward + num_fc_layers: 2 + fc_layers_size_list: 512, 256 + use_gru: True + +training: + n_epoch: 10 + lr: 1e-5 + lr_decay: 1.0 + weight_decay: 1e-06 + global_grad_clip: 5.0 + log_interval: 1 + checkpoint: + kbest_n: 3 + latest_n: 2 + + +decoding: + batch_size: 128 + error_rate_type: wer + decoding_method: ctc_beam_search + lang_model_path: data/lm/common_crawl_00.prune01111.trie.klm + alpha: 2.5 + beta: 0.3 + beam_size: 500 + cutoff_prob: 1.0 + cutoff_top_n: 40 + num_proc_bsearch: 8