|
|
@ -347,7 +347,7 @@ def make_batchset(
|
|
|
|
Note that if any utts doesn't have "category",
|
|
|
|
Note that if any utts doesn't have "category",
|
|
|
|
perform as same as batchfy_by_{count}
|
|
|
|
perform as same as batchfy_by_{count}
|
|
|
|
|
|
|
|
|
|
|
|
:param Dict[str, Dict[str, Any]] data: dictionary loaded from data.json
|
|
|
|
:param List[Dict[str, Any]] data: dictionary loaded from data.json
|
|
|
|
:param int batch_size: maximum number of sequences in a minibatch.
|
|
|
|
:param int batch_size: maximum number of sequences in a minibatch.
|
|
|
|
:param int batch_bins: maximum number of bins (frames x dim) in a minibatch.
|
|
|
|
:param int batch_bins: maximum number of bins (frames x dim) in a minibatch.
|
|
|
|
:param int batch_frames_in: maximum number of input frames in a minibatch.
|
|
|
|
:param int batch_frames_in: maximum number of input frames in a minibatch.
|
|
|
@ -374,7 +374,6 @@ def make_batchset(
|
|
|
|
reserved for future research, -1 means all axis.)
|
|
|
|
reserved for future research, -1 means all axis.)
|
|
|
|
:return: List[List[Tuple[str, dict]]] list of batches
|
|
|
|
:return: List[List[Tuple[str, dict]]] list of batches
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
# check args
|
|
|
|
# check args
|
|
|
|
if count not in BATCH_COUNT_CHOICES:
|
|
|
|
if count not in BATCH_COUNT_CHOICES:
|
|
|
|
raise ValueError(
|
|
|
|
raise ValueError(
|
|
|
@ -386,7 +385,6 @@ def make_batchset(
|
|
|
|
ikey = "input"
|
|
|
|
ikey = "input"
|
|
|
|
okey = "output"
|
|
|
|
okey = "output"
|
|
|
|
batch_sort_axis = 0 # index of list
|
|
|
|
batch_sort_axis = 0 # index of list
|
|
|
|
|
|
|
|
|
|
|
|
if count == "auto":
|
|
|
|
if count == "auto":
|
|
|
|
if batch_size != 0:
|
|
|
|
if batch_size != 0:
|
|
|
|
count = "seq"
|
|
|
|
count = "seq"
|
|
|
@ -405,7 +403,8 @@ def make_batchset(
|
|
|
|
"batch_sort_key=shuffle is only available if batch_count=seq")
|
|
|
|
"batch_sort_key=shuffle is only available if batch_count=seq")
|
|
|
|
|
|
|
|
|
|
|
|
category2data = {} # Dict[str, dict]
|
|
|
|
category2data = {} # Dict[str, dict]
|
|
|
|
for k, v in data.items():
|
|
|
|
for v in data:
|
|
|
|
|
|
|
|
k = v['utt']
|
|
|
|
category2data.setdefault(v.get("category"), {})[k] = v
|
|
|
|
category2data.setdefault(v.get("category"), {})[k] = v
|
|
|
|
|
|
|
|
|
|
|
|
batches_list = [] # List[List[List[Tuple[str, dict]]]]
|
|
|
|
batches_list = [] # List[List[List[Tuple[str, dict]]]]
|
|
|
@ -422,6 +421,7 @@ def make_batchset(
|
|
|
|
key=lambda data: int(data[1][batch_sort_key][batch_sort_axis]["shape"][0]),
|
|
|
|
key=lambda data: int(data[1][batch_sort_key][batch_sort_axis]["shape"][0]),
|
|
|
|
reverse=not shortest_first, )
|
|
|
|
reverse=not shortest_first, )
|
|
|
|
logger.info("# utts: " + str(len(sorted_data)))
|
|
|
|
logger.info("# utts: " + str(len(sorted_data)))
|
|
|
|
|
|
|
|
|
|
|
|
if count == "seq":
|
|
|
|
if count == "seq":
|
|
|
|
batches = batchfy_by_seq(
|
|
|
|
batches = batchfy_by_seq(
|
|
|
|
sorted_data,
|
|
|
|
sorted_data,
|
|
|
|