From b585684bf47bfa4afd7eff5e345b62c9fde53295 Mon Sep 17 00:00:00 2001 From: huangyuxin Date: Tue, 24 Aug 2021 06:56:50 +0000 Subject: [PATCH 1/5] add function: test export --- .../exps/deepspeech2/bin/test_export.py | 52 +++ deepspeech/exps/deepspeech2/model.py | 332 +++++++++++++++++- deepspeech/models/ds2_online/conv.py | 2 + deepspeech/models/ds2_online/deepspeech2.py | 18 + examples/aishell/s0/local/test_export.sh | 39 ++ examples/aishell/s0/run.sh | 5 + 6 files changed, 447 insertions(+), 1 deletion(-) create mode 100644 deepspeech/exps/deepspeech2/bin/test_export.py create mode 100755 examples/aishell/s0/local/test_export.sh diff --git a/deepspeech/exps/deepspeech2/bin/test_export.py b/deepspeech/exps/deepspeech2/bin/test_export.py new file mode 100644 index 00000000..cfc45c4b --- /dev/null +++ b/deepspeech/exps/deepspeech2/bin/test_export.py @@ -0,0 +1,52 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Evaluation for DeepSpeech2 model.""" +from deepspeech.exps.deepspeech2.config import get_cfg_defaults +from deepspeech.exps.deepspeech2.model import DeepSpeech2ExportTester as ExportTester +from deepspeech.training.cli import default_argument_parser +from deepspeech.utils.utility import print_arguments + + +def main_sp(config, args): + exp = ExportTester(config, args) + exp.setup() + exp.run_test() + + +def main(config, args): + main_sp(config, args) + + +if __name__ == "__main__": + parser = default_argument_parser() + parser.add_argument("--model_type") + args = parser.parse_args() + print_arguments(args, globals()) + if args.model_type is None: + args.model_type = 'offline' + print("model_type:{}".format(args.model_type)) + + # https://yaml.org/type/float.html + config = get_cfg_defaults(args.model_type) + if args.config: + config.merge_from_file(args.config) + if args.opts: + config.merge_from_list(args.opts) + config.freeze() + print(config) + if args.dump_config: + with open(args.dump_config, 'w') as f: + print(config, file=f) + + main(config, args) diff --git a/deepspeech/exps/deepspeech2/model.py b/deepspeech/exps/deepspeech2/model.py index 65c905a1..bbce273f 100644 --- a/deepspeech/exps/deepspeech2/model.py +++ b/deepspeech/exps/deepspeech2/model.py @@ -20,6 +20,7 @@ from typing import Optional import numpy as np import paddle from paddle import distributed as dist +from paddle import inference from paddle.io import DataLoader from yacs.config import CfgNode @@ -145,7 +146,7 @@ class DeepSpeech2Trainer(Trainer): learning_rate=config.training.lr, gamma=config.training.lr_decay, verbose=True) - optimizer = paddle.optimizer.Adam( + optimizer = paddle.optimizer.SGD( #Adam learning_rate=lr_scheduler, parameters=model.parameters(), weight_decay=paddle.regularizer.L2Decay( @@ -395,3 +396,332 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): output_dir.mkdir(parents=True, exist_ok=True) self.output_dir = output_dir + + +class DeepSpeech2ExportTester(DeepSpeech2Trainer): + @classmethod + def params(cls, config: Optional[CfgNode]=None) -> CfgNode: + # testing config + default = CfgNode( + dict( + alpha=2.5, # Coef of LM for beam search. + beta=0.3, # Coef of WC for beam search. + cutoff_prob=1.0, # Cutoff probability for pruning. + cutoff_top_n=40, # Cutoff number for pruning. + lang_model_path='models/lm/common_crawl_00.prune01111.trie.klm', # Filepath for language model. + decoding_method='ctc_beam_search', # Decoding method. Options: ctc_beam_search, ctc_greedy + error_rate_type='wer', # Error rate type for evaluation. Options `wer`, 'cer' + num_proc_bsearch=8, # # of CPUs for beam search. + beam_size=500, # Beam search width. + batch_size=128, # decoding batch size + )) + + if config is not None: + config.merge_from_other_cfg(default) + return default + + def __init__(self, config, args): + super().__init__(config, args) + + def ordid2token(self, texts, texts_len): + """ ord() id to chr() chr """ + trans = [] + for text, n in zip(texts, texts_len): + n = n.numpy().item() + ids = text[:n] + trans.append(''.join([chr(i) for i in ids])) + return trans + + def compute_metrics(self, + utts, + audio, + audio_len, + texts, + texts_len, + fout=None): + cfg = self.config.decoding + + errors_sum, len_refs, num_ins = 0.0, 0, 0 + errors_func = error_rate.char_errors if cfg.error_rate_type == 'cer' else error_rate.word_errors + error_rate_func = error_rate.cer if cfg.error_rate_type == 'cer' else error_rate.wer + + vocab_list = self.test_loader.collate_fn.vocab_list + + batch_size = self.config.decoding.batch_size + + output_prob_list = [] + output_lens_list = [] + decoder_chunk_size = 8 + subsampling_rate = self.model.encoder.conv.subsampling_rate + receptive_field_length = self.model.encoder.conv.receptive_field_length + chunk_stride = subsampling_rate * decoder_chunk_size + chunk_size = (decoder_chunk_size - 1 + ) * subsampling_rate + receptive_field_length + + x_batch = audio.numpy() + x_len_batch = audio_len.numpy().astype(np.int64) + max_len_batch = x_batch.shape[1] + batch_padding_len = chunk_stride - ( + max_len_batch - chunk_size + ) % chunk_stride # The length of padding for the batch + x_list = np.split(x_batch, x_batch.shape[0], axis=0) + x_len_list = np.split(x_len_batch, x_batch.shape[0], axis=0) + + for x, x_len in zip(x_list, x_len_list): + assert (chunk_size <= x_len[0]) + + eouts_chunk_list = [] + eouts_chunk_lens_list = [] + + padding_len_x = chunk_stride - (x_len[0] - chunk_size + ) % chunk_stride + padding = np.zeros( + (x.shape[0], padding_len_x, x.shape[2]), dtype=np.float32) + padded_x = np.concatenate([x, padding], axis=1) + + num_chunk = (x_len[0] + padding_len_x - chunk_size + ) / chunk_stride + 1 + num_chunk = int(num_chunk) + + chunk_state_h_box = np.zeros( + (self.config.model.num_rnn_layers, 1, + self.config.model.rnn_layer_size), + dtype=np.float32) + chunk_state_c_box = np.zeros( + (self.config.model.num_rnn_layers, 1, + self.config.model.rnn_layer_size), + dtype=np.float32) + + input_names = self.predictor.get_input_names() + audio_handle = self.predictor.get_input_handle(input_names[0]) + audio_len_handle = self.predictor.get_input_handle(input_names[1]) + h_box_handle = self.predictor.get_input_handle(input_names[2]) + c_box_handle = self.predictor.get_input_handle(input_names[3]) + + probs_chunk_list = [] + probs_chunk_lens_list = [] + for i in range(0, num_chunk): + start = i * chunk_stride + end = start + chunk_size + x_chunk = padded_x[:, start:end, :] + x_len_left = np.where(x_len - i * chunk_stride < 0, + np.zeros_like(x_len, dtype=np.int64), + x_len - i * chunk_stride) + x_chunk_len_tmp = np.ones_like( + x_len, dtype=np.int64) * chunk_size + x_chunk_lens = np.where(x_len_left < x_chunk_len_tmp, + x_len_left, x_chunk_len_tmp) + if (x_chunk_lens[0] < + receptive_field_length): #means the number of input frames in the chunk is not enough for predicting one prob + break + audio_handle.reshape(x_chunk.shape) + audio_handle.copy_from_cpu(x_chunk) + + audio_len_handle.reshape(x_chunk_lens.shape) + audio_len_handle.copy_from_cpu(x_chunk_lens) + + h_box_handle.reshape(chunk_state_h_box.shape) + h_box_handle.copy_from_cpu(chunk_state_h_box) + + c_box_handle.reshape(chunk_state_c_box.shape) + c_box_handle.copy_from_cpu(chunk_state_c_box) + + output_names = self.predictor.get_output_names() + output_handle = self.predictor.get_output_handle( + output_names[0]) + output_lens_handle = self.predictor.get_output_handle( + output_names[1]) + output_state_h_handle = self.predictor.get_output_handle( + output_names[2]) + output_state_c_handle = self.predictor.get_output_handle( + output_names[3]) + self.predictor.run() + output_chunk_prob = output_handle.copy_to_cpu() + output_chunk_lens = output_lens_handle.copy_to_cpu() + chunk_state_h_box = output_state_h_handle.copy_to_cpu() + chunk_state_c_box = output_state_c_handle.copy_to_cpu() + output_chunk_prob = paddle.to_tensor(output_chunk_prob) + output_chunk_lens = paddle.to_tensor(output_chunk_lens) + + probs_chunk_list.append(output_chunk_prob) + probs_chunk_lens_list.append(output_chunk_lens) + output_prob = paddle.concat(probs_chunk_list, axis=1) + output_lens = paddle.add_n(probs_chunk_lens_list) + output_prob_padding_len = max_len_batch + batch_padding_len - output_prob.shape[ + 1] + output_prob_padding = paddle.zeros( + (1, output_prob_padding_len, output_prob.shape[2]), + dtype="float32") # The prob padding for a piece of utterance + output_prob = paddle.concat( + [output_prob, output_prob_padding], axis=1) + output_prob_list.append(output_prob) + output_lens_list.append(output_lens) + output_prob_branch = paddle.concat(output_prob_list, axis=0) + output_lens_branch = paddle.concat(output_lens_list, axis=0) + """ + x = audio.numpy() + x_len = audio_len.numpy().astype(np.int64) + + input_names = self.predictor.get_input_names() + audio_handle = self.predictor.get_input_handle(input_names[0]) + audio_len_handle = self.predictor.get_input_handle(input_names[1]) + h_box_handle = self.predictor.get_input_handle(input_names[2]) + c_box_handle = self.predictor.get_input_handle(input_names[3]) + + + audio_handle.reshape(x.shape) + audio_handle.copy_from_cpu(x) + + audio_len_handle.reshape(x_len.shape) + audio_len_handle.copy_from_cpu(x_len) + + init_state_h_box = np.zeros((self.config.model.num_rnn_layers, audio.shape[0], self.config.model.rnn_layer_size), dtype=np.float32) + init_state_c_box = np.zeros((self.config.model.num_rnn_layers, audio.shape[0], self.config.model.rnn_layer_size), dtype=np.float32) + h_box_handle.reshape(init_state_h_box.shape) + h_box_handle.copy_from_cpu(init_state_h_box) + + c_box_handle.reshape(init_state_c_box.shape) + c_box_handle.copy_from_cpu(init_state_c_box) + + #self.autolog.times.start() + #self.autolog.times.stamp() + self.predictor.run() + + output_names = self.predictor.get_output_names() + output_handle = self.predictor.get_output_handle(output_names[0]) + output_lens_handle = self.predictor.get_output_handle(output_names[1]) + output_state_h_handle = self.predictor.get_output_handle(output_names[2]) + output_state_c_handle = self.predictor.get_output_handle(output_names[3]) + output_prob = output_handle.copy_to_cpu() + output_lens = output_lens_handle.copy_to_cpu() + output_stata_h_box = output_state_h_handle.copy_to_cpu() + output_stata_c_box = output_state_c_handle.copy_to_cpu() + output_prob_branch = paddle.to_tensor(output_prob) + output_lens_branch = paddle.to_tensor(output_lens) + """ + + result_transcripts = self.model.decode_by_probs( + output_prob_branch, + output_lens_branch, + vocab_list, + decoding_method=cfg.decoding_method, + lang_model_path=cfg.lang_model_path, + beam_alpha=cfg.alpha, + beam_beta=cfg.beta, + beam_size=cfg.beam_size, + cutoff_prob=cfg.cutoff_prob, + cutoff_top_n=cfg.cutoff_top_n, + num_processes=cfg.num_proc_bsearch) + + #self.autolog.times.stamp() + #self.autolog.times.stamp() + #self.autolog.times.end() + target_transcripts = self.ordid2token(texts, texts_len) + for utt, target, result in zip(utts, target_transcripts, + result_transcripts): + errors, len_ref = errors_func(target, result) + errors_sum += errors + len_refs += len_ref + num_ins += 1 + if fout: + fout.write(utt + " " + result + "\n") + logger.info("\nTarget Transcription: %s\nOutput Transcription: %s" % + (target, result)) + logger.info("Current error rate [%s] = %f" % + (cfg.error_rate_type, error_rate_func(target, result))) + + return dict( + errors_sum=errors_sum, + len_refs=len_refs, + num_ins=num_ins, + error_rate=errors_sum / len_refs, + error_rate_type=cfg.error_rate_type) + + @mp_tools.rank_zero_only + @paddle.no_grad() + def test(self): + logger.info(f"Test Total Examples: {len(self.test_loader.dataset)}") + #self.autolog = Autolog( + # batch_size=self.config.decoding.batch_size, + # model_name="deepspeech2", + # model_precision="fp32").getlog() + self.model.eval() + cfg = self.config + error_rate_type = None + errors_sum, len_refs, num_ins = 0.0, 0, 0 + with open(self.args.result_file, 'w') as fout: + for i, batch in enumerate(self.test_loader): + utts, audio, audio_len, texts, texts_len = batch + metrics = self.compute_metrics(utts, audio, audio_len, texts, + texts_len, fout) + errors_sum += metrics['errors_sum'] + len_refs += metrics['len_refs'] + num_ins += metrics['num_ins'] + error_rate_type = metrics['error_rate_type'] + logger.info("Error rate [%s] (%d/?) = %f" % + (error_rate_type, num_ins, errors_sum / len_refs)) + + # logging + msg = "Test: " + msg += "epoch: {}, ".format(self.epoch) + msg += "step: {}, ".format(self.iteration) + msg += "Final error rate [%s] (%d/%d) = %f" % ( + error_rate_type, num_ins, num_ins, errors_sum / len_refs) + logger.info(msg) + #self.autolog.report() + + def run_test(self): + try: + self.test() + except KeyboardInterrupt: + exit(-1) + + def run_export(self): + try: + self.export() + except KeyboardInterrupt: + exit(-1) + + def setup(self): + """Setup the experiment. + """ + paddle.set_device(self.args.device) + + self.setup_output_dir() + #self.setup_checkpointer() + + self.setup_dataloader() + self.setup_model() + + self.iteration = 0 + self.epoch = 0 + + def setup_output_dir(self): + """Create a directory used for output. + """ + # output dir + if self.args.output: + output_dir = Path(self.args.output).expanduser() + output_dir.mkdir(parents=True, exist_ok=True) + else: + output_dir = Path(self.args.export_path).expanduser().parent.parent + output_dir.mkdir(parents=True, exist_ok=True) + + self.output_dir = output_dir + + def setup_model(self): + super().setup_model() + if self.args.model_type == 'online': + #inference_dir = "exp/deepspeech2_online/checkpoints/" + #inference_dir = "exp/deepspeech2_online_3rr_1fc_lr_decay0.91_lstm/checkpoints/" + #speedyspeech_config = inference.Config( + # str(Path(inference_dir) / "avg_1.jit.pdmodel"), + # str(Path(inference_dir) / "avg_1.jit.pdiparams")) + speedyspeech_config = inference.Config( + self.args.export_path + ".pdmodel", + self.args.export_path + ".pdiparams") + speedyspeech_config.enable_use_gpu(100, 0) + speedyspeech_config.enable_memory_optim() + speedyspeech_predictor = inference.create_predictor( + speedyspeech_config) + self.predictor = speedyspeech_predictor diff --git a/deepspeech/models/ds2_online/conv.py b/deepspeech/models/ds2_online/conv.py index 4a6fd5ab..a98786e6 100644 --- a/deepspeech/models/ds2_online/conv.py +++ b/deepspeech/models/ds2_online/conv.py @@ -30,4 +30,6 @@ class Conv2dSubsampling4Online(Conv2dSubsampling4): #b, c, t, f = paddle.shape(x) #not work under jit x = x.transpose([0, 2, 1, 3]).reshape([0, 0, -1]) x_len = ((x_len - 1) // 2 - 1) // 2 + x_len = paddle.where(x_len >= 0, x_len, + paddle.zeros_like(x_len.shape, "int64")) return x, x_len diff --git a/deepspeech/models/ds2_online/deepspeech2.py b/deepspeech/models/ds2_online/deepspeech2.py index d092b154..77311929 100644 --- a/deepspeech/models/ds2_online/deepspeech2.py +++ b/deepspeech/models/ds2_online/deepspeech2.py @@ -325,6 +325,24 @@ class DeepSpeech2ModelOnline(nn.Layer): lang_model_path, beam_alpha, beam_beta, beam_size, cutoff_prob, cutoff_top_n, num_processes) + @paddle.no_grad() + def decode_by_probs(self, probs, probs_len, vocab_list, decoding_method, + lang_model_path, beam_alpha, beam_beta, beam_size, + cutoff_prob, cutoff_top_n, num_processes): + # init once + # decoders only accept string encoded in utf-8 + self.decoder.init_decode( + beam_alpha=beam_alpha, + beam_beta=beam_beta, + lang_model_path=lang_model_path, + vocab_list=vocab_list, + decoding_method=decoding_method) + + return self.decoder.decode_probs( + probs.numpy(), probs_len, vocab_list, decoding_method, + lang_model_path, beam_alpha, beam_beta, beam_size, cutoff_prob, + cutoff_top_n, num_processes) + @classmethod def from_pretrained(cls, dataloader, config, checkpoint_path): """Build a DeepSpeech2Model model from a pretrained model. diff --git a/examples/aishell/s0/local/test_export.sh b/examples/aishell/s0/local/test_export.sh new file mode 100755 index 00000000..9c863377 --- /dev/null +++ b/examples/aishell/s0/local/test_export.sh @@ -0,0 +1,39 @@ +#!/bin/bash + +if [ $# != 3 ];then + echo "usage: ${0} config_path ckpt_path_prefix model_type" + exit -1 +fi + +ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') +echo "using $ngpu gpus..." + +device=gpu +if [ ${ngpu} == 0 ];then + device=cpu +fi +config_path=$1 +jit_model_export_path=$2 +model_type=$3 + +# download language model +bash local/download_lm_ch.sh +if [ $? -ne 0 ]; then + exit 1 +fi + +python3 -u ${BIN_DIR}/test_export.py \ +--device ${device} \ +--nproc 1 \ +--config ${config_path} \ +--result_file ${ckpt_prefix}.rsl \ +--export_path ${jit_model_export_path} \ +--model_type ${model_type} + +if [ $? -ne 0 ]; then + echo "Failed in evaluation!" + exit 1 +fi + + +exit 0 diff --git a/examples/aishell/s0/run.sh b/examples/aishell/s0/run.sh index 7cd63999..e5ab12a5 100755 --- a/examples/aishell/s0/run.sh +++ b/examples/aishell/s0/run.sh @@ -39,3 +39,8 @@ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then # export ckpt avg_n CUDA_VISIBLE_DEVICES=0 ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit ${model_type} fi + +if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then + # test export ckpt avg_n + CUDA_VISIBLE_DEVICES=0 ./local/test_export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt}.jit ${model_type}|| exit -1 +fi From 0d0b581181988da7ab64fbe8c50047e468e1bba8 Mon Sep 17 00:00:00 2001 From: huangyuxin Date: Wed, 25 Aug 2021 08:27:22 +0000 Subject: [PATCH 2/5] add static_forward_online and static_forward_offline --- deepspeech/exps/deepspeech2/model.py | 233 +++++++------------- deepspeech/models/ds2/deepspeech2.py | 2 +- deepspeech/models/ds2_online/deepspeech2.py | 18 -- 3 files changed, 79 insertions(+), 174 deletions(-) diff --git a/deepspeech/exps/deepspeech2/model.py b/deepspeech/exps/deepspeech2/model.py index e00439a0..f386336a 100644 --- a/deepspeech/exps/deepspeech2/model.py +++ b/deepspeech/exps/deepspeech2/model.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Contains DeepSpeech2 and DeepSpeech2Online model.""" +import os import time from collections import defaultdict from pathlib import Path @@ -398,40 +399,10 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): self.output_dir = output_dir -class DeepSpeech2ExportTester(DeepSpeech2Trainer): - @classmethod - def params(cls, config: Optional[CfgNode]=None) -> CfgNode: - # testing config - default = CfgNode( - dict( - alpha=2.5, # Coef of LM for beam search. - beta=0.3, # Coef of WC for beam search. - cutoff_prob=1.0, # Cutoff probability for pruning. - cutoff_top_n=40, # Cutoff number for pruning. - lang_model_path='models/lm/common_crawl_00.prune01111.trie.klm', # Filepath for language model. - decoding_method='ctc_beam_search', # Decoding method. Options: ctc_beam_search, ctc_greedy - error_rate_type='wer', # Error rate type for evaluation. Options `wer`, 'cer' - num_proc_bsearch=8, # # of CPUs for beam search. - beam_size=500, # Beam search width. - batch_size=128, # decoding batch size - )) - - if config is not None: - config.merge_from_other_cfg(default) - return default - +class DeepSpeech2ExportTester(DeepSpeech2Tester): def __init__(self, config, args): super().__init__(config, args) - def ordid2token(self, texts, texts_len): - """ ord() id to chr() chr """ - trans = [] - for text, n in zip(texts, texts_len): - n = n.numpy().item() - ids = text[:n] - trans.append(''.join([chr(i) for i in ids])) - return trans - def compute_metrics(self, utts, audio, @@ -447,9 +418,48 @@ class DeepSpeech2ExportTester(DeepSpeech2Trainer): vocab_list = self.test_loader.collate_fn.vocab_list - batch_size = self.config.decoding.batch_size + if self.args.model_type == "online": + output_probs_branch, output_lens_branch = self.static_forward_online( + audio, audio_len) + elif self.args.model_type == "offline": + output_probs_branch, output_lens_branch = self.static_forward_offline( + audio, audio_len) + else: + raise Exception("wrong model type") + self.predictor.clear_intermediate_tensor() + self.predictor.try_shrink_memory() + self.model.decoder.init_decode(cfg.alpha, cfg.beta, cfg.lang_model_path, + vocab_list, cfg.decoding_method) + + result_transcripts = self.model.decoder.decode_probs( + output_probs_branch.numpy(), output_lens_branch, vocab_list, + cfg.decoding_method, cfg.lang_model_path, cfg.alpha, cfg.beta, + cfg.beam_size, cfg.cutoff_prob, cfg.cutoff_top_n, + cfg.num_proc_bsearch) - output_prob_list = [] + target_transcripts = self.ordid2token(texts, texts_len) + for utt, target, result in zip(utts, target_transcripts, + result_transcripts): + errors, len_ref = errors_func(target, result) + errors_sum += errors + len_refs += len_ref + num_ins += 1 + if fout: + fout.write(utt + " " + result + "\n") + logger.info("\nTarget Transcription: %s\nOutput Transcription: %s" % + (target, result)) + logger.info("Current error rate [%s] = %f" % + (cfg.error_rate_type, error_rate_func(target, result))) + + return dict( + errors_sum=errors_sum, + len_refs=len_refs, + num_ins=num_ins, + error_rate=errors_sum / len_refs, + error_rate_type=cfg.error_rate_type) + + def static_forward_online(self, audio, audio_len): + output_probs_list = [] output_lens_list = [] decoder_chunk_size = 8 subsampling_rate = self.model.encoder.conv.subsampling_rate @@ -459,15 +469,18 @@ class DeepSpeech2ExportTester(DeepSpeech2Trainer): ) * subsampling_rate + receptive_field_length x_batch = audio.numpy() + batch_size = x_batch.shape[0] x_len_batch = audio_len.numpy().astype(np.int64) max_len_batch = x_batch.shape[1] batch_padding_len = chunk_stride - ( max_len_batch - chunk_size ) % chunk_stride # The length of padding for the batch - x_list = np.split(x_batch, x_batch.shape[0], axis=0) + x_list = np.split(x_batch, batch_size, axis=0) x_len_list = np.split(x_len_batch, x_batch.shape[0], axis=0) for x, x_len in zip(x_list, x_len_list): + self.autolog.times.start() + self.autolog.times.stamp() assert (chunk_size <= x_len[0]) eouts_chunk_list = [] @@ -536,38 +549,40 @@ class DeepSpeech2ExportTester(DeepSpeech2Trainer): output_state_c_handle = self.predictor.get_output_handle( output_names[3]) self.predictor.run() - output_chunk_prob = output_handle.copy_to_cpu() + output_chunk_probs = output_handle.copy_to_cpu() output_chunk_lens = output_lens_handle.copy_to_cpu() chunk_state_h_box = output_state_h_handle.copy_to_cpu() chunk_state_c_box = output_state_c_handle.copy_to_cpu() - output_chunk_prob = paddle.to_tensor(output_chunk_prob) + output_chunk_probs = paddle.to_tensor(output_chunk_probs) output_chunk_lens = paddle.to_tensor(output_chunk_lens) - probs_chunk_list.append(output_chunk_prob) + probs_chunk_list.append(output_chunk_probs) probs_chunk_lens_list.append(output_chunk_lens) - output_prob = paddle.concat(probs_chunk_list, axis=1) + output_probs = paddle.concat(probs_chunk_list, axis=1) output_lens = paddle.add_n(probs_chunk_lens_list) - output_prob_padding_len = max_len_batch + batch_padding_len - output_prob.shape[ + output_probs_padding_len = max_len_batch + batch_padding_len - output_probs.shape[ 1] - output_prob_padding = paddle.zeros( - (1, output_prob_padding_len, output_prob.shape[2]), + output_probs_padding = paddle.zeros( + (1, output_probs_padding_len, output_probs.shape[2]), dtype="float32") # The prob padding for a piece of utterance - output_prob = paddle.concat( - [output_prob, output_prob_padding], axis=1) - output_prob_list.append(output_prob) + output_probs = paddle.concat( + [output_probs, output_probs_padding], axis=1) + output_probs_list.append(output_probs) output_lens_list.append(output_lens) - output_prob_branch = paddle.concat(output_prob_list, axis=0) + self.autolog.times.stamp() + self.autolog.times.stamp() + self.autolog.times.end() + output_probs_branch = paddle.concat(output_probs_list, axis=0) output_lens_branch = paddle.concat(output_lens_list, axis=0) - """ + return output_probs_branch, output_lens_branch + + def static_forward_offline(self, audio, audio_len): x = audio.numpy() x_len = audio_len.numpy().astype(np.int64) input_names = self.predictor.get_input_names() audio_handle = self.predictor.get_input_handle(input_names[0]) audio_len_handle = self.predictor.get_input_handle(input_names[1]) - h_box_handle = self.predictor.get_input_handle(input_names[2]) - c_box_handle = self.predictor.get_input_handle(input_names[3]) - audio_handle.reshape(x.shape) audio_handle.copy_from_cpu(x) @@ -575,100 +590,21 @@ class DeepSpeech2ExportTester(DeepSpeech2Trainer): audio_len_handle.reshape(x_len.shape) audio_len_handle.copy_from_cpu(x_len) - init_state_h_box = np.zeros((self.config.model.num_rnn_layers, audio.shape[0], self.config.model.rnn_layer_size), dtype=np.float32) - init_state_c_box = np.zeros((self.config.model.num_rnn_layers, audio.shape[0], self.config.model.rnn_layer_size), dtype=np.float32) - h_box_handle.reshape(init_state_h_box.shape) - h_box_handle.copy_from_cpu(init_state_h_box) - - c_box_handle.reshape(init_state_c_box.shape) - c_box_handle.copy_from_cpu(init_state_c_box) - - #self.autolog.times.start() - #self.autolog.times.stamp() + self.autolog.times.start() + self.autolog.times.stamp() self.predictor.run() + self.autolog.times.stamp() + self.autolog.times.stamp() + self.autolog.times.end() output_names = self.predictor.get_output_names() output_handle = self.predictor.get_output_handle(output_names[0]) output_lens_handle = self.predictor.get_output_handle(output_names[1]) - output_state_h_handle = self.predictor.get_output_handle(output_names[2]) - output_state_c_handle = self.predictor.get_output_handle(output_names[3]) - output_prob = output_handle.copy_to_cpu() + output_probs = output_handle.copy_to_cpu() output_lens = output_lens_handle.copy_to_cpu() - output_stata_h_box = output_state_h_handle.copy_to_cpu() - output_stata_c_box = output_state_c_handle.copy_to_cpu() - output_prob_branch = paddle.to_tensor(output_prob) + output_probs_branch = paddle.to_tensor(output_probs) output_lens_branch = paddle.to_tensor(output_lens) - """ - - result_transcripts = self.model.decode_by_probs( - output_prob_branch, - output_lens_branch, - vocab_list, - decoding_method=cfg.decoding_method, - lang_model_path=cfg.lang_model_path, - beam_alpha=cfg.alpha, - beam_beta=cfg.beta, - beam_size=cfg.beam_size, - cutoff_prob=cfg.cutoff_prob, - cutoff_top_n=cfg.cutoff_top_n, - num_processes=cfg.num_proc_bsearch) - - #self.autolog.times.stamp() - #self.autolog.times.stamp() - #self.autolog.times.end() - target_transcripts = self.ordid2token(texts, texts_len) - for utt, target, result in zip(utts, target_transcripts, - result_transcripts): - errors, len_ref = errors_func(target, result) - errors_sum += errors - len_refs += len_ref - num_ins += 1 - if fout: - fout.write(utt + " " + result + "\n") - logger.info("\nTarget Transcription: %s\nOutput Transcription: %s" % - (target, result)) - logger.info("Current error rate [%s] = %f" % - (cfg.error_rate_type, error_rate_func(target, result))) - - return dict( - errors_sum=errors_sum, - len_refs=len_refs, - num_ins=num_ins, - error_rate=errors_sum / len_refs, - error_rate_type=cfg.error_rate_type) - - @mp_tools.rank_zero_only - @paddle.no_grad() - def test(self): - logger.info(f"Test Total Examples: {len(self.test_loader.dataset)}") - #self.autolog = Autolog( - # batch_size=self.config.decoding.batch_size, - # model_name="deepspeech2", - # model_precision="fp32").getlog() - self.model.eval() - cfg = self.config - error_rate_type = None - errors_sum, len_refs, num_ins = 0.0, 0, 0 - with open(self.args.result_file, 'w') as fout: - for i, batch in enumerate(self.test_loader): - utts, audio, audio_len, texts, texts_len = batch - metrics = self.compute_metrics(utts, audio, audio_len, texts, - texts_len, fout) - errors_sum += metrics['errors_sum'] - len_refs += metrics['len_refs'] - num_ins += metrics['num_ins'] - error_rate_type = metrics['error_rate_type'] - logger.info("Error rate [%s] (%d/?) = %f" % - (error_rate_type, num_ins, errors_sum / len_refs)) - - # logging - msg = "Test: " - msg += "epoch: {}, ".format(self.epoch) - msg += "step: {}, ".format(self.iteration) - msg += "Final error rate [%s] (%d/%d) = %f" % ( - error_rate_type, num_ins, num_ins, errors_sum / len_refs) - logger.info(msg) - #self.autolog.report() + return output_probs_branch, output_lens_branch def run_test(self): try: @@ -676,19 +612,12 @@ class DeepSpeech2ExportTester(DeepSpeech2Trainer): except KeyboardInterrupt: exit(-1) - def run_export(self): - try: - self.export() - except KeyboardInterrupt: - exit(-1) - def setup(self): """Setup the experiment. """ paddle.set_device(self.args.device) self.setup_output_dir() - #self.setup_checkpointer() self.setup_dataloader() self.setup_model() @@ -711,17 +640,11 @@ class DeepSpeech2ExportTester(DeepSpeech2Trainer): def setup_model(self): super().setup_model() - if self.args.model_type == 'online': - #inference_dir = "exp/deepspeech2_online/checkpoints/" - #inference_dir = "exp/deepspeech2_online_3rr_1fc_lr_decay0.91_lstm/checkpoints/" - #speedyspeech_config = inference.Config( - # str(Path(inference_dir) / "avg_1.jit.pdmodel"), - # str(Path(inference_dir) / "avg_1.jit.pdiparams")) - speedyspeech_config = inference.Config( - self.args.export_path + ".pdmodel", - self.args.export_path + ".pdiparams") + speedyspeech_config = inference.Config( + self.args.export_path + ".pdmodel", + self.args.export_path + ".pdiparams") + if (os.environ['CUDA_VISIBLE_DEVICES'].strip() != ''): speedyspeech_config.enable_use_gpu(100, 0) speedyspeech_config.enable_memory_optim() - speedyspeech_predictor = inference.create_predictor( - speedyspeech_config) - self.predictor = speedyspeech_predictor + speedyspeech_predictor = inference.create_predictor(speedyspeech_config) + self.predictor = speedyspeech_predictor diff --git a/deepspeech/models/ds2/deepspeech2.py b/deepspeech/models/ds2/deepspeech2.py index 1ffd797b..5f8f3255 100644 --- a/deepspeech/models/ds2/deepspeech2.py +++ b/deepspeech/models/ds2/deepspeech2.py @@ -280,7 +280,7 @@ class DeepSpeech2InferModel(DeepSpeech2Model): """ eouts, eouts_len = self.encoder(audio, audio_len) probs = self.decoder.softmax(eouts) - return probs + return probs, eouts_len def export(self): static_model = paddle.jit.to_static( diff --git a/deepspeech/models/ds2_online/deepspeech2.py b/deepspeech/models/ds2_online/deepspeech2.py index 77311929..d092b154 100644 --- a/deepspeech/models/ds2_online/deepspeech2.py +++ b/deepspeech/models/ds2_online/deepspeech2.py @@ -325,24 +325,6 @@ class DeepSpeech2ModelOnline(nn.Layer): lang_model_path, beam_alpha, beam_beta, beam_size, cutoff_prob, cutoff_top_n, num_processes) - @paddle.no_grad() - def decode_by_probs(self, probs, probs_len, vocab_list, decoding_method, - lang_model_path, beam_alpha, beam_beta, beam_size, - cutoff_prob, cutoff_top_n, num_processes): - # init once - # decoders only accept string encoded in utf-8 - self.decoder.init_decode( - beam_alpha=beam_alpha, - beam_beta=beam_beta, - lang_model_path=lang_model_path, - vocab_list=vocab_list, - decoding_method=decoding_method) - - return self.decoder.decode_probs( - probs.numpy(), probs_len, vocab_list, decoding_method, - lang_model_path, beam_alpha, beam_beta, beam_size, cutoff_prob, - cutoff_top_n, num_processes) - @classmethod def from_pretrained(cls, dataloader, config, checkpoint_path): """Build a DeepSpeech2Model model from a pretrained model. From 1f050a4d018d771be26d742828c291dd340a9d4d Mon Sep 17 00:00:00 2001 From: huangyuxin Date: Wed, 25 Aug 2021 11:54:50 +0000 Subject: [PATCH 3/5] make the code simple --- deepspeech/exps/deepspeech2/model.py | 96 ++++++++++------------------ 1 file changed, 34 insertions(+), 62 deletions(-) diff --git a/deepspeech/exps/deepspeech2/model.py b/deepspeech/exps/deepspeech2/model.py index f386336a..74d9b205 100644 --- a/deepspeech/exps/deepspeech2/model.py +++ b/deepspeech/exps/deepspeech2/model.py @@ -270,24 +270,9 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): vocab_list = self.test_loader.collate_fn.vocab_list target_transcripts = self.ordid2token(texts, texts_len) - self.autolog.times.start() - self.autolog.times.stamp() - result_transcripts = self.model.decode( - audio, - audio_len, - vocab_list, - decoding_method=cfg.decoding_method, - lang_model_path=cfg.lang_model_path, - beam_alpha=cfg.alpha, - beam_beta=cfg.beta, - beam_size=cfg.beam_size, - cutoff_prob=cfg.cutoff_prob, - cutoff_top_n=cfg.cutoff_top_n, - num_processes=cfg.num_proc_bsearch) - self.autolog.times.stamp() - self.autolog.times.stamp() - self.autolog.times.end() + result_transcripts = self.compute_result_transcripts(audio, audio_len, + vocab_list, cfg) for utt, target, result in zip(utts, target_transcripts, result_transcripts): errors, len_ref = errors_func(target, result) @@ -308,6 +293,26 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): error_rate=errors_sum / len_refs, error_rate_type=cfg.error_rate_type) + def compute_result_transcripts(self, audio, audio_len, vocab_list, cfg): + self.autolog.times.start() + self.autolog.times.stamp() + result_transcripts = self.model.decode( + audio, + audio_len, + vocab_list, + decoding_method=cfg.decoding_method, + lang_model_path=cfg.lang_model_path, + beam_alpha=cfg.alpha, + beam_beta=cfg.beta, + beam_size=cfg.beam_size, + cutoff_prob=cfg.cutoff_prob, + cutoff_top_n=cfg.cutoff_top_n, + num_processes=cfg.num_proc_bsearch) + self.autolog.times.stamp() + self.autolog.times.stamp() + self.autolog.times.end() + return result_transcripts + @mp_tools.rank_zero_only @paddle.no_grad() def test(self): @@ -403,21 +408,7 @@ class DeepSpeech2ExportTester(DeepSpeech2Tester): def __init__(self, config, args): super().__init__(config, args) - def compute_metrics(self, - utts, - audio, - audio_len, - texts, - texts_len, - fout=None): - cfg = self.config.decoding - - errors_sum, len_refs, num_ins = 0.0, 0, 0 - errors_func = error_rate.char_errors if cfg.error_rate_type == 'cer' else error_rate.word_errors - error_rate_func = error_rate.cer if cfg.error_rate_type == 'cer' else error_rate.wer - - vocab_list = self.test_loader.collate_fn.vocab_list - + def compute_result_transcripts(self, audio, audio_len, vocab_list, cfg): if self.args.model_type == "online": output_probs_branch, output_lens_branch = self.static_forward_online( audio, audio_len) @@ -437,31 +428,12 @@ class DeepSpeech2ExportTester(DeepSpeech2Tester): cfg.beam_size, cfg.cutoff_prob, cfg.cutoff_top_n, cfg.num_proc_bsearch) - target_transcripts = self.ordid2token(texts, texts_len) - for utt, target, result in zip(utts, target_transcripts, - result_transcripts): - errors, len_ref = errors_func(target, result) - errors_sum += errors - len_refs += len_ref - num_ins += 1 - if fout: - fout.write(utt + " " + result + "\n") - logger.info("\nTarget Transcription: %s\nOutput Transcription: %s" % - (target, result)) - logger.info("Current error rate [%s] = %f" % - (cfg.error_rate_type, error_rate_func(target, result))) - - return dict( - errors_sum=errors_sum, - len_refs=len_refs, - num_ins=num_ins, - error_rate=errors_sum / len_refs, - error_rate_type=cfg.error_rate_type) + return result_transcripts def static_forward_online(self, audio, audio_len): output_probs_list = [] output_lens_list = [] - decoder_chunk_size = 8 + decoder_chunk_size = 1 subsampling_rate = self.model.encoder.conv.subsampling_rate receptive_field_length = self.model.encoder.conv.receptive_field_length chunk_stride = subsampling_rate * decoder_chunk_size @@ -553,27 +525,27 @@ class DeepSpeech2ExportTester(DeepSpeech2Tester): output_chunk_lens = output_lens_handle.copy_to_cpu() chunk_state_h_box = output_state_h_handle.copy_to_cpu() chunk_state_c_box = output_state_c_handle.copy_to_cpu() - output_chunk_probs = paddle.to_tensor(output_chunk_probs) - output_chunk_lens = paddle.to_tensor(output_chunk_lens) probs_chunk_list.append(output_chunk_probs) probs_chunk_lens_list.append(output_chunk_lens) - output_probs = paddle.concat(probs_chunk_list, axis=1) - output_lens = paddle.add_n(probs_chunk_lens_list) + output_probs = np.concatenate(probs_chunk_list, axis=1) + output_lens = np.sum(probs_chunk_lens_list, axis=0) output_probs_padding_len = max_len_batch + batch_padding_len - output_probs.shape[ 1] - output_probs_padding = paddle.zeros( + output_probs_padding = np.zeros( (1, output_probs_padding_len, output_probs.shape[2]), - dtype="float32") # The prob padding for a piece of utterance - output_probs = paddle.concat( + dtype=np.float32) # The prob padding for a piece of utterance + output_probs = np.concatenate( [output_probs, output_probs_padding], axis=1) output_probs_list.append(output_probs) output_lens_list.append(output_lens) self.autolog.times.stamp() self.autolog.times.stamp() self.autolog.times.end() - output_probs_branch = paddle.concat(output_probs_list, axis=0) - output_lens_branch = paddle.concat(output_lens_list, axis=0) + output_probs_branch = np.concatenate(output_probs_list, axis=0) + output_lens_branch = np.concatenate(output_lens_list, axis=0) + output_probs_branch = paddle.to_tensor(output_probs_branch) + output_lens_branch = paddle.to_tensor(output_lens_branch) return output_probs_branch, output_lens_branch def static_forward_offline(self, audio, audio_len): From 317ffea5e5b39d917c051543dfdf074242510e6d Mon Sep 17 00:00:00 2001 From: huangyuxin Date: Thu, 26 Aug 2021 12:46:51 +0000 Subject: [PATCH 4/5] simplify the code --- deepspeech/exps/deepspeech2/model.py | 109 ++++++++++++++++----------- 1 file changed, 65 insertions(+), 44 deletions(-) diff --git a/deepspeech/exps/deepspeech2/model.py b/deepspeech/exps/deepspeech2/model.py index 74d9b205..0e0e83c0 100644 --- a/deepspeech/exps/deepspeech2/model.py +++ b/deepspeech/exps/deepspeech2/model.py @@ -410,30 +410,42 @@ class DeepSpeech2ExportTester(DeepSpeech2Tester): def compute_result_transcripts(self, audio, audio_len, vocab_list, cfg): if self.args.model_type == "online": - output_probs_branch, output_lens_branch = self.static_forward_online( - audio, audio_len) + output_probs, output_lens = self.static_forward_online(audio, + audio_len) elif self.args.model_type == "offline": - output_probs_branch, output_lens_branch = self.static_forward_offline( - audio, audio_len) + output_probs, output_lens = self.static_forward_offline(audio, + audio_len) else: raise Exception("wrong model type") + self.predictor.clear_intermediate_tensor() self.predictor.try_shrink_memory() + self.model.decoder.init_decode(cfg.alpha, cfg.beta, cfg.lang_model_path, vocab_list, cfg.decoding_method) result_transcripts = self.model.decoder.decode_probs( - output_probs_branch.numpy(), output_lens_branch, vocab_list, - cfg.decoding_method, cfg.lang_model_path, cfg.alpha, cfg.beta, - cfg.beam_size, cfg.cutoff_prob, cfg.cutoff_top_n, - cfg.num_proc_bsearch) + output_probs, output_lens, vocab_list, cfg.decoding_method, + cfg.lang_model_path, cfg.alpha, cfg.beta, cfg.beam_size, + cfg.cutoff_prob, cfg.cutoff_top_n, cfg.num_proc_bsearch) return result_transcripts - def static_forward_online(self, audio, audio_len): + def static_forward_online(self, audio, audio_len, + decoder_chunk_size: int=1): + """ + Parameters + ---------- + audio (Tensor): shape[B, T, D] + audio_len (Tensor): shape[B] + decoder_chunk_size(int) + Returns + ------- + output_probs(numpy.array): shape[B, T, vocab_size] + output_lens(numpy.array): shape[B] + """ output_probs_list = [] output_lens_list = [] - decoder_chunk_size = 1 subsampling_rate = self.model.encoder.conv.subsampling_rate receptive_field_length = self.model.encoder.conv.receptive_field_length chunk_stride = subsampling_rate * decoder_chunk_size @@ -441,41 +453,42 @@ class DeepSpeech2ExportTester(DeepSpeech2Tester): ) * subsampling_rate + receptive_field_length x_batch = audio.numpy() - batch_size = x_batch.shape[0] + batch_size, Tmax, x_dim = x_batch.shape x_len_batch = audio_len.numpy().astype(np.int64) - max_len_batch = x_batch.shape[1] - batch_padding_len = chunk_stride - ( - max_len_batch - chunk_size + + padding_len_batch = chunk_stride - ( + Tmax - chunk_size ) % chunk_stride # The length of padding for the batch x_list = np.split(x_batch, batch_size, axis=0) - x_len_list = np.split(x_len_batch, x_batch.shape[0], axis=0) + x_len_list = np.split(x_len_batch, batch_size, axis=0) for x, x_len in zip(x_list, x_len_list): self.autolog.times.start() self.autolog.times.stamp() - assert (chunk_size <= x_len[0]) + x_len = x_len[0] + assert (chunk_size <= x_len) - eouts_chunk_list = [] - eouts_chunk_lens_list = [] + if (x_len - chunk_size) % chunk_stride != 0: + padding_len_x = chunk_stride - (x_len - chunk_size + ) % chunk_stride + else: + padding_len_x = 0 - padding_len_x = chunk_stride - (x_len[0] - chunk_size - ) % chunk_stride padding = np.zeros( - (x.shape[0], padding_len_x, x.shape[2]), dtype=np.float32) + (x.shape[0], padding_len_x, x.shape[2]), dtype=x.dtype) padded_x = np.concatenate([x, padding], axis=1) - num_chunk = (x_len[0] + padding_len_x - chunk_size - ) / chunk_stride + 1 + num_chunk = (x_len + padding_len_x - chunk_size) / chunk_stride + 1 num_chunk = int(num_chunk) chunk_state_h_box = np.zeros( (self.config.model.num_rnn_layers, 1, self.config.model.rnn_layer_size), - dtype=np.float32) + dtype=x.dtype) chunk_state_c_box = np.zeros( (self.config.model.num_rnn_layers, 1, self.config.model.rnn_layer_size), - dtype=np.float32) + dtype=x.dtype) input_names = self.predictor.get_input_names() audio_handle = self.predictor.get_input_handle(input_names[0]) @@ -489,16 +502,15 @@ class DeepSpeech2ExportTester(DeepSpeech2Tester): start = i * chunk_stride end = start + chunk_size x_chunk = padded_x[:, start:end, :] - x_len_left = np.where(x_len - i * chunk_stride < 0, - np.zeros_like(x_len, dtype=np.int64), - x_len - i * chunk_stride) - x_chunk_len_tmp = np.ones_like( - x_len, dtype=np.int64) * chunk_size - x_chunk_lens = np.where(x_len_left < x_chunk_len_tmp, - x_len_left, x_chunk_len_tmp) - if (x_chunk_lens[0] < + if x_len < i * chunk_stride: + x_chunk_lens = 0 + else: + x_chunk_lens = min(x_len - i * chunk_stride, chunk_size) + + if (x_chunk_lens < receptive_field_length): #means the number of input frames in the chunk is not enough for predicting one prob break + x_chunk_lens = np.array([x_chunk_lens]) audio_handle.reshape(x_chunk.shape) audio_handle.copy_from_cpu(x_chunk) @@ -530,11 +542,13 @@ class DeepSpeech2ExportTester(DeepSpeech2Tester): probs_chunk_lens_list.append(output_chunk_lens) output_probs = np.concatenate(probs_chunk_list, axis=1) output_lens = np.sum(probs_chunk_lens_list, axis=0) - output_probs_padding_len = max_len_batch + batch_padding_len - output_probs.shape[ + vocab_size = output_probs.shape[2] + output_probs_padding_len = Tmax + padding_len_batch - output_probs.shape[ 1] output_probs_padding = np.zeros( - (1, output_probs_padding_len, output_probs.shape[2]), - dtype=np.float32) # The prob padding for a piece of utterance + (1, output_probs_padding_len, vocab_size), + dtype=output_probs. + dtype) # The prob padding for a piece of utterance output_probs = np.concatenate( [output_probs, output_probs_padding], axis=1) output_probs_list.append(output_probs) @@ -542,13 +556,22 @@ class DeepSpeech2ExportTester(DeepSpeech2Tester): self.autolog.times.stamp() self.autolog.times.stamp() self.autolog.times.end() - output_probs_branch = np.concatenate(output_probs_list, axis=0) - output_lens_branch = np.concatenate(output_lens_list, axis=0) - output_probs_branch = paddle.to_tensor(output_probs_branch) - output_lens_branch = paddle.to_tensor(output_lens_branch) - return output_probs_branch, output_lens_branch + output_probs = np.concatenate(output_probs_list, axis=0) + output_lens = np.concatenate(output_lens_list, axis=0) + return output_probs, output_lens def static_forward_offline(self, audio, audio_len): + """ + Parameters + ---------- + audio (Tensor): shape[B, T, D] + audio_len (Tensor): shape[B] + + Returns + ------- + output_probs(numpy.array): shape[B, T, vocab_size] + output_lens(numpy.array): shape[B] + """ x = audio.numpy() x_len = audio_len.numpy().astype(np.int64) @@ -574,9 +597,7 @@ class DeepSpeech2ExportTester(DeepSpeech2Tester): output_lens_handle = self.predictor.get_output_handle(output_names[1]) output_probs = output_handle.copy_to_cpu() output_lens = output_lens_handle.copy_to_cpu() - output_probs_branch = paddle.to_tensor(output_probs) - output_lens_branch = paddle.to_tensor(output_lens) - return output_probs_branch, output_lens_branch + return output_probs, output_lens def run_test(self): try: From 2451a177b0875f992119b4dc2377b422914d5fc9 Mon Sep 17 00:00:00 2001 From: huangyuxin Date: Fri, 27 Aug 2021 08:59:39 +0000 Subject: [PATCH 5/5] fix paddling len bug --- deepspeech/exps/deepspeech2/model.py | 10 ++++++---- deepspeech/models/ds2_online/deepspeech2.py | 4 ++-- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/deepspeech/exps/deepspeech2/model.py b/deepspeech/exps/deepspeech2/model.py index 0e0e83c0..f3e3fcad 100644 --- a/deepspeech/exps/deepspeech2/model.py +++ b/deepspeech/exps/deepspeech2/model.py @@ -455,10 +455,12 @@ class DeepSpeech2ExportTester(DeepSpeech2Tester): x_batch = audio.numpy() batch_size, Tmax, x_dim = x_batch.shape x_len_batch = audio_len.numpy().astype(np.int64) - - padding_len_batch = chunk_stride - ( - Tmax - chunk_size - ) % chunk_stride # The length of padding for the batch + if (Tmax - chunk_size) % chunk_stride != 0: + padding_len_batch = chunk_stride - ( + Tmax - chunk_size + ) % chunk_stride # The length of padding for the batch + else: + padding_len_batch = 0 x_list = np.split(x_batch, batch_size, axis=0) x_len_list = np.split(x_len_batch, batch_size, axis=0) diff --git a/deepspeech/models/ds2_online/deepspeech2.py b/deepspeech/models/ds2_online/deepspeech2.py index d092b154..d0fbdcf6 100644 --- a/deepspeech/models/ds2_online/deepspeech2.py +++ b/deepspeech/models/ds2_online/deepspeech2.py @@ -100,12 +100,12 @@ class CRNNEncoder(nn.Layer): """Compute Encoder outputs Args: - x (Tensor): [B, feature_size, D] + x (Tensor): [B, T, D] x_lens (Tensor): [B] init_state_h_box(Tensor): init_states h for RNN layers: [num_rnn_layers * num_directions, batch_size, hidden_size] init_state_c_box(Tensor): init_states c for RNN layers: [num_rnn_layers * num_directions, batch_size, hidden_size] Return: - x (Tensor): encoder outputs, [B, size, D] + x (Tensor): encoder outputs, [B, T, D] x_lens (Tensor): encoder length, [B] final_state_h_box(Tensor): final_states h for RNN layers: [num_rnn_layers * num_directions, batch_size, hidden_size] final_state_c_box(Tensor): final_states c for RNN layers: [num_rnn_layers * num_directions, batch_size, hidden_size]