optimize the 1xt2x

pull/851/head
huangyuxin 3 years ago
parent 6dc32d185b
commit 30b3e237e2

@ -84,8 +84,6 @@ class TextFeaturizer():
tokens = self.tokenize(text) tokens = self.tokenize(text)
ids = [] ids = []
for token in tokens: for token in tokens:
if '' in self.vocab_dict and token == ' ':
token = ''
token = token if token in self.vocab_dict else self.unk token = token if token in self.vocab_dict else self.unk
ids.append(self.vocab_dict[token]) ids.append(self.vocab_dict[token])
return ids return ids
@ -201,7 +199,6 @@ class TextFeaturizer():
"""Load vocabulary from file.""" """Load vocabulary from file."""
vocab_list = load_dict(vocab_filepath, maskctc) vocab_list = load_dict(vocab_filepath, maskctc)
assert vocab_list is not None assert vocab_list is not None
assert SPACE in vocab_list
id2token = dict( id2token = dict(
[(idx, token) for (idx, token) in enumerate(vocab_list)]) [(idx, token) for (idx, token) in enumerate(vocab_list)])

@ -10,7 +10,7 @@ ckpt_dir=$1
. ${MAIN_ROOT}/utils/utility.sh . ${MAIN_ROOT}/utils/utility.sh
URL='https://deepspeech.bj.bcebos.com/mandarin_models/aishell_model_v1.8_to_v2.x.tar.gz' URL='https://deepspeech.bj.bcebos.com/mandarin_models/aishell_model_v1.8_to_v2.x.tar.gz'
MD5=4ade113c69ea291b8ce5ec6a03296659 MD5=87e7577d4bea737dbf3e8daab37aa808
TARGET=${ckpt_dir}/aishell_model_v1.8_to_v2.x.tar.gz TARGET=${ckpt_dir}/aishell_model_v1.8_to_v2.x.tar.gz

@ -7,6 +7,7 @@ stop_stage=100
conf_path=conf/deepspeech2.yaml conf_path=conf/deepspeech2.yaml
avg_num=1 avg_num=1
model_type=offline model_type=offline
gpus=2
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1; source ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
@ -22,6 +23,6 @@ fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# test ckpt avg_n # test ckpt avg_n
CUDA_VISIBLE_DEVICES=2 ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${v18_ckpt} ${model_type}|| exit -1 CUDA_VISIBLE_DEVICES=${gpus} ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${v18_ckpt} ${model_type}|| exit -1
fi fi

@ -10,7 +10,7 @@ ckpt_dir=$1
. ${MAIN_ROOT}/utils/utility.sh . ${MAIN_ROOT}/utils/utility.sh
URL='https://deepspeech.bj.bcebos.com/eng_models/baidu_en8k_v1.8_to_v2.x.tar.gz' URL='https://deepspeech.bj.bcebos.com/eng_models/baidu_en8k_v1.8_to_v2.x.tar.gz'
MD5=fdabeb6c96963ac85d9188f0275c6a1b MD5=c1676be8505cee436e6f312823e9008c
TARGET=${ckpt_dir}/baidu_en8k_v1.8_to_v2.x.tar.gz TARGET=${ckpt_dir}/baidu_en8k_v1.8_to_v2.x.tar.gz

@ -7,6 +7,7 @@ stop_stage=100
conf_path=conf/deepspeech2.yaml conf_path=conf/deepspeech2.yaml
avg_num=1 avg_num=1
model_type=offline model_type=offline
gpus=0
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1; source ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
@ -22,6 +23,6 @@ fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# test ckpt avg_n # test ckpt avg_n
CUDA_VISIBLE_DEVICES=0 ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${v18_ckpt} ${model_type}|| exit -1 CUDA_VISIBLE_DEVICES=${gpus} ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${v18_ckpt} ${model_type}|| exit -1
fi fi

