complete model export for ds2_online

pull/735/head
huangyuxin 3 years ago
parent e8a3913422
commit 3fb9f6885a

@ -32,7 +32,8 @@ if __name__ == "__main__":
parser = default_argument_parser() parser = default_argument_parser()
parser.add_argument("--model_type") parser.add_argument("--model_type")
args = parser.parse_args() args = parser.parse_args()
if args.model_type is None:
args.model_type = 'offline'
print_arguments(args) print_arguments(args)
# https://yaml.org/type/float.html # https://yaml.org/type/float.html

@ -33,6 +33,8 @@ if __name__ == "__main__":
parser.add_argument("--model_type") parser.add_argument("--model_type")
args = parser.parse_args() args = parser.parse_args()
print_arguments(args, globals()) print_arguments(args, globals())
if args.model_type is None:
args.model_type = 'offline'
# https://yaml.org/type/float.html # https://yaml.org/type/float.html
config = get_cfg_defaults(args.model_type) config = get_cfg_defaults(args.model_type)

@ -37,6 +37,8 @@ if __name__ == "__main__":
parser = default_argument_parser() parser = default_argument_parser()
parser.add_argument("--model_type") parser.add_argument("--model_type")
args = parser.parse_args() args = parser.parse_args()
if args.model_type is None:
args.model_type = 'offline'
print_arguments(args, globals()) print_arguments(args, globals())
# https://yaml.org/type/float.html # https://yaml.org/type/float.html

@ -21,7 +21,7 @@ from deepspeech.models.ds2 import DeepSpeech2Model
from deepspeech.models.ds2_online import DeepSpeech2ModelOnline from deepspeech.models.ds2_online import DeepSpeech2ModelOnline
def get_cfg_defaults(model_type): def get_cfg_defaults(model_type='offline'):
_C = CfgNode() _C = CfgNode()
if (model_type == 'offline'): if (model_type == 'offline'):
_C.data = ManifestDataset.params() _C.data = ManifestDataset.params()

