|
|
@ -250,12 +250,9 @@ class U2Trainer(Trainer):
|
|
|
|
model_conf.output_dim = self.train_loader.vocab_size
|
|
|
|
model_conf.output_dim = self.train_loader.vocab_size
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
model_conf.input_dim = self.test_loader.feat_dim
|
|
|
|
model_conf.input_dim = self.test_loader.feat_dim
|
|
|
|
model_conf.output_dim = 5538
|
|
|
|
model_conf.output_dim = self.test_loader.vocab_size
|
|
|
|
|
|
|
|
|
|
|
|
model = U2Model.from_config(model_conf)
|
|
|
|
model = U2Model.from_config(model_conf)
|
|
|
|
# params = model.state_dict()
|
|
|
|
|
|
|
|
# paddle.save(params, 'for_torch/test.pdparams')
|
|
|
|
|
|
|
|
# exit()
|
|
|
|
|
|
|
|
if self.parallel:
|
|
|
|
if self.parallel:
|
|
|
|
model = paddle.DataParallel(model)
|
|
|
|
model = paddle.DataParallel(model)
|
|
|
|
|
|
|
|
|
|
|
@ -319,6 +316,7 @@ class U2Tester(U2Trainer):
|
|
|
|
vocab=self.config.vocab_filepath,
|
|
|
|
vocab=self.config.vocab_filepath,
|
|
|
|
spm_model_prefix=self.config.spm_model_prefix)
|
|
|
|
spm_model_prefix=self.config.spm_model_prefix)
|
|
|
|
self.vocab_list = self.text_feature.vocab_list
|
|
|
|
self.vocab_list = self.text_feature.vocab_list
|
|
|
|
|
|
|
|
self.reverse_weight = getattr(config, 'reverse_weight', '0.0')
|
|
|
|
|
|
|
|
|
|
|
|
def id2token(self, texts, texts_len, text_feature):
|
|
|
|
def id2token(self, texts, texts_len, text_feature):
|
|
|
|
""" ord() id to chr() chr """
|
|
|
|
""" ord() id to chr() chr """
|
|
|
@ -353,7 +351,7 @@ class U2Tester(U2Trainer):
|
|
|
|
decoding_chunk_size=decode_config.decoding_chunk_size,
|
|
|
|
decoding_chunk_size=decode_config.decoding_chunk_size,
|
|
|
|
num_decoding_left_chunks=decode_config.num_decoding_left_chunks,
|
|
|
|
num_decoding_left_chunks=decode_config.num_decoding_left_chunks,
|
|
|
|
simulate_streaming=decode_config.simulate_streaming,
|
|
|
|
simulate_streaming=decode_config.simulate_streaming,
|
|
|
|
reverse_weight=self.config.model_conf.reverse_weight)
|
|
|
|
reverse_weight=self.reverse_weight)
|
|
|
|
decode_time = time.time() - start_time
|
|
|
|
decode_time = time.time() - start_time
|
|
|
|
|
|
|
|
|
|
|
|
for utt, target, result, rec_tids in zip(
|
|
|
|
for utt, target, result, rec_tids in zip(
|
|
|
|