Merge pull request #867 from LittleChenCc/develop

update the results of TIMIT and Ted-ST
pull/870/head
Hui Zhang 3 years ago committed by GitHub
commit 84f77ecdf5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -113,7 +113,8 @@ class U2STBaseModel(nn.Layer):
asr_weight: float=0.0, asr_weight: float=0.0,
ignore_id: int=IGNORE_ID, ignore_id: int=IGNORE_ID,
lsm_weight: float=0.0, lsm_weight: float=0.0,
length_normalized_loss: bool=False): length_normalized_loss: bool=False,
**kwargs):
assert 0.0 <= ctc_weight <= 1.0, ctc_weight assert 0.0 <= ctc_weight <= 1.0, ctc_weight
super().__init__() super().__init__()
@ -650,7 +651,7 @@ class U2STModel(U2STBaseModel):
odim=vocab_size, odim=vocab_size,
enc_n_units=encoder.output_size(), enc_n_units=encoder.output_size(),
blank_id=0, blank_id=0,
dropout_rate=model_conf['ctc_dropout_rate'], dropout_rate=model_conf['ctc_dropoutrate'],
reduction=True, # sum reduction=True, # sum
batch_average=True, # sum / batch_size batch_average=True, # sum / batch_size
grad_norm_type=model_conf['ctc_grad_norm_type']) grad_norm_type=model_conf['ctc_grad_norm_type'])

@ -8,3 +8,8 @@
| data/manifest.train | 0.942 ~ 60 | | data/manifest.train | 0.942 ~ 60 |
| data/manifest.dev | 1.151 ~ 39 | | data/manifest.dev | 1.151 ~ 39 |
| data/manifest.test | 1.1 ~ 42.746 | | data/manifest.test | 1.1 ~ 42.746 |
## Transformer
| Model | Params | Config | Char-BLEU |
| --- | --- | --- | --- |
| Transformer+ASR MTL | 50.26M | conf/transformer_joint_noam.yaml | 17.38 |

@ -1,3 +1,11 @@
# TIMIT # TIMIT
Results will be organized and updated soon.
### Transformer
| Model | Params | Config | Decode method | PER |
| --- | --- | --- | --- | --- |
| transformer | 5.17M | conf/transformer.yaml | attention | 0.5531 |
| transformer | 5.17M | conf/transformer.yaml | ctc_greedy_search | 0.3922 |
| transformer | 5.17M | conf/transformer.yaml | ctc_prefix_beam_search | 0.3768 |

@ -3,12 +3,12 @@ data:
train_manifest: data/manifest.train train_manifest: data/manifest.train
dev_manifest: data/manifest.dev dev_manifest: data/manifest.dev
test_manifest: data/manifest.test test_manifest: data/manifest.test
min_input_len: 0.5 # second min_input_len: 0.0 # second
max_input_len: 30.0 # second max_input_len: 10.0 # second
min_output_len: 0.0 # tokens min_output_len: 0.0 # tokens
max_output_len: 400.0 # tokens max_output_len: 150.0 # tokens
min_output_input_ratio: 0.05 min_output_input_ratio: 0.005
max_output_input_ratio: 100.0 max_output_input_ratio: 1000.0
collator: collator:
vocab_filepath: data/vocab.txt vocab_filepath: data/vocab.txt
@ -42,10 +42,10 @@ model:
# encoder related # encoder related
encoder: transformer encoder: transformer
encoder_conf: encoder_conf:
output_size: 256 # dimension of attention output_size: 128 # dimension of attention
attention_heads: 4 attention_heads: 4
linear_units: 2048 # the number of units of position-wise feed forward linear_units: 1024 # the number of units of position-wise feed forward
num_blocks: 12 # the number of encoder blocks num_blocks: 6 # the number of encoder blocks
dropout_rate: 0.1 dropout_rate: 0.1
positional_dropout_rate: 0.1 positional_dropout_rate: 0.1
attention_dropout_rate: 0.0 attention_dropout_rate: 0.0
@ -56,7 +56,7 @@ model:
decoder: transformer decoder: transformer
decoder_conf: decoder_conf:
attention_heads: 4 attention_heads: 4
linear_units: 2048 linear_units: 1024
num_blocks: 6 num_blocks: 6
dropout_rate: 0.1 dropout_rate: 0.1
positional_dropout_rate: 0.1 positional_dropout_rate: 0.1
@ -65,26 +65,26 @@ model:
# hybrid CTC/attention # hybrid CTC/attention
model_conf: model_conf:
ctc_weight: 0.3 ctc_weight: 0.5
ctc_dropoutrate: 0.0 ctc_dropoutrate: 0.0
ctc_grad_norm_type: instance ctc_grad_norm_type: batch
lsm_weight: 0.1 # label smoothing option lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false length_normalized_loss: false
training: training:
n_epoch: 120 n_epoch: 200
accum_grad: 2 accum_grad: 2
global_grad_clip: 5.0 global_grad_clip: 5.0
optim: adam optim: adam
optim_conf: optim_conf:
lr: 0.002 lr: 0.004
weight_decay: 1e-06 weight_decay: 1e-06
scheduler: warmuplr # pytorch v1.1.0+ required scheduler: warmuplr # pytorch v1.1.0+ required
scheduler_conf: scheduler_conf:
warmup_steps: 400 warmup_steps: 2000
lr_decay: 1.0 lr_decay: 1.0
log_interval: 100 log_interval: 10
checkpoint: checkpoint:
kbest_n: 50 kbest_n: 50
latest_n: 5 latest_n: 5

Loading…
Cancel
Save