refactor feature, dict and argument for new config format

pull/782/head
Hui Zhang 3 years ago
parent 27daa92a81
commit 561d5cf085

@ -42,6 +42,10 @@ ignore =
# these ignores are from flake8-comprehensions; please fix! # these ignores are from flake8-comprehensions; please fix!
C400,C401,C402,C403,C404,C405,C407,C411,C413,C414,C415 C400,C401,C402,C403,C404,C405,C407,C411,C413,C414,C415
per-file-ignores =
*/__init__.py: F401
# Specify the list of error codes you wish Flake8 to report. # Specify the list of error codes you wish Flake8 to report.
select = select =
E, E,

@ -30,6 +30,9 @@ def main(config, args):
if __name__ == "__main__": if __name__ == "__main__":
parser = default_argument_parser() parser = default_argument_parser()
# save jit model to
parser.add_argument(
"--export_path", type=str, help="path of the jit model to save")
parser.add_argument("--model_type") parser.add_argument("--model_type")
args = parser.parse_args() args = parser.parse_args()
if args.model_type is None: if args.model_type is None:

@ -31,6 +31,9 @@ def main(config, args):
if __name__ == "__main__": if __name__ == "__main__":
parser = default_argument_parser() parser = default_argument_parser()
parser.add_argument("--model_type") parser.add_argument("--model_type")
# save asr result to
parser.add_argument(
"--result_file", type=str, help="path of save the asr result")
args = parser.parse_args() args = parser.parse_args()
print_arguments(args, globals()) print_arguments(args, globals())
if args.model_type is None: if args.model_type is None:

@ -30,6 +30,9 @@ def main(config, args):
if __name__ == "__main__": if __name__ == "__main__":
parser = default_argument_parser() parser = default_argument_parser()
# save asr result to
parser.add_argument(
"--result_file", type=str, help="path of save the asr result")
args = parser.parse_args() args = parser.parse_args()
print_arguments(args, globals()) print_arguments(args, globals())

@ -30,6 +30,9 @@ def main(config, args):
if __name__ == "__main__": if __name__ == "__main__":
parser = default_argument_parser() parser = default_argument_parser()
# save jit model to
parser.add_argument(
"--export_path", type=str, help="path of the jit model to save")
args = parser.parse_args() args = parser.parse_args()
print_arguments(args, globals()) print_arguments(args, globals())

@ -34,6 +34,9 @@ def main(config, args):
if __name__ == "__main__": if __name__ == "__main__":
parser = default_argument_parser() parser = default_argument_parser()
# save asr result to
parser.add_argument(
"--result_file", type=str, help="path of save the asr result")
args = parser.parse_args() args = parser.parse_args()
print_arguments(args, globals()) print_arguments(args, globals())

@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
"""Evaluation for U2 model.""" """Evaluation for U2 model."""
import cProfile import cProfile
from yacs.config import CfgNode from yacs.config import CfgNode
from deepspeech.training.cli import default_argument_parser from deepspeech.training.cli import default_argument_parser
@ -54,6 +55,14 @@ if __name__ == "__main__":
type=str, type=str,
default='test', default='test',
help='run mode, e.g. test, align, export') help='run mode, e.g. test, align, export')
parser.add_argument(
'--dict-path', type=str, default=None, help='dict path.')
# save asr result to
parser.add_argument(
"--result-file", type=str, help="path of save the asr result")
# save jit model to
parser.add_argument(
"--export-path", type=str, help="path of the jit model to save")
args = parser.parse_args() args = parser.parse_args()
print_arguments(args, globals()) print_arguments(args, globals())

