refactor the train and test config,test=asr

pull/1225/head
huangyuxin 3 years ago
parent 425b085f94
commit c40b6f4062

@ -1,97 +1,93 @@
# network architecture ############################################
model: # Network Architecture #
cmvn_file: ############################################
cmvn_file_type: "json" #model:
# encoder related cmvn_file:
encoder: conformer cmvn_file_type: "json"
encoder_conf: # encoder related
output_size: 256 # dimension of attention encoder: conformer
attention_heads: 4 encoder_conf:
linear_units: 2048 # the number of units of position-wise feed forward output_size: 256 # dimension of attention
num_blocks: 12 # the number of encoder blocks attention_heads: 4
dropout_rate: 0.1 linear_units: 2048 # the number of units of position-wise feed forward
positional_dropout_rate: 0.1 num_blocks: 12 # the number of encoder blocks
attention_dropout_rate: 0.0 dropout_rate: 0.1
input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8 positional_dropout_rate: 0.1
normalize_before: True attention_dropout_rate: 0.0
cnn_module_kernel: 15 input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8
use_cnn_module: True normalize_before: True
activation_type: 'swish' cnn_module_kernel: 15
pos_enc_layer_type: 'rel_pos' use_cnn_module: True
selfattention_layer_type: 'rel_selfattn' activation_type: 'swish'
pos_enc_layer_type: 'rel_pos'
selfattention_layer_type: 'rel_selfattn'
# decoder related # decoder related
decoder: transformer decoder: transformer
decoder_conf: decoder_conf:
attention_heads: 4 attention_heads: 4
linear_units: 2048 linear_units: 2048
num_blocks: 6 num_blocks: 6
dropout_rate: 0.1 dropout_rate: 0.1
positional_dropout_rate: 0.1 positional_dropout_rate: 0.1
self_attention_dropout_rate: 0.0 self_attention_dropout_rate: 0.0
src_attention_dropout_rate: 0.0 src_attention_dropout_rate: 0.0
# hybrid CTC/attention # hybrid CTC/attention
model_conf: model_conf:
ctc_weight: 0.3 ctc_weight: 0.3
lsm_weight: 0.1 # label smoothing option lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false length_normalized_loss: false
data: ###########################################
train_manifest: data/manifest.train # Data #
dev_manifest: data/manifest.dev ###########################################
test_manifest: data/manifest.test #data:
train_manifest: data/manifest.train
dev_manifest: data/manifest.dev
test_manifest: data/manifest.test
###########################################
# Dataloader #
###########################################
#collator:
vocab_filepath: data/lang_char/vocab.txt
unit_type: 'char'
augmentation_config: conf/preprocess.yaml
spm_model_prefix: ''
feat_dim: 80
stride_ms: 10.0
window_ms: 25.0
sortagrad: 0 # Feed samples from shortest to longest ; -1: enabled for all epochs, 0: disabled, other: enabled for 'other' epochs
batch_size: 64
maxlen_in: 512 # if input length > maxlen-in, batchsize is automatically reduced
maxlen_out: 150 # if output length > maxlen-out, batchsize is automatically reduced
minibatches: 0 # for debug
batch_count: auto
batch_bins: 0
batch_frames_in: 0
batch_frames_out: 0
batch_frames_inout: 0
num_workers: 0
subsampling_factor: 1
num_encs: 1
collator: ###########################################
vocab_filepath: data/lang_char/vocab.txt # training #
unit_type: 'char' ###########################################
augmentation_config: conf/preprocess.yaml #training:
feat_dim: 80 n_epoch: 240
stride_ms: 10.0 accum_grad: 2
window_ms: 25.0 global_grad_clip: 5.0
sortagrad: 0 # Feed samples from shortest to longest ; -1: enabled for all epochs, 0: disabled, other: enabled for 'other' epochs optim: adam
batch_size: 64 optim_conf:
maxlen_in: 512 # if input length > maxlen-in, batchsize is automatically reduced lr: 0.002
maxlen_out: 150 # if output length > maxlen-out, batchsize is automatically reduced weight_decay: 1.0e-6
minibatches: 0 # for debug scheduler: warmuplr
batch_count: auto scheduler_conf:
batch_bins: 0 warmup_steps: 25000
batch_frames_in: 0 lr_decay: 1.0
batch_frames_out: 0 log_interval: 100
batch_frames_inout: 0 checkpoint:
num_workers: 0 kbest_n: 50
subsampling_factor: 1 latest_n: 5
num_encs: 1
training:
n_epoch: 240
accum_grad: 2
global_grad_clip: 5.0
optim: adam
optim_conf:
lr: 0.002
weight_decay: 1e-6
scheduler: warmuplr
scheduler_conf:
warmup_steps: 25000
lr_decay: 1.0
log_interval: 100
checkpoint:
kbest_n: 50
latest_n: 5
decoding:
beam_size: 10
batch_size: 128
error_rate_type: cer
decoding_method: attention # 'attention', 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring'
ctc_weight: 0.5 # ctc weight for attention rescoring decode mode.
decoding_chunk_size: -1 # decoding chunk size. Defaults to -1.
# <0: for decoding, use full chunk.
# >0: for decoding, use fixed chunk size as set.
# 0: used for training, it's prohibited here.
num_decoding_left_chunks: -1 # number of left chunks for decoding. Defaults to -1.
simulate_streaming: False # simulate streaming inference. Defaults to False.

