reconstruct the exp/model.py and the model.export()

pull/735/head
huangyuxin 4 years ago
parent 319228653e
commit 85d5021475

@ -27,7 +27,7 @@ def get_cfg_defaults(model_type='offline'):
_C.collator = SpeechCollator.params() _C.collator = SpeechCollator.params()
_C.training = DeepSpeech2Trainer.params() _C.training = DeepSpeech2Trainer.params()
_C.decoding = DeepSpeech2Tester.params() _C.decoding = DeepSpeech2Tester.params()
if (model_type == 'offline'): if model_type == 'offline':
_C.model = DeepSpeech2Model.params() _C.model = DeepSpeech2Model.params()
else: else:
_C.model = DeepSpeech2ModelOnline.params() _C.model = DeepSpeech2ModelOnline.params()

@ -124,10 +124,23 @@ class DeepSpeech2Trainer(Trainer):
def setup_model(self): def setup_model(self):
config = self.config config = self.config
if hasattr(self, "train_loader"):
config.defrost()
config.model.feat_size = self.train_loader.collate_fn.feature_size
config.model.dict_size = self.train_loader.collate_fn.vocab_size
config.freeze()
elif hasattr(self, "test_loader"):
config.defrost()
config.model.feat_size = self.test_loader.collate_fn.feature_size
config.model.dict_size = self.test_loader.collate_fn.vocab_size
config.freeze()
else:
raise Exception("Please setup the dataloader first")
if self.args.model_type == 'offline': if self.args.model_type == 'offline':
model = DeepSpeech2Model( model = DeepSpeech2Model(
feat_size=self.train_loader.collate_fn.feature_size, feat_size=config.model.feat_size,
dict_size=self.train_loader.collate_fn.vocab_size, dict_size=config.model.dict_size,
num_conv_layers=config.model.num_conv_layers, num_conv_layers=config.model.num_conv_layers,
num_rnn_layers=config.model.num_rnn_layers, num_rnn_layers=config.model.num_rnn_layers,
rnn_size=config.model.rnn_layer_size, rnn_size=config.model.rnn_layer_size,
@ -135,8 +148,8 @@ class DeepSpeech2Trainer(Trainer):
share_rnn_weights=config.model.share_rnn_weights) share_rnn_weights=config.model.share_rnn_weights)
elif self.args.model_type == 'online': elif self.args.model_type == 'online':
model = DeepSpeech2ModelOnline( model = DeepSpeech2ModelOnline(
feat_size=self.train_loader.collate_fn.feature_size, feat_size=config.model.feat_size,
dict_size=self.train_loader.collate_fn.vocab_size, dict_size=config.model.dict_size,
num_conv_layers=config.model.num_conv_layers, num_conv_layers=config.model.num_conv_layers,
num_rnn_layers=config.model.num_rnn_layers, num_rnn_layers=config.model.num_rnn_layers,
rnn_size=config.model.rnn_layer_size, rnn_size=config.model.rnn_layer_size,
@ -209,6 +222,7 @@ class DeepSpeech2Trainer(Trainer):
batch_sampler=batch_sampler, batch_sampler=batch_sampler,
collate_fn=collate_fn_train, collate_fn=collate_fn_train,
num_workers=config.collator.num_workers) num_workers=config.collator.num_workers)
print("feature_size", self.train_loader.collate_fn.feature_size)
self.valid_loader = DataLoader( self.valid_loader = DataLoader(
dev_dataset, dev_dataset,
batch_size=config.collator.batch_size, batch_size=config.collator.batch_size,
@ -368,8 +382,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
dtype='int64'), # audio_length, [B] dtype='int64'), # audio_length, [B]
]) ])
elif self.args.model_type == 'online': elif self.args.model_type == 'online':
static_model = DeepSpeech2InferModelOnline.export(infer_model, static_model = infer_model.export()
feat_dim)
else: else:
raise Exception("wrong model type") raise Exception("wrong model type")
logger.info(f"Export code: {static_model.forward.code}") logger.info(f"Export code: {static_model.forward.code}")
@ -395,6 +408,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
self.iteration = 0 self.iteration = 0
self.epoch = 0 self.epoch = 0
'''
def setup_model(self): def setup_model(self):
config = self.config config = self.config
if self.args.model_type == 'offline': if self.args.model_type == 'offline':
@ -422,6 +436,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
self.model = model self.model = model
logger.info("Setup model!") logger.info("Setup model!")
'''
def setup_dataloader(self): def setup_dataloader(self):
config = self.config.clone() config = self.config.clone()