@ -25,6 +25,8 @@ import paddle
from paddle import distributed as dist from paddle import distributed as dist
from yacs.config import CfgNode from yacs.config import CfgNode
from deepspeech.frontend.featurizer import TextFeaturizer
from deepspeech.frontend.utility import load_dict
from deepspeech.io.dataloader import BatchDataLoader from deepspeech.io.dataloader import BatchDataLoader
from deepspeech.models.u2 import U2Model from deepspeech.models.u2 import U2Model
from deepspeech.training.optimizer import OptimizerFactory from deepspeech.training.optimizer import OptimizerFactory
@ -80,8 +82,8 @@ class U2Trainer(Trainer):
def train_batch(self, batch_index, batch_data, msg): def train_batch(self, batch_index, batch_data, msg):
train_conf = self.config.training train_conf = self.config.training
start = time.time() start = time.time()
utt, audio, audio_len, text, text_len = batch_data
utt, audio, audio_len, text, text_len = batch_data
loss, attention_loss, ctc_loss = self.model(audio, audio_len, text, loss, attention_loss, ctc_loss = self.model(audio, audio_len, text,
text_len) text_len)
# loss div by `batch_size * accum_grad` # loss div by `batch_size * accum_grad`
@ -124,6 +126,7 @@ class U2Trainer(Trainer):
valid_losses = defaultdict(list) valid_losses = defaultdict(list)
num_seen_utts = 1 num_seen_utts = 1
total_loss = 0.0 total_loss = 0.0
for i, batch in enumerate(self.valid_loader): for i, batch in enumerate(self.valid_loader):
utt, audio, audio_len, text, text_len = batch utt, audio, audio_len, text, text_len = batch
loss, attention_loss, ctc_loss = self.model(audio, audio_len, text, loss, attention_loss, ctc_loss = self.model(audio, audio_len, text,
@ -305,10 +308,8 @@ class U2Trainer(Trainer):
model_conf.output_dim = self.train_loader.vocab_size model_conf.output_dim = self.train_loader.vocab_size
model_conf.freeze() model_conf.freeze()
model = U2Model.from_config(model_conf) model = U2Model.from_config(model_conf)
if self.parallel: if self.parallel:
model = paddle.DataParallel(model) model = paddle.DataParallel(model)
logger.info(f"{model}") logger.info(f"{model}")
layer_tools.print_params(model, logger.info) layer_tools.print_params(model, logger.info)
@ -379,13 +380,13 @@ class U2Tester(U2Trainer):
def __init__(self, config, args): def __init__(self, config, args):
super().__init__(config, args) super().__init__(config, args)
def ordid2token(self, texts, texts_len): def id2token(self, texts, texts_len, text_feature):
""" ord() id to chr() chr """ """ ord() id to chr() chr """
trans = [] trans = []
for text, n in zip(texts, texts_len): for text, n in zip(texts, texts_len):
n = n.numpy().item() n = n.numpy().item()
ids = text[:n] ids = text[:n]
trans.append(''.join([chr(i) for i in ids])) trans.append(text_feature.defeaturize(ids.numpy().tolist()))
return trans return trans
def compute_metrics(self, def compute_metrics(self,
@ -401,8 +402,11 @@ class U2Tester(U2Trainer):
error_rate_func = error_rate.cer if cfg.error_rate_type == 'cer' else error_rate.wer error_rate_func = error_rate.cer if cfg.error_rate_type == 'cer' else error_rate.wer
start_time = time.time() start_time = time.time()
text_feature = self.test_loader.collate_fn.text_feature text_feature = TextFeaturizer(
target_transcripts = self.ordid2token(texts, texts_len) unit_type=self.config.collator.unit_type,
vocab_filepath=self.config.collator.vocab_filepath,
spm_model_prefix=self.config.collator.spm_model_prefix)
target_transcripts = self.id2token(texts, texts_len, text_feature)
result_transcripts = self.model.decode( result_transcripts = self.model.decode(
audio, audio,
audio_len, audio_len,
@ -450,7 +454,7 @@ class U2Tester(U2Trainer):
self.model.eval() self.model.eval()
logger.info(f"Test Total Examples: {len(self.test_loader.dataset)}") logger.info(f"Test Total Examples: {len(self.test_loader.dataset)}")
stride_ms = self.test_loader.collate_fn.stride_ms stride_ms = self.config.collator.stride_ms
error_rate_type = None error_rate_type = None
errors_sum, len_refs, num_ins = 0.0, 0, 0 errors_sum, len_refs, num_ins = 0.0, 0, 0
num_frames = 0.0 num_frames = 0.0
@ -525,8 +529,9 @@ class U2Tester(U2Trainer):
self.model.eval() self.model.eval()
logger.info(f"Align Total Examples: {len(self.align_loader.dataset)}") logger.info(f"Align Total Examples: {len(self.align_loader.dataset)}")
stride_ms = self.config.collate.stride_ms stride_ms = self.config.collater.stride_ms
token_dict = self.align_loader.collate_fn.vocab_list token_dict = self.args.char_list
with open(self.args.result_file, 'w') as fout: with open(self.args.result_file, 'w') as fout:
# one example in batch # one example in batch
for i, batch in enumerate(self.align_loader): for i, batch in enumerate(self.align_loader):
@ -613,6 +618,11 @@ class U2Tester(U2Trainer):
except KeyboardInterrupt: except KeyboardInterrupt:
sys.exit(-1) sys.exit(-1)
def setup_dict(self):
# load dictionary for debug log
self.args.char_list = load_dict(self.args.dict_path,
"maskctc" in self.args.model_name)
def setup(self): def setup(self):
"""Setup the experiment. """Setup the experiment.
""" """
@ -624,6 +634,8 @@ class U2Tester(U2Trainer):
self.setup_dataloader() self.setup_dataloader()
self.setup_model() self.setup_model()
self.setup_dict()
self.iteration = 0 self.iteration = 0
self.epoch = 0 self.epoch = 0

@ -30,6 +30,9 @@ def main(config, args):
if __name__ == "__main__": if __name__ == "__main__":
parser = default_argument_parser() parser = default_argument_parser()
# save jit model to
parser.add_argument(
"--export_path", type=str, help="path of the jit model to save")
args = parser.parse_args() args = parser.parse_args()
print_arguments(args, globals()) print_arguments(args, globals())

@ -34,6 +34,9 @@ def main(config, args):
if __name__ == "__main__": if __name__ == "__main__":
parser = default_argument_parser() parser = default_argument_parser()
# save asr result to
parser.add_argument(
"--result_file", type=str, help="path of save the asr result")
args = parser.parse_args() args = parser.parse_args()
print_arguments(args, globals()) print_arguments(args, globals())

@ -11,3 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .audio_featurizer import AudioFeaturizer #noqa: F401
from .speech_featurizer import SpeechFeaturizer
from .text_featurizer import TextFeaturizer

@ -18,7 +18,7 @@ from python_speech_features import logfbank
from python_speech_features import mfcc from python_speech_features import mfcc
class AudioFeaturizer(object): class AudioFeaturizer():
"""Audio featurizer, for extracting features from audio contents of """Audio featurizer, for extracting features from audio contents of
AudioSegment or SpeechSegment. AudioSegment or SpeechSegment.

@ -16,7 +16,7 @@ from deepspeech.frontend.featurizer.audio_featurizer import AudioFeaturizer
from deepspeech.frontend.featurizer.text_featurizer import TextFeaturizer from deepspeech.frontend.featurizer.text_featurizer import TextFeaturizer
class SpeechFeaturizer(object): class SpeechFeaturizer():
"""Speech featurizer, for extracting features from both audio and transcript """Speech featurizer, for extracting features from both audio and transcript
contents of SpeechSegment. contents of SpeechSegment.

@ -14,12 +14,19 @@
"""Contains the text featurizer class.""" """Contains the text featurizer class."""
import sentencepiece as spm import sentencepiece as spm
from deepspeech.frontend.utility import EOS from ..utility import EOS
from deepspeech.frontend.utility import UNK from ..utility import load_dict
from ..utility import UNK
__all__ = ["TextFeaturizer"]
class TextFeaturizer(object):
def __init__(self, unit_type, vocab_filepath, spm_model_prefix=None): class TextFeaturizer():
def __init__(self,
unit_type,
vocab_filepath,
spm_model_prefix=None,
maskctc=False):
"""Text featurizer, for processing or extracting features from text. """Text featurizer, for processing or extracting features from text.
Currently, it supports char/word/sentence-piece level tokenizing and conversion into Currently, it supports char/word/sentence-piece level tokenizing and conversion into
@ -34,11 +41,12 @@ class TextFeaturizer(object):
assert unit_type in ('char', 'spm', 'word') assert unit_type in ('char', 'spm', 'word')
self.unit_type = unit_type self.unit_type = unit_type
self.unk = UNK self.unk = UNK
self.maskctc = maskctc
if vocab_filepath: if vocab_filepath:
self._vocab_dict, self._id2token, self._vocab_list = self._load_vocabulary_from_file( self.vocab_dict, self._id2token, self.vocab_list, self.unk_id, self.eos_id = self._load_vocabulary_from_file(
vocab_filepath) vocab_filepath, maskctc)
self.unk_id = self._vocab_list.index(self.unk) self.vocab_size = len(self.vocab_list)
self.eos_id = self._vocab_list.index(EOS)
if unit_type == 'spm': if unit_type == 'spm':
spm_model = spm_model_prefix + '.model' spm_model = spm_model_prefix + '.model'
@ -67,7 +75,7 @@ class TextFeaturizer(object):
"""Convert text string to a list of token indices. """Convert text string to a list of token indices.
Args: Args:
text (str): Text to process. text (str): Text.
Returns: Returns:
List[int]: List of token indices. List[int]: List of token indices.
@ -75,8 +83,8 @@ class TextFeaturizer(object):
tokens = self.tokenize(text) tokens = self.tokenize(text)
ids = [] ids = []
for token in tokens: for token in tokens:
token = token if token in self._vocab_dict else self.unk token = token if token in self.vocab_dict else self.unk
ids.append(self._vocab_dict[token]) ids.append(self.vocab_dict[token])
return ids return ids
def defeaturize(self, idxs): def defeaturize(self, idxs):
@ -87,7 +95,7 @@ class TextFeaturizer(object):
idxs (List[int]): List of token indices. idxs (List[int]): List of token indices.
Returns: Returns:
str: Text to process. str: Text.
""" """
tokens = [] tokens = []
for idx in idxs: for idx in idxs:
@ -97,33 +105,6 @@ class TextFeaturizer(object):
text = self.detokenize(tokens) text = self.detokenize(tokens)
return text return text
@property
def vocab_size(self):
"""Return the vocabulary size.
:return: Vocabulary size.
:rtype: int
"""
return len(self._vocab_list)
@property
def vocab_list(self):
"""Return the vocabulary in list.
Returns:
List[str]: tokens.
"""
return self._vocab_list
@property
def vocab_dict(self):
"""Return the vocabulary in dict.
Returns:
Dict[str, int]: token str -> int
"""
return self._vocab_dict
def char_tokenize(self, text): def char_tokenize(self, text):
"""Character tokenizer. """Character tokenizer.
@ -206,14 +187,16 @@ class TextFeaturizer(object):
return decode(tokens) return decode(tokens)
def _load_vocabulary_from_file(self, vocab_filepath): def _load_vocabulary_from_file(self, vocab_filepath: str, maskctc: bool):
"""Load vocabulary from file.""" """Load vocabulary from file."""
vocab_lines = [] vocab_list = load_dict(vocab_filepath, maskctc)
with open(vocab_filepath, 'r', encoding='utf-8') as file: assert vocab_list is not None
vocab_lines.extend(file.readlines())
vocab_list = [line[:-1] for line in vocab_lines]
id2token = dict( id2token = dict(
[(idx, token) for (idx, token) in enumerate(vocab_list)]) [(idx, token) for (idx, token) in enumerate(vocab_list)])
token2id = dict( token2id = dict(
[(token, idx) for (idx, token) in enumerate(vocab_list)]) [(token, idx) for (idx, token) in enumerate(vocab_list)])
return token2id, id2token, vocab_list
unk_id = vocab_list.index(UNK)
eos_id = vocab_list.index(EOS)
return token2id, id2token, vocab_list, unk_id, eos_id

@ -15,6 +15,9 @@
import codecs import codecs
import json import json
import math import math
from typing import List
from typing import Optional
from typing import Text
import numpy as np import numpy as np
@ -23,16 +26,35 @@ from deepspeech.utils.log import Log
logger = Log(__name__).getlog() logger = Log(__name__).getlog()
__all__ = [ __all__ = [
"load_cmvn", "read_manifest", "rms_to_db", "rms_to_dbfs", "max_dbfs", "load_dict", "load_cmvn", "read_manifest", "rms_to_db", "rms_to_dbfs",
"mean_dbfs", "gain_db_to_ratio", "normalize_audio", "SOS", "EOS", "UNK", "max_dbfs", "mean_dbfs", "gain_db_to_ratio", "normalize_audio", "SOS",
"BLANK" "EOS", "UNK", "BLANK", "MASKCTC"
] ]
IGNORE_ID = -1 IGNORE_ID = -1
SOS = "<sos/eos>" # `sos` and `eos` using same token
SOS = "<eos>"
EOS = SOS EOS = SOS
UNK = "<unk>" UNK = "<unk>"
BLANK = "<blank>" BLANK = "<blank>"
MASKCTC = "<mask>"
def load_dict(dict_path: Optional[Text], maskctc=False) -> Optional[List[Text]]:
if dict_path is None:
return None
with open(dict_path, "r") as f:
dictionary = f.readlines()
char_list = [entry.split(" ")[0] for entry in dictionary]
if BLANK not in char_list:
char_list.insert(0, BLANK)
if EOS not in char_list:
char_list.append(EOS)
# for non-autoregressive maskctc model
if maskctc and MASKCTC not in char_list:
char_list.append(MASKCTC)
return char_list
def read_manifest( def read_manifest(
@ -47,12 +69,20 @@ def read_manifest(
Args: Args:
manifest_path ([type]): Manifest file to load and parse. manifest_path ([type]): Manifest file to load and parse.
max_input_len ([type], optional): maximum output seq length, in seconds for raw wav, in frame numbers for feature data. Defaults to float('inf'). max_input_len ([type], optional): maximum output seq length,
min_input_len (float, optional): minimum input seq length, in seconds for raw wav, in frame numbers for feature data. Defaults to 0.0. in seconds for raw wav, in frame numbers for feature data.
max_output_len (float, optional): maximum input seq length, in modeling units. Defaults to 500.0. Defaults to float('inf').
min_output_len (float, optional): minimum input seq length, in modeling units. Defaults to 0.0. min_input_len (float, optional): minimum input seq length,
max_output_input_ratio (float, optional): maximum output seq length/output seq length ratio. Defaults to 10.0. in seconds for raw wav, in frame numbers for feature data.
min_output_input_ratio (float, optional): minimum output seq length/output seq length ratio. Defaults to 0.05. Defaults to 0.0.
max_output_len (float, optional): maximum input seq length,
in modeling units. Defaults to 500.0.
min_output_len (float, optional): minimum input seq length,
in modeling units. Defaults to 0.0.
max_output_input_ratio (float, optional):
maximum output seq length/output seq length ratio. Defaults to 10.0.
min_output_input_ratio (float, optional):
minimum output seq length/output seq length ratio. Defaults to 0.05.
Raises: Raises:
IOError: If failed to parse the manifest. IOError: If failed to parse the manifest.

@ -47,18 +47,11 @@ def default_argument_parser():
# data and output # data and output
parser.add_argument("--config", metavar="FILE", help="path of the config file to overwrite to default config with.") parser.add_argument("--config", metavar="FILE", help="path of the config file to overwrite to default config with.")
parser.add_argument("--dump-config", metavar="FILE", help="dump config to yaml file.") parser.add_argument("--dump-config", metavar="FILE", help="dump config to yaml file.")
# parser.add_argument("--data", metavar="DATA_DIR", help="path to the datatset.")
parser.add_argument("--output", metavar="OUTPUT_DIR", help="path to save checkpoint and logs.") parser.add_argument("--output", metavar="OUTPUT_DIR", help="path to save checkpoint and logs.")
# load from saved checkpoint # load from saved checkpoint
parser.add_argument("--checkpoint_path", type=str, help="path of the checkpoint to load") parser.add_argument("--checkpoint_path", type=str, help="path of the checkpoint to load")
# save jit model to
parser.add_argument("--export_path", type=str, help="path of the jit model to save")
# save asr result to
parser.add_argument("--result_file", type=str, help="path of save the asr result")
# running # running
parser.add_argument("--device", type=str, default='gpu', choices=["cpu", "gpu"], parser.add_argument("--device", type=str, default='gpu', choices=["cpu", "gpu"],
help="device type to use, cpu and gpu are supported.") help="device type to use, cpu and gpu are supported.")

@ -33,4 +33,4 @@
}, },
"prob": 1.0 "prob": 1.0
} }
] ]

@ -3,17 +3,11 @@ data:
train_manifest: data/manifest.train train_manifest: data/manifest.train
dev_manifest: data/manifest.dev dev_manifest: data/manifest.dev
test_manifest: data/manifest.test-clean test_manifest: data/manifest.test-clean
min_input_len: 0.5 # second
max_input_len: 20.0 # second
min_output_len: 0.0 # tokens
max_output_len: 400.0 # tokens
min_output_input_ratio: 0.05
max_output_input_ratio: 10.0
collator: collator:
vocab_filepath: data/vocab.txt vocab_filepath: data/train_960_unigram5000_units.txt
unit_type: 'spm' unit_type: 'spm'
spm_model_prefix: 'data/bpe_unigram_5000' spm_model_prefix: 'data/train_960_unigram5000'
mean_std_filepath: "" mean_std_filepath: ""
augmentation_config: conf/augmentation.json augmentation_config: conf/augmentation.json
batch_size: 64 batch_size: 64

@ -1,7 +1,7 @@
#!/bin/bash #!/bin/bash
if [ $# != 2 ];then if [ $# != 3 ];then
echo "usage: ${0} config_path ckpt_path_prefix" echo "usage: ${0} config_path dict_path ckpt_path_prefix"
exit -1 exit -1
fi fi
@ -13,7 +13,8 @@ if [ ${ngpu} == 0 ];then
device=cpu device=cpu
fi fi
config_path=$1 config_path=$1
ckpt_prefix=$2 dict_path=$2
ckpt_prefix=$3
batch_size=1 batch_size=1
output_dir=${ckpt_prefix} output_dir=${ckpt_prefix}
@ -22,11 +23,13 @@ mkdir -p ${output_dir}
# align dump in `result_file` # align dump in `result_file`
# .tier, .TextGrid dump in `dir of result_file` # .tier, .TextGrid dump in `dir of result_file`
python3 -u ${BIN_DIR}/test.py \ python3 -u ${BIN_DIR}/test.py \
--run_mode 'align' \ --model-name 'u2_kaldi' \
--run-mode 'align' \
--dict-path ${dict_path} \
--device ${device} \ --device ${device} \
--nproc 1 \ --nproc 1 \
--config ${config_path} \ --config ${config_path} \
--result_file ${output_dir}/${type}.align \ --result-file ${output_dir}/${type}.align \
--checkpoint_path ${ckpt_prefix} \ --checkpoint_path ${ckpt_prefix} \
--opts decoding.batch_size ${batch_size} --opts decoding.batch_size ${batch_size}

@ -18,7 +18,8 @@ if [ ${ngpu} == 0 ];then
fi fi
python3 -u ${BIN_DIR}/test.py \ python3 -u ${BIN_DIR}/test.py \
--run_mode 'export' \ --model-name 'u2_kaldi' \
--run-mode 'export' \
--device ${device} \ --device ${device} \
--nproc ${ngpu} \ --nproc ${ngpu} \
--config ${config_path} \ --config ${config_path} \

@ -1,7 +1,7 @@
#!/bin/bash #!/bin/bash
if [ $# != 2 ];then if [ $# != 3 ];then
echo "usage: ${0} config_path ckpt_path_prefix" echo "usage: ${0} config_path dict_path ckpt_path_prefix"
exit -1 exit -1
fi fi
@ -14,7 +14,8 @@ if [ ${ngpu} == 0 ];then
fi fi
config_path=$1 config_path=$1
ckpt_prefix=$2 dict_path=$2
ckpt_prefix=$3
chunk_mode=false chunk_mode=false
if [[ ${config_path} =~ ^.*chunk_.*yaml$ ]];then if [[ ${config_path} =~ ^.*chunk_.*yaml$ ]];then
@ -38,11 +39,13 @@ for type in attention ctc_greedy_search; do
batch_size=64 batch_size=64
fi fi
python3 -u ${BIN_DIR}/test.py \ python3 -u ${BIN_DIR}/test.py \
--run_mode test \ --model-name u2_kaldi \
--run-mode test \
--dict-path ${dict_path} \
--device ${device} \ --device ${device} \
--nproc 1 \ --nproc 1 \
--config ${config_path} \ --config ${config_path} \
--result_file ${ckpt_prefix}.${type}.rsl \ --result-file ${ckpt_prefix}.${type}.rsl \
--checkpoint_path ${ckpt_prefix} \ --checkpoint_path ${ckpt_prefix} \
--opts decoding.decoding_method ${type} decoding.batch_size ${batch_size} --opts decoding.decoding_method ${type} decoding.batch_size ${batch_size}
@ -56,11 +59,13 @@ for type in ctc_prefix_beam_search attention_rescoring; do
echo "decoding ${type}" echo "decoding ${type}"
batch_size=1 batch_size=1
python3 -u ${BIN_DIR}/test.py \ python3 -u ${BIN_DIR}/test.py \
--run_mode test \ --model-name u2_kaldi \
--run-mode test \
--dict-path ${dict_path} \
--device ${device} \ --device ${device} \
--nproc 1 \ --nproc 1 \
--config ${config_path} \ --config ${config_path} \
--result_file ${ckpt_prefix}.${type}.rsl \ --result-file ${ckpt_prefix}.${type}.rsl \
--checkpoint_path ${ckpt_prefix} \ --checkpoint_path ${ckpt_prefix} \
--opts decoding.decoding_method ${type} decoding.batch_size ${batch_size} --opts decoding.decoding_method ${type} decoding.batch_size ${batch_size}

@ -5,6 +5,7 @@ source path.sh
stage=0 stage=0
stop_stage=100 stop_stage=100
conf_path=conf/transformer.yaml conf_path=conf/transformer.yaml
dict_path=data/train_960_unigram5000_units.txt
avg_num=5 avg_num=5
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1; source ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
@ -29,12 +30,12 @@ fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
# test ckpt avg_n # test ckpt avg_n
CUDA_VISIBLE_DEVICES=0 ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1 CUDA_VISIBLE_DEVICES=0 ./local/test.sh ${conf_path} ${dict_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1
fi fi
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
# ctc alignment of test data # ctc alignment of test data
CUDA_VISIBLE_DEVICES=0 ./local/align.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1 CUDA_VISIBLE_DEVICES=0 ./local/align.sh ${conf_path} ${dict_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1
fi fi
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then

@ -29,8 +29,7 @@
"adaptive_number_ratio": 0, "adaptive_number_ratio": 0,
"adaptive_size_ratio": 0, "adaptive_size_ratio": 0,
"max_n_time_masks": 20, "max_n_time_masks": 20,
"replace_with_zero": true, "replace_with_zero": true
"warp_mode": "PIL"
}, },
"prob": 1.0 "prob": 1.0
} }

Loading…
Cancel
Save