@ -0,0 +1,12 @@
#decoding:
beam_size: 10
decode_batch_size: 128
error_rate_type: cer
decoding_method: attention # 'attention', 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring'
ctc_weight: 0.5 # ctc weight for attention rescoring decode mode.
decoding_chunk_size: -1 # decoding chunk size. Defaults to -1.
# <0: for decoding, use full chunk.
# >0: for decoding, use fixed chunk size as set.
# 0: used for training, it's prohibited here.
num_decoding_left_chunks: -1 # number of left chunks for decoding. Defaults to -1.
simulate_streaming: False # simulate streaming inference. Defaults to False.

@ -1,7 +1,7 @@
#!/bin/bash #!/bin/bash
if [ $# != 2 ];then if [ $# != 3 ];then
echo "usage: ${0} config_path ckpt_path_prefix" echo "usage: ${0} config_path decode_config_path ckpt_path_prefix"
exit -1 exit -1
fi fi
@ -9,7 +9,8 @@ ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..." echo "using $ngpu gpus..."
config_path=$1 config_path=$1
ckpt_prefix=$2 decode_config_path=$2
ckpt_prefix=$3
batch_size=1 batch_size=1
output_dir=${ckpt_prefix} output_dir=${ckpt_prefix}
@ -20,9 +21,10 @@ mkdir -p ${output_dir}
python3 -u ${BIN_DIR}/alignment.py \ python3 -u ${BIN_DIR}/alignment.py \
--ngpu ${ngpu} \ --ngpu ${ngpu} \
--config ${config_path} \ --config ${config_path} \
--decode_config ${decode_config_path} \
--result_file ${output_dir}/${type}.align \ --result_file ${output_dir}/${type}.align \
--checkpoint_path ${ckpt_prefix} \ --checkpoint_path ${ckpt_prefix} \
--opts decoding.batch_size ${batch_size} --opts decoding.decode_batch_size ${batch_size}
if [ $? -ne 0 ]; then if [ $? -ne 0 ]; then
echo "Failed in ctc alignment!" echo "Failed in ctc alignment!"

@ -1,7 +1,7 @@
#!/bin/bash #!/bin/bash
if [ $# != 2 ];then if [ $# != 3 ];then
echo "usage: ${0} config_path ckpt_path_prefix" echo "usage: ${0} config_path decode_config_path ckpt_path_prefix"
exit -1 exit -1
fi fi
@ -9,7 +9,8 @@ ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..." echo "using $ngpu gpus..."
config_path=$1 config_path=$1
ckpt_prefix=$2 decode_config_path=$2
ckpt_prefix=$3
chunk_mode=false chunk_mode=false
if [[ ${config_path} =~ ^.*chunk_.*yaml$ ]];then if [[ ${config_path} =~ ^.*chunk_.*yaml$ ]];then
@ -36,10 +37,11 @@ for type in attention ctc_greedy_search; do
python3 -u ${BIN_DIR}/test.py \ python3 -u ${BIN_DIR}/test.py \
--ngpu ${ngpu} \ --ngpu ${ngpu} \
--config ${config_path} \ --config ${config_path} \
--decode_config ${decode_config_path} \
--result_file ${output_dir}/${type}.rsl \ --result_file ${output_dir}/${type}.rsl \
--checkpoint_path ${ckpt_prefix} \ --checkpoint_path ${ckpt_prefix} \
--opts decoding.decoding_method ${type} \ --opts decoding.decoding_method ${type} \
--opts decoding.batch_size ${batch_size} --opts decoding.decode_batch_size ${batch_size}
if [ $? -ne 0 ]; then if [ $? -ne 0 ]; then
echo "Failed in evaluation!" echo "Failed in evaluation!"
@ -55,6 +57,7 @@ for type in ctc_prefix_beam_search attention_rescoring; do
python3 -u ${BIN_DIR}/test.py \ python3 -u ${BIN_DIR}/test.py \
--ngpu ${ngpu} \ --ngpu ${ngpu} \
--config ${config_path} \ --config ${config_path} \
--decode_config ${decode_config_path} \
--result_file ${output_dir}/${type}.rsl \ --result_file ${output_dir}/${type}.rsl \
--checkpoint_path ${ckpt_prefix} \ --checkpoint_path ${ckpt_prefix} \
--opts decoding.decoding_method ${type} \ --opts decoding.decoding_method ${type} \

@ -1,7 +1,7 @@
#!/bin/bash #!/bin/bash
if [ $# != 3 ];then if [ $# != 4 ];then
echo "usage: ${0} config_path ckpt_path_prefix audio_file" echo "usage: ${0} config_path decode_config_path ckpt_path_prefix audio_file"
exit -1 exit -1
fi fi
@ -9,8 +9,9 @@ ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..." echo "using $ngpu gpus..."
config_path=$1 config_path=$1
ckpt_prefix=$2 decode_config_path=$2
audio_file=$3 ckpt_prefix=$3
audio_file=$4
mkdir -p data mkdir -p data
wget -nc https://paddlespeech.bj.bcebos.com/datasets/single_wav/zh/demo_01_03.wav -P data/ wget -nc https://paddlespeech.bj.bcebos.com/datasets/single_wav/zh/demo_01_03.wav -P data/
@ -42,10 +43,11 @@ for type in attention_rescoring; do
python3 -u ${BIN_DIR}/test_wav.py \ python3 -u ${BIN_DIR}/test_wav.py \
--ngpu ${ngpu} \ --ngpu ${ngpu} \
--config ${config_path} \ --config ${config_path} \
--decode_config ${decode_config_path} \
--result_file ${output_dir}/${type}.rsl \ --result_file ${output_dir}/${type}.rsl \
--checkpoint_path ${ckpt_prefix} \ --checkpoint_path ${ckpt_prefix} \
--opts decoding.decoding_method ${type} \ --opts decoding.decoding_method ${type} \
--opts decoding.batch_size ${batch_size} \ --opts decoding.decode_batch_size ${batch_size} \
--audio_file ${audio_file} --audio_file ${audio_file}
if [ $? -ne 0 ]; then if [ $? -ne 0 ]; then

@ -6,6 +6,7 @@ gpus=0,1,2,3
stage=0 stage=0
stop_stage=50 stop_stage=50
conf_path=conf/conformer.yaml conf_path=conf/conformer.yaml
decode_conf_path=conf/decode.yaml
avg_num=20 avg_num=20
audio_file=data/demo_01_03.wav audio_file=data/demo_01_03.wav
@ -32,18 +33,18 @@ fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
# test ckpt avg_n # test ckpt avg_n
CUDA_VISIBLE_DEVICES=0 ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1 CUDA_VISIBLE_DEVICES=0 ./local/test.sh ${conf_path} ${decode_conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1
fi fi
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
# ctc alignment of test data # ctc alignment of test data
CUDA_VISIBLE_DEVICES=0 ./local/align.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1 CUDA_VISIBLE_DEVICES=0 ./local/align.sh ${conf_path} ${decode_conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1
fi fi
# Optionally, you can add LM and test it with runtime. # Optionally, you can add LM and test it with runtime.
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
# test a single .wav file # test a single .wav file
CUDA_VISIBLE_DEVICES=0 ./local/test_wav.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} ${audio_file} || exit -1 CUDA_VISIBLE_DEVICES=0 ./local/test_wav.sh ${conf_path} ${decode_conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} ${audio_file} || exit -1
fi fi
# Not supported at now!!! # Not supported at now!!!

@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Alignment for U2 model.""" """Alignment for U2 model."""
from yacs.config import CfgNode
from paddlespeech.s2t.exps.u2.config import get_cfg_defaults from paddlespeech.s2t.exps.u2.config import get_cfg_defaults
from paddlespeech.s2t.exps.u2.model import U2Tester as Tester from paddlespeech.s2t.exps.u2.model import U2Tester as Tester
from paddlespeech.s2t.training.cli import default_argument_parser from paddlespeech.s2t.training.cli import default_argument_parser
@ -41,6 +43,10 @@ if __name__ == "__main__":
config = get_cfg_defaults() config = get_cfg_defaults()
if args.config: if args.config:
config.merge_from_file(args.config) config.merge_from_file(args.config)
if args.decode_config:
decode_confs = CfgNode(new_allowed=True)
decode_confs.merge_from_file(args.decode_config)
config.decoding = decode_confs
if args.opts: if args.opts:
config.merge_from_list(args.opts) config.merge_from_list(args.opts)
config.freeze() config.freeze()

@ -14,12 +14,14 @@
"""Evaluation for U2 model.""" """Evaluation for U2 model."""
import cProfile import cProfile
from yacs.config import CfgNode
from paddlespeech.s2t.exps.u2.config import get_cfg_defaults from paddlespeech.s2t.exps.u2.config import get_cfg_defaults
from paddlespeech.s2t.exps.u2.model import U2Tester as Tester from paddlespeech.s2t.exps.u2.model import U2Tester as Tester
from paddlespeech.s2t.training.cli import default_argument_parser from paddlespeech.s2t.training.cli import default_argument_parser
from paddlespeech.s2t.utils.utility import print_arguments from paddlespeech.s2t.utils.utility import print_arguments
# TODO(hui zhang): dynamic load # TODO(hui zhang): dynamic load
def main_sp(config, args): def main_sp(config, args):
@ -35,7 +37,7 @@ def main(config, args):
if __name__ == "__main__": if __name__ == "__main__":
parser = default_argument_parser() parser = default_argument_parser()
# save asr result to # save asr result to
parser.add_argument( parser.add_argument(
"--result_file", type=str, help="path of save the asr result") "--result_file", type=str, help="path of save the asr result")
args = parser.parse_args() args = parser.parse_args()
@ -45,6 +47,10 @@ if __name__ == "__main__":
config = get_cfg_defaults() config = get_cfg_defaults()
if args.config: if args.config:
config.merge_from_file(args.config) config.merge_from_file(args.config)
if args.decode_config:
decode_confs = CfgNode(new_allowed=True)
decode_confs.merge_from_file(args.decode_config)
config.decoding = decode_confs
if args.opts: if args.opts:
config.merge_from_list(args.opts) config.merge_from_list(args.opts)
config.freeze() config.freeze()

@ -18,6 +18,7 @@ from pathlib import Path
import paddle import paddle
import soundfile import soundfile
from yacs.config import CfgNode
from paddlespeech.s2t.exps.u2.config import get_cfg_defaults from paddlespeech.s2t.exps.u2.config import get_cfg_defaults
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
@ -36,23 +37,22 @@ class U2Infer():
self.args = args self.args = args
self.config = config self.config = config
self.audio_file = args.audio_file self.audio_file = args.audio_file
self.sr = config.collator.target_sample_rate
self.preprocess_conf = config.collator.augmentation_config self.preprocess_conf = config.augmentation_config
self.preprocess_args = {"train": False} self.preprocess_args = {"train": False}
self.preprocessing = Transformation(self.preprocess_conf) self.preprocessing = Transformation(self.preprocess_conf)
self.text_feature = TextFeaturizer( self.text_feature = TextFeaturizer(
unit_type=config.collator.unit_type, unit_type=config.unit_type,
vocab=config.collator.vocab_filepath, vocab=config.vocab_filepath,
spm_model_prefix=config.collator.spm_model_prefix) spm_model_prefix=config.spm_model_prefix)
paddle.set_device('gpu' if self.args.ngpu > 0 else 'cpu') paddle.set_device('gpu' if self.args.ngpu > 0 else 'cpu')
# model # model
model_conf = config.model model_conf = config
with UpdateConfig(model_conf): with UpdateConfig(model_conf):
model_conf.input_dim = config.collator.feat_dim model_conf.input_dim = config.feat_dim
model_conf.output_dim = self.text_feature.vocab_size model_conf.output_dim = self.text_feature.vocab_size
model = U2Model.from_config(model_conf) model = U2Model.from_config(model_conf)
self.model = model self.model = model
@ -70,10 +70,6 @@ class U2Infer():
# read # read
audio, sample_rate = soundfile.read( audio, sample_rate = soundfile.read(
self.audio_file, dtype="int16", always_2d=True) self.audio_file, dtype="int16", always_2d=True)
if sample_rate != self.sr:
logger.error(
f"sample rate error: {sample_rate}, need {self.sr} ")
sys.exit(-1)
audio = audio[:, 0] audio = audio[:, 0]
logger.info(f"audio shape: {audio.shape}") logger.info(f"audio shape: {audio.shape}")
@ -85,17 +81,17 @@ class U2Infer():
ilen = paddle.to_tensor(feat.shape[0]) ilen = paddle.to_tensor(feat.shape[0])
xs = paddle.to_tensor(feat, dtype='float32').unsqueeze(axis=0) xs = paddle.to_tensor(feat, dtype='float32').unsqueeze(axis=0)
cfg = self.config.decoding decode_config = self.config.decoding
result_transcripts = self.model.decode( result_transcripts = self.model.decode(
xs, xs,
ilen, ilen,
text_feature=self.text_feature, text_feature=self.text_feature,
decoding_method=cfg.decoding_method, decoding_method=decode_config.decoding_method,
beam_size=cfg.beam_size, beam_size=decode_config.beam_size,
ctc_weight=cfg.ctc_weight, ctc_weight=decode_config.ctc_weight,
decoding_chunk_size=cfg.decoding_chunk_size, decoding_chunk_size=decode_config.decoding_chunk_size,
num_decoding_left_chunks=cfg.num_decoding_left_chunks, num_decoding_left_chunks=decode_config.num_decoding_left_chunks,
simulate_streaming=cfg.simulate_streaming) simulate_streaming=decode_config.simulate_streaming)
rsl = result_transcripts[0][0] rsl = result_transcripts[0][0]
utt = Path(self.audio_file).name utt = Path(self.audio_file).name
logger.info(f"hyp: {utt} {result_transcripts[0][0]}") logger.info(f"hyp: {utt} {result_transcripts[0][0]}")
@ -136,6 +132,10 @@ if __name__ == "__main__":
config = get_cfg_defaults() config = get_cfg_defaults()
if args.config: if args.config:
config.merge_from_file(args.config) config.merge_from_file(args.config)
if args.decode_config:
decode_confs = CfgNode(new_allowed=True)
decode_confs.merge_from_file(args.decode_config)
config.decoding = decode_confs
if args.opts: if args.opts:
config.merge_from_list(args.opts) config.merge_from_list(args.opts)
config.freeze() config.freeze()

@ -19,19 +19,18 @@ from paddlespeech.s2t.io.collator import SpeechCollator
from paddlespeech.s2t.io.dataset import ManifestDataset from paddlespeech.s2t.io.dataset import ManifestDataset
from paddlespeech.s2t.models.u2 import U2Model from paddlespeech.s2t.models.u2 import U2Model
_C = CfgNode() _C = CfgNode(new_allowed=True)
_C.data = ManifestDataset.params() ManifestDataset.params(_C)
_C.collator = SpeechCollator.params() SpeechCollator.params(_C)
_C.model = U2Model.params() U2Model.params(_C)
_C.training = U2Trainer.params() U2Trainer.params(_C)
_C.decoding = U2Tester.params() _C.decoding = U2Tester.params()
def get_cfg_defaults(): def get_cfg_defaults():
"""Get a yacs CfgNode object with default values for my_project.""" """Get a yacs CfgNode object with default values for my_project."""
# Return a clone so that the defaults will not be altered # Return a clone so that the defaults will not be altered

@ -77,7 +77,7 @@ class U2Trainer(Trainer):
super().__init__(config, args) super().__init__(config, args)
def train_batch(self, batch_index, batch_data, msg): def train_batch(self, batch_index, batch_data, msg):
train_conf = self.config.training train_conf = self.config
start = time.time() start = time.time()
# forward # forward
@ -120,7 +120,7 @@ class U2Trainer(Trainer):
for k, v in losses_np.items(): for k, v in losses_np.items():
report(k, v) report(k, v)
report("batch_size", self.config.collator.batch_size) report("batch_size", self.config.batch_size)
report("accum", train_conf.accum_grad) report("accum", train_conf.accum_grad)
report("step_cost", iteration_time) report("step_cost", iteration_time)
@ -153,7 +153,7 @@ class U2Trainer(Trainer):
if ctc_loss: if ctc_loss:
valid_losses['val_ctc_loss'].append(float(ctc_loss)) valid_losses['val_ctc_loss'].append(float(ctc_loss))
if (i + 1) % self.config.training.log_interval == 0: if (i + 1) % self.config.log_interval == 0:
valid_dump = {k: np.mean(v) for k, v in valid_losses.items()} valid_dump = {k: np.mean(v) for k, v in valid_losses.items()}
valid_dump['val_history_loss'] = total_loss / num_seen_utts valid_dump['val_history_loss'] = total_loss / num_seen_utts
@ -182,7 +182,7 @@ class U2Trainer(Trainer):
self.before_train() self.before_train()
logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}") logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}")
while self.epoch < self.config.training.n_epoch: while self.epoch < self.config.n_epoch:
with Timer("Epoch-Train Time Cost: {}"): with Timer("Epoch-Train Time Cost: {}"):
self.model.train() self.model.train()
try: try:
@ -214,8 +214,7 @@ class U2Trainer(Trainer):
k.split(',')) == 2 else "" k.split(',')) == 2 else ""
msg += "," msg += ","
msg = msg[:-1] # remove the last "," msg = msg[:-1] # remove the last ","
if (batch_index + 1 if (batch_index + 1) % self.config.log_interval == 0:
) % self.config.training.log_interval == 0:
logger.info(msg) logger.info(msg)
data_start_time = time.time() data_start_time = time.time()
except Exception as e: except Exception as e:
@ -252,29 +251,29 @@ class U2Trainer(Trainer):
if self.train: if self.train:
# train/valid dataset, return token ids # train/valid dataset, return token ids
self.train_loader = BatchDataLoader( self.train_loader = BatchDataLoader(
json_file=config.data.train_manifest, json_file=config.train_manifest,
train_mode=True, train_mode=True,
sortagrad=config.collator.sortagrad, sortagrad=config.sortagrad,
batch_size=config.collator.batch_size, batch_size=config.batch_size,
maxlen_in=config.collator.maxlen_in, maxlen_in=config.maxlen_in,
maxlen_out=config.collator.maxlen_out, maxlen_out=config.maxlen_out,
minibatches=config.collator.minibatches, minibatches=config.minibatches,
mini_batch_size=self.args.ngpu, mini_batch_size=self.args.ngpu,
batch_count=config.collator.batch_count, batch_count=config.batch_count,
batch_bins=config.collator.batch_bins, batch_bins=config.batch_bins,
batch_frames_in=config.collator.batch_frames_in, batch_frames_in=config.batch_frames_in,
batch_frames_out=config.collator.batch_frames_out, batch_frames_out=config.batch_frames_out,
batch_frames_inout=config.collator.batch_frames_inout, batch_frames_inout=config.batch_frames_inout,
preprocess_conf=config.collator.augmentation_config, preprocess_conf=config.augmentation_config,
n_iter_processes=config.collator.num_workers, n_iter_processes=config.num_workers,
subsampling_factor=1, subsampling_factor=1,
num_encs=1) num_encs=1)
self.valid_loader = BatchDataLoader( self.valid_loader = BatchDataLoader(
json_file=config.data.dev_manifest, json_file=config.dev_manifest,
train_mode=False, train_mode=False,
sortagrad=False, sortagrad=False,
batch_size=config.collator.batch_size, batch_size=config.batch_size,
maxlen_in=float('inf'), maxlen_in=float('inf'),
maxlen_out=float('inf'), maxlen_out=float('inf'),
minibatches=0, minibatches=0,
@ -284,18 +283,18 @@ class U2Trainer(Trainer):
batch_frames_in=0, batch_frames_in=0,
batch_frames_out=0, batch_frames_out=0,
batch_frames_inout=0, batch_frames_inout=0,
preprocess_conf=config.collator.augmentation_config, preprocess_conf=config.augmentation_config,
n_iter_processes=config.collator.num_workers, n_iter_processes=config.num_workers,
subsampling_factor=1, subsampling_factor=1,
num_encs=1) num_encs=1)
logger.info("Setup train/valid Dataloader!") logger.info("Setup train/valid Dataloader!")
else: else:
# test dataset, return raw text # test dataset, return raw text
self.test_loader = BatchDataLoader( self.test_loader = BatchDataLoader(
json_file=config.data.test_manifest, json_file=config.test_manifest,
train_mode=False, train_mode=False,
sortagrad=False, sortagrad=False,
batch_size=config.decoding.batch_size, batch_size=config.decoding.decode_batch_size,
maxlen_in=float('inf'), maxlen_in=float('inf'),
maxlen_out=float('inf'), maxlen_out=float('inf'),
minibatches=0, minibatches=0,
@ -305,16 +304,16 @@ class U2Trainer(Trainer):
batch_frames_in=0, batch_frames_in=0,
batch_frames_out=0, batch_frames_out=0,
batch_frames_inout=0, batch_frames_inout=0,
preprocess_conf=config.collator.augmentation_config, preprocess_conf=config.augmentation_config,
n_iter_processes=1, n_iter_processes=1,
subsampling_factor=1, subsampling_factor=1,
num_encs=1) num_encs=1)
self.align_loader = BatchDataLoader( self.align_loader = BatchDataLoader(
json_file=config.data.test_manifest, json_file=config.test_manifest,
train_mode=False, train_mode=False,
sortagrad=False, sortagrad=False,
batch_size=config.decoding.batch_size, batch_size=config.decoding.decode_batch_size,
maxlen_in=float('inf'), maxlen_in=float('inf'),
maxlen_out=float('inf'), maxlen_out=float('inf'),
minibatches=0, minibatches=0,
@ -324,7 +323,7 @@ class U2Trainer(Trainer):
batch_frames_in=0, batch_frames_in=0,
batch_frames_out=0, batch_frames_out=0,
batch_frames_inout=0, batch_frames_inout=0,
preprocess_conf=config.collator.augmentation_config, preprocess_conf=config.augmentation_config,
n_iter_processes=1, n_iter_processes=1,
subsampling_factor=1, subsampling_factor=1,
num_encs=1) num_encs=1)
@ -332,7 +331,7 @@ class U2Trainer(Trainer):
def setup_model(self): def setup_model(self):
config = self.config config = self.config
model_conf = config.model model_conf = config
with UpdateConfig(model_conf): with UpdateConfig(model_conf):
if self.train: if self.train:
@ -355,7 +354,7 @@ class U2Trainer(Trainer):
if not self.train: if not self.train:
return return
train_config = config.training train_config = config
optim_type = train_config.optim optim_type = train_config.optim
optim_conf = train_config.optim_conf optim_conf = train_config.optim_conf
scheduler_type = train_config.scheduler scheduler_type = train_config.scheduler
@ -375,7 +374,7 @@ class U2Trainer(Trainer):
config, config,
parameters, parameters,
lr_scheduler=None, ): lr_scheduler=None, ):
train_config = config.training train_config = config
optim_type = train_config.optim optim_type = train_config.optim
optim_conf = train_config.optim_conf optim_conf = train_config.optim_conf
scheduler_type = train_config.scheduler scheduler_type = train_config.scheduler
@ -415,7 +414,7 @@ class U2Tester(U2Trainer):
error_rate_type='wer', # Error rate type for evaluation. Options `wer`, 'cer' error_rate_type='wer', # Error rate type for evaluation. Options `wer`, 'cer'
num_proc_bsearch=8, # # of CPUs for beam search. num_proc_bsearch=8, # # of CPUs for beam search.
beam_size=10, # Beam search width. beam_size=10, # Beam search width.
batch_size=16, # decoding batch size decode_batch_size=16, # decoding batch size
ctc_weight=0.0, # ctc weight for attention rescoring decode mode. ctc_weight=0.0, # ctc weight for attention rescoring decode mode.
decoding_chunk_size=-1, # decoding chunk size. Defaults to -1. decoding_chunk_size=-1, # decoding chunk size. Defaults to -1.
# <0: for decoding, use full chunk. # <0: for decoding, use full chunk.
@ -432,9 +431,9 @@ class U2Tester(U2Trainer):
def __init__(self, config, args): def __init__(self, config, args):
super().__init__(config, args) super().__init__(config, args)
self.text_feature = TextFeaturizer( self.text_feature = TextFeaturizer(
unit_type=self.config.collator.unit_type, unit_type=self.config.unit_type,
vocab=self.config.collator.vocab_filepath, vocab=self.config.vocab_filepath,
spm_model_prefix=self.config.collator.spm_model_prefix) spm_model_prefix=self.config.spm_model_prefix)
self.vocab_list = self.text_feature.vocab_list self.vocab_list = self.text_feature.vocab_list
def id2token(self, texts, texts_len, text_feature): def id2token(self, texts, texts_len, text_feature):
@ -453,10 +452,10 @@ class U2Tester(U2Trainer):
texts, texts,
texts_len, texts_len,
fout=None): fout=None):
cfg = self.config.decoding decode_config = self.config.decoding
errors_sum, len_refs, num_ins = 0.0, 0, 0 errors_sum, len_refs, num_ins = 0.0, 0, 0
errors_func = error_rate.char_errors if cfg.error_rate_type == 'cer' else error_rate.word_errors errors_func = error_rate.char_errors if decode_config.error_rate_type == 'cer' else error_rate.word_errors
error_rate_func = error_rate.cer if cfg.error_rate_type == 'cer' else error_rate.wer error_rate_func = error_rate.cer if decode_config.error_rate_type == 'cer' else error_rate.wer
start_time = time.time() start_time = time.time()
target_transcripts = self.id2token(texts, texts_len, self.text_feature) target_transcripts = self.id2token(texts, texts_len, self.text_feature)
@ -464,12 +463,12 @@ class U2Tester(U2Trainer):
audio, audio,
audio_len, audio_len,
text_feature=self.text_feature, text_feature=self.text_feature,
decoding_method=cfg.decoding_method, decoding_method=decode_config.decoding_method,
beam_size=cfg.beam_size, beam_size=decode_config.beam_size,
ctc_weight=cfg.ctc_weight, ctc_weight=decode_config.ctc_weight,
decoding_chunk_size=cfg.decoding_chunk_size, decoding_chunk_size=decode_config.decoding_chunk_size,
num_decoding_left_chunks=cfg.num_decoding_left_chunks, num_decoding_left_chunks=decode_config.num_decoding_left_chunks,
simulate_streaming=cfg.simulate_streaming) simulate_streaming=decode_config.simulate_streaming)
decode_time = time.time() - start_time decode_time = time.time() - start_time
for utt, target, result, rec_tids in zip( for utt, target, result, rec_tids in zip(
@ -488,15 +487,15 @@ class U2Tester(U2Trainer):
logger.info(f"Utt: {utt}") logger.info(f"Utt: {utt}")
logger.info(f"Ref: {target}") logger.info(f"Ref: {target}")
logger.info(f"Hyp: {result}") logger.info(f"Hyp: {result}")
logger.info("One example error rate [%s] = %f" % logger.info("One example error rate [%s] = %f" % (
(cfg.error_rate_type, error_rate_func(target, result))) decode_config.error_rate_type, error_rate_func(target, result)))
return dict( return dict(
errors_sum=errors_sum, errors_sum=errors_sum,
len_refs=len_refs, len_refs=len_refs,
num_ins=num_ins, # num examples num_ins=num_ins, # num examples
error_rate=errors_sum / len_refs, error_rate=errors_sum / len_refs,
error_rate_type=cfg.error_rate_type, error_rate_type=decode_config.error_rate_type,
num_frames=audio_len.sum().numpy().item(), num_frames=audio_len.sum().numpy().item(),
decode_time=decode_time) decode_time=decode_time)
@ -507,7 +506,7 @@ class U2Tester(U2Trainer):
self.model.eval() self.model.eval()
logger.info(f"Test Total Examples: {len(self.test_loader.dataset)}") logger.info(f"Test Total Examples: {len(self.test_loader.dataset)}")
stride_ms = self.config.collator.stride_ms stride_ms = self.config.stride_ms
error_rate_type = None error_rate_type = None
errors_sum, len_refs, num_ins = 0.0, 0, 0 errors_sum, len_refs, num_ins = 0.0, 0, 0
num_frames = 0.0 num_frames = 0.0
@ -558,15 +557,15 @@ class U2Tester(U2Trainer):
"ref_len": "ref_len":
len_refs, len_refs,
"decode_method": "decode_method":
self.config.decoding.decoding_method, self.config.decoding_method,
}) })
f.write(data + '\n') f.write(data + '\n')
@paddle.no_grad() @paddle.no_grad()
def align(self): def align(self):
ctc_utils.ctc_align(self.config, self.model, self.align_loader, ctc_utils.ctc_align(self.config, self.model, self.align_loader,
self.config.decoding.batch_size, self.config.decoding.decode_batch_size,
self.config.collator.stride_ms, self.vocab_list, self.config.stride_ms, self.vocab_list,
self.args.result_file) self.args.result_file)
def load_inferspec(self): def load_inferspec(self):
@ -577,10 +576,10 @@ class U2Tester(U2Trainer):
List[paddle.static.InputSpec]: input spec. List[paddle.static.InputSpec]: input spec.
""" """
from paddlespeech.s2t.models.u2 import U2InferModel from paddlespeech.s2t.models.u2 import U2InferModel
infer_model = U2InferModel.from_pretrained(self.test_loader, infer_model = U2InferModel.from_pretrained(self.train_loader,
self.config.model.clone(), self.config.clone(),
self.args.checkpoint_path) self.args.checkpoint_path)
feat_dim = self.test_loader.feat_dim feat_dim = self.train_loader.feat_dim
input_spec = [ input_spec = [
paddle.static.InputSpec(shape=[1, None, feat_dim], paddle.static.InputSpec(shape=[1, None, feat_dim],
dtype='float32'), # audio, [B,T,D] dtype='float32'), # audio, [B,T,D]

@ -97,6 +97,14 @@ def default_argument_parser(parser=None):
train_group.add_argument( train_group.add_argument(
"--dump-config", metavar="FILE", help="dump config to `this` file.") "--dump-config", metavar="FILE", help="dump config to `this` file.")
test_group = parser.add_argument_group(
title='Test Options', description=None)
test_group.add_argument(
"--decode_config",
metavar="DECODE_CONFIG_FILE",
help="decode config file.")
profile_group = parser.add_argument_group( profile_group = parser.add_argument_group(
title='Benchmark Options', description=None) title='Benchmark Options', description=None)
profile_group.add_argument( profile_group.add_argument(

@ -117,8 +117,8 @@ class Trainer():
self.init_parallel() self.init_parallel()
self.checkpoint = Checkpoint( self.checkpoint = Checkpoint(
kbest_n=self.config.training.checkpoint.kbest_n, kbest_n=self.config.checkpoint.kbest_n,
latest_n=self.config.training.checkpoint.latest_n) latest_n=self.config.checkpoint.latest_n)
# set random seed if needed # set random seed if needed
if args.seed: if args.seed:
@ -129,8 +129,8 @@ class Trainer():
if hasattr(self.args, if hasattr(self.args,
"benchmark_batch_size") and self.args.benchmark_batch_size: "benchmark_batch_size") and self.args.benchmark_batch_size:
with UpdateConfig(self.config): with UpdateConfig(self.config):
self.config.collator.batch_size = self.args.benchmark_batch_size self.config.batch_size = self.args.benchmark_batch_size
self.config.training.log_interval = 1 self.config.log_interval = 1
logger.info( logger.info(
f"Benchmark reset batch-size: {self.args.benchmark_batch_size}") f"Benchmark reset batch-size: {self.args.benchmark_batch_size}")
@ -260,7 +260,7 @@ class Trainer():
self.before_train() self.before_train()
logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}") logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}")
while self.epoch < self.config.training.n_epoch: while self.epoch < self.config.n_epoch:
with Timer("Epoch-Train Time Cost: {}"): with Timer("Epoch-Train Time Cost: {}"):
self.model.train() self.model.train()
try: try:

@ -130,7 +130,7 @@ def get_subsample(config):
Returns: Returns:
int: subsample rate. int: subsample rate.
""" """
input_layer = config["model"]["encoder_conf"]["input_layer"] input_layer = config["encoder_conf"]["input_layer"]
assert input_layer in ["conv2d", "conv2d6", "conv2d8"] assert input_layer in ["conv2d", "conv2d6", "conv2d8"]
if input_layer == "conv2d": if input_layer == "conv2d":
return 4 return 4

Loading…
Cancel
Save