|
|
|
@ -29,6 +29,8 @@ from deepspeech.io.sampler import SortagradBatchSampler
|
|
|
|
|
from deepspeech.io.sampler import SortagradDistributedBatchSampler
|
|
|
|
|
from deepspeech.models.ds2 import DeepSpeech2InferModel
|
|
|
|
|
from deepspeech.models.ds2 import DeepSpeech2Model
|
|
|
|
|
from deepspeech.models.ds2_online import DeepSpeech2InferModelOnline
|
|
|
|
|
from deepspeech.models.ds2_online import DeepSpeech2ModelOnline
|
|
|
|
|
from deepspeech.training.gradclip import ClipGradByGlobalNormWithLog
|
|
|
|
|
from deepspeech.training.trainer import Trainer
|
|
|
|
|
from deepspeech.utils import error_rate
|
|
|
|
@ -122,13 +124,27 @@ class DeepSpeech2Trainer(Trainer):
|
|
|
|
|
|
|
|
|
|
def setup_model(self):
|
|
|
|
|
config = self.config
|
|
|
|
|
model = DeepSpeech2Model(
|
|
|
|
|
feat_size=self.train_loader.collate_fn.feature_size,
|
|
|
|
|
dict_size=self.train_loader.collate_fn.vocab_size,
|
|
|
|
|
num_conv_layers=config.model.num_conv_layers,
|
|
|
|
|
num_rnn_layers=config.model.num_rnn_layers,
|
|
|
|
|
rnn_size=config.model.rnn_layer_size,
|
|
|
|
|
use_gru=config.model.use_gru)
|
|
|
|
|
if self.args.model_type == 'offline':
|
|
|
|
|
model = DeepSpeech2Model(
|
|
|
|
|
feat_size=self.train_loader.collate_fn.feature_size,
|
|
|
|
|
dict_size=self.train_loader.collate_fn.vocab_size,
|
|
|
|
|
num_conv_layers=config.model.num_conv_layers,
|
|
|
|
|
num_rnn_layers=config.model.num_rnn_layers,
|
|
|
|
|
rnn_size=config.model.rnn_layer_size,
|
|
|
|
|
use_gru=config.model.use_gru,
|
|
|
|
|
share_rnn_weights=config.model.share_rnn_weights)
|
|
|
|
|
elif self.args.model_type == 'online':
|
|
|
|
|
model = DeepSpeech2ModelOnline(
|
|
|
|
|
feat_size=self.train_loader.collate_fn.feature_size,
|
|
|
|
|
dict_size=self.train_loader.collate_fn.vocab_size,
|
|
|
|
|
num_conv_layers=config.model.num_conv_layers,
|
|
|
|
|
num_rnn_layers=config.model.num_rnn_layers,
|
|
|
|
|
num_fc_layers=config.model.num_fc_layers,
|
|
|
|
|
fc_layers_size_list=config.model.fc_layers_size_list,
|
|
|
|
|
rnn_size=config.model.rnn_layer_size,
|
|
|
|
|
use_gru=config.model.use_gru)
|
|
|
|
|
else:
|
|
|
|
|
raise Exception("wrong model type")
|
|
|
|
|
if self.parallel:
|
|
|
|
|
model = paddle.DataParallel(model)
|
|
|
|
|
|
|
|
|
@ -329,8 +345,14 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
|
|
|
|
|
exit(-1)
|
|
|
|
|
|
|
|
|
|
def export(self):
|
|
|
|
|
infer_model = DeepSpeech2InferModel.from_pretrained(
|
|
|
|
|
self.test_loader, self.config, self.args.checkpoint_path)
|
|
|
|
|
if self.args.model_type == 'offline':
|
|
|
|
|
infer_model = DeepSpeech2InferModel.from_pretrained(
|
|
|
|
|
self.test_loader, self.config, self.args.checkpoint_path)
|
|
|
|
|
elif self.args.model_type == 'online':
|
|
|
|
|
infer_model = DeepSpeech2InferModelOnline.from_pretrained(
|
|
|
|
|
self.test_loader, self.config, self.args.checkpoint_path)
|
|
|
|
|
else:
|
|
|
|
|
raise Exception("wrong model tyep")
|
|
|
|
|
|
|
|
|
|
infer_model.eval()
|
|
|
|
|
feat_dim = self.test_loader.collate_fn.feature_size
|
|
|
|
@ -368,13 +390,27 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
|
|
|
|
|
|
|
|
|
|
def setup_model(self):
|
|
|
|
|
config = self.config
|
|
|
|
|
model = DeepSpeech2Model(
|
|
|
|
|
feat_size=self.test_loader.collate_fn.feature_size,
|
|
|
|
|
dict_size=self.test_loader.collate_fn.vocab_size,
|
|
|
|
|
num_conv_layers=config.model.num_conv_layers,
|
|
|
|
|
num_rnn_layers=config.model.num_rnn_layers,
|
|
|
|
|
rnn_size=config.model.rnn_layer_size,
|
|
|
|
|
use_gru=config.model.use_gru)
|
|
|
|
|
if self.args.model_type == 'offline':
|
|
|
|
|
model = DeepSpeech2Model(
|
|
|
|
|
feat_size=self.test_loader.collate_fn.feature_size,
|
|
|
|
|
dict_size=self.test_loader.collate_fn.vocab_size,
|
|
|
|
|
num_conv_layers=config.model.num_conv_layers,
|
|
|
|
|
num_rnn_layers=config.model.num_rnn_layers,
|
|
|
|
|
rnn_size=config.model.rnn_layer_size,
|
|
|
|
|
use_gru=config.model.use_gru,
|
|
|
|
|
share_rnn_weights=config.model.share_rnn_weights)
|
|
|
|
|
elif self.args.model_type == 'online':
|
|
|
|
|
model = DeepSpeech2ModelOnline(
|
|
|
|
|
feat_size=self.test_loader.collate_fn.feature_size,
|
|
|
|
|
dict_size=self.test_loader.collate_fn.vocab_size,
|
|
|
|
|
num_conv_layers=config.model.num_conv_layers,
|
|
|
|
|
num_rnn_layers=config.model.num_rnn_layers,
|
|
|
|
|
num_fc_layers=config.model.num_fc_layers,
|
|
|
|
|
fc_layers_size_list=config.model.fc_layers_size_list,
|
|
|
|
|
rnn_size=config.model.rnn_layer_size,
|
|
|
|
|
use_gru=config.model.use_gru)
|
|
|
|
|
else:
|
|
|
|
|
raise Exception("Wrong model type")
|
|
|
|
|
|
|
|
|
|
self.model = model
|
|
|
|
|
logger.info("Setup model!")
|
|
|
|
|