跑通了deeppseech_online的流程

pull/735/head
huangyuxin 4 years ago
parent 2c8d28111a
commit 6baf9f0620

@ -29,8 +29,8 @@ from deepspeech.io.sampler import SortagradBatchSampler
from deepspeech.io.sampler import SortagradDistributedBatchSampler
from deepspeech.models.ds2 import DeepSpeech2InferModel
from deepspeech.models.ds2 import DeepSpeech2Model
#from deepspeech.models.ds2_online import DeepSpeech2InferModelOnline
#from deepspeech.models.ds2_online import DeepSpeech2ModelOnline
from deepspeech.models.ds2_online import DeepSpeech2InferModelOnline
from deepspeech.models.ds2_online import DeepSpeech2ModelOnline
from deepspeech.training.gradclip import ClipGradByGlobalNormWithLog
from deepspeech.training.trainer import Trainer
from deepspeech.utils import error_rate
@ -122,6 +122,7 @@ class DeepSpeech2Trainer(Trainer):
def setup_model(self):
config = self.config
if config.model.apply_online == True:
model = DeepSpeech2Model(
feat_size=self.train_loader.collate_fn.feature_size,
dict_size=self.train_loader.collate_fn.vocab_size,
@ -130,6 +131,17 @@ class DeepSpeech2Trainer(Trainer):
rnn_size=config.model.rnn_layer_size,
use_gru=config.model.use_gru,
share_rnn_weights=config.model.share_rnn_weights)
else:
model = DeepSpeech2ModelOnline(
feat_size=self.train_loader.collate_fn.feature_size,
dict_size=self.train_loader.collate_fn.vocab_size,
num_conv_layers=config.model.num_conv_layers,
num_rnn_layers=config.model.num_rnn_layers,
num_fc_layers=config.model.num_fc_layers,
fc_layers_size_list=config.model.fc_layers_size_list,
rnn_size=config.model.rnn_layer_size,
use_gru=config.model.use_gru,
share_rnn_weights=config.model.share_rnn_weights)
if self.parallel:
model = paddle.DataParallel(model)
@ -331,6 +343,10 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
exit(-1)
def export(self):
if self.config.model.apply_online == True:
infer_model = DeepSpeech2InferModelOnline.from_pretrained(
self.test_loader, self.config, self.args.checkpoint_path)
else:
infer_model = DeepSpeech2InferModel.from_pretrained(
self.test_loader, self.config, self.args.checkpoint_path)
@ -370,6 +386,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
def setup_model(self):
config = self.config
if config.model.apply_online == True:
model = DeepSpeech2Model(
feat_size=self.test_loader.collate_fn.feature_size,
dict_size=self.test_loader.collate_fn.vocab_size,
@ -378,6 +395,17 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
rnn_size=config.model.rnn_layer_size,
use_gru=config.model.use_gru,
share_rnn_weights=config.model.share_rnn_weights)
else:
model = DeepSpeech2ModelOnline(
feat_size=self.train_loader.collate_fn.feature_size,
dict_size=self.train_loader.collate_fn.vocab_size,
num_conv_layers=config.model.num_conv_layers,
num_rnn_layers=config.model.num_rnn_layers,
num_fc_layers=config.model.num_fc_layers,
fc_layers_size_list=config.model.fc_layers_size_list,
rnn_size=config.model.rnn_layer_size,
use_gru=config.model.use_gru,
share_rnn_weights=config.model.share_rnn_weights)
self.model = model
logger.info("Setup model!")

@ -19,8 +19,8 @@ from paddle import nn
from yacs.config import CfgNode
from deepspeech.models.ds2.conv import ConvStack
from deepspeech.modules.ctc import CTCDecoder
from deepspeech.models.ds2.rnn import RNNStack
from deepspeech.modules.ctc import CTCDecoder
from deepspeech.utils import layer_tools
from deepspeech.utils.checkpoint import Checkpoint
from deepspeech.utils.log import Log

@ -1,7 +1,17 @@
from .deepspeech2 import DeepSpeech2ModelOnline
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .deepspeech2 import DeepSpeech2InferModelOnline
from .deepspeech2 import DeepSpeech2ModelOnline
__all__ = ['DeepSpeech2ModelOnline', 'DeepSpeech2InferModelOnline']

@ -16,8 +16,10 @@ import unittest
import numpy as np
import paddle
#from deepspeech.models.deepspeech2 import DeepSpeech2Model
from deepspeech.models.ds2_online import DeepSpeech2ModelOnline as DeepSpeech2Model
from deepspeech.models.deepspeech2 import DeepSpeech2Model
from deepspeech.models.ds2_online import DeepSpeech2ModelOnline
class TestDeepSpeech2Model(unittest.TestCase):
def setUp(self):

Loading…
Cancel
Save