|
|
|
@ -29,8 +29,8 @@ from deepspeech.io.sampler import SortagradBatchSampler
|
|
|
|
|
from deepspeech.io.sampler import SortagradDistributedBatchSampler
|
|
|
|
|
from deepspeech.models.ds2 import DeepSpeech2InferModel
|
|
|
|
|
from deepspeech.models.ds2 import DeepSpeech2Model
|
|
|
|
|
#from deepspeech.models.ds2_online import DeepSpeech2InferModelOnline
|
|
|
|
|
#from deepspeech.models.ds2_online import DeepSpeech2ModelOnline
|
|
|
|
|
from deepspeech.models.ds2_online import DeepSpeech2InferModelOnline
|
|
|
|
|
from deepspeech.models.ds2_online import DeepSpeech2ModelOnline
|
|
|
|
|
from deepspeech.training.gradclip import ClipGradByGlobalNormWithLog
|
|
|
|
|
from deepspeech.training.trainer import Trainer
|
|
|
|
|
from deepspeech.utils import error_rate
|
|
|
|
@ -122,6 +122,7 @@ class DeepSpeech2Trainer(Trainer):
|
|
|
|
|
|
|
|
|
|
def setup_model(self):
|
|
|
|
|
config = self.config
|
|
|
|
|
if config.model.apply_online == True:
|
|
|
|
|
model = DeepSpeech2Model(
|
|
|
|
|
feat_size=self.train_loader.collate_fn.feature_size,
|
|
|
|
|
dict_size=self.train_loader.collate_fn.vocab_size,
|
|
|
|
@ -130,6 +131,17 @@ class DeepSpeech2Trainer(Trainer):
|
|
|
|
|
rnn_size=config.model.rnn_layer_size,
|
|
|
|
|
use_gru=config.model.use_gru,
|
|
|
|
|
share_rnn_weights=config.model.share_rnn_weights)
|
|
|
|
|
else:
|
|
|
|
|
model = DeepSpeech2ModelOnline(
|
|
|
|
|
feat_size=self.train_loader.collate_fn.feature_size,
|
|
|
|
|
dict_size=self.train_loader.collate_fn.vocab_size,
|
|
|
|
|
num_conv_layers=config.model.num_conv_layers,
|
|
|
|
|
num_rnn_layers=config.model.num_rnn_layers,
|
|
|
|
|
num_fc_layers=config.model.num_fc_layers,
|
|
|
|
|
fc_layers_size_list=config.model.fc_layers_size_list,
|
|
|
|
|
rnn_size=config.model.rnn_layer_size,
|
|
|
|
|
use_gru=config.model.use_gru,
|
|
|
|
|
share_rnn_weights=config.model.share_rnn_weights)
|
|
|
|
|
|
|
|
|
|
if self.parallel:
|
|
|
|
|
model = paddle.DataParallel(model)
|
|
|
|
@ -331,6 +343,10 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
|
|
|
|
|
exit(-1)
|
|
|
|
|
|
|
|
|
|
def export(self):
|
|
|
|
|
if self.config.model.apply_online == True:
|
|
|
|
|
infer_model = DeepSpeech2InferModelOnline.from_pretrained(
|
|
|
|
|
self.test_loader, self.config, self.args.checkpoint_path)
|
|
|
|
|
else:
|
|
|
|
|
infer_model = DeepSpeech2InferModel.from_pretrained(
|
|
|
|
|
self.test_loader, self.config, self.args.checkpoint_path)
|
|
|
|
|
|
|
|
|
@ -370,6 +386,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
|
|
|
|
|
|
|
|
|
|
def setup_model(self):
|
|
|
|
|
config = self.config
|
|
|
|
|
if config.model.apply_online == True:
|
|
|
|
|
model = DeepSpeech2Model(
|
|
|
|
|
feat_size=self.test_loader.collate_fn.feature_size,
|
|
|
|
|
dict_size=self.test_loader.collate_fn.vocab_size,
|
|
|
|
@ -378,6 +395,17 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
|
|
|
|
|
rnn_size=config.model.rnn_layer_size,
|
|
|
|
|
use_gru=config.model.use_gru,
|
|
|
|
|
share_rnn_weights=config.model.share_rnn_weights)
|
|
|
|
|
else:
|
|
|
|
|
model = DeepSpeech2ModelOnline(
|
|
|
|
|
feat_size=self.train_loader.collate_fn.feature_size,
|
|
|
|
|
dict_size=self.train_loader.collate_fn.vocab_size,
|
|
|
|
|
num_conv_layers=config.model.num_conv_layers,
|
|
|
|
|
num_rnn_layers=config.model.num_rnn_layers,
|
|
|
|
|
num_fc_layers=config.model.num_fc_layers,
|
|
|
|
|
fc_layers_size_list=config.model.fc_layers_size_list,
|
|
|
|
|
rnn_size=config.model.rnn_layer_size,
|
|
|
|
|
use_gru=config.model.use_gru,
|
|
|
|
|
share_rnn_weights=config.model.share_rnn_weights)
|
|
|
|
|
|
|
|
|
|
self.model = model
|
|
|
|
|
logger.info("Setup model!")
|
|
|
|
|