跑通了deeppseech_online的流程

pull/735/head
huangyuxin 4 years ago
parent 2c8d28111a
commit 6baf9f0620

@ -29,8 +29,8 @@ from deepspeech.io.sampler import SortagradBatchSampler
from deepspeech.io.sampler import SortagradDistributedBatchSampler from deepspeech.io.sampler import SortagradDistributedBatchSampler
from deepspeech.models.ds2 import DeepSpeech2InferModel from deepspeech.models.ds2 import DeepSpeech2InferModel
from deepspeech.models.ds2 import DeepSpeech2Model from deepspeech.models.ds2 import DeepSpeech2Model
#from deepspeech.models.ds2_online import DeepSpeech2InferModelOnline from deepspeech.models.ds2_online import DeepSpeech2InferModelOnline
#from deepspeech.models.ds2_online import DeepSpeech2ModelOnline from deepspeech.models.ds2_online import DeepSpeech2ModelOnline
from deepspeech.training.gradclip import ClipGradByGlobalNormWithLog from deepspeech.training.gradclip import ClipGradByGlobalNormWithLog
from deepspeech.training.trainer import Trainer from deepspeech.training.trainer import Trainer
from deepspeech.utils import error_rate from deepspeech.utils import error_rate
@ -122,6 +122,7 @@ class DeepSpeech2Trainer(Trainer):
def setup_model(self): def setup_model(self):
config = self.config config = self.config
if config.model.apply_online == True:
model = DeepSpeech2Model( model = DeepSpeech2Model(
feat_size=self.train_loader.collate_fn.feature_size, feat_size=self.train_loader.collate_fn.feature_size,
dict_size=self.train_loader.collate_fn.vocab_size, dict_size=self.train_loader.collate_fn.vocab_size,
@ -130,6 +131,17 @@ class DeepSpeech2Trainer(Trainer):
rnn_size=config.model.rnn_layer_size, rnn_size=config.model.rnn_layer_size,
use_gru=config.model.use_gru, use_gru=config.model.use_gru,
share_rnn_weights=config.model.share_rnn_weights) share_rnn_weights=config.model.share_rnn_weights)
else:
model = DeepSpeech2ModelOnline(
feat_size=self.train_loader.collate_fn.feature_size,
dict_size=self.train_loader.collate_fn.vocab_size,
num_conv_layers=config.model.num_conv_layers,
num_rnn_layers=config.model.num_rnn_layers,
num_fc_layers=config.model.num_fc_layers,
fc_layers_size_list=config.model.fc_layers_size_list,
rnn_size=config.model.rnn_layer_size,
use_gru=config.model.use_gru,
share_rnn_weights=config.model.share_rnn_weights)
if self.parallel: if self.parallel:
model = paddle.DataParallel(model) model = paddle.DataParallel(model)
@ -331,6 +343,10 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
exit(-1) exit(-1)
def export(self): def export(self):
if self.config.model.apply_online == True:
infer_model = DeepSpeech2InferModelOnline.from_pretrained(
self.test_loader, self.config, self.args.checkpoint_path)
else:
infer_model = DeepSpeech2InferModel.from_pretrained( infer_model = DeepSpeech2InferModel.from_pretrained(
self.test_loader, self.config, self.args.checkpoint_path) self.test_loader, self.config, self.args.checkpoint_path)
@ -370,6 +386,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
def setup_model(self): def setup_model(self):
config = self.config config = self.config
if config.model.apply_online == True:
model = DeepSpeech2Model( model = DeepSpeech2Model(
feat_size=self.test_loader.collate_fn.feature_size, feat_size=self.test_loader.collate_fn.feature_size,
dict_size=self.test_loader.collate_fn.vocab_size, dict_size=self.test_loader.collate_fn.vocab_size,
@ -378,6 +395,17 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
rnn_size=config.model.rnn_layer_size, rnn_size=config.model.rnn_layer_size,
use_gru=config.model.use_gru, use_gru=config.model.use_gru,
share_rnn_weights=config.model.share_rnn_weights) share_rnn_weights=config.model.share_rnn_weights)
else:
model = DeepSpeech2ModelOnline(
feat_size=self.train_loader.collate_fn.feature_size,
dict_size=self.train_loader.collate_fn.vocab_size,
num_conv_layers=config.model.num_conv_layers,
num_rnn_layers=config.model.num_rnn_layers,
num_fc_layers=config.model.num_fc_layers,
fc_layers_size_list=config.model.fc_layers_size_list,
rnn_size=config.model.rnn_layer_size,
use_gru=config.model.use_gru,
share_rnn_weights=config.model.share_rnn_weights)
self.model = model self.model = model
logger.info("Setup model!") logger.info("Setup model!")

@ -19,8 +19,8 @@ from paddle import nn
from yacs.config import CfgNode from yacs.config import CfgNode
from deepspeech.models.ds2.conv import ConvStack from deepspeech.models.ds2.conv import ConvStack
from deepspeech.modules.ctc import CTCDecoder
from deepspeech.models.ds2.rnn import RNNStack from deepspeech.models.ds2.rnn import RNNStack
from deepspeech.modules.ctc import CTCDecoder
from deepspeech.utils import layer_tools from deepspeech.utils import layer_tools
from deepspeech.utils.checkpoint import Checkpoint from deepspeech.utils.checkpoint import Checkpoint
from deepspeech.utils.log import Log from deepspeech.utils.log import Log

@ -1,7 +1,17 @@
from .deepspeech2 import DeepSpeech2ModelOnline # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .deepspeech2 import DeepSpeech2InferModelOnline from .deepspeech2 import DeepSpeech2InferModelOnline
from .deepspeech2 import DeepSpeech2ModelOnline
__all__ = ['DeepSpeech2ModelOnline', 'DeepSpeech2InferModelOnline'] __all__ = ['DeepSpeech2ModelOnline', 'DeepSpeech2InferModelOnline']

@ -16,8 +16,10 @@ import unittest
import numpy as np import numpy as np
import paddle import paddle
#from deepspeech.models.deepspeech2 import DeepSpeech2Model from deepspeech.models.deepspeech2 import DeepSpeech2Model
from deepspeech.models.ds2_online import DeepSpeech2ModelOnline as DeepSpeech2Model
from deepspeech.models.ds2_online import DeepSpeech2ModelOnline
class TestDeepSpeech2Model(unittest.TestCase): class TestDeepSpeech2Model(unittest.TestCase):
def setUp(self): def setUp(self):

Loading…
Cancel
Save