fix augmentation

pull/768/head
Hui Zhang 3 years ago
parent 0ab299a842
commit 9dace62581

@ -1,54 +0,0 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Alignment for U2 model."""
from deepspeech.exps.u2.model import get_cfg_defaults
from deepspeech.exps.u2.model import U2Tester as Tester
from deepspeech.training.cli import default_argument_parser
from deepspeech.utils.dynamic_import import dynamic_import
from deepspeech.utils.utility import print_arguments
def main_sp(config, args):
exp = Tester(config, args)
exp.setup()
exp.run_align()
def main(config, args):
main_sp(config, args)
if __name__ == "__main__":
parser = default_argument_parser()
parser.add_arguments(
'--model-name',
type=str,
default='u2',
help='model name, e.g: deepspeech2, u2, u2_kaldi, u2_st')
args = parser.parse_args()
print_arguments(args, globals())
# https://yaml.org/type/float.html
config = get_cfg_defaults()
if args.config:
config.merge_from_file(args.config)
if args.opts:
config.merge_from_list(args.opts)
config.freeze()
print(config)
if args.dump_config:
with open(args.dump_config, 'w') as f:
print(config, file=f)
main(config, args)

@ -1,48 +0,0 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Export for U2 model."""
from deepspeech.exps.u2.model import get_cfg_defaults
from deepspeech.exps.u2.model import U2Tester as Tester
from deepspeech.training.cli import default_argument_parser
from deepspeech.utils.utility import print_arguments
def main_sp(config, args):
exp = Tester(config, args)
exp.setup()
exp.run_export()
def main(config, args):
main_sp(config, args)
if __name__ == "__main__":
parser = default_argument_parser()
args = parser.parse_args()
print_arguments(args, globals())
# https://yaml.org/type/float.html
config = get_cfg_defaults()
if args.config:
config.merge_from_file(args.config)
if args.opts:
config.merge_from_list(args.opts)
config.freeze()
print(config)
if args.dump_config:
with open(args.dump_config, 'w') as f:
print(config, file=f)
main(config, args)

@ -18,17 +18,23 @@ from deepspeech.training.cli import default_argument_parser
from deepspeech.utils.dynamic_import import dynamic_import
from deepspeech.utils.utility import print_arguments
model_alias = {
model_test_alias = {
"u2": "deepspeech.exps.u2.model:U2Tester",
"u2_kaldi": "deepspeech.exps.u2_kaldi.model:U2Tester",
}
def main_sp(config, args):
class_obj = dynamic_import(args.model_name, model_alias)
class_obj = dynamic_import(args.model_name, model_test_alias)
exp = class_obj(config, args)
exp.setup()
if args.run_mode == 'test':
exp.run_test()
elif args.run_mode == 'export':
exp.run_export()
elif args.run_mode == 'align':
exp.run_align()
def main(config, args):
@ -42,6 +48,11 @@ if __name__ == "__main__":
type=str,
default='u2_kaldi',
help='model name, e.g: deepspeech2, u2, u2_kaldi, u2_st')
parser.add_argument(
'--run-mode',
type=str,
default='test',
help='run mode, e.g. test, align, export')
args = parser.parse_args()
print_arguments(args, globals())

@ -22,14 +22,14 @@ from deepspeech.training.cli import default_argument_parser
from deepspeech.utils.dynamic_import import dynamic_import
from deepspeech.utils.utility import print_arguments
model_alias = {
model_train_alias = {
"u2": "deepspeech.exps.u2.model:U2Trainer",
"u2_kaldi": "deepspeech.exps.u2_kaldi.model:U2Trainer",
}
def main_sp(config, args):
class_obj = dynamic_import(args.model_name, model_alias)
class_obj = dynamic_import(args.model_name, model_train_alias)
exp = class_obj(config, args)
exp.setup()
exp.run()

@ -97,14 +97,14 @@ class AugmentationPipeline():
ValueError: If the augmentation json config is in incorrect format".
"""
SPEC_TYPES = ('specaug')
def __init__(self, augmentation_config: str, random_seed: int=0):
self._rng = np.random.RandomState(random_seed)
self._spec_types = ('specaug')
if augmentation_config is None:
self.conf = {}
else:
self.conf = json.loads(augmentation_config)
self.conf = {'mode': 'sequential', 'process': []}
if augmentation_config:
process = json.loads(augmentation_config)
self.conf['process'] += process
self._augmentors, self._rates = self._parse_pipeline_from('all')
self._audio_augmentors, self._audio_rates = self._parse_pipeline_from(
@ -188,7 +188,7 @@ class AugmentationPipeline():
all_confs = []
for config in self.conf:
all_confs.append(config)
if config["type"] in self._spec_types:
if config["type"] in self.SPEC_TYPES:
feature_confs.append(config)
else:
audio_confs.append(config)

@ -21,7 +21,8 @@ mkdir -p ${output_dir}
# align dump in `result_file`
# .tier, .TextGrid dump in `dir of result_file`
python3 -u ${BIN_DIR}/alignment.py \
python3 -u ${BIN_DIR}/test.py \
--run_mode 'align' \
--device ${device} \
--nproc 1 \
--config ${config_path} \

@ -17,7 +17,8 @@ if [ ${ngpu} == 0 ];then
device=cpu
fi
python3 -u ${BIN_DIR}/export.py \
python3 -u ${BIN_DIR}/test.py \
--run_mode 'export' \
--device ${device} \
--nproc ${ngpu} \
--config ${config_path} \

@ -38,6 +38,7 @@ for type in attention ctc_greedy_search; do
batch_size=64
fi
python3 -u ${BIN_DIR}/test.py \
--run_mode test \
--device ${device} \
--nproc 1 \
--config ${config_path} \
@ -55,6 +56,7 @@ for type in ctc_prefix_beam_search attention_rescoring; do
echo "decoding ${type}"
batch_size=1
python3 -u ${BIN_DIR}/test.py \
--run_mode test \
--device ${device} \
--nproc 1 \
--config ${config_path} \

Loading…
Cancel
Save