From 95d5274aef8c31fce2668b79aca1e17ecb82335d Mon Sep 17 00:00:00 2001 From: huangyuxin Date: Tue, 15 Feb 2022 05:26:25 +0000 Subject: [PATCH] fix sortagrad, test=asr --- paddlespeech/s2t/io/batchfy.py | 2 +- paddlespeech/s2t/io/dataloader.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/paddlespeech/s2t/io/batchfy.py b/paddlespeech/s2t/io/batchfy.py index f59fb24c..f3630f2e 100644 --- a/paddlespeech/s2t/io/batchfy.py +++ b/paddlespeech/s2t/io/batchfy.py @@ -419,7 +419,7 @@ def make_batchset( # sort it by input lengths (long to short) sorted_data = sorted( d.items(), - key=lambda data: int(data[1][batch_sort_key][batch_sort_axis]["shape"][0]), + key=lambda data: float(data[1][batch_sort_key][batch_sort_axis]["shape"][0]), reverse=not shortest_first, ) logger.info("# utts: " + str(len(sorted_data))) diff --git a/paddlespeech/s2t/io/dataloader.py b/paddlespeech/s2t/io/dataloader.py index 920de34f..55aa13ff 100644 --- a/paddlespeech/s2t/io/dataloader.py +++ b/paddlespeech/s2t/io/dataloader.py @@ -61,7 +61,7 @@ class BatchDataLoader(): def __init__(self, json_file: str, train_mode: bool, - sortagrad: bool=False, + sortagrad: int=0, batch_size: int=0, maxlen_in: float=float('inf'), maxlen_out: float=float('inf'),