From 561d5cf085b49baf27b47f13b15074d654acbce2 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Mon, 23 Aug 2021 12:11:43 +0000 Subject: [PATCH] refactor feature, dict and argument for new config format --- .flake8 | 4 + deepspeech/exps/deepspeech2/bin/export.py | 3 + deepspeech/exps/deepspeech2/bin/test.py | 3 + deepspeech/exps/u2/bin/alignment.py | 3 + deepspeech/exps/u2/bin/export.py | 3 + deepspeech/exps/u2/bin/test.py | 3 + deepspeech/exps/u2_kaldi/bin/test.py | 9 +++ deepspeech/exps/u2_kaldi/model.py | 32 +++++--- deepspeech/exps/u2_st/bin/export.py | 3 + deepspeech/exps/u2_st/bin/test.py | 3 + deepspeech/frontend/featurizer/__init__.py | 3 + .../frontend/featurizer/audio_featurizer.py | 2 +- .../frontend/featurizer/speech_featurizer.py | 2 +- .../frontend/featurizer/text_featurizer.py | 73 +++++++------------ deepspeech/frontend/utility.py | 50 ++++++++++--- deepspeech/training/cli.py | 7 -- examples/aishell/s0/conf/augmentation.json | 2 +- examples/librispeech/s2/conf/transformer.yaml | 10 +-- examples/librispeech/s2/local/align.sh | 13 ++-- examples/librispeech/s2/local/export.sh | 3 +- examples/librispeech/s2/local/test.sh | 19 +++-- examples/librispeech/s2/run.sh | 5 +- examples/tiny/s0/conf/augmentation.json | 3 +- 23 files changed, 158 insertions(+), 100 deletions(-) diff --git a/.flake8 b/.flake8 index 72289943..44685f23 100644 --- a/.flake8 +++ b/.flake8 @@ -42,6 +42,10 @@ ignore = # these ignores are from flake8-comprehensions; please fix! C400,C401,C402,C403,C404,C405,C407,C411,C413,C414,C415 + +per-file-ignores = + */__init__.py: F401 + # Specify the list of error codes you wish Flake8 to report. select = E, diff --git a/deepspeech/exps/deepspeech2/bin/export.py b/deepspeech/exps/deepspeech2/bin/export.py index f8764fde..7962d4fc 100644 --- a/deepspeech/exps/deepspeech2/bin/export.py +++ b/deepspeech/exps/deepspeech2/bin/export.py @@ -30,6 +30,9 @@ def main(config, args): if __name__ == "__main__": parser = default_argument_parser() + # save jit model to + parser.add_argument( + "--export_path", type=str, help="path of the jit model to save") parser.add_argument("--model_type") args = parser.parse_args() if args.model_type is None: diff --git a/deepspeech/exps/deepspeech2/bin/test.py b/deepspeech/exps/deepspeech2/bin/test.py index 376e18e3..f2fd3a39 100644 --- a/deepspeech/exps/deepspeech2/bin/test.py +++ b/deepspeech/exps/deepspeech2/bin/test.py @@ -31,6 +31,9 @@ def main(config, args): if __name__ == "__main__": parser = default_argument_parser() parser.add_argument("--model_type") + # save asr result to + parser.add_argument( + "--result_file", type=str, help="path of save the asr result") args = parser.parse_args() print_arguments(args, globals()) if args.model_type is None: diff --git a/deepspeech/exps/u2/bin/alignment.py b/deepspeech/exps/u2/bin/alignment.py index c1c9582f..cef9d1ab 100644 --- a/deepspeech/exps/u2/bin/alignment.py +++ b/deepspeech/exps/u2/bin/alignment.py @@ -30,6 +30,9 @@ def main(config, args): if __name__ == "__main__": parser = default_argument_parser() + # save asr result to + parser.add_argument( + "--result_file", type=str, help="path of save the asr result") args = parser.parse_args() print_arguments(args, globals()) diff --git a/deepspeech/exps/u2/bin/export.py b/deepspeech/exps/u2/bin/export.py index 292c7838..3dc41b70 100644 --- a/deepspeech/exps/u2/bin/export.py +++ b/deepspeech/exps/u2/bin/export.py @@ -30,6 +30,9 @@ def main(config, args): if __name__ == "__main__": parser = default_argument_parser() + # save jit model to + parser.add_argument( + "--export_path", type=str, help="path of the jit model to save") args = parser.parse_args() print_arguments(args, globals()) diff --git a/deepspeech/exps/u2/bin/test.py b/deepspeech/exps/u2/bin/test.py index c47f932c..f6127675 100644 --- a/deepspeech/exps/u2/bin/test.py +++ b/deepspeech/exps/u2/bin/test.py @@ -34,6 +34,9 @@ def main(config, args): if __name__ == "__main__": parser = default_argument_parser() + # save asr result to + parser.add_argument( + "--result_file", type=str, help="path of save the asr result") args = parser.parse_args() print_arguments(args, globals()) diff --git a/deepspeech/exps/u2_kaldi/bin/test.py b/deepspeech/exps/u2_kaldi/bin/test.py index c5064ec5..93a29ab1 100644 --- a/deepspeech/exps/u2_kaldi/bin/test.py +++ b/deepspeech/exps/u2_kaldi/bin/test.py @@ -13,6 +13,7 @@ # limitations under the License. """Evaluation for U2 model.""" import cProfile + from yacs.config import CfgNode from deepspeech.training.cli import default_argument_parser @@ -54,6 +55,14 @@ if __name__ == "__main__": type=str, default='test', help='run mode, e.g. test, align, export') + parser.add_argument( + '--dict-path', type=str, default=None, help='dict path.') + # save asr result to + parser.add_argument( + "--result-file", type=str, help="path of save the asr result") + # save jit model to + parser.add_argument( + "--export-path", type=str, help="path of the jit model to save") args = parser.parse_args() print_arguments(args, globals()) diff --git a/deepspeech/exps/u2_kaldi/model.py b/deepspeech/exps/u2_kaldi/model.py index 60f070a3..4f6ff4cb 100644 --- a/deepspeech/exps/u2_kaldi/model.py +++ b/deepspeech/exps/u2_kaldi/model.py @@ -25,6 +25,8 @@ import paddle from paddle import distributed as dist from yacs.config import CfgNode +from deepspeech.frontend.featurizer import TextFeaturizer +from deepspeech.frontend.utility import load_dict from deepspeech.io.dataloader import BatchDataLoader from deepspeech.models.u2 import U2Model from deepspeech.training.optimizer import OptimizerFactory @@ -80,8 +82,8 @@ class U2Trainer(Trainer): def train_batch(self, batch_index, batch_data, msg): train_conf = self.config.training start = time.time() - utt, audio, audio_len, text, text_len = batch_data + utt, audio, audio_len, text, text_len = batch_data loss, attention_loss, ctc_loss = self.model(audio, audio_len, text, text_len) # loss div by `batch_size * accum_grad` @@ -124,6 +126,7 @@ class U2Trainer(Trainer): valid_losses = defaultdict(list) num_seen_utts = 1 total_loss = 0.0 + for i, batch in enumerate(self.valid_loader): utt, audio, audio_len, text, text_len = batch loss, attention_loss, ctc_loss = self.model(audio, audio_len, text, @@ -305,10 +308,8 @@ class U2Trainer(Trainer): model_conf.output_dim = self.train_loader.vocab_size model_conf.freeze() model = U2Model.from_config(model_conf) - if self.parallel: model = paddle.DataParallel(model) - logger.info(f"{model}") layer_tools.print_params(model, logger.info) @@ -379,13 +380,13 @@ class U2Tester(U2Trainer): def __init__(self, config, args): super().__init__(config, args) - def ordid2token(self, texts, texts_len): + def id2token(self, texts, texts_len, text_feature): """ ord() id to chr() chr """ trans = [] for text, n in zip(texts, texts_len): n = n.numpy().item() ids = text[:n] - trans.append(''.join([chr(i) for i in ids])) + trans.append(text_feature.defeaturize(ids.numpy().tolist())) return trans def compute_metrics(self, @@ -401,8 +402,11 @@ class U2Tester(U2Trainer): error_rate_func = error_rate.cer if cfg.error_rate_type == 'cer' else error_rate.wer start_time = time.time() - text_feature = self.test_loader.collate_fn.text_feature - target_transcripts = self.ordid2token(texts, texts_len) + text_feature = TextFeaturizer( + unit_type=self.config.collator.unit_type, + vocab_filepath=self.config.collator.vocab_filepath, + spm_model_prefix=self.config.collator.spm_model_prefix) + target_transcripts = self.id2token(texts, texts_len, text_feature) result_transcripts = self.model.decode( audio, audio_len, @@ -450,7 +454,7 @@ class U2Tester(U2Trainer): self.model.eval() logger.info(f"Test Total Examples: {len(self.test_loader.dataset)}") - stride_ms = self.test_loader.collate_fn.stride_ms + stride_ms = self.config.collator.stride_ms error_rate_type = None errors_sum, len_refs, num_ins = 0.0, 0, 0 num_frames = 0.0 @@ -525,8 +529,9 @@ class U2Tester(U2Trainer): self.model.eval() logger.info(f"Align Total Examples: {len(self.align_loader.dataset)}") - stride_ms = self.config.collate.stride_ms - token_dict = self.align_loader.collate_fn.vocab_list + stride_ms = self.config.collater.stride_ms + token_dict = self.args.char_list + with open(self.args.result_file, 'w') as fout: # one example in batch for i, batch in enumerate(self.align_loader): @@ -613,6 +618,11 @@ class U2Tester(U2Trainer): except KeyboardInterrupt: sys.exit(-1) + def setup_dict(self): + # load dictionary for debug log + self.args.char_list = load_dict(self.args.dict_path, + "maskctc" in self.args.model_name) + def setup(self): """Setup the experiment. """ @@ -624,6 +634,8 @@ class U2Tester(U2Trainer): self.setup_dataloader() self.setup_model() + self.setup_dict() + self.iteration = 0 self.epoch = 0 diff --git a/deepspeech/exps/u2_st/bin/export.py b/deepspeech/exps/u2_st/bin/export.py index f566ba5b..c7eb5d03 100644 --- a/deepspeech/exps/u2_st/bin/export.py +++ b/deepspeech/exps/u2_st/bin/export.py @@ -30,6 +30,9 @@ def main(config, args): if __name__ == "__main__": parser = default_argument_parser() + # save jit model to + parser.add_argument( + "--export_path", type=str, help="path of the jit model to save") args = parser.parse_args() print_arguments(args, globals()) diff --git a/deepspeech/exps/u2_st/bin/test.py b/deepspeech/exps/u2_st/bin/test.py index d66c7a26..81197dec 100644 --- a/deepspeech/exps/u2_st/bin/test.py +++ b/deepspeech/exps/u2_st/bin/test.py @@ -34,6 +34,9 @@ def main(config, args): if __name__ == "__main__": parser = default_argument_parser() + # save asr result to + parser.add_argument( + "--result_file", type=str, help="path of save the asr result") args = parser.parse_args() print_arguments(args, globals()) diff --git a/deepspeech/frontend/featurizer/__init__.py b/deepspeech/frontend/featurizer/__init__.py index 185a92b8..6992700d 100644 --- a/deepspeech/frontend/featurizer/__init__.py +++ b/deepspeech/frontend/featurizer/__init__.py @@ -11,3 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from .audio_featurizer import AudioFeaturizer #noqa: F401 +from .speech_featurizer import SpeechFeaturizer +from .text_featurizer import TextFeaturizer diff --git a/deepspeech/frontend/featurizer/audio_featurizer.py b/deepspeech/frontend/featurizer/audio_featurizer.py index 7e9efa36..4c40c847 100644 --- a/deepspeech/frontend/featurizer/audio_featurizer.py +++ b/deepspeech/frontend/featurizer/audio_featurizer.py @@ -18,7 +18,7 @@ from python_speech_features import logfbank from python_speech_features import mfcc -class AudioFeaturizer(object): +class AudioFeaturizer(): """Audio featurizer, for extracting features from audio contents of AudioSegment or SpeechSegment. diff --git a/deepspeech/frontend/featurizer/speech_featurizer.py b/deepspeech/frontend/featurizer/speech_featurizer.py index 0fbbc564..5082850d 100644 --- a/deepspeech/frontend/featurizer/speech_featurizer.py +++ b/deepspeech/frontend/featurizer/speech_featurizer.py @@ -16,7 +16,7 @@ from deepspeech.frontend.featurizer.audio_featurizer import AudioFeaturizer from deepspeech.frontend.featurizer.text_featurizer import TextFeaturizer -class SpeechFeaturizer(object): +class SpeechFeaturizer(): """Speech featurizer, for extracting features from both audio and transcript contents of SpeechSegment. diff --git a/deepspeech/frontend/featurizer/text_featurizer.py b/deepspeech/frontend/featurizer/text_featurizer.py index 1ba6ac7f..e4364f70 100644 --- a/deepspeech/frontend/featurizer/text_featurizer.py +++ b/deepspeech/frontend/featurizer/text_featurizer.py @@ -14,12 +14,19 @@ """Contains the text featurizer class.""" import sentencepiece as spm -from deepspeech.frontend.utility import EOS -from deepspeech.frontend.utility import UNK +from ..utility import EOS +from ..utility import load_dict +from ..utility import UNK +__all__ = ["TextFeaturizer"] -class TextFeaturizer(object): - def __init__(self, unit_type, vocab_filepath, spm_model_prefix=None): + +class TextFeaturizer(): + def __init__(self, + unit_type, + vocab_filepath, + spm_model_prefix=None, + maskctc=False): """Text featurizer, for processing or extracting features from text. Currently, it supports char/word/sentence-piece level tokenizing and conversion into @@ -34,11 +41,12 @@ class TextFeaturizer(object): assert unit_type in ('char', 'spm', 'word') self.unit_type = unit_type self.unk = UNK + self.maskctc = maskctc + if vocab_filepath: - self._vocab_dict, self._id2token, self._vocab_list = self._load_vocabulary_from_file( - vocab_filepath) - self.unk_id = self._vocab_list.index(self.unk) - self.eos_id = self._vocab_list.index(EOS) + self.vocab_dict, self._id2token, self.vocab_list, self.unk_id, self.eos_id = self._load_vocabulary_from_file( + vocab_filepath, maskctc) + self.vocab_size = len(self.vocab_list) if unit_type == 'spm': spm_model = spm_model_prefix + '.model' @@ -67,7 +75,7 @@ class TextFeaturizer(object): """Convert text string to a list of token indices. Args: - text (str): Text to process. + text (str): Text. Returns: List[int]: List of token indices. @@ -75,8 +83,8 @@ class TextFeaturizer(object): tokens = self.tokenize(text) ids = [] for token in tokens: - token = token if token in self._vocab_dict else self.unk - ids.append(self._vocab_dict[token]) + token = token if token in self.vocab_dict else self.unk + ids.append(self.vocab_dict[token]) return ids def defeaturize(self, idxs): @@ -87,7 +95,7 @@ class TextFeaturizer(object): idxs (List[int]): List of token indices. Returns: - str: Text to process. + str: Text. """ tokens = [] for idx in idxs: @@ -97,33 +105,6 @@ class TextFeaturizer(object): text = self.detokenize(tokens) return text - @property - def vocab_size(self): - """Return the vocabulary size. - - :return: Vocabulary size. - :rtype: int - """ - return len(self._vocab_list) - - @property - def vocab_list(self): - """Return the vocabulary in list. - - Returns: - List[str]: tokens. - """ - return self._vocab_list - - @property - def vocab_dict(self): - """Return the vocabulary in dict. - - Returns: - Dict[str, int]: token str -> int - """ - return self._vocab_dict - def char_tokenize(self, text): """Character tokenizer. @@ -206,14 +187,16 @@ class TextFeaturizer(object): return decode(tokens) - def _load_vocabulary_from_file(self, vocab_filepath): + def _load_vocabulary_from_file(self, vocab_filepath: str, maskctc: bool): """Load vocabulary from file.""" - vocab_lines = [] - with open(vocab_filepath, 'r', encoding='utf-8') as file: - vocab_lines.extend(file.readlines()) - vocab_list = [line[:-1] for line in vocab_lines] + vocab_list = load_dict(vocab_filepath, maskctc) + assert vocab_list is not None + id2token = dict( [(idx, token) for (idx, token) in enumerate(vocab_list)]) token2id = dict( [(token, idx) for (idx, token) in enumerate(vocab_list)]) - return token2id, id2token, vocab_list + + unk_id = vocab_list.index(UNK) + eos_id = vocab_list.index(EOS) + return token2id, id2token, vocab_list, unk_id, eos_id diff --git a/deepspeech/frontend/utility.py b/deepspeech/frontend/utility.py index b2dd9601..3d0683b0 100644 --- a/deepspeech/frontend/utility.py +++ b/deepspeech/frontend/utility.py @@ -15,6 +15,9 @@ import codecs import json import math +from typing import List +from typing import Optional +from typing import Text import numpy as np @@ -23,16 +26,35 @@ from deepspeech.utils.log import Log logger = Log(__name__).getlog() __all__ = [ - "load_cmvn", "read_manifest", "rms_to_db", "rms_to_dbfs", "max_dbfs", - "mean_dbfs", "gain_db_to_ratio", "normalize_audio", "SOS", "EOS", "UNK", - "BLANK" + "load_dict", "load_cmvn", "read_manifest", "rms_to_db", "rms_to_dbfs", + "max_dbfs", "mean_dbfs", "gain_db_to_ratio", "normalize_audio", "SOS", + "EOS", "UNK", "BLANK", "MASKCTC" ] IGNORE_ID = -1 -SOS = "" +# `sos` and `eos` using same token +SOS = "" EOS = SOS UNK = "" BLANK = "" +MASKCTC = "" + + +def load_dict(dict_path: Optional[Text], maskctc=False) -> Optional[List[Text]]: + if dict_path is None: + return None + + with open(dict_path, "r") as f: + dictionary = f.readlines() + char_list = [entry.split(" ")[0] for entry in dictionary] + if BLANK not in char_list: + char_list.insert(0, BLANK) + if EOS not in char_list: + char_list.append(EOS) + # for non-autoregressive maskctc model + if maskctc and MASKCTC not in char_list: + char_list.append(MASKCTC) + return char_list def read_manifest( @@ -47,12 +69,20 @@ def read_manifest( Args: manifest_path ([type]): Manifest file to load and parse. - max_input_len ([type], optional): maximum output seq length, in seconds for raw wav, in frame numbers for feature data. Defaults to float('inf'). - min_input_len (float, optional): minimum input seq length, in seconds for raw wav, in frame numbers for feature data. Defaults to 0.0. - max_output_len (float, optional): maximum input seq length, in modeling units. Defaults to 500.0. - min_output_len (float, optional): minimum input seq length, in modeling units. Defaults to 0.0. - max_output_input_ratio (float, optional): maximum output seq length/output seq length ratio. Defaults to 10.0. - min_output_input_ratio (float, optional): minimum output seq length/output seq length ratio. Defaults to 0.05. + max_input_len ([type], optional): maximum output seq length, + in seconds for raw wav, in frame numbers for feature data. + Defaults to float('inf'). + min_input_len (float, optional): minimum input seq length, + in seconds for raw wav, in frame numbers for feature data. + Defaults to 0.0. + max_output_len (float, optional): maximum input seq length, + in modeling units. Defaults to 500.0. + min_output_len (float, optional): minimum input seq length, + in modeling units. Defaults to 0.0. + max_output_input_ratio (float, optional): + maximum output seq length/output seq length ratio. Defaults to 10.0. + min_output_input_ratio (float, optional): + minimum output seq length/output seq length ratio. Defaults to 0.05. Raises: IOError: If failed to parse the manifest. diff --git a/deepspeech/training/cli.py b/deepspeech/training/cli.py index b83d989d..9d145645 100644 --- a/deepspeech/training/cli.py +++ b/deepspeech/training/cli.py @@ -47,18 +47,11 @@ def default_argument_parser(): # data and output parser.add_argument("--config", metavar="FILE", help="path of the config file to overwrite to default config with.") parser.add_argument("--dump-config", metavar="FILE", help="dump config to yaml file.") - # parser.add_argument("--data", metavar="DATA_DIR", help="path to the datatset.") parser.add_argument("--output", metavar="OUTPUT_DIR", help="path to save checkpoint and logs.") # load from saved checkpoint parser.add_argument("--checkpoint_path", type=str, help="path of the checkpoint to load") - # save jit model to - parser.add_argument("--export_path", type=str, help="path of the jit model to save") - - # save asr result to - parser.add_argument("--result_file", type=str, help="path of save the asr result") - # running parser.add_argument("--device", type=str, default='gpu', choices=["cpu", "gpu"], help="device type to use, cpu and gpu are supported.") diff --git a/examples/aishell/s0/conf/augmentation.json b/examples/aishell/s0/conf/augmentation.json index 39afe4e6..ac8a1c53 100644 --- a/examples/aishell/s0/conf/augmentation.json +++ b/examples/aishell/s0/conf/augmentation.json @@ -33,4 +33,4 @@ }, "prob": 1.0 } -] \ No newline at end of file +] diff --git a/examples/librispeech/s2/conf/transformer.yaml b/examples/librispeech/s2/conf/transformer.yaml index 7710d706..ded4f240 100644 --- a/examples/librispeech/s2/conf/transformer.yaml +++ b/examples/librispeech/s2/conf/transformer.yaml @@ -3,17 +3,11 @@ data: train_manifest: data/manifest.train dev_manifest: data/manifest.dev test_manifest: data/manifest.test-clean - min_input_len: 0.5 # second - max_input_len: 20.0 # second - min_output_len: 0.0 # tokens - max_output_len: 400.0 # tokens - min_output_input_ratio: 0.05 - max_output_input_ratio: 10.0 collator: - vocab_filepath: data/vocab.txt + vocab_filepath: data/train_960_unigram5000_units.txt unit_type: 'spm' - spm_model_prefix: 'data/bpe_unigram_5000' + spm_model_prefix: 'data/train_960_unigram5000' mean_std_filepath: "" augmentation_config: conf/augmentation.json batch_size: 64 diff --git a/examples/librispeech/s2/local/align.sh b/examples/librispeech/s2/local/align.sh index 94146ccf..b3d8fa5f 100755 --- a/examples/librispeech/s2/local/align.sh +++ b/examples/librispeech/s2/local/align.sh @@ -1,7 +1,7 @@ #!/bin/bash -if [ $# != 2 ];then - echo "usage: ${0} config_path ckpt_path_prefix" +if [ $# != 3 ];then + echo "usage: ${0} config_path dict_path ckpt_path_prefix" exit -1 fi @@ -13,7 +13,8 @@ if [ ${ngpu} == 0 ];then device=cpu fi config_path=$1 -ckpt_prefix=$2 +dict_path=$2 +ckpt_prefix=$3 batch_size=1 output_dir=${ckpt_prefix} @@ -22,11 +23,13 @@ mkdir -p ${output_dir} # align dump in `result_file` # .tier, .TextGrid dump in `dir of result_file` python3 -u ${BIN_DIR}/test.py \ ---run_mode 'align' \ +--model-name 'u2_kaldi' \ +--run-mode 'align' \ +--dict-path ${dict_path} \ --device ${device} \ --nproc 1 \ --config ${config_path} \ ---result_file ${output_dir}/${type}.align \ +--result-file ${output_dir}/${type}.align \ --checkpoint_path ${ckpt_prefix} \ --opts decoding.batch_size ${batch_size} diff --git a/examples/librispeech/s2/local/export.sh b/examples/librispeech/s2/local/export.sh index 7e42e011..efa70a2b 100755 --- a/examples/librispeech/s2/local/export.sh +++ b/examples/librispeech/s2/local/export.sh @@ -18,7 +18,8 @@ if [ ${ngpu} == 0 ];then fi python3 -u ${BIN_DIR}/test.py \ ---run_mode 'export' \ +--model-name 'u2_kaldi' \ +--run-mode 'export' \ --device ${device} \ --nproc ${ngpu} \ --config ${config_path} \ diff --git a/examples/librispeech/s2/local/test.sh b/examples/librispeech/s2/local/test.sh index 762211c2..efd06f35 100755 --- a/examples/librispeech/s2/local/test.sh +++ b/examples/librispeech/s2/local/test.sh @@ -1,7 +1,7 @@ #!/bin/bash -if [ $# != 2 ];then - echo "usage: ${0} config_path ckpt_path_prefix" +if [ $# != 3 ];then + echo "usage: ${0} config_path dict_path ckpt_path_prefix" exit -1 fi @@ -14,7 +14,8 @@ if [ ${ngpu} == 0 ];then fi config_path=$1 -ckpt_prefix=$2 +dict_path=$2 +ckpt_prefix=$3 chunk_mode=false if [[ ${config_path} =~ ^.*chunk_.*yaml$ ]];then @@ -38,11 +39,13 @@ for type in attention ctc_greedy_search; do batch_size=64 fi python3 -u ${BIN_DIR}/test.py \ - --run_mode test \ + --model-name u2_kaldi \ + --run-mode test \ + --dict-path ${dict_path} \ --device ${device} \ --nproc 1 \ --config ${config_path} \ - --result_file ${ckpt_prefix}.${type}.rsl \ + --result-file ${ckpt_prefix}.${type}.rsl \ --checkpoint_path ${ckpt_prefix} \ --opts decoding.decoding_method ${type} decoding.batch_size ${batch_size} @@ -56,11 +59,13 @@ for type in ctc_prefix_beam_search attention_rescoring; do echo "decoding ${type}" batch_size=1 python3 -u ${BIN_DIR}/test.py \ - --run_mode test \ + --model-name u2_kaldi \ + --run-mode test \ + --dict-path ${dict_path} \ --device ${device} \ --nproc 1 \ --config ${config_path} \ - --result_file ${ckpt_prefix}.${type}.rsl \ + --result-file ${ckpt_prefix}.${type}.rsl \ --checkpoint_path ${ckpt_prefix} \ --opts decoding.decoding_method ${type} decoding.batch_size ${batch_size} diff --git a/examples/librispeech/s2/run.sh b/examples/librispeech/s2/run.sh index def10ab0..26398dd1 100755 --- a/examples/librispeech/s2/run.sh +++ b/examples/librispeech/s2/run.sh @@ -5,6 +5,7 @@ source path.sh stage=0 stop_stage=100 conf_path=conf/transformer.yaml +dict_path=data/train_960_unigram5000_units.txt avg_num=5 source ${MAIN_ROOT}/utils/parse_options.sh || exit 1; @@ -29,12 +30,12 @@ fi if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then # test ckpt avg_n - CUDA_VISIBLE_DEVICES=0 ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1 + CUDA_VISIBLE_DEVICES=0 ./local/test.sh ${conf_path} ${dict_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1 fi if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then # ctc alignment of test data - CUDA_VISIBLE_DEVICES=0 ./local/align.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1 + CUDA_VISIBLE_DEVICES=0 ./local/align.sh ${conf_path} ${dict_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1 fi if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then diff --git a/examples/tiny/s0/conf/augmentation.json b/examples/tiny/s0/conf/augmentation.json index 83705516..4480307b 100644 --- a/examples/tiny/s0/conf/augmentation.json +++ b/examples/tiny/s0/conf/augmentation.json @@ -29,8 +29,7 @@ "adaptive_number_ratio": 0, "adaptive_size_ratio": 0, "max_n_time_masks": 20, - "replace_with_zero": true, - "warp_mode": "PIL" + "replace_with_zero": true }, "prob": 1.0 }