先不暴露出online

pull/735/head
huangyuxin 3 years ago
parent 6079a2495d
commit 5dd9e2f8ec

@ -29,8 +29,8 @@ from deepspeech.io.sampler import SortagradBatchSampler
from deepspeech.io.sampler import SortagradDistributedBatchSampler from deepspeech.io.sampler import SortagradDistributedBatchSampler
from deepspeech.models.ds2 import DeepSpeech2InferModel from deepspeech.models.ds2 import DeepSpeech2InferModel
from deepspeech.models.ds2 import DeepSpeech2Model from deepspeech.models.ds2 import DeepSpeech2Model
from deepspeech.models.ds2_online import DeepSpeech2InferModelOnline #from deepspeech.models.ds2_online import DeepSpeech2InferModelOnline
from deepspeech.models.ds2_online import DeepSpeech2ModelOnline #from deepspeech.models.ds2_online import DeepSpeech2ModelOnline
from deepspeech.training.gradclip import ClipGradByGlobalNormWithLog from deepspeech.training.gradclip import ClipGradByGlobalNormWithLog
from deepspeech.training.trainer import Trainer from deepspeech.training.trainer import Trainer
from deepspeech.utils import error_rate from deepspeech.utils import error_rate
@ -122,7 +122,6 @@ class DeepSpeech2Trainer(Trainer):
def setup_model(self): def setup_model(self):
config = self.config config = self.config
if (config.model.apply_online == False):
model = DeepSpeech2Model( model = DeepSpeech2Model(
feat_size=self.train_loader.collate_fn.feature_size, feat_size=self.train_loader.collate_fn.feature_size,
dict_size=self.train_loader.collate_fn.vocab_size, dict_size=self.train_loader.collate_fn.vocab_size,
@ -131,15 +130,6 @@ class DeepSpeech2Trainer(Trainer):
rnn_size=config.model.rnn_layer_size, rnn_size=config.model.rnn_layer_size,
use_gru=config.model.use_gru, use_gru=config.model.use_gru,
share_rnn_weights=config.model.share_rnn_weights) share_rnn_weights=config.model.share_rnn_weights)
else:
model = DeepSpeech2ModelOnline(
feat_size=self.train_loader.collate_fn.feature_size,
dict_size=self.train_loader.collate_fn.vocab_size,
num_conv_layers=config.model.num_conv_layers,
num_rnn_layers=config.model.num_rnn_layers,
rnn_size=config.model.rnn_layer_size,
use_gru=config.model.use_gru,
share_rnn_weights=config.model.share_rnn_weights)
if self.parallel: if self.parallel:
model = paddle.DataParallel(model) model = paddle.DataParallel(model)
@ -384,7 +374,6 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
def setup_model(self): def setup_model(self):
config = self.config config = self.config
if config.model.apply_online == False:
model = DeepSpeech2Model( model = DeepSpeech2Model(
feat_size=self.test_loader.collate_fn.feature_size, feat_size=self.test_loader.collate_fn.feature_size,
dict_size=self.test_loader.collate_fn.vocab_size, dict_size=self.test_loader.collate_fn.vocab_size,
@ -393,15 +382,6 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
rnn_size=config.model.rnn_layer_size, rnn_size=config.model.rnn_layer_size,
use_gru=config.model.use_gru, use_gru=config.model.use_gru,
share_rnn_weights=config.model.share_rnn_weights) share_rnn_weights=config.model.share_rnn_weights)
else:
model = DeepSpeech2ModelOnline(
feat_size=self.test_loader.collate_fn.feature_size,
dict_size=self.test_loader.collate_fn.vocab_size,
num_conv_layers=config.model.num_conv_layers,
num_rnn_layers=config.model.num_rnn_layers,
rnn_size=config.model.rnn_layer_size,
use_gru=config.model.use_gru,
share_rnn_weights=config.model.share_rnn_weights)
self.model = model self.model = model
logger.info("Setup model!") logger.info("Setup model!")

Loading…
Cancel
Save