@ -134,6 +134,7 @@ class DeepSpeech2Trainer(Trainer):
use_gru=config.model.use_gru, use_gru=config.model.use_gru,
share_rnn_weights=config.model.share_rnn_weights) share_rnn_weights=config.model.share_rnn_weights)
elif self.args.model_type == 'online': elif self.args.model_type == 'online':
print("fc_layers_size_list", config.model.fc_layers_size_list)
model = DeepSpeech2ModelOnline( model = DeepSpeech2ModelOnline(
feat_size=self.train_loader.collate_fn.feature_size, feat_size=self.train_loader.collate_fn.feature_size,
dict_size=self.train_loader.collate_fn.vocab_size, dict_size=self.train_loader.collate_fn.vocab_size,
@ -352,19 +353,43 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
infer_model = DeepSpeech2InferModelOnline.from_pretrained( infer_model = DeepSpeech2InferModelOnline.from_pretrained(
self.test_loader, self.config, self.args.checkpoint_path) self.test_loader, self.config, self.args.checkpoint_path)
else: else:
raise Exception("wrong model tyep") raise Exception("wrong model type")
infer_model.eval() infer_model.eval()
feat_dim = self.test_loader.collate_fn.feature_size feat_dim = self.test_loader.collate_fn.feature_size
static_model = paddle.jit.to_static( if self.args.model_type == 'offline':
infer_model, static_model = paddle.jit.to_static(
input_spec=[ infer_model,
paddle.static.InputSpec( input_spec=[
shape=[None, None, feat_dim], paddle.static.InputSpec(
dtype='float32'), # audio, [B,T,D] shape=[None, None, feat_dim],
paddle.static.InputSpec(shape=[None], dtype='float32'), # audio, [B,T,D]
dtype='int64'), # audio_length, [B] paddle.static.InputSpec(shape=[None],
]) dtype='int64'), # audio_length, [B]
])
elif self.args.model_type == 'online':
static_model = paddle.jit.to_static(
infer_model,
input_spec=[
paddle.static.InputSpec(
shape=[None, None,
feat_dim], #[B, chunk_size, feat_dim]
dtype='float32'), # audio, [B,T,D]
paddle.static.InputSpec(shape=[None],
dtype='int64'), # audio_length, [B]
[
(
paddle.static.InputSpec(
shape=[None, None, None], dtype='float32'
), #num_rnn_layers * num_dirctions, rnn_size
paddle.static.InputSpec(
shape=[None, None, None], dtype='float32'
) #num_rnn_layers * num_dirctions, rnn_size
) for i in range(self.config.model.num_rnn_layers)
]
])
else:
raise Exception("wrong model type")
logger.info(f"Export code: {static_model.forward.code}") logger.info(f"Export code: {static_model.forward.code}")
paddle.jit.save(static_model, self.args.export_path) paddle.jit.save(static_model, self.args.export_path)

@ -29,7 +29,7 @@ class Conv2dSubsampling4Online(Conv2dSubsampling4):
x_len: paddle.Tensor) -> [paddle.Tensor, paddle.Tensor]: x_len: paddle.Tensor) -> [paddle.Tensor, paddle.Tensor]:
x = x.unsqueeze(1) # (b, c=1, t, f) x = x.unsqueeze(1) # (b, c=1, t, f)
x = self.conv(x) x = self.conv(x)
b, c, t, f = paddle.shape(x) #b, c, t, f = paddle.shape(x) #not work under jit
x = x.transpose([0, 2, 1, 3]).reshape([b, t, c * f]) x = x.transpose([0, 2, 1, 3]).reshape([0, 0, -1])
x_len = ((x_len - 1) // 2 - 1) // 2 x_len = ((x_len - 1) // 2 - 1) // 2
return x, x_len return x, x_len

@ -61,7 +61,7 @@ class CRNNEncoder(nn.Layer):
rnn_input_size = i_size rnn_input_size = i_size
else: else:
rnn_input_size = rnn_size rnn_input_size = rnn_size
if (use_gru == True): if use_gru == True:
self.rnn.append( self.rnn.append(
nn.GRU( nn.GRU(
input_size=rnn_input_size, input_size=rnn_input_size,
@ -146,6 +146,17 @@ class CRNNEncoder(nn.Layer):
return x, x_lens, chunk_final_state_list return x, x_lens, chunk_final_state_list
def forward_chunk_by_chunk(self, x, x_lens, decoder_chunk_size=8): def forward_chunk_by_chunk(self, x, x_lens, decoder_chunk_size=8):
"""Compute Encoder outputs
Args:
x (Tensor): [B, T, D]
x_lens (Tensor): [B]
decoder_chunk_size: The chunk size of decoder
Returns:
eouts_chunk_list (List of Tensor): The list of encoder outputs in chunk_size, [B, chunk_size, D] * num_chunks
eouts_chunk_lens_list (List of Tensor): The list of encoder length in chunk_size, [B] * num_chunks
final_chunk_state_list: list of final_states for RNN layers, [num_directions, batch_size, hidden_size] * num_rnn_layers
"""
subsampling_rate = self.conv.subsampling_rate subsampling_rate = self.conv.subsampling_rate
receptive_field_length = self.conv.receptive_field_length receptive_field_length = self.conv.receptive_field_length
chunk_size = (decoder_chunk_size - 1 chunk_size = (decoder_chunk_size - 1
@ -183,8 +194,8 @@ class CRNNEncoder(nn.Layer):
eouts_chunk_list.append(eouts_chunk) eouts_chunk_list.append(eouts_chunk)
eouts_chunk_lens_list.append(eouts_chunk_lens) eouts_chunk_lens_list.append(eouts_chunk_lens)
final_chunk_state_list = chunk_state_list
return eouts_chunk_list, eouts_chunk_lens_list, chunk_state_list return eouts_chunk_list, eouts_chunk_lens_list, final_chunk_state_list
class DeepSpeech2ModelOnline(nn.Layer): class DeepSpeech2ModelOnline(nn.Layer):
@ -208,7 +219,6 @@ class DeepSpeech2ModelOnline(nn.Layer):
:type rnn_size: int :type rnn_size: int
:param use_gru: Use gru if set True. Use simple rnn if set False. :param use_gru: Use gru if set True. Use simple rnn if set False.
:type use_gru: bool :type use_gru: bool
:type share_weights: bool
:return: A tuple of an output unnormalized log probability layer ( :return: A tuple of an output unnormalized log probability layer (
before softmax) and a ctc cost layer. before softmax) and a ctc cost layer.
:rtype: tuple of LayerOutput :rtype: tuple of LayerOutput
@ -295,97 +305,6 @@ class DeepSpeech2ModelOnline(nn.Layer):
probs.numpy(), eouts_len, vocab_list, decoding_method, probs.numpy(), eouts_len, vocab_list, decoding_method,
lang_model_path, beam_alpha, beam_beta, beam_size, cutoff_prob, lang_model_path, beam_alpha, beam_beta, beam_size, cutoff_prob,
cutoff_top_n, num_processes) cutoff_top_n, num_processes)
"""
@paddle.no_grad()
def decode_by_chunk(self, eouts_prefix, eouts_len_prefix, chunk_state_list,
audio_chunk, audio_len_chunk, vocab_list,
decoding_method, lang_model_path, beam_alpha, beam_beta,
beam_size, cutoff_prob, cutoff_top_n, num_processes):
# init once
# decoders only accept string encoded in utf-8
self.decoder.init_decode(
beam_alpha=beam_alpha,
beam_beta=beam_beta,
lang_model_path=lang_model_path,
vocab_list=vocab_list,
decoding_method=decoding_method)
eouts_chunk, eouts_chunk_len, final_state_list = self.encoder.forward_chunk(
audio_chunk, audio_len_chunk, chunk_state_list)
if eouts_prefix is not None:
eouts = paddle.concat([eouts_prefix, eouts_chunk], axis=1)
eouts_len = paddle.add_n([eouts_len_prefix, eouts_chunk_len])
else:
eouts = eouts_chunk
eouts_len = eouts_chunk_len
probs = self.decoder.softmax(eouts)
return self.decoder.decode_probs(
probs.numpy(), eouts_len, vocab_list, decoding_method,
lang_model_path, beam_alpha, beam_beta, beam_size, cutoff_prob,
cutoff_top_n, num_processes), eouts, eouts_len, final_state_list
@paddle.no_grad()
def decode_chunk_by_chunk(self, audio, audio_len, vocab_list,
decoding_method, lang_model_path, beam_alpha,
beam_beta, beam_size, cutoff_prob, cutoff_top_n,
num_processes):
# init once
# decoders only accept string encoded in utf-8
self.decoder.init_decode(
beam_alpha=beam_alpha,
beam_beta=beam_beta,
lang_model_path=lang_model_path,
vocab_list=vocab_list,
decoding_method=decoding_method)
eouts_chunk_list, eouts_chunk_len_list, final_state_list = self.encoder.forward_chunk_by_chunk(
audio, audio_len)
eouts = paddle.concat(eouts_chunk_list, axis=1)
eouts_len = paddle.add_n(eouts_chunk_len_list)
probs = self.decoder.softmax(eouts)
return self.decoder.decode_probs(
probs.numpy(), eouts_len, vocab_list, decoding_method,
lang_model_path, beam_alpha, beam_beta, beam_size, cutoff_prob,
cutoff_top_n, num_processes)
"""
"""
decocd_prob,
decode_prob_chunk_by_chunk
decode_prob_by_chunk
is only used for test
"""
"""
@paddle.no_grad()
def decode_prob(self, audio, audio_len):
eouts, eouts_len, final_state_list = self.encoder(audio, audio_len)
probs = self.decoder.softmax(eouts)
return probs, eouts, eouts_len, final_state_list
@paddle.no_grad()
def decode_prob_chunk_by_chunk(self, audio, audio_len, decoder_chunk_size):
eouts_chunk_list, eouts_chunk_len_list, final_state_list = self.encoder.forward_chunk_by_chunk(
audio, audio_len, decoder_chunk_size)
eouts = paddle.concat(eouts_chunk_list, axis=1)
eouts_len = paddle.add_n(eouts_chunk_len_list)
probs = self.decoder.softmax(eouts)
return probs, eouts, eouts_len, final_state_list
@paddle.no_grad()
def decode_prob_by_chunk(self, audio, audio_len, eouts_prefix,
eouts_lens_prefix, chunk_state_list):
eouts_chunk, eouts_chunk_lens, final_state_list = self.encoder.forward_chunk(
audio, audio_len, chunk_state_list)
if eouts_prefix is not None:
eouts = paddle.concat([eouts_prefix, eouts_chunk], axis=1)
eouts_lens = paddle.add_n([eouts_lens_prefix, eouts_chunk_lens])
else:
eouts = eouts_chunk
eouts_lens = eouts_chunk_lens
probs = self.decoder.softmax(eouts)
return probs, eouts, eouts_lens, final_state_list
"""
@classmethod @classmethod
def from_pretrained(cls, dataloader, config, checkpoint_path): def from_pretrained(cls, dataloader, config, checkpoint_path):
@ -443,42 +362,8 @@ class DeepSpeech2InferModelOnline(DeepSpeech2ModelOnline):
fc_layers_size_list=fc_layers_size_list, fc_layers_size_list=fc_layers_size_list,
use_gru=use_gru) use_gru=use_gru)
def forward(self, audio, audio_len): def forward(self, audio_chunk, audio_chunk_lens, chunk_state_list):
"""export model function
Args:
audio (Tensor): [B, T, D]
audio_len (Tensor): [B]
Returns:
probs: probs after softmax
"""
eouts, eouts_len, final_state_list = self.encoder(audio, audio_len)
probs = self.decoder.softmax(eouts)
return probs
def forward_chunk(self, audio_chunk, audio_chunk_lens):
eouts_chunkt, eouts_chunk_lens, final_state_list = self.encoder.forward_chunk(
audio_chunk, audio_chunk_lens)
probs = self.decoder.softmax(eouts)
return probs
def forward(self, eouts_chunk_prefix, eouts_chunk_lens_prefix, audio_chunk,
audio_chunk_lens, chunk_state_list):
"""export model function
Args:
audio_chunk (Tensor): [B, T, D]
audio_chunk_len (Tensor): [B]
Returns:
probs: probs after softmax
"""
eouts_chunk, eouts_chunk_lens, final_state_list = self.encoder.forward_chunk( eouts_chunk, eouts_chunk_lens, final_state_list = self.encoder.forward_chunk(
audio_chunk, audio_chunk_lens, chunk_state_list) audio_chunk, audio_chunk_lens, chunk_state_list)
eouts_chunk_new_prefix = paddle.concat( probs_chunk = self.decoder.softmax(eouts_chunk)
[eouts_chunk_prefix, eouts_chunk], axis=1) return probs_chunk, final_state_list
eouts_chunk_lens_new_prefix = paddle.add(eouts_chunk_lens_prefix,
eouts_chunk_lens)
probs_chunk = self.decoder.softmax(eouts_chunk_new_prefix)
return probs_chunk, eouts_chunk_new_prefix, eouts_chunk_lens_new_prefix, final_state_list

@ -7,7 +7,7 @@ stage=0
stop_stage=100 stop_stage=100
conf_path=conf/deepspeech2.yaml conf_path=conf/deepspeech2.yaml
avg_num=1 avg_num=1
model_type=online model_type=offline
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1; source ${MAIN_ROOT}/utils/parse_options.sh || exit 1;

@ -4,10 +4,10 @@ source path.sh
gpus=7 gpus=7
stage=1 stage=1
stop_stage=100 stop_stage=1
conf_path=conf/deepspeech2.yaml conf_path=conf/deepspeech2_online.yaml
avg_num=1 avg_num=1
model_type=online model_type=online #online | offline
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1; source ${MAIN_ROOT}/utils/parse_options.sh || exit 1;

@ -19,7 +19,6 @@ import paddle
from deepspeech.models.ds2 import DeepSpeech2Model from deepspeech.models.ds2 import DeepSpeech2Model
class TestDeepSpeech2Model(unittest.TestCase): class TestDeepSpeech2Model(unittest.TestCase):
def setUp(self): def setUp(self):
paddle.set_device('cpu') paddle.set_device('cpu')

@ -119,14 +119,14 @@ class TestDeepSpeech2ModelOnline(unittest.TestCase):
paddle.device.set_device("cpu") paddle.device.set_device("cpu")
de_ch_size = 9 de_ch_size = 9
eouts, eouts_lens, final_state_list = model.encoder( eouts, eouts_lens, final_state_list = model.encoder(self.audio,
self.audio, self.audio_len) self.audio_len)
eouts_by_chk_list, eouts_lens_by_chk_list, final_state_list_by_chk = model.encoder.forward_chunk_by_chunk( eouts_by_chk_list, eouts_lens_by_chk_list, final_state_list_by_chk = model.encoder.forward_chunk_by_chunk(
self.audio, self.audio_len, de_ch_size) self.audio, self.audio_len, de_ch_size)
eouts_by_chk = paddle.concat(eouts_by_chk_list, axis = 1) eouts_by_chk = paddle.concat(eouts_by_chk_list, axis=1)
eouts_lens_by_chk = paddle.add_n(eouts_lens_by_chk_list) eouts_lens_by_chk = paddle.add_n(eouts_lens_by_chk_list)
decode_max_len = eouts.shape[1] decode_max_len = eouts.shape[1]
print ("dml", decode_max_len) print("dml", decode_max_len)
eouts_by_chk = eouts_by_chk[:, :decode_max_len, :] eouts_by_chk = eouts_by_chk[:, :decode_max_len, :]
self.assertEqual( self.assertEqual(
paddle.sum( paddle.sum(
@ -149,6 +149,7 @@ class TestDeepSpeech2ModelOnline(unittest.TestCase):
print (paddle.sum(paddle.abs(paddle.subtract(eouts, eouts_by_chk)))) print (paddle.sum(paddle.abs(paddle.subtract(eouts, eouts_by_chk))))
print (paddle.allclose(eouts[:,:,:], eouts_by_chk[:,:,:])) print (paddle.allclose(eouts[:,:,:], eouts_by_chk[:,:,:]))
""" """
""" """
def split_into_chunk(self, x, x_lens, decoder_chunk_size, subsampling_rate, def split_into_chunk(self, x, x_lens, decoder_chunk_size, subsampling_rate,
receptive_field_length): receptive_field_length):

Loading…
Cancel
Save