Merge pull request #786 from Jackwaterveg/ds2_online

[Static model test] Add the test process for export model
pull/794/head
Hui Zhang 3 years ago committed by GitHub
commit 7181e427af
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,58 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Evaluation for DeepSpeech2 model."""
from deepspeech.exps.deepspeech2.config import get_cfg_defaults
from deepspeech.exps.deepspeech2.model import DeepSpeech2ExportTester as ExportTester
from deepspeech.training.cli import default_argument_parser
from deepspeech.utils.utility import print_arguments
def main_sp(config, args):
exp = ExportTester(config, args)
exp.setup()
exp.run_test()
def main(config, args):
main_sp(config, args)
if __name__ == "__main__":
parser = default_argument_parser()
# save asr result to
parser.add_argument(
"--result_file", type=str, help="path of save the asr result")
#load jit model from
parser.add_argument(
"--export_path", type=str, help="path of the jit model to save")
parser.add_argument("--model_type")
args = parser.parse_args()
print_arguments(args, globals())
if args.model_type is None:
args.model_type = 'offline'
print("model_type:{}".format(args.model_type))
# https://yaml.org/type/float.html
config = get_cfg_defaults(args.model_type)
if args.config:
config.merge_from_file(args.config)
if args.opts:
config.merge_from_list(args.opts)
config.freeze()
print(config)
if args.dump_config:
with open(args.dump_config, 'w') as f:
print(config, file=f)
main(config, args)

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains DeepSpeech2 and DeepSpeech2Online model."""
import os
import time
from collections import defaultdict
from pathlib import Path
@ -20,6 +21,7 @@ from typing import Optional
import numpy as np
import paddle
from paddle import distributed as dist
from paddle import inference
from paddle.io import DataLoader
from yacs.config import CfgNode
@ -268,24 +270,9 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
vocab_list = self.test_loader.collate_fn.vocab_list
target_transcripts = self.ordid2token(texts, texts_len)
self.autolog.times.start()
self.autolog.times.stamp()
result_transcripts = self.model.decode(
audio,
audio_len,
vocab_list,
decoding_method=cfg.decoding_method,
lang_model_path=cfg.lang_model_path,
beam_alpha=cfg.alpha,
beam_beta=cfg.beta,
beam_size=cfg.beam_size,
cutoff_prob=cfg.cutoff_prob,
cutoff_top_n=cfg.cutoff_top_n,
num_processes=cfg.num_proc_bsearch)
self.autolog.times.stamp()
self.autolog.times.stamp()
self.autolog.times.end()
result_transcripts = self.compute_result_transcripts(audio, audio_len,
vocab_list, cfg)
for utt, target, result in zip(utts, target_transcripts,
result_transcripts):
errors, len_ref = errors_func(target, result)
@ -306,6 +293,26 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
error_rate=errors_sum / len_refs,
error_rate_type=cfg.error_rate_type)
def compute_result_transcripts(self, audio, audio_len, vocab_list, cfg):
self.autolog.times.start()
self.autolog.times.stamp()
result_transcripts = self.model.decode(
audio,
audio_len,
vocab_list,
decoding_method=cfg.decoding_method,
lang_model_path=cfg.lang_model_path,
beam_alpha=cfg.alpha,
beam_beta=cfg.beta,
beam_size=cfg.beam_size,
cutoff_prob=cfg.cutoff_prob,
cutoff_top_n=cfg.cutoff_top_n,
num_processes=cfg.num_proc_bsearch)
self.autolog.times.stamp()
self.autolog.times.stamp()
self.autolog.times.end()
return result_transcripts
@mp_tools.rank_zero_only
@paddle.no_grad()
def test(self):
@ -395,3 +402,244 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
output_dir.mkdir(parents=True, exist_ok=True)
self.output_dir = output_dir
class DeepSpeech2ExportTester(DeepSpeech2Tester):
def __init__(self, config, args):
super().__init__(config, args)
def compute_result_transcripts(self, audio, audio_len, vocab_list, cfg):
if self.args.model_type == "online":
output_probs, output_lens = self.static_forward_online(audio,
audio_len)
elif self.args.model_type == "offline":
output_probs, output_lens = self.static_forward_offline(audio,
audio_len)
else:
raise Exception("wrong model type")
self.predictor.clear_intermediate_tensor()
self.predictor.try_shrink_memory()
self.model.decoder.init_decode(cfg.alpha, cfg.beta, cfg.lang_model_path,
vocab_list, cfg.decoding_method)
result_transcripts = self.model.decoder.decode_probs(
output_probs, output_lens, vocab_list, cfg.decoding_method,
cfg.lang_model_path, cfg.alpha, cfg.beta, cfg.beam_size,
cfg.cutoff_prob, cfg.cutoff_top_n, cfg.num_proc_bsearch)
return result_transcripts
def static_forward_online(self, audio, audio_len,
decoder_chunk_size: int=1):
"""
Parameters
----------
audio (Tensor): shape[B, T, D]
audio_len (Tensor): shape[B]
decoder_chunk_size(int)
Returns
-------
output_probs(numpy.array): shape[B, T, vocab_size]
output_lens(numpy.array): shape[B]
"""
output_probs_list = []
output_lens_list = []
subsampling_rate = self.model.encoder.conv.subsampling_rate
receptive_field_length = self.model.encoder.conv.receptive_field_length
chunk_stride = subsampling_rate * decoder_chunk_size
chunk_size = (decoder_chunk_size - 1
) * subsampling_rate + receptive_field_length
x_batch = audio.numpy()
batch_size, Tmax, x_dim = x_batch.shape
x_len_batch = audio_len.numpy().astype(np.int64)
if (Tmax - chunk_size) % chunk_stride != 0:
padding_len_batch = chunk_stride - (
Tmax - chunk_size
) % chunk_stride # The length of padding for the batch
else:
padding_len_batch = 0
x_list = np.split(x_batch, batch_size, axis=0)
x_len_list = np.split(x_len_batch, batch_size, axis=0)
for x, x_len in zip(x_list, x_len_list):
self.autolog.times.start()
self.autolog.times.stamp()
x_len = x_len[0]
assert (chunk_size <= x_len)
if (x_len - chunk_size) % chunk_stride != 0:
padding_len_x = chunk_stride - (x_len - chunk_size
) % chunk_stride
else:
padding_len_x = 0
padding = np.zeros(
(x.shape[0], padding_len_x, x.shape[2]), dtype=x.dtype)
padded_x = np.concatenate([x, padding], axis=1)
num_chunk = (x_len + padding_len_x - chunk_size) / chunk_stride + 1
num_chunk = int(num_chunk)
chunk_state_h_box = np.zeros(
(self.config.model.num_rnn_layers, 1,
self.config.model.rnn_layer_size),
dtype=x.dtype)
chunk_state_c_box = np.zeros(
(self.config.model.num_rnn_layers, 1,
self.config.model.rnn_layer_size),
dtype=x.dtype)
input_names = self.predictor.get_input_names()
audio_handle = self.predictor.get_input_handle(input_names[0])
audio_len_handle = self.predictor.get_input_handle(input_names[1])
h_box_handle = self.predictor.get_input_handle(input_names[2])
c_box_handle = self.predictor.get_input_handle(input_names[3])
probs_chunk_list = []
probs_chunk_lens_list = []
for i in range(0, num_chunk):
start = i * chunk_stride
end = start + chunk_size
x_chunk = padded_x[:, start:end, :]
if x_len < i * chunk_stride:
x_chunk_lens = 0
else:
x_chunk_lens = min(x_len - i * chunk_stride, chunk_size)
if (x_chunk_lens <
receptive_field_length): #means the number of input frames in the chunk is not enough for predicting one prob
break
x_chunk_lens = np.array([x_chunk_lens])
audio_handle.reshape(x_chunk.shape)
audio_handle.copy_from_cpu(x_chunk)
audio_len_handle.reshape(x_chunk_lens.shape)
audio_len_handle.copy_from_cpu(x_chunk_lens)
h_box_handle.reshape(chunk_state_h_box.shape)
h_box_handle.copy_from_cpu(chunk_state_h_box)
c_box_handle.reshape(chunk_state_c_box.shape)
c_box_handle.copy_from_cpu(chunk_state_c_box)
output_names = self.predictor.get_output_names()
output_handle = self.predictor.get_output_handle(
output_names[0])
output_lens_handle = self.predictor.get_output_handle(
output_names[1])
output_state_h_handle = self.predictor.get_output_handle(
output_names[2])
output_state_c_handle = self.predictor.get_output_handle(
output_names[3])
self.predictor.run()
output_chunk_probs = output_handle.copy_to_cpu()
output_chunk_lens = output_lens_handle.copy_to_cpu()
chunk_state_h_box = output_state_h_handle.copy_to_cpu()
chunk_state_c_box = output_state_c_handle.copy_to_cpu()
probs_chunk_list.append(output_chunk_probs)
probs_chunk_lens_list.append(output_chunk_lens)
output_probs = np.concatenate(probs_chunk_list, axis=1)
output_lens = np.sum(probs_chunk_lens_list, axis=0)
vocab_size = output_probs.shape[2]
output_probs_padding_len = Tmax + padding_len_batch - output_probs.shape[
1]
output_probs_padding = np.zeros(
(1, output_probs_padding_len, vocab_size),
dtype=output_probs.
dtype) # The prob padding for a piece of utterance
output_probs = np.concatenate(
[output_probs, output_probs_padding], axis=1)
output_probs_list.append(output_probs)
output_lens_list.append(output_lens)
self.autolog.times.stamp()
self.autolog.times.stamp()
self.autolog.times.end()
output_probs = np.concatenate(output_probs_list, axis=0)
output_lens = np.concatenate(output_lens_list, axis=0)
return output_probs, output_lens
def static_forward_offline(self, audio, audio_len):
"""
Parameters
----------
audio (Tensor): shape[B, T, D]
audio_len (Tensor): shape[B]
Returns
-------
output_probs(numpy.array): shape[B, T, vocab_size]
output_lens(numpy.array): shape[B]
"""
x = audio.numpy()
x_len = audio_len.numpy().astype(np.int64)
input_names = self.predictor.get_input_names()
audio_handle = self.predictor.get_input_handle(input_names[0])
audio_len_handle = self.predictor.get_input_handle(input_names[1])
audio_handle.reshape(x.shape)
audio_handle.copy_from_cpu(x)
audio_len_handle.reshape(x_len.shape)
audio_len_handle.copy_from_cpu(x_len)
self.autolog.times.start()
self.autolog.times.stamp()
self.predictor.run()
self.autolog.times.stamp()
self.autolog.times.stamp()
self.autolog.times.end()
output_names = self.predictor.get_output_names()
output_handle = self.predictor.get_output_handle(output_names[0])
output_lens_handle = self.predictor.get_output_handle(output_names[1])
output_probs = output_handle.copy_to_cpu()
output_lens = output_lens_handle.copy_to_cpu()
return output_probs, output_lens
def run_test(self):
try:
self.test()
except KeyboardInterrupt:
exit(-1)
def setup(self):
"""Setup the experiment.
"""
paddle.set_device(self.args.device)
self.setup_output_dir()
self.setup_dataloader()
self.setup_model()
self.iteration = 0
self.epoch = 0
def setup_output_dir(self):
"""Create a directory used for output.
"""
# output dir
if self.args.output:
output_dir = Path(self.args.output).expanduser()
output_dir.mkdir(parents=True, exist_ok=True)
else:
output_dir = Path(self.args.export_path).expanduser().parent.parent
output_dir.mkdir(parents=True, exist_ok=True)
self.output_dir = output_dir
def setup_model(self):
super().setup_model()
speedyspeech_config = inference.Config(
self.args.export_path + ".pdmodel",
self.args.export_path + ".pdiparams")
if (os.environ['CUDA_VISIBLE_DEVICES'].strip() != ''):
speedyspeech_config.enable_use_gpu(100, 0)
speedyspeech_config.enable_memory_optim()
speedyspeech_predictor = inference.create_predictor(speedyspeech_config)
self.predictor = speedyspeech_predictor

