{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 147,
   "id": "extensive-venice",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "'/'"
      ]
     },
     "execution_count": 147,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "%cd ..\n",
    "%pwd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 148,
   "id": "correct-window",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "manifest.dev\t  manifest.test-clean\t   manifest.train\r\n",
      "manifest.dev.raw  manifest.test-clean.raw  manifest.train.raw\r\n"
     ]
    }
   ],
   "source": [
    "!ls /workspace/zhanghui/DeepSpeech-2.x/examples/librispeech/s2/data/"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 149,
   "id": "exceptional-cheese",
   "metadata": {},
   "outputs": [],
   "source": [
    "dev_data='/workspace/zhanghui/DeepSpeech-2.x/examples/librispeech/s2/data/manifest.dev'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 150,
   "id": "extraordinary-orleans",
   "metadata": {},
   "outputs": [],
   "source": [
    "from deepspeech.frontend.utility import read_manifest"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 151,
   "id": "returning-lighter",
   "metadata": {},
   "outputs": [],
   "source": [
    "dev_json = read_manifest(dev_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 152,
   "id": "western-founder",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'input': [{'feat': '/workspace/zhanghui/asr/espnet/egs/librispeech/asr1/dump/dev/deltafalse/feats.1.ark:16',\n",
      "            'name': 'input1',\n",
      "            'shape': [1063, 83]}],\n",
      " 'output': [{'name': 'target1',\n",
      "             'shape': [41, 5002],\n",
      "             'text': 'AS I APPROACHED THE CITY I HEARD BELLS RINGING AND A '\n",
      "                     'LITTLE LATER I FOUND THE STREETS ASTIR WITH THRONGS OF '\n",
      "                     'WELL DRESSED PEOPLE IN FAMILY GROUPS WENDING THEIR WAY '\n",
      "                     'HITHER AND THITHER',\n",
      "             'token': '▁AS ▁I ▁APPROACHED ▁THE ▁CITY ▁I ▁HEARD ▁BELL S ▁RING '\n",
      "                      'ING ▁AND ▁A ▁LITTLE ▁LATER ▁I ▁FOUND ▁THE ▁STREETS ▁AS '\n",
      "                      'T IR ▁WITH ▁THRONG S ▁OF ▁WELL ▁DRESSED ▁PEOPLE ▁IN '\n",
      "                      '▁FAMILY ▁GROUP S ▁WE ND ING ▁THEIR ▁WAY ▁HITHER ▁AND '\n",
      "                      '▁THITHER',\n",
      "             'tokenid': '713 2458 676 4502 1155 2458 2351 849 389 3831 206 627 '\n",
      "                        '482 2812 2728 2458 2104 4502 4316 713 404 212 4925 '\n",
      "                        '4549 389 3204 4861 1677 3339 2495 1950 2279 389 4845 '\n",
      "                        '302 206 4504 4843 2394 627 4526'}],\n",
      " 'utt': '116-288045-0000',\n",
      " 'utt2spk': '116-288045'}\n",
      "5542\n",
      "<class 'list'>\n"
     ]
    }
   ],
   "source": [
    "from pprint import pprint\n",
    "pprint(dev_json[0])\n",
    "print(len(dev_json))\n",
    "print(type(dev_json))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 97,
   "id": "motivated-receptor",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.\n",
    "#\n",
    "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
    "# you may not use this file except in compliance with the License.\n",
    "# You may obtain a copy of the License at\n",
    "#\n",
    "#     http://www.apache.org/licenses/LICENSE-2.0\n",
    "#\n",
    "# Unless required by applicable law or agreed to in writing, software\n",
    "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
    "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
    "# See the License for the specific language governing permissions and\n",
    "# limitations under the License.\n",
    "import itertools\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "from deepspeech.utils.log import Log\n",
    "\n",
    "__all__ = [\"make_batchset\"]\n",
    "\n",
    "logger = Log(__name__).getlog()\n",
    "\n",
    "\n",
    "def batchfy_by_seq(\n",
    "        sorted_data,\n",
    "        batch_size,\n",
    "        max_length_in,\n",
    "        max_length_out,\n",
    "        min_batch_size=1,\n",
    "        shortest_first=False,\n",
    "        ikey=\"input\",\n",
    "        iaxis=0,\n",
    "        okey=\"output\",\n",
    "        oaxis=0, ):\n",
    "    \"\"\"Make batch set from json dictionary\n",
    "\n",
    "    :param List[(str, Dict[str, Any])] sorted_data: dictionary loaded from data.json\n",
    "    :param int batch_size: batch size\n",
    "    :param int max_length_in: maximum length of input to decide adaptive batch size\n",
    "    :param int max_length_out: maximum length of output to decide adaptive batch size\n",
    "    :param int min_batch_size: mininum batch size (for multi-gpu)\n",
    "    :param bool shortest_first: Sort from batch with shortest samples\n",
    "        to longest if true, otherwise reverse\n",
    "    :param str ikey: key to access input\n",
    "        (for ASR ikey=\"input\", for TTS, MT ikey=\"output\".)\n",
    "    :param int iaxis: dimension to access input\n",
    "        (for ASR, TTS iaxis=0, for MT iaxis=\"1\".)\n",
    "    :param str okey: key to access output\n",
    "        (for ASR, MT okey=\"output\". for TTS okey=\"input\".)\n",
    "    :param int oaxis: dimension to access output\n",
    "        (for ASR, TTS, MT oaxis=0, reserved for future research, -1 means all axis.)\n",
    "    :return: List[List[Tuple[str, dict]]] list of batches\n",
    "    \"\"\"\n",
    "    if batch_size <= 0:\n",
    "        raise ValueError(f\"Invalid batch_size={batch_size}\")\n",
    "\n",
    "    # check #utts is more than min_batch_size\n",
    "    if len(sorted_data) < min_batch_size:\n",
    "        raise ValueError(\n",
    "            f\"#utts({len(sorted_data)}) is less than min_batch_size({min_batch_size}).\"\n",
    "        )\n",
    "\n",
    "    # make list of minibatches\n",
    "    minibatches = []\n",
    "    start = 0\n",
    "    while True:\n",
    "        _, info = sorted_data[start]\n",
    "        ilen = int(info[ikey][iaxis][\"shape\"][0])\n",
    "        olen = (int(info[okey][oaxis][\"shape\"][0]) if oaxis >= 0 else\n",
    "                max(map(lambda x: int(x[\"shape\"][0]), info[okey])))\n",
    "        factor = max(int(ilen / max_length_in), int(olen / max_length_out))\n",
    "        # change batchsize depending on the input and output length\n",
    "        # if ilen = 1000 and max_length_in = 800\n",
    "        # then b = batchsize / 2\n",
    "        # and max(min_batches, .) avoids batchsize = 0\n",
    "        bs = max(min_batch_size, int(batch_size / (1 + factor)))\n",
    "        end = min(len(sorted_data), start + bs)\n",
    "        minibatch = sorted_data[start:end]\n",
    "        if shortest_first:\n",
    "            minibatch.reverse()\n",
    "\n",
    "        # check each batch is more than minimum batchsize\n",
    "        if len(minibatch) < min_batch_size:\n",
    "            mod = min_batch_size - len(minibatch) % min_batch_size\n",
    "            additional_minibatch = [\n",
    "                sorted_data[i] for i in np.random.randint(0, start, mod)\n",
    "            ]\n",
    "            if shortest_first:\n",
    "                additional_minibatch.reverse()\n",
    "            minibatch.extend(additional_minibatch)\n",
    "        minibatches.append(minibatch)\n",
    "\n",
    "        if end == len(sorted_data):\n",
    "            break\n",
    "        start = end\n",
    "\n",
    "    # batch: List[List[Tuple[str, dict]]]\n",
    "    return minibatches\n",
    "\n",
    "\n",
    "def batchfy_by_bin(\n",
    "        sorted_data,\n",
    "        batch_bins,\n",
    "        num_batches=0,\n",
    "        min_batch_size=1,\n",
    "        shortest_first=False,\n",
    "        ikey=\"input\",\n",
    "        okey=\"output\", ):\n",
    "    \"\"\"Make variably sized batch set, which maximizes\n",
    "\n",
    "    the number of bins up to `batch_bins`.\n",
    "\n",
    "    :param List[(str, Dict[str, Any])] sorted_data: dictionary loaded from data.json\n",
    "    :param int batch_bins: Maximum frames of a batch\n",
    "    :param int num_batches: # number of batches to use (for debug)\n",
    "    :param int min_batch_size: minimum batch size (for multi-gpu)\n",
    "    :param int test: Return only every `test` batches\n",
    "    :param bool shortest_first: Sort from batch with shortest samples\n",
    "        to longest if true, otherwise reverse\n",
    "\n",
    "    :param str ikey: key to access input (for ASR ikey=\"input\", for TTS ikey=\"output\".)\n",
    "    :param str okey: key to access output (for ASR okey=\"output\". for TTS okey=\"input\".)\n",
    "\n",
    "    :return: List[Tuple[str, Dict[str, List[Dict[str, Any]]]] list of batches\n",
    "    \"\"\"\n",
    "    if batch_bins <= 0:\n",
    "        raise ValueError(f\"invalid batch_bins={batch_bins}\")\n",
    "    length = len(sorted_data)\n",
    "    idim = int(sorted_data[0][1][ikey][0][\"shape\"][1])\n",
    "    odim = int(sorted_data[0][1][okey][0][\"shape\"][1])\n",
    "    logger.info(\"# utts: \" + str(len(sorted_data)))\n",
    "    minibatches = []\n",
    "    start = 0\n",
    "    n = 0\n",
    "    while True:\n",
    "        # Dynamic batch size depending on size of samples\n",
    "        b = 0\n",
    "        next_size = 0\n",
    "        max_olen = 0\n",
    "        while next_size < batch_bins and (start + b) < length:\n",
    "            ilen = int(sorted_data[start + b][1][ikey][0][\"shape\"][0]) * idim\n",
    "            olen = int(sorted_data[start + b][1][okey][0][\"shape\"][0]) * odim\n",
    "            if olen > max_olen:\n",
    "                max_olen = olen\n",
    "            next_size = (max_olen + ilen) * (b + 1)\n",
    "            if next_size <= batch_bins:\n",
    "                b += 1\n",
    "            elif next_size == 0:\n",
    "                raise ValueError(\n",
    "                    f\"Can't fit one sample in batch_bins ({batch_bins}): \"\n",
    "                    f\"Please increase the value\")\n",
    "        end = min(length, start + max(min_batch_size, b))\n",
    "        batch = sorted_data[start:end]\n",
    "        if shortest_first:\n",
    "            batch.reverse()\n",
    "        minibatches.append(batch)\n",
    "        # Check for min_batch_size and fixes the batches if needed\n",
    "        i = -1\n",
    "        while len(minibatches[i]) < min_batch_size:\n",
    "            missing = min_batch_size - len(minibatches[i])\n",
    "            if -i == len(minibatches):\n",
    "                minibatches[i + 1].extend(minibatches[i])\n",
    "                minibatches = minibatches[1:]\n",
    "                break\n",
    "            else:\n",
    "                minibatches[i].extend(minibatches[i - 1][:missing])\n",
    "                minibatches[i - 1] = minibatches[i - 1][missing:]\n",
    "                i -= 1\n",
    "        if end == length:\n",
    "            break\n",
    "        start = end\n",
    "        n += 1\n",
    "    if num_batches > 0:\n",
    "        minibatches = minibatches[:num_batches]\n",
    "    lengths = [len(x) for x in minibatches]\n",
    "    logger.info(\n",
    "        str(len(minibatches)) + \" batches containing from \" + str(min(lengths))\n",
    "        + \" to \" + str(max(lengths)) + \" samples \" + \"(avg \" + str(\n",
    "            int(np.mean(lengths))) + \" samples).\")\n",
    "    return minibatches\n",
    "\n",
    "\n",
    "def batchfy_by_frame(\n",
    "        sorted_data,\n",
    "        max_frames_in,\n",
    "        max_frames_out,\n",
    "        max_frames_inout,\n",
    "        num_batches=0,\n",
    "        min_batch_size=1,\n",
    "        shortest_first=False,\n",
    "        ikey=\"input\",\n",
    "        okey=\"output\", ):\n",
    "    \"\"\"Make variable batch set, which maximizes the number of frames to max_batch_frame.\n",
    "\n",
    "    :param List[(str, Dict[str, Any])] sorteddata: dictionary loaded from data.json\n",
    "    :param int max_frames_in: Maximum input frames of a batch\n",
    "    :param int max_frames_out: Maximum output frames of a batch\n",
    "    :param int max_frames_inout: Maximum input+output frames of a batch\n",
    "    :param int num_batches: # number of batches to use (for debug)\n",
    "    :param int min_batch_size: minimum batch size (for multi-gpu)\n",
    "    :param int test: Return only every `test` batches\n",
    "    :param bool shortest_first: Sort from batch with shortest samples\n",
    "        to longest if true, otherwise reverse\n",
    "\n",
    "    :param str ikey: key to access input (for ASR ikey=\"input\", for TTS ikey=\"output\".)\n",
    "    :param str okey: key to access output (for ASR okey=\"output\". for TTS okey=\"input\".)\n",
    "\n",
    "    :return: List[Tuple[str, Dict[str, List[Dict[str, Any]]]] list of batches\n",
    "    \"\"\"\n",
    "    if max_frames_in <= 0 and max_frames_out <= 0 and max_frames_inout <= 0:\n",
    "        raise ValueError(\n",
    "            \"At least, one of `--batch-frames-in`, `--batch-frames-out` or \"\n",
    "            \"`--batch-frames-inout` should be > 0\")\n",
    "    length = len(sorted_data)\n",
    "    minibatches = []\n",
    "    start = 0\n",
    "    end = 0\n",
    "    while end != length:\n",
    "        # Dynamic batch size depending on size of samples\n",
    "        b = 0\n",
    "        max_olen = 0\n",
    "        max_ilen = 0\n",
    "        while (start + b) < length:\n",
    "            ilen = int(sorted_data[start + b][1][ikey][0][\"shape\"][0])\n",
    "            if ilen > max_frames_in and max_frames_in != 0:\n",
    "                raise ValueError(\n",
    "                    f\"Can't fit one sample in --batch-frames-in ({max_frames_in}): \"\n",
    "                    f\"Please increase the value\")\n",
    "            olen = int(sorted_data[start + b][1][okey][0][\"shape\"][0])\n",
    "            if olen > max_frames_out and max_frames_out != 0:\n",
    "                raise ValueError(\n",
    "                    f\"Can't fit one sample in --batch-frames-out ({max_frames_out}): \"\n",
    "                    f\"Please increase the value\")\n",
    "            if ilen + olen > max_frames_inout and max_frames_inout != 0:\n",
    "                raise ValueError(\n",
    "                    f\"Can't fit one sample in --batch-frames-out ({max_frames_inout}): \"\n",
    "                    f\"Please increase the value\")\n",
    "            max_olen = max(max_olen, olen)\n",
    "            max_ilen = max(max_ilen, ilen)\n",
    "            in_ok = max_ilen * (b + 1) <= max_frames_in or max_frames_in == 0\n",
    "            out_ok = max_olen * (b + 1) <= max_frames_out or max_frames_out == 0\n",
    "            inout_ok = (max_ilen + max_olen) * (\n",
    "                b + 1) <= max_frames_inout or max_frames_inout == 0\n",
    "            if in_ok and out_ok and inout_ok:\n",
    "                # add more seq in the minibatch\n",
    "                b += 1\n",
    "            else:\n",
    "                # no more seq in the minibatch\n",
    "                break\n",
    "        end = min(length, start + b)\n",
    "        batch = sorted_data[start:end]\n",
    "        if shortest_first:\n",
    "            batch.reverse()\n",
    "        minibatches.append(batch)\n",
    "        # Check for min_batch_size and fixes the batches if needed\n",
    "        i = -1\n",
    "        while len(minibatches[i]) < min_batch_size:\n",
    "            missing = min_batch_size - len(minibatches[i])\n",
    "            if -i == len(minibatches):\n",
    "                minibatches[i + 1].extend(minibatches[i])\n",
    "                minibatches = minibatches[1:]\n",
    "                break\n",
    "            else:\n",
    "                minibatches[i].extend(minibatches[i - 1][:missing])\n",
    "                minibatches[i - 1] = minibatches[i - 1][missing:]\n",
    "                i -= 1\n",
    "        start = end\n",
    "    if num_batches > 0:\n",
    "        minibatches = minibatches[:num_batches]\n",
    "    lengths = [len(x) for x in minibatches]\n",
    "    logger.info(\n",
    "        str(len(minibatches)) + \" batches containing from \" + str(min(lengths))\n",
    "        + \" to \" + str(max(lengths)) + \" samples\" + \"(avg \" + str(\n",
    "            int(np.mean(lengths))) + \" samples).\")\n",
    "\n",
    "    return minibatches\n",
    "\n",
    "\n",
    "def batchfy_shuffle(data, batch_size, min_batch_size, num_batches,\n",
    "                    shortest_first):\n",
    "    import random\n",
    "\n",
    "    logger.info(\"use shuffled batch.\")\n",
    "    sorted_data = random.sample(data.items(), len(data.items()))\n",
    "    logger.info(\"# utts: \" + str(len(sorted_data)))\n",
    "    # make list of minibatches\n",
    "    minibatches = []\n",
    "    start = 0\n",
    "    while True:\n",
    "        end = min(len(sorted_data), start + batch_size)\n",
    "        # check each batch is more than minimum batchsize\n",
    "        minibatch = sorted_data[start:end]\n",
    "        if shortest_first:\n",
    "            minibatch.reverse()\n",
    "        if len(minibatch) < min_batch_size:\n",
    "            mod = min_batch_size - len(minibatch) % min_batch_size\n",
    "            additional_minibatch = [\n",
    "                sorted_data[i] for i in np.random.randint(0, start, mod)\n",
    "            ]\n",
    "            if shortest_first:\n",
    "                additional_minibatch.reverse()\n",
    "            minibatch.extend(additional_minibatch)\n",
    "        minibatches.append(minibatch)\n",
    "        if end == len(sorted_data):\n",
    "            break\n",
    "        start = end\n",
    "\n",
    "    # for debugging\n",
    "    if num_batches > 0:\n",
    "        minibatches = minibatches[:num_batches]\n",
    "        logger.info(\"# minibatches: \" + str(len(minibatches)))\n",
    "    return minibatches\n",
    "\n",
    "\n",
    "BATCH_COUNT_CHOICES = [\"auto\", \"seq\", \"bin\", \"frame\"]\n",
    "BATCH_SORT_KEY_CHOICES = [\"input\", \"output\", \"shuffle\"]\n",
    "\n",
    "\n",
    "def make_batchset(\n",
    "        data,\n",
    "        batch_size=0,\n",
    "        max_length_in=float(\"inf\"),\n",
    "        max_length_out=float(\"inf\"),\n",
    "        num_batches=0,\n",
    "        min_batch_size=1,\n",
    "        shortest_first=False,\n",
    "        batch_sort_key=\"input\",\n",
    "        count=\"auto\",\n",
    "        batch_bins=0,\n",
    "        batch_frames_in=0,\n",
    "        batch_frames_out=0,\n",
    "        batch_frames_inout=0,\n",
    "        iaxis=0,\n",
    "        oaxis=0, ):\n",
    "    \"\"\"Make batch set from json dictionary\n",
    "\n",
    "    if utts have \"category\" value,\n",
    "\n",
    "        >>> data = {'utt1': {'category': 'A', 'input': ...},\n",
    "        ...         'utt2': {'category': 'B', 'input': ...},\n",
    "        ...         'utt3': {'category': 'B', 'input': ...},\n",
    "        ...         'utt4': {'category': 'A', 'input': ...}}\n",
    "        >>> make_batchset(data, batchsize=2, ...)\n",
    "        [[('utt1', ...), ('utt4', ...)], [('utt2', ...), ('utt3': ...)]]\n",
    "\n",
    "    Note that if any utts doesn't have \"category\",\n",
    "    perform as same as batchfy_by_{count}\n",
    "\n",
    "    :param List[Dict[str, Any]] data: dictionary loaded from data.json\n",
    "    :param int batch_size: maximum number of sequences in a minibatch.\n",
    "    :param int batch_bins: maximum number of bins (frames x dim) in a minibatch.\n",
    "    :param int batch_frames_in:  maximum number of input frames in a minibatch.\n",
    "    :param int batch_frames_out: maximum number of output frames in a minibatch.\n",
    "    :param int batch_frames_out: maximum number of input+output frames in a minibatch.\n",
    "    :param str count: strategy to count maximum size of batch.\n",
    "        For choices, see espnet.asr.batchfy.BATCH_COUNT_CHOICES\n",
    "\n",
    "    :param int max_length_in: maximum length of input to decide adaptive batch size\n",
    "    :param int max_length_out: maximum length of output to decide adaptive batch size\n",
    "    :param int num_batches: # number of batches to use (for debug)\n",
    "    :param int min_batch_size: minimum batch size (for multi-gpu)\n",
    "    :param bool shortest_first: Sort from batch with shortest samples\n",
    "        to longest if true, otherwise reverse\n",
    "    :param str batch_sort_key: how to sort data before creating minibatches\n",
    "        [\"input\", \"output\", \"shuffle\"]\n",
    "    :param bool swap_io: if True, use \"input\" as output and \"output\"\n",
    "        as input in `data` dict\n",
    "    :param bool mt: if True, use 0-axis of \"output\" as output and 1-axis of \"output\"\n",
    "        as input in `data` dict\n",
    "    :param int iaxis: dimension to access input\n",
    "        (for ASR, TTS iaxis=0, for MT iaxis=\"1\".)\n",
    "    :param int oaxis: dimension to access output (for ASR, TTS, MT oaxis=0,\n",
    "        reserved for future research, -1 means all axis.)\n",
    "    :return: List[List[Tuple[str, dict]]] list of batches\n",
    "    \"\"\"\n",
    "\n",
    "    # check args\n",
    "    if count not in BATCH_COUNT_CHOICES:\n",
    "        raise ValueError(\n",
    "            f\"arg 'count' ({count}) should be one of {BATCH_COUNT_CHOICES}\")\n",
    "    if batch_sort_key not in BATCH_SORT_KEY_CHOICES:\n",
    "        raise ValueError(f\"arg 'batch_sort_key' ({batch_sort_key}) should be \"\n",
    "                         f\"one of {BATCH_SORT_KEY_CHOICES}\")\n",
    "\n",
    "    ikey = \"input\"\n",
    "    okey = \"output\"\n",
    "    batch_sort_axis = 0  # index of list \n",
    "\n",
    "    if count == \"auto\":\n",
    "        if batch_size != 0:\n",
    "            count = \"seq\"\n",
    "        elif batch_bins != 0:\n",
    "            count = \"bin\"\n",
    "        elif batch_frames_in != 0 or batch_frames_out != 0 or batch_frames_inout != 0:\n",
    "            count = \"frame\"\n",
    "        else:\n",
    "            raise ValueError(\n",
    "                f\"cannot detect `count` manually set one of {BATCH_COUNT_CHOICES}\"\n",
    "            )\n",
    "        logger.info(f\"count is auto detected as {count}\")\n",
    "\n",
    "    if count != \"seq\" and batch_sort_key == \"shuffle\":\n",
    "        raise ValueError(\n",
    "            \"batch_sort_key=shuffle is only available if batch_count=seq\")\n",
    "\n",
    "    category2data = {}  # Dict[str, dict]\n",
    "    for v in data:\n",
    "        k = v['utt']\n",
    "        category2data.setdefault(v.get(\"category\"), {})[k] = v\n",
    "\n",
    "    batches_list = []  # List[List[List[Tuple[str, dict]]]]\n",
    "    for d in category2data.values():\n",
    "        if batch_sort_key == \"shuffle\":\n",
    "            batches = batchfy_shuffle(d, batch_size, min_batch_size,\n",
    "                                      num_batches, shortest_first)\n",
    "            batches_list.append(batches)\n",
    "            continue\n",
    "\n",
    "        # sort it by input lengths (long to short)\n",
    "        sorted_data = sorted(\n",
    "            d.items(),\n",
    "            key=lambda data: int(data[1][batch_sort_key][batch_sort_axis][\"shape\"][0]),\n",
    "            reverse=not shortest_first, )\n",
    "        logger.info(\"# utts: \" + str(len(sorted_data)))\n",
    "        \n",
    "        if count == \"seq\":\n",
    "            batches = batchfy_by_seq(\n",
    "                sorted_data,\n",
    "                batch_size=batch_size,\n",
    "                max_length_in=max_length_in,\n",
    "                max_length_out=max_length_out,\n",
    "                min_batch_size=min_batch_size,\n",
    "                shortest_first=shortest_first,\n",
    "                ikey=ikey,\n",
    "                iaxis=iaxis,\n",
    "                okey=okey,\n",
    "                oaxis=oaxis, )\n",
    "        if count == \"bin\":\n",
    "            batches = batchfy_by_bin(\n",
    "                sorted_data,\n",
    "                batch_bins=batch_bins,\n",
    "                min_batch_size=min_batch_size,\n",
    "                shortest_first=shortest_first,\n",
    "                ikey=ikey,\n",
    "                okey=okey, )\n",
    "        if count == \"frame\":\n",
    "            batches = batchfy_by_frame(\n",
    "                sorted_data,\n",
    "                max_frames_in=batch_frames_in,\n",
    "                max_frames_out=batch_frames_out,\n",
    "                max_frames_inout=batch_frames_inout,\n",
    "                min_batch_size=min_batch_size,\n",
    "                shortest_first=shortest_first,\n",
    "                ikey=ikey,\n",
    "                okey=okey, )\n",
    "        batches_list.append(batches)\n",
    "\n",
    "    if len(batches_list) == 1:\n",
    "        batches = batches_list[0]\n",
    "    else:\n",
    "        # Concat list. This way is faster than \"sum(batch_list, [])\"\n",
    "        batches = list(itertools.chain(*batches_list))\n",
    "\n",
    "    # for debugging\n",
    "    if num_batches > 0:\n",
    "        batches = batches[:num_batches]\n",
    "    logger.info(\"# minibatches: \" + str(len(batches)))\n",
    "\n",
    "    # batch: List[List[Tuple[str, dict]]]\n",
    "    return batches\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 98,
   "id": "acquired-hurricane",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[INFO 2021/08/18 06:57:10 1445365138.py:284] use shuffled batch.\n",
      "[INFO 2021/08/18 06:57:10 1445365138.py:286] # utts: 5542\n",
      "[INFO 2021/08/18 06:57:10 1445365138.py:468] # minibatches: 555\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "555\n"
     ]
    }
   ],
   "source": [
    "batch_size=10\n",
    "maxlen_in=300\n",
    "maxlen_out=400\n",
    "minibatches=0  # for debug\n",
    "min_batch_size=2\n",
    "use_sortagrad=True\n",
    "batch_count='seq'\n",
    "batch_bins=0\n",
    "batch_frames_in=3000\n",
    "batch_frames_out=0\n",
    "batch_frames_inout=0\n",
    "            \n",
    "dev_data = make_batchset(\n",
    "            dev_json,\n",
    "            batch_size,\n",
    "            maxlen_in,\n",
    "            maxlen_out,\n",
    "            minibatches,  # for debug\n",
    "            min_batch_size=min_batch_size,\n",
    "            shortest_first=use_sortagrad,\n",
    "            batch_sort_key=\"shuffle\",\n",
    "            count=batch_count,\n",
    "            batch_bins=batch_bins,\n",
    "            batch_frames_in=batch_frames_in,\n",
    "            batch_frames_out=batch_frames_out,\n",
    "            batch_frames_inout=batch_frames_inout,\n",
    "            iaxis=0,\n",
    "            oaxis=0, )\n",
    "print(len(dev_data))\n",
    "# for i in range(len(dev_data)):\n",
    "#     print(len(dev_data[i]))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 99,
   "id": "warming-malpractice",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Requirement already satisfied: kaldiio in ./DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (2.17.2)\n",
      "Requirement already satisfied: numpy in ./DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/numpy-1.21.2-py3.7-linux-x86_64.egg (from kaldiio) (1.21.2)\n",
      "\u001b[33mWARNING: You are using pip version 20.3.3; however, version 21.2.4 is available.\n",
      "You should consider upgrading via the '/workspace/zhanghui/DeepSpeech-2.x/tools/venv/bin/python -m pip install --upgrade pip' command.\u001b[0m\n"
     ]
    }
   ],
   "source": [
    "!pip install kaldiio"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "equipped-subject",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 100,
   "id": "superb-methodology",
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import OrderedDict\n",
    "import kaldiio\n",
    "\n",
    "class LoadInputsAndTargets():\n",
    "    \"\"\"Create a mini-batch from a list of dicts\n",
    "\n",
    "    >>> batch = [('utt1',\n",
    "    ...           dict(input=[dict(feat='some.ark:123',\n",
    "    ...                            filetype='mat',\n",
    "    ...                            name='input1',\n",
    "    ...                            shape=[100, 80])],\n",
    "    ...                output=[dict(tokenid='1 2 3 4',\n",
    "    ...                             name='target1',\n",
    "    ...                             shape=[4, 31])]]))\n",
    "    >>> l = LoadInputsAndTargets()\n",
    "    >>> feat, target = l(batch)\n",
    "\n",
    "    :param: str mode: Specify the task mode, \"asr\" or \"tts\"\n",
    "    :param: str preprocess_conf: The path of a json file for pre-processing\n",
    "    :param: bool load_input: If False, not to load the input data\n",
    "    :param: bool load_output: If False, not to load the output data\n",
    "    :param: bool sort_in_input_length: Sort the mini-batch in descending order\n",
    "        of the input length\n",
    "    :param: bool use_speaker_embedding: Used for tts mode only\n",
    "    :param: bool use_second_target: Used for tts mode only\n",
    "    :param: dict preprocess_args: Set some optional arguments for preprocessing\n",
    "    :param: Optional[dict] preprocess_args: Used for tts mode only\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(\n",
    "            self,\n",
    "            mode=\"asr\",\n",
    "            preprocess_conf=None,\n",
    "            load_input=True,\n",
    "            load_output=True,\n",
    "            sort_in_input_length=True,\n",
    "            preprocess_args=None,\n",
    "            keep_all_data_on_mem=False, ):\n",
    "        self._loaders = {}\n",
    "\n",
    "        if mode not in [\"asr\"]:\n",
    "            raise ValueError(\"Only asr are allowed: mode={}\".format(mode))\n",
    "\n",
    "        if preprocess_conf is not None:\n",
    "            self.preprocessing = AugmentationPipeline(preprocess_conf)\n",
    "            logging.warning(\n",
    "                \"[Experimental feature] Some preprocessing will be done \"\n",
    "                \"for the mini-batch creation using {}\".format(\n",
    "                    self.preprocessing))\n",
    "        else:\n",
    "            # If conf doesn't exist, this function don't touch anything.\n",
    "            self.preprocessing = None\n",
    "\n",
    "        self.mode = mode\n",
    "        self.load_output = load_output\n",
    "        self.load_input = load_input\n",
    "        self.sort_in_input_length = sort_in_input_length\n",
    "        if preprocess_args is None:\n",
    "            self.preprocess_args = {}\n",
    "        else:\n",
    "            assert isinstance(preprocess_args, dict), type(preprocess_args)\n",
    "            self.preprocess_args = dict(preprocess_args)\n",
    "\n",
    "        self.keep_all_data_on_mem = keep_all_data_on_mem\n",
    "\n",
    "    def __call__(self, batch, return_uttid=False):\n",
    "        \"\"\"Function to load inputs and targets from list of dicts\n",
    "\n",
    "        :param List[Tuple[str, dict]] batch: list of dict which is subset of\n",
    "            loaded data.json\n",
    "        :param bool return_uttid: return utterance ID information for visualization\n",
    "        :return: list of input token id sequences [(L_1), (L_2), ..., (L_B)]\n",
    "        :return: list of input feature sequences\n",
    "            [(T_1, D), (T_2, D), ..., (T_B, D)]\n",
    "        :rtype: list of float ndarray\n",
    "        :return: list of target token id sequences [(L_1), (L_2), ..., (L_B)]\n",
    "        :rtype: list of int ndarray\n",
    "\n",
    "        \"\"\"\n",
    "        x_feats_dict = OrderedDict()  # OrderedDict[str, List[np.ndarray]]\n",
    "        y_feats_dict = OrderedDict()  # OrderedDict[str, List[np.ndarray]]\n",
    "        uttid_list = []  # List[str]\n",
    "\n",
    "        for uttid, info in batch:\n",
    "            uttid_list.append(uttid)\n",
    "\n",
    "            if self.load_input:\n",
    "                # Note(kamo): This for-loop is for multiple inputs\n",
    "                for idx, inp in enumerate(info[\"input\"]):\n",
    "                    # {\"input\":\n",
    "                    #  [{\"feat\": \"some/path.h5:F01_050C0101_PED_REAL\",\n",
    "                    #    \"filetype\": \"hdf5\",\n",
    "                    #    \"name\": \"input1\", ...}], ...}\n",
    "                    x = self._get_from_loader(\n",
    "                        filepath=inp[\"feat\"],\n",
    "                        filetype=inp.get(\"filetype\", \"mat\"))\n",
    "                    x_feats_dict.setdefault(inp[\"name\"], []).append(x)\n",
    "\n",
    "            if self.load_output:\n",
    "                for idx, inp in enumerate(info[\"output\"]):\n",
    "                    if \"tokenid\" in inp:\n",
    "                        # ======= Legacy format for output =======\n",
    "                        # {\"output\": [{\"tokenid\": \"1 2 3 4\"}])\n",
    "                        x = np.fromiter(\n",
    "                            map(int, inp[\"tokenid\"].split()), dtype=np.int64)\n",
    "                    else:\n",
    "                        # ======= New format =======\n",
    "                        # {\"input\":\n",
    "                        #  [{\"feat\": \"some/path.h5:F01_050C0101_PED_REAL\",\n",
    "                        #    \"filetype\": \"hdf5\",\n",
    "                        #    \"name\": \"target1\", ...}], ...}\n",
    "                        x = self._get_from_loader(\n",
    "                            filepath=inp[\"feat\"],\n",
    "                            filetype=inp.get(\"filetype\", \"mat\"))\n",
    "\n",
    "                    y_feats_dict.setdefault(inp[\"name\"], []).append(x)\n",
    "\n",
    "        if self.mode == \"asr\":\n",
    "            return_batch, uttid_list = self._create_batch_asr(\n",
    "                x_feats_dict, y_feats_dict, uttid_list)\n",
    "        else:\n",
    "            raise NotImplementedError(self.mode)\n",
    "\n",
    "        if self.preprocessing is not None:\n",
    "            # Apply pre-processing all input features\n",
    "            for x_name in return_batch.keys():\n",
    "                if x_name.startswith(\"input\"):\n",
    "                    return_batch[x_name] = self.preprocessing(\n",
    "                        return_batch[x_name], uttid_list,\n",
    "                        **self.preprocess_args)\n",
    "\n",
    "        if return_uttid:\n",
    "            return tuple(return_batch.values()), uttid_list\n",
    "\n",
    "        # Doesn't return the names now.\n",
    "        return tuple(return_batch.values())\n",
    "\n",
    "    def _create_batch_asr(self, x_feats_dict, y_feats_dict, uttid_list):\n",
    "        \"\"\"Create a OrderedDict for the mini-batch\n",
    "\n",
    "        :param OrderedDict x_feats_dict:\n",
    "            e.g. {\"input1\": [ndarray, ndarray, ...],\n",
    "                  \"input2\": [ndarray, ndarray, ...]}\n",
    "        :param OrderedDict y_feats_dict:\n",
    "            e.g. {\"target1\": [ndarray, ndarray, ...],\n",
    "                  \"target2\": [ndarray, ndarray, ...]}\n",
    "        :param: List[str] uttid_list:\n",
    "            Give uttid_list to sort in the same order as the mini-batch\n",
    "        :return: batch, uttid_list\n",
    "        :rtype: Tuple[OrderedDict, List[str]]\n",
    "        \"\"\"\n",
    "        # handle single-input and multi-input (paralell) asr mode\n",
    "        xs = list(x_feats_dict.values())\n",
    "\n",
    "        if self.load_output:\n",
    "            ys = list(y_feats_dict.values())\n",
    "            assert len(xs[0]) == len(ys[0]), (len(xs[0]), len(ys[0]))\n",
    "\n",
    "            # get index of non-zero length samples\n",
    "            nonzero_idx = list(\n",
    "                filter(lambda i: len(ys[0][i]) > 0, range(len(ys[0]))))\n",
    "            for n in range(1, len(y_feats_dict)):\n",
    "                nonzero_idx = filter(lambda i: len(ys[n][i]) > 0, nonzero_idx)\n",
    "        else:\n",
    "            # Note(kamo): Be careful not to make nonzero_idx to a generator\n",
    "            nonzero_idx = list(range(len(xs[0])))\n",
    "\n",
    "        if self.sort_in_input_length:\n",
    "            # sort in input lengths based on the first input\n",
    "            nonzero_sorted_idx = sorted(\n",
    "                nonzero_idx, key=lambda i: -len(xs[0][i]))\n",
    "        else:\n",
    "            nonzero_sorted_idx = nonzero_idx\n",
    "\n",
    "        if len(nonzero_sorted_idx) != len(xs[0]):\n",
    "            logging.warning(\n",
    "                \"Target sequences include empty tokenid (batch {} -> {}).\".\n",
    "                format(len(xs[0]), len(nonzero_sorted_idx)))\n",
    "\n",
    "        # remove zero-length samples\n",
    "        xs = [[x[i] for i in nonzero_sorted_idx] for x in xs]\n",
    "        uttid_list = [uttid_list[i] for i in nonzero_sorted_idx]\n",
    "\n",
    "        x_names = list(x_feats_dict.keys())\n",
    "        if self.load_output:\n",
    "            ys = [[y[i] for i in nonzero_sorted_idx] for y in ys]\n",
    "            y_names = list(y_feats_dict.keys())\n",
    "\n",
    "            # Keeping x_name and y_name, e.g. input1, for future extension\n",
    "            return_batch = OrderedDict([\n",
    "                * [(x_name, x) for x_name, x in zip(x_names, xs)],\n",
    "                * [(y_name, y) for y_name, y in zip(y_names, ys)],\n",
    "            ])\n",
    "        else:\n",
    "            return_batch = OrderedDict(\n",
    "                [(x_name, x) for x_name, x in zip(x_names, xs)])\n",
    "        return return_batch, uttid_list\n",
    "\n",
    "    def _get_from_loader(self, filepath, filetype):\n",
    "        \"\"\"Return ndarray\n",
    "\n",
    "        In order to make the fds to be opened only at the first referring,\n",
    "        the loader are stored in self._loaders\n",
    "\n",
    "        >>> ndarray = loader.get_from_loader(\n",
    "        ...     'some/path.h5:F01_050C0101_PED_REAL', filetype='hdf5')\n",
    "\n",
    "        :param: str filepath:\n",
    "        :param: str filetype:\n",
    "        :return:\n",
    "        :rtype: np.ndarray\n",
    "        \"\"\"\n",
    "        if filetype == \"hdf5\":\n",
    "            # e.g.\n",
    "            #    {\"input\": [{\"feat\": \"some/path.h5:F01_050C0101_PED_REAL\",\n",
    "            #                \"filetype\": \"hdf5\",\n",
    "            # -> filepath = \"some/path.h5\", key = \"F01_050C0101_PED_REAL\"\n",
    "            filepath, key = filepath.split(\":\", 1)\n",
    "\n",
    "            loader = self._loaders.get(filepath)\n",
    "            if loader is None:\n",
    "                # To avoid disk access, create loader only for the first time\n",
    "                loader = h5py.File(filepath, \"r\")\n",
    "                self._loaders[filepath] = loader\n",
    "            return loader[key][()]\n",
    "        elif filetype == \"sound.hdf5\":\n",
    "            # e.g.\n",
    "            #    {\"input\": [{\"feat\": \"some/path.h5:F01_050C0101_PED_REAL\",\n",
    "            #                \"filetype\": \"sound.hdf5\",\n",
    "            # -> filepath = \"some/path.h5\", key = \"F01_050C0101_PED_REAL\"\n",
    "            filepath, key = filepath.split(\":\", 1)\n",
    "\n",
    "            loader = self._loaders.get(filepath)\n",
    "            if loader is None:\n",
    "                # To avoid disk access, create loader only for the first time\n",
    "                loader = SoundHDF5File(filepath, \"r\", dtype=\"int16\")\n",
    "                self._loaders[filepath] = loader\n",
    "            array, rate = loader[key]\n",
    "            return array\n",
    "        elif filetype == \"sound\":\n",
    "            # e.g.\n",
    "            #    {\"input\": [{\"feat\": \"some/path.wav\",\n",
    "            #                \"filetype\": \"sound\"},\n",
    "            # Assume PCM16\n",
    "            if not self.keep_all_data_on_mem:\n",
    "                array, _ = soundfile.read(filepath, dtype=\"int16\")\n",
    "                return array\n",
    "            if filepath not in self._loaders:\n",
    "                array, _ = soundfile.read(filepath, dtype=\"int16\")\n",
    "                self._loaders[filepath] = array\n",
    "            return self._loaders[filepath]\n",
    "        elif filetype == \"npz\":\n",
    "            # e.g.\n",
    "            #    {\"input\": [{\"feat\": \"some/path.npz:F01_050C0101_PED_REAL\",\n",
    "            #                \"filetype\": \"npz\",\n",
    "            filepath, key = filepath.split(\":\", 1)\n",
    "\n",
    "            loader = self._loaders.get(filepath)\n",
    "            if loader is None:\n",
    "                # To avoid disk access, create loader only for the first time\n",
    "                loader = np.load(filepath)\n",
    "                self._loaders[filepath] = loader\n",
    "            return loader[key]\n",
    "        elif filetype == \"npy\":\n",
    "            # e.g.\n",
    "            #    {\"input\": [{\"feat\": \"some/path.npy\",\n",
    "            #                \"filetype\": \"npy\"},\n",
    "            if not self.keep_all_data_on_mem:\n",
    "                return np.load(filepath)\n",
    "            if filepath not in self._loaders:\n",
    "                self._loaders[filepath] = np.load(filepath)\n",
    "            return self._loaders[filepath]\n",
    "        elif filetype in [\"mat\", \"vec\"]:\n",
    "            # e.g.\n",
    "            #    {\"input\": [{\"feat\": \"some/path.ark:123\",\n",
    "            #                \"filetype\": \"mat\"}]},\n",
    "            # In this case, \"123\" indicates the starting points of the matrix\n",
    "            # load_mat can load both matrix and vector\n",
    "            if not self.keep_all_data_on_mem:\n",
    "                return kaldiio.load_mat(filepath)\n",
    "            if filepath not in self._loaders:\n",
    "                self._loaders[filepath] = kaldiio.load_mat(filepath)\n",
    "            return self._loaders[filepath]\n",
    "        elif filetype == \"scp\":\n",
    "            # e.g.\n",
    "            #    {\"input\": [{\"feat\": \"some/path.scp:F01_050C0101_PED_REAL\",\n",
    "            #                \"filetype\": \"scp\",\n",
    "            filepath, key = filepath.split(\":\", 1)\n",
    "            loader = self._loaders.get(filepath)\n",
    "            if loader is None:\n",
    "                # To avoid disk access, create loader only for the first time\n",
    "                loader = kaldiio.load_scp(filepath)\n",
    "                self._loaders[filepath] = loader\n",
    "            return loader[key]\n",
    "        else:\n",
    "            raise NotImplementedError(\n",
    "                \"Not supported: loader_type={}\".format(filetype))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 101,
   "id": "monthly-muscle",
   "metadata": {},
   "outputs": [],
   "source": [
    "preprocess_conf=None\n",
    "train_mode=True\n",
    "load = LoadInputsAndTargets(\n",
    "            mode=\"asr\",\n",
    "            load_output=True,\n",
    "            preprocess_conf=preprocess_conf,\n",
    "            preprocess_args={\"train\":\n",
    "                             train_mode},  # Switch the mode of preprocessing\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 102,
   "id": "periodic-senegal",
   "metadata": {},
   "outputs": [],
   "source": [
    "res = load(dev_data[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 103,
   "id": "502d3f4d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<class 'tuple'>\n",
      "2\n",
      "10\n",
      "10\n",
      "(1174, 83) float32\n",
      "(29,) int64\n"
     ]
    }
   ],
   "source": [
    "print(type(res))\n",
    "print(len(res))\n",
    "print(len(res[0]))\n",
    "print(len(res[1]))\n",
    "print(res[0][0].shape, res[0][0].dtype)\n",
    "print(res[1][0].shape, res[1][0].dtype)\n",
    "# Tuple[Tuple[np.ndarry], Tuple[np.ndarry]]\n",
    "# 2[10, 10]\n",
    "# feats, labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 104,
   "id": "humanitarian-container",
   "metadata": {},
   "outputs": [],
   "source": [
    "(inputs, outputs), utts = load(dev_data[0], return_uttid=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 105,
   "id": "heard-prize",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['4572-112383-0005', '6313-66125-0015', '251-137823-0022', '2277-149896-0030', '652-130726-0032', '5895-34615-0013', '1462-170138-0002', '777-126732-0008', '3660-172182-0021', '2277-149896-0027'] 10\n",
      "10\n"
     ]
    }
   ],
   "source": [
    "print(utts, len(utts))\n",
    "print(len(inputs))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 106,
   "id": "convinced-animation",
   "metadata": {},
   "outputs": [],
   "source": [
    "import paddle\n",
    "from deepspeech.io.utility import pad_list\n",
    "class CustomConverter():\n",
    "    \"\"\"Custom batch converter.\n",
    "\n",
    "    Args:\n",
    "        subsampling_factor (int): The subsampling factor.\n",
    "        dtype (paddle.dtype): Data type to convert.\n",
    "\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, subsampling_factor=1, dtype=np.float32):\n",
    "        \"\"\"Construct a CustomConverter object.\"\"\"\n",
    "        self.subsampling_factor = subsampling_factor\n",
    "        self.ignore_id = -1\n",
    "        self.dtype = dtype\n",
    "\n",
    "    def __call__(self, batch):\n",
    "        \"\"\"Transform a batch and send it to a device.\n",
    "\n",
    "        Args:\n",
    "            batch (list): The batch to transform.\n",
    "\n",
    "        Returns:\n",
    "            tuple(paddle.Tensor, paddle.Tensor, paddle.Tensor)\n",
    "\n",
    "        \"\"\"\n",
    "        # batch should be located in list\n",
    "        assert len(batch) == 1\n",
    "        (xs, ys), utts = batch[0]\n",
    "\n",
    "        # perform subsampling\n",
    "        if self.subsampling_factor > 1:\n",
    "            xs = [x[::self.subsampling_factor, :] for x in xs]\n",
    "\n",
    "        # get batch of lengths of input sequences\n",
    "        ilens = np.array([x.shape[0] for x in xs])\n",
    "\n",
    "        # perform padding and convert to tensor\n",
    "        # currently only support real number\n",
    "        if xs[0].dtype.kind == \"c\":\n",
    "            xs_pad_real = pad_list([x.real for x in xs], 0).astype(self.dtype)\n",
    "            xs_pad_imag = pad_list([x.imag for x in xs], 0).astype(self.dtype)\n",
    "            # Note(kamo):\n",
    "            # {'real': ..., 'imag': ...} will be changed to ComplexTensor in E2E.\n",
    "            # Don't create ComplexTensor and give it E2E here\n",
    "            # because torch.nn.DataParellel can't handle it.\n",
    "            xs_pad = {\"real\": xs_pad_real, \"imag\": xs_pad_imag}\n",
    "        else:\n",
    "            xs_pad = pad_list(xs, 0).astype(self.dtype)\n",
    "\n",
    "        # NOTE: this is for multi-output (e.g., speech translation)\n",
    "        ys_pad = pad_list(\n",
    "            [np.array(y[0][:]) if isinstance(y, tuple) else y for y in ys],\n",
    "            self.ignore_id)\n",
    "\n",
    "        olens = np.array([y[0].shape[0] if isinstance(y, tuple) else y.shape[0] for y in ys])\n",
    "        return utts, xs_pad, ilens, ys_pad, olens"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 107,
   "id": "0b92ade5",
   "metadata": {},
   "outputs": [],
   "source": [
    "convert = CustomConverter()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 108,
   "id": "8dbd847c",
   "metadata": {},
   "outputs": [],
   "source": [
    "utts, xs, ilen, ys, olen = convert([load(dev_data[0], return_uttid=True)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 109,
   "id": "31c085f4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['4572-112383-0005', '6313-66125-0015', '251-137823-0022', '2277-149896-0030', '652-130726-0032', '5895-34615-0013', '1462-170138-0002', '777-126732-0008', '3660-172182-0021', '2277-149896-0027']\n",
      "(10, 1174, 83)\n",
      "(10,)\n",
      "[1174  821  716  628  597  473  463  441  419  358]\n",
      "(10, 32)\n",
      "[[4502 2404 4223 3204 4502  587 1018 3861 2932  713 2458 2916  253 4508\n",
      "   627 1395  713 4504  957 2761  209 2967 3173 3918 2598 4100    3 2816\n",
      "  4990   -1   -1   -1]\n",
      " [1005  451  210  278 3411  206  482 2307  573 4502 3848 4577 4273 2388\n",
      "  4444   89 4919  278 1264 4501 2371    3  139  113 2603 4962 3158 3325\n",
      "  4577  814 4587 1422]\n",
      " [2345 4144 2291  200  713 2345  532  999 2458 3076  545 2458 4832 3038\n",
      "  4499  482 2812 1260 3080   -1   -1   -1   -1   -1   -1   -1   -1   -1\n",
      "    -1   -1   -1   -1]\n",
      " [2345  832 4577 4920 4501 2345 2298 1236  381  288  389  101 2495 4172\n",
      "  4843 3233 3245 4501 2345 2298 3987 4502 3023 3353 2345 1361 1635 2603\n",
      "  4723 2371   -1   -1]\n",
      " [4502 4207  432 3204 4502 2396  125  935  433 2598  483   18  327    2\n",
      "   389  627 4512 2340  713  482 1981 4525 4031  269 2030 1340  101 2495\n",
      "  4013 4844   -1   -1]\n",
      " [4502 4892 3204 1892 3780  389  482 2774 3013   89  192 2495 4502 3475\n",
      "   389   66  370  343  404   -1   -1   -1   -1   -1   -1   -1   -1   -1\n",
      "    -1   -1   -1   -1]\n",
      " [2458 2314 4577 2340 2863 1254  303  269    2  389  932 2079 4577  299\n",
      "   195 3233 4508    2   89  814 3144 1091 3204 3250 2193 3414   -1   -1\n",
      "    -1   -1   -1   -1]\n",
      " [2391 1785  443   78   39 4962 2340  829  599 4593  278 4681  202  407\n",
      "   269  194  182 4577  482 4308   -1   -1   -1   -1   -1   -1   -1   -1\n",
      "    -1   -1   -1   -1]\n",
      " [ 627 4873 2175  363  202  404 1018 4577 4502 3412 4875 2286  107  122\n",
      "  4832 2345 3896   89 2368   -1   -1   -1   -1   -1   -1   -1   -1   -1\n",
      "    -1   -1   -1   -1]\n",
      " [ 481  174  474  599 1881 3252 2842  742 4502 2545  107   88 3204 4525\n",
      "  4517   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1\n",
      "    -1   -1   -1   -1]]\n",
      "[29 32 19 30 30 19 26 20 19 15]\n",
      "float32\n",
      "int64\n",
      "int64\n",
      "int64\n"
     ]
    }
   ],
   "source": [
    "print(utts)\n",
    "print(xs.shape)\n",
    "print(ilen.shape)\n",
    "print(ilen)\n",
    "print(ys.shape)\n",
    "print(ys)\n",
    "print(olen)\n",
    "print(xs.dtype)\n",
    "print(ilen.dtype)\n",
    "print(ys.dtype)\n",
    "print(olen.dtype)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 110,
   "id": "72e9ba60",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 230,
   "id": "64593e5f",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "from paddle.io import DataLoader\n",
    "\n",
    "from deepspeech.frontend.utility import read_manifest\n",
    "from deepspeech.io.batchfy import make_batchset\n",
    "from deepspeech.io.converter import CustomConverter\n",
    "from deepspeech.io.dataset import TransformDataset\n",
    "from deepspeech.io.reader import LoadInputsAndTargets\n",
    "from deepspeech.utils.log import Log\n",
    "\n",
    "\n",
    "logger = Log(__name__).getlog()\n",
    "\n",
    "\n",
    "class BatchDataLoader():\n",
    "    def __init__(self,\n",
    "                 json_file: str,\n",
    "                 train_mode: bool,\n",
    "                 sortagrad: bool=False,\n",
    "                 batch_size: int=0,\n",
    "                 maxlen_in: float=float('inf'),\n",
    "                 maxlen_out: float=float('inf'),\n",
    "                 minibatches: int=0,\n",
    "                 mini_batch_size: int=1,\n",
    "                 batch_count: str='auto',\n",
    "                 batch_bins: int=0,\n",
    "                 batch_frames_in: int=0,\n",
    "                 batch_frames_out: int=0,\n",
    "                 batch_frames_inout: int=0,\n",
    "                 preprocess_conf=None,\n",
    "                 n_iter_processes: int=1,\n",
    "                 subsampling_factor: int=1,\n",
    "                 num_encs: int=1):\n",
    "        self.json_file = json_file\n",
    "        self.train_mode = train_mode\n",
    "        self.use_sortagrad = sortagrad == -1 or sortagrad > 0\n",
    "        self.batch_size = batch_size\n",
    "        self.maxlen_in = maxlen_in\n",
    "        self.maxlen_out = maxlen_out\n",
    "        self.batch_count = batch_count\n",
    "        self.batch_bins = batch_bins\n",
    "        self.batch_frames_in = batch_frames_in\n",
    "        self.batch_frames_out = batch_frames_out\n",
    "        self.batch_frames_inout = batch_frames_inout\n",
    "        self.subsampling_factor = subsampling_factor\n",
    "        self.num_encs = num_encs\n",
    "        self.preprocess_conf = preprocess_conf\n",
    "        self.n_iter_processes = n_iter_processes\n",
    "\n",
    "        \n",
    "        # read json data\n",
    "        self.data_json = read_manifest(json_file)\n",
    "\n",
    "        # make minibatch list (variable length)\n",
    "        self.minibaches = make_batchset(\n",
    "            self.data_json,\n",
    "            batch_size,\n",
    "            maxlen_in,\n",
    "            maxlen_out,\n",
    "            minibatches,  # for debug\n",
    "            min_batch_size=mini_batch_size,\n",
    "            shortest_first=self.use_sortagrad,\n",
    "            count=batch_count,\n",
    "            batch_bins=batch_bins,\n",
    "            batch_frames_in=batch_frames_in,\n",
    "            batch_frames_out=batch_frames_out,\n",
    "            batch_frames_inout=batch_frames_inout,\n",
    "            iaxis=0,\n",
    "            oaxis=0, )\n",
    "\n",
    "        # data reader\n",
    "        self.reader = LoadInputsAndTargets(\n",
    "            mode=\"asr\",\n",
    "            load_output=True,\n",
    "            preprocess_conf=preprocess_conf,\n",
    "            preprocess_args={\"train\":\n",
    "                             train_mode},  # Switch the mode of preprocessing\n",
    "        )\n",
    "\n",
    "        # Setup a converter\n",
    "        if num_encs == 1:\n",
    "            self.converter = CustomConverter(\n",
    "                subsampling_factor=subsampling_factor, dtype=np.float32)\n",
    "        else:\n",
    "            assert NotImplementedError(\"not impl CustomConverterMulEnc.\")\n",
    "\n",
    "        # hack to make batchsize argument as 1\n",
    "        # actual bathsize is included in a list\n",
    "        # default collate function converts numpy array to pytorch tensor\n",
    "        # we used an empty collate function instead which returns list\n",
    "        self.dataset = TransformDataset(self.minibaches, \n",
    "                                        lambda data: self.converter([self.reader(data, return_uttid=True)]))\n",
    "        self.dataloader = DataLoader(\n",
    "            dataset=self.dataset,\n",
    "            batch_size=1,\n",
    "            shuffle=not use_sortagrad if train_mode else False,\n",
    "            collate_fn=lambda x: x[0],\n",
    "            num_workers=n_iter_processes, )\n",
    "\n",
    "    def __repr__(self):\n",
    "        echo = f\"<{self.__class__.__module__}.{self.__class__.__name__} object at {hex(id(self))}> \"\n",
    "        echo += f\"train_mode: {self.train_mode}, \"\n",
    "        echo += f\"sortagrad: {self.use_sortagrad}, \"\n",
    "        echo += f\"batch_size: {self.batch_size}, \"\n",
    "        echo += f\"maxlen_in: {self.maxlen_in}, \"\n",
    "        echo += f\"maxlen_out: {self.maxlen_out}, \"\n",
    "        echo += f\"batch_count: {self.batch_count}, \"\n",
    "        echo += f\"batch_bins: {self.batch_bins}, \"\n",
    "        echo += f\"batch_frames_in: {self.batch_frames_in}, \"\n",
    "        echo += f\"batch_frames_out: {self.batch_frames_out}, \"\n",
    "        echo += f\"batch_frames_inout: {self.batch_frames_inout}, \"\n",
    "        echo += f\"subsampling_factor: {self.subsampling_factor}, \"\n",
    "        echo += f\"num_encs: {self.num_encs}, \"\n",
    "        echo += f\"num_workers: {self.n_iter_processes}, \"\n",
    "        echo += f\"file: {self.json_file}\"\n",
    "        return echo\n",
    "    \n",
    "    def __len__(self):\n",
    "        return len(self.dataloader)\n",
    "    \n",
    "    def __iter__(self):\n",
    "        return self.dataloader.__iter__()\n",
    "    \n",
    "    def __call__(self):\n",
    "        return self.__iter__()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 231,
   "id": "fcea3fd0",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[INFO 2021/08/18 07:42:23 batchfy.py:399] count is auto detected as seq\n",
      "[INFO 2021/08/18 07:42:23 batchfy.py:423] # utts: 5542\n",
      "[INFO 2021/08/18 07:42:23 batchfy.py:466] # minibatches: 278\n"
     ]
    }
   ],
   "source": [
    "train = BatchDataLoader(dev_data, True, batch_size=20)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 232,
   "id": "e2a2c9a8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "278\n",
      "['__call__', '__class__', '__delattr__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__iter__', '__le__', '__len__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__', '__str__', '__subclasshook__', '__weakref__', 'auto_collate_batch', 'batch_sampler', 'batch_size', 'collate_fn', 'dataset', 'dataset_kind', 'feed_list', 'from_dataset', 'from_generator', 'num_workers', 'pin_memory', 'places', 'return_list', 'timeout', 'use_buffer_reader', 'use_shared_memory', 'worker_init_fn']\n",
      "<__main__.BatchDataLoader object at 0x7fdddba35470> train_mode: True, sortagrad: False, batch_size: 20, maxlen_in: inf, maxlen_out: inf, batch_count: auto, batch_bins: 0, batch_frames_in: 0, batch_frames_out: 0, batch_frames_inout: 0, subsampling_factor: 1, num_encs: 1, num_workers: 1, file: /workspace/zhanghui/DeepSpeech-2.x/examples/librispeech/s2/data/manifest.dev\n",
      "278\n"
     ]
    }
   ],
   "source": [
    "print(len(train.dataloader))\n",
    "print(dir(train.dataloader))\n",
    "print(train)\n",
    "print(len(train))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 220,
   "id": "a5ba7d6e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['7601-101619-0003', '1255-138279-0000', '1272-128104-0004', '6123-59150-0027', '2078-142845-0025', '7850-73752-0018', '4570-24733-0004', '2506-169427-0002', '7601-101619-0004', '3170-137482-0000', '6267-53049-0019', '4570-14911-0009', '174-168635-0018', '7601-291468-0004', '3576-138058-0022', '1919-142785-0007', '6467-62797-0007', '4153-61735-0005', '1686-142278-0003', '2506-169427-0000']\n",
      "Tensor(shape=[20, 2961, 83], dtype=float32, place=CUDAPinnedPlace, stop_gradient=True,\n",
      "       [[[-1.99415934, -1.80315673, -1.88801885, ...,  0.86933994, -0.59853148,  0.02596200],\n",
      "         [-1.95346808, -1.84891188, -2.17492867, ...,  0.83640492, -0.59853148, -0.11333394],\n",
      "         [-2.27899861, -2.21495342, -2.58480024, ...,  0.91874266, -0.59853148, -0.31453922],\n",
      "         ...,\n",
      "         [-2.64522028, -2.35221887, -2.91269732, ...,  1.48994756, -0.16100442,  0.36646330],\n",
      "         [-2.40107250, -2.21495342, -2.37986445, ...,  1.44072104, -0.13220564,  0.12656468],\n",
      "         [-2.15692472, -1.89466715, -2.25690317, ...,  1.31273174, -0.09620714, -0.15202725]],\n",
      "\n",
      "        [[-0.28859532, -0.29033494, -0.86576819, ...,  1.37753224, -0.30570769,  0.25806731],\n",
      "         [-0.20149794, -0.17814466, -0.59891301, ...,  1.35188794, -0.30570769, -0.02964944],\n",
      "         [-0.34947991, -0.33597648, -0.96877253, ...,  1.38394332, -0.30570769, -0.38376236],\n",
      "         ...,\n",
      "         [ 0.        ,  0.        ,  0.        , ...,  0.        ,  0.        ,  0.        ],\n",
      "         [ 0.        ,  0.        ,  0.        , ...,  0.        ,  0.        ,  0.        ],\n",
      "         [ 0.        ,  0.        ,  0.        , ...,  0.        ,  0.        ,  0.        ]],\n",
      "\n",
      "        [[-0.44914246, -0.33902276, -0.78237975, ...,  1.38218808,  0.29214793, -0.16815147],\n",
      "         [-0.55490732, -0.41596055, -0.84425378, ...,  1.34530187,  0.25002354, -0.04004869],\n",
      "         [-0.83694696, -0.62112784, -1.07112527, ...,  1.19160914,  0.20789915,  0.37984371],\n",
      "         ...,\n",
      "         [ 0.        ,  0.        ,  0.        , ...,  0.        ,  0.        ,  0.        ],\n",
      "         [ 0.        ,  0.        ,  0.        , ...,  0.        ,  0.        ,  0.        ],\n",
      "         [ 0.        ,  0.        ,  0.        , ...,  0.        ,  0.        ,  0.        ]],\n",
      "\n",
      "        ...,\n",
      "\n",
      "        [[-1.24343657, -0.94188881, -1.41092563, ...,  0.96716309,  0.60345763,  0.15360183],\n",
      "         [-1.19466043, -0.80585432, -0.49723154, ...,  1.06735480,  0.60345763,  0.14511746],\n",
      "         [-0.94079566, -0.59330046, -0.40948665, ...,  0.82244170,  0.55614340,  0.28086722],\n",
      "         ...,\n",
      "         [ 0.        ,  0.        ,  0.        , ...,  0.        ,  0.        ,  0.        ],\n",
      "         [ 0.        ,  0.        ,  0.        , ...,  0.        ,  0.        ,  0.        ],\n",
      "         [ 0.        ,  0.        ,  0.        , ...,  0.        ,  0.        ,  0.        ]],\n",
      "\n",
      "        [[ 0.21757117,  0.11361472, -0.33262897, ...,  0.76338506, -0.10711290, -0.57754958],\n",
      "         [-1.00205481, -0.61152041, -0.47124696, ...,  1.11897349, -0.10711290,  0.24931324],\n",
      "         [-1.03929281, -1.20336759, -1.16433656, ...,  0.88888687, -0.10711290, -0.04115745],\n",
      "         ...,\n",
      "         [ 0.        ,  0.        ,  0.        , ...,  0.        ,  0.        ,  0.        ],\n",
      "         [ 0.        ,  0.        ,  0.        , ...,  0.        ,  0.        ,  0.        ],\n",
      "         [ 0.        ,  0.        ,  0.        , ...,  0.        ,  0.        ,  0.        ]],\n",
      "\n",
      "        [[-1.25289667, -1.05046368, -0.82881606, ...,  1.23991334,  0.61702502,  0.05275881],\n",
      "         [-1.19659519, -0.78677225, -0.80407262, ...,  1.27644968,  0.61702502, -0.35079369],\n",
      "         [-1.49687004, -1.01750231, -0.82881606, ...,  1.29106426,  0.65006059,  0.17958963],\n",
      "         ...,\n",
      "         [ 0.        ,  0.        ,  0.        , ...,  0.        ,  0.        ,  0.        ],\n",
      "         [ 0.        ,  0.        ,  0.        , ...,  0.        ,  0.        ,  0.        ],\n",
      "         [ 0.        ,  0.        ,  0.        , ...,  0.        ,  0.        ,  0.        ]]])\n",
      "Tensor(shape=[20], dtype=int64, place=CUDAPinnedPlace, stop_gradient=True,\n",
      "       [2961, 2948, 2938, 2907, 2904, 2838, 2832, 2819, 2815, 2797, 2775, 2710, 2709, 2696, 2688, 2661, 2616, 2595, 2589, 2576])\n",
      "Tensor(shape=[20, 133], dtype=int64, place=CUDAPinnedPlace, stop_gradient=True,\n",
      "       [[3098, 1595,  389, ..., -1  , -1  , -1  ],\n",
      "        [2603, 4832,  482, ..., -1  , -1  , -1  ],\n",
      "        [2796,  303,  269, ..., -1  , -1  , -1  ],\n",
      "        ...,\n",
      "        [3218, 3673,  206, ..., -1  , -1  , -1  ],\n",
      "        [2371, 4832, 4031, ..., -1  , -1  , -1  ],\n",
      "        [2570, 2433, 4285, ..., -1  , -1  , -1  ]])\n",
      "Tensor(shape=[20], dtype=int64, place=CUDAPinnedPlace, stop_gradient=True,\n",
      "       [80 , 83 , 102, 133, 82 , 102, 71 , 91 , 68 , 81 , 86 , 67 , 71 , 95 , 65 , 88 , 97 , 98 , 89 , 72 ])\n"
     ]
    }
   ],
   "source": [
    "for batch in train:\n",
    "    utts, xs, ilens, ys, olens = batch\n",
    "    print(utts)\n",
    "    print(xs)\n",
    "    print(ilens)\n",
    "    print(ys)\n",
    "    print(olens)\n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3c974a1e",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}