|
|
@ -102,8 +102,8 @@ class DeepSpeech2Trainer(Trainer):
|
|
|
|
def setup_model(self):
|
|
|
|
def setup_model(self):
|
|
|
|
config = self.config
|
|
|
|
config = self.config
|
|
|
|
model = DeepSpeech2Model(
|
|
|
|
model = DeepSpeech2Model(
|
|
|
|
feat_size=self.train_loader.dataset.feature_size,
|
|
|
|
feat_size=self.train_loader.collate_fn.feature_size,
|
|
|
|
dict_size=self.train_loader.dataset.vocab_size,
|
|
|
|
dict_size=self.train_loader.collate_fn.vocab_size,
|
|
|
|
num_conv_layers=config.model.num_conv_layers,
|
|
|
|
num_conv_layers=config.model.num_conv_layers,
|
|
|
|
num_rnn_layers=config.model.num_rnn_layers,
|
|
|
|
num_rnn_layers=config.model.num_rnn_layers,
|
|
|
|
rnn_size=config.model.rnn_layer_size,
|
|
|
|
rnn_size=config.model.rnn_layer_size,
|
|
|
@ -199,7 +199,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
|
|
|
|
errors_func = error_rate.char_errors if cfg.error_rate_type == 'cer' else error_rate.word_errors
|
|
|
|
errors_func = error_rate.char_errors if cfg.error_rate_type == 'cer' else error_rate.word_errors
|
|
|
|
error_rate_func = error_rate.cer if cfg.error_rate_type == 'cer' else error_rate.wer
|
|
|
|
error_rate_func = error_rate.cer if cfg.error_rate_type == 'cer' else error_rate.wer
|
|
|
|
|
|
|
|
|
|
|
|
vocab_list = self.test_loader.dataset.vocab_list
|
|
|
|
vocab_list = self.test_loader.collate_fn.vocab_list
|
|
|
|
|
|
|
|
|
|
|
|
target_transcripts = self.ordid2token(texts, texts_len)
|
|
|
|
target_transcripts = self.ordid2token(texts, texts_len)
|
|
|
|
result_transcripts = self.model.decode(
|
|
|
|
result_transcripts = self.model.decode(
|
|
|
@ -272,7 +272,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
|
|
|
|
infer_model = DeepSpeech2InferModel.from_pretrained(
|
|
|
|
infer_model = DeepSpeech2InferModel.from_pretrained(
|
|
|
|
self.test_loader.dataset, self.config, self.args.checkpoint_path)
|
|
|
|
self.test_loader.dataset, self.config, self.args.checkpoint_path)
|
|
|
|
infer_model.eval()
|
|
|
|
infer_model.eval()
|
|
|
|
feat_dim = self.test_loader.dataset.feature_size
|
|
|
|
feat_dim = self.test_loader.collate_fn.feature_size
|
|
|
|
static_model = paddle.jit.to_static(
|
|
|
|
static_model = paddle.jit.to_static(
|
|
|
|
infer_model,
|
|
|
|
infer_model,
|
|
|
|
input_spec=[
|
|
|
|
input_spec=[
|
|
|
@ -308,8 +308,8 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
|
|
|
|
def setup_model(self):
|
|
|
|
def setup_model(self):
|
|
|
|
config = self.config
|
|
|
|
config = self.config
|
|
|
|
model = DeepSpeech2Model(
|
|
|
|
model = DeepSpeech2Model(
|
|
|
|
feat_size=self.test_loader.dataset.feature_size,
|
|
|
|
feat_size=self.test_loader.collate_fn.feature_size,
|
|
|
|
dict_size=self.test_loader.dataset.vocab_size,
|
|
|
|
dict_size=self.test_loader.collate_fn.vocab_size,
|
|
|
|
num_conv_layers=config.model.num_conv_layers,
|
|
|
|
num_conv_layers=config.model.num_conv_layers,
|
|
|
|
num_rnn_layers=config.model.num_rnn_layers,
|
|
|
|
num_rnn_layers=config.model.num_rnn_layers,
|
|
|
|
rnn_size=config.model.rnn_layer_size,
|
|
|
|
rnn_size=config.model.rnn_layer_size,
|
|
|
|