@ -280,7 +280,7 @@ class DeepSpeech2InferModel(DeepSpeech2Model):
"""
eouts, eouts_len = self.encoder(audio, audio_len)
probs = self.decoder.softmax(eouts)
return probs
return probs, eouts_len
def export(self):
static_model = paddle.jit.to_static(

@ -100,12 +100,12 @@ class CRNNEncoder(nn.Layer):
"""Compute Encoder outputs
Args:
x (Tensor): [B, feature_size, D]
x (Tensor): [B, T, D]
x_lens (Tensor): [B]
init_state_h_box(Tensor): init_states h for RNN layers: [num_rnn_layers * num_directions, batch_size, hidden_size]
init_state_c_box(Tensor): init_states c for RNN layers: [num_rnn_layers * num_directions, batch_size, hidden_size]
Return:
x (Tensor): encoder outputs, [B, size, D]
x (Tensor): encoder outputs, [B, T, D]
x_lens (Tensor): encoder length, [B]
final_state_h_box(Tensor): final_states h for RNN layers: [num_rnn_layers * num_directions, batch_size, hidden_size]
final_state_c_box(Tensor): final_states c for RNN layers: [num_rnn_layers * num_directions, batch_size, hidden_size]

@ -0,0 +1,39 @@
#!/bin/bash
if [ $# != 3 ];then
echo "usage: ${0} config_path ckpt_path_prefix model_type"
exit -1
fi
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..."
device=gpu
if [ ${ngpu} == 0 ];then
device=cpu
fi
config_path=$1
jit_model_export_path=$2
model_type=$3
# download language model
bash local/download_lm_ch.sh
if [ $? -ne 0 ]; then
exit 1
fi
python3 -u ${BIN_DIR}/test_export.py \
--device ${device} \
--nproc 1 \
--config ${config_path} \
--result_file ${jit_model_export_path}.rsl \
--export_path ${jit_model_export_path} \
--model_type ${model_type}
if [ $? -ne 0 ]; then
echo "Failed in evaluation!"
exit 1
fi
exit 0

@ -39,3 +39,8 @@ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
# export ckpt avg_n
CUDA_VISIBLE_DEVICES=0 ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit ${model_type}
fi
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
# test export ckpt avg_n
CUDA_VISIBLE_DEVICES=0 ./local/test_export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt}.jit ${model_type}|| exit -1
fi

Loading…
Cancel
Save