diff --git a/deepspeech/exps/deepspeech2/model.py b/deepspeech/exps/deepspeech2/model.py index fab94ced8..9e870b13e 100644 --- a/deepspeech/exps/deepspeech2/model.py +++ b/deepspeech/exps/deepspeech2/model.py @@ -29,8 +29,8 @@ from deepspeech.io.sampler import SortagradBatchSampler from deepspeech.io.sampler import SortagradDistributedBatchSampler from deepspeech.models.ds2 import DeepSpeech2InferModel from deepspeech.models.ds2 import DeepSpeech2Model -#from deepspeech.models.ds2_online import DeepSpeech2InferModelOnline -#from deepspeech.models.ds2_online import DeepSpeech2ModelOnline +from deepspeech.models.ds2_online import DeepSpeech2InferModelOnline +from deepspeech.models.ds2_online import DeepSpeech2ModelOnline from deepspeech.training.gradclip import ClipGradByGlobalNormWithLog from deepspeech.training.trainer import Trainer from deepspeech.utils import error_rate @@ -122,14 +122,26 @@ class DeepSpeech2Trainer(Trainer): def setup_model(self): config = self.config - model = DeepSpeech2Model( - feat_size=self.train_loader.collate_fn.feature_size, - dict_size=self.train_loader.collate_fn.vocab_size, - num_conv_layers=config.model.num_conv_layers, - num_rnn_layers=config.model.num_rnn_layers, - rnn_size=config.model.rnn_layer_size, - use_gru=config.model.use_gru, - share_rnn_weights=config.model.share_rnn_weights) + if config.model.apply_online == True: + model = DeepSpeech2Model( + feat_size=self.train_loader.collate_fn.feature_size, + dict_size=self.train_loader.collate_fn.vocab_size, + num_conv_layers=config.model.num_conv_layers, + num_rnn_layers=config.model.num_rnn_layers, + rnn_size=config.model.rnn_layer_size, + use_gru=config.model.use_gru, + share_rnn_weights=config.model.share_rnn_weights) + else: + model = DeepSpeech2ModelOnline( + feat_size=self.train_loader.collate_fn.feature_size, + dict_size=self.train_loader.collate_fn.vocab_size, + num_conv_layers=config.model.num_conv_layers, + num_rnn_layers=config.model.num_rnn_layers, + num_fc_layers=config.model.num_fc_layers, + fc_layers_size_list=config.model.fc_layers_size_list, + rnn_size=config.model.rnn_layer_size, + use_gru=config.model.use_gru, + share_rnn_weights=config.model.share_rnn_weights) if self.parallel: model = paddle.DataParallel(model) @@ -331,8 +343,12 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): exit(-1) def export(self): - infer_model = DeepSpeech2InferModel.from_pretrained( - self.test_loader, self.config, self.args.checkpoint_path) + if self.config.model.apply_online == True: + infer_model = DeepSpeech2InferModelOnline.from_pretrained( + self.test_loader, self.config, self.args.checkpoint_path) + else: + infer_model = DeepSpeech2InferModel.from_pretrained( + self.test_loader, self.config, self.args.checkpoint_path) infer_model.eval() feat_dim = self.test_loader.collate_fn.feature_size @@ -370,14 +386,26 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): def setup_model(self): config = self.config - model = DeepSpeech2Model( - feat_size=self.test_loader.collate_fn.feature_size, - dict_size=self.test_loader.collate_fn.vocab_size, - num_conv_layers=config.model.num_conv_layers, - num_rnn_layers=config.model.num_rnn_layers, - rnn_size=config.model.rnn_layer_size, - use_gru=config.model.use_gru, - share_rnn_weights=config.model.share_rnn_weights) + if config.model.apply_online == True: + model = DeepSpeech2Model( + feat_size=self.test_loader.collate_fn.feature_size, + dict_size=self.test_loader.collate_fn.vocab_size, + num_conv_layers=config.model.num_conv_layers, + num_rnn_layers=config.model.num_rnn_layers, + rnn_size=config.model.rnn_layer_size, + use_gru=config.model.use_gru, + share_rnn_weights=config.model.share_rnn_weights) + else: + model = DeepSpeech2ModelOnline( + feat_size=self.train_loader.collate_fn.feature_size, + dict_size=self.train_loader.collate_fn.vocab_size, + num_conv_layers=config.model.num_conv_layers, + num_rnn_layers=config.model.num_rnn_layers, + num_fc_layers=config.model.num_fc_layers, + fc_layers_size_list=config.model.fc_layers_size_list, + rnn_size=config.model.rnn_layer_size, + use_gru=config.model.use_gru, + share_rnn_weights=config.model.share_rnn_weights) self.model = model logger.info("Setup model!") diff --git a/deepspeech/models/ds2/deepspeech2.py b/deepspeech/models/ds2/deepspeech2.py index 4026c89a7..8d737e800 100644 --- a/deepspeech/models/ds2/deepspeech2.py +++ b/deepspeech/models/ds2/deepspeech2.py @@ -19,8 +19,8 @@ from paddle import nn from yacs.config import CfgNode from deepspeech.models.ds2.conv import ConvStack -from deepspeech.modules.ctc import CTCDecoder from deepspeech.models.ds2.rnn import RNNStack +from deepspeech.modules.ctc import CTCDecoder from deepspeech.utils import layer_tools from deepspeech.utils.checkpoint import Checkpoint from deepspeech.utils.log import Log diff --git a/deepspeech/models/ds2_online/__init__.py b/deepspeech/models/ds2_online/__init__.py index 88076667c..255000eeb 100644 --- a/deepspeech/models/ds2_online/__init__.py +++ b/deepspeech/models/ds2_online/__init__.py @@ -1,7 +1,17 @@ -from .deepspeech2 import DeepSpeech2ModelOnline +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from .deepspeech2 import DeepSpeech2InferModelOnline +from .deepspeech2 import DeepSpeech2ModelOnline __all__ = ['DeepSpeech2ModelOnline', 'DeepSpeech2InferModelOnline'] - - - diff --git a/tests/deepspeech2_model_test.py b/tests/deepspeech2_model_test.py index bb40802d7..1938f7147 100644 --- a/tests/deepspeech2_model_test.py +++ b/tests/deepspeech2_model_test.py @@ -16,8 +16,10 @@ import unittest import numpy as np import paddle -#from deepspeech.models.deepspeech2 import DeepSpeech2Model -from deepspeech.models.ds2_online import DeepSpeech2ModelOnline as DeepSpeech2Model +from deepspeech.models.deepspeech2 import DeepSpeech2Model + +from deepspeech.models.ds2_online import DeepSpeech2ModelOnline + class TestDeepSpeech2Model(unittest.TestCase): def setUp(self):