[Hackathon 7th] 修复 `s2t` 示例错误 (#3950)

* [Fix] s2t

* [Fix] s2t test
pull/3956/head
megemini 1 week ago committed by GitHub
parent 1b9217f9f6
commit b4c2f3bae3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -127,7 +127,7 @@ source path.h
bash ./local/data.sh
CUDA_VISIBLE_DEVICES= ./local/train.sh conf/transformer_mtl_noam.yaml transformer_mtl_noam
avg.sh latest exp/transformer_mtl_noam/checkpoints 5
CUDA_VISIBLE_DEVICES= ./local/test.sh conf/transformer_mtl_noam.yaml exp/transformer_mtl_noam/checkpoints/avg_5
CUDA_VISIBLE_DEVICES= ./local/test.sh conf/transformer_mtl_noam.yaml conf/tuning/decode.yaml exp/transformer_mtl_noam/checkpoints/avg_5
```
The performance of the released models are shown below:
### Transformer

@ -34,9 +34,6 @@ def main(config, args):
if __name__ == "__main__":
parser = default_argument_parser()
# save asr result to
parser.add_argument(
"--result_file", type=str, help="path of save the asr result")
args = parser.parse_args()
print_arguments(args, globals())

@ -115,6 +115,10 @@ class TextFeaturizer():
"""
assert self.vocab_path_or_list, "toidx need vocab path or vocab list"
tokens = []
# unwrap `idxs`` like `[[1,2,3]]`
if idxs and isinstance(idxs[0], (list, tuple)) and len(idxs) == 1:
idxs = idxs[0]
for idx in idxs:
if idx == self.eos_id:
break

@ -404,6 +404,12 @@ class DataLoaderFactory():
config['subsampling_factor'] = 1
config['num_encs'] = 1
config['shortest_first'] = False
config['minibatches'] = 0
config['batch_count'] = 'auto'
config['batch_bins'] = 0
config['batch_frames_in'] = 0
config['batch_frames_out'] = 0
config['batch_frames_inout'] = 0
elif mode == 'valid':
config['manifest'] = config.dev_manifest
config['train_mode'] = False

@ -170,8 +170,8 @@ class U2STBaseModel(nn.Layer):
ys_in_lens = ys_pad_lens + 1
# 1. Forward decoder
decoder_out, _ = self.st_decoder(encoder_out, encoder_mask, ys_in_pad,
ys_in_lens)
decoder_out, *_ = self.st_decoder(encoder_out, encoder_mask, ys_in_pad,
ys_in_lens)
# 2. Compute attention loss
loss_att = self.criterion_att(decoder_out, ys_out_pad)
@ -203,8 +203,8 @@ class U2STBaseModel(nn.Layer):
ys_in_lens = ys_pad_lens + 1
# 1. Forward decoder
decoder_out, _ = self.decoder(encoder_out, encoder_mask, ys_in_pad,
ys_in_lens)
decoder_out, *_ = self.decoder(encoder_out, encoder_mask, ys_in_pad,
ys_in_lens)
# 2. Compute attention loss
loss_att = self.criterion_att(decoder_out, ys_out_pad)

@ -110,14 +110,14 @@ class TransformerDecoder(BatchScorerInterface, nn.Layer):
concat_after=concat_after, ) for _ in range(num_blocks)
])
def forward(
self,
memory: paddle.Tensor,
memory_mask: paddle.Tensor,
ys_in_pad: paddle.Tensor,
ys_in_lens: paddle.Tensor,
r_ys_in_pad: paddle.Tensor=paddle.empty([0]),
reverse_weight: float=0.0) -> Tuple[paddle.Tensor, paddle.Tensor]:
def forward(self,
memory: paddle.Tensor,
memory_mask: paddle.Tensor,
ys_in_pad: paddle.Tensor,
ys_in_lens: paddle.Tensor,
r_ys_in_pad: paddle.Tensor=paddle.empty([0]),
reverse_weight: float=0.0
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
"""Forward decoder.
Args:
memory: encoded memory, float32 (batch, maxlen_in, feat)

Loading…
Cancel
Save