@ -88,55 +88,7 @@ class CRNNEncoder(nn.Layer):
def output_size(self): def output_size(self):
return self.fc_layers_size_list[-1] return self.fc_layers_size_list[-1]
def forward(self, x, x_lens): def forward(self, x, x_lens, init_state_h_box=None, init_state_c_box=None):
"""Compute Encoder outputs
Args:
x (Tensor): [B, T_input, D]
x_lens (Tensor): [B]
Returns:
x (Tensor): encoder outputs, [B, T_output, D]
x_lens (Tensor): encoder length, [B]
final_state_h_box(Tensor): final_states h for RNN layers, num_rnn_layers * num_directions, batch_size, hidden_size
final_state_c_box(Tensor): final_states c for RNN layers, num_rnn_layers * num_directions, batch_size, hidden_size
"""
# [B, T, D]
# convolution group
x, x_lens = self.conv(x, x_lens)
# convert data from convolution feature map to sequence of vectors
#B, C, D, T = paddle.shape(x) # not work under jit
#x = x.transpose([0, 3, 1, 2]) #[B, T, C, D]
#x = x.reshape([B, T, C * D]) #[B, T, C*D] # not work under jit
#x = x.reshape([0, 0, -1]) #[B, T, C*D]
# remove padding part
init_state = None
final_state_list = []
for i in range(0, self.num_rnn_layers):
x, final_state = self.rnn[i](x, init_state, x_lens) #[B, T, D]
final_state_list.append(final_state)
x = self.layernorm_list[i](x)
for i in range(self.num_fc_layers):
x = self.fc_layers_list[i](x)
x = F.relu(x)
if self.use_gru == True:
final_state_h_box = paddle.concat(final_state_list, axis=0)
final_state_c_box = paddle.zeros_like(final_state_h_box)
else:
final_state_h_list = [
final_state_list[i][0] for i in range(self.num_rnn_layers)
]
final_state_c_list = [
final_state_list[i][1] for i in range(self.num_rnn_layers)
]
final_state_h_box = paddle.concat(final_state_h_list, axis=0)
final_state_c_box = paddle.concat(final_state_c_list, axis=0)
return x, x_lens, final_state_h_box, final_state_c_box
def forward_chunk(self, x, x_lens, init_state_h_box, init_state_c_box):
"""Compute Encoder outputs """Compute Encoder outputs
Args: Args:
@ -152,13 +104,16 @@ class CRNNEncoder(nn.Layer):
""" """
if init_state_h_box is not None: if init_state_h_box is not None:
init_state_list = None init_state_list = None
init_state_h_list = paddle.split(
init_state_h_box, self.num_rnn_layers, axis=0)
init_state_c_list = paddle.split(
init_state_c_box, self.num_rnn_layers, axis=0)
if self.use_gru == True: if self.use_gru == True:
init_state_h_list = paddle.split(
init_state_h_box, self.num_rnn_layers, axis=0)
init_state_list = init_state_h_list init_state_list = init_state_h_list
else: else:
init_state_h_list = paddle.split(
init_state_h_box, self.num_rnn_layers, axis=0)
init_state_c_list = paddle.split(
init_state_c_box, self.num_rnn_layers, axis=0)
init_state_list = [(init_state_h_list[i], init_state_c_list[i]) init_state_list = [(init_state_h_list[i], init_state_c_list[i])
for i in range(self.num_rnn_layers)] for i in range(self.num_rnn_layers)]
else: else:
@ -179,7 +134,7 @@ class CRNNEncoder(nn.Layer):
if self.use_gru == True: if self.use_gru == True:
final_chunk_state_h_box = paddle.concat( final_chunk_state_h_box = paddle.concat(
final_chunk_state_list, axis=0) final_chunk_state_list, axis=0)
final_chunk_state_c_box = paddle.zeros_like(final_chunk_state_h_box) final_chunk_state_c_box = init_state_c_box #paddle.zeros_like(final_chunk_state_h_box)
else: else:
final_chunk_state_h_list = [ final_chunk_state_h_list = [
final_chunk_state_list[i][0] for i in range(self.num_rnn_layers) final_chunk_state_list[i][0] for i in range(self.num_rnn_layers)
@ -242,13 +197,13 @@ class CRNNEncoder(nn.Layer):
x_chunk_lens = paddle.where(x_len_left < x_chunk_len_tmp, x_chunk_lens = paddle.where(x_len_left < x_chunk_len_tmp,
x_len_left, x_chunk_len_tmp) x_len_left, x_chunk_len_tmp)
eouts_chunk, eouts_chunk_lens, chunk_state_h_box, chunk_state_c_box = self.forward_chunk( eouts_chunk, eouts_chunk_lens, chunk_state_h_box, chunk_state_c_box = self.forward(
x_chunk, x_chunk_lens, chunk_state_h_box, chunk_state_c_box) x_chunk, x_chunk_lens, chunk_state_h_box, chunk_state_c_box)
eouts_chunk_list.append(eouts_chunk) eouts_chunk_list.append(eouts_chunk)
eouts_chunk_lens_list.append(eouts_chunk_lens) eouts_chunk_lens_list.append(eouts_chunk_lens)
final_state_h_box = chunk_state_h_box final_state_h_box = chunk_state_h_box
final_state_c_box = chunk_state_c_box final_state_c_box = chunk_state_c_box
return eouts_chunk_list, eouts_chunk_lens_list, final_state_h_box, final_state_c_box return eouts_chunk_list, eouts_chunk_lens_list, final_state_h_box, final_state_c_box
@ -297,7 +252,7 @@ class DeepSpeech2ModelOnline(nn.Layer):
feat_size, feat_size,
dict_size, dict_size,
num_conv_layers=2, num_conv_layers=2,
num_rnn_layers=3, num_rnn_layers=4,
rnn_size=1024, rnn_size=1024,
rnn_direction='forward', rnn_direction='forward',
num_fc_layers=2, num_fc_layers=2,
@ -337,7 +292,7 @@ class DeepSpeech2ModelOnline(nn.Layer):
loss (Tenosr): [1] loss (Tenosr): [1]
""" """
eouts, eouts_len, final_state_h_box, final_state_c_box = self.encoder( eouts, eouts_len, final_state_h_box, final_state_c_box = self.encoder(
audio, audio_len) audio, audio_len, None, None)
loss = self.decoder(eouts, eouts_len, text, text_len) loss = self.decoder(eouts, eouts_len, text, text_len)
return loss return loss
@ -355,7 +310,7 @@ class DeepSpeech2ModelOnline(nn.Layer):
decoding_method=decoding_method) decoding_method=decoding_method)
eouts, eouts_len, final_state_h_box, final_state_c_box = self.encoder( eouts, eouts_len, final_state_h_box, final_state_c_box = self.encoder(
audio, audio_len) audio, audio_len, None, None)
probs = self.decoder.softmax(eouts) probs = self.decoder.softmax(eouts)
return self.decoder.decode_probs( return self.decoder.decode_probs(
probs.numpy(), eouts_len, vocab_list, decoding_method, probs.numpy(), eouts_len, vocab_list, decoding_method,
@ -401,7 +356,7 @@ class DeepSpeech2InferModelOnline(DeepSpeech2ModelOnline):
feat_size, feat_size,
dict_size, dict_size,
num_conv_layers=2, num_conv_layers=2,
num_rnn_layers=3, num_rnn_layers=4,
rnn_size=1024, rnn_size=1024,
rnn_direction='forward', rnn_direction='forward',
num_fc_layers=2, num_fc_layers=2,
@ -420,18 +375,18 @@ class DeepSpeech2InferModelOnline(DeepSpeech2ModelOnline):
def forward(self, audio_chunk, audio_chunk_lens, chunk_state_h_box, def forward(self, audio_chunk, audio_chunk_lens, chunk_state_h_box,
chunk_state_c_box): chunk_state_c_box):
eouts_chunk, eouts_chunk_lens, final_state_h_box, final_state_c_box = self.encoder.forward_chunk( eouts_chunk, eouts_chunk_lens, final_state_h_box, final_state_c_box = self.encoder(
audio_chunk, audio_chunk_lens, chunk_state_h_box, chunk_state_c_box) audio_chunk, audio_chunk_lens, chunk_state_h_box, chunk_state_c_box)
probs_chunk = self.decoder.softmax(eouts_chunk) probs_chunk = self.decoder.softmax(eouts_chunk)
return probs_chunk, eouts_chunk_lens, final_state_h_box, final_state_c_box return probs_chunk, eouts_chunk_lens, final_state_h_box, final_state_c_box
@classmethod def export(self):
def export(self, infer_model, feat_dim):
static_model = paddle.jit.to_static( static_model = paddle.jit.to_static(
infer_model, self,
input_spec=[ input_spec=[
paddle.static.InputSpec( paddle.static.InputSpec(
shape=[None, None, feat_dim], #[B, chunk_size, feat_dim] shape=[None, None, self.encoder.feat_size
], #[B, chunk_size, feat_dim]
dtype='float32'), # audio, [B,T,D] dtype='float32'), # audio, [B,T,D]
paddle.static.InputSpec(shape=[None], paddle.static.InputSpec(shape=[None],
dtype='int64'), # audio_length, [B] dtype='int64'), # audio_length, [B]

Loading…
Cancel
Save