diff --git a/deepspeech/exps/deepspeech2/model.py b/deepspeech/exps/deepspeech2/model.py index 5c6c5fa18..91b7a1bfe 100644 --- a/deepspeech/exps/deepspeech2/model.py +++ b/deepspeech/exps/deepspeech2/model.py @@ -29,8 +29,8 @@ from deepspeech.io.sampler import SortagradBatchSampler from deepspeech.io.sampler import SortagradDistributedBatchSampler from deepspeech.models.ds2 import DeepSpeech2InferModel from deepspeech.models.ds2 import DeepSpeech2Model -from deepspeech.models.ds2_online import DeepSpeech2InferModelOnline -from deepspeech.models.ds2_online import DeepSpeech2ModelOnline +#from deepspeech.models.ds2_online import DeepSpeech2InferModelOnline +#from deepspeech.models.ds2_online import DeepSpeech2ModelOnline from deepspeech.training.gradclip import ClipGradByGlobalNormWithLog from deepspeech.training.trainer import Trainer from deepspeech.utils import error_rate @@ -122,25 +122,15 @@ class DeepSpeech2Trainer(Trainer): def setup_model(self): config = self.config - if (config.model.apply_online == False): - model = DeepSpeech2Model( - feat_size=self.train_loader.collate_fn.feature_size, - dict_size=self.train_loader.collate_fn.vocab_size, - num_conv_layers=config.model.num_conv_layers, - num_rnn_layers=config.model.num_rnn_layers, - rnn_size=config.model.rnn_layer_size, - use_gru=config.model.use_gru, - share_rnn_weights=config.model.share_rnn_weights) - else: - model = DeepSpeech2ModelOnline( - feat_size=self.train_loader.collate_fn.feature_size, - dict_size=self.train_loader.collate_fn.vocab_size, - num_conv_layers=config.model.num_conv_layers, - num_rnn_layers=config.model.num_rnn_layers, - rnn_size=config.model.rnn_layer_size, - use_gru=config.model.use_gru, - share_rnn_weights=config.model.share_rnn_weights) - + model = DeepSpeech2Model( + feat_size=self.train_loader.collate_fn.feature_size, + dict_size=self.train_loader.collate_fn.vocab_size, + num_conv_layers=config.model.num_conv_layers, + num_rnn_layers=config.model.num_rnn_layers, + rnn_size=config.model.rnn_layer_size, + use_gru=config.model.use_gru, + share_rnn_weights=config.model.share_rnn_weights) + if self.parallel: model = paddle.DataParallel(model) @@ -347,7 +337,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): else: infer_model = DeepSpeech2InferModelOnline.from_pretrained( self.test_loader, self.config, self.args.checkpoint_path) - + infer_model.eval() feat_dim = self.test_loader.collate_fn.feature_size static_model = paddle.jit.to_static( @@ -384,25 +374,15 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): def setup_model(self): config = self.config - if config.model.apply_online == False: - model = DeepSpeech2Model( - feat_size=self.test_loader.collate_fn.feature_size, - dict_size=self.test_loader.collate_fn.vocab_size, - num_conv_layers=config.model.num_conv_layers, - num_rnn_layers=config.model.num_rnn_layers, - rnn_size=config.model.rnn_layer_size, - use_gru=config.model.use_gru, - share_rnn_weights=config.model.share_rnn_weights) - else: - model = DeepSpeech2ModelOnline( - feat_size=self.test_loader.collate_fn.feature_size, - dict_size=self.test_loader.collate_fn.vocab_size, - num_conv_layers=config.model.num_conv_layers, - num_rnn_layers=config.model.num_rnn_layers, - rnn_size=config.model.rnn_layer_size, - use_gru=config.model.use_gru, - share_rnn_weights=config.model.share_rnn_weights) - + model = DeepSpeech2Model( + feat_size=self.test_loader.collate_fn.feature_size, + dict_size=self.test_loader.collate_fn.vocab_size, + num_conv_layers=config.model.num_conv_layers, + num_rnn_layers=config.model.num_rnn_layers, + rnn_size=config.model.rnn_layer_size, + use_gru=config.model.use_gru, + share_rnn_weights=config.model.share_rnn_weights) + self.model = model logger.info("Setup model!")