|
|
|
@ -29,8 +29,8 @@ from deepspeech.io.sampler import SortagradBatchSampler
|
|
|
|
|
from deepspeech.io.sampler import SortagradDistributedBatchSampler
|
|
|
|
|
from deepspeech.models.ds2 import DeepSpeech2InferModel
|
|
|
|
|
from deepspeech.models.ds2 import DeepSpeech2Model
|
|
|
|
|
from deepspeech.models.ds2_online import DeepSpeech2InferModelOnline
|
|
|
|
|
from deepspeech.models.ds2_online import DeepSpeech2ModelOnline
|
|
|
|
|
#from deepspeech.models.ds2_online import DeepSpeech2InferModelOnline
|
|
|
|
|
#from deepspeech.models.ds2_online import DeepSpeech2ModelOnline
|
|
|
|
|
from deepspeech.training.gradclip import ClipGradByGlobalNormWithLog
|
|
|
|
|
from deepspeech.training.trainer import Trainer
|
|
|
|
|
from deepspeech.utils import error_rate
|
|
|
|
@ -122,7 +122,6 @@ class DeepSpeech2Trainer(Trainer):
|
|
|
|
|
|
|
|
|
|
def setup_model(self):
|
|
|
|
|
config = self.config
|
|
|
|
|
if (config.model.apply_online == False):
|
|
|
|
|
model = DeepSpeech2Model(
|
|
|
|
|
feat_size=self.train_loader.collate_fn.feature_size,
|
|
|
|
|
dict_size=self.train_loader.collate_fn.vocab_size,
|
|
|
|
@ -131,15 +130,6 @@ class DeepSpeech2Trainer(Trainer):
|
|
|
|
|
rnn_size=config.model.rnn_layer_size,
|
|
|
|
|
use_gru=config.model.use_gru,
|
|
|
|
|
share_rnn_weights=config.model.share_rnn_weights)
|
|
|
|
|
else:
|
|
|
|
|
model = DeepSpeech2ModelOnline(
|
|
|
|
|
feat_size=self.train_loader.collate_fn.feature_size,
|
|
|
|
|
dict_size=self.train_loader.collate_fn.vocab_size,
|
|
|
|
|
num_conv_layers=config.model.num_conv_layers,
|
|
|
|
|
num_rnn_layers=config.model.num_rnn_layers,
|
|
|
|
|
rnn_size=config.model.rnn_layer_size,
|
|
|
|
|
use_gru=config.model.use_gru,
|
|
|
|
|
share_rnn_weights=config.model.share_rnn_weights)
|
|
|
|
|
|
|
|
|
|
if self.parallel:
|
|
|
|
|
model = paddle.DataParallel(model)
|
|
|
|
@ -384,7 +374,6 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
|
|
|
|
|
|
|
|
|
|
def setup_model(self):
|
|
|
|
|
config = self.config
|
|
|
|
|
if config.model.apply_online == False:
|
|
|
|
|
model = DeepSpeech2Model(
|
|
|
|
|
feat_size=self.test_loader.collate_fn.feature_size,
|
|
|
|
|
dict_size=self.test_loader.collate_fn.vocab_size,
|
|
|
|
@ -393,15 +382,6 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
|
|
|
|
|
rnn_size=config.model.rnn_layer_size,
|
|
|
|
|
use_gru=config.model.use_gru,
|
|
|
|
|
share_rnn_weights=config.model.share_rnn_weights)
|
|
|
|
|
else:
|
|
|
|
|
model = DeepSpeech2ModelOnline(
|
|
|
|
|
feat_size=self.test_loader.collate_fn.feature_size,
|
|
|
|
|
dict_size=self.test_loader.collate_fn.vocab_size,
|
|
|
|
|
num_conv_layers=config.model.num_conv_layers,
|
|
|
|
|
num_rnn_layers=config.model.num_rnn_layers,
|
|
|
|
|
rnn_size=config.model.rnn_layer_size,
|
|
|
|
|
use_gru=config.model.use_gru,
|
|
|
|
|
share_rnn_weights=config.model.share_rnn_weights)
|
|
|
|
|
|
|
|
|
|
self.model = model
|
|
|
|
|
logger.info("Setup model!")
|
|
|
|
|