support bitransformer decoder

pull/2415/head
tianhao zhang 3 years ago
parent ecbf324286
commit 455379b88e

@ -250,12 +250,9 @@ class U2Trainer(Trainer):
model_conf.output_dim = self.train_loader.vocab_size model_conf.output_dim = self.train_loader.vocab_size
else: else:
model_conf.input_dim = self.test_loader.feat_dim model_conf.input_dim = self.test_loader.feat_dim
model_conf.output_dim = 5538 model_conf.output_dim = self.test_loader.vocab_size
model = U2Model.from_config(model_conf) model = U2Model.from_config(model_conf)
# params = model.state_dict()
# paddle.save(params, 'for_torch/test.pdparams')
# exit()
if self.parallel: if self.parallel:
model = paddle.DataParallel(model) model = paddle.DataParallel(model)
@ -319,6 +316,7 @@ class U2Tester(U2Trainer):
vocab=self.config.vocab_filepath, vocab=self.config.vocab_filepath,
spm_model_prefix=self.config.spm_model_prefix) spm_model_prefix=self.config.spm_model_prefix)
self.vocab_list = self.text_feature.vocab_list self.vocab_list = self.text_feature.vocab_list
self.reverse_weight = getattr(config, 'reverse_weight', '0.0')
def id2token(self, texts, texts_len, text_feature): def id2token(self, texts, texts_len, text_feature):
""" ord() id to chr() chr """ """ ord() id to chr() chr """
@ -353,7 +351,7 @@ class U2Tester(U2Trainer):
decoding_chunk_size=decode_config.decoding_chunk_size, decoding_chunk_size=decode_config.decoding_chunk_size,
num_decoding_left_chunks=decode_config.num_decoding_left_chunks, num_decoding_left_chunks=decode_config.num_decoding_left_chunks,
simulate_streaming=decode_config.simulate_streaming, simulate_streaming=decode_config.simulate_streaming,
reverse_weight=self.config.model_conf.reverse_weight) reverse_weight=self.reverse_weight)
decode_time = time.time() - start_time decode_time = time.time() - start_time
for utt, target, result, rec_tids in zip( for utt, target, result, rec_tids in zip(

Loading…
Cancel
Save