Merge pull request #756 from PaddlePaddle/filter
test w/ all example & fix ctc api & add new iopull/767/head
commit
f05f367cc5
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,469 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import itertools
|
||||
|
||||
import numpy as np
|
||||
|
||||
from deepspeech.utils.log import Log
|
||||
|
||||
__all__ = ["make_batchset"]
|
||||
|
||||
logger = Log(__name__).getlog()
|
||||
|
||||
|
||||
def batchfy_by_seq(
|
||||
sorted_data,
|
||||
batch_size,
|
||||
max_length_in,
|
||||
max_length_out,
|
||||
min_batch_size=1,
|
||||
shortest_first=False,
|
||||
ikey="input",
|
||||
iaxis=0,
|
||||
okey="output",
|
||||
oaxis=0, ):
|
||||
"""Make batch set from json dictionary
|
||||
|
||||
:param List[(str, Dict[str, Any])] sorted_data: dictionary loaded from data.json
|
||||
:param int batch_size: batch size
|
||||
:param int max_length_in: maximum length of input to decide adaptive batch size
|
||||
:param int max_length_out: maximum length of output to decide adaptive batch size
|
||||
:param int min_batch_size: mininum batch size (for multi-gpu)
|
||||
:param bool shortest_first: Sort from batch with shortest samples
|
||||
to longest if true, otherwise reverse
|
||||
:param str ikey: key to access input
|
||||
(for ASR ikey="input", for TTS, MT ikey="output".)
|
||||
:param int iaxis: dimension to access input
|
||||
(for ASR, TTS iaxis=0, for MT iaxis="1".)
|
||||
:param str okey: key to access output
|
||||
(for ASR, MT okey="output". for TTS okey="input".)
|
||||
:param int oaxis: dimension to access output
|
||||
(for ASR, TTS, MT oaxis=0, reserved for future research, -1 means all axis.)
|
||||
:return: List[List[Tuple[str, dict]]] list of batches
|
||||
"""
|
||||
if batch_size <= 0:
|
||||
raise ValueError(f"Invalid batch_size={batch_size}")
|
||||
|
||||
# check #utts is more than min_batch_size
|
||||
if len(sorted_data) < min_batch_size:
|
||||
raise ValueError(
|
||||
f"#utts({len(sorted_data)}) is less than min_batch_size({min_batch_size})."
|
||||
)
|
||||
|
||||
# make list of minibatches
|
||||
minibatches = []
|
||||
start = 0
|
||||
while True:
|
||||
_, info = sorted_data[start]
|
||||
ilen = int(info[ikey][iaxis]["shape"][0])
|
||||
olen = (int(info[okey][oaxis]["shape"][0]) if oaxis >= 0 else
|
||||
max(map(lambda x: int(x["shape"][0]), info[okey])))
|
||||
factor = max(int(ilen / max_length_in), int(olen / max_length_out))
|
||||
# change batchsize depending on the input and output length
|
||||
# if ilen = 1000 and max_length_in = 800
|
||||
# then b = batchsize / 2
|
||||
# and max(min_batches, .) avoids batchsize = 0
|
||||
bs = max(min_batch_size, int(batch_size / (1 + factor)))
|
||||
end = min(len(sorted_data), start + bs)
|
||||
minibatch = sorted_data[start:end]
|
||||
if shortest_first:
|
||||
minibatch.reverse()
|
||||
|
||||
# check each batch is more than minimum batchsize
|
||||
if len(minibatch) < min_batch_size:
|
||||
mod = min_batch_size - len(minibatch) % min_batch_size
|
||||
additional_minibatch = [
|
||||
sorted_data[i] for i in np.random.randint(0, start, mod)
|
||||
]
|
||||
if shortest_first:
|
||||
additional_minibatch.reverse()
|
||||
minibatch.extend(additional_minibatch)
|
||||
minibatches.append(minibatch)
|
||||
|
||||
if end == len(sorted_data):
|
||||
break
|
||||
start = end
|
||||
|
||||
# batch: List[List[Tuple[str, dict]]]
|
||||
return minibatches
|
||||
|
||||
|
||||
def batchfy_by_bin(
|
||||
sorted_data,
|
||||
batch_bins,
|
||||
num_batches=0,
|
||||
min_batch_size=1,
|
||||
shortest_first=False,
|
||||
ikey="input",
|
||||
okey="output", ):
|
||||
"""Make variably sized batch set, which maximizes
|
||||
|
||||
the number of bins up to `batch_bins`.
|
||||
|
||||
:param List[(str, Dict[str, Any])] sorted_data: dictionary loaded from data.json
|
||||
:param int batch_bins: Maximum frames of a batch
|
||||
:param int num_batches: # number of batches to use (for debug)
|
||||
:param int min_batch_size: minimum batch size (for multi-gpu)
|
||||
:param int test: Return only every `test` batches
|
||||
:param bool shortest_first: Sort from batch with shortest samples
|
||||
to longest if true, otherwise reverse
|
||||
|
||||
:param str ikey: key to access input (for ASR ikey="input", for TTS ikey="output".)
|
||||
:param str okey: key to access output (for ASR okey="output". for TTS okey="input".)
|
||||
|
||||
:return: List[Tuple[str, Dict[str, List[Dict[str, Any]]]] list of batches
|
||||
"""
|
||||
if batch_bins <= 0:
|
||||
raise ValueError(f"invalid batch_bins={batch_bins}")
|
||||
length = len(sorted_data)
|
||||
idim = int(sorted_data[0][1][ikey][0]["shape"][1])
|
||||
odim = int(sorted_data[0][1][okey][0]["shape"][1])
|
||||
logger.info("# utts: " + str(len(sorted_data)))
|
||||
minibatches = []
|
||||
start = 0
|
||||
n = 0
|
||||
while True:
|
||||
# Dynamic batch size depending on size of samples
|
||||
b = 0
|
||||
next_size = 0
|
||||
max_olen = 0
|
||||
while next_size < batch_bins and (start + b) < length:
|
||||
ilen = int(sorted_data[start + b][1][ikey][0]["shape"][0]) * idim
|
||||
olen = int(sorted_data[start + b][1][okey][0]["shape"][0]) * odim
|
||||
if olen > max_olen:
|
||||
max_olen = olen
|
||||
next_size = (max_olen + ilen) * (b + 1)
|
||||
if next_size <= batch_bins:
|
||||
b += 1
|
||||
elif next_size == 0:
|
||||
raise ValueError(
|
||||
f"Can't fit one sample in batch_bins ({batch_bins}): "
|
||||
f"Please increase the value")
|
||||
end = min(length, start + max(min_batch_size, b))
|
||||
batch = sorted_data[start:end]
|
||||
if shortest_first:
|
||||
batch.reverse()
|
||||
minibatches.append(batch)
|
||||
# Check for min_batch_size and fixes the batches if needed
|
||||
i = -1
|
||||
while len(minibatches[i]) < min_batch_size:
|
||||
missing = min_batch_size - len(minibatches[i])
|
||||
if -i == len(minibatches):
|
||||
minibatches[i + 1].extend(minibatches[i])
|
||||
minibatches = minibatches[1:]
|
||||
break
|
||||
else:
|
||||
minibatches[i].extend(minibatches[i - 1][:missing])
|
||||
minibatches[i - 1] = minibatches[i - 1][missing:]
|
||||
i -= 1
|
||||
if end == length:
|
||||
break
|
||||
start = end
|
||||
n += 1
|
||||
if num_batches > 0:
|
||||
minibatches = minibatches[:num_batches]
|
||||
lengths = [len(x) for x in minibatches]
|
||||
logger.info(
|
||||
str(len(minibatches)) + " batches containing from " + str(min(lengths))
|
||||
+ " to " + str(max(lengths)) + " samples " + "(avg " + str(
|
||||
int(np.mean(lengths))) + " samples).")
|
||||
return minibatches
|
||||
|
||||
|
||||
def batchfy_by_frame(
|
||||
sorted_data,
|
||||
max_frames_in,
|
||||
max_frames_out,
|
||||
max_frames_inout,
|
||||
num_batches=0,
|
||||
min_batch_size=1,
|
||||
shortest_first=False,
|
||||
ikey="input",
|
||||
okey="output", ):
|
||||
"""Make variable batch set, which maximizes the number of frames to max_batch_frame.
|
||||
|
||||
:param List[(str, Dict[str, Any])] sorteddata: dictionary loaded from data.json
|
||||
:param int max_frames_in: Maximum input frames of a batch
|
||||
:param int max_frames_out: Maximum output frames of a batch
|
||||
:param int max_frames_inout: Maximum input+output frames of a batch
|
||||
:param int num_batches: # number of batches to use (for debug)
|
||||
:param int min_batch_size: minimum batch size (for multi-gpu)
|
||||
:param int test: Return only every `test` batches
|
||||
:param bool shortest_first: Sort from batch with shortest samples
|
||||
to longest if true, otherwise reverse
|
||||
|
||||
:param str ikey: key to access input (for ASR ikey="input", for TTS ikey="output".)
|
||||
:param str okey: key to access output (for ASR okey="output". for TTS okey="input".)
|
||||
|
||||
:return: List[Tuple[str, Dict[str, List[Dict[str, Any]]]] list of batches
|
||||
"""
|
||||
if max_frames_in <= 0 and max_frames_out <= 0 and max_frames_inout <= 0:
|
||||
raise ValueError(
|
||||
"At least, one of `--batch-frames-in`, `--batch-frames-out` or "
|
||||
"`--batch-frames-inout` should be > 0")
|
||||
length = len(sorted_data)
|
||||
minibatches = []
|
||||
start = 0
|
||||
end = 0
|
||||
while end != length:
|
||||
# Dynamic batch size depending on size of samples
|
||||
b = 0
|
||||
max_olen = 0
|
||||
max_ilen = 0
|
||||
while (start + b) < length:
|
||||
ilen = int(sorted_data[start + b][1][ikey][0]["shape"][0])
|
||||
if ilen > max_frames_in and max_frames_in != 0:
|
||||
raise ValueError(
|
||||
f"Can't fit one sample in --batch-frames-in ({max_frames_in}): "
|
||||
f"Please increase the value")
|
||||
olen = int(sorted_data[start + b][1][okey][0]["shape"][0])
|
||||
if olen > max_frames_out and max_frames_out != 0:
|
||||
raise ValueError(
|
||||
f"Can't fit one sample in --batch-frames-out ({max_frames_out}): "
|
||||
f"Please increase the value")
|
||||
if ilen + olen > max_frames_inout and max_frames_inout != 0:
|
||||
raise ValueError(
|
||||
f"Can't fit one sample in --batch-frames-out ({max_frames_inout}): "
|
||||
f"Please increase the value")
|
||||
max_olen = max(max_olen, olen)
|
||||
max_ilen = max(max_ilen, ilen)
|
||||
in_ok = max_ilen * (b + 1) <= max_frames_in or max_frames_in == 0
|
||||
out_ok = max_olen * (b + 1) <= max_frames_out or max_frames_out == 0
|
||||
inout_ok = (max_ilen + max_olen) * (
|
||||
b + 1) <= max_frames_inout or max_frames_inout == 0
|
||||
if in_ok and out_ok and inout_ok:
|
||||
# add more seq in the minibatch
|
||||
b += 1
|
||||
else:
|
||||
# no more seq in the minibatch
|
||||
break
|
||||
end = min(length, start + b)
|
||||
batch = sorted_data[start:end]
|
||||
if shortest_first:
|
||||
batch.reverse()
|
||||
minibatches.append(batch)
|
||||
# Check for min_batch_size and fixes the batches if needed
|
||||
i = -1
|
||||
while len(minibatches[i]) < min_batch_size:
|
||||
missing = min_batch_size - len(minibatches[i])
|
||||
if -i == len(minibatches):
|
||||
minibatches[i + 1].extend(minibatches[i])
|
||||
minibatches = minibatches[1:]
|
||||
break
|
||||
else:
|
||||
minibatches[i].extend(minibatches[i - 1][:missing])
|
||||
minibatches[i - 1] = minibatches[i - 1][missing:]
|
||||
i -= 1
|
||||
start = end
|
||||
if num_batches > 0:
|
||||
minibatches = minibatches[:num_batches]
|
||||
lengths = [len(x) for x in minibatches]
|
||||
logger.info(
|
||||
str(len(minibatches)) + " batches containing from " + str(min(lengths))
|
||||
+ " to " + str(max(lengths)) + " samples" + "(avg " + str(
|
||||
int(np.mean(lengths))) + " samples).")
|
||||
|
||||
return minibatches
|
||||
|
||||
|
||||
def batchfy_shuffle(data, batch_size, min_batch_size, num_batches,
|
||||
shortest_first):
|
||||
import random
|
||||
|
||||
logger.info("use shuffled batch.")
|
||||
sorted_data = random.sample(data.items(), len(data.items()))
|
||||
logger.info("# utts: " + str(len(sorted_data)))
|
||||
# make list of minibatches
|
||||
minibatches = []
|
||||
start = 0
|
||||
while True:
|
||||
end = min(len(sorted_data), start + batch_size)
|
||||
# check each batch is more than minimum batchsize
|
||||
minibatch = sorted_data[start:end]
|
||||
if shortest_first:
|
||||
minibatch.reverse()
|
||||
if len(minibatch) < min_batch_size:
|
||||
mod = min_batch_size - len(minibatch) % min_batch_size
|
||||
additional_minibatch = [
|
||||
sorted_data[i] for i in np.random.randint(0, start, mod)
|
||||
]
|
||||
if shortest_first:
|
||||
additional_minibatch.reverse()
|
||||
minibatch.extend(additional_minibatch)
|
||||
minibatches.append(minibatch)
|
||||
if end == len(sorted_data):
|
||||
break
|
||||
start = end
|
||||
|
||||
# for debugging
|
||||
if num_batches > 0:
|
||||
minibatches = minibatches[:num_batches]
|
||||
logger.info("# minibatches: " + str(len(minibatches)))
|
||||
return minibatches
|
||||
|
||||
|
||||
BATCH_COUNT_CHOICES = ["auto", "seq", "bin", "frame"]
|
||||
BATCH_SORT_KEY_CHOICES = ["input", "output", "shuffle"]
|
||||
|
||||
|
||||
def make_batchset(
|
||||
data,
|
||||
batch_size=0,
|
||||
max_length_in=float("inf"),
|
||||
max_length_out=float("inf"),
|
||||
num_batches=0,
|
||||
min_batch_size=1,
|
||||
shortest_first=False,
|
||||
batch_sort_key="input",
|
||||
count="auto",
|
||||
batch_bins=0,
|
||||
batch_frames_in=0,
|
||||
batch_frames_out=0,
|
||||
batch_frames_inout=0,
|
||||
iaxis=0,
|
||||
oaxis=0, ):
|
||||
"""Make batch set from json dictionary
|
||||
|
||||
if utts have "category" value,
|
||||
|
||||
>>> data = [{'category': 'A', 'input': ..., 'utt':'utt1'},
|
||||
... {'category': 'B', 'input': ..., 'utt':'utt2'},
|
||||
... {'category': 'B', 'input': ..., 'utt':'utt3'},
|
||||
... {'category': 'A', 'input': ..., 'utt':'utt4'}]
|
||||
>>> make_batchset(data, batchsize=2, ...)
|
||||
[[('utt1', ...), ('utt4', ...)], [('utt2', ...), ('utt3': ...)]]
|
||||
|
||||
Note that if any utts doesn't have "category",
|
||||
perform as same as batchfy_by_{count}
|
||||
|
||||
:param List[Dict[str, Any]] data: dictionary loaded from data.json
|
||||
:param int batch_size: maximum number of sequences in a minibatch.
|
||||
:param int batch_bins: maximum number of bins (frames x dim) in a minibatch.
|
||||
:param int batch_frames_in: maximum number of input frames in a minibatch.
|
||||
:param int batch_frames_out: maximum number of output frames in a minibatch.
|
||||
:param int batch_frames_out: maximum number of input+output frames in a minibatch.
|
||||
:param str count: strategy to count maximum size of batch.
|
||||
For choices, see espnet.asr.batchfy.BATCH_COUNT_CHOICES
|
||||
|
||||
:param int max_length_in: maximum length of input to decide adaptive batch size
|
||||
:param int max_length_out: maximum length of output to decide adaptive batch size
|
||||
:param int num_batches: # number of batches to use (for debug)
|
||||
:param int min_batch_size: minimum batch size (for multi-gpu)
|
||||
:param bool shortest_first: Sort from batch with shortest samples
|
||||
to longest if true, otherwise reverse
|
||||
:param str batch_sort_key: how to sort data before creating minibatches
|
||||
["input", "output", "shuffle"]
|
||||
:param bool swap_io: if True, use "input" as output and "output"
|
||||
as input in `data` dict
|
||||
:param bool mt: if True, use 0-axis of "output" as output and 1-axis of "output"
|
||||
as input in `data` dict
|
||||
:param int iaxis: dimension to access input
|
||||
(for ASR, TTS iaxis=0, for MT iaxis="1".)
|
||||
:param int oaxis: dimension to access output (for ASR, TTS, MT oaxis=0,
|
||||
reserved for future research, -1 means all axis.)
|
||||
:return: List[List[Tuple[str, dict]]] list of batches
|
||||
"""
|
||||
# check args
|
||||
if count not in BATCH_COUNT_CHOICES:
|
||||
raise ValueError(
|
||||
f"arg 'count' ({count}) should be one of {BATCH_COUNT_CHOICES}")
|
||||
if batch_sort_key not in BATCH_SORT_KEY_CHOICES:
|
||||
raise ValueError(f"arg 'batch_sort_key' ({batch_sort_key}) should be "
|
||||
f"one of {BATCH_SORT_KEY_CHOICES}")
|
||||
|
||||
ikey = "input"
|
||||
okey = "output"
|
||||
batch_sort_axis = 0 # index of list
|
||||
if count == "auto":
|
||||
if batch_size != 0:
|
||||
count = "seq"
|
||||
elif batch_bins != 0:
|
||||
count = "bin"
|
||||
elif batch_frames_in != 0 or batch_frames_out != 0 or batch_frames_inout != 0:
|
||||
count = "frame"
|
||||
else:
|
||||
raise ValueError(
|
||||
f"cannot detect `count` manually set one of {BATCH_COUNT_CHOICES}"
|
||||
)
|
||||
logger.info(f"count is auto detected as {count}")
|
||||
|
||||
if count != "seq" and batch_sort_key == "shuffle":
|
||||
raise ValueError(
|
||||
"batch_sort_key=shuffle is only available if batch_count=seq")
|
||||
|
||||
category2data = {} # Dict[str, dict]
|
||||
for v in data:
|
||||
k = v['utt']
|
||||
category2data.setdefault(v.get("category"), {})[k] = v
|
||||
|
||||
batches_list = [] # List[List[List[Tuple[str, dict]]]]
|
||||
for d in category2data.values():
|
||||
if batch_sort_key == "shuffle":
|
||||
batches = batchfy_shuffle(d, batch_size, min_batch_size,
|
||||
num_batches, shortest_first)
|
||||
batches_list.append(batches)
|
||||
continue
|
||||
|
||||
# sort it by input lengths (long to short)
|
||||
sorted_data = sorted(
|
||||
d.items(),
|
||||
key=lambda data: int(data[1][batch_sort_key][batch_sort_axis]["shape"][0]),
|
||||
reverse=not shortest_first, )
|
||||
logger.info("# utts: " + str(len(sorted_data)))
|
||||
|
||||
if count == "seq":
|
||||
batches = batchfy_by_seq(
|
||||
sorted_data,
|
||||
batch_size=batch_size,
|
||||
max_length_in=max_length_in,
|
||||
max_length_out=max_length_out,
|
||||
min_batch_size=min_batch_size,
|
||||
shortest_first=shortest_first,
|
||||
ikey=ikey,
|
||||
iaxis=iaxis,
|
||||
okey=okey,
|
||||
oaxis=oaxis, )
|
||||
if count == "bin":
|
||||
batches = batchfy_by_bin(
|
||||
sorted_data,
|
||||
batch_bins=batch_bins,
|
||||
min_batch_size=min_batch_size,
|
||||
shortest_first=shortest_first,
|
||||
ikey=ikey,
|
||||
okey=okey, )
|
||||
if count == "frame":
|
||||
batches = batchfy_by_frame(
|
||||
sorted_data,
|
||||
max_frames_in=batch_frames_in,
|
||||
max_frames_out=batch_frames_out,
|
||||
max_frames_inout=batch_frames_inout,
|
||||
min_batch_size=min_batch_size,
|
||||
shortest_first=shortest_first,
|
||||
ikey=ikey,
|
||||
okey=okey, )
|
||||
batches_list.append(batches)
|
||||
|
||||
if len(batches_list) == 1:
|
||||
batches = batches_list[0]
|
||||
else:
|
||||
# Concat list. This way is faster than "sum(batch_list, [])"
|
||||
batches = list(itertools.chain(*batches_list))
|
||||
|
||||
# for debugging
|
||||
if num_batches > 0:
|
||||
batches = batches[:num_batches]
|
||||
logger.info("# minibatches: " + str(len(batches)))
|
||||
|
||||
# batch: List[List[Tuple[str, dict]]]
|
||||
return batches
|
@ -0,0 +1,80 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import numpy as np
|
||||
|
||||
from deepspeech.io.utility import pad_list
|
||||
from deepspeech.utils.log import Log
|
||||
|
||||
__all__ = ["CustomConverter"]
|
||||
|
||||
logger = Log(__name__).getlog()
|
||||
|
||||
|
||||
class CustomConverter():
|
||||
"""Custom batch converter.
|
||||
|
||||
Args:
|
||||
subsampling_factor (int): The subsampling factor.
|
||||
dtype (np.dtype): Data type to convert.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, subsampling_factor=1, dtype=np.float32):
|
||||
"""Construct a CustomConverter object."""
|
||||
self.subsampling_factor = subsampling_factor
|
||||
self.ignore_id = -1
|
||||
self.dtype = dtype
|
||||
|
||||
def __call__(self, batch):
|
||||
"""Transform a batch and send it to a device.
|
||||
|
||||
Args:
|
||||
batch (list): The batch to transform.
|
||||
|
||||
Returns:
|
||||
tuple(paddle.Tensor, paddle.Tensor, paddle.Tensor)
|
||||
|
||||
"""
|
||||
# batch should be located in list
|
||||
assert len(batch) == 1
|
||||
(xs, ys), utts = batch[0]
|
||||
|
||||
# perform subsampling
|
||||
if self.subsampling_factor > 1:
|
||||
xs = [x[::self.subsampling_factor, :] for x in xs]
|
||||
|
||||
# get batch of lengths of input sequences
|
||||
ilens = np.array([x.shape[0] for x in xs])
|
||||
|
||||
# perform padding and convert to tensor
|
||||
# currently only support real number
|
||||
if xs[0].dtype.kind == "c":
|
||||
xs_pad_real = pad_list([x.real for x in xs], 0).astype(self.dtype)
|
||||
xs_pad_imag = pad_list([x.imag for x in xs], 0).astype(self.dtype)
|
||||
# Note(kamo):
|
||||
# {'real': ..., 'imag': ...} will be changed to ComplexTensor in E2E.
|
||||
# Don't create ComplexTensor and give it E2E here
|
||||
# because torch.nn.DataParellel can't handle it.
|
||||
xs_pad = {"real": xs_pad_real, "imag": xs_pad_imag}
|
||||
else:
|
||||
xs_pad = pad_list(xs, 0).astype(self.dtype)
|
||||
|
||||
# NOTE: this is for multi-output (e.g., speech translation)
|
||||
ys_pad = pad_list(
|
||||
[np.array(y[0][:]) if isinstance(y, tuple) else y for y in ys],
|
||||
self.ignore_id)
|
||||
|
||||
olens = np.array(
|
||||
[y[0].shape[0] if isinstance(y, tuple) else y.shape[0] for y in ys])
|
||||
return utts, xs_pad, ilens, ys_pad, olens
|
@ -0,0 +1,138 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from paddle.io import DataLoader
|
||||
|
||||
from deepspeech.frontend.utility import read_manifest
|
||||
from deepspeech.io.batchfy import make_batchset
|
||||
from deepspeech.io.converter import CustomConverter
|
||||
from deepspeech.io.dataset import TransformDataset
|
||||
from deepspeech.io.reader import LoadInputsAndTargets
|
||||
from deepspeech.utils.log import Log
|
||||
|
||||
__all__ = ["BatchDataLoader"]
|
||||
|
||||
logger = Log(__name__).getlog()
|
||||
|
||||
|
||||
class BatchDataLoader():
|
||||
def __init__(self,
|
||||
json_file: str,
|
||||
train_mode: bool,
|
||||
sortagrad: bool=False,
|
||||
batch_size: int=0,
|
||||
maxlen_in: float=float('inf'),
|
||||
maxlen_out: float=float('inf'),
|
||||
minibatches: int=0,
|
||||
mini_batch_size: int=1,
|
||||
batch_count: str='auto',
|
||||
batch_bins: int=0,
|
||||
batch_frames_in: int=0,
|
||||
batch_frames_out: int=0,
|
||||
batch_frames_inout: int=0,
|
||||
preprocess_conf=None,
|
||||
n_iter_processes: int=1,
|
||||
subsampling_factor: int=1,
|
||||
num_encs: int=1):
|
||||
self.json_file = json_file
|
||||
self.train_mode = train_mode
|
||||
self.use_sortagrad = sortagrad == -1 or sortagrad > 0
|
||||
self.batch_size = batch_size
|
||||
self.maxlen_in = maxlen_in
|
||||
self.maxlen_out = maxlen_out
|
||||
self.batch_count = batch_count
|
||||
self.batch_bins = batch_bins
|
||||
self.batch_frames_in = batch_frames_in
|
||||
self.batch_frames_out = batch_frames_out
|
||||
self.batch_frames_inout = batch_frames_inout
|
||||
self.subsampling_factor = subsampling_factor
|
||||
self.num_encs = num_encs
|
||||
self.preprocess_conf = preprocess_conf
|
||||
self.n_iter_processes = n_iter_processes
|
||||
|
||||
# read json data
|
||||
self.data_json = read_manifest(json_file)
|
||||
|
||||
# make minibatch list (variable length)
|
||||
self.minibaches = make_batchset(
|
||||
self.data_json,
|
||||
batch_size,
|
||||
maxlen_in,
|
||||
maxlen_out,
|
||||
minibatches, # for debug
|
||||
min_batch_size=mini_batch_size,
|
||||
shortest_first=self.use_sortagrad,
|
||||
count=batch_count,
|
||||
batch_bins=batch_bins,
|
||||
batch_frames_in=batch_frames_in,
|
||||
batch_frames_out=batch_frames_out,
|
||||
batch_frames_inout=batch_frames_inout,
|
||||
iaxis=0,
|
||||
oaxis=0, )
|
||||
|
||||
# data reader
|
||||
self.reader = LoadInputsAndTargets(
|
||||
mode="asr",
|
||||
load_output=True,
|
||||
preprocess_conf=preprocess_conf,
|
||||
preprocess_args={"train":
|
||||
train_mode}, # Switch the mode of preprocessing
|
||||
)
|
||||
|
||||
# Setup a converter
|
||||
if num_encs == 1:
|
||||
self.converter = CustomConverter(
|
||||
subsampling_factor=subsampling_factor, dtype=np.float32)
|
||||
else:
|
||||
assert NotImplementedError("not impl CustomConverterMulEnc.")
|
||||
|
||||
# hack to make batchsize argument as 1
|
||||
# actual bathsize is included in a list
|
||||
# default collate function converts numpy array to pytorch tensor
|
||||
# we used an empty collate function instead which returns list
|
||||
self.dataset = TransformDataset(
|
||||
self.minibaches,
|
||||
lambda data: self.converter([self.reader(data, return_uttid=True)]))
|
||||
self.dataloader = DataLoader(
|
||||
dataset=self.dataset,
|
||||
batch_size=1,
|
||||
shuffle=not use_sortagrad if train_mode else False,
|
||||
collate_fn=lambda x: x[0],
|
||||
num_workers=n_iter_processes, )
|
||||
|
||||
def __repr__(self):
|
||||
echo = f"<{self.__class__.__module__}.{self.__class__.__name__} object at {hex(id(self))}> "
|
||||
echo += f"train_mode: {self.train_mode}, "
|
||||
echo += f"sortagrad: {self.use_sortagrad}, "
|
||||
echo += f"batch_size: {self.batch_size}, "
|
||||
echo += f"maxlen_in: {self.maxlen_in}, "
|
||||
echo += f"maxlen_out: {self.maxlen_out}, "
|
||||
echo += f"batch_count: {self.batch_count}, "
|
||||
echo += f"batch_bins: {self.batch_bins}, "
|
||||
echo += f"batch_frames_in: {self.batch_frames_in}, "
|
||||
echo += f"batch_frames_out: {self.batch_frames_out}, "
|
||||
echo += f"batch_frames_inout: {self.batch_frames_inout}, "
|
||||
echo += f"subsampling_factor: {self.subsampling_factor}, "
|
||||
echo += f"num_encs: {self.num_encs}, "
|
||||
echo += f"num_workers: {self.n_iter_processes}, "
|
||||
echo += f"file: {self.json_file}"
|
||||
return echo
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataloader)
|
||||
|
||||
def __iter__(self):
|
||||
return self.dataloader.__iter__()
|
||||
|
||||
def __call__(self):
|
||||
return self.__iter__()
|
@ -0,0 +1,409 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from collections import OrderedDict
|
||||
|
||||
import kaldiio
|
||||
import numpy as np
|
||||
import soundfile
|
||||
|
||||
from deepspeech.frontend.augmentor.augmentation import AugmentationPipeline
|
||||
from deepspeech.utils.log import Log
|
||||
|
||||
__all__ = ["LoadInputsAndTargets"]
|
||||
|
||||
logger = Log(__name__).getlog()
|
||||
|
||||
|
||||
class LoadInputsAndTargets():
|
||||
"""Create a mini-batch from a list of dicts
|
||||
|
||||
>>> batch = [('utt1',
|
||||
... dict(input=[dict(feat='some.ark:123',
|
||||
... filetype='mat',
|
||||
... name='input1',
|
||||
... shape=[100, 80])],
|
||||
... output=[dict(tokenid='1 2 3 4',
|
||||
... name='target1',
|
||||
... shape=[4, 31])]]))
|
||||
>>> l = LoadInputsAndTargets()
|
||||
>>> feat, target = l(batch)
|
||||
|
||||
:param: str mode: Specify the task mode, "asr" or "tts"
|
||||
:param: str preprocess_conf: The path of a json file for pre-processing
|
||||
:param: bool load_input: If False, not to load the input data
|
||||
:param: bool load_output: If False, not to load the output data
|
||||
:param: bool sort_in_input_length: Sort the mini-batch in descending order
|
||||
of the input length
|
||||
:param: bool use_speaker_embedding: Used for tts mode only
|
||||
:param: bool use_second_target: Used for tts mode only
|
||||
:param: dict preprocess_args: Set some optional arguments for preprocessing
|
||||
:param: Optional[dict] preprocess_args: Used for tts mode only
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mode="asr",
|
||||
preprocess_conf=None,
|
||||
load_input=True,
|
||||
load_output=True,
|
||||
sort_in_input_length=True,
|
||||
preprocess_args=None,
|
||||
keep_all_data_on_mem=False, ):
|
||||
self._loaders = {}
|
||||
|
||||
if mode not in ["asr"]:
|
||||
raise ValueError("Only asr are allowed: mode={}".format(mode))
|
||||
|
||||
if preprocess_conf is not None:
|
||||
self.preprocessing = AugmentationPipeline(preprocess_conf)
|
||||
logging.warning(
|
||||
"[Experimental feature] Some preprocessing will be done "
|
||||
"for the mini-batch creation using {}".format(
|
||||
self.preprocessing))
|
||||
else:
|
||||
# If conf doesn't exist, this function don't touch anything.
|
||||
self.preprocessing = None
|
||||
|
||||
self.mode = mode
|
||||
self.load_output = load_output
|
||||
self.load_input = load_input
|
||||
self.sort_in_input_length = sort_in_input_length
|
||||
if preprocess_args is None:
|
||||
self.preprocess_args = {}
|
||||
else:
|
||||
assert isinstance(preprocess_args, dict), type(preprocess_args)
|
||||
self.preprocess_args = dict(preprocess_args)
|
||||
|
||||
self.keep_all_data_on_mem = keep_all_data_on_mem
|
||||
|
||||
def __call__(self, batch, return_uttid=False):
|
||||
"""Function to load inputs and targets from list of dicts
|
||||
|
||||
:param List[Tuple[str, dict]] batch: list of dict which is subset of
|
||||
loaded data.json
|
||||
:param bool return_uttid: return utterance ID information for visualization
|
||||
:return: list of input token id sequences [(L_1), (L_2), ..., (L_B)]
|
||||
:return: list of input feature sequences
|
||||
[(T_1, D), (T_2, D), ..., (T_B, D)]
|
||||
:rtype: list of float ndarray
|
||||
:return: list of target token id sequences [(L_1), (L_2), ..., (L_B)]
|
||||
:rtype: list of int ndarray
|
||||
|
||||
"""
|
||||
x_feats_dict = OrderedDict() # OrderedDict[str, List[np.ndarray]]
|
||||
y_feats_dict = OrderedDict() # OrderedDict[str, List[np.ndarray]]
|
||||
uttid_list = [] # List[str]
|
||||
|
||||
for uttid, info in batch:
|
||||
uttid_list.append(uttid)
|
||||
|
||||
if self.load_input:
|
||||
# Note(kamo): This for-loop is for multiple inputs
|
||||
for idx, inp in enumerate(info["input"]):
|
||||
# {"input":
|
||||
# [{"feat": "some/path.h5:F01_050C0101_PED_REAL",
|
||||
# "filetype": "hdf5",
|
||||
# "name": "input1", ...}], ...}
|
||||
x = self._get_from_loader(
|
||||
filepath=inp["feat"],
|
||||
filetype=inp.get("filetype", "mat"))
|
||||
x_feats_dict.setdefault(inp["name"], []).append(x)
|
||||
|
||||
if self.load_output:
|
||||
for idx, inp in enumerate(info["output"]):
|
||||
if "tokenid" in inp:
|
||||
# ======= Legacy format for output =======
|
||||
# {"output": [{"tokenid": "1 2 3 4"}])
|
||||
x = np.fromiter(
|
||||
map(int, inp["tokenid"].split()), dtype=np.int64)
|
||||
else:
|
||||
# ======= New format =======
|
||||
# {"input":
|
||||
# [{"feat": "some/path.h5:F01_050C0101_PED_REAL",
|
||||
# "filetype": "hdf5",
|
||||
# "name": "target1", ...}], ...}
|
||||
x = self._get_from_loader(
|
||||
filepath=inp["feat"],
|
||||
filetype=inp.get("filetype", "mat"))
|
||||
|
||||
y_feats_dict.setdefault(inp["name"], []).append(x)
|
||||
|
||||
if self.mode == "asr":
|
||||
return_batch, uttid_list = self._create_batch_asr(
|
||||
x_feats_dict, y_feats_dict, uttid_list)
|
||||
else:
|
||||
raise NotImplementedError(self.mode)
|
||||
|
||||
if self.preprocessing is not None:
|
||||
# Apply pre-processing all input features
|
||||
for x_name in return_batch.keys():
|
||||
if x_name.startswith("input"):
|
||||
return_batch[x_name] = self.preprocessing(
|
||||
return_batch[x_name], uttid_list,
|
||||
**self.preprocess_args)
|
||||
|
||||
if return_uttid:
|
||||
return tuple(return_batch.values()), uttid_list
|
||||
|
||||
# Doesn't return the names now.
|
||||
return tuple(return_batch.values())
|
||||
|
||||
def _create_batch_asr(self, x_feats_dict, y_feats_dict, uttid_list):
|
||||
"""Create a OrderedDict for the mini-batch
|
||||
|
||||
:param OrderedDict x_feats_dict:
|
||||
e.g. {"input1": [ndarray, ndarray, ...],
|
||||
"input2": [ndarray, ndarray, ...]}
|
||||
:param OrderedDict y_feats_dict:
|
||||
e.g. {"target1": [ndarray, ndarray, ...],
|
||||
"target2": [ndarray, ndarray, ...]}
|
||||
:param: List[str] uttid_list:
|
||||
Give uttid_list to sort in the same order as the mini-batch
|
||||
:return: batch, uttid_list
|
||||
:rtype: Tuple[OrderedDict, List[str]]
|
||||
"""
|
||||
# handle single-input and multi-input (paralell) asr mode
|
||||
xs = list(x_feats_dict.values())
|
||||
|
||||
if self.load_output:
|
||||
ys = list(y_feats_dict.values())
|
||||
assert len(xs[0]) == len(ys[0]), (len(xs[0]), len(ys[0]))
|
||||
|
||||
# get index of non-zero length samples
|
||||
nonzero_idx = list(
|
||||
filter(lambda i: len(ys[0][i]) > 0, range(len(ys[0]))))
|
||||
for n in range(1, len(y_feats_dict)):
|
||||
nonzero_idx = filter(lambda i: len(ys[n][i]) > 0, nonzero_idx)
|
||||
else:
|
||||
# Note(kamo): Be careful not to make nonzero_idx to a generator
|
||||
nonzero_idx = list(range(len(xs[0])))
|
||||
|
||||
if self.sort_in_input_length:
|
||||
# sort in input lengths based on the first input
|
||||
nonzero_sorted_idx = sorted(
|
||||
nonzero_idx, key=lambda i: -len(xs[0][i]))
|
||||
else:
|
||||
nonzero_sorted_idx = nonzero_idx
|
||||
|
||||
if len(nonzero_sorted_idx) != len(xs[0]):
|
||||
logging.warning(
|
||||
"Target sequences include empty tokenid (batch {} -> {}).".
|
||||
format(len(xs[0]), len(nonzero_sorted_idx)))
|
||||
|
||||
# remove zero-length samples
|
||||
xs = [[x[i] for i in nonzero_sorted_idx] for x in xs]
|
||||
uttid_list = [uttid_list[i] for i in nonzero_sorted_idx]
|
||||
|
||||
x_names = list(x_feats_dict.keys())
|
||||
if self.load_output:
|
||||
ys = [[y[i] for i in nonzero_sorted_idx] for y in ys]
|
||||
y_names = list(y_feats_dict.keys())
|
||||
|
||||
# Keeping x_name and y_name, e.g. input1, for future extension
|
||||
return_batch = OrderedDict([
|
||||
* [(x_name, x) for x_name, x in zip(x_names, xs)],
|
||||
* [(y_name, y) for y_name, y in zip(y_names, ys)],
|
||||
])
|
||||
else:
|
||||
return_batch = OrderedDict(
|
||||
[(x_name, x) for x_name, x in zip(x_names, xs)])
|
||||
return return_batch, uttid_list
|
||||
|
||||
def _get_from_loader(self, filepath, filetype):
|
||||
"""Return ndarray
|
||||
|
||||
In order to make the fds to be opened only at the first referring,
|
||||
the loader are stored in self._loaders
|
||||
|
||||
>>> ndarray = loader.get_from_loader(
|
||||
... 'some/path.h5:F01_050C0101_PED_REAL', filetype='hdf5')
|
||||
|
||||
:param: str filepath:
|
||||
:param: str filetype:
|
||||
:return:
|
||||
:rtype: np.ndarray
|
||||
"""
|
||||
if filetype == "hdf5":
|
||||
# e.g.
|
||||
# {"input": [{"feat": "some/path.h5:F01_050C0101_PED_REAL",
|
||||
# "filetype": "hdf5",
|
||||
# -> filepath = "some/path.h5", key = "F01_050C0101_PED_REAL"
|
||||
filepath, key = filepath.split(":", 1)
|
||||
|
||||
loader = self._loaders.get(filepath)
|
||||
if loader is None:
|
||||
# To avoid disk access, create loader only for the first time
|
||||
loader = h5py.File(filepath, "r")
|
||||
self._loaders[filepath] = loader
|
||||
return loader[key][()]
|
||||
elif filetype == "sound.hdf5":
|
||||
# e.g.
|
||||
# {"input": [{"feat": "some/path.h5:F01_050C0101_PED_REAL",
|
||||
# "filetype": "sound.hdf5",
|
||||
# -> filepath = "some/path.h5", key = "F01_050C0101_PED_REAL"
|
||||
filepath, key = filepath.split(":", 1)
|
||||
|
||||
loader = self._loaders.get(filepath)
|
||||
if loader is None:
|
||||
# To avoid disk access, create loader only for the first time
|
||||
loader = SoundHDF5File(filepath, "r", dtype="int16")
|
||||
self._loaders[filepath] = loader
|
||||
array, rate = loader[key]
|
||||
return array
|
||||
elif filetype == "sound":
|
||||
# e.g.
|
||||
# {"input": [{"feat": "some/path.wav",
|
||||
# "filetype": "sound"},
|
||||
# Assume PCM16
|
||||
if not self.keep_all_data_on_mem:
|
||||
array, _ = soundfile.read(filepath, dtype="int16")
|
||||
return array
|
||||
if filepath not in self._loaders:
|
||||
array, _ = soundfile.read(filepath, dtype="int16")
|
||||
self._loaders[filepath] = array
|
||||
return self._loaders[filepath]
|
||||
elif filetype == "npz":
|
||||
# e.g.
|
||||
# {"input": [{"feat": "some/path.npz:F01_050C0101_PED_REAL",
|
||||
# "filetype": "npz",
|
||||
filepath, key = filepath.split(":", 1)
|
||||
|
||||
loader = self._loaders.get(filepath)
|
||||
if loader is None:
|
||||
# To avoid disk access, create loader only for the first time
|
||||
loader = np.load(filepath)
|
||||
self._loaders[filepath] = loader
|
||||
return loader[key]
|
||||
elif filetype == "npy":
|
||||
# e.g.
|
||||
# {"input": [{"feat": "some/path.npy",
|
||||
# "filetype": "npy"},
|
||||
if not self.keep_all_data_on_mem:
|
||||
return np.load(filepath)
|
||||
if filepath not in self._loaders:
|
||||
self._loaders[filepath] = np.load(filepath)
|
||||
return self._loaders[filepath]
|
||||
elif filetype in ["mat", "vec"]:
|
||||
# e.g.
|
||||
# {"input": [{"feat": "some/path.ark:123",
|
||||
# "filetype": "mat"}]},
|
||||
# In this case, "123" indicates the starting points of the matrix
|
||||
# load_mat can load both matrix and vector
|
||||
if not self.keep_all_data_on_mem:
|
||||
return kaldiio.load_mat(filepath)
|
||||
if filepath not in self._loaders:
|
||||
self._loaders[filepath] = kaldiio.load_mat(filepath)
|
||||
return self._loaders[filepath]
|
||||
elif filetype == "scp":
|
||||
# e.g.
|
||||
# {"input": [{"feat": "some/path.scp:F01_050C0101_PED_REAL",
|
||||
# "filetype": "scp",
|
||||
filepath, key = filepath.split(":", 1)
|
||||
loader = self._loaders.get(filepath)
|
||||
if loader is None:
|
||||
# To avoid disk access, create loader only for the first time
|
||||
loader = kaldiio.load_scp(filepath)
|
||||
self._loaders[filepath] = loader
|
||||
return loader[key]
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Not supported: loader_type={}".format(filetype))
|
||||
|
||||
|
||||
class SoundHDF5File():
|
||||
"""Collecting sound files to a HDF5 file
|
||||
|
||||
>>> f = SoundHDF5File('a.flac.h5', mode='a')
|
||||
>>> array = np.random.randint(0, 100, 100, dtype=np.int16)
|
||||
>>> f['id'] = (array, 16000)
|
||||
>>> array, rate = f['id']
|
||||
|
||||
|
||||
:param: str filepath:
|
||||
:param: str mode:
|
||||
:param: str format: The type used when saving wav. flac, nist, htk, etc.
|
||||
:param: str dtype:
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
filepath,
|
||||
mode="r+",
|
||||
format=None,
|
||||
dtype="int16",
|
||||
**kwargs):
|
||||
self.filepath = filepath
|
||||
self.mode = mode
|
||||
self.dtype = dtype
|
||||
|
||||
self.file = h5py.File(filepath, mode, **kwargs)
|
||||
if format is None:
|
||||
# filepath = a.flac.h5 -> format = flac
|
||||
second_ext = os.path.splitext(os.path.splitext(filepath)[0])[1]
|
||||
format = second_ext[1:]
|
||||
if format.upper() not in soundfile.available_formats():
|
||||
# If not found, flac is selected
|
||||
format = "flac"
|
||||
|
||||
# This format affects only saving
|
||||
self.format = format
|
||||
|
||||
def __repr__(self):
|
||||
return '<SoundHDF5 file "{}" (mode {}, format {}, type {})>'.format(
|
||||
self.filepath, self.mode, self.format, self.dtype)
|
||||
|
||||
def create_dataset(self, name, shape=None, data=None, **kwds):
|
||||
f = io.BytesIO()
|
||||
array, rate = data
|
||||
soundfile.write(f, array, rate, format=self.format)
|
||||
self.file.create_dataset(
|
||||
name, shape=shape, data=np.void(f.getvalue()), **kwds)
|
||||
|
||||
def __setitem__(self, name, data):
|
||||
self.create_dataset(name, data=data)
|
||||
|
||||
def __getitem__(self, key):
|
||||
data = self.file[key][()]
|
||||
f = io.BytesIO(data.tobytes())
|
||||
array, rate = soundfile.read(f, dtype=self.dtype)
|
||||
return array, rate
|
||||
|
||||
def keys(self):
|
||||
return self.file.keys()
|
||||
|
||||
def values(self):
|
||||
for k in self.file:
|
||||
yield self[k]
|
||||
|
||||
def items(self):
|
||||
for k in self.file:
|
||||
yield k, self[k]
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.file)
|
||||
|
||||
def __contains__(self, item):
|
||||
return item in self.file
|
||||
|
||||
def __len__(self, item):
|
||||
return len(self.file)
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.file.close()
|
||||
|
||||
def close(self):
|
||||
self.file.close()
|
@ -0,0 +1,36 @@
|
||||
#!/usr/bin/env python
|
||||
import argparse
|
||||
import json
|
||||
|
||||
|
||||
def main(args):
|
||||
with open(args.json_file, 'r') as fin:
|
||||
data_json = json.load(fin)
|
||||
|
||||
# manifest format:
|
||||
# {"input": [
|
||||
# {"feat": "dev/deltafalse/feats.1.ark:842920", "name": "input1", "shape": [349, 83]}
|
||||
# ],
|
||||
# "output": [
|
||||
# {"name": "target1", "shape": [12, 5002], "text": "NO APOLLO", "token": "▁NO ▁A PO LL O", "tokenid": "3144 482 352 269 317"}
|
||||
# ],
|
||||
# "utt2spk": "116-288045",
|
||||
# "utt": "116-288045-0019"}
|
||||
with open(args.manifest_file, 'w') as fout:
|
||||
for key, value in data_json['utts'].items():
|
||||
value['utt'] = key
|
||||
fout.write(json.dumps(value, ensure_ascii=False))
|
||||
fout.write("\n")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument(
|
||||
'--json-file', type=str, default=None, help="espnet data json file.")
|
||||
parser.add_argument(
|
||||
'--manifest-file',
|
||||
type=str,
|
||||
default='maniefst.train',
|
||||
help='manifest data json line file.')
|
||||
args = parser.parse_args()
|
||||
main(args)
|
Loading…
Reference in new issue