refactor the code

pull/931/head
huangyuxin 4 years ago
parent ae9f379547
commit ed6bb7a54b

@ -25,7 +25,6 @@ from .utils import add_results_to_json
from deepspeech.exps import dynamic_import_tester
from deepspeech.io.reader import LoadInputsAndTargets
from deepspeech.models.asr_interface import ASRInterface
from deepspeech.models.lm.transformer import TransformerLM
from deepspeech.models.lm_interface import dynamic_import_lm
from deepspeech.utils.log import Log
# from espnet.asr.asr_utils import get_model_conf
@ -51,12 +50,14 @@ def load_trained_model(args):
model = exp.model
return model, char_list, exp, confs
def get_config(config_path):
stream = open(config_path, mode='r', encoding="utf-8")
config = yaml.load(stream, Loader=yaml.FullLoader)
stream.close()
return config
def recog_v2(args):
"""Decode with custom models that implements ScorerInterface.
@ -85,8 +86,8 @@ def recog_v2(args):
if args.preprocess_conf is None else args.preprocess_conf,
preprocess_args={"train": False}, )
if args.use_lm:
lm_path = args.rnnlm_path
if args.rnnlm:
lm_path = args.rnnlm
lm_config_path = args.rnnlm_conf
lm_config = get_config(lm_config_path)
lm_class = dynamic_import_lm("transformer")

@ -397,7 +397,6 @@ class DeepSpeech2ModelOnline(nn.Layer):
class DeepSpeech2InferModelOnline(DeepSpeech2ModelOnline):
def __init__(self, *args, **kwargs):
print("*args", *args)
super().__init__(*args, **kwargs)
def forward(self, audio_chunk, audio_chunk_lens, chunk_state_h_box,

@ -1,3 +1,16 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
@ -17,4 +30,5 @@ def check_kwargs(func, kwargs, name=None):
name = func.__name__
for k in kwargs.keys():
if k not in params:
raise TypeError(f"{name}() got an unexpected keyword argument '{k}'")
raise TypeError(
f"{name}() got an unexpected keyword argument '{k}'")

@ -0,0 +1,13 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

@ -18,11 +18,11 @@ collator:
augmentation_config: conf/augmentation.json
random_seed: 0
spm_model_prefix:
spectrum_type: fbank #linear, mfcc, fbank
feat_dim: 80
spectrum_type: linear #linear, mfcc, fbank
feat_dim:
delta_delta: False
stride_ms: 10.0
window_ms: 25.0
window_ms: 20.0
n_fft: None
max_freq: None
target_sample_rate: 16000
@ -36,17 +36,17 @@ collator:
model:
num_conv_layers: 2
num_rnn_layers: 4
num_rnn_layers: 5
rnn_layer_size: 1024
rnn_direction: forward # [forward, bidirect]
num_fc_layers: 1
fc_layers_size_list: 1024,
num_fc_layers: 0
fc_layers_size_list: -1,
use_gru: False
blank_id: 0
ctc_grad_norm_type: null
training:
n_epoch: 80
n_epoch: 50
accum_grad: 1
lr: 2e-3
lr_decay: 0.9 # 0.83

@ -30,11 +30,10 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
num_workers=$(nproc)
python3 ${MAIN_ROOT}/utils/compute_mean_std.py \
--manifest_path="data/manifest.train.raw" \
--spectrum_type="fbank" \
--feat_dim=80 \
--spectrum_type="linear" \
--delta_delta=false \
--stride_ms=10.0 \
--window_ms=25.0 \
--window_ms=20.0 \
--sample_rate=16000 \
--use_dB_normalization=True \
--num_samples=2000 \

@ -23,5 +23,5 @@
| test-clean | join_ctc_w/o_lm | 2620 | 52576 | 97.2 | 2.6 | 0.3 | 0.4 | 3.2 | 34.9 |
| test-clean | join_ctc_w_lm | 2620 | 52576 | 97.9 | 1.8 | 0.2 | 0.3 | 2.4 | 27.8 |
Compare with [ESPNET](https://github.com/espnet/espnet/blob/master/egs/librispeech/asr1/RESULTS.md#pytorch-large-transformer-with-specaug-4-gpus--transformer-lm-4-gpus)
Compare with [ESPNET](https://github.com/espnet/espnet/blob/master/egs/librispeech/asr1/RESULTS.md#pytorch-large-transformer-with-specaug-4-gpus--transformer-lm-4-gpus)
we using 8gpu, but model size (aheads4-adim256) small than it.

@ -11,7 +11,6 @@ tag=
decode_config=conf/decode/decode.yaml
# lm params
use_lm=true
lang_model=transformerLM.pdparams
lmexpdir=exp/lm/transformer
lmtag='nolm'
@ -95,9 +94,8 @@ for dmethd in join_ctc; do
--result-label ${decode_dir}/data.JOB.json \
--model-conf ${config_path} \
--model ${ckpt_prefix}.pdparams \
--use_rnnlm ${use_lm} \
--rnnlm-conf ${rnnlm_config_path} \
--rnnlm-path ${lmexpdir}/${lang_model}
--rnnlm ${lmexpdir}/${lang_model}
score_sclite.sh --bpe ${nbpe} --bpemodel ${bpemodel} --wer false ${decode_dir} ${dict}

@ -10,6 +10,7 @@ stop_stage=100
conf_path=conf/transformer.yaml
dict_path=data/bpe_unigram_5000_units.txt
avg_num=10
use_lm=true
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
@ -46,3 +47,11 @@ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
# export ckpt avg_n
CUDA_VISIBLE_DEVICES= ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit
fi
if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ] && ${use_lm} == true; then
# use transformerlm to score
if [ ! -f exp/lm/transformer/transformerLM.pdparams ]; then
wget https://deepspeech.bj.bcebos.com/transformer_lm/transformerLM.pdparams exp/lm/transformer/
fi
bash local/recog.sh --ckpt_prefix exp/${ckpt}/checkpoints/${avg_ckpt}
fi

@ -12,33 +12,32 @@ from deepspeech.utils.cli_utils import is_scipy_wav_style
def get_parser():
parser = argparse.ArgumentParser(
description="convert feature to its shape",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument("--verbose", "-V", default=0, type=int, help="Verbose option")
formatter_class=argparse.ArgumentDefaultsHelpFormatter, )
parser.add_argument(
"--verbose", "-V", default=0, type=int, help="Verbose option")
parser.add_argument(
"--filetype",
type=str,
default="mat",
choices=["mat", "hdf5", "sound.hdf5", "sound"],
help="Specify the file format for the rspecifier. "
'"mat" is the matrix format in kaldi',
)
'"mat" is the matrix format in kaldi', )
parser.add_argument(
"--preprocess-conf",
type=str,
default=None,
help="The configuration file for the pre-processing",
)
help="The configuration file for the pre-processing", )
parser.add_argument(
"rspecifier", type=str, help="Read specifier for feats. e.g. ark:some.ark"
)
"rspecifier",
type=str,
help="Read specifier for feats. e.g. ark:some.ark")
parser.add_argument(
"out",
nargs="?",
type=argparse.FileType("w"),
default=sys.stdout,
help="The output filename. " "If omitted, then output to sys.stdout",
)
help="The output filename. "
"If omitted, then output to sys.stdout", )
return parser
@ -64,8 +63,7 @@ def main():
# so change to file_reader_helper to return shape.
# This make sense only with filetype="hdf5".
for utt, mat in file_reader_helper(
args.rspecifier, args.filetype, return_shape=preprocessing is None
):
args.rspecifier, args.filetype, return_shape=preprocessing is None):
if preprocessing is not None:
if is_scipy_wav_style(mat):
# If data is sound file, then got as Tuple[int, ndarray]

Loading…
Cancel
Save