fix io; add test

pull/756/head
Hui Zhang 3 years ago
parent 4b5410eecd
commit e3d73acd37

@ -0,0 +1,10 @@
# Locales
export LC_ALL=en_US.UTF-8
export LANG=en_US.UTF-8
export LANGUAGE=en_US.UTF-8
# Aliases
alias nvs="nvidia-smi"
alias rsync="rsync --progress -raz"
alias his="history"

File diff suppressed because it is too large Load Diff

@ -347,7 +347,7 @@ def make_batchset(
Note that if any utts doesn't have "category",
perform as same as batchfy_by_{count}
:param Dict[str, Dict[str, Any]] data: dictionary loaded from data.json
:param List[Dict[str, Any]] data: dictionary loaded from data.json
:param int batch_size: maximum number of sequences in a minibatch.
:param int batch_bins: maximum number of bins (frames x dim) in a minibatch.
:param int batch_frames_in: maximum number of input frames in a minibatch.
@ -374,7 +374,6 @@ def make_batchset(
reserved for future research, -1 means all axis.)
:return: List[List[Tuple[str, dict]]] list of batches
"""
# check args
if count not in BATCH_COUNT_CHOICES:
raise ValueError(
@ -386,7 +385,6 @@ def make_batchset(
ikey = "input"
okey = "output"
batch_sort_axis = 0 # index of list
if count == "auto":
if batch_size != 0:
count = "seq"
@ -405,7 +403,8 @@ def make_batchset(
"batch_sort_key=shuffle is only available if batch_count=seq")
category2data = {} # Dict[str, dict]
for k, v in data.items():
for v in data:
k = v['utt']
category2data.setdefault(v.get("category"), {})[k] = v
batches_list = [] # List[List[List[Tuple[str, dict]]]]
@ -422,6 +421,7 @@ def make_batchset(
key=lambda data: int(data[1][batch_sort_key][batch_sort_axis]["shape"][0]),
reverse=not shortest_first, )
logger.info("# utts: " + str(len(sorted_data)))
if count == "seq":
batches = batchfy_by_seq(
sorted_data,
@ -466,4 +466,4 @@ def make_batchset(
logger.info("# minibatches: " + str(len(batches)))
# batch: List[List[Tuple[str, dict]]]
return batches
return batches

@ -16,7 +16,7 @@ from typing import Optional
from paddle.io import Dataset
from yacs.config import CfgNode
from deepspeech.frontend.utility import read_manifest
from deepspeech.utils.log import Log
__all__ = ["ManifestDataset", "TripletManifestDataset", "TransformDataset"]

Loading…
Cancel
Save