You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
1334 lines
57 KiB
1334 lines
57 KiB
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"id": "extensive-venice",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"/workspace/zhanghui/DeepSpeech-2.x\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"'/workspace/zhanghui/DeepSpeech-2.x'"
|
|
]
|
|
},
|
|
"execution_count": 1,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"%cd ..\n",
|
|
"%pwd"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"id": "correct-window",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"manifest.dev\t manifest.test-clean\t manifest.train\r\n",
|
|
"manifest.dev.raw manifest.test-clean.raw manifest.train.raw\r\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"!ls /workspace/zhanghui/DeepSpeech-2.x/examples/librispeech/s2/data/"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 9,
|
|
"id": "exceptional-cheese",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"dev_data='/workspace/zhanghui/DeepSpeech-2.x/examples/librispeech/s2/data/manifest.dev'"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 11,
|
|
"id": "extraordinary-orleans",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"grep: warning: GREP_OPTIONS is deprecated; please use an alias or script\n",
|
|
"register user softmax to paddle, remove this when fixed!\n",
|
|
"register user log_softmax to paddle, remove this when fixed!\n",
|
|
"register user sigmoid to paddle, remove this when fixed!\n",
|
|
"register user log_sigmoid to paddle, remove this when fixed!\n",
|
|
"register user relu to paddle, remove this when fixed!\n",
|
|
"override cat of paddle if exists or register, remove this when fixed!\n",
|
|
"override long of paddle.Tensor if exists or register, remove this when fixed!\n",
|
|
"override new_full of paddle.Tensor if exists or register, remove this when fixed!\n",
|
|
"override eq of paddle.Tensor if exists or register, remove this when fixed!\n",
|
|
"override eq of paddle if exists or register, remove this when fixed!\n",
|
|
"override contiguous of paddle.Tensor if exists or register, remove this when fixed!\n",
|
|
"override size of paddle.Tensor (`to_static` do not process `size` property, maybe some `paddle` api dependent on it), remove this when fixed!\n",
|
|
"register user view to paddle.Tensor, remove this when fixed!\n",
|
|
"register user view_as to paddle.Tensor, remove this when fixed!\n",
|
|
"register user masked_fill to paddle.Tensor, remove this when fixed!\n",
|
|
"register user masked_fill_ to paddle.Tensor, remove this when fixed!\n",
|
|
"register user fill_ to paddle.Tensor, remove this when fixed!\n",
|
|
"register user repeat to paddle.Tensor, remove this when fixed!\n",
|
|
"register user softmax to paddle.Tensor, remove this when fixed!\n",
|
|
"register user sigmoid to paddle.Tensor, remove this when fixed!\n",
|
|
"register user relu to paddle.Tensor, remove this when fixed!\n",
|
|
"register user type_as to paddle.Tensor, remove this when fixed!\n",
|
|
"register user to to paddle.Tensor, remove this when fixed!\n",
|
|
"register user float to paddle.Tensor, remove this when fixed!\n",
|
|
"register user int to paddle.Tensor, remove this when fixed!\n",
|
|
"register user GLU to paddle.nn, remove this when fixed!\n",
|
|
"register user ConstantPad2d to paddle.nn, remove this when fixed!\n",
|
|
"register user export to paddle.jit, remove this when fixed!\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"from deepspeech.frontend.utility import read_manifest"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 12,
|
|
"id": "returning-lighter",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"dev_json = read_manifest(dev_data)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 13,
|
|
"id": "western-founder",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"{'input': [{'feat': '/workspace/zhanghui/asr/espnet/egs/librispeech/asr1/dump/dev/deltafalse/feats.1.ark:16',\n",
|
|
" 'name': 'input1',\n",
|
|
" 'shape': [1063, 83]}],\n",
|
|
" 'output': [{'name': 'target1',\n",
|
|
" 'shape': [41, 5002],\n",
|
|
" 'text': 'AS I APPROACHED THE CITY I HEARD BELLS RINGING AND A '\n",
|
|
" 'LITTLE LATER I FOUND THE STREETS ASTIR WITH THRONGS OF '\n",
|
|
" 'WELL DRESSED PEOPLE IN FAMILY GROUPS WENDING THEIR WAY '\n",
|
|
" 'HITHER AND THITHER',\n",
|
|
" 'token': '▁AS ▁I ▁APPROACHED ▁THE ▁CITY ▁I ▁HEARD ▁BELL S ▁RING '\n",
|
|
" 'ING ▁AND ▁A ▁LITTLE ▁LATER ▁I ▁FOUND ▁THE ▁STREETS ▁AS '\n",
|
|
" 'T IR ▁WITH ▁THRONG S ▁OF ▁WELL ▁DRESSED ▁PEOPLE ▁IN '\n",
|
|
" '▁FAMILY ▁GROUP S ▁WE ND ING ▁THEIR ▁WAY ▁HITHER ▁AND '\n",
|
|
" '▁THITHER',\n",
|
|
" 'tokenid': '713 2458 676 4502 1155 2458 2351 849 389 3831 206 627 '\n",
|
|
" '482 2812 2728 2458 2104 4502 4316 713 404 212 4925 '\n",
|
|
" '4549 389 3204 4861 1677 3339 2495 1950 2279 389 4845 '\n",
|
|
" '302 206 4504 4843 2394 627 4526'}],\n",
|
|
" 'utt': '116-288045-0000',\n",
|
|
" 'utt2spk': '116-288045'}\n",
|
|
"5542\n",
|
|
"<class 'list'>\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"from pprint import pprint\n",
|
|
"pprint(dev_json[0])\n",
|
|
"print(len(dev_json))\n",
|
|
"print(type(dev_json))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 14,
|
|
"id": "motivated-receptor",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.\n",
|
|
"#\n",
|
|
"# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
|
|
"# you may not use this file except in compliance with the License.\n",
|
|
"# You may obtain a copy of the License at\n",
|
|
"#\n",
|
|
"# http://www.apache.org/licenses/LICENSE-2.0\n",
|
|
"#\n",
|
|
"# Unless required by applicable law or agreed to in writing, software\n",
|
|
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
|
|
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
|
|
"# See the License for the specific language governing permissions and\n",
|
|
"# limitations under the License.\n",
|
|
"import itertools\n",
|
|
"\n",
|
|
"import numpy as np\n",
|
|
"\n",
|
|
"from deepspeech.utils.log import Log\n",
|
|
"\n",
|
|
"__all__ = [\"make_batchset\"]\n",
|
|
"\n",
|
|
"logger = Log(__name__).getlog()\n",
|
|
"\n",
|
|
"\n",
|
|
"def batchfy_by_seq(\n",
|
|
" sorted_data,\n",
|
|
" batch_size,\n",
|
|
" max_length_in,\n",
|
|
" max_length_out,\n",
|
|
" min_batch_size=1,\n",
|
|
" shortest_first=False,\n",
|
|
" ikey=\"input\",\n",
|
|
" iaxis=0,\n",
|
|
" okey=\"output\",\n",
|
|
" oaxis=0, ):\n",
|
|
" \"\"\"Make batch set from json dictionary\n",
|
|
"\n",
|
|
" :param List[(str, Dict[str, Any])] sorted_data: dictionary loaded from data.json\n",
|
|
" :param int batch_size: batch size\n",
|
|
" :param int max_length_in: maximum length of input to decide adaptive batch size\n",
|
|
" :param int max_length_out: maximum length of output to decide adaptive batch size\n",
|
|
" :param int min_batch_size: mininum batch size (for multi-gpu)\n",
|
|
" :param bool shortest_first: Sort from batch with shortest samples\n",
|
|
" to longest if true, otherwise reverse\n",
|
|
" :param str ikey: key to access input\n",
|
|
" (for ASR ikey=\"input\", for TTS, MT ikey=\"output\".)\n",
|
|
" :param int iaxis: dimension to access input\n",
|
|
" (for ASR, TTS iaxis=0, for MT iaxis=\"1\".)\n",
|
|
" :param str okey: key to access output\n",
|
|
" (for ASR, MT okey=\"output\". for TTS okey=\"input\".)\n",
|
|
" :param int oaxis: dimension to access output\n",
|
|
" (for ASR, TTS, MT oaxis=0, reserved for future research, -1 means all axis.)\n",
|
|
" :return: List[List[Tuple[str, dict]]] list of batches\n",
|
|
" \"\"\"\n",
|
|
" if batch_size <= 0:\n",
|
|
" raise ValueError(f\"Invalid batch_size={batch_size}\")\n",
|
|
"\n",
|
|
" # check #utts is more than min_batch_size\n",
|
|
" if len(sorted_data) < min_batch_size:\n",
|
|
" raise ValueError(\n",
|
|
" f\"#utts({len(sorted_data)}) is less than min_batch_size({min_batch_size}).\"\n",
|
|
" )\n",
|
|
"\n",
|
|
" # make list of minibatches\n",
|
|
" minibatches = []\n",
|
|
" start = 0\n",
|
|
" while True:\n",
|
|
" _, info = sorted_data[start]\n",
|
|
" ilen = int(info[ikey][iaxis][\"shape\"][0])\n",
|
|
" olen = (int(info[okey][oaxis][\"shape\"][0]) if oaxis >= 0 else\n",
|
|
" max(map(lambda x: int(x[\"shape\"][0]), info[okey])))\n",
|
|
" factor = max(int(ilen / max_length_in), int(olen / max_length_out))\n",
|
|
" # change batchsize depending on the input and output length\n",
|
|
" # if ilen = 1000 and max_length_in = 800\n",
|
|
" # then b = batchsize / 2\n",
|
|
" # and max(min_batches, .) avoids batchsize = 0\n",
|
|
" bs = max(min_batch_size, int(batch_size / (1 + factor)))\n",
|
|
" end = min(len(sorted_data), start + bs)\n",
|
|
" minibatch = sorted_data[start:end]\n",
|
|
" if shortest_first:\n",
|
|
" minibatch.reverse()\n",
|
|
"\n",
|
|
" # check each batch is more than minimum batchsize\n",
|
|
" if len(minibatch) < min_batch_size:\n",
|
|
" mod = min_batch_size - len(minibatch) % min_batch_size\n",
|
|
" additional_minibatch = [\n",
|
|
" sorted_data[i] for i in np.random.randint(0, start, mod)\n",
|
|
" ]\n",
|
|
" if shortest_first:\n",
|
|
" additional_minibatch.reverse()\n",
|
|
" minibatch.extend(additional_minibatch)\n",
|
|
" minibatches.append(minibatch)\n",
|
|
"\n",
|
|
" if end == len(sorted_data):\n",
|
|
" break\n",
|
|
" start = end\n",
|
|
"\n",
|
|
" # batch: List[List[Tuple[str, dict]]]\n",
|
|
" return minibatches\n",
|
|
"\n",
|
|
"\n",
|
|
"def batchfy_by_bin(\n",
|
|
" sorted_data,\n",
|
|
" batch_bins,\n",
|
|
" num_batches=0,\n",
|
|
" min_batch_size=1,\n",
|
|
" shortest_first=False,\n",
|
|
" ikey=\"input\",\n",
|
|
" okey=\"output\", ):\n",
|
|
" \"\"\"Make variably sized batch set, which maximizes\n",
|
|
"\n",
|
|
" the number of bins up to `batch_bins`.\n",
|
|
"\n",
|
|
" :param List[(str, Dict[str, Any])] sorted_data: dictionary loaded from data.json\n",
|
|
" :param int batch_bins: Maximum frames of a batch\n",
|
|
" :param int num_batches: # number of batches to use (for debug)\n",
|
|
" :param int min_batch_size: minimum batch size (for multi-gpu)\n",
|
|
" :param int test: Return only every `test` batches\n",
|
|
" :param bool shortest_first: Sort from batch with shortest samples\n",
|
|
" to longest if true, otherwise reverse\n",
|
|
"\n",
|
|
" :param str ikey: key to access input (for ASR ikey=\"input\", for TTS ikey=\"output\".)\n",
|
|
" :param str okey: key to access output (for ASR okey=\"output\". for TTS okey=\"input\".)\n",
|
|
"\n",
|
|
" :return: List[Tuple[str, Dict[str, List[Dict[str, Any]]]] list of batches\n",
|
|
" \"\"\"\n",
|
|
" if batch_bins <= 0:\n",
|
|
" raise ValueError(f\"invalid batch_bins={batch_bins}\")\n",
|
|
" length = len(sorted_data)\n",
|
|
" idim = int(sorted_data[0][1][ikey][0][\"shape\"][1])\n",
|
|
" odim = int(sorted_data[0][1][okey][0][\"shape\"][1])\n",
|
|
" logger.info(\"# utts: \" + str(len(sorted_data)))\n",
|
|
" minibatches = []\n",
|
|
" start = 0\n",
|
|
" n = 0\n",
|
|
" while True:\n",
|
|
" # Dynamic batch size depending on size of samples\n",
|
|
" b = 0\n",
|
|
" next_size = 0\n",
|
|
" max_olen = 0\n",
|
|
" while next_size < batch_bins and (start + b) < length:\n",
|
|
" ilen = int(sorted_data[start + b][1][ikey][0][\"shape\"][0]) * idim\n",
|
|
" olen = int(sorted_data[start + b][1][okey][0][\"shape\"][0]) * odim\n",
|
|
" if olen > max_olen:\n",
|
|
" max_olen = olen\n",
|
|
" next_size = (max_olen + ilen) * (b + 1)\n",
|
|
" if next_size <= batch_bins:\n",
|
|
" b += 1\n",
|
|
" elif next_size == 0:\n",
|
|
" raise ValueError(\n",
|
|
" f\"Can't fit one sample in batch_bins ({batch_bins}): \"\n",
|
|
" f\"Please increase the value\")\n",
|
|
" end = min(length, start + max(min_batch_size, b))\n",
|
|
" batch = sorted_data[start:end]\n",
|
|
" if shortest_first:\n",
|
|
" batch.reverse()\n",
|
|
" minibatches.append(batch)\n",
|
|
" # Check for min_batch_size and fixes the batches if needed\n",
|
|
" i = -1\n",
|
|
" while len(minibatches[i]) < min_batch_size:\n",
|
|
" missing = min_batch_size - len(minibatches[i])\n",
|
|
" if -i == len(minibatches):\n",
|
|
" minibatches[i + 1].extend(minibatches[i])\n",
|
|
" minibatches = minibatches[1:]\n",
|
|
" break\n",
|
|
" else:\n",
|
|
" minibatches[i].extend(minibatches[i - 1][:missing])\n",
|
|
" minibatches[i - 1] = minibatches[i - 1][missing:]\n",
|
|
" i -= 1\n",
|
|
" if end == length:\n",
|
|
" break\n",
|
|
" start = end\n",
|
|
" n += 1\n",
|
|
" if num_batches > 0:\n",
|
|
" minibatches = minibatches[:num_batches]\n",
|
|
" lengths = [len(x) for x in minibatches]\n",
|
|
" logger.info(\n",
|
|
" str(len(minibatches)) + \" batches containing from \" + str(min(lengths))\n",
|
|
" + \" to \" + str(max(lengths)) + \" samples \" + \"(avg \" + str(\n",
|
|
" int(np.mean(lengths))) + \" samples).\")\n",
|
|
" return minibatches\n",
|
|
"\n",
|
|
"\n",
|
|
"def batchfy_by_frame(\n",
|
|
" sorted_data,\n",
|
|
" max_frames_in,\n",
|
|
" max_frames_out,\n",
|
|
" max_frames_inout,\n",
|
|
" num_batches=0,\n",
|
|
" min_batch_size=1,\n",
|
|
" shortest_first=False,\n",
|
|
" ikey=\"input\",\n",
|
|
" okey=\"output\", ):\n",
|
|
" \"\"\"Make variable batch set, which maximizes the number of frames to max_batch_frame.\n",
|
|
"\n",
|
|
" :param List[(str, Dict[str, Any])] sorteddata: dictionary loaded from data.json\n",
|
|
" :param int max_frames_in: Maximum input frames of a batch\n",
|
|
" :param int max_frames_out: Maximum output frames of a batch\n",
|
|
" :param int max_frames_inout: Maximum input+output frames of a batch\n",
|
|
" :param int num_batches: # number of batches to use (for debug)\n",
|
|
" :param int min_batch_size: minimum batch size (for multi-gpu)\n",
|
|
" :param int test: Return only every `test` batches\n",
|
|
" :param bool shortest_first: Sort from batch with shortest samples\n",
|
|
" to longest if true, otherwise reverse\n",
|
|
"\n",
|
|
" :param str ikey: key to access input (for ASR ikey=\"input\", for TTS ikey=\"output\".)\n",
|
|
" :param str okey: key to access output (for ASR okey=\"output\". for TTS okey=\"input\".)\n",
|
|
"\n",
|
|
" :return: List[Tuple[str, Dict[str, List[Dict[str, Any]]]] list of batches\n",
|
|
" \"\"\"\n",
|
|
" if max_frames_in <= 0 and max_frames_out <= 0 and max_frames_inout <= 0:\n",
|
|
" raise ValueError(\n",
|
|
" \"At least, one of `--batch-frames-in`, `--batch-frames-out` or \"\n",
|
|
" \"`--batch-frames-inout` should be > 0\")\n",
|
|
" length = len(sorted_data)\n",
|
|
" minibatches = []\n",
|
|
" start = 0\n",
|
|
" end = 0\n",
|
|
" while end != length:\n",
|
|
" # Dynamic batch size depending on size of samples\n",
|
|
" b = 0\n",
|
|
" max_olen = 0\n",
|
|
" max_ilen = 0\n",
|
|
" while (start + b) < length:\n",
|
|
" ilen = int(sorted_data[start + b][1][ikey][0][\"shape\"][0])\n",
|
|
" if ilen > max_frames_in and max_frames_in != 0:\n",
|
|
" raise ValueError(\n",
|
|
" f\"Can't fit one sample in --batch-frames-in ({max_frames_in}): \"\n",
|
|
" f\"Please increase the value\")\n",
|
|
" olen = int(sorted_data[start + b][1][okey][0][\"shape\"][0])\n",
|
|
" if olen > max_frames_out and max_frames_out != 0:\n",
|
|
" raise ValueError(\n",
|
|
" f\"Can't fit one sample in --batch-frames-out ({max_frames_out}): \"\n",
|
|
" f\"Please increase the value\")\n",
|
|
" if ilen + olen > max_frames_inout and max_frames_inout != 0:\n",
|
|
" raise ValueError(\n",
|
|
" f\"Can't fit one sample in --batch-frames-out ({max_frames_inout}): \"\n",
|
|
" f\"Please increase the value\")\n",
|
|
" max_olen = max(max_olen, olen)\n",
|
|
" max_ilen = max(max_ilen, ilen)\n",
|
|
" in_ok = max_ilen * (b + 1) <= max_frames_in or max_frames_in == 0\n",
|
|
" out_ok = max_olen * (b + 1) <= max_frames_out or max_frames_out == 0\n",
|
|
" inout_ok = (max_ilen + max_olen) * (\n",
|
|
" b + 1) <= max_frames_inout or max_frames_inout == 0\n",
|
|
" if in_ok and out_ok and inout_ok:\n",
|
|
" # add more seq in the minibatch\n",
|
|
" b += 1\n",
|
|
" else:\n",
|
|
" # no more seq in the minibatch\n",
|
|
" break\n",
|
|
" end = min(length, start + b)\n",
|
|
" batch = sorted_data[start:end]\n",
|
|
" if shortest_first:\n",
|
|
" batch.reverse()\n",
|
|
" minibatches.append(batch)\n",
|
|
" # Check for min_batch_size and fixes the batches if needed\n",
|
|
" i = -1\n",
|
|
" while len(minibatches[i]) < min_batch_size:\n",
|
|
" missing = min_batch_size - len(minibatches[i])\n",
|
|
" if -i == len(minibatches):\n",
|
|
" minibatches[i + 1].extend(minibatches[i])\n",
|
|
" minibatches = minibatches[1:]\n",
|
|
" break\n",
|
|
" else:\n",
|
|
" minibatches[i].extend(minibatches[i - 1][:missing])\n",
|
|
" minibatches[i - 1] = minibatches[i - 1][missing:]\n",
|
|
" i -= 1\n",
|
|
" start = end\n",
|
|
" if num_batches > 0:\n",
|
|
" minibatches = minibatches[:num_batches]\n",
|
|
" lengths = [len(x) for x in minibatches]\n",
|
|
" logger.info(\n",
|
|
" str(len(minibatches)) + \" batches containing from \" + str(min(lengths))\n",
|
|
" + \" to \" + str(max(lengths)) + \" samples\" + \"(avg \" + str(\n",
|
|
" int(np.mean(lengths))) + \" samples).\")\n",
|
|
"\n",
|
|
" return minibatches\n",
|
|
"\n",
|
|
"\n",
|
|
"def batchfy_shuffle(data, batch_size, min_batch_size, num_batches,\n",
|
|
" shortest_first):\n",
|
|
" import random\n",
|
|
"\n",
|
|
" logger.info(\"use shuffled batch.\")\n",
|
|
" sorted_data = random.sample(data.items(), len(data.items()))\n",
|
|
" logger.info(\"# utts: \" + str(len(sorted_data)))\n",
|
|
" # make list of minibatches\n",
|
|
" minibatches = []\n",
|
|
" start = 0\n",
|
|
" while True:\n",
|
|
" end = min(len(sorted_data), start + batch_size)\n",
|
|
" # check each batch is more than minimum batchsize\n",
|
|
" minibatch = sorted_data[start:end]\n",
|
|
" if shortest_first:\n",
|
|
" minibatch.reverse()\n",
|
|
" if len(minibatch) < min_batch_size:\n",
|
|
" mod = min_batch_size - len(minibatch) % min_batch_size\n",
|
|
" additional_minibatch = [\n",
|
|
" sorted_data[i] for i in np.random.randint(0, start, mod)\n",
|
|
" ]\n",
|
|
" if shortest_first:\n",
|
|
" additional_minibatch.reverse()\n",
|
|
" minibatch.extend(additional_minibatch)\n",
|
|
" minibatches.append(minibatch)\n",
|
|
" if end == len(sorted_data):\n",
|
|
" break\n",
|
|
" start = end\n",
|
|
"\n",
|
|
" # for debugging\n",
|
|
" if num_batches > 0:\n",
|
|
" minibatches = minibatches[:num_batches]\n",
|
|
" logger.info(\"# minibatches: \" + str(len(minibatches)))\n",
|
|
" return minibatches\n",
|
|
"\n",
|
|
"\n",
|
|
"BATCH_COUNT_CHOICES = [\"auto\", \"seq\", \"bin\", \"frame\"]\n",
|
|
"BATCH_SORT_KEY_CHOICES = [\"input\", \"output\", \"shuffle\"]\n",
|
|
"\n",
|
|
"\n",
|
|
"def make_batchset(\n",
|
|
" data,\n",
|
|
" batch_size=0,\n",
|
|
" max_length_in=float(\"inf\"),\n",
|
|
" max_length_out=float(\"inf\"),\n",
|
|
" num_batches=0,\n",
|
|
" min_batch_size=1,\n",
|
|
" shortest_first=False,\n",
|
|
" batch_sort_key=\"input\",\n",
|
|
" count=\"auto\",\n",
|
|
" batch_bins=0,\n",
|
|
" batch_frames_in=0,\n",
|
|
" batch_frames_out=0,\n",
|
|
" batch_frames_inout=0,\n",
|
|
" iaxis=0,\n",
|
|
" oaxis=0, ):\n",
|
|
" \"\"\"Make batch set from json dictionary\n",
|
|
"\n",
|
|
" if utts have \"category\" value,\n",
|
|
"\n",
|
|
" >>> data = {'utt1': {'category': 'A', 'input': ...},\n",
|
|
" ... 'utt2': {'category': 'B', 'input': ...},\n",
|
|
" ... 'utt3': {'category': 'B', 'input': ...},\n",
|
|
" ... 'utt4': {'category': 'A', 'input': ...}}\n",
|
|
" >>> make_batchset(data, batchsize=2, ...)\n",
|
|
" [[('utt1', ...), ('utt4', ...)], [('utt2', ...), ('utt3': ...)]]\n",
|
|
"\n",
|
|
" Note that if any utts doesn't have \"category\",\n",
|
|
" perform as same as batchfy_by_{count}\n",
|
|
"\n",
|
|
" :param List[Dict[str, Any]] data: dictionary loaded from data.json\n",
|
|
" :param int batch_size: maximum number of sequences in a minibatch.\n",
|
|
" :param int batch_bins: maximum number of bins (frames x dim) in a minibatch.\n",
|
|
" :param int batch_frames_in: maximum number of input frames in a minibatch.\n",
|
|
" :param int batch_frames_out: maximum number of output frames in a minibatch.\n",
|
|
" :param int batch_frames_out: maximum number of input+output frames in a minibatch.\n",
|
|
" :param str count: strategy to count maximum size of batch.\n",
|
|
" For choices, see espnet.asr.batchfy.BATCH_COUNT_CHOICES\n",
|
|
"\n",
|
|
" :param int max_length_in: maximum length of input to decide adaptive batch size\n",
|
|
" :param int max_length_out: maximum length of output to decide adaptive batch size\n",
|
|
" :param int num_batches: # number of batches to use (for debug)\n",
|
|
" :param int min_batch_size: minimum batch size (for multi-gpu)\n",
|
|
" :param bool shortest_first: Sort from batch with shortest samples\n",
|
|
" to longest if true, otherwise reverse\n",
|
|
" :param str batch_sort_key: how to sort data before creating minibatches\n",
|
|
" [\"input\", \"output\", \"shuffle\"]\n",
|
|
" :param bool swap_io: if True, use \"input\" as output and \"output\"\n",
|
|
" as input in `data` dict\n",
|
|
" :param bool mt: if True, use 0-axis of \"output\" as output and 1-axis of \"output\"\n",
|
|
" as input in `data` dict\n",
|
|
" :param int iaxis: dimension to access input\n",
|
|
" (for ASR, TTS iaxis=0, for MT iaxis=\"1\".)\n",
|
|
" :param int oaxis: dimension to access output (for ASR, TTS, MT oaxis=0,\n",
|
|
" reserved for future research, -1 means all axis.)\n",
|
|
" :return: List[List[Tuple[str, dict]]] list of batches\n",
|
|
" \"\"\"\n",
|
|
"\n",
|
|
" # check args\n",
|
|
" if count not in BATCH_COUNT_CHOICES:\n",
|
|
" raise ValueError(\n",
|
|
" f\"arg 'count' ({count}) should be one of {BATCH_COUNT_CHOICES}\")\n",
|
|
" if batch_sort_key not in BATCH_SORT_KEY_CHOICES:\n",
|
|
" raise ValueError(f\"arg 'batch_sort_key' ({batch_sort_key}) should be \"\n",
|
|
" f\"one of {BATCH_SORT_KEY_CHOICES}\")\n",
|
|
"\n",
|
|
" ikey = \"input\"\n",
|
|
" okey = \"output\"\n",
|
|
" batch_sort_axis = 0 # index of list \n",
|
|
"\n",
|
|
" if count == \"auto\":\n",
|
|
" if batch_size != 0:\n",
|
|
" count = \"seq\"\n",
|
|
" elif batch_bins != 0:\n",
|
|
" count = \"bin\"\n",
|
|
" elif batch_frames_in != 0 or batch_frames_out != 0 or batch_frames_inout != 0:\n",
|
|
" count = \"frame\"\n",
|
|
" else:\n",
|
|
" raise ValueError(\n",
|
|
" f\"cannot detect `count` manually set one of {BATCH_COUNT_CHOICES}\"\n",
|
|
" )\n",
|
|
" logger.info(f\"count is auto detected as {count}\")\n",
|
|
"\n",
|
|
" if count != \"seq\" and batch_sort_key == \"shuffle\":\n",
|
|
" raise ValueError(\n",
|
|
" \"batch_sort_key=shuffle is only available if batch_count=seq\")\n",
|
|
"\n",
|
|
" category2data = {} # Dict[str, dict]\n",
|
|
" for v in data:\n",
|
|
" k = v['utt']\n",
|
|
" category2data.setdefault(v.get(\"category\"), {})[k] = v\n",
|
|
"\n",
|
|
" batches_list = [] # List[List[List[Tuple[str, dict]]]]\n",
|
|
" for d in category2data.values():\n",
|
|
" if batch_sort_key == \"shuffle\":\n",
|
|
" batches = batchfy_shuffle(d, batch_size, min_batch_size,\n",
|
|
" num_batches, shortest_first)\n",
|
|
" batches_list.append(batches)\n",
|
|
" continue\n",
|
|
"\n",
|
|
" # sort it by input lengths (long to short)\n",
|
|
" sorted_data = sorted(\n",
|
|
" d.items(),\n",
|
|
" key=lambda data: int(data[1][batch_sort_key][batch_sort_axis][\"shape\"][0]),\n",
|
|
" reverse=not shortest_first, )\n",
|
|
" logger.info(\"# utts: \" + str(len(sorted_data)))\n",
|
|
" \n",
|
|
" if count == \"seq\":\n",
|
|
" batches = batchfy_by_seq(\n",
|
|
" sorted_data,\n",
|
|
" batch_size=batch_size,\n",
|
|
" max_length_in=max_length_in,\n",
|
|
" max_length_out=max_length_out,\n",
|
|
" min_batch_size=min_batch_size,\n",
|
|
" shortest_first=shortest_first,\n",
|
|
" ikey=ikey,\n",
|
|
" iaxis=iaxis,\n",
|
|
" okey=okey,\n",
|
|
" oaxis=oaxis, )\n",
|
|
" if count == \"bin\":\n",
|
|
" batches = batchfy_by_bin(\n",
|
|
" sorted_data,\n",
|
|
" batch_bins=batch_bins,\n",
|
|
" min_batch_size=min_batch_size,\n",
|
|
" shortest_first=shortest_first,\n",
|
|
" ikey=ikey,\n",
|
|
" okey=okey, )\n",
|
|
" if count == \"frame\":\n",
|
|
" batches = batchfy_by_frame(\n",
|
|
" sorted_data,\n",
|
|
" max_frames_in=batch_frames_in,\n",
|
|
" max_frames_out=batch_frames_out,\n",
|
|
" max_frames_inout=batch_frames_inout,\n",
|
|
" min_batch_size=min_batch_size,\n",
|
|
" shortest_first=shortest_first,\n",
|
|
" ikey=ikey,\n",
|
|
" okey=okey, )\n",
|
|
" batches_list.append(batches)\n",
|
|
"\n",
|
|
" if len(batches_list) == 1:\n",
|
|
" batches = batches_list[0]\n",
|
|
" else:\n",
|
|
" # Concat list. This way is faster than \"sum(batch_list, [])\"\n",
|
|
" batches = list(itertools.chain(*batches_list))\n",
|
|
"\n",
|
|
" # for debugging\n",
|
|
" if num_batches > 0:\n",
|
|
" batches = batches[:num_batches]\n",
|
|
" logger.info(\"# minibatches: \" + str(len(batches)))\n",
|
|
"\n",
|
|
" # batch: List[List[Tuple[str, dict]]]\n",
|
|
" return batches\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 15,
|
|
"id": "acquired-hurricane",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"555\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"batch_size=10\n",
|
|
"maxlen_in=300\n",
|
|
"maxlen_out=400\n",
|
|
"minibatches=0 # for debug\n",
|
|
"min_batch_size=2\n",
|
|
"use_sortagrad=True\n",
|
|
"batch_count='seq'\n",
|
|
"batch_bins=0\n",
|
|
"batch_frames_in=3000\n",
|
|
"batch_frames_out=0\n",
|
|
"batch_frames_inout=0\n",
|
|
" \n",
|
|
"dev_data = make_batchset(\n",
|
|
" dev_json,\n",
|
|
" batch_size,\n",
|
|
" maxlen_in,\n",
|
|
" maxlen_out,\n",
|
|
" minibatches, # for debug\n",
|
|
" min_batch_size=min_batch_size,\n",
|
|
" shortest_first=use_sortagrad,\n",
|
|
" batch_sort_key=\"shuffle\",\n",
|
|
" count=batch_count,\n",
|
|
" batch_bins=batch_bins,\n",
|
|
" batch_frames_in=batch_frames_in,\n",
|
|
" batch_frames_out=batch_frames_out,\n",
|
|
" batch_frames_inout=batch_frames_inout,\n",
|
|
" iaxis=0,\n",
|
|
" oaxis=0, )\n",
|
|
"print(len(dev_data))\n",
|
|
"# for i in range(len(dev_data)):\n",
|
|
"# print(len(dev_data[i]))\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 16,
|
|
"id": "warming-malpractice",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Collecting kaldiio\n",
|
|
" Downloading kaldiio-2.17.2.tar.gz (24 kB)\n",
|
|
"Requirement already satisfied: numpy in ./tools/venv/lib/python3.7/site-packages/numpy-1.21.2-py3.7-linux-x86_64.egg (from kaldiio) (1.21.2)\n",
|
|
"Building wheels for collected packages: kaldiio\n",
|
|
" Building wheel for kaldiio (setup.py) ... \u001b[?25ldone\n",
|
|
"\u001b[?25h Created wheel for kaldiio: filename=kaldiio-2.17.2-py3-none-any.whl size=24468 sha256=cd6e066764dcc8c24a9dfe3f7bd8acda18761a6fbcb024995729da8debdb466e\n",
|
|
" Stored in directory: /root/.cache/pip/wheels/04/07/e8/45641287c59bf6ce41e22259f8680b521c31e6306cb88392ac\n",
|
|
"Successfully built kaldiio\n",
|
|
"Installing collected packages: kaldiio\n",
|
|
"Successfully installed kaldiio-2.17.2\n",
|
|
"\u001b[33mWARNING: You are using pip version 20.3.3; however, version 21.2.4 is available.\n",
|
|
"You should consider upgrading via the '/workspace/zhanghui/DeepSpeech-2.x/tools/venv/bin/python -m pip install --upgrade pip' command.\u001b[0m\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"!pip install kaldiio"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "equipped-subject",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 19,
|
|
"id": "superb-methodology",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from collections import OrderedDict\n",
|
|
"import kaldiio\n",
|
|
"\n",
|
|
"class LoadInputsAndTargets():\n",
|
|
" \"\"\"Create a mini-batch from a list of dicts\n",
|
|
"\n",
|
|
" >>> batch = [('utt1',\n",
|
|
" ... dict(input=[dict(feat='some.ark:123',\n",
|
|
" ... filetype='mat',\n",
|
|
" ... name='input1',\n",
|
|
" ... shape=[100, 80])],\n",
|
|
" ... output=[dict(tokenid='1 2 3 4',\n",
|
|
" ... name='target1',\n",
|
|
" ... shape=[4, 31])]]))\n",
|
|
" >>> l = LoadInputsAndTargets()\n",
|
|
" >>> feat, target = l(batch)\n",
|
|
"\n",
|
|
" :param: str mode: Specify the task mode, \"asr\" or \"tts\"\n",
|
|
" :param: str preprocess_conf: The path of a json file for pre-processing\n",
|
|
" :param: bool load_input: If False, not to load the input data\n",
|
|
" :param: bool load_output: If False, not to load the output data\n",
|
|
" :param: bool sort_in_input_length: Sort the mini-batch in descending order\n",
|
|
" of the input length\n",
|
|
" :param: bool use_speaker_embedding: Used for tts mode only\n",
|
|
" :param: bool use_second_target: Used for tts mode only\n",
|
|
" :param: dict preprocess_args: Set some optional arguments for preprocessing\n",
|
|
" :param: Optional[dict] preprocess_args: Used for tts mode only\n",
|
|
" \"\"\"\n",
|
|
"\n",
|
|
" def __init__(\n",
|
|
" self,\n",
|
|
" mode=\"asr\",\n",
|
|
" preprocess_conf=None,\n",
|
|
" load_input=True,\n",
|
|
" load_output=True,\n",
|
|
" sort_in_input_length=True,\n",
|
|
" preprocess_args=None,\n",
|
|
" keep_all_data_on_mem=False, ):\n",
|
|
" self._loaders = {}\n",
|
|
"\n",
|
|
" if mode not in [\"asr\"]:\n",
|
|
" raise ValueError(\"Only asr are allowed: mode={}\".format(mode))\n",
|
|
"\n",
|
|
" if preprocess_conf is not None:\n",
|
|
" self.preprocessing = AugmentationPipeline(preprocess_conf)\n",
|
|
" logging.warning(\n",
|
|
" \"[Experimental feature] Some preprocessing will be done \"\n",
|
|
" \"for the mini-batch creation using {}\".format(\n",
|
|
" self.preprocessing))\n",
|
|
" else:\n",
|
|
" # If conf doesn't exist, this function don't touch anything.\n",
|
|
" self.preprocessing = None\n",
|
|
"\n",
|
|
" self.mode = mode\n",
|
|
" self.load_output = load_output\n",
|
|
" self.load_input = load_input\n",
|
|
" self.sort_in_input_length = sort_in_input_length\n",
|
|
" if preprocess_args is None:\n",
|
|
" self.preprocess_args = {}\n",
|
|
" else:\n",
|
|
" assert isinstance(preprocess_args, dict), type(preprocess_args)\n",
|
|
" self.preprocess_args = dict(preprocess_args)\n",
|
|
"\n",
|
|
" self.keep_all_data_on_mem = keep_all_data_on_mem\n",
|
|
"\n",
|
|
" def __call__(self, batch, return_uttid=False):\n",
|
|
" \"\"\"Function to load inputs and targets from list of dicts\n",
|
|
"\n",
|
|
" :param List[Tuple[str, dict]] batch: list of dict which is subset of\n",
|
|
" loaded data.json\n",
|
|
" :param bool return_uttid: return utterance ID information for visualization\n",
|
|
" :return: list of input token id sequences [(L_1), (L_2), ..., (L_B)]\n",
|
|
" :return: list of input feature sequences\n",
|
|
" [(T_1, D), (T_2, D), ..., (T_B, D)]\n",
|
|
" :rtype: list of float ndarray\n",
|
|
" :return: list of target token id sequences [(L_1), (L_2), ..., (L_B)]\n",
|
|
" :rtype: list of int ndarray\n",
|
|
"\n",
|
|
" \"\"\"\n",
|
|
" x_feats_dict = OrderedDict() # OrderedDict[str, List[np.ndarray]]\n",
|
|
" y_feats_dict = OrderedDict() # OrderedDict[str, List[np.ndarray]]\n",
|
|
" uttid_list = [] # List[str]\n",
|
|
"\n",
|
|
" for uttid, info in batch:\n",
|
|
" uttid_list.append(uttid)\n",
|
|
"\n",
|
|
" if self.load_input:\n",
|
|
" # Note(kamo): This for-loop is for multiple inputs\n",
|
|
" for idx, inp in enumerate(info[\"input\"]):\n",
|
|
" # {\"input\":\n",
|
|
" # [{\"feat\": \"some/path.h5:F01_050C0101_PED_REAL\",\n",
|
|
" # \"filetype\": \"hdf5\",\n",
|
|
" # \"name\": \"input1\", ...}], ...}\n",
|
|
" x = self._get_from_loader(\n",
|
|
" filepath=inp[\"feat\"],\n",
|
|
" filetype=inp.get(\"filetype\", \"mat\"))\n",
|
|
" x_feats_dict.setdefault(inp[\"name\"], []).append(x)\n",
|
|
"\n",
|
|
" if self.load_output:\n",
|
|
" for idx, inp in enumerate(info[\"output\"]):\n",
|
|
" if \"tokenid\" in inp:\n",
|
|
" # ======= Legacy format for output =======\n",
|
|
" # {\"output\": [{\"tokenid\": \"1 2 3 4\"}])\n",
|
|
" x = np.fromiter(\n",
|
|
" map(int, inp[\"tokenid\"].split()), dtype=np.int64)\n",
|
|
" else:\n",
|
|
" # ======= New format =======\n",
|
|
" # {\"input\":\n",
|
|
" # [{\"feat\": \"some/path.h5:F01_050C0101_PED_REAL\",\n",
|
|
" # \"filetype\": \"hdf5\",\n",
|
|
" # \"name\": \"target1\", ...}], ...}\n",
|
|
" x = self._get_from_loader(\n",
|
|
" filepath=inp[\"feat\"],\n",
|
|
" filetype=inp.get(\"filetype\", \"mat\"))\n",
|
|
"\n",
|
|
" y_feats_dict.setdefault(inp[\"name\"], []).append(x)\n",
|
|
"\n",
|
|
" if self.mode == \"asr\":\n",
|
|
" return_batch, uttid_list = self._create_batch_asr(\n",
|
|
" x_feats_dict, y_feats_dict, uttid_list)\n",
|
|
" else:\n",
|
|
" raise NotImplementedError(self.mode)\n",
|
|
"\n",
|
|
" if self.preprocessing is not None:\n",
|
|
" # Apply pre-processing all input features\n",
|
|
" for x_name in return_batch.keys():\n",
|
|
" if x_name.startswith(\"input\"):\n",
|
|
" return_batch[x_name] = self.preprocessing(\n",
|
|
" return_batch[x_name], uttid_list,\n",
|
|
" **self.preprocess_args)\n",
|
|
"\n",
|
|
" if return_uttid:\n",
|
|
" return tuple(return_batch.values()), uttid_list\n",
|
|
"\n",
|
|
" # Doesn't return the names now.\n",
|
|
" return tuple(return_batch.values())\n",
|
|
"\n",
|
|
" def _create_batch_asr(self, x_feats_dict, y_feats_dict, uttid_list):\n",
|
|
" \"\"\"Create a OrderedDict for the mini-batch\n",
|
|
"\n",
|
|
" :param OrderedDict x_feats_dict:\n",
|
|
" e.g. {\"input1\": [ndarray, ndarray, ...],\n",
|
|
" \"input2\": [ndarray, ndarray, ...]}\n",
|
|
" :param OrderedDict y_feats_dict:\n",
|
|
" e.g. {\"target1\": [ndarray, ndarray, ...],\n",
|
|
" \"target2\": [ndarray, ndarray, ...]}\n",
|
|
" :param: List[str] uttid_list:\n",
|
|
" Give uttid_list to sort in the same order as the mini-batch\n",
|
|
" :return: batch, uttid_list\n",
|
|
" :rtype: Tuple[OrderedDict, List[str]]\n",
|
|
" \"\"\"\n",
|
|
" # handle single-input and multi-input (paralell) asr mode\n",
|
|
" xs = list(x_feats_dict.values())\n",
|
|
"\n",
|
|
" if self.load_output:\n",
|
|
" ys = list(y_feats_dict.values())\n",
|
|
" assert len(xs[0]) == len(ys[0]), (len(xs[0]), len(ys[0]))\n",
|
|
"\n",
|
|
" # get index of non-zero length samples\n",
|
|
" nonzero_idx = list(\n",
|
|
" filter(lambda i: len(ys[0][i]) > 0, range(len(ys[0]))))\n",
|
|
" for n in range(1, len(y_feats_dict)):\n",
|
|
" nonzero_idx = filter(lambda i: len(ys[n][i]) > 0, nonzero_idx)\n",
|
|
" else:\n",
|
|
" # Note(kamo): Be careful not to make nonzero_idx to a generator\n",
|
|
" nonzero_idx = list(range(len(xs[0])))\n",
|
|
"\n",
|
|
" if self.sort_in_input_length:\n",
|
|
" # sort in input lengths based on the first input\n",
|
|
" nonzero_sorted_idx = sorted(\n",
|
|
" nonzero_idx, key=lambda i: -len(xs[0][i]))\n",
|
|
" else:\n",
|
|
" nonzero_sorted_idx = nonzero_idx\n",
|
|
"\n",
|
|
" if len(nonzero_sorted_idx) != len(xs[0]):\n",
|
|
" logging.warning(\n",
|
|
" \"Target sequences include empty tokenid (batch {} -> {}).\".\n",
|
|
" format(len(xs[0]), len(nonzero_sorted_idx)))\n",
|
|
"\n",
|
|
" # remove zero-length samples\n",
|
|
" xs = [[x[i] for i in nonzero_sorted_idx] for x in xs]\n",
|
|
" uttid_list = [uttid_list[i] for i in nonzero_sorted_idx]\n",
|
|
"\n",
|
|
" x_names = list(x_feats_dict.keys())\n",
|
|
" if self.load_output:\n",
|
|
" ys = [[y[i] for i in nonzero_sorted_idx] for y in ys]\n",
|
|
" y_names = list(y_feats_dict.keys())\n",
|
|
"\n",
|
|
" # Keeping x_name and y_name, e.g. input1, for future extension\n",
|
|
" return_batch = OrderedDict([\n",
|
|
" * [(x_name, x) for x_name, x in zip(x_names, xs)],\n",
|
|
" * [(y_name, y) for y_name, y in zip(y_names, ys)],\n",
|
|
" ])\n",
|
|
" else:\n",
|
|
" return_batch = OrderedDict(\n",
|
|
" [(x_name, x) for x_name, x in zip(x_names, xs)])\n",
|
|
" return return_batch, uttid_list\n",
|
|
"\n",
|
|
" def _get_from_loader(self, filepath, filetype):\n",
|
|
" \"\"\"Return ndarray\n",
|
|
"\n",
|
|
" In order to make the fds to be opened only at the first referring,\n",
|
|
" the loader are stored in self._loaders\n",
|
|
"\n",
|
|
" >>> ndarray = loader.get_from_loader(\n",
|
|
" ... 'some/path.h5:F01_050C0101_PED_REAL', filetype='hdf5')\n",
|
|
"\n",
|
|
" :param: str filepath:\n",
|
|
" :param: str filetype:\n",
|
|
" :return:\n",
|
|
" :rtype: np.ndarray\n",
|
|
" \"\"\"\n",
|
|
" if filetype == \"hdf5\":\n",
|
|
" # e.g.\n",
|
|
" # {\"input\": [{\"feat\": \"some/path.h5:F01_050C0101_PED_REAL\",\n",
|
|
" # \"filetype\": \"hdf5\",\n",
|
|
" # -> filepath = \"some/path.h5\", key = \"F01_050C0101_PED_REAL\"\n",
|
|
" filepath, key = filepath.split(\":\", 1)\n",
|
|
"\n",
|
|
" loader = self._loaders.get(filepath)\n",
|
|
" if loader is None:\n",
|
|
" # To avoid disk access, create loader only for the first time\n",
|
|
" loader = h5py.File(filepath, \"r\")\n",
|
|
" self._loaders[filepath] = loader\n",
|
|
" return loader[key][()]\n",
|
|
" elif filetype == \"sound.hdf5\":\n",
|
|
" # e.g.\n",
|
|
" # {\"input\": [{\"feat\": \"some/path.h5:F01_050C0101_PED_REAL\",\n",
|
|
" # \"filetype\": \"sound.hdf5\",\n",
|
|
" # -> filepath = \"some/path.h5\", key = \"F01_050C0101_PED_REAL\"\n",
|
|
" filepath, key = filepath.split(\":\", 1)\n",
|
|
"\n",
|
|
" loader = self._loaders.get(filepath)\n",
|
|
" if loader is None:\n",
|
|
" # To avoid disk access, create loader only for the first time\n",
|
|
" loader = SoundHDF5File(filepath, \"r\", dtype=\"int16\")\n",
|
|
" self._loaders[filepath] = loader\n",
|
|
" array, rate = loader[key]\n",
|
|
" return array\n",
|
|
" elif filetype == \"sound\":\n",
|
|
" # e.g.\n",
|
|
" # {\"input\": [{\"feat\": \"some/path.wav\",\n",
|
|
" # \"filetype\": \"sound\"},\n",
|
|
" # Assume PCM16\n",
|
|
" if not self.keep_all_data_on_mem:\n",
|
|
" array, _ = soundfile.read(filepath, dtype=\"int16\")\n",
|
|
" return array\n",
|
|
" if filepath not in self._loaders:\n",
|
|
" array, _ = soundfile.read(filepath, dtype=\"int16\")\n",
|
|
" self._loaders[filepath] = array\n",
|
|
" return self._loaders[filepath]\n",
|
|
" elif filetype == \"npz\":\n",
|
|
" # e.g.\n",
|
|
" # {\"input\": [{\"feat\": \"some/path.npz:F01_050C0101_PED_REAL\",\n",
|
|
" # \"filetype\": \"npz\",\n",
|
|
" filepath, key = filepath.split(\":\", 1)\n",
|
|
"\n",
|
|
" loader = self._loaders.get(filepath)\n",
|
|
" if loader is None:\n",
|
|
" # To avoid disk access, create loader only for the first time\n",
|
|
" loader = np.load(filepath)\n",
|
|
" self._loaders[filepath] = loader\n",
|
|
" return loader[key]\n",
|
|
" elif filetype == \"npy\":\n",
|
|
" # e.g.\n",
|
|
" # {\"input\": [{\"feat\": \"some/path.npy\",\n",
|
|
" # \"filetype\": \"npy\"},\n",
|
|
" if not self.keep_all_data_on_mem:\n",
|
|
" return np.load(filepath)\n",
|
|
" if filepath not in self._loaders:\n",
|
|
" self._loaders[filepath] = np.load(filepath)\n",
|
|
" return self._loaders[filepath]\n",
|
|
" elif filetype in [\"mat\", \"vec\"]:\n",
|
|
" # e.g.\n",
|
|
" # {\"input\": [{\"feat\": \"some/path.ark:123\",\n",
|
|
" # \"filetype\": \"mat\"}]},\n",
|
|
" # In this case, \"123\" indicates the starting points of the matrix\n",
|
|
" # load_mat can load both matrix and vector\n",
|
|
" if not self.keep_all_data_on_mem:\n",
|
|
" return kaldiio.load_mat(filepath)\n",
|
|
" if filepath not in self._loaders:\n",
|
|
" self._loaders[filepath] = kaldiio.load_mat(filepath)\n",
|
|
" return self._loaders[filepath]\n",
|
|
" elif filetype == \"scp\":\n",
|
|
" # e.g.\n",
|
|
" # {\"input\": [{\"feat\": \"some/path.scp:F01_050C0101_PED_REAL\",\n",
|
|
" # \"filetype\": \"scp\",\n",
|
|
" filepath, key = filepath.split(\":\", 1)\n",
|
|
" loader = self._loaders.get(filepath)\n",
|
|
" if loader is None:\n",
|
|
" # To avoid disk access, create loader only for the first time\n",
|
|
" loader = kaldiio.load_scp(filepath)\n",
|
|
" self._loaders[filepath] = loader\n",
|
|
" return loader[key]\n",
|
|
" else:\n",
|
|
" raise NotImplementedError(\n",
|
|
" \"Not supported: loader_type={}\".format(filetype))\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 20,
|
|
"id": "monthly-muscle",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"preprocess_conf=None\n",
|
|
"train_mode=True\n",
|
|
"load = LoadInputsAndTargets(\n",
|
|
" mode=\"asr\",\n",
|
|
" load_output=True,\n",
|
|
" preprocess_conf=preprocess_conf,\n",
|
|
" preprocess_args={\"train\":\n",
|
|
" train_mode}, # Switch the mode of preprocessing\n",
|
|
" )"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 23,
|
|
"id": "periodic-senegal",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"res = load(dev_data[0])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 34,
|
|
"id": "7f0307eb",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"<class 'tuple'>\n",
|
|
"2\n",
|
|
"10\n",
|
|
"10\n",
|
|
"(1763, 83) float32\n",
|
|
"(73,) int64\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"print(type(res))\n",
|
|
"print(len(res))\n",
|
|
"print(len(res[0]))\n",
|
|
"print(len(res[1]))\n",
|
|
"print(res[0][0].shape, res[0][0].dtype)\n",
|
|
"print(res[1][0].shape, res[1][0].dtype)\n",
|
|
"# Tuple[Tuple[np.ndarry], Tuple[np.ndarry]]\n",
|
|
"# 2[10, 10]\n",
|
|
"# feats, labels"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 36,
|
|
"id": "humanitarian-container",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"(inputs, outputs), utts = load(dev_data[0], return_uttid=True)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 39,
|
|
"id": "heard-prize",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"['1673-143396-0008', '1650-173552-0000', '2803-154320-0000', '6267-65525-0045', '7641-96684-0029', '5338-284437-0010', '8173-294714-0033', '5543-27761-0047', '8254-115543-0043', '6467-94831-0038'] 10\n",
|
|
"10\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"print(utts, len(utts))\n",
|
|
"print(len(inputs))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 83,
|
|
"id": "convinced-animation",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import paddle\n",
|
|
"from deepspeech.io.utility import pad_list\n",
|
|
"class CustomConverter():\n",
|
|
" \"\"\"Custom batch converter.\n",
|
|
"\n",
|
|
" Args:\n",
|
|
" subsampling_factor (int): The subsampling factor.\n",
|
|
" dtype (paddle.dtype): Data type to convert.\n",
|
|
"\n",
|
|
" \"\"\"\n",
|
|
"\n",
|
|
" def __init__(self, subsampling_factor=1, dtype=np.float32):\n",
|
|
" \"\"\"Construct a CustomConverter object.\"\"\"\n",
|
|
" self.subsampling_factor = subsampling_factor\n",
|
|
" self.ignore_id = -1\n",
|
|
" self.dtype = dtype\n",
|
|
"\n",
|
|
" def __call__(self, batch):\n",
|
|
" \"\"\"Transform a batch and send it to a device.\n",
|
|
"\n",
|
|
" Args:\n",
|
|
" batch (list): The batch to transform.\n",
|
|
"\n",
|
|
" Returns:\n",
|
|
" tuple(paddle.Tensor, paddle.Tensor, paddle.Tensor)\n",
|
|
"\n",
|
|
" \"\"\"\n",
|
|
" # batch should be located in list\n",
|
|
" assert len(batch) == 1\n",
|
|
" (xs, ys), utts = batch[0]\n",
|
|
"\n",
|
|
" # perform subsampling\n",
|
|
" if self.subsampling_factor > 1:\n",
|
|
" xs = [x[::self.subsampling_factor, :] for x in xs]\n",
|
|
"\n",
|
|
" # get batch of lengths of input sequences\n",
|
|
" ilens = np.array([x.shape[0] for x in xs])\n",
|
|
"\n",
|
|
" # perform padding and convert to tensor\n",
|
|
" # currently only support real number\n",
|
|
" if xs[0].dtype.kind == \"c\":\n",
|
|
" xs_pad_real = pad_list([x.real for x in xs], 0).astype(self.dtype)\n",
|
|
" xs_pad_imag = pad_list([x.imag for x in xs], 0).astype(self.dtype)\n",
|
|
" # Note(kamo):\n",
|
|
" # {'real': ..., 'imag': ...} will be changed to ComplexTensor in E2E.\n",
|
|
" # Don't create ComplexTensor and give it E2E here\n",
|
|
" # because torch.nn.DataParellel can't handle it.\n",
|
|
" xs_pad = {\"real\": xs_pad_real, \"imag\": xs_pad_imag}\n",
|
|
" else:\n",
|
|
" xs_pad = pad_list(xs, 0).astype(self.dtype)\n",
|
|
"\n",
|
|
" # NOTE: this is for multi-output (e.g., speech translation)\n",
|
|
" ys_pad = pad_list(\n",
|
|
" [np.array(y[0][:]) if isinstance(y, tuple) else y for y in ys],\n",
|
|
" self.ignore_id)\n",
|
|
"\n",
|
|
" olens = np.array([y[0].shape[0] if isinstance(y, tuple) else y.shape[0] for y in ys])\n",
|
|
" return utts, xs_pad, ilens, ys_pad, olens"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 84,
|
|
"id": "1b6508fc",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"convert = CustomConverter()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 85,
|
|
"id": "25d655c0",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"utts, xs, ilen, ys, olen = convert([load(dev_data[0], return_uttid=True)])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 87,
|
|
"id": "a28e5141",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"['1673-143396-0008', '1650-173552-0000', '2803-154320-0000', '6267-65525-0045', '7641-96684-0029', '5338-284437-0010', '8173-294714-0033', '5543-27761-0047', '8254-115543-0043', '6467-94831-0038']\n",
|
|
"(10, 1763, 83)\n",
|
|
"(10,)\n",
|
|
"[1763 1214 1146 757 751 661 625 512 426 329]\n",
|
|
"(10, 73)\n",
|
|
"[[2896 621 4502 2176 404 198 3538 391 278 407 389 3719 4577 846\n",
|
|
" 4501 482 1004 103 116 178 4222 624 4689 176 459 89 101 3465\n",
|
|
" 3204 4502 2029 1834 2298 829 3366 278 4705 4925 482 2920 3204 2481\n",
|
|
" 448 627 1254 404 20 202 36 2047 627 2495 4504 481 479 99\n",
|
|
" 18 2079 4502 1628 202 226 4512 3267 210 278 483 234 367 4502\n",
|
|
" 2438 3204 1141]\n",
|
|
" [ 742 4501 4768 4569 742 4483 2495 4502 3040 3204 4502 3961 3204 3992\n",
|
|
" 3089 4832 4258 621 2391 4642 3218 4502 3439 235 270 313 2385 2833\n",
|
|
" 742 4502 3282 332 3 280 4237 3252 830 2387 -1 -1 -1 -1\n",
|
|
" -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1\n",
|
|
" -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1\n",
|
|
" -1 -1 -1]\n",
|
|
" [2099 278 4904 2302 124 4832 3158 482 2888 2495 482 2450 627 1560\n",
|
|
" 3158 4729 482 3514 3204 1027 3233 2391 2862 399 389 4962 2495 121\n",
|
|
" 221 7 2340 1216 1658 -1 -1 -1 -1 -1 -1 -1 -1 -1\n",
|
|
" -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1\n",
|
|
" -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1\n",
|
|
" -1 -1 -1]\n",
|
|
" [2458 2659 1362 2 404 4975 4995 487 3079 2785 2371 3158 824 2603\n",
|
|
" 4832 2323 999 2603 4832 4156 4678 627 1784 -1 -1 -1 -1 -1\n",
|
|
" -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1\n",
|
|
" -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1\n",
|
|
" -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1\n",
|
|
" -1 -1 -1]\n",
|
|
" [2458 2340 1661 101 4723 2138 4502 4690 463 332 251 2345 4534 4502\n",
|
|
" 2396 444 4501 2287 389 4531 4894 1466 959 389 1658 2584 4502 3681\n",
|
|
" 279 3204 4502 2228 3204 4502 4690 463 332 251 -1 -1 -1 -1\n",
|
|
" -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1\n",
|
|
" -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1\n",
|
|
" -1 -1 -1]\n",
|
|
" [2368 1248 208 4832 3158 482 1473 3401 999 482 4159 3838 389 478\n",
|
|
" 4572 404 3158 3063 1481 113 4499 4501 3204 4643 2 389 4111 -1\n",
|
|
" -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1\n",
|
|
" -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1\n",
|
|
" -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1\n",
|
|
" -1 -1 -1]\n",
|
|
" [2882 2932 4329 1808 4577 4350 4577 482 1636 2 389 1841 3204 3079\n",
|
|
" 1091 389 3204 2816 2079 4172 4986 4990 -1 -1 -1 -1 -1 -1\n",
|
|
" -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1\n",
|
|
" -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1\n",
|
|
" -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1\n",
|
|
" -1 -1 -1]\n",
|
|
" [4869 2598 2603 1976 96 389 478 3 4031 721 4925 2263 1259 2598\n",
|
|
" 4508 653 4979 4925 2741 252 72 236 -1 -1 -1 -1 -1 -1\n",
|
|
" -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1\n",
|
|
" -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1\n",
|
|
" -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1\n",
|
|
" -1 -1 -1]\n",
|
|
" [2458 4447 4505 713 624 3207 206 4577 4502 2404 3837 3458 2812 4936\n",
|
|
" -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1\n",
|
|
" -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1\n",
|
|
" -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1\n",
|
|
" -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1\n",
|
|
" -1 -1 -1]\n",
|
|
" [1501 3897 2537 278 2601 2 404 2603 482 2235 3388 -1 -1 -1\n",
|
|
" -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1\n",
|
|
" -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1\n",
|
|
" -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1\n",
|
|
" -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1\n",
|
|
" -1 -1 -1]]\n",
|
|
"[73 38 33 23 38 27 22 22 14 11]\n",
|
|
"float32\n",
|
|
"int64\n",
|
|
"int64\n",
|
|
"int64\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"print(utts)\n",
|
|
"print(xs.shape)\n",
|
|
"print(ilen.shape)\n",
|
|
"print(ilen)\n",
|
|
"print(ys.shape)\n",
|
|
"print(ys)\n",
|
|
"print(olen)\n",
|
|
"print(xs.dtype)\n",
|
|
"print(ilen.dtype)\n",
|
|
"print(ys.dtype)\n",
|
|
"print(olen.dtype)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "1d981df4",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.7.0"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|