From b3947bf5fab9b2cb2420dd7658a73247832a2cda Mon Sep 17 00:00:00 2001 From: huangyuxin Date: Wed, 25 Aug 2021 08:27:22 +0000 Subject: [PATCH] add static_forward_online and static_forward_offline --- deepspeech/exps/deepspeech2/model.py | 233 +++++++------------- deepspeech/models/ds2/deepspeech2.py | 2 +- deepspeech/models/ds2_online/deepspeech2.py | 18 -- 3 files changed, 79 insertions(+), 174 deletions(-) diff --git a/deepspeech/exps/deepspeech2/model.py b/deepspeech/exps/deepspeech2/model.py index e00439a02..f386336ad 100644 --- a/deepspeech/exps/deepspeech2/model.py +++ b/deepspeech/exps/deepspeech2/model.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Contains DeepSpeech2 and DeepSpeech2Online model.""" +import os import time from collections import defaultdict from pathlib import Path @@ -398,40 +399,10 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): self.output_dir = output_dir -class DeepSpeech2ExportTester(DeepSpeech2Trainer): - @classmethod - def params(cls, config: Optional[CfgNode]=None) -> CfgNode: - # testing config - default = CfgNode( - dict( - alpha=2.5, # Coef of LM for beam search. - beta=0.3, # Coef of WC for beam search. - cutoff_prob=1.0, # Cutoff probability for pruning. - cutoff_top_n=40, # Cutoff number for pruning. - lang_model_path='models/lm/common_crawl_00.prune01111.trie.klm', # Filepath for language model. - decoding_method='ctc_beam_search', # Decoding method. Options: ctc_beam_search, ctc_greedy - error_rate_type='wer', # Error rate type for evaluation. Options `wer`, 'cer' - num_proc_bsearch=8, # # of CPUs for beam search. - beam_size=500, # Beam search width. - batch_size=128, # decoding batch size - )) - - if config is not None: - config.merge_from_other_cfg(default) - return default - +class DeepSpeech2ExportTester(DeepSpeech2Tester): def __init__(self, config, args): super().__init__(config, args) - def ordid2token(self, texts, texts_len): - """ ord() id to chr() chr """ - trans = [] - for text, n in zip(texts, texts_len): - n = n.numpy().item() - ids = text[:n] - trans.append(''.join([chr(i) for i in ids])) - return trans - def compute_metrics(self, utts, audio, @@ -447,9 +418,48 @@ class DeepSpeech2ExportTester(DeepSpeech2Trainer): vocab_list = self.test_loader.collate_fn.vocab_list - batch_size = self.config.decoding.batch_size + if self.args.model_type == "online": + output_probs_branch, output_lens_branch = self.static_forward_online( + audio, audio_len) + elif self.args.model_type == "offline": + output_probs_branch, output_lens_branch = self.static_forward_offline( + audio, audio_len) + else: + raise Exception("wrong model type") + self.predictor.clear_intermediate_tensor() + self.predictor.try_shrink_memory() + self.model.decoder.init_decode(cfg.alpha, cfg.beta, cfg.lang_model_path, + vocab_list, cfg.decoding_method) + + result_transcripts = self.model.decoder.decode_probs( + output_probs_branch.numpy(), output_lens_branch, vocab_list, + cfg.decoding_method, cfg.lang_model_path, cfg.alpha, cfg.beta, + cfg.beam_size, cfg.cutoff_prob, cfg.cutoff_top_n, + cfg.num_proc_bsearch) - output_prob_list = [] + target_transcripts = self.ordid2token(texts, texts_len) + for utt, target, result in zip(utts, target_transcripts, + result_transcripts): + errors, len_ref = errors_func(target, result) + errors_sum += errors + len_refs += len_ref + num_ins += 1 + if fout: + fout.write(utt + " " + result + "\n") + logger.info("\nTarget Transcription: %s\nOutput Transcription: %s" % + (target, result)) + logger.info("Current error rate [%s] = %f" % + (cfg.error_rate_type, error_rate_func(target, result))) + + return dict( + errors_sum=errors_sum, + len_refs=len_refs, + num_ins=num_ins, + error_rate=errors_sum / len_refs, + error_rate_type=cfg.error_rate_type) + + def static_forward_online(self, audio, audio_len): + output_probs_list = [] output_lens_list = [] decoder_chunk_size = 8 subsampling_rate = self.model.encoder.conv.subsampling_rate @@ -459,15 +469,18 @@ class DeepSpeech2ExportTester(DeepSpeech2Trainer): ) * subsampling_rate + receptive_field_length x_batch = audio.numpy() + batch_size = x_batch.shape[0] x_len_batch = audio_len.numpy().astype(np.int64) max_len_batch = x_batch.shape[1] batch_padding_len = chunk_stride - ( max_len_batch - chunk_size ) % chunk_stride # The length of padding for the batch - x_list = np.split(x_batch, x_batch.shape[0], axis=0) + x_list = np.split(x_batch, batch_size, axis=0) x_len_list = np.split(x_len_batch, x_batch.shape[0], axis=0) for x, x_len in zip(x_list, x_len_list): + self.autolog.times.start() + self.autolog.times.stamp() assert (chunk_size <= x_len[0]) eouts_chunk_list = [] @@ -536,38 +549,40 @@ class DeepSpeech2ExportTester(DeepSpeech2Trainer): output_state_c_handle = self.predictor.get_output_handle( output_names[3]) self.predictor.run() - output_chunk_prob = output_handle.copy_to_cpu() + output_chunk_probs = output_handle.copy_to_cpu() output_chunk_lens = output_lens_handle.copy_to_cpu() chunk_state_h_box = output_state_h_handle.copy_to_cpu() chunk_state_c_box = output_state_c_handle.copy_to_cpu() - output_chunk_prob = paddle.to_tensor(output_chunk_prob) + output_chunk_probs = paddle.to_tensor(output_chunk_probs) output_chunk_lens = paddle.to_tensor(output_chunk_lens) - probs_chunk_list.append(output_chunk_prob) + probs_chunk_list.append(output_chunk_probs) probs_chunk_lens_list.append(output_chunk_lens) - output_prob = paddle.concat(probs_chunk_list, axis=1) + output_probs = paddle.concat(probs_chunk_list, axis=1) output_lens = paddle.add_n(probs_chunk_lens_list) - output_prob_padding_len = max_len_batch + batch_padding_len - output_prob.shape[ + output_probs_padding_len = max_len_batch + batch_padding_len - output_probs.shape[ 1] - output_prob_padding = paddle.zeros( - (1, output_prob_padding_len, output_prob.shape[2]), + output_probs_padding = paddle.zeros( + (1, output_probs_padding_len, output_probs.shape[2]), dtype="float32") # The prob padding for a piece of utterance - output_prob = paddle.concat( - [output_prob, output_prob_padding], axis=1) - output_prob_list.append(output_prob) + output_probs = paddle.concat( + [output_probs, output_probs_padding], axis=1) + output_probs_list.append(output_probs) output_lens_list.append(output_lens) - output_prob_branch = paddle.concat(output_prob_list, axis=0) + self.autolog.times.stamp() + self.autolog.times.stamp() + self.autolog.times.end() + output_probs_branch = paddle.concat(output_probs_list, axis=0) output_lens_branch = paddle.concat(output_lens_list, axis=0) - """ + return output_probs_branch, output_lens_branch + + def static_forward_offline(self, audio, audio_len): x = audio.numpy() x_len = audio_len.numpy().astype(np.int64) input_names = self.predictor.get_input_names() audio_handle = self.predictor.get_input_handle(input_names[0]) audio_len_handle = self.predictor.get_input_handle(input_names[1]) - h_box_handle = self.predictor.get_input_handle(input_names[2]) - c_box_handle = self.predictor.get_input_handle(input_names[3]) - audio_handle.reshape(x.shape) audio_handle.copy_from_cpu(x) @@ -575,100 +590,21 @@ class DeepSpeech2ExportTester(DeepSpeech2Trainer): audio_len_handle.reshape(x_len.shape) audio_len_handle.copy_from_cpu(x_len) - init_state_h_box = np.zeros((self.config.model.num_rnn_layers, audio.shape[0], self.config.model.rnn_layer_size), dtype=np.float32) - init_state_c_box = np.zeros((self.config.model.num_rnn_layers, audio.shape[0], self.config.model.rnn_layer_size), dtype=np.float32) - h_box_handle.reshape(init_state_h_box.shape) - h_box_handle.copy_from_cpu(init_state_h_box) - - c_box_handle.reshape(init_state_c_box.shape) - c_box_handle.copy_from_cpu(init_state_c_box) - - #self.autolog.times.start() - #self.autolog.times.stamp() + self.autolog.times.start() + self.autolog.times.stamp() self.predictor.run() + self.autolog.times.stamp() + self.autolog.times.stamp() + self.autolog.times.end() output_names = self.predictor.get_output_names() output_handle = self.predictor.get_output_handle(output_names[0]) output_lens_handle = self.predictor.get_output_handle(output_names[1]) - output_state_h_handle = self.predictor.get_output_handle(output_names[2]) - output_state_c_handle = self.predictor.get_output_handle(output_names[3]) - output_prob = output_handle.copy_to_cpu() + output_probs = output_handle.copy_to_cpu() output_lens = output_lens_handle.copy_to_cpu() - output_stata_h_box = output_state_h_handle.copy_to_cpu() - output_stata_c_box = output_state_c_handle.copy_to_cpu() - output_prob_branch = paddle.to_tensor(output_prob) + output_probs_branch = paddle.to_tensor(output_probs) output_lens_branch = paddle.to_tensor(output_lens) - """ - - result_transcripts = self.model.decode_by_probs( - output_prob_branch, - output_lens_branch, - vocab_list, - decoding_method=cfg.decoding_method, - lang_model_path=cfg.lang_model_path, - beam_alpha=cfg.alpha, - beam_beta=cfg.beta, - beam_size=cfg.beam_size, - cutoff_prob=cfg.cutoff_prob, - cutoff_top_n=cfg.cutoff_top_n, - num_processes=cfg.num_proc_bsearch) - - #self.autolog.times.stamp() - #self.autolog.times.stamp() - #self.autolog.times.end() - target_transcripts = self.ordid2token(texts, texts_len) - for utt, target, result in zip(utts, target_transcripts, - result_transcripts): - errors, len_ref = errors_func(target, result) - errors_sum += errors - len_refs += len_ref - num_ins += 1 - if fout: - fout.write(utt + " " + result + "\n") - logger.info("\nTarget Transcription: %s\nOutput Transcription: %s" % - (target, result)) - logger.info("Current error rate [%s] = %f" % - (cfg.error_rate_type, error_rate_func(target, result))) - - return dict( - errors_sum=errors_sum, - len_refs=len_refs, - num_ins=num_ins, - error_rate=errors_sum / len_refs, - error_rate_type=cfg.error_rate_type) - - @mp_tools.rank_zero_only - @paddle.no_grad() - def test(self): - logger.info(f"Test Total Examples: {len(self.test_loader.dataset)}") - #self.autolog = Autolog( - # batch_size=self.config.decoding.batch_size, - # model_name="deepspeech2", - # model_precision="fp32").getlog() - self.model.eval() - cfg = self.config - error_rate_type = None - errors_sum, len_refs, num_ins = 0.0, 0, 0 - with open(self.args.result_file, 'w') as fout: - for i, batch in enumerate(self.test_loader): - utts, audio, audio_len, texts, texts_len = batch - metrics = self.compute_metrics(utts, audio, audio_len, texts, - texts_len, fout) - errors_sum += metrics['errors_sum'] - len_refs += metrics['len_refs'] - num_ins += metrics['num_ins'] - error_rate_type = metrics['error_rate_type'] - logger.info("Error rate [%s] (%d/?) = %f" % - (error_rate_type, num_ins, errors_sum / len_refs)) - - # logging - msg = "Test: " - msg += "epoch: {}, ".format(self.epoch) - msg += "step: {}, ".format(self.iteration) - msg += "Final error rate [%s] (%d/%d) = %f" % ( - error_rate_type, num_ins, num_ins, errors_sum / len_refs) - logger.info(msg) - #self.autolog.report() + return output_probs_branch, output_lens_branch def run_test(self): try: @@ -676,19 +612,12 @@ class DeepSpeech2ExportTester(DeepSpeech2Trainer): except KeyboardInterrupt: exit(-1) - def run_export(self): - try: - self.export() - except KeyboardInterrupt: - exit(-1) - def setup(self): """Setup the experiment. """ paddle.set_device(self.args.device) self.setup_output_dir() - #self.setup_checkpointer() self.setup_dataloader() self.setup_model() @@ -711,17 +640,11 @@ class DeepSpeech2ExportTester(DeepSpeech2Trainer): def setup_model(self): super().setup_model() - if self.args.model_type == 'online': - #inference_dir = "exp/deepspeech2_online/checkpoints/" - #inference_dir = "exp/deepspeech2_online_3rr_1fc_lr_decay0.91_lstm/checkpoints/" - #speedyspeech_config = inference.Config( - # str(Path(inference_dir) / "avg_1.jit.pdmodel"), - # str(Path(inference_dir) / "avg_1.jit.pdiparams")) - speedyspeech_config = inference.Config( - self.args.export_path + ".pdmodel", - self.args.export_path + ".pdiparams") + speedyspeech_config = inference.Config( + self.args.export_path + ".pdmodel", + self.args.export_path + ".pdiparams") + if (os.environ['CUDA_VISIBLE_DEVICES'].strip() != ''): speedyspeech_config.enable_use_gpu(100, 0) speedyspeech_config.enable_memory_optim() - speedyspeech_predictor = inference.create_predictor( - speedyspeech_config) - self.predictor = speedyspeech_predictor + speedyspeech_predictor = inference.create_predictor(speedyspeech_config) + self.predictor = speedyspeech_predictor diff --git a/deepspeech/models/ds2/deepspeech2.py b/deepspeech/models/ds2/deepspeech2.py index 1ffd797b4..5f8f32557 100644 --- a/deepspeech/models/ds2/deepspeech2.py +++ b/deepspeech/models/ds2/deepspeech2.py @@ -280,7 +280,7 @@ class DeepSpeech2InferModel(DeepSpeech2Model): """ eouts, eouts_len = self.encoder(audio, audio_len) probs = self.decoder.softmax(eouts) - return probs + return probs, eouts_len def export(self): static_model = paddle.jit.to_static( diff --git a/deepspeech/models/ds2_online/deepspeech2.py b/deepspeech/models/ds2_online/deepspeech2.py index 773119296..d092b154b 100644 --- a/deepspeech/models/ds2_online/deepspeech2.py +++ b/deepspeech/models/ds2_online/deepspeech2.py @@ -325,24 +325,6 @@ class DeepSpeech2ModelOnline(nn.Layer): lang_model_path, beam_alpha, beam_beta, beam_size, cutoff_prob, cutoff_top_n, num_processes) - @paddle.no_grad() - def decode_by_probs(self, probs, probs_len, vocab_list, decoding_method, - lang_model_path, beam_alpha, beam_beta, beam_size, - cutoff_prob, cutoff_top_n, num_processes): - # init once - # decoders only accept string encoded in utf-8 - self.decoder.init_decode( - beam_alpha=beam_alpha, - beam_beta=beam_beta, - lang_model_path=lang_model_path, - vocab_list=vocab_list, - decoding_method=decoding_method) - - return self.decoder.decode_probs( - probs.numpy(), probs_len, vocab_list, decoding_method, - lang_model_path, beam_alpha, beam_beta, beam_size, cutoff_prob, - cutoff_top_n, num_processes) - @classmethod def from_pretrained(cls, dataloader, config, checkpoint_path): """Build a DeepSpeech2Model model from a pretrained model.