fix sortagrad, test=asr

pull/1449/head
huangyuxin 3 years ago
parent 7f970bb255
commit 95d5274aef

@ -419,7 +419,7 @@ def make_batchset(
# sort it by input lengths (long to short)
sorted_data = sorted(
d.items(),
key=lambda data: int(data[1][batch_sort_key][batch_sort_axis]["shape"][0]),
key=lambda data: float(data[1][batch_sort_key][batch_sort_axis]["shape"][0]),
reverse=not shortest_first, )
logger.info("# utts: " + str(len(sorted_data)))

@ -61,7 +61,7 @@ class BatchDataLoader():
def __init__(self,
json_file: str,
train_mode: bool,
sortagrad: bool=False,
sortagrad: int=0,
batch_size: int=0,
maxlen_in: float=float('inf'),
maxlen_out: float=float('inf'),

Loading…
Cancel
Save