Merge pull request #786 from Jackwaterveg/ds2_online

[Static model test] Add the test process for export model
pull/794/head
Hui Zhang 3 years ago committed by GitHub
commit 7181e427af
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,58 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Evaluation for DeepSpeech2 model."""
from deepspeech.exps.deepspeech2.config import get_cfg_defaults
from deepspeech.exps.deepspeech2.model import DeepSpeech2ExportTester as ExportTester
from deepspeech.training.cli import default_argument_parser
from deepspeech.utils.utility import print_arguments
def main_sp(config, args):
exp = ExportTester(config, args)
exp.setup()
exp.run_test()
def main(config, args):
main_sp(config, args)
if __name__ == "__main__":
parser = default_argument_parser()
# save asr result to
parser.add_argument(
"--result_file", type=str, help="path of save the asr result")
#load jit model from
parser.add_argument(
"--export_path", type=str, help="path of the jit model to save")
parser.add_argument("--model_type")
args = parser.parse_args()
print_arguments(args, globals())
if args.model_type is None:
args.model_type = 'offline'
print("model_type:{}".format(args.model_type))
# https://yaml.org/type/float.html
config = get_cfg_defaults(args.model_type)
if args.config:
config.merge_from_file(args.config)
if args.opts:
config.merge_from_list(args.opts)
config.freeze()
print(config)
if args.dump_config:
with open(args.dump_config, 'w') as f:
print(config, file=f)
main(config, args)

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Contains DeepSpeech2 and DeepSpeech2Online model.""" """Contains DeepSpeech2 and DeepSpeech2Online model."""
import os
import time import time
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
@ -20,6 +21,7 @@ from typing import Optional
import numpy as np import numpy as np
import paddle import paddle
from paddle import distributed as dist from paddle import distributed as dist
from paddle import inference
from paddle.io import DataLoader from paddle.io import DataLoader
from yacs.config import CfgNode from yacs.config import CfgNode
@ -268,24 +270,9 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
vocab_list = self.test_loader.collate_fn.vocab_list vocab_list = self.test_loader.collate_fn.vocab_list
target_transcripts = self.ordid2token(texts, texts_len) target_transcripts = self.ordid2token(texts, texts_len)
self.autolog.times.start()
self.autolog.times.stamp()
result_transcripts = self.model.decode(
audio,
audio_len,
vocab_list,
decoding_method=cfg.decoding_method,
lang_model_path=cfg.lang_model_path,
beam_alpha=cfg.alpha,
beam_beta=cfg.beta,
beam_size=cfg.beam_size,
cutoff_prob=cfg.cutoff_prob,
cutoff_top_n=cfg.cutoff_top_n,
num_processes=cfg.num_proc_bsearch)
self.autolog.times.stamp()
self.autolog.times.stamp()
self.autolog.times.end()
result_transcripts = self.compute_result_transcripts(audio, audio_len,
vocab_list, cfg)
for utt, target, result in zip(utts, target_transcripts, for utt, target, result in zip(utts, target_transcripts,
result_transcripts): result_transcripts):
errors, len_ref = errors_func(target, result) errors, len_ref = errors_func(target, result)
@ -306,6 +293,26 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
error_rate=errors_sum / len_refs, error_rate=errors_sum / len_refs,
error_rate_type=cfg.error_rate_type) error_rate_type=cfg.error_rate_type)
def compute_result_transcripts(self, audio, audio_len, vocab_list, cfg):
self.autolog.times.start()
self.autolog.times.stamp()
result_transcripts = self.model.decode(
audio,
audio_len,
vocab_list,
decoding_method=cfg.decoding_method,
lang_model_path=cfg.lang_model_path,
beam_alpha=cfg.alpha,
beam_beta=cfg.beta,
beam_size=cfg.beam_size,
cutoff_prob=cfg.cutoff_prob,
cutoff_top_n=cfg.cutoff_top_n,
num_processes=cfg.num_proc_bsearch)
self.autolog.times.stamp()
self.autolog.times.stamp()
self.autolog.times.end()
return result_transcripts
@mp_tools.rank_zero_only @mp_tools.rank_zero_only
@paddle.no_grad() @paddle.no_grad()
def test(self): def test(self):
@ -395,3 +402,244 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
output_dir.mkdir(parents=True, exist_ok=True) output_dir.mkdir(parents=True, exist_ok=True)
self.output_dir = output_dir self.output_dir = output_dir
class DeepSpeech2ExportTester(DeepSpeech2Tester):
def __init__(self, config, args):
super().__init__(config, args)
def compute_result_transcripts(self, audio, audio_len, vocab_list, cfg):
if self.args.model_type == "online":
output_probs, output_lens = self.static_forward_online(audio,
audio_len)
elif self.args.model_type == "offline":
output_probs, output_lens = self.static_forward_offline(audio,
audio_len)
else:
raise Exception("wrong model type")
self.predictor.clear_intermediate_tensor()
self.predictor.try_shrink_memory()
self.model.decoder.init_decode(cfg.alpha, cfg.beta, cfg.lang_model_path,
vocab_list, cfg.decoding_method)
result_transcripts = self.model.decoder.decode_probs(
output_probs, output_lens, vocab_list, cfg.decoding_method,
cfg.lang_model_path, cfg.alpha, cfg.beta, cfg.beam_size,
cfg.cutoff_prob, cfg.cutoff_top_n, cfg.num_proc_bsearch)
return result_transcripts
def static_forward_online(self, audio, audio_len,
decoder_chunk_size: int=1):
"""
Parameters
----------
audio (Tensor): shape[B, T, D]
audio_len (Tensor): shape[B]
decoder_chunk_size(int)
Returns
-------
output_probs(numpy.array): shape[B, T, vocab_size]
output_lens(numpy.array): shape[B]
"""
output_probs_list = []
output_lens_list = []
subsampling_rate = self.model.encoder.conv.subsampling_rate
receptive_field_length = self.model.encoder.conv.receptive_field_length
chunk_stride = subsampling_rate * decoder_chunk_size
chunk_size = (decoder_chunk_size - 1
) * subsampling_rate + receptive_field_length
x_batch = audio.numpy()
batch_size, Tmax, x_dim = x_batch.shape
x_len_batch = audio_len.numpy().astype(np.int64)
if (Tmax - chunk_size) % chunk_stride != 0:
padding_len_batch = chunk_stride - (
Tmax - chunk_size
) % chunk_stride # The length of padding for the batch
else:
padding_len_batch = 0
x_list = np.split(x_batch, batch_size, axis=0)
x_len_list = np.split(x_len_batch, batch_size, axis=0)
for x, x_len in zip(x_list, x_len_list):
self.autolog.times.start()
self.autolog.times.stamp()
x_len = x_len[0]
assert (chunk_size <= x_len)
if (x_len - chunk_size) % chunk_stride != 0:
padding_len_x = chunk_stride - (x_len - chunk_size
) % chunk_stride
else:
padding_len_x = 0
padding = np.zeros(
(x.shape[0], padding_len_x, x.shape[2]), dtype=x.dtype)
padded_x = np.concatenate([x, padding], axis=1)
num_chunk = (x_len + padding_len_x - chunk_size) / chunk_stride + 1
num_chunk = int(num_chunk)
chunk_state_h_box = np.zeros(
(self.config.model.num_rnn_layers, 1,
self.config.model.rnn_layer_size),
dtype=x.dtype)
chunk_state_c_box = np.zeros(
(self.config.model.num_rnn_layers, 1,
self.config.model.rnn_layer_size),
dtype=x.dtype)
input_names = self.predictor.get_input_names()
audio_handle = self.predictor.get_input_handle(input_names[0])
audio_len_handle = self.predictor.get_input_handle(input_names[1])
h_box_handle = self.predictor.get_input_handle(input_names[2])
c_box_handle = self.predictor.get_input_handle(input_names[3])
probs_chunk_list = []
probs_chunk_lens_list = []
for i in range(0, num_chunk):
start = i * chunk_stride
end = start + chunk_size
x_chunk = padded_x[:, start:end, :]
if x_len < i * chunk_stride:
x_chunk_lens = 0
else:
x_chunk_lens = min(x_len - i * chunk_stride, chunk_size)
if (x_chunk_lens <
receptive_field_length): #means the number of input frames in the chunk is not enough for predicting one prob
break
x_chunk_lens = np.array([x_chunk_lens])
audio_handle.reshape(x_chunk.shape)
audio_handle.copy_from_cpu(x_chunk)
audio_len_handle.reshape(x_chunk_lens.shape)
audio_len_handle.copy_from_cpu(x_chunk_lens)
h_box_handle.reshape(chunk_state_h_box.shape)
h_box_handle.copy_from_cpu(chunk_state_h_box)
c_box_handle.reshape(chunk_state_c_box.shape)
c_box_handle.copy_from_cpu(chunk_state_c_box)
output_names = self.predictor.get_output_names()
output_handle = self.predictor.get_output_handle(
output_names[0])
output_lens_handle = self.predictor.get_output_handle(
output_names[1])
output_state_h_handle = self.predictor.get_output_handle(
output_names[2])
output_state_c_handle = self.predictor.get_output_handle(
output_names[3])
self.predictor.run()
output_chunk_probs = output_handle.copy_to_cpu()
output_chunk_lens = output_lens_handle.copy_to_cpu()
chunk_state_h_box = output_state_h_handle.copy_to_cpu()
chunk_state_c_box = output_state_c_handle.copy_to_cpu()
probs_chunk_list.append(output_chunk_probs)
probs_chunk_lens_list.append(output_chunk_lens)
output_probs = np.concatenate(probs_chunk_list, axis=1)
output_lens = np.sum(probs_chunk_lens_list, axis=0)
vocab_size = output_probs.shape[2]
output_probs_padding_len = Tmax + padding_len_batch - output_probs.shape[
1]
output_probs_padding = np.zeros(
(1, output_probs_padding_len, vocab_size),
dtype=output_probs.
dtype) # The prob padding for a piece of utterance
output_probs = np.concatenate(
[output_probs, output_probs_padding], axis=1)
output_probs_list.append(output_probs)
output_lens_list.append(output_lens)
self.autolog.times.stamp()
self.autolog.times.stamp()
self.autolog.times.end()
output_probs = np.concatenate(output_probs_list, axis=0)
output_lens = np.concatenate(output_lens_list, axis=0)
return output_probs, output_lens
def static_forward_offline(self, audio, audio_len):
"""
Parameters
----------
audio (Tensor): shape[B, T, D]
audio_len (Tensor): shape[B]
Returns
-------
output_probs(numpy.array): shape[B, T, vocab_size]
output_lens(numpy.array): shape[B]
"""
x = audio.numpy()
x_len = audio_len.numpy().astype(np.int64)
input_names = self.predictor.get_input_names()
audio_handle = self.predictor.get_input_handle(input_names[0])
audio_len_handle = self.predictor.get_input_handle(input_names[1])
audio_handle.reshape(x.shape)
audio_handle.copy_from_cpu(x)
audio_len_handle.reshape(x_len.shape)
audio_len_handle.copy_from_cpu(x_len)
self.autolog.times.start()
self.autolog.times.stamp()
self.predictor.run()
self.autolog.times.stamp()
self.autolog.times.stamp()
self.autolog.times.end()
output_names = self.predictor.get_output_names()
output_handle = self.predictor.get_output_handle(output_names[0])
output_lens_handle = self.predictor.get_output_handle(output_names[1])
output_probs = output_handle.copy_to_cpu()
output_lens = output_lens_handle.copy_to_cpu()
return output_probs, output_lens
def run_test(self):
try:
self.test()
except KeyboardInterrupt:
exit(-1)
def setup(self):
"""Setup the experiment.
"""
paddle.set_device(self.args.device)
self.setup_output_dir()
self.setup_dataloader()
self.setup_model()
self.iteration = 0
self.epoch = 0
def setup_output_dir(self):
"""Create a directory used for output.
"""
# output dir
if self.args.output:
output_dir = Path(self.args.output).expanduser()
output_dir.mkdir(parents=True, exist_ok=True)
else:
output_dir = Path(self.args.export_path).expanduser().parent.parent
output_dir.mkdir(parents=True, exist_ok=True)
self.output_dir = output_dir
def setup_model(self):
super().setup_model()
speedyspeech_config = inference.Config(
self.args.export_path + ".pdmodel",
self.args.export_path + ".pdiparams")
if (os.environ['CUDA_VISIBLE_DEVICES'].strip() != ''):
speedyspeech_config.enable_use_gpu(100, 0)
speedyspeech_config.enable_memory_optim()
speedyspeech_predictor = inference.create_predictor(speedyspeech_config)
self.predictor = speedyspeech_predictor

