diff --git a/examples/aishell/s1/conf/transformer.yaml b/examples/aishell/s1/conf/transformer.yaml new file mode 100644 index 00000000..7803097a --- /dev/null +++ b/examples/aishell/s1/conf/transformer.yaml @@ -0,0 +1,112 @@ +# https://yaml.org/type/float.html +data: + train_manifest: data/manifest.train + dev_manifest: data/manifest.dev + test_manifest: data/manifest.test + min_input_len: 0.5 + max_input_len: 20.0 # second + min_output_len: 0.0 + max_output_len: 400.0 + min_output_input_ratio: 0.05 + max_output_input_ratio: 10.0 + + +collator: + vocab_filepath: data/vocab.txt + unit_type: 'char' + spm_model_prefix: '' + augmentation_config: conf/preprocess.yaml + batch_size: 64 + raw_wav: True # use raw_wav or kaldi feature + spectrum_type: fbank #linear, mfcc, fbank + feat_dim: 80 + delta_delta: False + dither: 1.0 + target_sample_rate: 16000 + max_freq: None + n_fft: None + stride_ms: 10.0 + window_ms: 25.0 + use_dB_normalization: True + target_dB: -20 + random_seed: 0 + keep_transcription_text: False + sortagrad: True + shuffle_method: batch_shuffle + num_workers: 2 + +# network architecture +model: + cmvn_file: + cmvn_file_type: "json" + # encoder related + encoder: transformer + encoder_conf: + output_size: 256 # dimension of attention + attention_heads: 4 + linear_units: 2048 # the number of units of position-wise feed forward + num_blocks: 12 # the number of encoder blocks + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + attention_dropout_rate: 0.0 + input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8 + normalize_before: true + + # decoder related + decoder: transformer + decoder_conf: + attention_heads: 4 + linear_units: 2048 + num_blocks: 6 + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + self_attention_dropout_rate: 0.0 + src_attention_dropout_rate: 0.0 + + # hybrid CTC/attention + model_conf: + ctc_weight: 0.3 + ctc_dropoutrate: 0.0 + ctc_grad_norm_type: null + lsm_weight: 0.1 # label smoothing option + length_normalized_loss: false + + +training: + n_epoch: 240 + accum_grad: 2 + global_grad_clip: 5.0 + optim: adam + optim_conf: + lr: 0.002 + weight_decay: 1e-6 + scheduler: warmuplr # pytorch v1.1.0+ required + scheduler_conf: + warmup_steps: 25000 + lr_decay: 1.0 + log_interval: 100 + checkpoint: + kbest_n: 50 + latest_n: 5 + + +decoding: + batch_size: 128 + error_rate_type: cer + decoding_method: attention # 'attention', 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring' + lang_model_path: data/lm/common_crawl_00.prune01111.trie.klm + alpha: 2.5 + beta: 0.3 + beam_size: 10 + cutoff_prob: 1.0 + cutoff_top_n: 0 + num_proc_bsearch: 8 + ctc_weight: 0.5 # ctc weight for attention rescoring decode mode. + decoding_chunk_size: -1 # decoding chunk size. Defaults to -1. + # <0: for decoding, use full chunk. + # >0: for decoding, use fixed chunk size as set. + # 0: used for training, it's prohibited here. + num_decoding_left_chunks: -1 # number of left chunks for decoding. Defaults to -1. + simulate_streaming: False # simulate streaming inference. Defaults to False. + + diff --git a/examples/dataset/aishell/aishell.py b/examples/dataset/aishell/aishell.py index 66e06901..95ed0408 100644 --- a/examples/dataset/aishell/aishell.py +++ b/examples/dataset/aishell/aishell.py @@ -22,6 +22,7 @@ import argparse import codecs import json import os +from pathlib import Path import soundfile @@ -81,6 +82,8 @@ def create_manifest(data_dir, manifest_path_prefix): # if no transcription for audio then skipped if audio_id not in transcript_dict: continue + + utt2spk = Path(audio_path).parent.name audio_data, samplerate = soundfile.read(audio_path) duration = float(len(audio_data) / samplerate) text = transcript_dict[audio_id] @@ -88,6 +91,7 @@ def create_manifest(data_dir, manifest_path_prefix): json.dumps( { 'utt': audio_id, + 'utt2spk': str(utt2spk), 'feat': audio_path, 'feat_shape': (duration, ), # second 'text': text