diff --git a/paddlespeech/s2t/exps/u2/model.py b/paddlespeech/s2t/exps/u2/model.py index a7ccba485..99a0434d5 100644 --- a/paddlespeech/s2t/exps/u2/model.py +++ b/paddlespeech/s2t/exps/u2/model.py @@ -250,12 +250,9 @@ class U2Trainer(Trainer): model_conf.output_dim = self.train_loader.vocab_size else: model_conf.input_dim = self.test_loader.feat_dim - model_conf.output_dim = 5538 + model_conf.output_dim = self.test_loader.vocab_size model = U2Model.from_config(model_conf) - # params = model.state_dict() - # paddle.save(params, 'for_torch/test.pdparams') - # exit() if self.parallel: model = paddle.DataParallel(model) @@ -319,6 +316,7 @@ class U2Tester(U2Trainer): vocab=self.config.vocab_filepath, spm_model_prefix=self.config.spm_model_prefix) self.vocab_list = self.text_feature.vocab_list + self.reverse_weight = getattr(config, 'reverse_weight', '0.0') def id2token(self, texts, texts_len, text_feature): """ ord() id to chr() chr """ @@ -353,7 +351,7 @@ class U2Tester(U2Trainer): decoding_chunk_size=decode_config.decoding_chunk_size, num_decoding_left_chunks=decode_config.num_decoding_left_chunks, simulate_streaming=decode_config.simulate_streaming, - reverse_weight=self.config.model_conf.reverse_weight) + reverse_weight=self.reverse_weight) decode_time = time.time() - start_time for utt, target, result, rec_tids in zip(