[Hackathon 7th] 修复 `s2t` 示例错误 (#3950)

* [Fix] s2t

* [Fix] s2t test
pull/3956/head
megemini 1 week ago committed by GitHub
parent 1b9217f9f6
commit b4c2f3bae3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -127,7 +127,7 @@ source path.h
bash ./local/data.sh bash ./local/data.sh
CUDA_VISIBLE_DEVICES= ./local/train.sh conf/transformer_mtl_noam.yaml transformer_mtl_noam CUDA_VISIBLE_DEVICES= ./local/train.sh conf/transformer_mtl_noam.yaml transformer_mtl_noam
avg.sh latest exp/transformer_mtl_noam/checkpoints 5 avg.sh latest exp/transformer_mtl_noam/checkpoints 5
CUDA_VISIBLE_DEVICES= ./local/test.sh conf/transformer_mtl_noam.yaml exp/transformer_mtl_noam/checkpoints/avg_5 CUDA_VISIBLE_DEVICES= ./local/test.sh conf/transformer_mtl_noam.yaml conf/tuning/decode.yaml exp/transformer_mtl_noam/checkpoints/avg_5
``` ```
The performance of the released models are shown below: The performance of the released models are shown below:
### Transformer ### Transformer

@ -34,9 +34,6 @@ def main(config, args):
if __name__ == "__main__": if __name__ == "__main__":
parser = default_argument_parser() parser = default_argument_parser()
# save asr result to
parser.add_argument(
"--result_file", type=str, help="path of save the asr result")
args = parser.parse_args() args = parser.parse_args()
print_arguments(args, globals()) print_arguments(args, globals())

@ -115,6 +115,10 @@ class TextFeaturizer():
""" """
assert self.vocab_path_or_list, "toidx need vocab path or vocab list" assert self.vocab_path_or_list, "toidx need vocab path or vocab list"
tokens = [] tokens = []
# unwrap `idxs`` like `[[1,2,3]]`
if idxs and isinstance(idxs[0], (list, tuple)) and len(idxs) == 1:
idxs = idxs[0]
for idx in idxs: for idx in idxs:
if idx == self.eos_id: if idx == self.eos_id:
break break

@ -404,6 +404,12 @@ class DataLoaderFactory():
config['subsampling_factor'] = 1 config['subsampling_factor'] = 1
config['num_encs'] = 1 config['num_encs'] = 1
config['shortest_first'] = False config['shortest_first'] = False
config['minibatches'] = 0
config['batch_count'] = 'auto'
config['batch_bins'] = 0
config['batch_frames_in'] = 0
config['batch_frames_out'] = 0
config['batch_frames_inout'] = 0
elif mode == 'valid': elif mode == 'valid':
config['manifest'] = config.dev_manifest config['manifest'] = config.dev_manifest
config['train_mode'] = False config['train_mode'] = False

@ -170,8 +170,8 @@ class U2STBaseModel(nn.Layer):
ys_in_lens = ys_pad_lens + 1 ys_in_lens = ys_pad_lens + 1
# 1. Forward decoder # 1. Forward decoder
decoder_out, _ = self.st_decoder(encoder_out, encoder_mask, ys_in_pad, decoder_out, *_ = self.st_decoder(encoder_out, encoder_mask, ys_in_pad,
ys_in_lens) ys_in_lens)
# 2. Compute attention loss # 2. Compute attention loss
loss_att = self.criterion_att(decoder_out, ys_out_pad) loss_att = self.criterion_att(decoder_out, ys_out_pad)
@ -203,8 +203,8 @@ class U2STBaseModel(nn.Layer):
ys_in_lens = ys_pad_lens + 1 ys_in_lens = ys_pad_lens + 1
# 1. Forward decoder # 1. Forward decoder
decoder_out, _ = self.decoder(encoder_out, encoder_mask, ys_in_pad, decoder_out, *_ = self.decoder(encoder_out, encoder_mask, ys_in_pad,
ys_in_lens) ys_in_lens)
# 2. Compute attention loss # 2. Compute attention loss
loss_att = self.criterion_att(decoder_out, ys_out_pad) loss_att = self.criterion_att(decoder_out, ys_out_pad)

@ -110,14 +110,14 @@ class TransformerDecoder(BatchScorerInterface, nn.Layer):
concat_after=concat_after, ) for _ in range(num_blocks) concat_after=concat_after, ) for _ in range(num_blocks)
]) ])
def forward( def forward(self,
self, memory: paddle.Tensor,
memory: paddle.Tensor, memory_mask: paddle.Tensor,
memory_mask: paddle.Tensor, ys_in_pad: paddle.Tensor,
ys_in_pad: paddle.Tensor, ys_in_lens: paddle.Tensor,
ys_in_lens: paddle.Tensor, r_ys_in_pad: paddle.Tensor=paddle.empty([0]),
r_ys_in_pad: paddle.Tensor=paddle.empty([0]), reverse_weight: float=0.0
reverse_weight: float=0.0) -> Tuple[paddle.Tensor, paddle.Tensor]: ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
"""Forward decoder. """Forward decoder.
Args: Args:
memory: encoded memory, float32 (batch, maxlen_in, feat) memory: encoded memory, float32 (batch, maxlen_in, feat)

Loading…
Cancel
Save