From 9f42ec4bc225d7b23ff5d2d29a4a7a1e870a942b Mon Sep 17 00:00:00 2001 From: gongel Date: Thu, 18 Nov 2021 08:10:07 +0000 Subject: [PATCH] feat: add ted_en_zh t1 --- examples/ted_en_zh/t1/.gitignore | 3 + examples/ted_en_zh/t1/README.md | 15 +++ examples/ted_en_zh/t1/conf/transformer.yaml | 112 ++++++++++++++++++ .../t1/conf/transformer_joint_noam.yaml | 112 ++++++++++++++++++ .../t1/local/convert_torch_to_paddle.py | 88 ++++++++++++++ examples/ted_en_zh/t1/local/data.sh | 110 +++++++++++++++++ examples/ted_en_zh/t1/local/test.sh | 31 +++++ examples/ted_en_zh/t1/local/train.sh | 39 ++++++ examples/ted_en_zh/t1/path.sh | 15 +++ examples/ted_en_zh/t1/run.sh | 42 +++++++ 10 files changed, 567 insertions(+) create mode 100644 examples/ted_en_zh/t1/.gitignore create mode 100644 examples/ted_en_zh/t1/README.md create mode 100644 examples/ted_en_zh/t1/conf/transformer.yaml create mode 100644 examples/ted_en_zh/t1/conf/transformer_joint_noam.yaml create mode 100644 examples/ted_en_zh/t1/local/convert_torch_to_paddle.py create mode 100755 examples/ted_en_zh/t1/local/data.sh create mode 100755 examples/ted_en_zh/t1/local/test.sh create mode 100755 examples/ted_en_zh/t1/local/train.sh create mode 100644 examples/ted_en_zh/t1/path.sh create mode 100755 examples/ted_en_zh/t1/run.sh diff --git a/examples/ted_en_zh/t1/.gitignore b/examples/ted_en_zh/t1/.gitignore new file mode 100644 index 00000000..123e5174 --- /dev/null +++ b/examples/ted_en_zh/t1/.gitignore @@ -0,0 +1,3 @@ +TED_EnZh +data +exp diff --git a/examples/ted_en_zh/t1/README.md b/examples/ted_en_zh/t1/README.md new file mode 100644 index 00000000..66a5dbec --- /dev/null +++ b/examples/ted_en_zh/t1/README.md @@ -0,0 +1,15 @@ + +# TED En-Zh + +## Dataset + +| Data Subset | Duration in Seconds | +| --- | --- | +| data/manifest.train | 0.942 ~ 60 | +| data/manifest.dev | 1.151 ~ 39 | +| data/manifest.test | 1.1 ~ 42.746 | + +## Transformer +| Model | Params | Config | Char-BLEU | +| --- | --- | --- | --- | +| Transformer+ASR MTL | 50.26M | conf/transformer_joint_noam.yaml | 17.38 | diff --git a/examples/ted_en_zh/t1/conf/transformer.yaml b/examples/ted_en_zh/t1/conf/transformer.yaml new file mode 100644 index 00000000..d9637286 --- /dev/null +++ b/examples/ted_en_zh/t1/conf/transformer.yaml @@ -0,0 +1,112 @@ +# https://yaml.org/type/float.html +data: + train_manifest: data/manifest.train.tiny + dev_manifest: data/manifest.dev + test_manifest: data/manifest.test + min_input_len: 5.0 # frame + max_input_len: 3000.0 # frame + min_output_len: 0.0 # tokens + max_output_len: 400.0 # tokens + min_output_input_ratio: 0.01 + max_output_input_ratio: 20.0 + +collator: + vocab_filepath: data/vocab.txt + unit_type: 'spm' + spm_model_prefix: data/bpe_unigram_8000 + mean_std_filepath: "" + # augmentation_config: conf/augmentation.json + batch_size: 10 + raw_wav: True # use raw_wav or kaldi feature + spectrum_type: fbank #linear, mfcc, fbank + feat_dim: 83 + delta_delta: False + dither: 1.0 + target_sample_rate: 16000 + max_freq: None + n_fft: None + stride_ms: 10.0 + window_ms: 25.0 + use_dB_normalization: True + target_dB: -20 + random_seed: 0 + keep_transcription_text: False + sortagrad: True + shuffle_method: batch_shuffle + num_workers: 2 + + +# network architecture +model: + cmvn_file: None + cmvn_file_type: "json" + # encoder related + encoder: transformer + encoder_conf: + output_size: 256 # dimension of attention + attention_heads: 4 + linear_units: 2048 # the number of units of position-wise feed forward + num_blocks: 12 # the number of encoder blocks + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + attention_dropout_rate: 0.0 + input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8 + normalize_before: true + + # decoder related + decoder: transformer + decoder_conf: + attention_heads: 4 + linear_units: 2048 + num_blocks: 6 + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + self_attention_dropout_rate: 0.0 + src_attention_dropout_rate: 0.0 + + # hybrid CTC/attention + model_conf: + asr_weight: 0.0 + ctc_weight: 0.0 + ctc_dropoutrate: 0.0 + ctc_grad_norm_type: null + lsm_weight: 0.1 # label smoothing option + length_normalized_loss: false + + +training: + n_epoch: 20 + accum_grad: 2 + global_grad_clip: 5.0 + optim: adam + optim_conf: + lr: 0.004 + weight_decay: 1e-06 + scheduler: warmuplr # pytorch v1.1.0+ required + scheduler_conf: + warmup_steps: 25000 + lr_decay: 1.0 + log_interval: 5 + checkpoint: + kbest_n: 50 + latest_n: 5 + + +decoding: + batch_size: 5 + error_rate_type: char-bleu + decoding_method: fullsentence # 'fullsentence', 'simultaneous' + alpha: 2.5 + beta: 0.3 + beam_size: 10 + word_reward: 0.7 + cutoff_prob: 1.0 + cutoff_top_n: 0 + num_proc_bsearch: 8 + ctc_weight: 0.5 # ctc weight for attention rescoring decode mode. + decoding_chunk_size: -1 # decoding chunk size. Defaults to -1. + # <0: for decoding, use full chunk. + # >0: for decoding, use fixed chunk size as set. + # 0: used for training, it's prohibited here. + num_decoding_left_chunks: -1 # number of left chunks for decoding. Defaults to -1. + simulate_streaming: False # simulate streaming inference. Defaults to False. diff --git a/examples/ted_en_zh/t1/conf/transformer_joint_noam.yaml b/examples/ted_en_zh/t1/conf/transformer_joint_noam.yaml new file mode 100644 index 00000000..ea38d6ee --- /dev/null +++ b/examples/ted_en_zh/t1/conf/transformer_joint_noam.yaml @@ -0,0 +1,112 @@ +# https://yaml.org/type/float.html +data: + train_manifest: data/manifest.train + dev_manifest: data/manifest.dev + test_manifest: data/manifest.test + min_input_len: 5.0 # frame + max_input_len: 3000.0 # frame + min_output_len: 0.0 # tokens + max_output_len: 400.0 # tokens + min_output_input_ratio: 0.01 + max_output_input_ratio: 20.0 + +collator: + vocab_filepath: data/vocab.txt + unit_type: 'spm' + spm_model_prefix: data/train_sp.en-zh-nlpr.zh-nlpr_bpe8000_tc + mean_std_filepath: "" + # augmentation_config: conf/augmentation.json + batch_size: 10 + raw_wav: True # use raw_wav or kaldi feature + spectrum_type: fbank #linear, mfcc, fbank + feat_dim: 83 + delta_delta: False + dither: 1.0 + target_sample_rate: 16000 + max_freq: None + n_fft: None + stride_ms: 10.0 + window_ms: 25.0 + use_dB_normalization: True + target_dB: -20 + random_seed: 0 + keep_transcription_text: False + sortagrad: True + shuffle_method: batch_shuffle + num_workers: 2 + + +# network architecture +model: + cmvn_file: None + cmvn_file_type: "json" + # encoder related + encoder: transformer + encoder_conf: + output_size: 256 # dimension of attention + attention_heads: 4 + linear_units: 2048 # the number of units of position-wise feed forward + num_blocks: 12 # the number of encoder blocks + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + attention_dropout_rate: 0.0 + input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8 + normalize_before: true + + # decoder related + decoder: transformer + decoder_conf: + attention_heads: 4 + linear_units: 2048 + num_blocks: 6 + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + self_attention_dropout_rate: 0.0 + src_attention_dropout_rate: 0.0 + + # hybrid CTC/attention + model_conf: + asr_weight: 0.5 + ctc_weight: 0.3 + ctc_dropoutrate: 0.0 + ctc_grad_norm_type: null + lsm_weight: 0.1 # label smoothing option + length_normalized_loss: false + + +training: + n_epoch: 20 + accum_grad: 2 + global_grad_clip: 5.0 + optim: adam + optim_conf: + lr: 2.5 + weight_decay: 1e-06 + scheduler: noam + scheduler_conf: + warmup_steps: 25000 + lr_decay: 1.0 + log_interval: 5 + checkpoint: + kbest_n: 50 + latest_n: 5 + + +decoding: + batch_size: 5 + error_rate_type: char-bleu + decoding_method: fullsentence # 'fullsentence', 'simultaneous' + alpha: 2.5 + beta: 0.3 + beam_size: 10 + word_reward: 0.7 + cutoff_prob: 1.0 + cutoff_top_n: 0 + num_proc_bsearch: 8 + ctc_weight: 0.5 # ctc weight for attention rescoring decode mode. + decoding_chunk_size: -1 # decoding chunk size. Defaults to -1. + # <0: for decoding, use full chunk. + # >0: for decoding, use fixed chunk size as set. + # 0: used for training, it's prohibited here. + num_decoding_left_chunks: -1 # number of left chunks for decoding. Defaults to -1. + simulate_streaming: False # simulate streaming inference. Defaults to False. \ No newline at end of file diff --git a/examples/ted_en_zh/t1/local/convert_torch_to_paddle.py b/examples/ted_en_zh/t1/local/convert_torch_to_paddle.py new file mode 100644 index 00000000..93fa1595 --- /dev/null +++ b/examples/ted_en_zh/t1/local/convert_torch_to_paddle.py @@ -0,0 +1,88 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse + +import paddle +import torch + +from paddlespeech.s2t.utils.log import Log + +logger = Log(__name__).getlog() + + +def torch2paddle(args): + paddle.set_device('cpu') + paddle_model_dict = {} + torch_model = torch.load(args.torch_ckpt, map_location='cpu') + cnt = 0 + for k, v in torch_model['model'].items(): + if k.startswith('encoder.embed'): + if v.ndim == 2: + v = v.transpose(0, 1) + paddle_model_dict[k] = v.numpy() + cnt += 1 + logger.info( + f"Convert torch weight: {k} to paddlepaddle weight: {k}, shape is {v.shape}" + ) + if k.startswith('encoder.after_norm'): + paddle_model_dict[k] = v.numpy() + cnt += 1 + paddle_model_dict[k.replace('en', 'de')] = v.numpy() + logger.info( + f"Convert torch weight: {k} to paddlepaddle weight: {k.replace('en','de')}, shape is {v.shape}" + ) + paddle_model_dict['st_' + k.replace('en', 'de')] = v.numpy() + logger.info( + f"Convert torch weight: {k} to paddlepaddle weight: {'st_'+ k.replace('en','de')}, shape is {v.shape}" + ) + cnt += 2 + if k.startswith('encoder.encoders'): + if v.ndim == 2: + v = v.transpose(0, 1) + paddle_model_dict[k] = v.numpy() + logger.info( + f"Convert torch weight: {k} to paddlepaddle weight: {k}, shape is {v.shape}" + ) + cnt += 1 + origin_k = k + k_split = k.split('.') + if int(k_split[2]) >= 6: + k = k.replace(k_split[2], str(int(k_split[2]) - 6)) + paddle_model_dict[k.replace('en', 'de')] = v.numpy() + logger.info( + f"Convert torch weight: {origin_k} to paddlepaddle weight: {k.replace('en','de')}, shape is {v.shape}" + ) + paddle_model_dict['st_' + k.replace('en', 'de')] = v.numpy() + logger.info( + f"Convert torch weight: {origin_k} to paddlepaddle weight: {'st_'+ k.replace('en','de')}, shape is {v.shape}" + ) + cnt += 2 + logger.info(f"Convert {cnt} weights totally from torch to paddlepaddle") + paddle.save(paddle_model_dict, args.paddle_ckpt) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + '--torch_ckpt', + type=str, + default='/home/snapshot.ep.98', + help="Path to torch checkpoint.") + parser.add_argument( + '--paddle_ckpt', + type=str, + default='paddle.98.pdparams', + help="Path to save paddlepaddle checkpoint.") + args = parser.parse_args() + torch2paddle(args) diff --git a/examples/ted_en_zh/t1/local/data.sh b/examples/ted_en_zh/t1/local/data.sh new file mode 100755 index 00000000..b080a5b4 --- /dev/null +++ b/examples/ted_en_zh/t1/local/data.sh @@ -0,0 +1,110 @@ +#!/bin/bash + +set -e + +stage=-1 +stop_stage=100 + +# bpemode (unigram or bpe) +nbpe=8000 +bpemode=unigram +bpeprefix="data/bpe_${bpemode}_${nbpe}" +data_dir=./TED_EnZh + + +source ${MAIN_ROOT}/utils/parse_options.sh + +TARGET_DIR=${MAIN_ROOT}/examples/dataset +mkdir -p ${TARGET_DIR} +mkdir -p data + + +if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then + if [ ! -e ${data_dir} ]; then + echo "Error: Dataset is not avaiable. Please download and unzip the dataset" + echo "Download Link: https://pan.baidu.com/s/18L-59wgeS96WkObISrytQQ Passwd: bva0" + echo "The tree of the directory should be:" + echo "." + echo "|-- En-Zh" + echo "|-- test-segment" + echo " |-- tst2010" + echo " |-- ..." + echo "|-- train-split" + echo " |-- train-segment" + echo "|-- README.md" + + exit 1 + fi + + # generate manifests + python3 ${TARGET_DIR}/ted_en_zh/ted_en_zh.py \ + --manifest_prefix="data/manifest" \ + --src_dir="${data_dir}" + + echo "Complete raw data pre-process." +fi + +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + # compute mean and stddev for normalizer + num_workers=$(nproc) + python3 ${MAIN_ROOT}/utils/compute_mean_std.py \ + --manifest_path="data/manifest.train.raw" \ + --num_samples=-1 \ + --spectrum_type="fbank" \ + --feat_dim=80 \ + --delta_delta=false \ + --sample_rate=16000 \ + --stride_ms=10.0 \ + --window_ms=25.0 \ + --use_dB_normalization=False \ + --num_workers=${num_workers} \ + --output_path="data/mean_std.json" + + if [ $? -ne 0 ]; then + echo "Compute mean and stddev failed. Terminated." + exit 1 + fi +fi + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + # build vocabulary + python3 ${MAIN_ROOT}/utils/build_vocab.py \ + --unit_type "spm" \ + --spm_vocab_size=${nbpe} \ + --spm_mode ${bpemode} \ + --spm_model_prefix ${bpeprefix} \ + --vocab_path="data/vocab.txt" \ + --text_keys 'text' 'text1' \ + --manifest_paths="data/manifest.train.raw" + + + if [ $? -ne 0 ]; then + echo "Build vocabulary failed. Terminated." + exit 1 + fi +fi + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + # format manifest with tokenids, vocab size + for set in train dev test; do + { + python3 ${MAIN_ROOT}/utils/format_triplet_data.py \ + --feat_type "raw" \ + --cmvn_path "data/mean_std.json" \ + --unit_type "spm" \ + --spm_model_prefix ${bpeprefix} \ + --vocab_path="data/vocab.txt" \ + --manifest_path="data/manifest.${set}.raw" \ + --output_path="data/manifest.${set}" + + if [ $? -ne 0 ]; then + echo "Formt mnaifest failed. Terminated." + exit 1 + fi + }& + done + wait +fi + +echo "Ted En-Zh Data preparation done." +exit 0 diff --git a/examples/ted_en_zh/t1/local/test.sh b/examples/ted_en_zh/t1/local/test.sh new file mode 100755 index 00000000..7235c6f9 --- /dev/null +++ b/examples/ted_en_zh/t1/local/test.sh @@ -0,0 +1,31 @@ +#! /usr/bin/env bash + +if [ $# != 2 ];then + echo "usage: ${0} config_path ckpt_path_prefix" + exit -1 +fi + +ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') +echo "using $ngpu gpus..." + +config_path=$1 +ckpt_prefix=$2 + +for type in fullsentence; do + echo "decoding ${type}" + batch_size=32 + python3 -u ${BIN_DIR}/test.py \ + --nproc ${ngpu} \ + --config ${config_path} \ + --result_file ${ckpt_prefix}.${type}.rsl \ + --checkpoint_path ${ckpt_prefix} \ + --opts decoding.decoding_method ${type} \ + --opts decoding.batch_size ${batch_size} + + if [ $? -ne 0 ]; then + echo "Failed in evaluation!" + exit 1 + fi +done + +exit 0 diff --git a/examples/ted_en_zh/t1/local/train.sh b/examples/ted_en_zh/t1/local/train.sh new file mode 100755 index 00000000..36701121 --- /dev/null +++ b/examples/ted_en_zh/t1/local/train.sh @@ -0,0 +1,39 @@ +#!/bin/bash + +if [ $# != 3 ];then + echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name ckpt_path" + exit -1 +fi + +ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') +echo "using $ngpu gpus..." + +config_path=$1 +ckpt_name=$2 +ckpt_path=$3 + +mkdir -p exp + +# seed may break model convergence +seed=0 +if [ ${seed} != 0 ]; then + export FLAGS_cudnn_deterministic=True +fi + +python3 -u ${BIN_DIR}/train.py \ +--nproc ${ngpu} \ +--config ${config_path} \ +--output exp/${ckpt_name} \ +--checkpoint_path ${ckpt_path} \ +--seed ${seed} + +if [ ${seed} != 0 ]; then + unset FLAGS_cudnn_deterministic +fi + +if [ $? -ne 0 ]; then + echo "Failed in training!" + exit 1 +fi + +exit 0 \ No newline at end of file diff --git a/examples/ted_en_zh/t1/path.sh b/examples/ted_en_zh/t1/path.sh new file mode 100644 index 00000000..fd537917 --- /dev/null +++ b/examples/ted_en_zh/t1/path.sh @@ -0,0 +1,15 @@ +export MAIN_ROOT=`realpath ${PWD}/../../../` + +export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH} +export LC_ALL=C + +export PYTHONDONTWRITEBYTECODE=1 +# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C +export PYTHONIOENCODING=UTF-8 +export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH} + +export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib/ + + +MODEL=u2_st +export BIN_DIR=${MAIN_ROOT}/paddlespeech/s2t/exps/${MODEL}/bin diff --git a/examples/ted_en_zh/t1/run.sh b/examples/ted_en_zh/t1/run.sh new file mode 100755 index 00000000..ddb155ad --- /dev/null +++ b/examples/ted_en_zh/t1/run.sh @@ -0,0 +1,42 @@ +#!/bin/bash +set -e +source path.sh + +gpus=0,1,2,3 +stage=1 +stop_stage=100 +conf_path=conf/transformer_joint_noam.yaml +ckpt_path=paddle.98 +avg_num=5 +data_path=./TED_EnZh # path to unzipped data +source ${MAIN_ROOT}/utils/parse_options.sh || exit 1; + +avg_ckpt=avg_${avg_num} +ckpt=$(basename ${conf_path} | awk -F'.' '{print $1}') +echo "checkpoint name ${ckpt}" + + +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + # prepare data + bash ./local/data.sh --data_dir ${data_path} || exit -1 +fi + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + # train model, all `ckpt` under `exp` dir + CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${ckpt} ${ckpt_path} +fi + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + # avg n best model + avg.sh best exp/${ckpt}/checkpoints ${avg_num} +fi + +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + # test ckpt avg_n + CUDA_VISIBLE_DEVICES=0 ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1 +fi + +if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then + # export ckpt avg_n + CUDA_VISIBLE_DEVICES= ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit +fi