|
|
|
@ -12,6 +12,7 @@
|
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
"""Contains DeepSpeech2 and DeepSpeech2Online model."""
|
|
|
|
|
import os
|
|
|
|
|
import time
|
|
|
|
|
from collections import defaultdict
|
|
|
|
|
from pathlib import Path
|
|
|
|
@ -398,40 +399,10 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
|
|
|
|
|
self.output_dir = output_dir
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DeepSpeech2ExportTester(DeepSpeech2Trainer):
|
|
|
|
|
@classmethod
|
|
|
|
|
def params(cls, config: Optional[CfgNode]=None) -> CfgNode:
|
|
|
|
|
# testing config
|
|
|
|
|
default = CfgNode(
|
|
|
|
|
dict(
|
|
|
|
|
alpha=2.5, # Coef of LM for beam search.
|
|
|
|
|
beta=0.3, # Coef of WC for beam search.
|
|
|
|
|
cutoff_prob=1.0, # Cutoff probability for pruning.
|
|
|
|
|
cutoff_top_n=40, # Cutoff number for pruning.
|
|
|
|
|
lang_model_path='models/lm/common_crawl_00.prune01111.trie.klm', # Filepath for language model.
|
|
|
|
|
decoding_method='ctc_beam_search', # Decoding method. Options: ctc_beam_search, ctc_greedy
|
|
|
|
|
error_rate_type='wer', # Error rate type for evaluation. Options `wer`, 'cer'
|
|
|
|
|
num_proc_bsearch=8, # # of CPUs for beam search.
|
|
|
|
|
beam_size=500, # Beam search width.
|
|
|
|
|
batch_size=128, # decoding batch size
|
|
|
|
|
))
|
|
|
|
|
|
|
|
|
|
if config is not None:
|
|
|
|
|
config.merge_from_other_cfg(default)
|
|
|
|
|
return default
|
|
|
|
|
|
|
|
|
|
class DeepSpeech2ExportTester(DeepSpeech2Tester):
|
|
|
|
|
def __init__(self, config, args):
|
|
|
|
|
super().__init__(config, args)
|
|
|
|
|
|
|
|
|
|
def ordid2token(self, texts, texts_len):
|
|
|
|
|
""" ord() id to chr() chr """
|
|
|
|
|
trans = []
|
|
|
|
|
for text, n in zip(texts, texts_len):
|
|
|
|
|
n = n.numpy().item()
|
|
|
|
|
ids = text[:n]
|
|
|
|
|
trans.append(''.join([chr(i) for i in ids]))
|
|
|
|
|
return trans
|
|
|
|
|
|
|
|
|
|
def compute_metrics(self,
|
|
|
|
|
utts,
|
|
|
|
|
audio,
|
|
|
|
@ -447,9 +418,48 @@ class DeepSpeech2ExportTester(DeepSpeech2Trainer):
|
|
|
|
|
|
|
|
|
|
vocab_list = self.test_loader.collate_fn.vocab_list
|
|
|
|
|
|
|
|
|
|
batch_size = self.config.decoding.batch_size
|
|
|
|
|
if self.args.model_type == "online":
|
|
|
|
|
output_probs_branch, output_lens_branch = self.static_forward_online(
|
|
|
|
|
audio, audio_len)
|
|
|
|
|
elif self.args.model_type == "offline":
|
|
|
|
|
output_probs_branch, output_lens_branch = self.static_forward_offline(
|
|
|
|
|
audio, audio_len)
|
|
|
|
|
else:
|
|
|
|
|
raise Exception("wrong model type")
|
|
|
|
|
self.predictor.clear_intermediate_tensor()
|
|
|
|
|
self.predictor.try_shrink_memory()
|
|
|
|
|
self.model.decoder.init_decode(cfg.alpha, cfg.beta, cfg.lang_model_path,
|
|
|
|
|
vocab_list, cfg.decoding_method)
|
|
|
|
|
|
|
|
|
|
result_transcripts = self.model.decoder.decode_probs(
|
|
|
|
|
output_probs_branch.numpy(), output_lens_branch, vocab_list,
|
|
|
|
|
cfg.decoding_method, cfg.lang_model_path, cfg.alpha, cfg.beta,
|
|
|
|
|
cfg.beam_size, cfg.cutoff_prob, cfg.cutoff_top_n,
|
|
|
|
|
cfg.num_proc_bsearch)
|
|
|
|
|
|
|
|
|
|
output_prob_list = []
|
|
|
|
|
target_transcripts = self.ordid2token(texts, texts_len)
|
|
|
|
|
for utt, target, result in zip(utts, target_transcripts,
|
|
|
|
|
result_transcripts):
|
|
|
|
|
errors, len_ref = errors_func(target, result)
|
|
|
|
|
errors_sum += errors
|
|
|
|
|
len_refs += len_ref
|
|
|
|
|
num_ins += 1
|
|
|
|
|
if fout:
|
|
|
|
|
fout.write(utt + " " + result + "\n")
|
|
|
|
|
logger.info("\nTarget Transcription: %s\nOutput Transcription: %s" %
|
|
|
|
|
(target, result))
|
|
|
|
|
logger.info("Current error rate [%s] = %f" %
|
|
|
|
|
(cfg.error_rate_type, error_rate_func(target, result)))
|
|
|
|
|
|
|
|
|
|
return dict(
|
|
|
|
|
errors_sum=errors_sum,
|
|
|
|
|
len_refs=len_refs,
|
|
|
|
|
num_ins=num_ins,
|
|
|
|
|
error_rate=errors_sum / len_refs,
|
|
|
|
|
error_rate_type=cfg.error_rate_type)
|
|
|
|
|
|
|
|
|
|
def static_forward_online(self, audio, audio_len):
|
|
|
|
|
output_probs_list = []
|
|
|
|
|
output_lens_list = []
|
|
|
|
|
decoder_chunk_size = 8
|
|
|
|
|
subsampling_rate = self.model.encoder.conv.subsampling_rate
|
|
|
|
@ -459,15 +469,18 @@ class DeepSpeech2ExportTester(DeepSpeech2Trainer):
|
|
|
|
|
) * subsampling_rate + receptive_field_length
|
|
|
|
|
|
|
|
|
|
x_batch = audio.numpy()
|
|
|
|
|
batch_size = x_batch.shape[0]
|
|
|
|
|
x_len_batch = audio_len.numpy().astype(np.int64)
|
|
|
|
|
max_len_batch = x_batch.shape[1]
|
|
|
|
|
batch_padding_len = chunk_stride - (
|
|
|
|
|
max_len_batch - chunk_size
|
|
|
|
|
) % chunk_stride # The length of padding for the batch
|
|
|
|
|
x_list = np.split(x_batch, x_batch.shape[0], axis=0)
|
|
|
|
|
x_list = np.split(x_batch, batch_size, axis=0)
|
|
|
|
|
x_len_list = np.split(x_len_batch, x_batch.shape[0], axis=0)
|
|
|
|
|
|
|
|
|
|
for x, x_len in zip(x_list, x_len_list):
|
|
|
|
|
self.autolog.times.start()
|
|
|
|
|
self.autolog.times.stamp()
|
|
|
|
|
assert (chunk_size <= x_len[0])
|
|
|
|
|
|
|
|
|
|
eouts_chunk_list = []
|
|
|
|
@ -536,38 +549,40 @@ class DeepSpeech2ExportTester(DeepSpeech2Trainer):
|
|
|
|
|
output_state_c_handle = self.predictor.get_output_handle(
|
|
|
|
|
output_names[3])
|
|
|
|
|
self.predictor.run()
|
|
|
|
|
output_chunk_prob = output_handle.copy_to_cpu()
|
|
|
|
|
output_chunk_probs = output_handle.copy_to_cpu()
|
|
|
|
|
output_chunk_lens = output_lens_handle.copy_to_cpu()
|
|
|
|
|
chunk_state_h_box = output_state_h_handle.copy_to_cpu()
|
|
|
|
|
chunk_state_c_box = output_state_c_handle.copy_to_cpu()
|
|
|
|
|
output_chunk_prob = paddle.to_tensor(output_chunk_prob)
|
|
|
|
|
output_chunk_probs = paddle.to_tensor(output_chunk_probs)
|
|
|
|
|
output_chunk_lens = paddle.to_tensor(output_chunk_lens)
|
|
|
|
|
|
|
|
|
|
probs_chunk_list.append(output_chunk_prob)
|
|
|
|
|
probs_chunk_list.append(output_chunk_probs)
|
|
|
|
|
probs_chunk_lens_list.append(output_chunk_lens)
|
|
|
|
|
output_prob = paddle.concat(probs_chunk_list, axis=1)
|
|
|
|
|
output_probs = paddle.concat(probs_chunk_list, axis=1)
|
|
|
|
|
output_lens = paddle.add_n(probs_chunk_lens_list)
|
|
|
|
|
output_prob_padding_len = max_len_batch + batch_padding_len - output_prob.shape[
|
|
|
|
|
output_probs_padding_len = max_len_batch + batch_padding_len - output_probs.shape[
|
|
|
|
|
1]
|
|
|
|
|
output_prob_padding = paddle.zeros(
|
|
|
|
|
(1, output_prob_padding_len, output_prob.shape[2]),
|
|
|
|
|
output_probs_padding = paddle.zeros(
|
|
|
|
|
(1, output_probs_padding_len, output_probs.shape[2]),
|
|
|
|
|
dtype="float32") # The prob padding for a piece of utterance
|
|
|
|
|
output_prob = paddle.concat(
|
|
|
|
|
[output_prob, output_prob_padding], axis=1)
|
|
|
|
|
output_prob_list.append(output_prob)
|
|
|
|
|
output_probs = paddle.concat(
|
|
|
|
|
[output_probs, output_probs_padding], axis=1)
|
|
|
|
|
output_probs_list.append(output_probs)
|
|
|
|
|
output_lens_list.append(output_lens)
|
|
|
|
|
output_prob_branch = paddle.concat(output_prob_list, axis=0)
|
|
|
|
|
self.autolog.times.stamp()
|
|
|
|
|
self.autolog.times.stamp()
|
|
|
|
|
self.autolog.times.end()
|
|
|
|
|
output_probs_branch = paddle.concat(output_probs_list, axis=0)
|
|
|
|
|
output_lens_branch = paddle.concat(output_lens_list, axis=0)
|
|
|
|
|
"""
|
|
|
|
|
return output_probs_branch, output_lens_branch
|
|
|
|
|
|
|
|
|
|
def static_forward_offline(self, audio, audio_len):
|
|
|
|
|
x = audio.numpy()
|
|
|
|
|
x_len = audio_len.numpy().astype(np.int64)
|
|
|
|
|
|
|
|
|
|
input_names = self.predictor.get_input_names()
|
|
|
|
|
audio_handle = self.predictor.get_input_handle(input_names[0])
|
|
|
|
|
audio_len_handle = self.predictor.get_input_handle(input_names[1])
|
|
|
|
|
h_box_handle = self.predictor.get_input_handle(input_names[2])
|
|
|
|
|
c_box_handle = self.predictor.get_input_handle(input_names[3])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
audio_handle.reshape(x.shape)
|
|
|
|
|
audio_handle.copy_from_cpu(x)
|
|
|
|
@ -575,100 +590,21 @@ class DeepSpeech2ExportTester(DeepSpeech2Trainer):
|
|
|
|
|
audio_len_handle.reshape(x_len.shape)
|
|
|
|
|
audio_len_handle.copy_from_cpu(x_len)
|
|
|
|
|
|
|
|
|
|
init_state_h_box = np.zeros((self.config.model.num_rnn_layers, audio.shape[0], self.config.model.rnn_layer_size), dtype=np.float32)
|
|
|
|
|
init_state_c_box = np.zeros((self.config.model.num_rnn_layers, audio.shape[0], self.config.model.rnn_layer_size), dtype=np.float32)
|
|
|
|
|
h_box_handle.reshape(init_state_h_box.shape)
|
|
|
|
|
h_box_handle.copy_from_cpu(init_state_h_box)
|
|
|
|
|
|
|
|
|
|
c_box_handle.reshape(init_state_c_box.shape)
|
|
|
|
|
c_box_handle.copy_from_cpu(init_state_c_box)
|
|
|
|
|
|
|
|
|
|
#self.autolog.times.start()
|
|
|
|
|
#self.autolog.times.stamp()
|
|
|
|
|
self.autolog.times.start()
|
|
|
|
|
self.autolog.times.stamp()
|
|
|
|
|
self.predictor.run()
|
|
|
|
|
self.autolog.times.stamp()
|
|
|
|
|
self.autolog.times.stamp()
|
|
|
|
|
self.autolog.times.end()
|
|
|
|
|
|
|
|
|
|
output_names = self.predictor.get_output_names()
|
|
|
|
|
output_handle = self.predictor.get_output_handle(output_names[0])
|
|
|
|
|
output_lens_handle = self.predictor.get_output_handle(output_names[1])
|
|
|
|
|
output_state_h_handle = self.predictor.get_output_handle(output_names[2])
|
|
|
|
|
output_state_c_handle = self.predictor.get_output_handle(output_names[3])
|
|
|
|
|
output_prob = output_handle.copy_to_cpu()
|
|
|
|
|
output_probs = output_handle.copy_to_cpu()
|
|
|
|
|
output_lens = output_lens_handle.copy_to_cpu()
|
|
|
|
|
output_stata_h_box = output_state_h_handle.copy_to_cpu()
|
|
|
|
|
output_stata_c_box = output_state_c_handle.copy_to_cpu()
|
|
|
|
|
output_prob_branch = paddle.to_tensor(output_prob)
|
|
|
|
|
output_probs_branch = paddle.to_tensor(output_probs)
|
|
|
|
|
output_lens_branch = paddle.to_tensor(output_lens)
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
result_transcripts = self.model.decode_by_probs(
|
|
|
|
|
output_prob_branch,
|
|
|
|
|
output_lens_branch,
|
|
|
|
|
vocab_list,
|
|
|
|
|
decoding_method=cfg.decoding_method,
|
|
|
|
|
lang_model_path=cfg.lang_model_path,
|
|
|
|
|
beam_alpha=cfg.alpha,
|
|
|
|
|
beam_beta=cfg.beta,
|
|
|
|
|
beam_size=cfg.beam_size,
|
|
|
|
|
cutoff_prob=cfg.cutoff_prob,
|
|
|
|
|
cutoff_top_n=cfg.cutoff_top_n,
|
|
|
|
|
num_processes=cfg.num_proc_bsearch)
|
|
|
|
|
|
|
|
|
|
#self.autolog.times.stamp()
|
|
|
|
|
#self.autolog.times.stamp()
|
|
|
|
|
#self.autolog.times.end()
|
|
|
|
|
target_transcripts = self.ordid2token(texts, texts_len)
|
|
|
|
|
for utt, target, result in zip(utts, target_transcripts,
|
|
|
|
|
result_transcripts):
|
|
|
|
|
errors, len_ref = errors_func(target, result)
|
|
|
|
|
errors_sum += errors
|
|
|
|
|
len_refs += len_ref
|
|
|
|
|
num_ins += 1
|
|
|
|
|
if fout:
|
|
|
|
|
fout.write(utt + " " + result + "\n")
|
|
|
|
|
logger.info("\nTarget Transcription: %s\nOutput Transcription: %s" %
|
|
|
|
|
(target, result))
|
|
|
|
|
logger.info("Current error rate [%s] = %f" %
|
|
|
|
|
(cfg.error_rate_type, error_rate_func(target, result)))
|
|
|
|
|
|
|
|
|
|
return dict(
|
|
|
|
|
errors_sum=errors_sum,
|
|
|
|
|
len_refs=len_refs,
|
|
|
|
|
num_ins=num_ins,
|
|
|
|
|
error_rate=errors_sum / len_refs,
|
|
|
|
|
error_rate_type=cfg.error_rate_type)
|
|
|
|
|
|
|
|
|
|
@mp_tools.rank_zero_only
|
|
|
|
|
@paddle.no_grad()
|
|
|
|
|
def test(self):
|
|
|
|
|
logger.info(f"Test Total Examples: {len(self.test_loader.dataset)}")
|
|
|
|
|
#self.autolog = Autolog(
|
|
|
|
|
# batch_size=self.config.decoding.batch_size,
|
|
|
|
|
# model_name="deepspeech2",
|
|
|
|
|
# model_precision="fp32").getlog()
|
|
|
|
|
self.model.eval()
|
|
|
|
|
cfg = self.config
|
|
|
|
|
error_rate_type = None
|
|
|
|
|
errors_sum, len_refs, num_ins = 0.0, 0, 0
|
|
|
|
|
with open(self.args.result_file, 'w') as fout:
|
|
|
|
|
for i, batch in enumerate(self.test_loader):
|
|
|
|
|
utts, audio, audio_len, texts, texts_len = batch
|
|
|
|
|
metrics = self.compute_metrics(utts, audio, audio_len, texts,
|
|
|
|
|
texts_len, fout)
|
|
|
|
|
errors_sum += metrics['errors_sum']
|
|
|
|
|
len_refs += metrics['len_refs']
|
|
|
|
|
num_ins += metrics['num_ins']
|
|
|
|
|
error_rate_type = metrics['error_rate_type']
|
|
|
|
|
logger.info("Error rate [%s] (%d/?) = %f" %
|
|
|
|
|
(error_rate_type, num_ins, errors_sum / len_refs))
|
|
|
|
|
|
|
|
|
|
# logging
|
|
|
|
|
msg = "Test: "
|
|
|
|
|
msg += "epoch: {}, ".format(self.epoch)
|
|
|
|
|
msg += "step: {}, ".format(self.iteration)
|
|
|
|
|
msg += "Final error rate [%s] (%d/%d) = %f" % (
|
|
|
|
|
error_rate_type, num_ins, num_ins, errors_sum / len_refs)
|
|
|
|
|
logger.info(msg)
|
|
|
|
|
#self.autolog.report()
|
|
|
|
|
return output_probs_branch, output_lens_branch
|
|
|
|
|
|
|
|
|
|
def run_test(self):
|
|
|
|
|
try:
|
|
|
|
@ -676,19 +612,12 @@ class DeepSpeech2ExportTester(DeepSpeech2Trainer):
|
|
|
|
|
except KeyboardInterrupt:
|
|
|
|
|
exit(-1)
|
|
|
|
|
|
|
|
|
|
def run_export(self):
|
|
|
|
|
try:
|
|
|
|
|
self.export()
|
|
|
|
|
except KeyboardInterrupt:
|
|
|
|
|
exit(-1)
|
|
|
|
|
|
|
|
|
|
def setup(self):
|
|
|
|
|
"""Setup the experiment.
|
|
|
|
|
"""
|
|
|
|
|
paddle.set_device(self.args.device)
|
|
|
|
|
|
|
|
|
|
self.setup_output_dir()
|
|
|
|
|
#self.setup_checkpointer()
|
|
|
|
|
|
|
|
|
|
self.setup_dataloader()
|
|
|
|
|
self.setup_model()
|
|
|
|
@ -711,17 +640,11 @@ class DeepSpeech2ExportTester(DeepSpeech2Trainer):
|
|
|
|
|
|
|
|
|
|
def setup_model(self):
|
|
|
|
|
super().setup_model()
|
|
|
|
|
if self.args.model_type == 'online':
|
|
|
|
|
#inference_dir = "exp/deepspeech2_online/checkpoints/"
|
|
|
|
|
#inference_dir = "exp/deepspeech2_online_3rr_1fc_lr_decay0.91_lstm/checkpoints/"
|
|
|
|
|
#speedyspeech_config = inference.Config(
|
|
|
|
|
# str(Path(inference_dir) / "avg_1.jit.pdmodel"),
|
|
|
|
|
# str(Path(inference_dir) / "avg_1.jit.pdiparams"))
|
|
|
|
|
speedyspeech_config = inference.Config(
|
|
|
|
|
self.args.export_path + ".pdmodel",
|
|
|
|
|
self.args.export_path + ".pdiparams")
|
|
|
|
|
speedyspeech_config = inference.Config(
|
|
|
|
|
self.args.export_path + ".pdmodel",
|
|
|
|
|
self.args.export_path + ".pdiparams")
|
|
|
|
|
if (os.environ['CUDA_VISIBLE_DEVICES'].strip() != ''):
|
|
|
|
|
speedyspeech_config.enable_use_gpu(100, 0)
|
|
|
|
|
speedyspeech_config.enable_memory_optim()
|
|
|
|
|
speedyspeech_predictor = inference.create_predictor(
|
|
|
|
|
speedyspeech_config)
|
|
|
|
|
self.predictor = speedyspeech_predictor
|
|
|
|
|
speedyspeech_predictor = inference.create_predictor(speedyspeech_config)
|
|
|
|
|
self.predictor = speedyspeech_predictor
|
|
|
|
|