|
|
|
@ -139,9 +139,10 @@ class DeepSpeech2Trainer(Trainer):
|
|
|
|
|
dict_size=self.train_loader.collate_fn.vocab_size,
|
|
|
|
|
num_conv_layers=config.model.num_conv_layers,
|
|
|
|
|
num_rnn_layers=config.model.num_rnn_layers,
|
|
|
|
|
rnn_size=config.model.rnn_layer_size,
|
|
|
|
|
rnn_direction=config.model.rnn_direction,
|
|
|
|
|
num_fc_layers=config.model.num_fc_layers,
|
|
|
|
|
fc_layers_size_list=config.model.fc_layers_size_list,
|
|
|
|
|
rnn_size=config.model.rnn_layer_size,
|
|
|
|
|
use_gru=config.model.use_gru)
|
|
|
|
|
else:
|
|
|
|
|
raise Exception("wrong model type")
|
|
|
|
@ -411,9 +412,10 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
|
|
|
|
|
dict_size=self.test_loader.collate_fn.vocab_size,
|
|
|
|
|
num_conv_layers=config.model.num_conv_layers,
|
|
|
|
|
num_rnn_layers=config.model.num_rnn_layers,
|
|
|
|
|
rnn_size=config.model.rnn_layer_size,
|
|
|
|
|
rnn_direction=config.model.rnn_direction,
|
|
|
|
|
num_fc_layers=config.model.num_fc_layers,
|
|
|
|
|
fc_layers_size_list=config.model.fc_layers_size_list,
|
|
|
|
|
rnn_size=config.model.rnn_layer_size,
|
|
|
|
|
use_gru=config.model.use_gru)
|
|
|
|
|
else:
|
|
|
|
|
raise Exception("Wrong model type")
|
|
|
|
|