From 718ae52e3ff3204ab02a3f45852ec47897873742 Mon Sep 17 00:00:00 2001 From: huangyuxin Date: Tue, 10 Aug 2021 09:28:23 +0000 Subject: [PATCH] add from_config function to ds2_oneline and ds2 --- deepspeech/exps/deepspeech2/model.py | 135 ++++-------------- deepspeech/models/ds2/deepspeech2.py | 33 +++++ deepspeech/models/ds2_online/deepspeech2.py | 44 ++++-- .../aishell/s0/conf/deepspeech2_online.yaml | 12 +- tests/deepspeech2_online_model_test.py | 119 ++++++--------- 5 files changed, 141 insertions(+), 202 deletions(-) diff --git a/deepspeech/exps/deepspeech2/model.py b/deepspeech/exps/deepspeech2/model.py index 03fe8c6f..dfd81241 100644 --- a/deepspeech/exps/deepspeech2/model.py +++ b/deepspeech/exps/deepspeech2/model.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Contains DeepSpeech2 model.""" +"""Contains DeepSpeech2 and DeepSpeech2Online model.""" import time from collections import defaultdict from pathlib import Path @@ -38,8 +38,6 @@ from deepspeech.utils import layer_tools from deepspeech.utils import mp_tools from deepspeech.utils.log import Autolog from deepspeech.utils.log import Log -#from deepspeech.models.ds2_online import DeepSpeech2InferModelOnline -#from deepspeech.models.ds2_online import DeepSpeech2ModelOnline logger = Log(__name__).getlog() @@ -123,40 +121,20 @@ class DeepSpeech2Trainer(Trainer): return total_loss, num_seen_utts def setup_model(self): - config = self.config - if hasattr(self, "train_loader"): - config.defrost() - config.model.feat_size = self.train_loader.collate_fn.feature_size - config.model.dict_size = self.train_loader.collate_fn.vocab_size - config.freeze() - elif hasattr(self, "test_loader"): - config.defrost() - config.model.feat_size = self.test_loader.collate_fn.feature_size - config.model.dict_size = self.test_loader.collate_fn.vocab_size - config.freeze() - else: - raise Exception("Please setup the dataloader first") + config = self.config.clone() + config.defrost() + assert (self.train_loader.collate_fn.feature_size == + self.test_loader.collate_fn.feature_size) + assert (self.train_loader.collate_fn.vocab_size == + self.test_loader.collate_fn.vocab_size) + config.model.feat_size = self.train_loader.collate_fn.feature_size + config.model.dict_size = self.train_loader.collate_fn.vocab_size + config.freeze() if self.args.model_type == 'offline': - model = DeepSpeech2Model( - feat_size=config.model.feat_size, - dict_size=config.model.dict_size, - num_conv_layers=config.model.num_conv_layers, - num_rnn_layers=config.model.num_rnn_layers, - rnn_size=config.model.rnn_layer_size, - use_gru=config.model.use_gru, - share_rnn_weights=config.model.share_rnn_weights) + model = DeepSpeech2Model.from_config(config.model) elif self.args.model_type == 'online': - model = DeepSpeech2ModelOnline( - feat_size=config.model.feat_size, - dict_size=config.model.dict_size, - num_conv_layers=config.model.num_conv_layers, - num_rnn_layers=config.model.num_rnn_layers, - rnn_size=config.model.rnn_layer_size, - rnn_direction=config.model.rnn_direction, - num_fc_layers=config.model.num_fc_layers, - fc_layers_size_list=config.model.fc_layers_size_list, - use_gru=config.model.use_gru) + model = DeepSpeech2ModelOnline.from_config(config.model) else: raise Exception("wrong model type") if self.parallel: @@ -194,6 +172,9 @@ class DeepSpeech2Trainer(Trainer): config.data.manifest = config.data.dev_manifest dev_dataset = ManifestDataset.from_config(config) + config.data.manifest = config.data.test_manifest + test_dataset = ManifestDataset.from_config(config) + if self.parallel: batch_sampler = SortagradDistributedBatchSampler( train_dataset, @@ -217,19 +198,29 @@ class DeepSpeech2Trainer(Trainer): config.collator.augmentation_config = "" collate_fn_dev = SpeechCollator.from_config(config) + + config.collator.keep_transcription_text = True + config.collator.augmentation_config = "" + collate_fn_test = SpeechCollator.from_config(config) + self.train_loader = DataLoader( train_dataset, batch_sampler=batch_sampler, collate_fn=collate_fn_train, num_workers=config.collator.num_workers) - print("feature_size", self.train_loader.collate_fn.feature_size) self.valid_loader = DataLoader( dev_dataset, batch_size=config.collator.batch_size, shuffle=False, drop_last=False, collate_fn=collate_fn_dev) - logger.info("Setup train/valid Dataloader!") + self.test_loader = DataLoader( + test_dataset, + batch_size=config.decoding.batch_size, + shuffle=False, + drop_last=False, + collate_fn=collate_fn_test) + logger.info("Setup train/valid/test Dataloader!") class DeepSpeech2Tester(DeepSpeech2Trainer): @@ -371,20 +362,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): infer_model.eval() feat_dim = self.test_loader.collate_fn.feature_size - if self.args.model_type == 'offline': - static_model = paddle.jit.to_static( - infer_model, - input_spec=[ - paddle.static.InputSpec( - shape=[None, None, feat_dim], - dtype='float32'), # audio, [B,T,D] - paddle.static.InputSpec(shape=[None], - dtype='int64'), # audio_length, [B] - ]) - elif self.args.model_type == 'online': - static_model = infer_model.export() - else: - raise Exception("wrong model type") + static_model = infer_model.export() logger.info(f"Export code: {static_model.forward.code}") paddle.jit.save(static_model, self.args.export_path) @@ -408,63 +386,6 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): self.iteration = 0 self.epoch = 0 - ''' - def setup_model(self): - config = self.config - if self.args.model_type == 'offline': - model = DeepSpeech2Model( - feat_size=self.test_loader.collate_fn.feature_size, - dict_size=self.test_loader.collate_fn.vocab_size, - num_conv_layers=config.model.num_conv_layers, - num_rnn_layers=config.model.num_rnn_layers, - rnn_size=config.model.rnn_layer_size, - use_gru=config.model.use_gru, - share_rnn_weights=config.model.share_rnn_weights) - elif self.args.model_type == 'online': - model = DeepSpeech2ModelOnline( - feat_size=self.test_loader.collate_fn.feature_size, - dict_size=self.test_loader.collate_fn.vocab_size, - num_conv_layers=config.model.num_conv_layers, - num_rnn_layers=config.model.num_rnn_layers, - rnn_size=config.model.rnn_layer_size, - rnn_direction=config.model.rnn_direction, - num_fc_layers=config.model.num_fc_layers, - fc_layers_size_list=config.model.fc_layers_size_list, - use_gru=config.model.use_gru) - else: - raise Exception("Wrong model type") - - self.model = model - logger.info("Setup model!") - ''' - - def setup_dataloader(self): - config = self.config.clone() - config.defrost() - # return raw text - - config.data.manifest = config.data.test_manifest - # filter test examples, will cause less examples, but no mismatch with training - # and can use large batch size , save training time, so filter test egs now. - # config.data.min_input_len = 0.0 # second - # config.data.max_input_len = float('inf') # second - # config.data.min_output_len = 0.0 # tokens - # config.data.max_output_len = float('inf') # tokens - # config.data.min_output_input_ratio = 0.00 - # config.data.max_output_input_ratio = float('inf') - test_dataset = ManifestDataset.from_config(config) - - config.collator.keep_transcription_text = True - config.collator.augmentation_config = "" - # return text ord id - self.test_loader = DataLoader( - test_dataset, - batch_size=config.decoding.batch_size, - shuffle=False, - drop_last=False, - collate_fn=SpeechCollator.from_config(config)) - logger.info("Setup test Dataloader!") - def setup_output_dir(self): """Create a directory used for output. """ diff --git a/deepspeech/models/ds2/deepspeech2.py b/deepspeech/models/ds2/deepspeech2.py index 8d737e80..1ffd797b 100644 --- a/deepspeech/models/ds2/deepspeech2.py +++ b/deepspeech/models/ds2/deepspeech2.py @@ -228,6 +228,27 @@ class DeepSpeech2Model(nn.Layer): layer_tools.summary(model) return model + @classmethod + def from_config(cls, config): + """Build a DeepSpeec2Model from config + Parameters + + config: yacs.config.CfgNode + config.model + Returns + ------- + DeepSpeech2Model + The model built from config. + """ + model = cls(feat_size=config.feat_size, + dict_size=config.dict_size, + num_conv_layers=config.num_conv_layers, + num_rnn_layers=config.num_rnn_layers, + rnn_size=config.rnn_layer_size, + use_gru=config.use_gru, + share_rnn_weights=config.share_rnn_weights) + return model + class DeepSpeech2InferModel(DeepSpeech2Model): def __init__(self, @@ -260,3 +281,15 @@ class DeepSpeech2InferModel(DeepSpeech2Model): eouts, eouts_len = self.encoder(audio, audio_len) probs = self.decoder.softmax(eouts) return probs + + def export(self): + static_model = paddle.jit.to_static( + self, + input_spec=[ + paddle.static.InputSpec( + shape=[None, None, self.encoder.feat_size], + dtype='float32'), # audio, [B,T,D] + paddle.static.InputSpec(shape=[None], + dtype='int64'), # audio_length, [B] + ]) + return static_model diff --git a/deepspeech/models/ds2_online/deepspeech2.py b/deepspeech/models/ds2_online/deepspeech2.py index 75a6f044..3083e4b2 100644 --- a/deepspeech/models/ds2_online/deepspeech2.py +++ b/deepspeech/models/ds2_online/deepspeech2.py @@ -51,8 +51,9 @@ class CRNNEncoder(nn.Layer): self.use_gru = use_gru self.conv = Conv2dSubsampling4Online(feat_size, 32, dropout_rate=0.0) - i_size = self.conv.output_dim + self.output_dim = self.conv.output_dim + i_size = self.conv.output_dim self.rnn = nn.LayerList() self.layernorm_list = nn.LayerList() self.fc_layers_list = nn.LayerList() @@ -82,16 +83,18 @@ class CRNNEncoder(nn.Layer): num_layers=1, direction=rnn_direction)) self.layernorm_list.append(nn.LayerNorm(layernorm_size)) + self.output_dim = layernorm_size fc_input_size = layernorm_size for i in range(self.num_fc_layers): self.fc_layers_list.append( nn.Linear(fc_input_size, fc_layers_size_list[i])) fc_input_size = fc_layers_size_list[i] + self.output_dim = fc_layers_size_list[i] @property def output_size(self): - return self.fc_layers_size_list[-1] + return self.output_dim def forward(self, x, x_lens, init_state_h_box=None, init_state_c_box=None): """Compute Encoder outputs @@ -190,9 +193,6 @@ class CRNNEncoder(nn.Layer): for i in range(0, num_chunk): start = i * chunk_stride end = start + chunk_size - # end = min(start + chunk_size, max_len) - # if (end - start < receptive_field_length): - # break x_chunk = padded_x[:, start:end, :] x_len_left = paddle.where(x_lens - i * chunk_stride < 0, @@ -221,8 +221,6 @@ class DeepSpeech2ModelOnline(nn.Layer): :type text_data: Variable :param audio_len: Valid sequence length data layer. :type audio_len: Variable - :param masks: Masks data layer to reset padding. - :type masks: Variable :param dict_size: Dictionary size for tokenized transcription. :type dict_size: int :param num_conv_layers: Number of stacking convolution layers. @@ -231,6 +229,10 @@ class DeepSpeech2ModelOnline(nn.Layer): :type num_rnn_layers: int :param rnn_size: RNN layer size (dimension of RNN cells). :type rnn_size: int + :param num_fc_layers: Number of stacking FC layers. + :type num_fc_layers: int + :param fc_layers_size_list: The list of FC layer sizes. + :type fc_layers_size_list: [int,] :param use_gru: Use gru if set True. Use simple rnn if set False. :type use_gru: bool :return: A tuple of an output unnormalized log probability layer ( @@ -274,7 +276,6 @@ class DeepSpeech2ModelOnline(nn.Layer): fc_layers_size_list=fc_layers_size_list, rnn_size=rnn_size, use_gru=use_gru) - assert (self.encoder.output_size == fc_layers_size_list[-1]) self.decoder = CTCDecoder( odim=dict_size, # is in vocab @@ -337,7 +338,7 @@ class DeepSpeech2ModelOnline(nn.Layer): Returns ------- - DeepSpeech2Model + DeepSpeech2ModelOnline The model built from pretrained result. """ model = cls(feat_size=dataloader.collate_fn.feature_size, @@ -355,6 +356,29 @@ class DeepSpeech2ModelOnline(nn.Layer): layer_tools.summary(model) return model + @classmethod + def from_config(cls, config): + """Build a DeepSpeec2ModelOnline from config + Parameters + + config: yacs.config.CfgNode + config.model + Returns + ------- + DeepSpeech2ModelOnline + The model built from config. + """ + model = cls(feat_size=config.feat_size, + dict_size=config.dict_size, + num_conv_layers=config.num_conv_layers, + num_rnn_layers=config.num_rnn_layers, + rnn_size=config.rnn_layer_size, + rnn_direction=config.rnn_direction, + num_fc_layers=config.num_fc_layers, + fc_layers_size_list=config.fc_layers_size_list, + use_gru=config.use_gru) + return model + class DeepSpeech2InferModelOnline(DeepSpeech2ModelOnline): def __init__(self, @@ -392,7 +416,7 @@ class DeepSpeech2InferModelOnline(DeepSpeech2ModelOnline): paddle.static.InputSpec( shape=[None, None, self.encoder.feat_size], #[B, chunk_size, feat_dim] - dtype='float32'), # audio, [B,T,D] + dtype='float32'), paddle.static.InputSpec(shape=[None], dtype='int64'), # audio_length, [B] paddle.static.InputSpec( diff --git a/examples/aishell/s0/conf/deepspeech2_online.yaml b/examples/aishell/s0/conf/deepspeech2_online.yaml index 60df8d17..33030a52 100644 --- a/examples/aishell/s0/conf/deepspeech2_online.yaml +++ b/examples/aishell/s0/conf/deepspeech2_online.yaml @@ -36,17 +36,17 @@ collator: model: num_conv_layers: 2 - num_rnn_layers: 4 + num_rnn_layers: 3 rnn_layer_size: 1024 - rnn_direction: bidirect - num_fc_layers: 2 - fc_layers_size_list: 512, 256 + rnn_direction: forward # [forward, bidirect] + num_fc_layers: 1 + fc_layers_size_list: 512, use_gru: True training: n_epoch: 50 lr: 2e-3 - lr_decay: 0.83 + lr_decay: 0.83 # 0.83 weight_decay: 1e-06 global_grad_clip: 3.0 log_interval: 100 @@ -55,7 +55,7 @@ training: latest_n: 5 decoding: - batch_size: 64 + batch_size: 32 error_rate_type: cer decoding_method: ctc_beam_search lang_model_path: data/lm/zh_giga.no_cna_cmn.prune01244.klm diff --git a/tests/deepspeech2_online_model_test.py b/tests/deepspeech2_online_model_test.py index fd1dfc4b..87f04887 100644 --- a/tests/deepspeech2_online_model_test.py +++ b/tests/deepspeech2_online_model_test.py @@ -106,18 +106,34 @@ class TestDeepSpeech2ModelOnline(unittest.TestCase): self.assertEqual(loss.numel(), 1) def test_ds2_6(self): + model = DeepSpeech2ModelOnline( + feat_size=self.feat_dim, + dict_size=10, + num_conv_layers=2, + num_rnn_layers=3, + rnn_size=1024, + rnn_direction='bidirect', + num_fc_layers=2, + fc_layers_size_list=[512, 256], + use_gru=False) + loss = model(self.audio, self.audio_len, self.text, self.text_len) + self.assertEqual(loss.numel(), 1) + + def test_ds2_7(self): + use_gru = False model = DeepSpeech2ModelOnline( feat_size=self.feat_dim, dict_size=10, num_conv_layers=2, num_rnn_layers=1, rnn_size=1024, + rnn_direction='forward', num_fc_layers=2, fc_layers_size_list=[512, 256], - use_gru=True) + use_gru=use_gru) model.eval() paddle.device.set_device("cpu") - de_ch_size = 9 + de_ch_size = 8 eouts, eouts_lens, final_state_h_box, final_state_c_box = model.encoder( self.audio, self.audio_len) @@ -126,99 +142,44 @@ class TestDeepSpeech2ModelOnline(unittest.TestCase): eouts_by_chk = paddle.concat(eouts_by_chk_list, axis=1) eouts_lens_by_chk = paddle.add_n(eouts_lens_by_chk_list) decode_max_len = eouts.shape[1] - print("dml", decode_max_len) eouts_by_chk = eouts_by_chk[:, :decode_max_len, :] - self.assertEqual( - paddle.sum( - paddle.abs(paddle.subtract(eouts_lens, eouts_lens_by_chk))), 0) - self.assertEqual( - paddle.sum(paddle.abs(paddle.subtract(eouts, eouts_by_chk))), 0) self.assertEqual(paddle.allclose(eouts_by_chk, eouts), True) self.assertEqual( paddle.allclose(final_state_h_box, final_state_h_box_chk), True) - self.assertEqual( - paddle.allclose(final_state_c_box, final_state_c_box_chk), True) - """ - print ("conv_x", conv_x) - print ("conv_x_by_chk", conv_x_by_chk) - print ("final_state_list", final_state_list) - #print ("final_state_list_by_chk", final_state_list_by_chk) - print (paddle.sum(paddle.abs(paddle.subtract(eouts[:,:de_ch_size,:], eouts_by_chk[:,:de_ch_size,:])))) - print (paddle.allclose(eouts[:,:de_ch_size,:], eouts_by_chk[:,:de_ch_size,:])) - print (paddle.sum(paddle.abs(paddle.subtract(eouts[:,de_ch_size:de_ch_size*2,:], eouts_by_chk[:,de_ch_size:de_ch_size*2,:])))) - print (paddle.allclose(eouts[:,de_ch_size:de_ch_size*2,:], eouts_by_chk[:,de_ch_size:de_ch_size*2,:])) - print (paddle.sum(paddle.abs(paddle.subtract(eouts[:,de_ch_size*2:de_ch_size*3,:], eouts_by_chk[:,de_ch_size*2:de_ch_size*3,:])))) - print (paddle.allclose(eouts[:,de_ch_size*2:de_ch_size*3,:], eouts_by_chk[:,de_ch_size*2:de_ch_size*3,:])) - print (paddle.sum(paddle.abs(paddle.subtract(eouts, eouts_by_chk)))) - print (paddle.sum(paddle.abs(paddle.subtract(eouts, eouts_by_chk)))) - print (paddle.allclose(eouts[:,:,:], eouts_by_chk[:,:,:])) - """ - - """ - def split_into_chunk(self, x, x_lens, decoder_chunk_size, subsampling_rate, - receptive_field_length): - chunk_size = (decoder_chunk_size - 1 - ) * subsampling_rate + receptive_field_length - chunk_stride = subsampling_rate * decoder_chunk_size - max_len = x.shape[1] - assert (chunk_size <= max_len) - x_chunk_list = [] - x_chunk_lens_list = [] - padding_len = chunk_stride - (max_len - chunk_size) % chunk_stride - padding = paddle.zeros((x.shape[0], padding_len, x.shape[2])) - padded_x = paddle.concat([x, padding], axis=1) - num_chunk = (max_len + padding_len - chunk_size) / chunk_stride + 1 - num_chunk = int(num_chunk) - for i in range(0, num_chunk): - start = i * chunk_stride - end = start + chunk_size - x_chunk = padded_x[:, start:end, :] - x_len_left = paddle.where(x_lens - i * chunk_stride < 0, - paddle.zeros_like(x_lens), - x_lens - i * chunk_stride) - x_chunk_len_tmp = paddle.ones_like(x_lens) * chunk_size - x_chunk_lens = paddle.where(x_len_left < x_chunk_len_tmp, - x_len_left, x_chunk_len_tmp) - x_chunk_list.append(x_chunk) - x_chunk_lens_list.append(x_chunk_lens) - - return x_chunk_list, x_chunk_lens_list + if use_gru == False: + self.assertEqual( + paddle.allclose(final_state_c_box, final_state_c_box_chk), True) - def test_ds2_7(self): + def test_ds2_8(self): + use_gru = True model = DeepSpeech2ModelOnline( feat_size=self.feat_dim, dict_size=10, num_conv_layers=2, num_rnn_layers=1, rnn_size=1024, + rnn_direction='forward', num_fc_layers=2, fc_layers_size_list=[512, 256], - use_gru=True) + use_gru=use_gru) model.eval() paddle.device.set_device("cpu") - de_ch_size = 9 - - audio_chunk_list, audio_chunk_lens_list = self.split_into_chunk( - self.audio, self.audio_len, de_ch_size, - model.encoder.conv.subsampling_rate, - model.encoder.conv.receptive_field_length) - eouts_prefix = None - eouts_lens_prefix = None - chunk_state_list = [None] * model.encoder.num_rnn_layers - for i, audio_chunk in enumerate(audio_chunk_list): - audio_chunk_lens = audio_chunk_lens_list[i] - eouts_prefix, eouts_lens_prefix, chunk_state_list = model.decode_prob_by_chunk( - audio_chunk, audio_chunk_lens, eouts_prefix, eouts_lens_prefix, - chunk_state_list) - # print (i, probs_pre_chunks.shape) - - probs, eouts, eouts_lens, final_state_list = model.decode_prob( - self.audio, self.audio_len) + de_ch_size = 8 - decode_max_len = probs.shape[1] - probs_pre_chunks = probs_pre_chunks[:, :decode_max_len, :] - self.assertEqual(paddle.allclose(probs, probs_pre_chunks), True) - """ + eouts, eouts_lens, final_state_h_box, final_state_c_box = model.encoder( + self.audio, self.audio_len) + eouts_by_chk_list, eouts_lens_by_chk_list, final_state_h_box_chk, final_state_c_box_chk = model.encoder.forward_chunk_by_chunk( + self.audio, self.audio_len, de_ch_size) + eouts_by_chk = paddle.concat(eouts_by_chk_list, axis=1) + eouts_lens_by_chk = paddle.add_n(eouts_lens_by_chk_list) + decode_max_len = eouts.shape[1] + eouts_by_chk = eouts_by_chk[:, :decode_max_len, :] + self.assertEqual(paddle.allclose(eouts_by_chk, eouts), True) + self.assertEqual( + paddle.allclose(final_state_h_box, final_state_h_box_chk), True) + if use_gru == False: + self.assertEqual( + paddle.allclose(final_state_c_box, final_state_c_box_chk), True) if __name__ == '__main__':