add static_forward_online and static_forward_offline

pull/786/head
huangyuxin 4 years ago
parent 92617f0802
commit 0d0b581181

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Contains DeepSpeech2 and DeepSpeech2Online model.""" """Contains DeepSpeech2 and DeepSpeech2Online model."""
import os
import time import time
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
@ -398,40 +399,10 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
self.output_dir = output_dir self.output_dir = output_dir
class DeepSpeech2ExportTester(DeepSpeech2Trainer): class DeepSpeech2ExportTester(DeepSpeech2Tester):
@classmethod
def params(cls, config: Optional[CfgNode]=None) -> CfgNode:
# testing config
default = CfgNode(
dict(
alpha=2.5, # Coef of LM for beam search.
beta=0.3, # Coef of WC for beam search.
cutoff_prob=1.0, # Cutoff probability for pruning.
cutoff_top_n=40, # Cutoff number for pruning.
lang_model_path='models/lm/common_crawl_00.prune01111.trie.klm', # Filepath for language model.
decoding_method='ctc_beam_search', # Decoding method. Options: ctc_beam_search, ctc_greedy
error_rate_type='wer', # Error rate type for evaluation. Options `wer`, 'cer'
num_proc_bsearch=8, # # of CPUs for beam search.
beam_size=500, # Beam search width.
batch_size=128, # decoding batch size
))
if config is not None:
config.merge_from_other_cfg(default)
return default
def __init__(self, config, args): def __init__(self, config, args):
super().__init__(config, args) super().__init__(config, args)
def ordid2token(self, texts, texts_len):
""" ord() id to chr() chr """
trans = []
for text, n in zip(texts, texts_len):
n = n.numpy().item()
ids = text[:n]
trans.append(''.join([chr(i) for i in ids]))
return trans
def compute_metrics(self, def compute_metrics(self,
utts, utts,
audio, audio,
@ -447,9 +418,48 @@ class DeepSpeech2ExportTester(DeepSpeech2Trainer):
vocab_list = self.test_loader.collate_fn.vocab_list vocab_list = self.test_loader.collate_fn.vocab_list
batch_size = self.config.decoding.batch_size if self.args.model_type == "online":
output_probs_branch, output_lens_branch = self.static_forward_online(
audio, audio_len)
elif self.args.model_type == "offline":
output_probs_branch, output_lens_branch = self.static_forward_offline(
audio, audio_len)
else:
raise Exception("wrong model type")
self.predictor.clear_intermediate_tensor()
self.predictor.try_shrink_memory()
self.model.decoder.init_decode(cfg.alpha, cfg.beta, cfg.lang_model_path,
vocab_list, cfg.decoding_method)
result_transcripts = self.model.decoder.decode_probs(
output_probs_branch.numpy(), output_lens_branch, vocab_list,
cfg.decoding_method, cfg.lang_model_path, cfg.alpha, cfg.beta,
cfg.beam_size, cfg.cutoff_prob, cfg.cutoff_top_n,
cfg.num_proc_bsearch)
output_prob_list = [] target_transcripts = self.ordid2token(texts, texts_len)
for utt, target, result in zip(utts, target_transcripts,
result_transcripts):
errors, len_ref = errors_func(target, result)
errors_sum += errors
len_refs += len_ref
num_ins += 1
if fout:
fout.write(utt + " " + result + "\n")
logger.info("\nTarget Transcription: %s\nOutput Transcription: %s" %
(target, result))
logger.info("Current error rate [%s] = %f" %
(cfg.error_rate_type, error_rate_func(target, result)))
return dict(
errors_sum=errors_sum,
len_refs=len_refs,
num_ins=num_ins,
error_rate=errors_sum / len_refs,
error_rate_type=cfg.error_rate_type)
def static_forward_online(self, audio, audio_len):
output_probs_list = []
output_lens_list = [] output_lens_list = []
decoder_chunk_size = 8 decoder_chunk_size = 8
subsampling_rate = self.model.encoder.conv.subsampling_rate subsampling_rate = self.model.encoder.conv.subsampling_rate
@ -459,15 +469,18 @@ class DeepSpeech2ExportTester(DeepSpeech2Trainer):
) * subsampling_rate + receptive_field_length ) * subsampling_rate + receptive_field_length
x_batch = audio.numpy() x_batch = audio.numpy()
batch_size = x_batch.shape[0]
x_len_batch = audio_len.numpy().astype(np.int64) x_len_batch = audio_len.numpy().astype(np.int64)
max_len_batch = x_batch.shape[1] max_len_batch = x_batch.shape[1]
batch_padding_len = chunk_stride - ( batch_padding_len = chunk_stride - (
max_len_batch - chunk_size max_len_batch - chunk_size
) % chunk_stride # The length of padding for the batch ) % chunk_stride # The length of padding for the batch
x_list = np.split(x_batch, x_batch.shape[0], axis=0) x_list = np.split(x_batch, batch_size, axis=0)
x_len_list = np.split(x_len_batch, x_batch.shape[0], axis=0) x_len_list = np.split(x_len_batch, x_batch.shape[0], axis=0)
for x, x_len in zip(x_list, x_len_list): for x, x_len in zip(x_list, x_len_list):
self.autolog.times.start()
self.autolog.times.stamp()
assert (chunk_size <= x_len[0]) assert (chunk_size <= x_len[0])
eouts_chunk_list = [] eouts_chunk_list = []
@ -536,38 +549,40 @@ class DeepSpeech2ExportTester(DeepSpeech2Trainer):
output_state_c_handle = self.predictor.get_output_handle( output_state_c_handle = self.predictor.get_output_handle(
output_names[3]) output_names[3])
self.predictor.run() self.predictor.run()
output_chunk_prob = output_handle.copy_to_cpu() output_chunk_probs = output_handle.copy_to_cpu()
output_chunk_lens = output_lens_handle.copy_to_cpu() output_chunk_lens = output_lens_handle.copy_to_cpu()
chunk_state_h_box = output_state_h_handle.copy_to_cpu() chunk_state_h_box = output_state_h_handle.copy_to_cpu()
chunk_state_c_box = output_state_c_handle.copy_to_cpu() chunk_state_c_box = output_state_c_handle.copy_to_cpu()
output_chunk_prob = paddle.to_tensor(output_chunk_prob) output_chunk_probs = paddle.to_tensor(output_chunk_probs)
output_chunk_lens = paddle.to_tensor(output_chunk_lens) output_chunk_lens = paddle.to_tensor(output_chunk_lens)
probs_chunk_list.append(output_chunk_prob) probs_chunk_list.append(output_chunk_probs)
probs_chunk_lens_list.append(output_chunk_lens) probs_chunk_lens_list.append(output_chunk_lens)
output_prob = paddle.concat(probs_chunk_list, axis=1) output_probs = paddle.concat(probs_chunk_list, axis=1)
output_lens = paddle.add_n(probs_chunk_lens_list) output_lens = paddle.add_n(probs_chunk_lens_list)
output_prob_padding_len = max_len_batch + batch_padding_len - output_prob.shape[ output_probs_padding_len = max_len_batch + batch_padding_len - output_probs.shape[
1] 1]
output_prob_padding = paddle.zeros( output_probs_padding = paddle.zeros(
(1, output_prob_padding_len, output_prob.shape[2]), (1, output_probs_padding_len, output_probs.shape[2]),
dtype="float32") # The prob padding for a piece of utterance dtype="float32") # The prob padding for a piece of utterance
output_prob = paddle.concat( output_probs = paddle.concat(
[output_prob, output_prob_padding], axis=1) [output_probs, output_probs_padding], axis=1)
output_prob_list.append(output_prob) output_probs_list.append(output_probs)
output_lens_list.append(output_lens) output_lens_list.append(output_lens)
output_prob_branch = paddle.concat(output_prob_list, axis=0) self.autolog.times.stamp()
self.autolog.times.stamp()
self.autolog.times.end()
output_probs_branch = paddle.concat(output_probs_list, axis=0)
output_lens_branch = paddle.concat(output_lens_list, axis=0) output_lens_branch = paddle.concat(output_lens_list, axis=0)
""" return output_probs_branch, output_lens_branch
def static_forward_offline(self, audio, audio_len):
x = audio.numpy() x = audio.numpy()
x_len = audio_len.numpy().astype(np.int64) x_len = audio_len.numpy().astype(np.int64)
input_names = self.predictor.get_input_names() input_names = self.predictor.get_input_names()
audio_handle = self.predictor.get_input_handle(input_names[0]) audio_handle = self.predictor.get_input_handle(input_names[0])
audio_len_handle = self.predictor.get_input_handle(input_names[1]) audio_len_handle = self.predictor.get_input_handle(input_names[1])
h_box_handle = self.predictor.get_input_handle(input_names[2])
c_box_handle = self.predictor.get_input_handle(input_names[3])
audio_handle.reshape(x.shape) audio_handle.reshape(x.shape)
audio_handle.copy_from_cpu(x) audio_handle.copy_from_cpu(x)
@ -575,100 +590,21 @@ class DeepSpeech2ExportTester(DeepSpeech2Trainer):
audio_len_handle.reshape(x_len.shape) audio_len_handle.reshape(x_len.shape)
audio_len_handle.copy_from_cpu(x_len) audio_len_handle.copy_from_cpu(x_len)
init_state_h_box = np.zeros((self.config.model.num_rnn_layers, audio.shape[0], self.config.model.rnn_layer_size), dtype=np.float32) self.autolog.times.start()
init_state_c_box = np.zeros((self.config.model.num_rnn_layers, audio.shape[0], self.config.model.rnn_layer_size), dtype=np.float32) self.autolog.times.stamp()
h_box_handle.reshape(init_state_h_box.shape)
h_box_handle.copy_from_cpu(init_state_h_box)
c_box_handle.reshape(init_state_c_box.shape)
c_box_handle.copy_from_cpu(init_state_c_box)
#self.autolog.times.start()
#self.autolog.times.stamp()
self.predictor.run() self.predictor.run()
self.autolog.times.stamp()
self.autolog.times.stamp()
self.autolog.times.end()
output_names = self.predictor.get_output_names() output_names = self.predictor.get_output_names()
output_handle = self.predictor.get_output_handle(output_names[0]) output_handle = self.predictor.get_output_handle(output_names[0])
output_lens_handle = self.predictor.get_output_handle(output_names[1]) output_lens_handle = self.predictor.get_output_handle(output_names[1])
output_state_h_handle = self.predictor.get_output_handle(output_names[2]) output_probs = output_handle.copy_to_cpu()
output_state_c_handle = self.predictor.get_output_handle(output_names[3])
output_prob = output_handle.copy_to_cpu()
output_lens = output_lens_handle.copy_to_cpu() output_lens = output_lens_handle.copy_to_cpu()
output_stata_h_box = output_state_h_handle.copy_to_cpu() output_probs_branch = paddle.to_tensor(output_probs)
output_stata_c_box = output_state_c_handle.copy_to_cpu()
output_prob_branch = paddle.to_tensor(output_prob)
output_lens_branch = paddle.to_tensor(output_lens) output_lens_branch = paddle.to_tensor(output_lens)
""" return output_probs_branch, output_lens_branch
result_transcripts = self.model.decode_by_probs(
output_prob_branch,
output_lens_branch,
vocab_list,
decoding_method=cfg.decoding_method,
lang_model_path=cfg.lang_model_path,
beam_alpha=cfg.alpha,
beam_beta=cfg.beta,
beam_size=cfg.beam_size,
cutoff_prob=cfg.cutoff_prob,
cutoff_top_n=cfg.cutoff_top_n,
num_processes=cfg.num_proc_bsearch)
#self.autolog.times.stamp()
#self.autolog.times.stamp()
#self.autolog.times.end()
target_transcripts = self.ordid2token(texts, texts_len)
for utt, target, result in zip(utts, target_transcripts,
result_transcripts):
errors, len_ref = errors_func(target, result)
errors_sum += errors
len_refs += len_ref
num_ins += 1
if fout:
fout.write(utt + " " + result + "\n")
logger.info("\nTarget Transcription: %s\nOutput Transcription: %s" %
(target, result))
logger.info("Current error rate [%s] = %f" %
(cfg.error_rate_type, error_rate_func(target, result)))
return dict(
errors_sum=errors_sum,
len_refs=len_refs,
num_ins=num_ins,
error_rate=errors_sum / len_refs,
error_rate_type=cfg.error_rate_type)
@mp_tools.rank_zero_only
@paddle.no_grad()
def test(self):
logger.info(f"Test Total Examples: {len(self.test_loader.dataset)}")
#self.autolog = Autolog(
# batch_size=self.config.decoding.batch_size,
# model_name="deepspeech2",
# model_precision="fp32").getlog()
self.model.eval()
cfg = self.config
error_rate_type = None
errors_sum, len_refs, num_ins = 0.0, 0, 0
with open(self.args.result_file, 'w') as fout:
for i, batch in enumerate(self.test_loader):
utts, audio, audio_len, texts, texts_len = batch
metrics = self.compute_metrics(utts, audio, audio_len, texts,
texts_len, fout)
errors_sum += metrics['errors_sum']
len_refs += metrics['len_refs']
num_ins += metrics['num_ins']
error_rate_type = metrics['error_rate_type']
logger.info("Error rate [%s] (%d/?) = %f" %
(error_rate_type, num_ins, errors_sum / len_refs))
# logging
msg = "Test: "
msg += "epoch: {}, ".format(self.epoch)
msg += "step: {}, ".format(self.iteration)
msg += "Final error rate [%s] (%d/%d) = %f" % (
error_rate_type, num_ins, num_ins, errors_sum / len_refs)
logger.info(msg)
#self.autolog.report()
def run_test(self): def run_test(self):
try: try:
@ -676,19 +612,12 @@ class DeepSpeech2ExportTester(DeepSpeech2Trainer):
except KeyboardInterrupt: except KeyboardInterrupt:
exit(-1) exit(-1)
def run_export(self):
try:
self.export()
except KeyboardInterrupt:
exit(-1)
def setup(self): def setup(self):
"""Setup the experiment. """Setup the experiment.
""" """
paddle.set_device(self.args.device) paddle.set_device(self.args.device)
self.setup_output_dir() self.setup_output_dir()
#self.setup_checkpointer()
self.setup_dataloader() self.setup_dataloader()
self.setup_model() self.setup_model()
@ -711,17 +640,11 @@ class DeepSpeech2ExportTester(DeepSpeech2Trainer):
def setup_model(self): def setup_model(self):
super().setup_model() super().setup_model()
if self.args.model_type == 'online':
#inference_dir = "exp/deepspeech2_online/checkpoints/"
#inference_dir = "exp/deepspeech2_online_3rr_1fc_lr_decay0.91_lstm/checkpoints/"
#speedyspeech_config = inference.Config(
# str(Path(inference_dir) / "avg_1.jit.pdmodel"),
# str(Path(inference_dir) / "avg_1.jit.pdiparams"))
speedyspeech_config = inference.Config( speedyspeech_config = inference.Config(
self.args.export_path + ".pdmodel", self.args.export_path + ".pdmodel",
self.args.export_path + ".pdiparams") self.args.export_path + ".pdiparams")
if (os.environ['CUDA_VISIBLE_DEVICES'].strip() != ''):
speedyspeech_config.enable_use_gpu(100, 0) speedyspeech_config.enable_use_gpu(100, 0)
speedyspeech_config.enable_memory_optim() speedyspeech_config.enable_memory_optim()
speedyspeech_predictor = inference.create_predictor( speedyspeech_predictor = inference.create_predictor(speedyspeech_config)
speedyspeech_config)
self.predictor = speedyspeech_predictor self.predictor = speedyspeech_predictor

@ -280,7 +280,7 @@ class DeepSpeech2InferModel(DeepSpeech2Model):
""" """
eouts, eouts_len = self.encoder(audio, audio_len) eouts, eouts_len = self.encoder(audio, audio_len)
probs = self.decoder.softmax(eouts) probs = self.decoder.softmax(eouts)
return probs return probs, eouts_len
def export(self): def export(self):
static_model = paddle.jit.to_static( static_model = paddle.jit.to_static(

@ -325,24 +325,6 @@ class DeepSpeech2ModelOnline(nn.Layer):
lang_model_path, beam_alpha, beam_beta, beam_size, cutoff_prob, lang_model_path, beam_alpha, beam_beta, beam_size, cutoff_prob,
cutoff_top_n, num_processes) cutoff_top_n, num_processes)
@paddle.no_grad()
def decode_by_probs(self, probs, probs_len, vocab_list, decoding_method,
lang_model_path, beam_alpha, beam_beta, beam_size,
cutoff_prob, cutoff_top_n, num_processes):
# init once
# decoders only accept string encoded in utf-8
self.decoder.init_decode(
beam_alpha=beam_alpha,
beam_beta=beam_beta,
lang_model_path=lang_model_path,
vocab_list=vocab_list,
decoding_method=decoding_method)
return self.decoder.decode_probs(
probs.numpy(), probs_len, vocab_list, decoding_method,
lang_model_path, beam_alpha, beam_beta, beam_size, cutoff_prob,
cutoff_top_n, num_processes)
@classmethod @classmethod
def from_pretrained(cls, dataloader, config, checkpoint_path): def from_pretrained(cls, dataloader, config, checkpoint_path):
"""Build a DeepSpeech2Model model from a pretrained model. """Build a DeepSpeech2Model model from a pretrained model.

Loading…
Cancel
Save