@ -10,7 +10,7 @@ ckpt_dir=$1
. ${MAIN_ROOT}/utils/utility.sh . ${MAIN_ROOT}/utils/utility.sh
URL='https://deepspeech.bj.bcebos.com/eng_models/librispeech_v1.8_to_v2.x.tar.gz' URL='https://deepspeech.bj.bcebos.com/eng_models/librispeech_v1.8_to_v2.x.tar.gz'
MD5=7b0f582fe2f5a840b840e7ee52246bc5 MD5=a06d9aadb560ea113984dc98d67232c8
TARGET=${ckpt_dir}/librispeech_v1.8_to_v2.x.tar.gz TARGET=${ckpt_dir}/librispeech_v1.8_to_v2.x.tar.gz

@ -7,6 +7,7 @@ stop_stage=100
conf_path=conf/deepspeech2.yaml conf_path=conf/deepspeech2.yaml
avg_num=1 avg_num=1
model_type=offline model_type=offline
gpus=1
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1; source ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
@ -22,5 +23,5 @@ fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# test ckpt avg_n # test ckpt avg_n
CUDA_VISIBLE_DEVICES=1 ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${v18_ckpt} ${model_type}|| exit -1 CUDA_VISIBLE_DEVICES=${gpus} ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${v18_ckpt} ${model_type}|| exit -1
fi fi

@ -26,6 +26,7 @@ from src_deepspeech2x.models.ds2 import DeepSpeech2InferModel
from src_deepspeech2x.models.ds2 import DeepSpeech2Model from src_deepspeech2x.models.ds2 import DeepSpeech2Model
from yacs.config import CfgNode from yacs.config import CfgNode
from deepspeech.frontend.featurizer.text_featurizer import TextFeaturizer
from deepspeech.io.collator import SpeechCollator from deepspeech.io.collator import SpeechCollator
from deepspeech.io.dataset import ManifestDataset from deepspeech.io.dataset import ManifestDataset
from deepspeech.io.sampler import SortagradBatchSampler from deepspeech.io.sampler import SortagradBatchSampler
@ -38,7 +39,6 @@ from deepspeech.utils import error_rate
from deepspeech.utils import layer_tools from deepspeech.utils import layer_tools
from deepspeech.utils import mp_tools from deepspeech.utils import mp_tools
from deepspeech.utils.log import Log from deepspeech.utils.log import Log
#from deepspeech.utils.log import Autolog
logger = Log(__name__).getlog() logger = Log(__name__).getlog()
@ -272,6 +272,9 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
return default return default
def __init__(self, config, args): def __init__(self, config, args):
self._text_featurizer = TextFeaturizer(
unit_type=config.collator.unit_type, vocab_filepath=None)
super().__init__(config, args) super().__init__(config, args)
def ordid2token(self, texts, texts_len): def ordid2token(self, texts, texts_len):
@ -296,9 +299,6 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
error_rate_func = error_rate.cer if cfg.error_rate_type == 'cer' else error_rate.wer error_rate_func = error_rate.cer if cfg.error_rate_type == 'cer' else error_rate.wer
vocab_list = self.test_loader.collate_fn.vocab_list vocab_list = self.test_loader.collate_fn.vocab_list
if "" in vocab_list:
space_id = vocab_list.index("")
vocab_list[space_id] = " "
target_transcripts = self.ordid2token(texts, texts_len) target_transcripts = self.ordid2token(texts, texts_len)
@ -337,6 +337,10 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
cutoff_prob=cfg.cutoff_prob, cutoff_prob=cfg.cutoff_prob,
cutoff_top_n=cfg.cutoff_top_n, cutoff_top_n=cfg.cutoff_top_n,
num_processes=cfg.num_proc_bsearch) num_processes=cfg.num_proc_bsearch)
result_transcripts = [
self._text_featurizer.detokenize(item)
for item in result_transcripts
]
return result_transcripts return result_transcripts
@mp_tools.rank_zero_only @mp_tools.rank_zero_only
@ -367,8 +371,6 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
error_rate_type, num_ins, num_ins, errors_sum / len_refs) error_rate_type, num_ins, num_ins, errors_sum / len_refs)
logger.info(msg) logger.info(msg)
# self.autolog.report()
def run_test(self): def run_test(self):
self.resume_or_scratch() self.resume_or_scratch()
try: try:

Loading…
Cancel
Save