@ -280,7 +280,7 @@ class DeepSpeech2InferModel(DeepSpeech2Model):
""" """
eouts, eouts_len = self.encoder(audio, audio_len) eouts, eouts_len = self.encoder(audio, audio_len)
probs = self.decoder.softmax(eouts) probs = self.decoder.softmax(eouts)
return probs return probs, eouts_len
def export(self): def export(self):
static_model = paddle.jit.to_static( static_model = paddle.jit.to_static(

@ -100,12 +100,12 @@ class CRNNEncoder(nn.Layer):
"""Compute Encoder outputs """Compute Encoder outputs
Args: Args:
x (Tensor): [B, feature_size, D] x (Tensor): [B, T, D]
x_lens (Tensor): [B] x_lens (Tensor): [B]
init_state_h_box(Tensor): init_states h for RNN layers: [num_rnn_layers * num_directions, batch_size, hidden_size] init_state_h_box(Tensor): init_states h for RNN layers: [num_rnn_layers * num_directions, batch_size, hidden_size]
init_state_c_box(Tensor): init_states c for RNN layers: [num_rnn_layers * num_directions, batch_size, hidden_size] init_state_c_box(Tensor): init_states c for RNN layers: [num_rnn_layers * num_directions, batch_size, hidden_size]
Return: Return:
x (Tensor): encoder outputs, [B, size, D] x (Tensor): encoder outputs, [B, T, D]
x_lens (Tensor): encoder length, [B] x_lens (Tensor): encoder length, [B]
final_state_h_box(Tensor): final_states h for RNN layers: [num_rnn_layers * num_directions, batch_size, hidden_size] final_state_h_box(Tensor): final_states h for RNN layers: [num_rnn_layers * num_directions, batch_size, hidden_size]
final_state_c_box(Tensor): final_states c for RNN layers: [num_rnn_layers * num_directions, batch_size, hidden_size] final_state_c_box(Tensor): final_states c for RNN layers: [num_rnn_layers * num_directions, batch_size, hidden_size]

@ -0,0 +1,39 @@
#!/bin/bash
if [ $# != 3 ];then
echo "usage: ${0} config_path ckpt_path_prefix model_type"
exit -1
fi
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..."
device=gpu
if [ ${ngpu} == 0 ];then
device=cpu
fi
config_path=$1
jit_model_export_path=$2
model_type=$3
# download language model
bash local/download_lm_ch.sh
if [ $? -ne 0 ]; then
exit 1
fi
python3 -u ${BIN_DIR}/test_export.py \
--device ${device} \
--nproc 1 \
--config ${config_path} \
--result_file ${jit_model_export_path}.rsl \
--export_path ${jit_model_export_path} \
--model_type ${model_type}
if [ $? -ne 0 ]; then
echo "Failed in evaluation!"
exit 1
fi
exit 0

@ -39,3 +39,8 @@ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
# export ckpt avg_n # export ckpt avg_n
CUDA_VISIBLE_DEVICES=0 ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit ${model_type} CUDA_VISIBLE_DEVICES=0 ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit ${model_type}
fi fi
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
# test export ckpt avg_n
CUDA_VISIBLE_DEVICES=0 ./local/test_export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt}.jit ${model_type}|| exit -1
fi

Loading…
Cancel
Save