Merge pull request #1449 from Jackwaterveg/fix

[ASR] fix sortagrad
pull/1452/head
Hui Zhang 2 years ago committed by GitHub
commit 15706a8428
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -419,7 +419,7 @@ def make_batchset(
# sort it by input lengths (long to short)
sorted_data = sorted(
d.items(),
key=lambda data: int(data[1][batch_sort_key][batch_sort_axis]["shape"][0]),
key=lambda data: float(data[1][batch_sort_key][batch_sort_axis]["shape"][0]),
reverse=not shortest_first, )
logger.info("# utts: " + str(len(sorted_data)))

@ -61,7 +61,7 @@ class BatchDataLoader():
def __init__(self,
json_file: str,
train_mode: bool,
sortagrad: bool=False,
sortagrad: int=0,
batch_size: int=0,
maxlen_in: float=float('inf'),
maxlen_out: float=float('inf'),

Loading…
Cancel
Save