diff --git a/deepspeech/models/u2_st.py b/deepspeech/models/u2_st.py index a3d99942..8f87f6da 100644 --- a/deepspeech/models/u2_st.py +++ b/deepspeech/models/u2_st.py @@ -113,7 +113,8 @@ class U2STBaseModel(nn.Layer): asr_weight: float=0.0, ignore_id: int=IGNORE_ID, lsm_weight: float=0.0, - length_normalized_loss: bool=False): + length_normalized_loss: bool=False, + **kwargs): assert 0.0 <= ctc_weight <= 1.0, ctc_weight super().__init__() @@ -650,7 +651,7 @@ class U2STModel(U2STBaseModel): odim=vocab_size, enc_n_units=encoder.output_size(), blank_id=0, - dropout_rate=model_conf['ctc_dropout_rate'], + dropout_rate=model_conf['ctc_dropoutrate'], reduction=True, # sum batch_average=True, # sum / batch_size grad_norm_type=model_conf['ctc_grad_norm_type']) diff --git a/examples/ted_en_zh/t0/README.md b/examples/ted_en_zh/t0/README.md index e2443d36..9bca2643 100644 --- a/examples/ted_en_zh/t0/README.md +++ b/examples/ted_en_zh/t0/README.md @@ -6,5 +6,10 @@ | Data Subset | Duration in Seconds | | --- | --- | | data/manifest.train | 0.942 ~ 60 | -| data/manifest.dev | 1.151 ~ 39 | +| data/manifest.dev | 1.151 ~ 39 | | data/manifest.test | 1.1 ~ 42.746 | + +## Transformer +| Model | Params | Config | Char-BLEU | +| --- | --- | --- | --- | +| Transformer+ASR MTL | 50.26M | conf/transformer_joint_noam.yaml | 17.38 | \ No newline at end of file diff --git a/examples/timit/s1/README.md b/examples/timit/s1/README.md index 4d9b146a..6d719a7d 100644 --- a/examples/timit/s1/README.md +++ b/examples/timit/s1/README.md @@ -1,3 +1,11 @@ # TIMIT -Results will be organized and updated soon. + + + +### Transformer +| Model | Params | Config | Decode method | PER | +| --- | --- | --- | --- | --- | +| transformer | 5.17M | conf/transformer.yaml | attention | 0.5531 | +| transformer | 5.17M | conf/transformer.yaml | ctc_greedy_search | 0.3922 | +| transformer | 5.17M | conf/transformer.yaml | ctc_prefix_beam_search | 0.3768 | \ No newline at end of file diff --git a/examples/timit/s1/conf/transformer.yaml b/examples/timit/s1/conf/transformer.yaml index 1ae9acd0..a55dcc43 100644 --- a/examples/timit/s1/conf/transformer.yaml +++ b/examples/timit/s1/conf/transformer.yaml @@ -3,12 +3,12 @@ data: train_manifest: data/manifest.train dev_manifest: data/manifest.dev test_manifest: data/manifest.test - min_input_len: 0.5 # second - max_input_len: 30.0 # second + min_input_len: 0.0 # second + max_input_len: 10.0 # second min_output_len: 0.0 # tokens - max_output_len: 400.0 # tokens - min_output_input_ratio: 0.05 - max_output_input_ratio: 100.0 + max_output_len: 150.0 # tokens + min_output_input_ratio: 0.005 + max_output_input_ratio: 1000.0 collator: vocab_filepath: data/vocab.txt @@ -42,10 +42,10 @@ model: # encoder related encoder: transformer encoder_conf: - output_size: 256 # dimension of attention + output_size: 128 # dimension of attention attention_heads: 4 - linear_units: 2048 # the number of units of position-wise feed forward - num_blocks: 12 # the number of encoder blocks + linear_units: 1024 # the number of units of position-wise feed forward + num_blocks: 6 # the number of encoder blocks dropout_rate: 0.1 positional_dropout_rate: 0.1 attention_dropout_rate: 0.0 @@ -56,7 +56,7 @@ model: decoder: transformer decoder_conf: attention_heads: 4 - linear_units: 2048 + linear_units: 1024 num_blocks: 6 dropout_rate: 0.1 positional_dropout_rate: 0.1 @@ -65,26 +65,26 @@ model: # hybrid CTC/attention model_conf: - ctc_weight: 0.3 + ctc_weight: 0.5 ctc_dropoutrate: 0.0 - ctc_grad_norm_type: instance + ctc_grad_norm_type: batch lsm_weight: 0.1 # label smoothing option length_normalized_loss: false training: - n_epoch: 120 + n_epoch: 200 accum_grad: 2 global_grad_clip: 5.0 optim: adam optim_conf: - lr: 0.002 + lr: 0.004 weight_decay: 1e-06 scheduler: warmuplr # pytorch v1.1.0+ required scheduler_conf: - warmup_steps: 400 + warmup_steps: 2000 lr_decay: 1.0 - log_interval: 100 + log_interval: 10 checkpoint: kbest_n: 50 latest_n: 5