optimize the 1xt2x

pull/851/head
huangyuxin 3 years ago
parent 6dc32d185b
commit 30b3e237e2

@ -84,8 +84,6 @@ class TextFeaturizer():
tokens = self.tokenize(text)
ids = []
for token in tokens:
if '' in self.vocab_dict and token == ' ':
token = ''
token = token if token in self.vocab_dict else self.unk
ids.append(self.vocab_dict[token])
return ids
@ -201,7 +199,6 @@ class TextFeaturizer():
"""Load vocabulary from file."""
vocab_list = load_dict(vocab_filepath, maskctc)
assert vocab_list is not None
assert SPACE in vocab_list
id2token = dict(
[(idx, token) for (idx, token) in enumerate(vocab_list)])

@ -10,7 +10,7 @@ ckpt_dir=$1
. ${MAIN_ROOT}/utils/utility.sh
URL='https://deepspeech.bj.bcebos.com/mandarin_models/aishell_model_v1.8_to_v2.x.tar.gz'
MD5=4ade113c69ea291b8ce5ec6a03296659
MD5=87e7577d4bea737dbf3e8daab37aa808
TARGET=${ckpt_dir}/aishell_model_v1.8_to_v2.x.tar.gz

@ -7,6 +7,7 @@ stop_stage=100
conf_path=conf/deepspeech2.yaml
avg_num=1
model_type=offline
gpus=2
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
@ -22,6 +23,6 @@ fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# test ckpt avg_n
CUDA_VISIBLE_DEVICES=2 ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${v18_ckpt} ${model_type}|| exit -1
CUDA_VISIBLE_DEVICES=${gpus} ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${v18_ckpt} ${model_type}|| exit -1
fi

@ -10,7 +10,7 @@ ckpt_dir=$1
. ${MAIN_ROOT}/utils/utility.sh
URL='https://deepspeech.bj.bcebos.com/eng_models/baidu_en8k_v1.8_to_v2.x.tar.gz'
MD5=fdabeb6c96963ac85d9188f0275c6a1b
MD5=c1676be8505cee436e6f312823e9008c
TARGET=${ckpt_dir}/baidu_en8k_v1.8_to_v2.x.tar.gz

@ -7,6 +7,7 @@ stop_stage=100
conf_path=conf/deepspeech2.yaml
avg_num=1
model_type=offline
gpus=0
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
@ -22,6 +23,6 @@ fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# test ckpt avg_n
CUDA_VISIBLE_DEVICES=0 ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${v18_ckpt} ${model_type}|| exit -1
CUDA_VISIBLE_DEVICES=${gpus} ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${v18_ckpt} ${model_type}|| exit -1
fi

@ -10,7 +10,7 @@ ckpt_dir=$1
. ${MAIN_ROOT}/utils/utility.sh
URL='https://deepspeech.bj.bcebos.com/eng_models/librispeech_v1.8_to_v2.x.tar.gz'
MD5=7b0f582fe2f5a840b840e7ee52246bc5
MD5=a06d9aadb560ea113984dc98d67232c8
TARGET=${ckpt_dir}/librispeech_v1.8_to_v2.x.tar.gz

@ -7,6 +7,7 @@ stop_stage=100
conf_path=conf/deepspeech2.yaml
avg_num=1
model_type=offline
gpus=1
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
@ -22,5 +23,5 @@ fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# test ckpt avg_n
CUDA_VISIBLE_DEVICES=1 ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${v18_ckpt} ${model_type}|| exit -1
CUDA_VISIBLE_DEVICES=${gpus} ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${v18_ckpt} ${model_type}|| exit -1
fi

@ -26,6 +26,7 @@ from src_deepspeech2x.models.ds2 import DeepSpeech2InferModel
from src_deepspeech2x.models.ds2 import DeepSpeech2Model
from yacs.config import CfgNode
from deepspeech.frontend.featurizer.text_featurizer import TextFeaturizer
from deepspeech.io.collator import SpeechCollator
from deepspeech.io.dataset import ManifestDataset
from deepspeech.io.sampler import SortagradBatchSampler
@ -38,7 +39,6 @@ from deepspeech.utils import error_rate
from deepspeech.utils import layer_tools
from deepspeech.utils import mp_tools
from deepspeech.utils.log import Log
#from deepspeech.utils.log import Autolog
logger = Log(__name__).getlog()
@ -272,6 +272,9 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
return default
def __init__(self, config, args):
self._text_featurizer = TextFeaturizer(
unit_type=config.collator.unit_type, vocab_filepath=None)
super().__init__(config, args)
def ordid2token(self, texts, texts_len):
@ -296,9 +299,6 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
error_rate_func = error_rate.cer if cfg.error_rate_type == 'cer' else error_rate.wer
vocab_list = self.test_loader.collate_fn.vocab_list
if "" in vocab_list:
space_id = vocab_list.index("")
vocab_list[space_id] = " "
target_transcripts = self.ordid2token(texts, texts_len)
@ -337,6 +337,10 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
cutoff_prob=cfg.cutoff_prob,
cutoff_top_n=cfg.cutoff_top_n,
num_processes=cfg.num_proc_bsearch)
result_transcripts = [
self._text_featurizer.detokenize(item)
for item in result_transcripts
]
return result_transcripts
@mp_tools.rank_zero_only
@ -367,8 +371,6 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
error_rate_type, num_ins, num_ins, errors_sum / len_refs)
logger.info(msg)
# self.autolog.report()
def run_test(self):
self.resume_or_scratch()
try:

Loading…
Cancel
Save