先不暴露出online

pull/735/head
huangyuxin 4 years ago
parent 6079a2495d
commit 5dd9e2f8ec

@ -29,8 +29,8 @@ from deepspeech.io.sampler import SortagradBatchSampler
from deepspeech.io.sampler import SortagradDistributedBatchSampler from deepspeech.io.sampler import SortagradDistributedBatchSampler
from deepspeech.models.ds2 import DeepSpeech2InferModel from deepspeech.models.ds2 import DeepSpeech2InferModel
from deepspeech.models.ds2 import DeepSpeech2Model from deepspeech.models.ds2 import DeepSpeech2Model
from deepspeech.models.ds2_online import DeepSpeech2InferModelOnline #from deepspeech.models.ds2_online import DeepSpeech2InferModelOnline
from deepspeech.models.ds2_online import DeepSpeech2ModelOnline #from deepspeech.models.ds2_online import DeepSpeech2ModelOnline
from deepspeech.training.gradclip import ClipGradByGlobalNormWithLog from deepspeech.training.gradclip import ClipGradByGlobalNormWithLog
from deepspeech.training.trainer import Trainer from deepspeech.training.trainer import Trainer
from deepspeech.utils import error_rate from deepspeech.utils import error_rate
@ -122,25 +122,15 @@ class DeepSpeech2Trainer(Trainer):
def setup_model(self): def setup_model(self):
config = self.config config = self.config
if (config.model.apply_online == False): model = DeepSpeech2Model(
model = DeepSpeech2Model( feat_size=self.train_loader.collate_fn.feature_size,
feat_size=self.train_loader.collate_fn.feature_size, dict_size=self.train_loader.collate_fn.vocab_size,
dict_size=self.train_loader.collate_fn.vocab_size, num_conv_layers=config.model.num_conv_layers,
num_conv_layers=config.model.num_conv_layers, num_rnn_layers=config.model.num_rnn_layers,
num_rnn_layers=config.model.num_rnn_layers, rnn_size=config.model.rnn_layer_size,
rnn_size=config.model.rnn_layer_size, use_gru=config.model.use_gru,
use_gru=config.model.use_gru, share_rnn_weights=config.model.share_rnn_weights)
share_rnn_weights=config.model.share_rnn_weights)
else:
model = DeepSpeech2ModelOnline(
feat_size=self.train_loader.collate_fn.feature_size,
dict_size=self.train_loader.collate_fn.vocab_size,
num_conv_layers=config.model.num_conv_layers,
num_rnn_layers=config.model.num_rnn_layers,
rnn_size=config.model.rnn_layer_size,
use_gru=config.model.use_gru,
share_rnn_weights=config.model.share_rnn_weights)
if self.parallel: if self.parallel:
model = paddle.DataParallel(model) model = paddle.DataParallel(model)
@ -347,7 +337,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
else: else:
infer_model = DeepSpeech2InferModelOnline.from_pretrained( infer_model = DeepSpeech2InferModelOnline.from_pretrained(
self.test_loader, self.config, self.args.checkpoint_path) self.test_loader, self.config, self.args.checkpoint_path)
infer_model.eval() infer_model.eval()
feat_dim = self.test_loader.collate_fn.feature_size feat_dim = self.test_loader.collate_fn.feature_size
static_model = paddle.jit.to_static( static_model = paddle.jit.to_static(
@ -384,25 +374,15 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
def setup_model(self): def setup_model(self):
config = self.config config = self.config
if config.model.apply_online == False: model = DeepSpeech2Model(
model = DeepSpeech2Model( feat_size=self.test_loader.collate_fn.feature_size,
feat_size=self.test_loader.collate_fn.feature_size, dict_size=self.test_loader.collate_fn.vocab_size,
dict_size=self.test_loader.collate_fn.vocab_size, num_conv_layers=config.model.num_conv_layers,
num_conv_layers=config.model.num_conv_layers, num_rnn_layers=config.model.num_rnn_layers,
num_rnn_layers=config.model.num_rnn_layers, rnn_size=config.model.rnn_layer_size,
rnn_size=config.model.rnn_layer_size, use_gru=config.model.use_gru,
use_gru=config.model.use_gru, share_rnn_weights=config.model.share_rnn_weights)
share_rnn_weights=config.model.share_rnn_weights)
else:
model = DeepSpeech2ModelOnline(
feat_size=self.test_loader.collate_fn.feature_size,
dict_size=self.test_loader.collate_fn.vocab_size,
num_conv_layers=config.model.num_conv_layers,
num_rnn_layers=config.model.num_rnn_layers,
rnn_size=config.model.rnn_layer_size,
use_gru=config.model.use_gru,
share_rnn_weights=config.model.share_rnn_weights)
self.model = model self.model = model
logger.info("Setup model!") logger.info("Setup model!")

Loading…
Cancel
Save