diff --git a/.flake8 b/.flake8
index 722899439..44685f23a 100644
--- a/.flake8
+++ b/.flake8
@@ -42,6 +42,10 @@ ignore =
     # these ignores are from flake8-comprehensions; please fix!
     C400,C401,C402,C403,C404,C405,C407,C411,C413,C414,C415
 
+
+per-file-ignores =
+    */__init__.py: F401
+
 # Specify the list of error codes you wish Flake8 to report.
 select =
     E,
diff --git a/.notebook/espnet_dataloader.ipynb b/.notebook/espnet_dataloader.ipynb
new file mode 100644
index 000000000..1bfc13e3c
--- /dev/null
+++ b/.notebook/espnet_dataloader.ipynb
@@ -0,0 +1,1541 @@
+{
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": 147,
+   "id": "extensive-venice",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "/\n"
+     ]
+    },
+    {
+     "data": {
+      "text/plain": [
+       "'/'"
+      ]
+     },
+     "execution_count": 147,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "%cd ..\n",
+    "%pwd"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 148,
+   "id": "correct-window",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "manifest.dev\t  manifest.test-clean\t   manifest.train\r\n",
+      "manifest.dev.raw  manifest.test-clean.raw  manifest.train.raw\r\n"
+     ]
+    }
+   ],
+   "source": [
+    "!ls /workspace/zhanghui/DeepSpeech-2.x/examples/librispeech/s2/data/"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 149,
+   "id": "exceptional-cheese",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "dev_data='/workspace/zhanghui/DeepSpeech-2.x/examples/librispeech/s2/data/manifest.dev'"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 150,
+   "id": "extraordinary-orleans",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "from deepspeech.frontend.utility import read_manifest"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 151,
+   "id": "returning-lighter",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "dev_json = read_manifest(dev_data)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 152,
+   "id": "western-founder",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'input': [{'feat': '/workspace/zhanghui/asr/espnet/egs/librispeech/asr1/dump/dev/deltafalse/feats.1.ark:16',\n",
+      "            'name': 'input1',\n",
+      "            'shape': [1063, 83]}],\n",
+      " 'output': [{'name': 'target1',\n",
+      "             'shape': [41, 5002],\n",
+      "             'text': 'AS I APPROACHED THE CITY I HEARD BELLS RINGING AND A '\n",
+      "                     'LITTLE LATER I FOUND THE STREETS ASTIR WITH THRONGS OF '\n",
+      "                     'WELL DRESSED PEOPLE IN FAMILY GROUPS WENDING THEIR WAY '\n",
+      "                     'HITHER AND THITHER',\n",
+      "             'token': '▁AS ▁I ▁APPROACHED ▁THE ▁CITY ▁I ▁HEARD ▁BELL S ▁RING '\n",
+      "                      'ING ▁AND ▁A ▁LITTLE ▁LATER ▁I ▁FOUND ▁THE ▁STREETS ▁AS '\n",
+      "                      'T IR ▁WITH ▁THRONG S ▁OF ▁WELL ▁DRESSED ▁PEOPLE ▁IN '\n",
+      "                      '▁FAMILY ▁GROUP S ▁WE ND ING ▁THEIR ▁WAY ▁HITHER ▁AND '\n",
+      "                      '▁THITHER',\n",
+      "             'tokenid': '713 2458 676 4502 1155 2458 2351 849 389 3831 206 627 '\n",
+      "                        '482 2812 2728 2458 2104 4502 4316 713 404 212 4925 '\n",
+      "                        '4549 389 3204 4861 1677 3339 2495 1950 2279 389 4845 '\n",
+      "                        '302 206 4504 4843 2394 627 4526'}],\n",
+      " 'utt': '116-288045-0000',\n",
+      " 'utt2spk': '116-288045'}\n",
+      "5542\n",
+      "<class 'list'>\n"
+     ]
+    }
+   ],
+   "source": [
+    "from pprint import pprint\n",
+    "pprint(dev_json[0])\n",
+    "print(len(dev_json))\n",
+    "print(type(dev_json))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 97,
+   "id": "motivated-receptor",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.\n",
+    "#\n",
+    "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
+    "# you may not use this file except in compliance with the License.\n",
+    "# You may obtain a copy of the License at\n",
+    "#\n",
+    "#     http://www.apache.org/licenses/LICENSE-2.0\n",
+    "#\n",
+    "# Unless required by applicable law or agreed to in writing, software\n",
+    "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
+    "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
+    "# See the License for the specific language governing permissions and\n",
+    "# limitations under the License.\n",
+    "import itertools\n",
+    "\n",
+    "import numpy as np\n",
+    "\n",
+    "from deepspeech.utils.log import Log\n",
+    "\n",
+    "__all__ = [\"make_batchset\"]\n",
+    "\n",
+    "logger = Log(__name__).getlog()\n",
+    "\n",
+    "\n",
+    "def batchfy_by_seq(\n",
+    "        sorted_data,\n",
+    "        batch_size,\n",
+    "        max_length_in,\n",
+    "        max_length_out,\n",
+    "        min_batch_size=1,\n",
+    "        shortest_first=False,\n",
+    "        ikey=\"input\",\n",
+    "        iaxis=0,\n",
+    "        okey=\"output\",\n",
+    "        oaxis=0, ):\n",
+    "    \"\"\"Make batch set from json dictionary\n",
+    "\n",
+    "    :param List[(str, Dict[str, Any])] sorted_data: dictionary loaded from data.json\n",
+    "    :param int batch_size: batch size\n",
+    "    :param int max_length_in: maximum length of input to decide adaptive batch size\n",
+    "    :param int max_length_out: maximum length of output to decide adaptive batch size\n",
+    "    :param int min_batch_size: mininum batch size (for multi-gpu)\n",
+    "    :param bool shortest_first: Sort from batch with shortest samples\n",
+    "        to longest if true, otherwise reverse\n",
+    "    :param str ikey: key to access input\n",
+    "        (for ASR ikey=\"input\", for TTS, MT ikey=\"output\".)\n",
+    "    :param int iaxis: dimension to access input\n",
+    "        (for ASR, TTS iaxis=0, for MT iaxis=\"1\".)\n",
+    "    :param str okey: key to access output\n",
+    "        (for ASR, MT okey=\"output\". for TTS okey=\"input\".)\n",
+    "    :param int oaxis: dimension to access output\n",
+    "        (for ASR, TTS, MT oaxis=0, reserved for future research, -1 means all axis.)\n",
+    "    :return: List[List[Tuple[str, dict]]] list of batches\n",
+    "    \"\"\"\n",
+    "    if batch_size <= 0:\n",
+    "        raise ValueError(f\"Invalid batch_size={batch_size}\")\n",
+    "\n",
+    "    # check #utts is more than min_batch_size\n",
+    "    if len(sorted_data) < min_batch_size:\n",
+    "        raise ValueError(\n",
+    "            f\"#utts({len(sorted_data)}) is less than min_batch_size({min_batch_size}).\"\n",
+    "        )\n",
+    "\n",
+    "    # make list of minibatches\n",
+    "    minibatches = []\n",
+    "    start = 0\n",
+    "    while True:\n",
+    "        _, info = sorted_data[start]\n",
+    "        ilen = int(info[ikey][iaxis][\"shape\"][0])\n",
+    "        olen = (int(info[okey][oaxis][\"shape\"][0]) if oaxis >= 0 else\n",
+    "                max(map(lambda x: int(x[\"shape\"][0]), info[okey])))\n",
+    "        factor = max(int(ilen / max_length_in), int(olen / max_length_out))\n",
+    "        # change batchsize depending on the input and output length\n",
+    "        # if ilen = 1000 and max_length_in = 800\n",
+    "        # then b = batchsize / 2\n",
+    "        # and max(min_batches, .) avoids batchsize = 0\n",
+    "        bs = max(min_batch_size, int(batch_size / (1 + factor)))\n",
+    "        end = min(len(sorted_data), start + bs)\n",
+    "        minibatch = sorted_data[start:end]\n",
+    "        if shortest_first:\n",
+    "            minibatch.reverse()\n",
+    "\n",
+    "        # check each batch is more than minimum batchsize\n",
+    "        if len(minibatch) < min_batch_size:\n",
+    "            mod = min_batch_size - len(minibatch) % min_batch_size\n",
+    "            additional_minibatch = [\n",
+    "                sorted_data[i] for i in np.random.randint(0, start, mod)\n",
+    "            ]\n",
+    "            if shortest_first:\n",
+    "                additional_minibatch.reverse()\n",
+    "            minibatch.extend(additional_minibatch)\n",
+    "        minibatches.append(minibatch)\n",
+    "\n",
+    "        if end == len(sorted_data):\n",
+    "            break\n",
+    "        start = end\n",
+    "\n",
+    "    # batch: List[List[Tuple[str, dict]]]\n",
+    "    return minibatches\n",
+    "\n",
+    "\n",
+    "def batchfy_by_bin(\n",
+    "        sorted_data,\n",
+    "        batch_bins,\n",
+    "        num_batches=0,\n",
+    "        min_batch_size=1,\n",
+    "        shortest_first=False,\n",
+    "        ikey=\"input\",\n",
+    "        okey=\"output\", ):\n",
+    "    \"\"\"Make variably sized batch set, which maximizes\n",
+    "\n",
+    "    the number of bins up to `batch_bins`.\n",
+    "\n",
+    "    :param List[(str, Dict[str, Any])] sorted_data: dictionary loaded from data.json\n",
+    "    :param int batch_bins: Maximum frames of a batch\n",
+    "    :param int num_batches: # number of batches to use (for debug)\n",
+    "    :param int min_batch_size: minimum batch size (for multi-gpu)\n",
+    "    :param int test: Return only every `test` batches\n",
+    "    :param bool shortest_first: Sort from batch with shortest samples\n",
+    "        to longest if true, otherwise reverse\n",
+    "\n",
+    "    :param str ikey: key to access input (for ASR ikey=\"input\", for TTS ikey=\"output\".)\n",
+    "    :param str okey: key to access output (for ASR okey=\"output\". for TTS okey=\"input\".)\n",
+    "\n",
+    "    :return: List[Tuple[str, Dict[str, List[Dict[str, Any]]]] list of batches\n",
+    "    \"\"\"\n",
+    "    if batch_bins <= 0:\n",
+    "        raise ValueError(f\"invalid batch_bins={batch_bins}\")\n",
+    "    length = len(sorted_data)\n",
+    "    idim = int(sorted_data[0][1][ikey][0][\"shape\"][1])\n",
+    "    odim = int(sorted_data[0][1][okey][0][\"shape\"][1])\n",
+    "    logger.info(\"# utts: \" + str(len(sorted_data)))\n",
+    "    minibatches = []\n",
+    "    start = 0\n",
+    "    n = 0\n",
+    "    while True:\n",
+    "        # Dynamic batch size depending on size of samples\n",
+    "        b = 0\n",
+    "        next_size = 0\n",
+    "        max_olen = 0\n",
+    "        while next_size < batch_bins and (start + b) < length:\n",
+    "            ilen = int(sorted_data[start + b][1][ikey][0][\"shape\"][0]) * idim\n",
+    "            olen = int(sorted_data[start + b][1][okey][0][\"shape\"][0]) * odim\n",
+    "            if olen > max_olen:\n",
+    "                max_olen = olen\n",
+    "            next_size = (max_olen + ilen) * (b + 1)\n",
+    "            if next_size <= batch_bins:\n",
+    "                b += 1\n",
+    "            elif next_size == 0:\n",
+    "                raise ValueError(\n",
+    "                    f\"Can't fit one sample in batch_bins ({batch_bins}): \"\n",
+    "                    f\"Please increase the value\")\n",
+    "        end = min(length, start + max(min_batch_size, b))\n",
+    "        batch = sorted_data[start:end]\n",
+    "        if shortest_first:\n",
+    "            batch.reverse()\n",
+    "        minibatches.append(batch)\n",
+    "        # Check for min_batch_size and fixes the batches if needed\n",
+    "        i = -1\n",
+    "        while len(minibatches[i]) < min_batch_size:\n",
+    "            missing = min_batch_size - len(minibatches[i])\n",
+    "            if -i == len(minibatches):\n",
+    "                minibatches[i + 1].extend(minibatches[i])\n",
+    "                minibatches = minibatches[1:]\n",
+    "                break\n",
+    "            else:\n",
+    "                minibatches[i].extend(minibatches[i - 1][:missing])\n",
+    "                minibatches[i - 1] = minibatches[i - 1][missing:]\n",
+    "                i -= 1\n",
+    "        if end == length:\n",
+    "            break\n",
+    "        start = end\n",
+    "        n += 1\n",
+    "    if num_batches > 0:\n",
+    "        minibatches = minibatches[:num_batches]\n",
+    "    lengths = [len(x) for x in minibatches]\n",
+    "    logger.info(\n",
+    "        str(len(minibatches)) + \" batches containing from \" + str(min(lengths))\n",
+    "        + \" to \" + str(max(lengths)) + \" samples \" + \"(avg \" + str(\n",
+    "            int(np.mean(lengths))) + \" samples).\")\n",
+    "    return minibatches\n",
+    "\n",
+    "\n",
+    "def batchfy_by_frame(\n",
+    "        sorted_data,\n",
+    "        max_frames_in,\n",
+    "        max_frames_out,\n",
+    "        max_frames_inout,\n",
+    "        num_batches=0,\n",
+    "        min_batch_size=1,\n",
+    "        shortest_first=False,\n",
+    "        ikey=\"input\",\n",
+    "        okey=\"output\", ):\n",
+    "    \"\"\"Make variable batch set, which maximizes the number of frames to max_batch_frame.\n",
+    "\n",
+    "    :param List[(str, Dict[str, Any])] sorteddata: dictionary loaded from data.json\n",
+    "    :param int max_frames_in: Maximum input frames of a batch\n",
+    "    :param int max_frames_out: Maximum output frames of a batch\n",
+    "    :param int max_frames_inout: Maximum input+output frames of a batch\n",
+    "    :param int num_batches: # number of batches to use (for debug)\n",
+    "    :param int min_batch_size: minimum batch size (for multi-gpu)\n",
+    "    :param int test: Return only every `test` batches\n",
+    "    :param bool shortest_first: Sort from batch with shortest samples\n",
+    "        to longest if true, otherwise reverse\n",
+    "\n",
+    "    :param str ikey: key to access input (for ASR ikey=\"input\", for TTS ikey=\"output\".)\n",
+    "    :param str okey: key to access output (for ASR okey=\"output\". for TTS okey=\"input\".)\n",
+    "\n",
+    "    :return: List[Tuple[str, Dict[str, List[Dict[str, Any]]]] list of batches\n",
+    "    \"\"\"\n",
+    "    if max_frames_in <= 0 and max_frames_out <= 0 and max_frames_inout <= 0:\n",
+    "        raise ValueError(\n",
+    "            \"At least, one of `--batch-frames-in`, `--batch-frames-out` or \"\n",
+    "            \"`--batch-frames-inout` should be > 0\")\n",
+    "    length = len(sorted_data)\n",
+    "    minibatches = []\n",
+    "    start = 0\n",
+    "    end = 0\n",
+    "    while end != length:\n",
+    "        # Dynamic batch size depending on size of samples\n",
+    "        b = 0\n",
+    "        max_olen = 0\n",
+    "        max_ilen = 0\n",
+    "        while (start + b) < length:\n",
+    "            ilen = int(sorted_data[start + b][1][ikey][0][\"shape\"][0])\n",
+    "            if ilen > max_frames_in and max_frames_in != 0:\n",
+    "                raise ValueError(\n",
+    "                    f\"Can't fit one sample in --batch-frames-in ({max_frames_in}): \"\n",
+    "                    f\"Please increase the value\")\n",
+    "            olen = int(sorted_data[start + b][1][okey][0][\"shape\"][0])\n",
+    "            if olen > max_frames_out and max_frames_out != 0:\n",
+    "                raise ValueError(\n",
+    "                    f\"Can't fit one sample in --batch-frames-out ({max_frames_out}): \"\n",
+    "                    f\"Please increase the value\")\n",
+    "            if ilen + olen > max_frames_inout and max_frames_inout != 0:\n",
+    "                raise ValueError(\n",
+    "                    f\"Can't fit one sample in --batch-frames-out ({max_frames_inout}): \"\n",
+    "                    f\"Please increase the value\")\n",
+    "            max_olen = max(max_olen, olen)\n",
+    "            max_ilen = max(max_ilen, ilen)\n",
+    "            in_ok = max_ilen * (b + 1) <= max_frames_in or max_frames_in == 0\n",
+    "            out_ok = max_olen * (b + 1) <= max_frames_out or max_frames_out == 0\n",
+    "            inout_ok = (max_ilen + max_olen) * (\n",
+    "                b + 1) <= max_frames_inout or max_frames_inout == 0\n",
+    "            if in_ok and out_ok and inout_ok:\n",
+    "                # add more seq in the minibatch\n",
+    "                b += 1\n",
+    "            else:\n",
+    "                # no more seq in the minibatch\n",
+    "                break\n",
+    "        end = min(length, start + b)\n",
+    "        batch = sorted_data[start:end]\n",
+    "        if shortest_first:\n",
+    "            batch.reverse()\n",
+    "        minibatches.append(batch)\n",
+    "        # Check for min_batch_size and fixes the batches if needed\n",
+    "        i = -1\n",
+    "        while len(minibatches[i]) < min_batch_size:\n",
+    "            missing = min_batch_size - len(minibatches[i])\n",
+    "            if -i == len(minibatches):\n",
+    "                minibatches[i + 1].extend(minibatches[i])\n",
+    "                minibatches = minibatches[1:]\n",
+    "                break\n",
+    "            else:\n",
+    "                minibatches[i].extend(minibatches[i - 1][:missing])\n",
+    "                minibatches[i - 1] = minibatches[i - 1][missing:]\n",
+    "                i -= 1\n",
+    "        start = end\n",
+    "    if num_batches > 0:\n",
+    "        minibatches = minibatches[:num_batches]\n",
+    "    lengths = [len(x) for x in minibatches]\n",
+    "    logger.info(\n",
+    "        str(len(minibatches)) + \" batches containing from \" + str(min(lengths))\n",
+    "        + \" to \" + str(max(lengths)) + \" samples\" + \"(avg \" + str(\n",
+    "            int(np.mean(lengths))) + \" samples).\")\n",
+    "\n",
+    "    return minibatches\n",
+    "\n",
+    "\n",
+    "def batchfy_shuffle(data, batch_size, min_batch_size, num_batches,\n",
+    "                    shortest_first):\n",
+    "    import random\n",
+    "\n",
+    "    logger.info(\"use shuffled batch.\")\n",
+    "    sorted_data = random.sample(data.items(), len(data.items()))\n",
+    "    logger.info(\"# utts: \" + str(len(sorted_data)))\n",
+    "    # make list of minibatches\n",
+    "    minibatches = []\n",
+    "    start = 0\n",
+    "    while True:\n",
+    "        end = min(len(sorted_data), start + batch_size)\n",
+    "        # check each batch is more than minimum batchsize\n",
+    "        minibatch = sorted_data[start:end]\n",
+    "        if shortest_first:\n",
+    "            minibatch.reverse()\n",
+    "        if len(minibatch) < min_batch_size:\n",
+    "            mod = min_batch_size - len(minibatch) % min_batch_size\n",
+    "            additional_minibatch = [\n",
+    "                sorted_data[i] for i in np.random.randint(0, start, mod)\n",
+    "            ]\n",
+    "            if shortest_first:\n",
+    "                additional_minibatch.reverse()\n",
+    "            minibatch.extend(additional_minibatch)\n",
+    "        minibatches.append(minibatch)\n",
+    "        if end == len(sorted_data):\n",
+    "            break\n",
+    "        start = end\n",
+    "\n",
+    "    # for debugging\n",
+    "    if num_batches > 0:\n",
+    "        minibatches = minibatches[:num_batches]\n",
+    "        logger.info(\"# minibatches: \" + str(len(minibatches)))\n",
+    "    return minibatches\n",
+    "\n",
+    "\n",
+    "BATCH_COUNT_CHOICES = [\"auto\", \"seq\", \"bin\", \"frame\"]\n",
+    "BATCH_SORT_KEY_CHOICES = [\"input\", \"output\", \"shuffle\"]\n",
+    "\n",
+    "\n",
+    "def make_batchset(\n",
+    "        data,\n",
+    "        batch_size=0,\n",
+    "        max_length_in=float(\"inf\"),\n",
+    "        max_length_out=float(\"inf\"),\n",
+    "        num_batches=0,\n",
+    "        min_batch_size=1,\n",
+    "        shortest_first=False,\n",
+    "        batch_sort_key=\"input\",\n",
+    "        count=\"auto\",\n",
+    "        batch_bins=0,\n",
+    "        batch_frames_in=0,\n",
+    "        batch_frames_out=0,\n",
+    "        batch_frames_inout=0,\n",
+    "        iaxis=0,\n",
+    "        oaxis=0, ):\n",
+    "    \"\"\"Make batch set from json dictionary\n",
+    "\n",
+    "    if utts have \"category\" value,\n",
+    "\n",
+    "        >>> data = {'utt1': {'category': 'A', 'input': ...},\n",
+    "        ...         'utt2': {'category': 'B', 'input': ...},\n",
+    "        ...         'utt3': {'category': 'B', 'input': ...},\n",
+    "        ...         'utt4': {'category': 'A', 'input': ...}}\n",
+    "        >>> make_batchset(data, batchsize=2, ...)\n",
+    "        [[('utt1', ...), ('utt4', ...)], [('utt2', ...), ('utt3': ...)]]\n",
+    "\n",
+    "    Note that if any utts doesn't have \"category\",\n",
+    "    perform as same as batchfy_by_{count}\n",
+    "\n",
+    "    :param List[Dict[str, Any]] data: dictionary loaded from data.json\n",
+    "    :param int batch_size: maximum number of sequences in a minibatch.\n",
+    "    :param int batch_bins: maximum number of bins (frames x dim) in a minibatch.\n",
+    "    :param int batch_frames_in:  maximum number of input frames in a minibatch.\n",
+    "    :param int batch_frames_out: maximum number of output frames in a minibatch.\n",
+    "    :param int batch_frames_out: maximum number of input+output frames in a minibatch.\n",
+    "    :param str count: strategy to count maximum size of batch.\n",
+    "        For choices, see espnet.asr.batchfy.BATCH_COUNT_CHOICES\n",
+    "\n",
+    "    :param int max_length_in: maximum length of input to decide adaptive batch size\n",
+    "    :param int max_length_out: maximum length of output to decide adaptive batch size\n",
+    "    :param int num_batches: # number of batches to use (for debug)\n",
+    "    :param int min_batch_size: minimum batch size (for multi-gpu)\n",
+    "    :param bool shortest_first: Sort from batch with shortest samples\n",
+    "        to longest if true, otherwise reverse\n",
+    "    :param str batch_sort_key: how to sort data before creating minibatches\n",
+    "        [\"input\", \"output\", \"shuffle\"]\n",
+    "    :param bool swap_io: if True, use \"input\" as output and \"output\"\n",
+    "        as input in `data` dict\n",
+    "    :param bool mt: if True, use 0-axis of \"output\" as output and 1-axis of \"output\"\n",
+    "        as input in `data` dict\n",
+    "    :param int iaxis: dimension to access input\n",
+    "        (for ASR, TTS iaxis=0, for MT iaxis=\"1\".)\n",
+    "    :param int oaxis: dimension to access output (for ASR, TTS, MT oaxis=0,\n",
+    "        reserved for future research, -1 means all axis.)\n",
+    "    :return: List[List[Tuple[str, dict]]] list of batches\n",
+    "    \"\"\"\n",
+    "\n",
+    "    # check args\n",
+    "    if count not in BATCH_COUNT_CHOICES:\n",
+    "        raise ValueError(\n",
+    "            f\"arg 'count' ({count}) should be one of {BATCH_COUNT_CHOICES}\")\n",
+    "    if batch_sort_key not in BATCH_SORT_KEY_CHOICES:\n",
+    "        raise ValueError(f\"arg 'batch_sort_key' ({batch_sort_key}) should be \"\n",
+    "                         f\"one of {BATCH_SORT_KEY_CHOICES}\")\n",
+    "\n",
+    "    ikey = \"input\"\n",
+    "    okey = \"output\"\n",
+    "    batch_sort_axis = 0  # index of list \n",
+    "\n",
+    "    if count == \"auto\":\n",
+    "        if batch_size != 0:\n",
+    "            count = \"seq\"\n",
+    "        elif batch_bins != 0:\n",
+    "            count = \"bin\"\n",
+    "        elif batch_frames_in != 0 or batch_frames_out != 0 or batch_frames_inout != 0:\n",
+    "            count = \"frame\"\n",
+    "        else:\n",
+    "            raise ValueError(\n",
+    "                f\"cannot detect `count` manually set one of {BATCH_COUNT_CHOICES}\"\n",
+    "            )\n",
+    "        logger.info(f\"count is auto detected as {count}\")\n",
+    "\n",
+    "    if count != \"seq\" and batch_sort_key == \"shuffle\":\n",
+    "        raise ValueError(\n",
+    "            \"batch_sort_key=shuffle is only available if batch_count=seq\")\n",
+    "\n",
+    "    category2data = {}  # Dict[str, dict]\n",
+    "    for v in data:\n",
+    "        k = v['utt']\n",
+    "        category2data.setdefault(v.get(\"category\"), {})[k] = v\n",
+    "\n",
+    "    batches_list = []  # List[List[List[Tuple[str, dict]]]]\n",
+    "    for d in category2data.values():\n",
+    "        if batch_sort_key == \"shuffle\":\n",
+    "            batches = batchfy_shuffle(d, batch_size, min_batch_size,\n",
+    "                                      num_batches, shortest_first)\n",
+    "            batches_list.append(batches)\n",
+    "            continue\n",
+    "\n",
+    "        # sort it by input lengths (long to short)\n",
+    "        sorted_data = sorted(\n",
+    "            d.items(),\n",
+    "            key=lambda data: int(data[1][batch_sort_key][batch_sort_axis][\"shape\"][0]),\n",
+    "            reverse=not shortest_first, )\n",
+    "        logger.info(\"# utts: \" + str(len(sorted_data)))\n",
+    "        \n",
+    "        if count == \"seq\":\n",
+    "            batches = batchfy_by_seq(\n",
+    "                sorted_data,\n",
+    "                batch_size=batch_size,\n",
+    "                max_length_in=max_length_in,\n",
+    "                max_length_out=max_length_out,\n",
+    "                min_batch_size=min_batch_size,\n",
+    "                shortest_first=shortest_first,\n",
+    "                ikey=ikey,\n",
+    "                iaxis=iaxis,\n",
+    "                okey=okey,\n",
+    "                oaxis=oaxis, )\n",
+    "        if count == \"bin\":\n",
+    "            batches = batchfy_by_bin(\n",
+    "                sorted_data,\n",
+    "                batch_bins=batch_bins,\n",
+    "                min_batch_size=min_batch_size,\n",
+    "                shortest_first=shortest_first,\n",
+    "                ikey=ikey,\n",
+    "                okey=okey, )\n",
+    "        if count == \"frame\":\n",
+    "            batches = batchfy_by_frame(\n",
+    "                sorted_data,\n",
+    "                max_frames_in=batch_frames_in,\n",
+    "                max_frames_out=batch_frames_out,\n",
+    "                max_frames_inout=batch_frames_inout,\n",
+    "                min_batch_size=min_batch_size,\n",
+    "                shortest_first=shortest_first,\n",
+    "                ikey=ikey,\n",
+    "                okey=okey, )\n",
+    "        batches_list.append(batches)\n",
+    "\n",
+    "    if len(batches_list) == 1:\n",
+    "        batches = batches_list[0]\n",
+    "    else:\n",
+    "        # Concat list. This way is faster than \"sum(batch_list, [])\"\n",
+    "        batches = list(itertools.chain(*batches_list))\n",
+    "\n",
+    "    # for debugging\n",
+    "    if num_batches > 0:\n",
+    "        batches = batches[:num_batches]\n",
+    "    logger.info(\"# minibatches: \" + str(len(batches)))\n",
+    "\n",
+    "    # batch: List[List[Tuple[str, dict]]]\n",
+    "    return batches\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 98,
+   "id": "acquired-hurricane",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "[INFO 2021/08/18 06:57:10 1445365138.py:284] use shuffled batch.\n",
+      "[INFO 2021/08/18 06:57:10 1445365138.py:286] # utts: 5542\n",
+      "[INFO 2021/08/18 06:57:10 1445365138.py:468] # minibatches: 555\n"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "555\n"
+     ]
+    }
+   ],
+   "source": [
+    "batch_size=10\n",
+    "maxlen_in=300\n",
+    "maxlen_out=400\n",
+    "minibatches=0  # for debug\n",
+    "min_batch_size=2\n",
+    "use_sortagrad=True\n",
+    "batch_count='seq'\n",
+    "batch_bins=0\n",
+    "batch_frames_in=3000\n",
+    "batch_frames_out=0\n",
+    "batch_frames_inout=0\n",
+    "            \n",
+    "dev_data = make_batchset(\n",
+    "            dev_json,\n",
+    "            batch_size,\n",
+    "            maxlen_in,\n",
+    "            maxlen_out,\n",
+    "            minibatches,  # for debug\n",
+    "            min_batch_size=min_batch_size,\n",
+    "            shortest_first=use_sortagrad,\n",
+    "            batch_sort_key=\"shuffle\",\n",
+    "            count=batch_count,\n",
+    "            batch_bins=batch_bins,\n",
+    "            batch_frames_in=batch_frames_in,\n",
+    "            batch_frames_out=batch_frames_out,\n",
+    "            batch_frames_inout=batch_frames_inout,\n",
+    "            iaxis=0,\n",
+    "            oaxis=0, )\n",
+    "print(len(dev_data))\n",
+    "# for i in range(len(dev_data)):\n",
+    "#     print(len(dev_data[i]))\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 99,
+   "id": "warming-malpractice",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Requirement already satisfied: kaldiio in ./DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (2.17.2)\n",
+      "Requirement already satisfied: numpy in ./DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/numpy-1.21.2-py3.7-linux-x86_64.egg (from kaldiio) (1.21.2)\n",
+      "\u001b[33mWARNING: You are using pip version 20.3.3; however, version 21.2.4 is available.\n",
+      "You should consider upgrading via the '/workspace/zhanghui/DeepSpeech-2.x/tools/venv/bin/python -m pip install --upgrade pip' command.\u001b[0m\n"
+     ]
+    }
+   ],
+   "source": [
+    "!pip install kaldiio"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "equipped-subject",
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 100,
+   "id": "superb-methodology",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "from collections import OrderedDict\n",
+    "import kaldiio\n",
+    "\n",
+    "class LoadInputsAndTargets():\n",
+    "    \"\"\"Create a mini-batch from a list of dicts\n",
+    "\n",
+    "    >>> batch = [('utt1',\n",
+    "    ...           dict(input=[dict(feat='some.ark:123',\n",
+    "    ...                            filetype='mat',\n",
+    "    ...                            name='input1',\n",
+    "    ...                            shape=[100, 80])],\n",
+    "    ...                output=[dict(tokenid='1 2 3 4',\n",
+    "    ...                             name='target1',\n",
+    "    ...                             shape=[4, 31])]]))\n",
+    "    >>> l = LoadInputsAndTargets()\n",
+    "    >>> feat, target = l(batch)\n",
+    "\n",
+    "    :param: str mode: Specify the task mode, \"asr\" or \"tts\"\n",
+    "    :param: str preprocess_conf: The path of a json file for pre-processing\n",
+    "    :param: bool load_input: If False, not to load the input data\n",
+    "    :param: bool load_output: If False, not to load the output data\n",
+    "    :param: bool sort_in_input_length: Sort the mini-batch in descending order\n",
+    "        of the input length\n",
+    "    :param: bool use_speaker_embedding: Used for tts mode only\n",
+    "    :param: bool use_second_target: Used for tts mode only\n",
+    "    :param: dict preprocess_args: Set some optional arguments for preprocessing\n",
+    "    :param: Optional[dict] preprocess_args: Used for tts mode only\n",
+    "    \"\"\"\n",
+    "\n",
+    "    def __init__(\n",
+    "            self,\n",
+    "            mode=\"asr\",\n",
+    "            preprocess_conf=None,\n",
+    "            load_input=True,\n",
+    "            load_output=True,\n",
+    "            sort_in_input_length=True,\n",
+    "            preprocess_args=None,\n",
+    "            keep_all_data_on_mem=False, ):\n",
+    "        self._loaders = {}\n",
+    "\n",
+    "        if mode not in [\"asr\"]:\n",
+    "            raise ValueError(\"Only asr are allowed: mode={}\".format(mode))\n",
+    "\n",
+    "        if preprocess_conf is not None:\n",
+    "            self.preprocessing = AugmentationPipeline(preprocess_conf)\n",
+    "            logging.warning(\n",
+    "                \"[Experimental feature] Some preprocessing will be done \"\n",
+    "                \"for the mini-batch creation using {}\".format(\n",
+    "                    self.preprocessing))\n",
+    "        else:\n",
+    "            # If conf doesn't exist, this function don't touch anything.\n",
+    "            self.preprocessing = None\n",
+    "\n",
+    "        self.mode = mode\n",
+    "        self.load_output = load_output\n",
+    "        self.load_input = load_input\n",
+    "        self.sort_in_input_length = sort_in_input_length\n",
+    "        if preprocess_args is None:\n",
+    "            self.preprocess_args = {}\n",
+    "        else:\n",
+    "            assert isinstance(preprocess_args, dict), type(preprocess_args)\n",
+    "            self.preprocess_args = dict(preprocess_args)\n",
+    "\n",
+    "        self.keep_all_data_on_mem = keep_all_data_on_mem\n",
+    "\n",
+    "    def __call__(self, batch, return_uttid=False):\n",
+    "        \"\"\"Function to load inputs and targets from list of dicts\n",
+    "\n",
+    "        :param List[Tuple[str, dict]] batch: list of dict which is subset of\n",
+    "            loaded data.json\n",
+    "        :param bool return_uttid: return utterance ID information for visualization\n",
+    "        :return: list of input token id sequences [(L_1), (L_2), ..., (L_B)]\n",
+    "        :return: list of input feature sequences\n",
+    "            [(T_1, D), (T_2, D), ..., (T_B, D)]\n",
+    "        :rtype: list of float ndarray\n",
+    "        :return: list of target token id sequences [(L_1), (L_2), ..., (L_B)]\n",
+    "        :rtype: list of int ndarray\n",
+    "\n",
+    "        \"\"\"\n",
+    "        x_feats_dict = OrderedDict()  # OrderedDict[str, List[np.ndarray]]\n",
+    "        y_feats_dict = OrderedDict()  # OrderedDict[str, List[np.ndarray]]\n",
+    "        uttid_list = []  # List[str]\n",
+    "\n",
+    "        for uttid, info in batch:\n",
+    "            uttid_list.append(uttid)\n",
+    "\n",
+    "            if self.load_input:\n",
+    "                # Note(kamo): This for-loop is for multiple inputs\n",
+    "                for idx, inp in enumerate(info[\"input\"]):\n",
+    "                    # {\"input\":\n",
+    "                    #  [{\"feat\": \"some/path.h5:F01_050C0101_PED_REAL\",\n",
+    "                    #    \"filetype\": \"hdf5\",\n",
+    "                    #    \"name\": \"input1\", ...}], ...}\n",
+    "                    x = self._get_from_loader(\n",
+    "                        filepath=inp[\"feat\"],\n",
+    "                        filetype=inp.get(\"filetype\", \"mat\"))\n",
+    "                    x_feats_dict.setdefault(inp[\"name\"], []).append(x)\n",
+    "\n",
+    "            if self.load_output:\n",
+    "                for idx, inp in enumerate(info[\"output\"]):\n",
+    "                    if \"tokenid\" in inp:\n",
+    "                        # ======= Legacy format for output =======\n",
+    "                        # {\"output\": [{\"tokenid\": \"1 2 3 4\"}])\n",
+    "                        x = np.fromiter(\n",
+    "                            map(int, inp[\"tokenid\"].split()), dtype=np.int64)\n",
+    "                    else:\n",
+    "                        # ======= New format =======\n",
+    "                        # {\"input\":\n",
+    "                        #  [{\"feat\": \"some/path.h5:F01_050C0101_PED_REAL\",\n",
+    "                        #    \"filetype\": \"hdf5\",\n",
+    "                        #    \"name\": \"target1\", ...}], ...}\n",
+    "                        x = self._get_from_loader(\n",
+    "                            filepath=inp[\"feat\"],\n",
+    "                            filetype=inp.get(\"filetype\", \"mat\"))\n",
+    "\n",
+    "                    y_feats_dict.setdefault(inp[\"name\"], []).append(x)\n",
+    "\n",
+    "        if self.mode == \"asr\":\n",
+    "            return_batch, uttid_list = self._create_batch_asr(\n",
+    "                x_feats_dict, y_feats_dict, uttid_list)\n",
+    "        else:\n",
+    "            raise NotImplementedError(self.mode)\n",
+    "\n",
+    "        if self.preprocessing is not None:\n",
+    "            # Apply pre-processing all input features\n",
+    "            for x_name in return_batch.keys():\n",
+    "                if x_name.startswith(\"input\"):\n",
+    "                    return_batch[x_name] = self.preprocessing(\n",
+    "                        return_batch[x_name], uttid_list,\n",
+    "                        **self.preprocess_args)\n",
+    "\n",
+    "        if return_uttid:\n",
+    "            return tuple(return_batch.values()), uttid_list\n",
+    "\n",
+    "        # Doesn't return the names now.\n",
+    "        return tuple(return_batch.values())\n",
+    "\n",
+    "    def _create_batch_asr(self, x_feats_dict, y_feats_dict, uttid_list):\n",
+    "        \"\"\"Create a OrderedDict for the mini-batch\n",
+    "\n",
+    "        :param OrderedDict x_feats_dict:\n",
+    "            e.g. {\"input1\": [ndarray, ndarray, ...],\n",
+    "                  \"input2\": [ndarray, ndarray, ...]}\n",
+    "        :param OrderedDict y_feats_dict:\n",
+    "            e.g. {\"target1\": [ndarray, ndarray, ...],\n",
+    "                  \"target2\": [ndarray, ndarray, ...]}\n",
+    "        :param: List[str] uttid_list:\n",
+    "            Give uttid_list to sort in the same order as the mini-batch\n",
+    "        :return: batch, uttid_list\n",
+    "        :rtype: Tuple[OrderedDict, List[str]]\n",
+    "        \"\"\"\n",
+    "        # handle single-input and multi-input (paralell) asr mode\n",
+    "        xs = list(x_feats_dict.values())\n",
+    "\n",
+    "        if self.load_output:\n",
+    "            ys = list(y_feats_dict.values())\n",
+    "            assert len(xs[0]) == len(ys[0]), (len(xs[0]), len(ys[0]))\n",
+    "\n",
+    "            # get index of non-zero length samples\n",
+    "            nonzero_idx = list(\n",
+    "                filter(lambda i: len(ys[0][i]) > 0, range(len(ys[0]))))\n",
+    "            for n in range(1, len(y_feats_dict)):\n",
+    "                nonzero_idx = filter(lambda i: len(ys[n][i]) > 0, nonzero_idx)\n",
+    "        else:\n",
+    "            # Note(kamo): Be careful not to make nonzero_idx to a generator\n",
+    "            nonzero_idx = list(range(len(xs[0])))\n",
+    "\n",
+    "        if self.sort_in_input_length:\n",
+    "            # sort in input lengths based on the first input\n",
+    "            nonzero_sorted_idx = sorted(\n",
+    "                nonzero_idx, key=lambda i: -len(xs[0][i]))\n",
+    "        else:\n",
+    "            nonzero_sorted_idx = nonzero_idx\n",
+    "\n",
+    "        if len(nonzero_sorted_idx) != len(xs[0]):\n",
+    "            logging.warning(\n",
+    "                \"Target sequences include empty tokenid (batch {} -> {}).\".\n",
+    "                format(len(xs[0]), len(nonzero_sorted_idx)))\n",
+    "\n",
+    "        # remove zero-length samples\n",
+    "        xs = [[x[i] for i in nonzero_sorted_idx] for x in xs]\n",
+    "        uttid_list = [uttid_list[i] for i in nonzero_sorted_idx]\n",
+    "\n",
+    "        x_names = list(x_feats_dict.keys())\n",
+    "        if self.load_output:\n",
+    "            ys = [[y[i] for i in nonzero_sorted_idx] for y in ys]\n",
+    "            y_names = list(y_feats_dict.keys())\n",
+    "\n",
+    "            # Keeping x_name and y_name, e.g. input1, for future extension\n",
+    "            return_batch = OrderedDict([\n",
+    "                * [(x_name, x) for x_name, x in zip(x_names, xs)],\n",
+    "                * [(y_name, y) for y_name, y in zip(y_names, ys)],\n",
+    "            ])\n",
+    "        else:\n",
+    "            return_batch = OrderedDict(\n",
+    "                [(x_name, x) for x_name, x in zip(x_names, xs)])\n",
+    "        return return_batch, uttid_list\n",
+    "\n",
+    "    def _get_from_loader(self, filepath, filetype):\n",
+    "        \"\"\"Return ndarray\n",
+    "\n",
+    "        In order to make the fds to be opened only at the first referring,\n",
+    "        the loader are stored in self._loaders\n",
+    "\n",
+    "        >>> ndarray = loader.get_from_loader(\n",
+    "        ...     'some/path.h5:F01_050C0101_PED_REAL', filetype='hdf5')\n",
+    "\n",
+    "        :param: str filepath:\n",
+    "        :param: str filetype:\n",
+    "        :return:\n",
+    "        :rtype: np.ndarray\n",
+    "        \"\"\"\n",
+    "        if filetype == \"hdf5\":\n",
+    "            # e.g.\n",
+    "            #    {\"input\": [{\"feat\": \"some/path.h5:F01_050C0101_PED_REAL\",\n",
+    "            #                \"filetype\": \"hdf5\",\n",
+    "            # -> filepath = \"some/path.h5\", key = \"F01_050C0101_PED_REAL\"\n",
+    "            filepath, key = filepath.split(\":\", 1)\n",
+    "\n",
+    "            loader = self._loaders.get(filepath)\n",
+    "            if loader is None:\n",
+    "                # To avoid disk access, create loader only for the first time\n",
+    "                loader = h5py.File(filepath, \"r\")\n",
+    "                self._loaders[filepath] = loader\n",
+    "            return loader[key][()]\n",
+    "        elif filetype == \"sound.hdf5\":\n",
+    "            # e.g.\n",
+    "            #    {\"input\": [{\"feat\": \"some/path.h5:F01_050C0101_PED_REAL\",\n",
+    "            #                \"filetype\": \"sound.hdf5\",\n",
+    "            # -> filepath = \"some/path.h5\", key = \"F01_050C0101_PED_REAL\"\n",
+    "            filepath, key = filepath.split(\":\", 1)\n",
+    "\n",
+    "            loader = self._loaders.get(filepath)\n",
+    "            if loader is None:\n",
+    "                # To avoid disk access, create loader only for the first time\n",
+    "                loader = SoundHDF5File(filepath, \"r\", dtype=\"int16\")\n",
+    "                self._loaders[filepath] = loader\n",
+    "            array, rate = loader[key]\n",
+    "            return array\n",
+    "        elif filetype == \"sound\":\n",
+    "            # e.g.\n",
+    "            #    {\"input\": [{\"feat\": \"some/path.wav\",\n",
+    "            #                \"filetype\": \"sound\"},\n",
+    "            # Assume PCM16\n",
+    "            if not self.keep_all_data_on_mem:\n",
+    "                array, _ = soundfile.read(filepath, dtype=\"int16\")\n",
+    "                return array\n",
+    "            if filepath not in self._loaders:\n",
+    "                array, _ = soundfile.read(filepath, dtype=\"int16\")\n",
+    "                self._loaders[filepath] = array\n",
+    "            return self._loaders[filepath]\n",
+    "        elif filetype == \"npz\":\n",
+    "            # e.g.\n",
+    "            #    {\"input\": [{\"feat\": \"some/path.npz:F01_050C0101_PED_REAL\",\n",
+    "            #                \"filetype\": \"npz\",\n",
+    "            filepath, key = filepath.split(\":\", 1)\n",
+    "\n",
+    "            loader = self._loaders.get(filepath)\n",
+    "            if loader is None:\n",
+    "                # To avoid disk access, create loader only for the first time\n",
+    "                loader = np.load(filepath)\n",
+    "                self._loaders[filepath] = loader\n",
+    "            return loader[key]\n",
+    "        elif filetype == \"npy\":\n",
+    "            # e.g.\n",
+    "            #    {\"input\": [{\"feat\": \"some/path.npy\",\n",
+    "            #                \"filetype\": \"npy\"},\n",
+    "            if not self.keep_all_data_on_mem:\n",
+    "                return np.load(filepath)\n",
+    "            if filepath not in self._loaders:\n",
+    "                self._loaders[filepath] = np.load(filepath)\n",
+    "            return self._loaders[filepath]\n",
+    "        elif filetype in [\"mat\", \"vec\"]:\n",
+    "            # e.g.\n",
+    "            #    {\"input\": [{\"feat\": \"some/path.ark:123\",\n",
+    "            #                \"filetype\": \"mat\"}]},\n",
+    "            # In this case, \"123\" indicates the starting points of the matrix\n",
+    "            # load_mat can load both matrix and vector\n",
+    "            if not self.keep_all_data_on_mem:\n",
+    "                return kaldiio.load_mat(filepath)\n",
+    "            if filepath not in self._loaders:\n",
+    "                self._loaders[filepath] = kaldiio.load_mat(filepath)\n",
+    "            return self._loaders[filepath]\n",
+    "        elif filetype == \"scp\":\n",
+    "            # e.g.\n",
+    "            #    {\"input\": [{\"feat\": \"some/path.scp:F01_050C0101_PED_REAL\",\n",
+    "            #                \"filetype\": \"scp\",\n",
+    "            filepath, key = filepath.split(\":\", 1)\n",
+    "            loader = self._loaders.get(filepath)\n",
+    "            if loader is None:\n",
+    "                # To avoid disk access, create loader only for the first time\n",
+    "                loader = kaldiio.load_scp(filepath)\n",
+    "                self._loaders[filepath] = loader\n",
+    "            return loader[key]\n",
+    "        else:\n",
+    "            raise NotImplementedError(\n",
+    "                \"Not supported: loader_type={}\".format(filetype))\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 101,
+   "id": "monthly-muscle",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "preprocess_conf=None\n",
+    "train_mode=True\n",
+    "load = LoadInputsAndTargets(\n",
+    "            mode=\"asr\",\n",
+    "            load_output=True,\n",
+    "            preprocess_conf=preprocess_conf,\n",
+    "            preprocess_args={\"train\":\n",
+    "                             train_mode},  # Switch the mode of preprocessing\n",
+    "        )"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 102,
+   "id": "periodic-senegal",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "res = load(dev_data[0])"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 103,
+   "id": "502d3f4d",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "<class 'tuple'>\n",
+      "2\n",
+      "10\n",
+      "10\n",
+      "(1174, 83) float32\n",
+      "(29,) int64\n"
+     ]
+    }
+   ],
+   "source": [
+    "print(type(res))\n",
+    "print(len(res))\n",
+    "print(len(res[0]))\n",
+    "print(len(res[1]))\n",
+    "print(res[0][0].shape, res[0][0].dtype)\n",
+    "print(res[1][0].shape, res[1][0].dtype)\n",
+    "# Tuple[Tuple[np.ndarry], Tuple[np.ndarry]]\n",
+    "# 2[10, 10]\n",
+    "# feats, labels"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 104,
+   "id": "humanitarian-container",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "(inputs, outputs), utts = load(dev_data[0], return_uttid=True)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 105,
+   "id": "heard-prize",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "['4572-112383-0005', '6313-66125-0015', '251-137823-0022', '2277-149896-0030', '652-130726-0032', '5895-34615-0013', '1462-170138-0002', '777-126732-0008', '3660-172182-0021', '2277-149896-0027'] 10\n",
+      "10\n"
+     ]
+    }
+   ],
+   "source": [
+    "print(utts, len(utts))\n",
+    "print(len(inputs))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 106,
+   "id": "convinced-animation",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import paddle\n",
+    "from deepspeech.io.utility import pad_list\n",
+    "class CustomConverter():\n",
+    "    \"\"\"Custom batch converter.\n",
+    "\n",
+    "    Args:\n",
+    "        subsampling_factor (int): The subsampling factor.\n",
+    "        dtype (paddle.dtype): Data type to convert.\n",
+    "\n",
+    "    \"\"\"\n",
+    "\n",
+    "    def __init__(self, subsampling_factor=1, dtype=np.float32):\n",
+    "        \"\"\"Construct a CustomConverter object.\"\"\"\n",
+    "        self.subsampling_factor = subsampling_factor\n",
+    "        self.ignore_id = -1\n",
+    "        self.dtype = dtype\n",
+    "\n",
+    "    def __call__(self, batch):\n",
+    "        \"\"\"Transform a batch and send it to a device.\n",
+    "\n",
+    "        Args:\n",
+    "            batch (list): The batch to transform.\n",
+    "\n",
+    "        Returns:\n",
+    "            tuple(paddle.Tensor, paddle.Tensor, paddle.Tensor)\n",
+    "\n",
+    "        \"\"\"\n",
+    "        # batch should be located in list\n",
+    "        assert len(batch) == 1\n",
+    "        (xs, ys), utts = batch[0]\n",
+    "\n",
+    "        # perform subsampling\n",
+    "        if self.subsampling_factor > 1:\n",
+    "            xs = [x[::self.subsampling_factor, :] for x in xs]\n",
+    "\n",
+    "        # get batch of lengths of input sequences\n",
+    "        ilens = np.array([x.shape[0] for x in xs])\n",
+    "\n",
+    "        # perform padding and convert to tensor\n",
+    "        # currently only support real number\n",
+    "        if xs[0].dtype.kind == \"c\":\n",
+    "            xs_pad_real = pad_list([x.real for x in xs], 0).astype(self.dtype)\n",
+    "            xs_pad_imag = pad_list([x.imag for x in xs], 0).astype(self.dtype)\n",
+    "            # Note(kamo):\n",
+    "            # {'real': ..., 'imag': ...} will be changed to ComplexTensor in E2E.\n",
+    "            # Don't create ComplexTensor and give it E2E here\n",
+    "            # because torch.nn.DataParellel can't handle it.\n",
+    "            xs_pad = {\"real\": xs_pad_real, \"imag\": xs_pad_imag}\n",
+    "        else:\n",
+    "            xs_pad = pad_list(xs, 0).astype(self.dtype)\n",
+    "\n",
+    "        # NOTE: this is for multi-output (e.g., speech translation)\n",
+    "        ys_pad = pad_list(\n",
+    "            [np.array(y[0][:]) if isinstance(y, tuple) else y for y in ys],\n",
+    "            self.ignore_id)\n",
+    "\n",
+    "        olens = np.array([y[0].shape[0] if isinstance(y, tuple) else y.shape[0] for y in ys])\n",
+    "        return utts, xs_pad, ilens, ys_pad, olens"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 107,
+   "id": "0b92ade5",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "convert = CustomConverter()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 108,
+   "id": "8dbd847c",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "utts, xs, ilen, ys, olen = convert([load(dev_data[0], return_uttid=True)])"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 109,
+   "id": "31c085f4",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "['4572-112383-0005', '6313-66125-0015', '251-137823-0022', '2277-149896-0030', '652-130726-0032', '5895-34615-0013', '1462-170138-0002', '777-126732-0008', '3660-172182-0021', '2277-149896-0027']\n",
+      "(10, 1174, 83)\n",
+      "(10,)\n",
+      "[1174  821  716  628  597  473  463  441  419  358]\n",
+      "(10, 32)\n",
+      "[[4502 2404 4223 3204 4502  587 1018 3861 2932  713 2458 2916  253 4508\n",
+      "   627 1395  713 4504  957 2761  209 2967 3173 3918 2598 4100    3 2816\n",
+      "  4990   -1   -1   -1]\n",
+      " [1005  451  210  278 3411  206  482 2307  573 4502 3848 4577 4273 2388\n",
+      "  4444   89 4919  278 1264 4501 2371    3  139  113 2603 4962 3158 3325\n",
+      "  4577  814 4587 1422]\n",
+      " [2345 4144 2291  200  713 2345  532  999 2458 3076  545 2458 4832 3038\n",
+      "  4499  482 2812 1260 3080   -1   -1   -1   -1   -1   -1   -1   -1   -1\n",
+      "    -1   -1   -1   -1]\n",
+      " [2345  832 4577 4920 4501 2345 2298 1236  381  288  389  101 2495 4172\n",
+      "  4843 3233 3245 4501 2345 2298 3987 4502 3023 3353 2345 1361 1635 2603\n",
+      "  4723 2371   -1   -1]\n",
+      " [4502 4207  432 3204 4502 2396  125  935  433 2598  483   18  327    2\n",
+      "   389  627 4512 2340  713  482 1981 4525 4031  269 2030 1340  101 2495\n",
+      "  4013 4844   -1   -1]\n",
+      " [4502 4892 3204 1892 3780  389  482 2774 3013   89  192 2495 4502 3475\n",
+      "   389   66  370  343  404   -1   -1   -1   -1   -1   -1   -1   -1   -1\n",
+      "    -1   -1   -1   -1]\n",
+      " [2458 2314 4577 2340 2863 1254  303  269    2  389  932 2079 4577  299\n",
+      "   195 3233 4508    2   89  814 3144 1091 3204 3250 2193 3414   -1   -1\n",
+      "    -1   -1   -1   -1]\n",
+      " [2391 1785  443   78   39 4962 2340  829  599 4593  278 4681  202  407\n",
+      "   269  194  182 4577  482 4308   -1   -1   -1   -1   -1   -1   -1   -1\n",
+      "    -1   -1   -1   -1]\n",
+      " [ 627 4873 2175  363  202  404 1018 4577 4502 3412 4875 2286  107  122\n",
+      "  4832 2345 3896   89 2368   -1   -1   -1   -1   -1   -1   -1   -1   -1\n",
+      "    -1   -1   -1   -1]\n",
+      " [ 481  174  474  599 1881 3252 2842  742 4502 2545  107   88 3204 4525\n",
+      "  4517   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1\n",
+      "    -1   -1   -1   -1]]\n",
+      "[29 32 19 30 30 19 26 20 19 15]\n",
+      "float32\n",
+      "int64\n",
+      "int64\n",
+      "int64\n"
+     ]
+    }
+   ],
+   "source": [
+    "print(utts)\n",
+    "print(xs.shape)\n",
+    "print(ilen.shape)\n",
+    "print(ilen)\n",
+    "print(ys.shape)\n",
+    "print(ys)\n",
+    "print(olen)\n",
+    "print(xs.dtype)\n",
+    "print(ilen.dtype)\n",
+    "print(ys.dtype)\n",
+    "print(olen.dtype)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 110,
+   "id": "72e9ba60",
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 230,
+   "id": "64593e5f",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "\n",
+    "from paddle.io import DataLoader\n",
+    "\n",
+    "from deepspeech.frontend.utility import read_manifest\n",
+    "from deepspeech.io.batchfy import make_batchset\n",
+    "from deepspeech.io.converter import CustomConverter\n",
+    "from deepspeech.io.dataset import TransformDataset\n",
+    "from deepspeech.io.reader import LoadInputsAndTargets\n",
+    "from deepspeech.utils.log import Log\n",
+    "\n",
+    "\n",
+    "logger = Log(__name__).getlog()\n",
+    "\n",
+    "\n",
+    "class BatchDataLoader():\n",
+    "    def __init__(self,\n",
+    "                 json_file: str,\n",
+    "                 train_mode: bool,\n",
+    "                 sortagrad: bool=False,\n",
+    "                 batch_size: int=0,\n",
+    "                 maxlen_in: float=float('inf'),\n",
+    "                 maxlen_out: float=float('inf'),\n",
+    "                 minibatches: int=0,\n",
+    "                 mini_batch_size: int=1,\n",
+    "                 batch_count: str='auto',\n",
+    "                 batch_bins: int=0,\n",
+    "                 batch_frames_in: int=0,\n",
+    "                 batch_frames_out: int=0,\n",
+    "                 batch_frames_inout: int=0,\n",
+    "                 preprocess_conf=None,\n",
+    "                 n_iter_processes: int=1,\n",
+    "                 subsampling_factor: int=1,\n",
+    "                 num_encs: int=1):\n",
+    "        self.json_file = json_file\n",
+    "        self.train_mode = train_mode\n",
+    "        self.use_sortagrad = sortagrad == -1 or sortagrad > 0\n",
+    "        self.batch_size = batch_size\n",
+    "        self.maxlen_in = maxlen_in\n",
+    "        self.maxlen_out = maxlen_out\n",
+    "        self.batch_count = batch_count\n",
+    "        self.batch_bins = batch_bins\n",
+    "        self.batch_frames_in = batch_frames_in\n",
+    "        self.batch_frames_out = batch_frames_out\n",
+    "        self.batch_frames_inout = batch_frames_inout\n",
+    "        self.subsampling_factor = subsampling_factor\n",
+    "        self.num_encs = num_encs\n",
+    "        self.preprocess_conf = preprocess_conf\n",
+    "        self.n_iter_processes = n_iter_processes\n",
+    "\n",
+    "        \n",
+    "        # read json data\n",
+    "        self.data_json = read_manifest(json_file)\n",
+    "\n",
+    "        # make minibatch list (variable length)\n",
+    "        self.minibaches = make_batchset(\n",
+    "            self.data_json,\n",
+    "            batch_size,\n",
+    "            maxlen_in,\n",
+    "            maxlen_out,\n",
+    "            minibatches,  # for debug\n",
+    "            min_batch_size=mini_batch_size,\n",
+    "            shortest_first=self.use_sortagrad,\n",
+    "            count=batch_count,\n",
+    "            batch_bins=batch_bins,\n",
+    "            batch_frames_in=batch_frames_in,\n",
+    "            batch_frames_out=batch_frames_out,\n",
+    "            batch_frames_inout=batch_frames_inout,\n",
+    "            iaxis=0,\n",
+    "            oaxis=0, )\n",
+    "\n",
+    "        # data reader\n",
+    "        self.reader = LoadInputsAndTargets(\n",
+    "            mode=\"asr\",\n",
+    "            load_output=True,\n",
+    "            preprocess_conf=preprocess_conf,\n",
+    "            preprocess_args={\"train\":\n",
+    "                             train_mode},  # Switch the mode of preprocessing\n",
+    "        )\n",
+    "\n",
+    "        # Setup a converter\n",
+    "        if num_encs == 1:\n",
+    "            self.converter = CustomConverter(\n",
+    "                subsampling_factor=subsampling_factor, dtype=np.float32)\n",
+    "        else:\n",
+    "            assert NotImplementedError(\"not impl CustomConverterMulEnc.\")\n",
+    "\n",
+    "        # hack to make batchsize argument as 1\n",
+    "        # actual bathsize is included in a list\n",
+    "        # default collate function converts numpy array to pytorch tensor\n",
+    "        # we used an empty collate function instead which returns list\n",
+    "        self.dataset = TransformDataset(self.minibaches, \n",
+    "                                        lambda data: self.converter([self.reader(data, return_uttid=True)]))\n",
+    "        self.dataloader = DataLoader(\n",
+    "            dataset=self.dataset,\n",
+    "            batch_size=1,\n",
+    "            shuffle=not use_sortagrad if train_mode else False,\n",
+    "            collate_fn=lambda x: x[0],\n",
+    "            num_workers=n_iter_processes, )\n",
+    "\n",
+    "    def __repr__(self):\n",
+    "        echo = f\"<{self.__class__.__module__}.{self.__class__.__name__} object at {hex(id(self))}> \"\n",
+    "        echo += f\"train_mode: {self.train_mode}, \"\n",
+    "        echo += f\"sortagrad: {self.use_sortagrad}, \"\n",
+    "        echo += f\"batch_size: {self.batch_size}, \"\n",
+    "        echo += f\"maxlen_in: {self.maxlen_in}, \"\n",
+    "        echo += f\"maxlen_out: {self.maxlen_out}, \"\n",
+    "        echo += f\"batch_count: {self.batch_count}, \"\n",
+    "        echo += f\"batch_bins: {self.batch_bins}, \"\n",
+    "        echo += f\"batch_frames_in: {self.batch_frames_in}, \"\n",
+    "        echo += f\"batch_frames_out: {self.batch_frames_out}, \"\n",
+    "        echo += f\"batch_frames_inout: {self.batch_frames_inout}, \"\n",
+    "        echo += f\"subsampling_factor: {self.subsampling_factor}, \"\n",
+    "        echo += f\"num_encs: {self.num_encs}, \"\n",
+    "        echo += f\"num_workers: {self.n_iter_processes}, \"\n",
+    "        echo += f\"file: {self.json_file}\"\n",
+    "        return echo\n",
+    "    \n",
+    "    def __len__(self):\n",
+    "        return len(self.dataloader)\n",
+    "    \n",
+    "    def __iter__(self):\n",
+    "        return self.dataloader.__iter__()\n",
+    "    \n",
+    "    def __call__(self):\n",
+    "        return self.__iter__()\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 231,
+   "id": "fcea3fd0",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "[INFO 2021/08/18 07:42:23 batchfy.py:399] count is auto detected as seq\n",
+      "[INFO 2021/08/18 07:42:23 batchfy.py:423] # utts: 5542\n",
+      "[INFO 2021/08/18 07:42:23 batchfy.py:466] # minibatches: 278\n"
+     ]
+    }
+   ],
+   "source": [
+    "train = BatchDataLoader(dev_data, True, batch_size=20)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 232,
+   "id": "e2a2c9a8",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "278\n",
+      "['__call__', '__class__', '__delattr__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__iter__', '__le__', '__len__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__', '__str__', '__subclasshook__', '__weakref__', 'auto_collate_batch', 'batch_sampler', 'batch_size', 'collate_fn', 'dataset', 'dataset_kind', 'feed_list', 'from_dataset', 'from_generator', 'num_workers', 'pin_memory', 'places', 'return_list', 'timeout', 'use_buffer_reader', 'use_shared_memory', 'worker_init_fn']\n",
+      "<__main__.BatchDataLoader object at 0x7fdddba35470> train_mode: True, sortagrad: False, batch_size: 20, maxlen_in: inf, maxlen_out: inf, batch_count: auto, batch_bins: 0, batch_frames_in: 0, batch_frames_out: 0, batch_frames_inout: 0, subsampling_factor: 1, num_encs: 1, num_workers: 1, file: /workspace/zhanghui/DeepSpeech-2.x/examples/librispeech/s2/data/manifest.dev\n",
+      "278\n"
+     ]
+    }
+   ],
+   "source": [
+    "print(len(train.dataloader))\n",
+    "print(dir(train.dataloader))\n",
+    "print(train)\n",
+    "print(len(train))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 220,
+   "id": "a5ba7d6e",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "['7601-101619-0003', '1255-138279-0000', '1272-128104-0004', '6123-59150-0027', '2078-142845-0025', '7850-73752-0018', '4570-24733-0004', '2506-169427-0002', '7601-101619-0004', '3170-137482-0000', '6267-53049-0019', '4570-14911-0009', '174-168635-0018', '7601-291468-0004', '3576-138058-0022', '1919-142785-0007', '6467-62797-0007', '4153-61735-0005', '1686-142278-0003', '2506-169427-0000']\n",
+      "Tensor(shape=[20, 2961, 83], dtype=float32, place=CUDAPinnedPlace, stop_gradient=True,\n",
+      "       [[[-1.99415934, -1.80315673, -1.88801885, ...,  0.86933994, -0.59853148,  0.02596200],\n",
+      "         [-1.95346808, -1.84891188, -2.17492867, ...,  0.83640492, -0.59853148, -0.11333394],\n",
+      "         [-2.27899861, -2.21495342, -2.58480024, ...,  0.91874266, -0.59853148, -0.31453922],\n",
+      "         ...,\n",
+      "         [-2.64522028, -2.35221887, -2.91269732, ...,  1.48994756, -0.16100442,  0.36646330],\n",
+      "         [-2.40107250, -2.21495342, -2.37986445, ...,  1.44072104, -0.13220564,  0.12656468],\n",
+      "         [-2.15692472, -1.89466715, -2.25690317, ...,  1.31273174, -0.09620714, -0.15202725]],\n",
+      "\n",
+      "        [[-0.28859532, -0.29033494, -0.86576819, ...,  1.37753224, -0.30570769,  0.25806731],\n",
+      "         [-0.20149794, -0.17814466, -0.59891301, ...,  1.35188794, -0.30570769, -0.02964944],\n",
+      "         [-0.34947991, -0.33597648, -0.96877253, ...,  1.38394332, -0.30570769, -0.38376236],\n",
+      "         ...,\n",
+      "         [ 0.        ,  0.        ,  0.        , ...,  0.        ,  0.        ,  0.        ],\n",
+      "         [ 0.        ,  0.        ,  0.        , ...,  0.        ,  0.        ,  0.        ],\n",
+      "         [ 0.        ,  0.        ,  0.        , ...,  0.        ,  0.        ,  0.        ]],\n",
+      "\n",
+      "        [[-0.44914246, -0.33902276, -0.78237975, ...,  1.38218808,  0.29214793, -0.16815147],\n",
+      "         [-0.55490732, -0.41596055, -0.84425378, ...,  1.34530187,  0.25002354, -0.04004869],\n",
+      "         [-0.83694696, -0.62112784, -1.07112527, ...,  1.19160914,  0.20789915,  0.37984371],\n",
+      "         ...,\n",
+      "         [ 0.        ,  0.        ,  0.        , ...,  0.        ,  0.        ,  0.        ],\n",
+      "         [ 0.        ,  0.        ,  0.        , ...,  0.        ,  0.        ,  0.        ],\n",
+      "         [ 0.        ,  0.        ,  0.        , ...,  0.        ,  0.        ,  0.        ]],\n",
+      "\n",
+      "        ...,\n",
+      "\n",
+      "        [[-1.24343657, -0.94188881, -1.41092563, ...,  0.96716309,  0.60345763,  0.15360183],\n",
+      "         [-1.19466043, -0.80585432, -0.49723154, ...,  1.06735480,  0.60345763,  0.14511746],\n",
+      "         [-0.94079566, -0.59330046, -0.40948665, ...,  0.82244170,  0.55614340,  0.28086722],\n",
+      "         ...,\n",
+      "         [ 0.        ,  0.        ,  0.        , ...,  0.        ,  0.        ,  0.        ],\n",
+      "         [ 0.        ,  0.        ,  0.        , ...,  0.        ,  0.        ,  0.        ],\n",
+      "         [ 0.        ,  0.        ,  0.        , ...,  0.        ,  0.        ,  0.        ]],\n",
+      "\n",
+      "        [[ 0.21757117,  0.11361472, -0.33262897, ...,  0.76338506, -0.10711290, -0.57754958],\n",
+      "         [-1.00205481, -0.61152041, -0.47124696, ...,  1.11897349, -0.10711290,  0.24931324],\n",
+      "         [-1.03929281, -1.20336759, -1.16433656, ...,  0.88888687, -0.10711290, -0.04115745],\n",
+      "         ...,\n",
+      "         [ 0.        ,  0.        ,  0.        , ...,  0.        ,  0.        ,  0.        ],\n",
+      "         [ 0.        ,  0.        ,  0.        , ...,  0.        ,  0.        ,  0.        ],\n",
+      "         [ 0.        ,  0.        ,  0.        , ...,  0.        ,  0.        ,  0.        ]],\n",
+      "\n",
+      "        [[-1.25289667, -1.05046368, -0.82881606, ...,  1.23991334,  0.61702502,  0.05275881],\n",
+      "         [-1.19659519, -0.78677225, -0.80407262, ...,  1.27644968,  0.61702502, -0.35079369],\n",
+      "         [-1.49687004, -1.01750231, -0.82881606, ...,  1.29106426,  0.65006059,  0.17958963],\n",
+      "         ...,\n",
+      "         [ 0.        ,  0.        ,  0.        , ...,  0.        ,  0.        ,  0.        ],\n",
+      "         [ 0.        ,  0.        ,  0.        , ...,  0.        ,  0.        ,  0.        ],\n",
+      "         [ 0.        ,  0.        ,  0.        , ...,  0.        ,  0.        ,  0.        ]]])\n",
+      "Tensor(shape=[20], dtype=int64, place=CUDAPinnedPlace, stop_gradient=True,\n",
+      "       [2961, 2948, 2938, 2907, 2904, 2838, 2832, 2819, 2815, 2797, 2775, 2710, 2709, 2696, 2688, 2661, 2616, 2595, 2589, 2576])\n",
+      "Tensor(shape=[20, 133], dtype=int64, place=CUDAPinnedPlace, stop_gradient=True,\n",
+      "       [[3098, 1595,  389, ..., -1  , -1  , -1  ],\n",
+      "        [2603, 4832,  482, ..., -1  , -1  , -1  ],\n",
+      "        [2796,  303,  269, ..., -1  , -1  , -1  ],\n",
+      "        ...,\n",
+      "        [3218, 3673,  206, ..., -1  , -1  , -1  ],\n",
+      "        [2371, 4832, 4031, ..., -1  , -1  , -1  ],\n",
+      "        [2570, 2433, 4285, ..., -1  , -1  , -1  ]])\n",
+      "Tensor(shape=[20], dtype=int64, place=CUDAPinnedPlace, stop_gradient=True,\n",
+      "       [80 , 83 , 102, 133, 82 , 102, 71 , 91 , 68 , 81 , 86 , 67 , 71 , 95 , 65 , 88 , 97 , 98 , 89 , 72 ])\n"
+     ]
+    }
+   ],
+   "source": [
+    "for batch in train:\n",
+    "    utts, xs, ilens, ys, olens = batch\n",
+    "    print(utts)\n",
+    "    print(xs)\n",
+    "    print(ilens)\n",
+    "    print(ys)\n",
+    "    print(olens)\n",
+    "    break"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "3c974a1e",
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3 (ipykernel)",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.7.0"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/.notebook/jit_infer.ipynb b/.notebook/jit_infer.ipynb
index ba50d8743..20882c1ae 100644
--- a/.notebook/jit_infer.ipynb
+++ b/.notebook/jit_infer.ipynb
@@ -83,8 +83,8 @@
     "from deepspeech.frontend.utility import read_manifest\n",
     "from deepspeech.utils.utility import add_arguments, print_arguments\n",
     "\n",
-    "from deepspeech.models.deepspeech2 import DeepSpeech2Model\n",
-    "from deepspeech.models.deepspeech2 import DeepSpeech2InferModel\n",
+    "from deepspeech.models.ds2 import DeepSpeech2Model\n",
+    "from deepspeech.models.ds2 import DeepSpeech2InferModel\n",
     "from deepspeech.io.dataset import ManifestDataset\n",
     "\n",
     "\n",
@@ -669,4 +669,4 @@
  },
  "nbformat": 4,
  "nbformat_minor": 2
-}
\ No newline at end of file
+}
diff --git a/.notebook/u2_confermer_model_wenet.ipynb b/.notebook/u2_confermer_model_wenet.ipynb
index 4f2c9632f..a425e16cb 100644
--- a/.notebook/u2_confermer_model_wenet.ipynb
+++ b/.notebook/u2_confermer_model_wenet.ipynb
@@ -3431,7 +3431,7 @@
     "        convolution_layer_args = (output_size, cnn_module_kernel, activation,\n",
     "                                  cnn_module_norm, causal)\n",
     "\n",
-    "        self.encoders = nn.ModuleList([\n",
+    "        self.encoders = nn.LayerList([\n",
     "            ConformerEncoderLayer(\n",
     "                size=output_size,\n",
     "                self_attn=encoder_selfattn_layer(*encoder_selfattn_layer_args),\n",
diff --git a/README.md b/README.md
index f7d1e0882..de24abe2f 100644
--- a/README.md
+++ b/README.md
@@ -1,12 +1,12 @@
 [中文版](README_cn.md)
 
-# PaddlePaddle ASR toolkit
+# PaddlePaddle Speech to Any toolkit
 
 ![License](https://img.shields.io/badge/license-Apache%202-red.svg)
 ![python version](https://img.shields.io/badge/python-3.7+-orange.svg)
 ![support os](https://img.shields.io/badge/os-linux-yellow.svg)
 
-*PaddleASR* is an open-source implementation of end-to-end Automatic Speech Recognition (ASR) engine, with [PaddlePaddle](https://github.com/PaddlePaddle/Paddle) platform. Our vision is to empower both industrial application and academic research on speech recognition, via an easy-to-use, efficient, samller and scalable implementation, including training, inference & testing module, and deployment.
+*DeepSpeech* is an open-source implementation of end-to-end Automatic Speech Recognition engine, with [PaddlePaddle](https://github.com/PaddlePaddle/Paddle) platform. Our vision is to empower both industrial application and academic research on speech recognition, via an easy-to-use, efficient, samller and scalable implementation, including training, inference & testing module, and deployment.
 
 
 ## Features
@@ -15,6 +15,8 @@
 
 ## Setup
 
+All tested under:  
+* Ubuntu 16.04
 * python>=3.7
 * paddlepaddle>=2.1.2
 
diff --git a/README_cn.md b/README_cn.md
index 019b38c15..4b9273625 100644
--- a/README_cn.md
+++ b/README_cn.md
@@ -1,12 +1,12 @@
 [English](README.md)
 
-# PaddlePaddle ASR toolkit
+# PaddlePaddle Speech to Any toolkit
 
 ![License](https://img.shields.io/badge/license-Apache%202-red.svg)
 ![python version](https://img.shields.io/badge/python-3.7+-orange.svg)
 ![support os](https://img.shields.io/badge/os-linux-yellow.svg)
 
-*PaddleASR*是一个采用[PaddlePaddle](https://github.com/PaddlePaddle/Paddle)平台的端到端自动语音识别(ASR)引擎的开源项目,
+*DeepSpeech*是一个采用[PaddlePaddle](https://github.com/PaddlePaddle/Paddle)平台的端到端自动语音识别引擎的开源项目,
 我们的愿景是为语音识别在工业应用和学术研究上,提供易于使用、高效、小型化和可扩展的工具,包括训练,推理,以及  部署。
 
 ## 特性
@@ -16,6 +16,9 @@
 
 ## 安装
 
+在以下环境测试验证过:  
+
+* Ubuntu 16.04
 * python>=3.7
 * paddlepaddle>=2.1.2
 
diff --git a/deepspeech/__init__.py b/deepspeech/__init__.py
index 37531657e..fbec5a5e8 100644
--- a/deepspeech/__init__.py
+++ b/deepspeech/__init__.py
@@ -30,24 +30,13 @@ logger = Log(__name__).getlog()
 logger.warn = logger.warning
 
 ########### hcak paddle #############
-paddle.bool = 'bool'
-paddle.float16 = 'float16'
 paddle.half = 'float16'
-paddle.float32 = 'float32'
 paddle.float = 'float32'
-paddle.float64 = 'float64'
 paddle.double = 'float64'
-paddle.int8 = 'int8'
-paddle.int16 = 'int16'
 paddle.short = 'int16'
-paddle.int32 = 'int32'
 paddle.int = 'int32'
-paddle.int64 = 'int64'
 paddle.long = 'int64'
-paddle.uint8 = 'uint8'
 paddle.uint16 = 'uint16'
-paddle.complex64 = 'complex64'
-paddle.complex128 = 'complex128'
 paddle.cdouble = 'complex128'
 
 
@@ -363,85 +352,8 @@ if not hasattr(paddle.Tensor, 'tolist'):
         "register user tolist to paddle.Tensor, remove this when fixed!")
     setattr(paddle.Tensor, 'tolist', tolist)
 
-########### hcak paddle.nn.functional #############
-
-
-def glu(x: paddle.Tensor, axis=-1) -> paddle.Tensor:
-    """The gated linear unit (GLU) activation."""
-    a, b = x.split(2, axis=axis)
-    act_b = F.sigmoid(b)
-    return a * act_b
-
-
-if not hasattr(paddle.nn.functional, 'glu'):
-    logger.warn(
-        "register user glu to paddle.nn.functional, remove this when fixed!")
-    setattr(paddle.nn.functional, 'glu', glu)
-
-# def softplus(x):
-#     """Softplus function."""
-#     if hasattr(paddle.nn.functional, 'softplus'):
-#         #return paddle.nn.functional.softplus(x.float()).type_as(x)
-#         return paddle.nn.functional.softplus(x)
-#     else:
-#         raise NotImplementedError
-
-# def gelu_accurate(x):
-#     """Gaussian Error Linear Units (GELU) activation."""
-#     # [reference] https://github.com/pytorch/fairseq/blob/e75cff5f2c1d62f12dc911e0bf420025eb1a4e33/fairseq/modules/gelu.py
-#     if not hasattr(gelu_accurate, "_a"):
-#         gelu_accurate._a = math.sqrt(2 / math.pi)
-#     return 0.5 * x * (1 + paddle.tanh(gelu_accurate._a *
-#                                       (x + 0.044715 * paddle.pow(x, 3))))
-
-# def gelu(x):
-#     """Gaussian Error Linear Units (GELU) activation."""
-#     if hasattr(nn.functional, 'gelu'):
-#         #return nn.functional.gelu(x.float()).type_as(x)
-#         return nn.functional.gelu(x)
-#     else:
-#         return x * 0.5 * (1.0 + paddle.erf(x / math.sqrt(2.0)))
-
-
-# hack loss
-def ctc_loss(logits,
-             labels,
-             input_lengths,
-             label_lengths,
-             blank=0,
-             reduction='mean',
-             norm_by_times=True):
-    #logger.info("my ctc loss with norm by times")
-    ## https://github.com/PaddlePaddle/Paddle/blob/f5ca2db2cc/paddle/fluid/operators/warpctc_op.h#L403
-    loss_out = paddle.fluid.layers.warpctc(logits, labels, blank, norm_by_times,
-                                           input_lengths, label_lengths)
-
-    loss_out = paddle.fluid.layers.squeeze(loss_out, [-1])
-    assert reduction in ['mean', 'sum', 'none']
-    if reduction == 'mean':
-        loss_out = paddle.mean(loss_out / label_lengths)
-    elif reduction == 'sum':
-        loss_out = paddle.sum(loss_out)
-    return loss_out
-
-
-logger.warn(
-    "override ctc_loss of paddle.nn.functional if exists, remove this when fixed!"
-)
-F.ctc_loss = ctc_loss
 
 ########### hcak paddle.nn #############
-if not hasattr(paddle.nn, 'Module'):
-    logger.warn("register user Module to paddle.nn, remove this when fixed!")
-    setattr(paddle.nn, 'Module', paddle.nn.Layer)
-
-# maybe cause assert isinstance(sublayer, core.Layer)
-if not hasattr(paddle.nn, 'ModuleList'):
-    logger.warn(
-        "register user ModuleList to paddle.nn, remove this when fixed!")
-    setattr(paddle.nn, 'ModuleList', paddle.nn.LayerList)
-
-
 class GLU(nn.Layer):
     """Gated Linear Units (GLU) Layer"""
 
@@ -450,48 +362,9 @@ class GLU(nn.Layer):
         self.dim = dim
 
     def forward(self, xs):
-        return glu(xs, dim=self.dim)
+        return F.glu(xs, dim=self.dim)
 
 
 if not hasattr(paddle.nn, 'GLU'):
     logger.warn("register user GLU to paddle.nn, remove this when fixed!")
     setattr(paddle.nn, 'GLU', GLU)
-
-
-# TODO(Hui Zhang): remove this Layer
-class ConstantPad2d(nn.Layer):
-    """Pads the input tensor boundaries with a constant value.
-    For N-dimensional padding, use paddle.nn.functional.pad().
-    """
-
-    def __init__(self, padding: Union[tuple, list, int], value: float):
-        """
-        Args:
-            paddle ([tuple]): the size of the padding.
-                If is int, uses the same padding in all boundaries.
-                If a 4-tuple, uses (padding_left, padding_right, padding_top, padding_bottom)
-            value ([flaot]): pad value
-        """
-        self.padding = padding if isinstance(padding,
-                                             [tuple, list]) else [padding] * 4
-        self.value = value
-
-    def forward(self, xs: paddle.Tensor) -> paddle.Tensor:
-        return nn.functional.pad(
-            xs,
-            self.padding,
-            mode='constant',
-            value=self.value,
-            data_format='NCHW')
-
-
-if not hasattr(paddle.nn, 'ConstantPad2d'):
-    logger.warn(
-        "register user ConstantPad2d to paddle.nn, remove this when fixed!")
-    setattr(paddle.nn, 'ConstantPad2d', ConstantPad2d)
-
-########### hcak paddle.jit #############
-
-if not hasattr(paddle.jit, 'export'):
-    logger.warn("register user export to paddle.jit, remove this when fixed!")
-    setattr(paddle.jit, 'export', paddle.jit.to_static)
diff --git a/deepspeech/decoders/swig/setup.py b/deepspeech/decoders/swig/setup.py
index 86af475af..3da5ce8bf 100644
--- a/deepspeech/decoders/swig/setup.py
+++ b/deepspeech/decoders/swig/setup.py
@@ -84,8 +84,9 @@ FILES = glob.glob('kenlm/util/*.cc') \
 FILES += glob.glob('openfst-1.6.3/src/lib/*.cc')
 
 FILES = [
-    fn for fn in FILES if not (fn.endswith('main.cc') or fn.endswith('test.cc')
-                               or fn.endswith('unittest.cc'))
+    fn for fn in FILES
+    if not (fn.endswith('main.cc') or fn.endswith('test.cc') or fn.endswith(
+        'unittest.cc'))
 ]
 
 LIBS = ['stdc++']
diff --git a/deepspeech/exps/deepspeech2/bin/deploy/runtime.py b/deepspeech/exps/deepspeech2/bin/deploy/runtime.py
index 01f01b651..21ffa6bf4 100644
--- a/deepspeech/exps/deepspeech2/bin/deploy/runtime.py
+++ b/deepspeech/exps/deepspeech2/bin/deploy/runtime.py
@@ -23,7 +23,7 @@ from paddle.io import DataLoader
 from deepspeech.exps.deepspeech2.config import get_cfg_defaults
 from deepspeech.io.collator import SpeechCollator
 from deepspeech.io.dataset import ManifestDataset
-from deepspeech.models.deepspeech2 import DeepSpeech2Model
+from deepspeech.models.ds2 import DeepSpeech2Model
 from deepspeech.training.cli import default_argument_parser
 from deepspeech.utils.socket_server import AsrRequestHandler
 from deepspeech.utils.socket_server import AsrTCPServer
diff --git a/deepspeech/exps/deepspeech2/bin/deploy/server.py b/deepspeech/exps/deepspeech2/bin/deploy/server.py
index b473a8fd4..583e90950 100644
--- a/deepspeech/exps/deepspeech2/bin/deploy/server.py
+++ b/deepspeech/exps/deepspeech2/bin/deploy/server.py
@@ -21,7 +21,7 @@ from paddle.io import DataLoader
 from deepspeech.exps.deepspeech2.config import get_cfg_defaults
 from deepspeech.io.collator import SpeechCollator
 from deepspeech.io.dataset import ManifestDataset
-from deepspeech.models.deepspeech2 import DeepSpeech2Model
+from deepspeech.models.ds2 import DeepSpeech2Model
 from deepspeech.training.cli import default_argument_parser
 from deepspeech.utils.socket_server import AsrRequestHandler
 from deepspeech.utils.socket_server import AsrTCPServer
diff --git a/deepspeech/exps/deepspeech2/bin/export.py b/deepspeech/exps/deepspeech2/bin/export.py
index f8764fde3..7962d4fc0 100644
--- a/deepspeech/exps/deepspeech2/bin/export.py
+++ b/deepspeech/exps/deepspeech2/bin/export.py
@@ -30,6 +30,9 @@ def main(config, args):
 
 if __name__ == "__main__":
     parser = default_argument_parser()
+    # save jit model to 
+    parser.add_argument(
+        "--export_path", type=str, help="path of the jit model to save")
     parser.add_argument("--model_type")
     args = parser.parse_args()
     if args.model_type is None:
diff --git a/deepspeech/exps/deepspeech2/bin/test.py b/deepspeech/exps/deepspeech2/bin/test.py
index 376e18e38..f2fd3a394 100644
--- a/deepspeech/exps/deepspeech2/bin/test.py
+++ b/deepspeech/exps/deepspeech2/bin/test.py
@@ -31,6 +31,9 @@ def main(config, args):
 if __name__ == "__main__":
     parser = default_argument_parser()
     parser.add_argument("--model_type")
+    # save asr result to 
+    parser.add_argument(
+        "--result_file", type=str, help="path of save the asr result")
     args = parser.parse_args()
     print_arguments(args, globals())
     if args.model_type is None:
diff --git a/deepspeech/exps/deepspeech2/bin/tune.py b/deepspeech/exps/deepspeech2/bin/tune.py
index f10dc27ce..94a9b6c47 100644
--- a/deepspeech/exps/deepspeech2/bin/tune.py
+++ b/deepspeech/exps/deepspeech2/bin/tune.py
@@ -21,7 +21,7 @@ from paddle.io import DataLoader
 from deepspeech.exps.deepspeech2.config import get_cfg_defaults
 from deepspeech.io.collator import SpeechCollator
 from deepspeech.io.dataset import ManifestDataset
-from deepspeech.models.deepspeech2 import DeepSpeech2Model
+from deepspeech.models.ds2 import DeepSpeech2Model
 from deepspeech.training.cli import default_argument_parser
 from deepspeech.utils import error_rate
 from deepspeech.utils.utility import add_arguments
diff --git a/deepspeech/exps/u2/bin/alignment.py b/deepspeech/exps/u2/bin/alignment.py
index c1c9582f8..cef9d1ab9 100644
--- a/deepspeech/exps/u2/bin/alignment.py
+++ b/deepspeech/exps/u2/bin/alignment.py
@@ -30,6 +30,9 @@ def main(config, args):
 
 if __name__ == "__main__":
     parser = default_argument_parser()
+    # save asr result to 
+    parser.add_argument(
+        "--result_file", type=str, help="path of save the asr result")
     args = parser.parse_args()
     print_arguments(args, globals())
 
diff --git a/deepspeech/exps/u2/bin/export.py b/deepspeech/exps/u2/bin/export.py
index 292c78389..3dc41b706 100644
--- a/deepspeech/exps/u2/bin/export.py
+++ b/deepspeech/exps/u2/bin/export.py
@@ -30,6 +30,9 @@ def main(config, args):
 
 if __name__ == "__main__":
     parser = default_argument_parser()
+    # save jit model to 
+    parser.add_argument(
+        "--export_path", type=str, help="path of the jit model to save")
     args = parser.parse_args()
     print_arguments(args, globals())
 
diff --git a/deepspeech/exps/u2/bin/test.py b/deepspeech/exps/u2/bin/test.py
index c47f932c7..f6127675e 100644
--- a/deepspeech/exps/u2/bin/test.py
+++ b/deepspeech/exps/u2/bin/test.py
@@ -34,6 +34,9 @@ def main(config, args):
 
 if __name__ == "__main__":
     parser = default_argument_parser()
+    # save asr result to 
+    parser.add_argument(
+        "--result_file", type=str, help="path of save the asr result")
     args = parser.parse_args()
     print_arguments(args, globals())
 
diff --git a/deepspeech/exps/u2/model.py b/deepspeech/exps/u2/model.py
index d661f078d..0662e38d9 100644
--- a/deepspeech/exps/u2/model.py
+++ b/deepspeech/exps/u2/model.py
@@ -264,12 +264,12 @@ class U2Trainer(Trainer):
         config.data.manifest = config.data.test_manifest
         # filter test examples, will cause less examples, but no mismatch with training
         # and can use large batch size , save training time, so filter test egs now.
-        # config.data.min_input_len = 0.0  # second
-        # config.data.max_input_len = float('inf')  # second
-        # config.data.min_output_len = 0.0  # tokens
-        # config.data.max_output_len = float('inf')  # tokens
-        # config.data.min_output_input_ratio = 0.00
-        # config.data.max_output_input_ratio = float('inf')
+        config.data.min_input_len = 0.0  # second
+        config.data.max_input_len = float('inf')  # second
+        config.data.min_output_len = 0.0  # tokens
+        config.data.max_output_len = float('inf')  # tokens
+        config.data.min_output_input_ratio = 0.00
+        config.data.max_output_input_ratio = float('inf')
 
         test_dataset = ManifestDataset.from_config(config)
         # return text ord id
diff --git a/deepspeech/exps/u2_kaldi/__init__.py b/deepspeech/exps/u2_kaldi/__init__.py
new file mode 100644
index 000000000..185a92b8d
--- /dev/null
+++ b/deepspeech/exps/u2_kaldi/__init__.py
@@ -0,0 +1,13 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/deepspeech/exps/u2_kaldi/bin/test.py b/deepspeech/exps/u2_kaldi/bin/test.py
new file mode 100644
index 000000000..93a29ab15
--- /dev/null
+++ b/deepspeech/exps/u2_kaldi/bin/test.py
@@ -0,0 +1,83 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Evaluation for U2 model."""
+import cProfile
+
+from yacs.config import CfgNode
+
+from deepspeech.training.cli import default_argument_parser
+from deepspeech.utils.dynamic_import import dynamic_import
+from deepspeech.utils.utility import print_arguments
+
+model_test_alias = {
+    "u2": "deepspeech.exps.u2.model:U2Tester",
+    "u2_kaldi": "deepspeech.exps.u2_kaldi.model:U2Tester",
+}
+
+
+def main_sp(config, args):
+    class_obj = dynamic_import(args.model_name, model_test_alias)
+    exp = class_obj(config, args)
+    exp.setup()
+
+    if args.run_mode == 'test':
+        exp.run_test()
+    elif args.run_mode == 'export':
+        exp.run_export()
+    elif args.run_mode == 'align':
+        exp.run_align()
+
+
+def main(config, args):
+    main_sp(config, args)
+
+
+if __name__ == "__main__":
+    parser = default_argument_parser()
+    parser.add_argument(
+        '--model-name',
+        type=str,
+        default='u2_kaldi',
+        help='model name, e.g: deepspeech2, u2, u2_kaldi, u2_st')
+    parser.add_argument(
+        '--run-mode',
+        type=str,
+        default='test',
+        help='run mode, e.g. test, align, export')
+    parser.add_argument(
+        '--dict-path', type=str, default=None, help='dict path.')
+    # save asr result to 
+    parser.add_argument(
+        "--result-file", type=str, help="path of save the asr result")
+    # save jit model to 
+    parser.add_argument(
+        "--export-path", type=str, help="path of the jit model to save")
+    args = parser.parse_args()
+    print_arguments(args, globals())
+
+    config = CfgNode()
+    config.set_new_allowed(True)
+    config.merge_from_file(args.config)
+    if args.opts:
+        config.merge_from_list(args.opts)
+    config.freeze()
+    print(config)
+    if args.dump_config:
+        with open(args.dump_config, 'w') as f:
+            print(config, file=f)
+
+    # Setting for profiling
+    pr = cProfile.Profile()
+    pr.runcall(main, config, args)
+    pr.dump_stats('test.profile')
diff --git a/deepspeech/exps/u2_kaldi/bin/train.py b/deepspeech/exps/u2_kaldi/bin/train.py
new file mode 100644
index 000000000..1dcd154d3
--- /dev/null
+++ b/deepspeech/exps/u2_kaldi/bin/train.py
@@ -0,0 +1,69 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Trainer for U2 model."""
+import cProfile
+import os
+
+from paddle import distributed as dist
+from yacs.config import CfgNode
+
+from deepspeech.training.cli import default_argument_parser
+from deepspeech.utils.dynamic_import import dynamic_import
+from deepspeech.utils.utility import print_arguments
+
+model_train_alias = {
+    "u2": "deepspeech.exps.u2.model:U2Trainer",
+    "u2_kaldi": "deepspeech.exps.u2_kaldi.model:U2Trainer",
+}
+
+
+def main_sp(config, args):
+    class_obj = dynamic_import(args.model_name, model_train_alias)
+    exp = class_obj(config, args)
+    exp.setup()
+    exp.run()
+
+
+def main(config, args):
+    if args.device == "gpu" and args.nprocs > 1:
+        dist.spawn(main_sp, args=(config, args), nprocs=args.nprocs)
+    else:
+        main_sp(config, args)
+
+
+if __name__ == "__main__":
+    parser = default_argument_parser()
+    parser.add_argument(
+        '--model-name',
+        type=str,
+        default='u2_kaldi',
+        help='model name, e.g: deepspeech2, u2, u2_kaldi, u2_st')
+    args = parser.parse_args()
+    print_arguments(args, globals())
+
+    config = CfgNode()
+    config.set_new_allowed(True)
+    config.merge_from_file(args.config)
+    if args.opts:
+        config.merge_from_list(args.opts)
+    config.freeze()
+    print(config)
+    if args.dump_config:
+        with open(args.dump_config, 'w') as f:
+            print(config, file=f)
+
+    # Setting for profiling
+    pr = cProfile.Profile()
+    pr.runcall(main, config, args)
+    pr.dump_stats(os.path.join(args.output, 'train.profile'))
diff --git a/deepspeech/exps/u2_kaldi/model.py b/deepspeech/exps/u2_kaldi/model.py
new file mode 100644
index 000000000..4f6ff4cb9
--- /dev/null
+++ b/deepspeech/exps/u2_kaldi/model.py
@@ -0,0 +1,654 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Contains U2 model."""
+import json
+import os
+import sys
+import time
+from collections import defaultdict
+from pathlib import Path
+from typing import Optional
+
+import numpy as np
+import paddle
+from paddle import distributed as dist
+from yacs.config import CfgNode
+
+from deepspeech.frontend.featurizer import TextFeaturizer
+from deepspeech.frontend.utility import load_dict
+from deepspeech.io.dataloader import BatchDataLoader
+from deepspeech.models.u2 import U2Model
+from deepspeech.training.optimizer import OptimizerFactory
+from deepspeech.training.scheduler import LRSchedulerFactory
+from deepspeech.training.trainer import Trainer
+from deepspeech.utils import ctc_utils
+from deepspeech.utils import error_rate
+from deepspeech.utils import layer_tools
+from deepspeech.utils import mp_tools
+from deepspeech.utils import text_grid
+from deepspeech.utils import utility
+from deepspeech.utils.log import Log
+
+logger = Log(__name__).getlog()
+
+
+def get_cfg_defaults():
+    """Get a yacs CfgNode object with default values for my_project."""
+    # Return a clone so that the defaults will not be altered
+    # This is for the "local variable" use pattern
+    _C = CfgNode()
+
+    _C.model = U2Model.params()
+
+    _C.training = U2Trainer.params()
+
+    _C.decoding = U2Tester.params()
+
+    config = _C.clone()
+    config.set_new_allowed(True)
+    return config
+
+
+class U2Trainer(Trainer):
+    @classmethod
+    def params(cls, config: Optional[CfgNode]=None) -> CfgNode:
+        # training config
+        default = CfgNode(
+            dict(
+                n_epoch=50,  # train epochs
+                log_interval=100,  # steps
+                accum_grad=1,  # accum grad by # steps
+                checkpoint=dict(
+                    kbest_n=50,
+                    latest_n=5, ), ))
+        if config is not None:
+            config.merge_from_other_cfg(default)
+        return default
+
+    def __init__(self, config, args):
+        super().__init__(config, args)
+
+    def train_batch(self, batch_index, batch_data, msg):
+        train_conf = self.config.training
+        start = time.time()
+
+        utt, audio, audio_len, text, text_len = batch_data
+        loss, attention_loss, ctc_loss = self.model(audio, audio_len, text,
+                                                    text_len)
+        # loss div by `batch_size * accum_grad`
+        loss /= train_conf.accum_grad
+        loss.backward()
+        layer_tools.print_grads(self.model, print_func=None)
+
+        losses_np = {'loss': float(loss) * train_conf.accum_grad}
+        if attention_loss:
+            losses_np['att_loss'] = float(attention_loss)
+        if ctc_loss:
+            losses_np['ctc_loss'] = float(ctc_loss)
+
+        if (batch_index + 1) % train_conf.accum_grad == 0:
+            self.optimizer.step()
+            self.optimizer.clear_grad()
+            self.lr_scheduler.step()
+            self.iteration += 1
+
+        iteration_time = time.time() - start
+
+        if (batch_index + 1) % train_conf.log_interval == 0:
+            msg += "train time: {:>.3f}s, ".format(iteration_time)
+            msg += "batch size: {}, ".format(self.config.collator.batch_size)
+            msg += "accum: {}, ".format(train_conf.accum_grad)
+            msg += ', '.join('{}: {:>.6f}'.format(k, v)
+                             for k, v in losses_np.items())
+            logger.info(msg)
+
+            if dist.get_rank() == 0 and self.visualizer:
+                losses_np_v = losses_np.copy()
+                losses_np_v.update({"lr": self.lr_scheduler()})
+                self.visualizer.add_scalars("step", losses_np_v,
+                                            self.iteration - 1)
+
+    @paddle.no_grad()
+    def valid(self):
+        self.model.eval()
+        logger.info(f"Valid Total Examples: {len(self.valid_loader.dataset)}")
+        valid_losses = defaultdict(list)
+        num_seen_utts = 1
+        total_loss = 0.0
+
+        for i, batch in enumerate(self.valid_loader):
+            utt, audio, audio_len, text, text_len = batch
+            loss, attention_loss, ctc_loss = self.model(audio, audio_len, text,
+                                                        text_len)
+            if paddle.isfinite(loss):
+                num_utts = batch[1].shape[0]
+                num_seen_utts += num_utts
+                total_loss += float(loss) * num_utts
+                valid_losses['val_loss'].append(float(loss))
+                if attention_loss:
+                    valid_losses['val_att_loss'].append(float(attention_loss))
+                if ctc_loss:
+                    valid_losses['val_ctc_loss'].append(float(ctc_loss))
+
+            if (i + 1) % self.config.training.log_interval == 0:
+                valid_dump = {k: np.mean(v) for k, v in valid_losses.items()}
+                valid_dump['val_history_loss'] = total_loss / num_seen_utts
+
+                # logging
+                msg = f"Valid: Rank: {dist.get_rank()}, "
+                msg += "epoch: {}, ".format(self.epoch)
+                msg += "step: {}, ".format(self.iteration)
+                msg += "batch: {}/{}, ".format(i + 1, len(self.valid_loader))
+                msg += ', '.join('{}: {:>.6f}'.format(k, v)
+                                 for k, v in valid_dump.items())
+                logger.info(msg)
+
+        logger.info('Rank {} Val info val_loss {}'.format(
+            dist.get_rank(), total_loss / num_seen_utts))
+        return total_loss, num_seen_utts
+
+    def train(self):
+        """The training process control by step."""
+        # !!!IMPORTANT!!!
+        # Try to export the model by script, if fails, we should refine
+        # the code to satisfy the script export requirements
+        # script_model = paddle.jit.to_static(self.model)
+        # script_model_path = str(self.checkpoint_dir / 'init')
+        # paddle.jit.save(script_model, script_model_path)
+
+        from_scratch = self.resume_or_scratch()
+        if from_scratch:
+            # save init model, i.e. 0 epoch
+            self.save(tag='init')
+
+        self.lr_scheduler.step(self.iteration)
+        if self.parallel:
+            self.train_loader.batch_sampler.set_epoch(self.epoch)
+
+        logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}")
+        while self.epoch < self.config.training.n_epoch:
+            self.model.train()
+            try:
+                data_start_time = time.time()
+                for batch_index, batch in enumerate(self.train_loader):
+                    dataload_time = time.time() - data_start_time
+                    msg = "Train: Rank: {}, ".format(dist.get_rank())
+                    msg += "epoch: {}, ".format(self.epoch)
+                    msg += "step: {}, ".format(self.iteration)
+                    msg += "batch : {}/{}, ".format(batch_index + 1,
+                                                    len(self.train_loader))
+                    msg += "lr: {:>.8f}, ".format(self.lr_scheduler())
+                    msg += "data time: {:>.3f}s, ".format(dataload_time)
+                    self.train_batch(batch_index, batch, msg)
+                    data_start_time = time.time()
+            except Exception as e:
+                logger.error(e)
+                raise e
+
+            total_loss, num_seen_utts = self.valid()
+            if dist.get_world_size() > 1:
+                num_seen_utts = paddle.to_tensor(num_seen_utts)
+                # the default operator in all_reduce function is sum.
+                dist.all_reduce(num_seen_utts)
+                total_loss = paddle.to_tensor(total_loss)
+                dist.all_reduce(total_loss)
+                cv_loss = total_loss / num_seen_utts
+                cv_loss = float(cv_loss)
+            else:
+                cv_loss = total_loss / num_seen_utts
+
+            logger.info(
+                'Epoch {} Val info val_loss {}'.format(self.epoch, cv_loss))
+            if self.visualizer:
+                self.visualizer.add_scalars(
+                    'epoch', {'cv_loss': cv_loss,
+                              'lr': self.lr_scheduler()}, self.epoch)
+            self.save(tag=self.epoch, infos={'val_loss': cv_loss})
+            self.new_epoch()
+
+    def setup_dataloader(self):
+        config = self.config.clone()
+        # train/valid dataset, return token ids
+        self.train_loader = BatchDataLoader(
+            json_file=config.data.train_manifest,
+            train_mode=True,
+            sortagrad=False,
+            batch_size=config.collator.batch_size,
+            maxlen_in=float('inf'),
+            maxlen_out=float('inf'),
+            minibatches=0,
+            mini_batch_size=1,
+            batch_count='auto',
+            batch_bins=0,
+            batch_frames_in=0,
+            batch_frames_out=0,
+            batch_frames_inout=0,
+            preprocess_conf=config.collator.augmentation_config,
+            n_iter_processes=config.collator.num_workers,
+            subsampling_factor=1,
+            num_encs=1)
+
+        self.valid_loader = BatchDataLoader(
+            json_file=config.data.dev_manifest,
+            train_mode=False,
+            sortagrad=False,
+            batch_size=config.collator.batch_size,
+            maxlen_in=float('inf'),
+            maxlen_out=float('inf'),
+            minibatches=0,
+            mini_batch_size=1,
+            batch_count='auto',
+            batch_bins=0,
+            batch_frames_in=0,
+            batch_frames_out=0,
+            batch_frames_inout=0,
+            preprocess_conf=None,
+            n_iter_processes=1,
+            subsampling_factor=1,
+            num_encs=1)
+
+        # test dataset, return raw text
+        self.test_loader = BatchDataLoader(
+            json_file=config.data.test_manifest,
+            train_mode=False,
+            sortagrad=False,
+            batch_size=config.collator.batch_size,
+            maxlen_in=float('inf'),
+            maxlen_out=float('inf'),
+            minibatches=0,
+            mini_batch_size=1,
+            batch_count='auto',
+            batch_bins=0,
+            batch_frames_in=0,
+            batch_frames_out=0,
+            batch_frames_inout=0,
+            preprocess_conf=None,
+            n_iter_processes=1,
+            subsampling_factor=1,
+            num_encs=1)
+
+        self.align_loader = BatchDataLoader(
+            json_file=config.data.test_manifest,
+            train_mode=False,
+            sortagrad=False,
+            batch_size=config.collator.batch_size,
+            maxlen_in=float('inf'),
+            maxlen_out=float('inf'),
+            minibatches=0,
+            mini_batch_size=1,
+            batch_count='auto',
+            batch_bins=0,
+            batch_frames_in=0,
+            batch_frames_out=0,
+            batch_frames_inout=0,
+            preprocess_conf=None,
+            n_iter_processes=1,
+            subsampling_factor=1,
+            num_encs=1)
+        logger.info("Setup train/valid/test/align Dataloader!")
+
+    def setup_model(self):
+        config = self.config
+
+        # model
+        model_conf = config.model
+        model_conf.defrost()
+        model_conf.input_dim = self.train_loader.feat_dim
+        model_conf.output_dim = self.train_loader.vocab_size
+        model_conf.freeze()
+        model = U2Model.from_config(model_conf)
+        if self.parallel:
+            model = paddle.DataParallel(model)
+        logger.info(f"{model}")
+        layer_tools.print_params(model, logger.info)
+
+        # lr
+        scheduler_conf = config.scheduler_conf
+        scheduler_args = {
+            "learning_rate": scheduler_conf.lr,
+            "warmup_steps": scheduler_conf.warmup_steps,
+            "gamma": scheduler_conf.lr_decay,
+            "d_model": model_conf.encoder_conf.output_size,
+            "verbose": False,
+        }
+        lr_scheduler = LRSchedulerFactory.from_args(config.scheduler,
+                                                    scheduler_args)
+
+        # opt
+        def optimizer_args(
+                config,
+                parameters,
+                lr_scheduler=None, ):
+            optim_conf = config.optim_conf
+            return {
+                "grad_clip": optim_conf.global_grad_clip,
+                "weight_decay": optim_conf.weight_decay,
+                "learning_rate": lr_scheduler,
+                "parameters": parameters,
+            }
+
+        optimzer_args = optimizer_args(config, model.parameters(), lr_scheduler)
+        optimizer = OptimizerFactory.from_args(config.optim, optimzer_args)
+
+        self.model = model
+        self.lr_scheduler = lr_scheduler
+        self.optimizer = optimizer
+        logger.info("Setup model/optimizer/lr_scheduler!")
+
+
+class U2Tester(U2Trainer):
+    @classmethod
+    def params(cls, config: Optional[CfgNode]=None) -> CfgNode:
+        # decoding config
+        default = CfgNode(
+            dict(
+                alpha=2.5,  # Coef of LM for beam search.
+                beta=0.3,  # Coef of WC for beam search.
+                cutoff_prob=1.0,  # Cutoff probability for pruning.
+                cutoff_top_n=40,  # Cutoff number for pruning.
+                lang_model_path='models/lm/common_crawl_00.prune01111.trie.klm',  # Filepath for language model.
+                decoding_method='attention',  # Decoding method. Options: 'attention', 'ctc_greedy_search',
+                # 'ctc_prefix_beam_search', 'attention_rescoring'
+                error_rate_type='wer',  # Error rate type for evaluation. Options `wer`, 'cer'
+                num_proc_bsearch=8,  # # of CPUs for beam search.
+                beam_size=10,  # Beam search width.
+                batch_size=16,  # decoding batch size
+                ctc_weight=0.0,  # ctc weight for attention rescoring decode mode.
+                decoding_chunk_size=-1,  # decoding chunk size. Defaults to -1.
+                # <0: for decoding, use full chunk.
+                # >0: for decoding, use fixed chunk size as set.
+                # 0: used for training, it's prohibited here.
+                num_decoding_left_chunks=-1,  # number of left chunks for decoding. Defaults to -1.
+                simulate_streaming=False,  # simulate streaming inference. Defaults to False.
+            ))
+
+        if config is not None:
+            config.merge_from_other_cfg(default)
+        return default
+
+    def __init__(self, config, args):
+        super().__init__(config, args)
+
+    def id2token(self, texts, texts_len, text_feature):
+        """ ord() id to chr() chr """
+        trans = []
+        for text, n in zip(texts, texts_len):
+            n = n.numpy().item()
+            ids = text[:n]
+            trans.append(text_feature.defeaturize(ids.numpy().tolist()))
+        return trans
+
+    def compute_metrics(self,
+                        utts,
+                        audio,
+                        audio_len,
+                        texts,
+                        texts_len,
+                        fout=None):
+        cfg = self.config.decoding
+        errors_sum, len_refs, num_ins = 0.0, 0, 0
+        errors_func = error_rate.char_errors if cfg.error_rate_type == 'cer' else error_rate.word_errors
+        error_rate_func = error_rate.cer if cfg.error_rate_type == 'cer' else error_rate.wer
+
+        start_time = time.time()
+        text_feature = TextFeaturizer(
+            unit_type=self.config.collator.unit_type,
+            vocab_filepath=self.config.collator.vocab_filepath,
+            spm_model_prefix=self.config.collator.spm_model_prefix)
+        target_transcripts = self.id2token(texts, texts_len, text_feature)
+        result_transcripts = self.model.decode(
+            audio,
+            audio_len,
+            text_feature=text_feature,
+            decoding_method=cfg.decoding_method,
+            lang_model_path=cfg.lang_model_path,
+            beam_alpha=cfg.alpha,
+            beam_beta=cfg.beta,
+            beam_size=cfg.beam_size,
+            cutoff_prob=cfg.cutoff_prob,
+            cutoff_top_n=cfg.cutoff_top_n,
+            num_processes=cfg.num_proc_bsearch,
+            ctc_weight=cfg.ctc_weight,
+            decoding_chunk_size=cfg.decoding_chunk_size,
+            num_decoding_left_chunks=cfg.num_decoding_left_chunks,
+            simulate_streaming=cfg.simulate_streaming)
+        decode_time = time.time() - start_time
+
+        for utt, target, result in zip(utts, target_transcripts,
+                                       result_transcripts):
+            errors, len_ref = errors_func(target, result)
+            errors_sum += errors
+            len_refs += len_ref
+            num_ins += 1
+            if fout:
+                fout.write(utt + " " + result + "\n")
+            logger.info("\nTarget Transcription: %s\nOutput Transcription: %s" %
+                        (target, result))
+            logger.info("One example error rate [%s] = %f" %
+                        (cfg.error_rate_type, error_rate_func(target, result)))
+
+        return dict(
+            errors_sum=errors_sum,
+            len_refs=len_refs,
+            num_ins=num_ins,  # num examples
+            error_rate=errors_sum / len_refs,
+            error_rate_type=cfg.error_rate_type,
+            num_frames=audio_len.sum().numpy().item(),
+            decode_time=decode_time)
+
+    @mp_tools.rank_zero_only
+    @paddle.no_grad()
+    def test(self):
+        assert self.args.result_file
+        self.model.eval()
+        logger.info(f"Test Total Examples: {len(self.test_loader.dataset)}")
+
+        stride_ms = self.config.collator.stride_ms
+        error_rate_type = None
+        errors_sum, len_refs, num_ins = 0.0, 0, 0
+        num_frames = 0.0
+        num_time = 0.0
+        with open(self.args.result_file, 'w') as fout:
+            for i, batch in enumerate(self.test_loader):
+                metrics = self.compute_metrics(*batch, fout=fout)
+                num_frames += metrics['num_frames']
+                num_time += metrics["decode_time"]
+                errors_sum += metrics['errors_sum']
+                len_refs += metrics['len_refs']
+                num_ins += metrics['num_ins']
+                error_rate_type = metrics['error_rate_type']
+                rtf = num_time / (num_frames * stride_ms)
+                logger.info(
+                    "RTF: %f, Error rate [%s] (%d/?) = %f" %
+                    (rtf, error_rate_type, num_ins, errors_sum / len_refs))
+
+        rtf = num_time / (num_frames * stride_ms)
+        msg = "Test: "
+        msg += "epoch: {}, ".format(self.epoch)
+        msg += "step: {}, ".format(self.iteration)
+        msg += "RTF: {}, ".format(rtf)
+        msg += "Final error rate [%s] (%d/%d) = %f" % (
+            error_rate_type, num_ins, num_ins, errors_sum / len_refs)
+        logger.info(msg)
+
+        # test meta results
+        err_meta_path = os.path.splitext(self.args.result_file)[0] + '.err'
+        err_type_str = "{}".format(error_rate_type)
+        with open(err_meta_path, 'w') as f:
+            data = json.dumps({
+                "epoch":
+                self.epoch,
+                "step":
+                self.iteration,
+                "rtf":
+                rtf,
+                error_rate_type:
+                errors_sum / len_refs,
+                "dataset_hour": (num_frames * stride_ms) / 1000.0 / 3600.0,
+                "process_hour":
+                num_time / 1000.0 / 3600.0,
+                "num_examples":
+                num_ins,
+                "err_sum":
+                errors_sum,
+                "ref_len":
+                len_refs,
+                "decode_method":
+                self.config.decoding.decoding_method,
+            })
+            f.write(data + '\n')
+
+    def run_test(self):
+        self.resume_or_scratch()
+        try:
+            self.test()
+        except KeyboardInterrupt:
+            sys.exit(-1)
+
+    @paddle.no_grad()
+    def align(self):
+        if self.config.decoding.batch_size > 1:
+            logger.fatal('alignment mode must be running with batch_size == 1')
+            sys.exit(1)
+
+        # xxx.align
+        assert self.args.result_file and self.args.result_file.endswith(
+            '.align')
+
+        self.model.eval()
+        logger.info(f"Align Total Examples: {len(self.align_loader.dataset)}")
+
+        stride_ms = self.config.collater.stride_ms
+        token_dict = self.args.char_list
+
+        with open(self.args.result_file, 'w') as fout:
+            # one example in batch
+            for i, batch in enumerate(self.align_loader):
+                key, feat, feats_length, target, target_length = batch
+
+                # 1. Encoder
+                encoder_out, encoder_mask = self.model._forward_encoder(
+                    feat, feats_length)  # (B, maxlen, encoder_dim)
+                maxlen = encoder_out.size(1)
+                ctc_probs = self.model.ctc.log_softmax(
+                    encoder_out)  # (1, maxlen, vocab_size)
+
+                # 2. alignment
+                ctc_probs = ctc_probs.squeeze(0)
+                target = target.squeeze(0)
+                alignment = ctc_utils.forced_align(ctc_probs, target)
+                logger.info("align ids", key[0], alignment)
+                fout.write('{} {}\n'.format(key[0], alignment))
+
+                # 3. gen praat
+                # segment alignment
+                align_segs = text_grid.segment_alignment(alignment)
+                logger.info("align tokens", key[0], align_segs)
+                # IntervalTier, List["start end token\n"]
+                subsample = utility.get_subsample(self.config)
+                tierformat = text_grid.align_to_tierformat(
+                    align_segs, subsample, token_dict)
+                # write tier
+                align_output_path = os.path.join(
+                    os.path.dirname(self.args.result_file), "align")
+                tier_path = os.path.join(align_output_path, key[0] + ".tier")
+                with open(tier_path, 'w') as f:
+                    f.writelines(tierformat)
+                # write textgrid
+                textgrid_path = os.path.join(align_output_path,
+                                             key[0] + ".TextGrid")
+                second_per_frame = 1. / (1000. /
+                                         stride_ms)  # 25ms window, 10ms stride
+                second_per_example = (
+                    len(alignment) + 1) * subsample * second_per_frame
+                text_grid.generate_textgrid(
+                    maxtime=second_per_example,
+                    intervals=tierformat,
+                    output=textgrid_path)
+
+    def run_align(self):
+        self.resume_or_scratch()
+        try:
+            self.align()
+        except KeyboardInterrupt:
+            sys.exit(-1)
+
+    def load_inferspec(self):
+        """infer model and input spec.
+
+        Returns:
+            nn.Layer: inference model
+            List[paddle.static.InputSpec]: input spec.
+        """
+        from deepspeech.models.u2 import U2InferModel
+        infer_model = U2InferModel.from_pretrained(self.test_loader,
+                                                   self.config.model.clone(),
+                                                   self.args.checkpoint_path)
+        feat_dim = self.test_loader.feat_dim
+        input_spec = [
+            paddle.static.InputSpec(shape=[1, None, feat_dim],
+                                    dtype='float32'),  # audio, [B,T,D]
+            paddle.static.InputSpec(shape=[1],
+                                    dtype='int64'),  # audio_length, [B]
+        ]
+        return infer_model, input_spec
+
+    def export(self):
+        infer_model, input_spec = self.load_inferspec()
+        assert isinstance(input_spec, list), type(input_spec)
+        infer_model.eval()
+        static_model = paddle.jit.to_static(infer_model, input_spec=input_spec)
+        logger.info(f"Export code: {static_model.forward.code}")
+        paddle.jit.save(static_model, self.args.export_path)
+
+    def run_export(self):
+        try:
+            self.export()
+        except KeyboardInterrupt:
+            sys.exit(-1)
+
+    def setup_dict(self):
+        # load dictionary for debug log
+        self.args.char_list = load_dict(self.args.dict_path,
+                                        "maskctc" in self.args.model_name)
+
+    def setup(self):
+        """Setup the experiment.
+        """
+        paddle.set_device(self.args.device)
+
+        self.setup_output_dir()
+        self.setup_checkpointer()
+
+        self.setup_dataloader()
+        self.setup_model()
+
+        self.setup_dict()
+
+        self.iteration = 0
+        self.epoch = 0
+
+    def setup_output_dir(self):
+        """Create a directory used for output.
+        """
+        # output dir
+        if self.args.output:
+            output_dir = Path(self.args.output).expanduser()
+            output_dir.mkdir(parents=True, exist_ok=True)
+        else:
+            output_dir = Path(
+                self.args.checkpoint_path).expanduser().parent.parent
+            output_dir.mkdir(parents=True, exist_ok=True)
+
+        self.output_dir = output_dir
diff --git a/deepspeech/exps/u2_st/bin/export.py b/deepspeech/exps/u2_st/bin/export.py
index f566ba5bb..c7eb5d03b 100644
--- a/deepspeech/exps/u2_st/bin/export.py
+++ b/deepspeech/exps/u2_st/bin/export.py
@@ -30,6 +30,9 @@ def main(config, args):
 
 if __name__ == "__main__":
     parser = default_argument_parser()
+    # save jit model to 
+    parser.add_argument(
+        "--export_path", type=str, help="path of the jit model to save")
     args = parser.parse_args()
     print_arguments(args, globals())
 
diff --git a/deepspeech/exps/u2_st/bin/test.py b/deepspeech/exps/u2_st/bin/test.py
index d66c7a26c..81197decf 100644
--- a/deepspeech/exps/u2_st/bin/test.py
+++ b/deepspeech/exps/u2_st/bin/test.py
@@ -34,6 +34,9 @@ def main(config, args):
 
 if __name__ == "__main__":
     parser = default_argument_parser()
+    # save asr result to 
+    parser.add_argument(
+        "--result_file", type=str, help="path of save the asr result")
     args = parser.parse_args()
     print_arguments(args, globals())
 
diff --git a/deepspeech/frontend/augmentor/augmentation.py b/deepspeech/frontend/augmentor/augmentation.py
index cc0564daf..17abcf605 100644
--- a/deepspeech/frontend/augmentor/augmentation.py
+++ b/deepspeech/frontend/augmentor/augmentation.py
@@ -13,18 +13,28 @@
 # limitations under the License.
 """Contains the data augmentation pipeline."""
 import json
+from collections.abc import Sequence
+from inspect import signature
 
 import numpy as np
 
-from deepspeech.frontend.augmentor.impulse_response import ImpulseResponseAugmentor
-from deepspeech.frontend.augmentor.noise_perturb import NoisePerturbAugmentor
-from deepspeech.frontend.augmentor.online_bayesian_normalization import \
-    OnlineBayesianNormalizationAugmentor
-from deepspeech.frontend.augmentor.resample import ResampleAugmentor
-from deepspeech.frontend.augmentor.shift_perturb import ShiftPerturbAugmentor
-from deepspeech.frontend.augmentor.spec_augment import SpecAugmentor
-from deepspeech.frontend.augmentor.speed_perturb import SpeedPerturbAugmentor
-from deepspeech.frontend.augmentor.volume_perturb import VolumePerturbAugmentor
+from deepspeech.frontend.augmentor.base import AugmentorBase
+from deepspeech.utils.dynamic_import import dynamic_import
+from deepspeech.utils.log import Log
+
+__all__ = ["AugmentationPipeline"]
+
+logger = Log(__name__).getlog()
+
+import_alias = dict(
+    volume="deepspeech.frontend.augmentor.impulse_response:VolumePerturbAugmentor",
+    shift="deepspeech.frontend.augmentor.shift_perturb:ShiftPerturbAugmentor",
+    speed="deepspeech.frontend.augmentor.speed_perturb:SpeedPerturbAugmentor",
+    resample="deepspeech.frontend.augmentor.resample:ResampleAugmentor",
+    bayesian_normal="deepspeech.frontend.augmentor.online_bayesian_normalization:OnlineBayesianNormalizationAugmentor",
+    noise="deepspeech.frontend.augmentor.noise_perturb:NoisePerturbAugmentor",
+    impulse="deepspeech.frontend.augmentor.impulse_response:ImpulseResponseAugmentor",
+    specaug="deepspeech.frontend.augmentor.spec_augment:SpecAugmentor", )
 
 
 class AugmentationPipeline():
@@ -78,20 +88,74 @@ class AugmentationPipeline():
     augmentor to take effect. If "prob" is zero, the augmentor does not take
     effect.
 
-    :param augmentation_config: Augmentation configuration in json string.
-    :type augmentation_config: str
-    :param random_seed: Random seed.
-    :type random_seed: int
-    :raises ValueError: If the augmentation json config is in incorrect format".
+    Params:
+        augmentation_config(str): Augmentation configuration in json string.
+        random_seed(int): Random seed.
+        train(bool): whether is train mode.
+    
+    Raises:
+        ValueError: If the augmentation json config is in incorrect format".
     """
 
-    def __init__(self, augmentation_config: str, random_seed=0):
+    SPEC_TYPES = {'specaug'}
+
+    def __init__(self, augmentation_config: str, random_seed: int=0):
         self._rng = np.random.RandomState(random_seed)
-        self._spec_types = ('specaug')
-        self._augmentors, self._rates = self._parse_pipeline_from(
-            augmentation_config, 'audio')
+        self.conf = {'mode': 'sequential', 'process': []}
+        if augmentation_config:
+            process = json.loads(augmentation_config)
+            self.conf['process'] += process
+
+        self._augmentors, self._rates = self._parse_pipeline_from('all')
+        self._audio_augmentors, self._audio_rates = self._parse_pipeline_from(
+            'audio')
         self._spec_augmentors, self._spec_rates = self._parse_pipeline_from(
-            augmentation_config, 'feature')
+            'feature')
+
+    def __call__(self, xs, uttid_list=None, **kwargs):
+        if not isinstance(xs, Sequence):
+            is_batch = False
+            xs = [xs]
+        else:
+            is_batch = True
+
+        if isinstance(uttid_list, str):
+            uttid_list = [uttid_list for _ in range(len(xs))]
+
+        if self.conf.get("mode", "sequential") == "sequential":
+            for idx, (func, rate) in enumerate(
+                    zip(self._augmentors, self._rates), 0):
+                if self._rng.uniform(0., 1.) >= rate:
+                    continue
+
+                # Derive only the args which the func has
+                try:
+                    param = signature(func).parameters
+                except ValueError:
+                    # Some function, e.g. built-in function, are failed
+                    param = {}
+                _kwargs = {k: v for k, v in kwargs.items() if k in param}
+
+                try:
+                    if uttid_list is not None and "uttid" in param:
+                        xs = [
+                            func(x, u, **_kwargs)
+                            for x, u in zip(xs, uttid_list)
+                        ]
+                    else:
+                        xs = [func(x, **_kwargs) for x in xs]
+                except Exception:
+                    logger.fatal("Catch a exception from {}th func: {}".format(
+                        idx, func))
+                    raise
+        else:
+            raise NotImplementedError(
+                "Not supporting mode={}".format(self.conf["mode"]))
+
+        if is_batch:
+            return xs
+        else:
+            return xs[0]
 
     def transform_audio(self, audio_segment):
         """Run the pre-processing pipeline for data augmentation.
@@ -101,7 +165,7 @@ class AugmentationPipeline():
         :param audio_segment: Audio segment to process.
         :type audio_segment: AudioSegmenet|SpeechSegment
         """
-        for augmentor, rate in zip(self._augmentors, self._rates):
+        for augmentor, rate in zip(self._audio_augmentors, self._audio_rates):
             if self._rng.uniform(0., 1.) < rate:
                 augmentor.transform_audio(audio_segment)
 
@@ -116,52 +180,39 @@ class AugmentationPipeline():
                 spec_segment = augmentor.transform_feature(spec_segment)
         return spec_segment
 
-    def _parse_pipeline_from(self, config_json, aug_type='audio'):
+    def _parse_pipeline_from(self, aug_type='all'):
         """Parse the config json to build a augmentation pipelien."""
-        assert aug_type in ('audio', 'feature'), aug_type
-        try:
-            configs = json.loads(config_json)
-            audio_confs = []
-            feature_confs = []
-            for config in configs:
-                if config["type"] in self._spec_types:
-                    feature_confs.append(config)
-                else:
-                    audio_confs.append(config)
-
-            if aug_type == 'audio':
-                aug_confs = audio_confs
-            elif aug_type == 'feature':
-                aug_confs = feature_confs
-
-            augmentors = [
-                self._get_augmentor(config["type"], config["params"])
-                for config in aug_confs
-            ]
-            rates = [config["prob"] for config in aug_confs]
-
-        except Exception as e:
-            raise ValueError("Failed to parse the augmentation config json: "
-                             "%s" % str(e))
+        assert aug_type in ('audio', 'feature', 'all'), aug_type
+        audio_confs = []
+        feature_confs = []
+        all_confs = []
+        for config in self.conf['process']:
+            all_confs.append(config)
+            if config["type"] in self.SPEC_TYPES:
+                feature_confs.append(config)
+            else:
+                audio_confs.append(config)
+
+        if aug_type == 'audio':
+            aug_confs = audio_confs
+        elif aug_type == 'feature':
+            aug_confs = feature_confs
+        else:
+            aug_confs = all_confs
+
+        augmentors = [
+            self._get_augmentor(config["type"], config["params"])
+            for config in aug_confs
+        ]
+        rates = [config["prob"] for config in aug_confs]
         return augmentors, rates
 
     def _get_augmentor(self, augmentor_type, params):
         """Return an augmentation model by the type name, and pass in params."""
-        if augmentor_type == "volume":
-            return VolumePerturbAugmentor(self._rng, **params)
-        elif augmentor_type == "shift":
-            return ShiftPerturbAugmentor(self._rng, **params)
-        elif augmentor_type == "speed":
-            return SpeedPerturbAugmentor(self._rng, **params)
-        elif augmentor_type == "resample":
-            return ResampleAugmentor(self._rng, **params)
-        elif augmentor_type == "bayesian_normal":
-            return OnlineBayesianNormalizationAugmentor(self._rng, **params)
-        elif augmentor_type == "noise":
-            return NoisePerturbAugmentor(self._rng, **params)
-        elif augmentor_type == "impulse":
-            return ImpulseResponseAugmentor(self._rng, **params)
-        elif augmentor_type == "specaug":
-            return SpecAugmentor(self._rng, **params)
-        else:
+        class_obj = dynamic_import(augmentor_type, import_alias)
+        assert issubclass(class_obj, AugmentorBase)
+        try:
+            obj = class_obj(self._rng, **params)
+        except Exception:
             raise ValueError("Unknown augmentor type [%s]." % augmentor_type)
+        return obj
diff --git a/deepspeech/frontend/augmentor/base.py b/deepspeech/frontend/augmentor/base.py
index e6f5c1e9f..18d003c0b 100644
--- a/deepspeech/frontend/augmentor/base.py
+++ b/deepspeech/frontend/augmentor/base.py
@@ -28,6 +28,10 @@ class AugmentorBase():
     def __init__(self):
         pass
 
+    @abstractmethod
+    def __call__(self, xs):
+        raise NotImplementedError("AugmentorBase: Not impl __call__")
+
     @abstractmethod
     def transform_audio(self, audio_segment):
         """Adds various effects to the input audio segment. Such effects
@@ -40,7 +44,7 @@ class AugmentorBase():
         :param audio_segment: Audio segment to add effects to.
         :type audio_segment: AudioSegmenet|SpeechSegment
         """
-        raise NotImplementedError
+        raise NotImplementedError("AugmentorBase: Not impl transform_audio")
 
     @abstractmethod
     def transform_feature(self, spec_segment):
@@ -52,4 +56,4 @@ class AugmentorBase():
         Args:
             spec_segment (Spectrogram): Spectrogram segment to add effects to.
         """
-        raise NotImplementedError
+        raise NotImplementedError("AugmentorBase: Not impl transform_feature")
diff --git a/deepspeech/frontend/augmentor/impulse_response.py b/deepspeech/frontend/augmentor/impulse_response.py
index fbd617b42..818251ed8 100644
--- a/deepspeech/frontend/augmentor/impulse_response.py
+++ b/deepspeech/frontend/augmentor/impulse_response.py
@@ -30,6 +30,12 @@ class ImpulseResponseAugmentor(AugmentorBase):
         self._rng = rng
         self._impulse_manifest = read_manifest(impulse_manifest_path)
 
+    def __call__(self, x, uttid=None, train=True):
+        if not train:
+            return x
+        self.transform_audio(x)
+        return x
+
     def transform_audio(self, audio_segment):
         """Add impulse response effect.
 
diff --git a/deepspeech/frontend/augmentor/noise_perturb.py b/deepspeech/frontend/augmentor/noise_perturb.py
index b3c07f5c1..790b0c396 100644
--- a/deepspeech/frontend/augmentor/noise_perturb.py
+++ b/deepspeech/frontend/augmentor/noise_perturb.py
@@ -36,6 +36,12 @@ class NoisePerturbAugmentor(AugmentorBase):
         self._rng = rng
         self._noise_manifest = read_manifest(manifest_path=noise_manifest_path)
 
+    def __call__(self, x, uttid=None, train=True):
+        if not train:
+            return x
+        self.transform_audio(x)
+        return x
+
     def transform_audio(self, audio_segment):
         """Add background noise audio.
 
diff --git a/deepspeech/frontend/augmentor/online_bayesian_normalization.py b/deepspeech/frontend/augmentor/online_bayesian_normalization.py
index 5af3b9b03..0f9d3ef6f 100644
--- a/deepspeech/frontend/augmentor/online_bayesian_normalization.py
+++ b/deepspeech/frontend/augmentor/online_bayesian_normalization.py
@@ -44,6 +44,12 @@ class OnlineBayesianNormalizationAugmentor(AugmentorBase):
         self._rng = rng
         self._startup_delay = startup_delay
 
+    def __call__(self, x, uttid=None, train=True):
+        if not train:
+            return x
+        self.transform_audio(x)
+        return x
+
     def transform_audio(self, audio_segment):
         """Normalizes the input audio using the online Bayesian approach.
 
diff --git a/deepspeech/frontend/augmentor/resample.py b/deepspeech/frontend/augmentor/resample.py
index 9afce635d..509fe003d 100644
--- a/deepspeech/frontend/augmentor/resample.py
+++ b/deepspeech/frontend/augmentor/resample.py
@@ -31,6 +31,12 @@ class ResampleAugmentor(AugmentorBase):
         self._new_sample_rate = new_sample_rate
         self._rng = rng
 
+    def __call__(self, x, uttid=None, train=True):
+        if not train:
+            return x
+        self.transform_audio(x)
+        return x
+
     def transform_audio(self, audio_segment):
         """Resamples the input audio to a target sample rate.
 
diff --git a/deepspeech/frontend/augmentor/shift_perturb.py b/deepspeech/frontend/augmentor/shift_perturb.py
index 9cc3fe2d0..8b7439fe5 100644
--- a/deepspeech/frontend/augmentor/shift_perturb.py
+++ b/deepspeech/frontend/augmentor/shift_perturb.py
@@ -31,6 +31,12 @@ class ShiftPerturbAugmentor(AugmentorBase):
         self._max_shift_ms = max_shift_ms
         self._rng = rng
 
+    def __call__(self, x, uttid=None, train=True):
+        if not train:
+            return x
+        self.transform_audio(x)
+        return x
+
     def transform_audio(self, audio_segment):
         """Shift audio.
 
diff --git a/deepspeech/frontend/augmentor/spec_augment.py b/deepspeech/frontend/augmentor/spec_augment.py
index 1c2e09fc7..7c23b628e 100644
--- a/deepspeech/frontend/augmentor/spec_augment.py
+++ b/deepspeech/frontend/augmentor/spec_augment.py
@@ -12,7 +12,11 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 """Contains the volume perturb augmentation model."""
+import random
+
 import numpy as np
+from PIL import Image
+from PIL.Image import BICUBIC
 
 from deepspeech.frontend.augmentor.base import AugmentorBase
 from deepspeech.utils.log import Log
@@ -41,7 +45,9 @@ class SpecAugmentor(AugmentorBase):
                  W=40,
                  adaptive_number_ratio=0,
                  adaptive_size_ratio=0,
-                 max_n_time_masks=20):
+                 max_n_time_masks=20,
+                 replace_with_zero=True,
+                 warp_mode='PIL'):
         """SpecAugment class.
         Args:
             rng (random.Random): random generator object.
@@ -54,17 +60,22 @@ class SpecAugmentor(AugmentorBase):
             adaptive_number_ratio (float): adaptive multiplicity ratio for time masking
             adaptive_size_ratio (float): adaptive size ratio for time masking
             max_n_time_masks (int): maximum number of time masking
+            replace_with_zero (bool): pad zero on mask if true else use mean
+            warp_mode (str):  "PIL" (default, fast, not differentiable) 
+                 or "sparse_image_warp" (slow, differentiable)
         """
         super().__init__()
         self._rng = rng
+        self.inplace = True
+        self.replace_with_zero = replace_with_zero
 
+        self.mode = warp_mode
         self.W = W
         self.F = F
         self.T = T
         self.n_freq_masks = n_freq_masks
         self.n_time_masks = n_time_masks
         self.p = p
-        #logger.info(f"specaug: F-{F}, T-{T}, F-n-{n_freq_masks}, T-n-{n_time_masks}")
 
         # adaptive SpecAugment
         self.adaptive_number_ratio = adaptive_number_ratio
@@ -121,21 +132,83 @@ class SpecAugmentor(AugmentorBase):
     def time_mask(self):
         return self._time_mask
 
-    def time_warp(xs, W=40):
-        raise NotImplementedError
+    def __repr__(self):
+        return f"specaug: F-{F}, T-{T}, F-n-{n_freq_masks}, T-n-{n_time_masks}"
+
+    def time_warp(self, x, mode='PIL'):
+        """time warp for spec augment
+        move random center frame by the random width ~ uniform(-window, window)
+
+        Args:
+            x (np.ndarray): spectrogram (time, freq)
+            mode (str): PIL or sparse_image_warp
+
+        Raises:
+            NotImplementedError: [description]
+            NotImplementedError: [description]
+
+        Returns:
+            np.ndarray: time warped spectrogram (time, freq)
+        """
+        window = max_time_warp = self.W
+        if mode == "PIL":
+            t = x.shape[0]
+            if t - window <= window:
+                return x
+            # NOTE: randrange(a, b) emits a, a + 1, ..., b - 1
+            center = random.randrange(window, t - window)
+            warped = random.randrange(center - window, center +
+                                      window) + 1  # 1 ... t - 1
+
+            left = Image.fromarray(x[:center]).resize((x.shape[1], warped),
+                                                      BICUBIC)
+            right = Image.fromarray(x[center:]).resize((x.shape[1], t - warped),
+                                                       BICUBIC)
+            if self.inplace:
+                x[:warped] = left
+                x[warped:] = right
+                return x
+            return np.concatenate((left, right), 0)
+        elif mode == "sparse_image_warp":
+            raise NotImplementedError('sparse_image_warp')
+        else:
+            raise NotImplementedError(
+                "unknown resize mode: " + mode +
+                ", choose one from (PIL, sparse_image_warp).")
+
+    def mask_freq(self, x, replace_with_zero=False):
+        """freq mask
 
-    def mask_freq(self, xs, replace_with_zero=False):
-        n_bins = xs.shape[0]
+        Args:
+            x (np.ndarray): spectrogram (time, freq)
+            replace_with_zero (bool, optional): Defaults to False.
+
+        Returns:
+            np.ndarray: freq mask spectrogram (time, freq)
+        """
+        n_bins = x.shape[1]
         for i in range(0, self.n_freq_masks):
             f = int(self._rng.uniform(low=0, high=self.F))
             f_0 = int(self._rng.uniform(low=0, high=n_bins - f))
-            xs[f_0:f_0 + f, :] = 0
             assert f_0 <= f_0 + f
+            if replace_with_zero:
+                x[:, f_0:f_0 + f] = 0
+            else:
+                x[:, f_0:f_0 + f] = x.mean()
             self._freq_mask = (f_0, f_0 + f)
-        return xs
+        return x
+
+    def mask_time(self, x, replace_with_zero=False):
+        """time mask
 
-    def mask_time(self, xs, replace_with_zero=False):
-        n_frames = xs.shape[1]
+        Args:
+            x (np.ndarray): spectrogram (time, freq)
+            replace_with_zero (bool, optional): Defaults to False.
+
+        Returns:
+            np.ndarray: time mask spectrogram (time, freq)
+        """
+        n_frames = x.shape[0]
 
         if self.adaptive_number_ratio > 0:
             n_masks = int(n_frames * self.adaptive_number_ratio)
@@ -152,19 +225,29 @@ class SpecAugmentor(AugmentorBase):
             t = int(self._rng.uniform(low=0, high=T))
             t = min(t, int(n_frames * self.p))
             t_0 = int(self._rng.uniform(low=0, high=n_frames - t))
-            xs[:, t_0:t_0 + t] = 0
             assert t_0 <= t_0 + t
+            if replace_with_zero:
+                x[t_0:t_0 + t, :] = 0
+            else:
+                x[t_0:t_0 + t, :] = x.mean()
             self._time_mask = (t_0, t_0 + t)
-        return xs
+        return x
+
+    def __call__(self, x, train=True):
+        if not train:
+            return x
+        return self.transform_feature(x)
 
-    def transform_feature(self, xs: np.ndarray):
+    def transform_feature(self, x: np.ndarray):
         """
         Args:
-            xs (FloatTensor): `[F, T]`
+            x (np.ndarray): `[T, F]`
         Returns:
-            xs (FloatTensor): `[F, T]`
+            x (np.ndarray): `[T, F]`
         """
-        # xs = self.time_warp(xs)
-        xs = self.mask_freq(xs)
-        xs = self.mask_time(xs)
-        return xs
+        assert isinstance(x, np.ndarray)
+        assert x.ndim == 2
+        x = self.time_warp(x, self.mode)
+        x = self.mask_freq(x, self.replace_with_zero)
+        x = self.mask_time(x, self.replace_with_zero)
+        return x
diff --git a/deepspeech/frontend/augmentor/speed_perturb.py b/deepspeech/frontend/augmentor/speed_perturb.py
index d0977c131..ce8dfde0a 100644
--- a/deepspeech/frontend/augmentor/speed_perturb.py
+++ b/deepspeech/frontend/augmentor/speed_perturb.py
@@ -79,6 +79,12 @@ class SpeedPerturbAugmentor(AugmentorBase):
             self._rates = np.linspace(
                 self._min_rate, self._max_rate, self._num_rates, endpoint=True)
 
+    def __call__(self, x, uttid=None, train=True):
+        if not train:
+            return x
+        self.transform_audio(x)
+        return x
+
     def transform_audio(self, audio_segment):
         """Sample a new speed rate from the given range and
         changes the speed of the given audio clip.
diff --git a/deepspeech/frontend/augmentor/volume_perturb.py b/deepspeech/frontend/augmentor/volume_perturb.py
index 0d76e7a05..70cb28897 100644
--- a/deepspeech/frontend/augmentor/volume_perturb.py
+++ b/deepspeech/frontend/augmentor/volume_perturb.py
@@ -37,6 +37,12 @@ class VolumePerturbAugmentor(AugmentorBase):
         self._max_gain_dBFS = max_gain_dBFS
         self._rng = rng
 
+    def __call__(self, x, uttid=None, train=True):
+        if not train:
+            return x
+        self.transform_audio(x)
+        return x
+
     def transform_audio(self, audio_segment):
         """Change audio loadness.
 
diff --git a/deepspeech/frontend/featurizer/__init__.py b/deepspeech/frontend/featurizer/__init__.py
index 185a92b8d..6992700d9 100644
--- a/deepspeech/frontend/featurizer/__init__.py
+++ b/deepspeech/frontend/featurizer/__init__.py
@@ -11,3 +11,6 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+from .audio_featurizer import AudioFeaturizer  #noqa: F401
+from .speech_featurizer import SpeechFeaturizer
+from .text_featurizer import TextFeaturizer
diff --git a/deepspeech/frontend/featurizer/audio_featurizer.py b/deepspeech/frontend/featurizer/audio_featurizer.py
index 11c1fa2d4..4c40c8472 100644
--- a/deepspeech/frontend/featurizer/audio_featurizer.py
+++ b/deepspeech/frontend/featurizer/audio_featurizer.py
@@ -18,7 +18,7 @@ from python_speech_features import logfbank
 from python_speech_features import mfcc
 
 
-class AudioFeaturizer(object):
+class AudioFeaturizer():
     """Audio featurizer, for extracting features from audio contents of
     AudioSegment or SpeechSegment.
 
@@ -167,32 +167,6 @@ class AudioFeaturizer(object):
             raise ValueError("Unknown specgram_type %s. "
                              "Supported values: linear." % self._specgram_type)
 
-    def _compute_linear_specgram(self,
-                                 samples,
-                                 sample_rate,
-                                 stride_ms=10.0,
-                                 window_ms=20.0,
-                                 max_freq=None,
-                                 eps=1e-14):
-        """Compute the linear spectrogram from FFT energy."""
-        if max_freq is None:
-            max_freq = sample_rate / 2
-        if max_freq > sample_rate / 2:
-            raise ValueError("max_freq must not be greater than half of "
-                             "sample rate.")
-        if stride_ms > window_ms:
-            raise ValueError("Stride size must not be greater than "
-                             "window size.")
-        stride_size = int(0.001 * sample_rate * stride_ms)
-        window_size = int(0.001 * sample_rate * window_ms)
-        specgram, freqs = self._specgram_real(
-            samples,
-            window_size=window_size,
-            stride_size=stride_size,
-            sample_rate=sample_rate)
-        ind = np.where(freqs <= max_freq)[0][-1] + 1
-        return np.log(specgram[:ind, :] + eps)
-
     def _specgram_real(self, samples, window_size, stride_size, sample_rate):
         """Compute the spectrogram for samples from a real signal."""
         # extract strided windows
@@ -217,26 +191,65 @@ class AudioFeaturizer(object):
         freqs = float(sample_rate) / window_size * np.arange(fft.shape[0])
         return fft, freqs
 
+    def _compute_linear_specgram(self,
+                                 samples,
+                                 sample_rate,
+                                 stride_ms=10.0,
+                                 window_ms=20.0,
+                                 max_freq=None,
+                                 eps=1e-14):
+        """Compute the linear spectrogram from FFT energy.
+
+        Args:
+            samples ([type]): [description]
+            sample_rate ([type]): [description]
+            stride_ms (float, optional): [description]. Defaults to 10.0.
+            window_ms (float, optional): [description]. Defaults to 20.0.
+            max_freq ([type], optional): [description]. Defaults to None.
+            eps ([type], optional): [description]. Defaults to 1e-14.
+
+        Raises:
+            ValueError: [description]
+            ValueError: [description]
+
+        Returns:
+            np.ndarray: log spectrogram, (time, freq)
+        """
+        if max_freq is None:
+            max_freq = sample_rate / 2
+        if max_freq > sample_rate / 2:
+            raise ValueError("max_freq must not be greater than half of "
+                             "sample rate.")
+        if stride_ms > window_ms:
+            raise ValueError("Stride size must not be greater than "
+                             "window size.")
+        stride_size = int(0.001 * sample_rate * stride_ms)
+        window_size = int(0.001 * sample_rate * window_ms)
+        specgram, freqs = self._specgram_real(
+            samples,
+            window_size=window_size,
+            stride_size=stride_size,
+            sample_rate=sample_rate)
+        ind = np.where(freqs <= max_freq)[0][-1] + 1
+        # (freq, time)
+        spec = np.log(specgram[:ind, :] + eps)
+        return np.transpose(spec)
+
     def _concat_delta_delta(self, feat):
         """append delat, delta-delta feature.
 
         Args:
-            feat (np.ndarray): (D, T)
+            feat (np.ndarray): (T, D)
 
         Returns:
-            np.ndarray: feat with delta-delta, (3*D, T)
+            np.ndarray: feat with delta-delta, (T, 3*D)
         """
-        feat = np.transpose(feat)
         # Deltas
         d_feat = delta(feat, 2)
         # Deltas-Deltas
         dd_feat = delta(feat, 2)
-        # transpose
-        feat = np.transpose(feat)
-        d_feat = np.transpose(d_feat)
-        dd_feat = np.transpose(dd_feat)
         # concat above three features
-        concat_feat = np.concatenate((feat, d_feat, dd_feat))
+        concat_feat = np.concatenate((feat, d_feat, dd_feat), axis=1)
         return concat_feat
 
     def _compute_mfcc(self,
@@ -292,7 +305,6 @@ class AudioFeaturizer(object):
             ceplifter=22,
             useEnergy=True,
             winfunc='povey')
-        mfcc_feat = np.transpose(mfcc_feat)
         if delta_delta:
             mfcc_feat = self._concat_delta_delta(mfcc_feat)
         return mfcc_feat
@@ -346,8 +358,6 @@ class AudioFeaturizer(object):
             remove_dc_offset=True,
             preemph=0.97,
             wintype='povey')
-
-        fbank_feat = np.transpose(fbank_feat)
         if delta_delta:
             fbank_feat = self._concat_delta_delta(fbank_feat)
         return fbank_feat
diff --git a/deepspeech/frontend/featurizer/speech_featurizer.py b/deepspeech/frontend/featurizer/speech_featurizer.py
index 0fbbc5648..5082850d6 100644
--- a/deepspeech/frontend/featurizer/speech_featurizer.py
+++ b/deepspeech/frontend/featurizer/speech_featurizer.py
@@ -16,7 +16,7 @@ from deepspeech.frontend.featurizer.audio_featurizer import AudioFeaturizer
 from deepspeech.frontend.featurizer.text_featurizer import TextFeaturizer
 
 
-class SpeechFeaturizer(object):
+class SpeechFeaturizer():
     """Speech featurizer, for extracting features from both audio and transcript
     contents of SpeechSegment.
 
diff --git a/deepspeech/frontend/featurizer/text_featurizer.py b/deepspeech/frontend/featurizer/text_featurizer.py
index 1ba6ac7f9..e4364f70a 100644
--- a/deepspeech/frontend/featurizer/text_featurizer.py
+++ b/deepspeech/frontend/featurizer/text_featurizer.py
@@ -14,12 +14,19 @@
 """Contains the text featurizer class."""
 import sentencepiece as spm
 
-from deepspeech.frontend.utility import EOS
-from deepspeech.frontend.utility import UNK
+from ..utility import EOS
+from ..utility import load_dict
+from ..utility import UNK
 
+__all__ = ["TextFeaturizer"]
 
-class TextFeaturizer(object):
-    def __init__(self, unit_type, vocab_filepath, spm_model_prefix=None):
+
+class TextFeaturizer():
+    def __init__(self,
+                 unit_type,
+                 vocab_filepath,
+                 spm_model_prefix=None,
+                 maskctc=False):
         """Text featurizer, for processing or extracting features from text.
 
         Currently, it supports char/word/sentence-piece level tokenizing and conversion into
@@ -34,11 +41,12 @@ class TextFeaturizer(object):
         assert unit_type in ('char', 'spm', 'word')
         self.unit_type = unit_type
         self.unk = UNK
+        self.maskctc = maskctc
+
         if vocab_filepath:
-            self._vocab_dict, self._id2token, self._vocab_list = self._load_vocabulary_from_file(
-                vocab_filepath)
-            self.unk_id = self._vocab_list.index(self.unk)
-            self.eos_id = self._vocab_list.index(EOS)
+            self.vocab_dict, self._id2token, self.vocab_list, self.unk_id, self.eos_id = self._load_vocabulary_from_file(
+                vocab_filepath, maskctc)
+            self.vocab_size = len(self.vocab_list)
 
         if unit_type == 'spm':
             spm_model = spm_model_prefix + '.model'
@@ -67,7 +75,7 @@ class TextFeaturizer(object):
         """Convert text string to a list of token indices.
 
         Args:
-            text (str): Text to process.
+            text (str): Text.
         
         Returns:
             List[int]: List of token indices.
@@ -75,8 +83,8 @@ class TextFeaturizer(object):
         tokens = self.tokenize(text)
         ids = []
         for token in tokens:
-            token = token if token in self._vocab_dict else self.unk
-            ids.append(self._vocab_dict[token])
+            token = token if token in self.vocab_dict else self.unk
+            ids.append(self.vocab_dict[token])
         return ids
 
     def defeaturize(self, idxs):
@@ -87,7 +95,7 @@ class TextFeaturizer(object):
             idxs (List[int]): List of token indices.
 
         Returns:
-            str: Text to process.
+            str: Text.
         """
         tokens = []
         for idx in idxs:
@@ -97,33 +105,6 @@ class TextFeaturizer(object):
         text = self.detokenize(tokens)
         return text
 
-    @property
-    def vocab_size(self):
-        """Return the vocabulary size.
-
-        :return: Vocabulary size.
-        :rtype: int
-        """
-        return len(self._vocab_list)
-
-    @property
-    def vocab_list(self):
-        """Return the vocabulary in list.
-
-        Returns:
-            List[str]: tokens.
-        """
-        return self._vocab_list
-
-    @property
-    def vocab_dict(self):
-        """Return the vocabulary in dict.
-
-        Returns:
-            Dict[str, int]: token str -> int
-        """
-        return self._vocab_dict
-
     def char_tokenize(self, text):
         """Character tokenizer.
 
@@ -206,14 +187,16 @@ class TextFeaturizer(object):
 
         return decode(tokens)
 
-    def _load_vocabulary_from_file(self, vocab_filepath):
+    def _load_vocabulary_from_file(self, vocab_filepath: str, maskctc: bool):
         """Load vocabulary from file."""
-        vocab_lines = []
-        with open(vocab_filepath, 'r', encoding='utf-8') as file:
-            vocab_lines.extend(file.readlines())
-        vocab_list = [line[:-1] for line in vocab_lines]
+        vocab_list = load_dict(vocab_filepath, maskctc)
+        assert vocab_list is not None
+
         id2token = dict(
             [(idx, token) for (idx, token) in enumerate(vocab_list)])
         token2id = dict(
             [(token, idx) for (idx, token) in enumerate(vocab_list)])
-        return token2id, id2token, vocab_list
+
+        unk_id = vocab_list.index(UNK)
+        eos_id = vocab_list.index(EOS)
+        return token2id, id2token, vocab_list, unk_id, eos_id
diff --git a/deepspeech/frontend/normalizer.py b/deepspeech/frontend/normalizer.py
index 287b51e58..73b3a4ba6 100644
--- a/deepspeech/frontend/normalizer.py
+++ b/deepspeech/frontend/normalizer.py
@@ -40,21 +40,21 @@ class CollateFunc(object):
         number = 0
         for item in batch:
             audioseg = AudioSegment.from_file(item['feat'])
-            feat = self.feature_func(audioseg)  #(D, T)
+            feat = self.feature_func(audioseg)  #(T, D)
 
-            sums = np.sum(feat, axis=1)
+            sums = np.sum(feat, axis=0)
             if mean_stat is None:
                 mean_stat = sums
             else:
                 mean_stat += sums
 
-            square_sums = np.sum(np.square(feat), axis=1)
+            square_sums = np.sum(np.square(feat), axis=0)
             if var_stat is None:
                 var_stat = square_sums
             else:
                 var_stat += square_sums
 
-            number += feat.shape[1]
+            number += feat.shape[0]
         return number, mean_stat, var_stat
 
 
@@ -120,7 +120,7 @@ class FeatureNormalizer(object):
         """Normalize features to be of zero mean and unit stddev.
 
         :param features: Input features to be normalized.
-        :type features: ndarray, shape (D, T)
+        :type features: ndarray, shape (T, D)
         :param eps:  added to stddev to provide numerical stablibity.
         :type eps: float
         :return: Normalized features.
@@ -131,8 +131,8 @@ class FeatureNormalizer(object):
     def _read_mean_std_from_file(self, filepath, eps=1e-20):
         """Load mean and std from file."""
         mean, istd = load_cmvn(filepath, filetype='json')
-        self._mean = np.expand_dims(mean, axis=-1)
-        self._istd = np.expand_dims(istd, axis=-1)
+        self._mean = np.expand_dims(mean, axis=0)
+        self._istd = np.expand_dims(istd, axis=0)
 
     def write_to_file(self, filepath):
         """Write the mean and stddev to the file.
diff --git a/deepspeech/frontend/utility.py b/deepspeech/frontend/utility.py
index b2dd9601f..3d0683b0a 100644
--- a/deepspeech/frontend/utility.py
+++ b/deepspeech/frontend/utility.py
@@ -15,6 +15,9 @@
 import codecs
 import json
 import math
+from typing import List
+from typing import Optional
+from typing import Text
 
 import numpy as np
 
@@ -23,16 +26,35 @@ from deepspeech.utils.log import Log
 logger = Log(__name__).getlog()
 
 __all__ = [
-    "load_cmvn", "read_manifest", "rms_to_db", "rms_to_dbfs", "max_dbfs",
-    "mean_dbfs", "gain_db_to_ratio", "normalize_audio", "SOS", "EOS", "UNK",
-    "BLANK"
+    "load_dict", "load_cmvn", "read_manifest", "rms_to_db", "rms_to_dbfs",
+    "max_dbfs", "mean_dbfs", "gain_db_to_ratio", "normalize_audio", "SOS",
+    "EOS", "UNK", "BLANK", "MASKCTC"
 ]
 
 IGNORE_ID = -1
-SOS = "<sos/eos>"
+# `sos` and `eos` using same token
+SOS = "<eos>"
 EOS = SOS
 UNK = "<unk>"
 BLANK = "<blank>"
+MASKCTC = "<mask>"
+
+
+def load_dict(dict_path: Optional[Text], maskctc=False) -> Optional[List[Text]]:
+    if dict_path is None:
+        return None
+
+    with open(dict_path, "r") as f:
+        dictionary = f.readlines()
+    char_list = [entry.split(" ")[0] for entry in dictionary]
+    if BLANK not in char_list:
+        char_list.insert(0, BLANK)
+    if EOS not in char_list:
+        char_list.append(EOS)
+    # for non-autoregressive maskctc model
+    if maskctc and MASKCTC not in char_list:
+        char_list.append(MASKCTC)
+    return char_list
 
 
 def read_manifest(
@@ -47,12 +69,20 @@ def read_manifest(
 
     Args:
         manifest_path ([type]): Manifest file to load and parse.
-        max_input_len ([type], optional): maximum output seq length, in seconds for raw wav, in frame numbers for feature data. Defaults to float('inf').
-        min_input_len (float, optional): minimum input seq length, in seconds for raw wav, in frame numbers for feature data. Defaults to 0.0.
-        max_output_len (float, optional): maximum input seq length, in modeling units. Defaults to 500.0.
-        min_output_len (float, optional): minimum input seq length, in modeling units. Defaults to 0.0.
-        max_output_input_ratio (float, optional): maximum output seq length/output seq length ratio. Defaults to 10.0.
-        min_output_input_ratio (float, optional): minimum output seq length/output seq length ratio. Defaults to 0.05.
+        max_input_len ([type], optional): maximum output seq length, 
+            in seconds for raw wav, in frame numbers for feature data. 
+            Defaults to float('inf').
+        min_input_len (float, optional): minimum input seq length, 
+            in seconds for raw wav, in frame numbers for feature data. 
+            Defaults to 0.0.
+        max_output_len (float, optional): maximum input seq length, 
+            in modeling units. Defaults to 500.0.
+        min_output_len (float, optional): minimum input seq length, 
+            in modeling units. Defaults to 0.0.
+        max_output_input_ratio (float, optional): 
+            maximum output seq length/output seq length ratio. Defaults to 10.0.
+        min_output_input_ratio (float, optional): 
+            minimum output seq length/output seq length ratio. Defaults to 0.05.
 
     Raises:
         IOError: If failed to parse the manifest.
diff --git a/deepspeech/io/__init__.py b/deepspeech/io/__init__.py
index e180f18ee..185a92b8d 100644
--- a/deepspeech/io/__init__.py
+++ b/deepspeech/io/__init__.py
@@ -11,139 +11,3 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-import numpy as np
-from paddle.io import DataLoader
-
-from deepspeech.io.collator import SpeechCollator
-from deepspeech.io.dataset import ManifestDataset
-from deepspeech.io.sampler import SortagradBatchSampler
-from deepspeech.io.sampler import SortagradDistributedBatchSampler
-
-
-def create_dataloader(manifest_path,
-                      unit_type,
-                      vocab_filepath,
-                      mean_std_filepath,
-                      spm_model_prefix,
-                      augmentation_config='{}',
-                      max_input_len=float('inf'),
-                      min_input_len=0.0,
-                      max_output_len=float('inf'),
-                      min_output_len=0.0,
-                      max_output_input_ratio=float('inf'),
-                      min_output_input_ratio=0.0,
-                      stride_ms=10.0,
-                      window_ms=20.0,
-                      max_freq=None,
-                      specgram_type='linear',
-                      feat_dim=None,
-                      delta_delta=False,
-                      use_dB_normalization=True,
-                      random_seed=0,
-                      keep_transcription_text=False,
-                      is_training=False,
-                      batch_size=1,
-                      num_workers=0,
-                      sortagrad=False,
-                      shuffle_method=None,
-                      dist=False):
-
-    dataset = ManifestDataset(
-        manifest_path=manifest_path,
-        unit_type=unit_type,
-        vocab_filepath=vocab_filepath,
-        mean_std_filepath=mean_std_filepath,
-        spm_model_prefix=spm_model_prefix,
-        augmentation_config=augmentation_config,
-        max_input_len=max_input_len,
-        min_input_len=min_input_len,
-        max_output_len=max_output_len,
-        min_output_len=min_output_len,
-        max_output_input_ratio=max_output_input_ratio,
-        min_output_input_ratio=min_output_input_ratio,
-        stride_ms=stride_ms,
-        window_ms=window_ms,
-        max_freq=max_freq,
-        specgram_type=specgram_type,
-        feat_dim=feat_dim,
-        delta_delta=delta_delta,
-        use_dB_normalization=use_dB_normalization,
-        random_seed=random_seed,
-        keep_transcription_text=keep_transcription_text)
-
-    if dist:
-        batch_sampler = SortagradDistributedBatchSampler(
-            dataset,
-            batch_size,
-            num_replicas=None,
-            rank=None,
-            shuffle=is_training,
-            drop_last=is_training,
-            sortagrad=is_training,
-            shuffle_method=shuffle_method)
-    else:
-        batch_sampler = SortagradBatchSampler(
-            dataset,
-            shuffle=is_training,
-            batch_size=batch_size,
-            drop_last=is_training,
-            sortagrad=is_training,
-            shuffle_method=shuffle_method)
-
-    def padding_batch(batch,
-                      padding_to=-1,
-                      flatten=False,
-                      keep_transcription_text=True):
-        """	
-        Padding audio features with zeros to make them have the same shape (or	
-        a user-defined shape) within one bach.	
-
-        If ``padding_to`` is -1, the maximun shape in the batch will be used	
-        as the target shape for padding. Otherwise, `padding_to` will be the	
-        target shape (only refers to the second axis).	
-
-        If `flatten` is True, features will be flatten to 1darray.	
-        """
-        new_batch = []
-        # get target shape	
-        max_length = max([audio.shape[1] for audio, text in batch])
-        if padding_to != -1:
-            if padding_to < max_length:
-                raise ValueError("If padding_to is not -1, it should be larger "
-                                 "than any instance's shape in the batch")
-            max_length = padding_to
-        max_text_length = max([len(text) for audio, text in batch])
-        # padding	
-        padded_audios = []
-        audio_lens = []
-        texts, text_lens = [], []
-        for audio, text in batch:
-            padded_audio = np.zeros([audio.shape[0], max_length])
-            padded_audio[:, :audio.shape[1]] = audio
-            if flatten:
-                padded_audio = padded_audio.flatten()
-            padded_audios.append(padded_audio)
-            audio_lens.append(audio.shape[1])
-
-            padded_text = np.zeros([max_text_length])
-            if keep_transcription_text:
-                padded_text[:len(text)] = [ord(t) for t in text]  # string
-            else:
-                padded_text[:len(text)] = text  # ids
-            texts.append(padded_text)
-            text_lens.append(len(text))
-
-        padded_audios = np.array(padded_audios).astype('float32')
-        audio_lens = np.array(audio_lens).astype('int64')
-        texts = np.array(texts).astype('int32')
-        text_lens = np.array(text_lens).astype('int64')
-        return padded_audios, audio_lens, texts, text_lens
-
-    # collate_fn=functools.partial(padding_batch, keep_transcription_text=keep_transcription_text),
-    collate_fn = SpeechCollator(keep_transcription_text=keep_transcription_text)
-    loader = DataLoader(
-        dataset,
-        batch_sampler=batch_sampler,
-        collate_fn=collate_fn,
-        num_workers=num_workers)
-    return loader
diff --git a/deepspeech/io/batchfy.py b/deepspeech/io/batchfy.py
new file mode 100644
index 000000000..de29d0546
--- /dev/null
+++ b/deepspeech/io/batchfy.py
@@ -0,0 +1,469 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import itertools
+
+import numpy as np
+
+from deepspeech.utils.log import Log
+
+__all__ = ["make_batchset"]
+
+logger = Log(__name__).getlog()
+
+
+def batchfy_by_seq(
+        sorted_data,
+        batch_size,
+        max_length_in,
+        max_length_out,
+        min_batch_size=1,
+        shortest_first=False,
+        ikey="input",
+        iaxis=0,
+        okey="output",
+        oaxis=0, ):
+    """Make batch set from json dictionary
+
+    :param List[(str, Dict[str, Any])] sorted_data: dictionary loaded from data.json
+    :param int batch_size: batch size
+    :param int max_length_in: maximum length of input to decide adaptive batch size
+    :param int max_length_out: maximum length of output to decide adaptive batch size
+    :param int min_batch_size: mininum batch size (for multi-gpu)
+    :param bool shortest_first: Sort from batch with shortest samples
+        to longest if true, otherwise reverse
+    :param str ikey: key to access input
+        (for ASR ikey="input", for TTS, MT ikey="output".)
+    :param int iaxis: dimension to access input
+        (for ASR, TTS iaxis=0, for MT iaxis="1".)
+    :param str okey: key to access output
+        (for ASR, MT okey="output". for TTS okey="input".)
+    :param int oaxis: dimension to access output
+        (for ASR, TTS, MT oaxis=0, reserved for future research, -1 means all axis.)
+    :return: List[List[Tuple[str, dict]]] list of batches
+    """
+    if batch_size <= 0:
+        raise ValueError(f"Invalid batch_size={batch_size}")
+
+    # check #utts is more than min_batch_size
+    if len(sorted_data) < min_batch_size:
+        raise ValueError(
+            f"#utts({len(sorted_data)}) is less than min_batch_size({min_batch_size})."
+        )
+
+    # make list of minibatches
+    minibatches = []
+    start = 0
+    while True:
+        _, info = sorted_data[start]
+        ilen = int(info[ikey][iaxis]["shape"][0])
+        olen = (int(info[okey][oaxis]["shape"][0]) if oaxis >= 0 else
+                max(map(lambda x: int(x["shape"][0]), info[okey])))
+        factor = max(int(ilen / max_length_in), int(olen / max_length_out))
+        # change batchsize depending on the input and output length
+        # if ilen = 1000 and max_length_in = 800
+        # then b = batchsize / 2
+        # and max(min_batches, .) avoids batchsize = 0
+        bs = max(min_batch_size, int(batch_size / (1 + factor)))
+        end = min(len(sorted_data), start + bs)
+        minibatch = sorted_data[start:end]
+        if shortest_first:
+            minibatch.reverse()
+
+        # check each batch is more than minimum batchsize
+        if len(minibatch) < min_batch_size:
+            mod = min_batch_size - len(minibatch) % min_batch_size
+            additional_minibatch = [
+                sorted_data[i] for i in np.random.randint(0, start, mod)
+            ]
+            if shortest_first:
+                additional_minibatch.reverse()
+            minibatch.extend(additional_minibatch)
+        minibatches.append(minibatch)
+
+        if end == len(sorted_data):
+            break
+        start = end
+
+    # batch: List[List[Tuple[str, dict]]]
+    return minibatches
+
+
+def batchfy_by_bin(
+        sorted_data,
+        batch_bins,
+        num_batches=0,
+        min_batch_size=1,
+        shortest_first=False,
+        ikey="input",
+        okey="output", ):
+    """Make variably sized batch set, which maximizes
+
+    the number of bins up to `batch_bins`.
+
+    :param List[(str, Dict[str, Any])] sorted_data: dictionary loaded from data.json
+    :param int batch_bins: Maximum frames of a batch
+    :param int num_batches: # number of batches to use (for debug)
+    :param int min_batch_size: minimum batch size (for multi-gpu)
+    :param int test: Return only every `test` batches
+    :param bool shortest_first: Sort from batch with shortest samples
+        to longest if true, otherwise reverse
+
+    :param str ikey: key to access input (for ASR ikey="input", for TTS ikey="output".)
+    :param str okey: key to access output (for ASR okey="output". for TTS okey="input".)
+
+    :return: List[Tuple[str, Dict[str, List[Dict[str, Any]]]] list of batches
+    """
+    if batch_bins <= 0:
+        raise ValueError(f"invalid batch_bins={batch_bins}")
+    length = len(sorted_data)
+    idim = int(sorted_data[0][1][ikey][0]["shape"][1])
+    odim = int(sorted_data[0][1][okey][0]["shape"][1])
+    logger.info("# utts: " + str(len(sorted_data)))
+    minibatches = []
+    start = 0
+    n = 0
+    while True:
+        # Dynamic batch size depending on size of samples
+        b = 0
+        next_size = 0
+        max_olen = 0
+        while next_size < batch_bins and (start + b) < length:
+            ilen = int(sorted_data[start + b][1][ikey][0]["shape"][0]) * idim
+            olen = int(sorted_data[start + b][1][okey][0]["shape"][0]) * odim
+            if olen > max_olen:
+                max_olen = olen
+            next_size = (max_olen + ilen) * (b + 1)
+            if next_size <= batch_bins:
+                b += 1
+            elif next_size == 0:
+                raise ValueError(
+                    f"Can't fit one sample in batch_bins ({batch_bins}): "
+                    f"Please increase the value")
+        end = min(length, start + max(min_batch_size, b))
+        batch = sorted_data[start:end]
+        if shortest_first:
+            batch.reverse()
+        minibatches.append(batch)
+        # Check for min_batch_size and fixes the batches if needed
+        i = -1
+        while len(minibatches[i]) < min_batch_size:
+            missing = min_batch_size - len(minibatches[i])
+            if -i == len(minibatches):
+                minibatches[i + 1].extend(minibatches[i])
+                minibatches = minibatches[1:]
+                break
+            else:
+                minibatches[i].extend(minibatches[i - 1][:missing])
+                minibatches[i - 1] = minibatches[i - 1][missing:]
+                i -= 1
+        if end == length:
+            break
+        start = end
+        n += 1
+    if num_batches > 0:
+        minibatches = minibatches[:num_batches]
+    lengths = [len(x) for x in minibatches]
+    logger.info(
+        str(len(minibatches)) + " batches containing from " + str(min(lengths))
+        + " to " + str(max(lengths)) + " samples " + "(avg " + str(
+            int(np.mean(lengths))) + " samples).")
+    return minibatches
+
+
+def batchfy_by_frame(
+        sorted_data,
+        max_frames_in,
+        max_frames_out,
+        max_frames_inout,
+        num_batches=0,
+        min_batch_size=1,
+        shortest_first=False,
+        ikey="input",
+        okey="output", ):
+    """Make variable batch set, which maximizes the number of frames to max_batch_frame.
+
+    :param List[(str, Dict[str, Any])] sorteddata: dictionary loaded from data.json
+    :param int max_frames_in: Maximum input frames of a batch
+    :param int max_frames_out: Maximum output frames of a batch
+    :param int max_frames_inout: Maximum input+output frames of a batch
+    :param int num_batches: # number of batches to use (for debug)
+    :param int min_batch_size: minimum batch size (for multi-gpu)
+    :param int test: Return only every `test` batches
+    :param bool shortest_first: Sort from batch with shortest samples
+        to longest if true, otherwise reverse
+
+    :param str ikey: key to access input (for ASR ikey="input", for TTS ikey="output".)
+    :param str okey: key to access output (for ASR okey="output". for TTS okey="input".)
+
+    :return: List[Tuple[str, Dict[str, List[Dict[str, Any]]]] list of batches
+    """
+    if max_frames_in <= 0 and max_frames_out <= 0 and max_frames_inout <= 0:
+        raise ValueError(
+            "At least, one of `--batch-frames-in`, `--batch-frames-out` or "
+            "`--batch-frames-inout` should be > 0")
+    length = len(sorted_data)
+    minibatches = []
+    start = 0
+    end = 0
+    while end != length:
+        # Dynamic batch size depending on size of samples
+        b = 0
+        max_olen = 0
+        max_ilen = 0
+        while (start + b) < length:
+            ilen = int(sorted_data[start + b][1][ikey][0]["shape"][0])
+            if ilen > max_frames_in and max_frames_in != 0:
+                raise ValueError(
+                    f"Can't fit one sample in --batch-frames-in ({max_frames_in}): "
+                    f"Please increase the value")
+            olen = int(sorted_data[start + b][1][okey][0]["shape"][0])
+            if olen > max_frames_out and max_frames_out != 0:
+                raise ValueError(
+                    f"Can't fit one sample in --batch-frames-out ({max_frames_out}): "
+                    f"Please increase the value")
+            if ilen + olen > max_frames_inout and max_frames_inout != 0:
+                raise ValueError(
+                    f"Can't fit one sample in --batch-frames-out ({max_frames_inout}): "
+                    f"Please increase the value")
+            max_olen = max(max_olen, olen)
+            max_ilen = max(max_ilen, ilen)
+            in_ok = max_ilen * (b + 1) <= max_frames_in or max_frames_in == 0
+            out_ok = max_olen * (b + 1) <= max_frames_out or max_frames_out == 0
+            inout_ok = (max_ilen + max_olen) * (
+                b + 1) <= max_frames_inout or max_frames_inout == 0
+            if in_ok and out_ok and inout_ok:
+                # add more seq in the minibatch
+                b += 1
+            else:
+                # no more seq in the minibatch
+                break
+        end = min(length, start + b)
+        batch = sorted_data[start:end]
+        if shortest_first:
+            batch.reverse()
+        minibatches.append(batch)
+        # Check for min_batch_size and fixes the batches if needed
+        i = -1
+        while len(minibatches[i]) < min_batch_size:
+            missing = min_batch_size - len(minibatches[i])
+            if -i == len(minibatches):
+                minibatches[i + 1].extend(minibatches[i])
+                minibatches = minibatches[1:]
+                break
+            else:
+                minibatches[i].extend(minibatches[i - 1][:missing])
+                minibatches[i - 1] = minibatches[i - 1][missing:]
+                i -= 1
+        start = end
+    if num_batches > 0:
+        minibatches = minibatches[:num_batches]
+    lengths = [len(x) for x in minibatches]
+    logger.info(
+        str(len(minibatches)) + " batches containing from " + str(min(lengths))
+        + " to " + str(max(lengths)) + " samples" + "(avg " + str(
+            int(np.mean(lengths))) + " samples).")
+
+    return minibatches
+
+
+def batchfy_shuffle(data, batch_size, min_batch_size, num_batches,
+                    shortest_first):
+    import random
+
+    logger.info("use shuffled batch.")
+    sorted_data = random.sample(data.items(), len(data.items()))
+    logger.info("# utts: " + str(len(sorted_data)))
+    # make list of minibatches
+    minibatches = []
+    start = 0
+    while True:
+        end = min(len(sorted_data), start + batch_size)
+        # check each batch is more than minimum batchsize
+        minibatch = sorted_data[start:end]
+        if shortest_first:
+            minibatch.reverse()
+        if len(minibatch) < min_batch_size:
+            mod = min_batch_size - len(minibatch) % min_batch_size
+            additional_minibatch = [
+                sorted_data[i] for i in np.random.randint(0, start, mod)
+            ]
+            if shortest_first:
+                additional_minibatch.reverse()
+            minibatch.extend(additional_minibatch)
+        minibatches.append(minibatch)
+        if end == len(sorted_data):
+            break
+        start = end
+
+    # for debugging
+    if num_batches > 0:
+        minibatches = minibatches[:num_batches]
+        logger.info("# minibatches: " + str(len(minibatches)))
+    return minibatches
+
+
+BATCH_COUNT_CHOICES = ["auto", "seq", "bin", "frame"]
+BATCH_SORT_KEY_CHOICES = ["input", "output", "shuffle"]
+
+
+def make_batchset(
+        data,
+        batch_size=0,
+        max_length_in=float("inf"),
+        max_length_out=float("inf"),
+        num_batches=0,
+        min_batch_size=1,
+        shortest_first=False,
+        batch_sort_key="input",
+        count="auto",
+        batch_bins=0,
+        batch_frames_in=0,
+        batch_frames_out=0,
+        batch_frames_inout=0,
+        iaxis=0,
+        oaxis=0, ):
+    """Make batch set from json dictionary
+
+    if utts have "category" value,
+
+        >>> data = [{'category': 'A', 'input': ..., 'utt':'utt1'},
+        ...         {'category': 'B', 'input': ..., 'utt':'utt2'},
+        ...         {'category': 'B', 'input': ..., 'utt':'utt3'},
+        ...         {'category': 'A', 'input': ..., 'utt':'utt4'}]
+        >>> make_batchset(data, batchsize=2, ...)
+        [[('utt1', ...), ('utt4', ...)], [('utt2', ...), ('utt3': ...)]]
+
+    Note that if any utts doesn't have "category",
+    perform as same as batchfy_by_{count}
+
+    :param List[Dict[str, Any]] data: dictionary loaded from data.json
+    :param int batch_size: maximum number of sequences in a minibatch.
+    :param int batch_bins: maximum number of bins (frames x dim) in a minibatch.
+    :param int batch_frames_in:  maximum number of input frames in a minibatch.
+    :param int batch_frames_out: maximum number of output frames in a minibatch.
+    :param int batch_frames_out: maximum number of input+output frames in a minibatch.
+    :param str count: strategy to count maximum size of batch.
+        For choices, see espnet.asr.batchfy.BATCH_COUNT_CHOICES
+
+    :param int max_length_in: maximum length of input to decide adaptive batch size
+    :param int max_length_out: maximum length of output to decide adaptive batch size
+    :param int num_batches: # number of batches to use (for debug)
+    :param int min_batch_size: minimum batch size (for multi-gpu)
+    :param bool shortest_first: Sort from batch with shortest samples
+        to longest if true, otherwise reverse
+    :param str batch_sort_key: how to sort data before creating minibatches
+        ["input", "output", "shuffle"]
+    :param bool swap_io: if True, use "input" as output and "output"
+        as input in `data` dict
+    :param bool mt: if True, use 0-axis of "output" as output and 1-axis of "output"
+        as input in `data` dict
+    :param int iaxis: dimension to access input
+        (for ASR, TTS iaxis=0, for MT iaxis="1".)
+    :param int oaxis: dimension to access output (for ASR, TTS, MT oaxis=0,
+        reserved for future research, -1 means all axis.)
+    :return: List[List[Tuple[str, dict]]] list of batches
+    """
+    # check args
+    if count not in BATCH_COUNT_CHOICES:
+        raise ValueError(
+            f"arg 'count' ({count}) should be one of {BATCH_COUNT_CHOICES}")
+    if batch_sort_key not in BATCH_SORT_KEY_CHOICES:
+        raise ValueError(f"arg 'batch_sort_key' ({batch_sort_key}) should be "
+                         f"one of {BATCH_SORT_KEY_CHOICES}")
+
+    ikey = "input"
+    okey = "output"
+    batch_sort_axis = 0  # index of list 
+    if count == "auto":
+        if batch_size != 0:
+            count = "seq"
+        elif batch_bins != 0:
+            count = "bin"
+        elif batch_frames_in != 0 or batch_frames_out != 0 or batch_frames_inout != 0:
+            count = "frame"
+        else:
+            raise ValueError(
+                f"cannot detect `count` manually set one of {BATCH_COUNT_CHOICES}"
+            )
+        logger.info(f"count is auto detected as {count}")
+
+    if count != "seq" and batch_sort_key == "shuffle":
+        raise ValueError(
+            "batch_sort_key=shuffle is only available if batch_count=seq")
+
+    category2data = {}  # Dict[str, dict]
+    for v in data:
+        k = v['utt']
+        category2data.setdefault(v.get("category"), {})[k] = v
+
+    batches_list = []  # List[List[List[Tuple[str, dict]]]]
+    for d in category2data.values():
+        if batch_sort_key == "shuffle":
+            batches = batchfy_shuffle(d, batch_size, min_batch_size,
+                                      num_batches, shortest_first)
+            batches_list.append(batches)
+            continue
+
+        # sort it by input lengths (long to short)
+        sorted_data = sorted(
+            d.items(),
+            key=lambda data: int(data[1][batch_sort_key][batch_sort_axis]["shape"][0]),
+            reverse=not shortest_first, )
+        logger.info("# utts: " + str(len(sorted_data)))
+
+        if count == "seq":
+            batches = batchfy_by_seq(
+                sorted_data,
+                batch_size=batch_size,
+                max_length_in=max_length_in,
+                max_length_out=max_length_out,
+                min_batch_size=min_batch_size,
+                shortest_first=shortest_first,
+                ikey=ikey,
+                iaxis=iaxis,
+                okey=okey,
+                oaxis=oaxis, )
+        if count == "bin":
+            batches = batchfy_by_bin(
+                sorted_data,
+                batch_bins=batch_bins,
+                min_batch_size=min_batch_size,
+                shortest_first=shortest_first,
+                ikey=ikey,
+                okey=okey, )
+        if count == "frame":
+            batches = batchfy_by_frame(
+                sorted_data,
+                max_frames_in=batch_frames_in,
+                max_frames_out=batch_frames_out,
+                max_frames_inout=batch_frames_inout,
+                min_batch_size=min_batch_size,
+                shortest_first=shortest_first,
+                ikey=ikey,
+                okey=okey, )
+        batches_list.append(batches)
+
+    if len(batches_list) == 1:
+        batches = batches_list[0]
+    else:
+        # Concat list. This way is faster than "sum(batch_list, [])"
+        batches = list(itertools.chain(*batches_list))
+
+    # for debugging
+    if num_batches > 0:
+        batches = batches[:num_batches]
+    logger.info("# minibatches: " + str(len(batches)))
+
+    # batch: List[List[Tuple[str, dict]]]
+    return batches
diff --git a/deepspeech/io/collator.py b/deepspeech/io/collator.py
index 2ef119666..df3004790 100644
--- a/deepspeech/io/collator.py
+++ b/deepspeech/io/collator.py
@@ -23,7 +23,7 @@ from deepspeech.frontend.featurizer.speech_featurizer import SpeechFeaturizer
 from deepspeech.frontend.normalizer import FeatureNormalizer
 from deepspeech.frontend.speech import SpeechSegment
 from deepspeech.frontend.utility import IGNORE_ID
-from deepspeech.io.utility import pad_sequence
+from deepspeech.io.utility import pad_list
 from deepspeech.utils.log import Log
 
 __all__ = ["SpeechCollator"]
@@ -242,7 +242,6 @@ class SpeechCollator():
 
         # specgram augment
         specgram = self._augmentation_pipeline.transform_feature(specgram)
-        specgram = specgram.transpose([1, 0])
         return specgram, transcript_part
 
     def __call__(self, batch):
@@ -250,7 +249,7 @@ class SpeechCollator():
 
         Args:
             batch ([List]): batch is (audio, text)
-                audio (np.ndarray) shape (D, T)
+                audio (np.ndarray) shape (T, D)
                 text (List[int] or str): shape (U,)
 
         Returns:
@@ -286,13 +285,12 @@ class SpeechCollator():
             texts.append(tokens)
             text_lens.append(tokens.shape[0])
 
-        padded_audios = pad_sequence(
-            audios, padding_value=0.0).astype(np.float32)  #[B, T, D]
-        audio_lens = np.array(audio_lens).astype(np.int64)
-        padded_texts = pad_sequence(
-            texts, padding_value=IGNORE_ID).astype(np.int64)
-        text_lens = np.array(text_lens).astype(np.int64)
-        return utts, padded_audios, audio_lens, padded_texts, text_lens
+        #[B, T, D]
+        xs_pad = pad_list(audios, 0.0).astype(np.float32)
+        ilens = np.array(audio_lens).astype(np.int64)
+        ys_pad = pad_list(texts, IGNORE_ID).astype(np.int64)
+        olens = np.array(text_lens).astype(np.int64)
+        return utts, xs_pad, ilens, ys_pad, olens
 
     @property
     def manifest(self):
diff --git a/deepspeech/io/collator_st.py b/deepspeech/io/collator_st.py
index 1ee361900..28573366b 100644
--- a/deepspeech/io/collator_st.py
+++ b/deepspeech/io/collator_st.py
@@ -217,6 +217,34 @@ class SpeechCollator():
         return self._local_data.tar2object[tarpath].extractfile(
             self._local_data.tar2info[tarpath][filename])
 
+    @property
+    def manifest(self):
+        return self._manifest
+
+    @property
+    def vocab_size(self):
+        return self._speech_featurizer.vocab_size
+
+    @property
+    def vocab_list(self):
+        return self._speech_featurizer.vocab_list
+
+    @property
+    def vocab_dict(self):
+        return self._speech_featurizer.vocab_dict
+
+    @property
+    def text_feature(self):
+        return self._speech_featurizer.text_feature
+
+    @property
+    def feature_size(self):
+        return self._speech_featurizer.feature_size
+
+    @property
+    def stride_ms(self):
+        return self._speech_featurizer.stride_ms
+
     def process_utterance(self, audio_file, translation):
         """Load, augment, featurize and normalize for speech data.
 
@@ -244,7 +272,6 @@ class SpeechCollator():
 
         # specgram augment
         specgram = self._augmentation_pipeline.transform_feature(specgram)
-        specgram = specgram.transpose([1, 0])
         return specgram, translation_part
 
     def __call__(self, batch):
@@ -252,7 +279,7 @@ class SpeechCollator():
 
         Args:
             batch ([List]): batch is (audio, text)
-                audio (np.ndarray) shape (D, T)
+                audio (np.ndarray) shape (T, D)
                 text (List[int] or str): shape (U,)
 
         Returns:
@@ -296,34 +323,6 @@ class SpeechCollator():
         text_lens = np.array(text_lens).astype(np.int64)
         return utts, padded_audios, audio_lens, padded_texts, text_lens
 
-    @property
-    def manifest(self):
-        return self._manifest
-
-    @property
-    def vocab_size(self):
-        return self._speech_featurizer.vocab_size
-
-    @property
-    def vocab_list(self):
-        return self._speech_featurizer.vocab_list
-
-    @property
-    def vocab_dict(self):
-        return self._speech_featurizer.vocab_dict
-
-    @property
-    def text_feature(self):
-        return self._speech_featurizer.text_feature
-
-    @property
-    def feature_size(self):
-        return self._speech_featurizer.feature_size
-
-    @property
-    def stride_ms(self):
-        return self._speech_featurizer.stride_ms
-
 
 class TripletSpeechCollator(SpeechCollator):
     def process_utterance(self, audio_file, translation, transcript):
@@ -355,7 +354,6 @@ class TripletSpeechCollator(SpeechCollator):
 
         # specgram augment
         specgram = self._augmentation_pipeline.transform_feature(specgram)
-        specgram = specgram.transpose([1, 0])
         return specgram, translation_part, transcript_part
 
     def __call__(self, batch):
@@ -363,7 +361,7 @@ class TripletSpeechCollator(SpeechCollator):
 
         Args:
             batch ([List]): batch is (audio, text)
-                audio (np.ndarray) shape (D, T)
+                audio (np.ndarray) shape (T, D)
                 text (List[int] or str): shape (U,)
 
         Returns:
@@ -524,49 +522,19 @@ class KaldiPrePorocessedCollator(SpeechCollator):
         :rtype: tuple of (2darray, list)
         """
         specgram = kaldiio.load_mat(audio_file)
-        specgram = specgram.transpose([1, 0])
         assert specgram.shape[
-            0] == self._feat_dim, 'expect feat dim {}, but got {}'.format(
-                self._feat_dim, specgram.shape[0])
+            1] == self._feat_dim, 'expect feat dim {}, but got {}'.format(
+                self._feat_dim, specgram.shape[1])
 
         # specgram augment
         specgram = self._augmentation_pipeline.transform_feature(specgram)
 
-        specgram = specgram.transpose([1, 0])
         if self._keep_transcription_text:
             return specgram, translation
         else:
             text_ids = self._text_featurizer.featurize(translation)
             return specgram, text_ids
 
-    @property
-    def manifest(self):
-        return self._manifest
-
-    @property
-    def vocab_size(self):
-        return self._text_featurizer.vocab_size
-
-    @property
-    def vocab_list(self):
-        return self._text_featurizer.vocab_list
-
-    @property
-    def vocab_dict(self):
-        return self._text_featurizer.vocab_dict
-
-    @property
-    def text_feature(self):
-        return self._text_featurizer
-
-    @property
-    def feature_size(self):
-        return self._feat_dim
-
-    @property
-    def stride_ms(self):
-        return self._stride_ms
-
 
 class TripletKaldiPrePorocessedCollator(KaldiPrePorocessedCollator):
     def process_utterance(self, audio_file, translation, transcript):
@@ -583,15 +551,13 @@ class TripletKaldiPrePorocessedCollator(KaldiPrePorocessedCollator):
         :rtype: tuple of (2darray, (list, list))
         """
         specgram = kaldiio.load_mat(audio_file)
-        specgram = specgram.transpose([1, 0])
         assert specgram.shape[
-            0] == self._feat_dim, 'expect feat dim {}, but got {}'.format(
-                self._feat_dim, specgram.shape[0])
+            1] == self._feat_dim, 'expect feat dim {}, but got {}'.format(
+                self._feat_dim, specgram.shape[1])
 
         # specgram augment
         specgram = self._augmentation_pipeline.transform_feature(specgram)
 
-        specgram = specgram.transpose([1, 0])
         if self._keep_transcription_text:
             return specgram, translation, transcript
         else:
@@ -604,7 +570,7 @@ class TripletKaldiPrePorocessedCollator(KaldiPrePorocessedCollator):
 
         Args:
             batch ([List]): batch is (audio, text)
-                audio (np.ndarray) shape (D, T)
+                audio (np.ndarray) shape (T, D)
                 translation (List[int] or str): shape (U,)
                 transcription (List[int] or str): shape (V,)
 
diff --git a/deepspeech/io/converter.py b/deepspeech/io/converter.py
new file mode 100644
index 000000000..3bfcc1b1e
--- /dev/null
+++ b/deepspeech/io/converter.py
@@ -0,0 +1,81 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import numpy as np
+
+from deepspeech.io.utility import pad_list
+from deepspeech.utils.log import Log
+
+__all__ = ["CustomConverter"]
+
+logger = Log(__name__).getlog()
+
+
+class CustomConverter():
+    """Custom batch converter.
+
+    Args:
+        subsampling_factor (int): The subsampling factor.
+        dtype (np.dtype): Data type to convert.
+        
+    """
+
+    def __init__(self, subsampling_factor=1, dtype=np.float32):
+        """Construct a CustomConverter object."""
+        self.subsampling_factor = subsampling_factor
+        self.ignore_id = -1
+        self.dtype = dtype
+
+    def __call__(self, batch):
+        """Transform a batch and send it to a device.
+
+        Args:
+            batch (list): The batch to transform.
+
+        Returns:
+            tuple(paddle.Tensor, paddle.Tensor, paddle.Tensor)
+
+        """
+        # batch should be located in list
+        assert len(batch) == 1
+        (xs, ys), utts = batch[0]
+        assert xs[0] is not None, "please check Reader and Augmentation impl."
+
+        # perform subsampling
+        if self.subsampling_factor > 1:
+            xs = [x[::self.subsampling_factor, :] for x in xs]
+
+        # get batch of lengths of input sequences
+        ilens = np.array([x.shape[0] for x in xs])
+
+        # perform padding and convert to tensor
+        # currently only support real number
+        if xs[0].dtype.kind == "c":
+            xs_pad_real = pad_list([x.real for x in xs], 0).astype(self.dtype)
+            xs_pad_imag = pad_list([x.imag for x in xs], 0).astype(self.dtype)
+            # Note(kamo):
+            # {'real': ..., 'imag': ...} will be changed to ComplexTensor in E2E.
+            # Don't create ComplexTensor and give it E2E here
+            # because torch.nn.DataParellel can't handle it.
+            xs_pad = {"real": xs_pad_real, "imag": xs_pad_imag}
+        else:
+            xs_pad = pad_list(xs, 0).astype(self.dtype)
+
+        # NOTE: this is for multi-output (e.g., speech translation)
+        ys_pad = pad_list(
+            [np.array(y[0][:]) if isinstance(y, tuple) else y for y in ys],
+            self.ignore_id)
+
+        olens = np.array(
+            [y[0].shape[0] if isinstance(y, tuple) else y.shape[0] for y in ys])
+        return utts, xs_pad, ilens, ys_pad, olens
diff --git a/deepspeech/io/dataloader.py b/deepspeech/io/dataloader.py
new file mode 100644
index 000000000..115fe4617
--- /dev/null
+++ b/deepspeech/io/dataloader.py
@@ -0,0 +1,158 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Any
+from typing import Dict
+from typing import List
+from typing import Text
+
+import numpy as np
+from paddle.io import DataLoader
+
+from deepspeech.frontend.utility import read_manifest
+from deepspeech.io.batchfy import make_batchset
+from deepspeech.io.converter import CustomConverter
+from deepspeech.io.dataset import TransformDataset
+from deepspeech.io.reader import LoadInputsAndTargets
+from deepspeech.utils.log import Log
+
+__all__ = ["BatchDataLoader"]
+
+logger = Log(__name__).getlog()
+
+
+def feat_dim_and_vocab_size(data_json: List[Dict[Text, Any]],
+                            mode: Text="asr",
+                            iaxis=0,
+                            oaxis=0):
+    if mode == 'asr':
+        feat_dim = data_json[0]['input'][oaxis]['shape'][1]
+        vocab_size = data_json[0]['output'][oaxis]['shape'][1]
+    else:
+        raise ValueError(f"{mode} mode not support!")
+    return feat_dim, vocab_size
+
+
+class BatchDataLoader():
+    def __init__(self,
+                 json_file: str,
+                 train_mode: bool,
+                 sortagrad: bool=False,
+                 batch_size: int=0,
+                 maxlen_in: float=float('inf'),
+                 maxlen_out: float=float('inf'),
+                 minibatches: int=0,
+                 mini_batch_size: int=1,
+                 batch_count: str='auto',
+                 batch_bins: int=0,
+                 batch_frames_in: int=0,
+                 batch_frames_out: int=0,
+                 batch_frames_inout: int=0,
+                 preprocess_conf=None,
+                 n_iter_processes: int=1,
+                 subsampling_factor: int=1,
+                 num_encs: int=1):
+        self.json_file = json_file
+        self.train_mode = train_mode
+        self.use_sortagrad = sortagrad == -1 or sortagrad > 0
+        self.batch_size = batch_size
+        self.maxlen_in = maxlen_in
+        self.maxlen_out = maxlen_out
+        self.batch_count = batch_count
+        self.batch_bins = batch_bins
+        self.batch_frames_in = batch_frames_in
+        self.batch_frames_out = batch_frames_out
+        self.batch_frames_inout = batch_frames_inout
+        self.subsampling_factor = subsampling_factor
+        self.num_encs = num_encs
+        self.preprocess_conf = preprocess_conf
+        self.n_iter_processes = n_iter_processes
+
+        # read json data
+        self.data_json = read_manifest(json_file)
+        self.feat_dim, self.vocab_size = feat_dim_and_vocab_size(
+            self.data_json, mode='asr')
+
+        # make minibatch list (variable length)
+        self.minibaches = make_batchset(
+            self.data_json,
+            batch_size,
+            maxlen_in,
+            maxlen_out,
+            minibatches,  # for debug
+            min_batch_size=mini_batch_size,
+            shortest_first=self.use_sortagrad,
+            count=batch_count,
+            batch_bins=batch_bins,
+            batch_frames_in=batch_frames_in,
+            batch_frames_out=batch_frames_out,
+            batch_frames_inout=batch_frames_inout,
+            iaxis=0,
+            oaxis=0, )
+
+        # data reader
+        self.reader = LoadInputsAndTargets(
+            mode="asr",
+            load_output=True,
+            preprocess_conf=preprocess_conf,
+            preprocess_args={"train":
+                             train_mode},  # Switch the mode of preprocessing
+        )
+
+        # Setup a converter
+        if num_encs == 1:
+            self.converter = CustomConverter(
+                subsampling_factor=subsampling_factor, dtype=np.float32)
+        else:
+            assert NotImplementedError("not impl CustomConverterMulEnc.")
+
+        # hack to make batchsize argument as 1
+        # actual bathsize is included in a list
+        # default collate function converts numpy array to pytorch tensor
+        # we used an empty collate function instead which returns list
+        self.dataset = TransformDataset(
+            self.minibaches,
+            lambda data: self.converter([self.reader(data, return_uttid=True)]))
+        self.dataloader = DataLoader(
+            dataset=self.dataset,
+            batch_size=1,
+            shuffle=not self.use_sortagrad if train_mode else False,
+            collate_fn=lambda x: x[0],
+            num_workers=n_iter_processes, )
+
+    def __repr__(self):
+        echo = f"<{self.__class__.__module__}.{self.__class__.__name__} object at {hex(id(self))}> "
+        echo += f"train_mode: {self.train_mode}, "
+        echo += f"sortagrad: {self.use_sortagrad}, "
+        echo += f"batch_size: {self.batch_size}, "
+        echo += f"maxlen_in: {self.maxlen_in}, "
+        echo += f"maxlen_out: {self.maxlen_out}, "
+        echo += f"batch_count: {self.batch_count}, "
+        echo += f"batch_bins: {self.batch_bins}, "
+        echo += f"batch_frames_in: {self.batch_frames_in}, "
+        echo += f"batch_frames_out: {self.batch_frames_out}, "
+        echo += f"batch_frames_inout: {self.batch_frames_inout}, "
+        echo += f"subsampling_factor: {self.subsampling_factor}, "
+        echo += f"num_encs: {self.num_encs}, "
+        echo += f"num_workers: {self.n_iter_processes}, "
+        echo += f"file: {self.json_file}"
+        return echo
+
+    def __len__(self):
+        return len(self.dataloader)
+
+    def __iter__(self):
+        return self.dataloader.__iter__()
+
+    def __call__(self):
+        return self.__iter__()
diff --git a/deepspeech/io/dataset.py b/deepspeech/io/dataset.py
index ac7be1f9e..74c08b461 100644
--- a/deepspeech/io/dataset.py
+++ b/deepspeech/io/dataset.py
@@ -19,7 +19,7 @@ from yacs.config import CfgNode
 from deepspeech.frontend.utility import read_manifest
 from deepspeech.utils.log import Log
 
-__all__ = ["ManifestDataset", "TripletManifestDataset"]
+__all__ = ["ManifestDataset", "TripletManifestDataset", "TransformDataset"]
 
 logger = Log(__name__).getlog()
 
@@ -76,12 +76,18 @@ class ManifestDataset(Dataset):
 
         Args:
             manifest_path (str): manifest josn file path
-            max_input_len ([type], optional): maximum output seq length, in seconds for raw wav, in frame numbers for feature data. Defaults to float('inf').
-            min_input_len (float, optional): minimum input seq length, in seconds for raw wav, in frame numbers for feature data. Defaults to 0.0.
-            max_output_len (float, optional): maximum input seq length, in modeling units. Defaults to 500.0.
-            min_output_len (float, optional): minimum input seq length, in modeling units. Defaults to 0.0.
-            max_output_input_ratio (float, optional): maximum output seq length/output seq length ratio. Defaults to 10.0.
-            min_output_input_ratio (float, optional): minimum output seq length/output seq length ratio. Defaults to 0.05.
+            max_input_len ([type], optional): maximum output seq length, 
+                in seconds for raw wav, in frame numbers for feature data. Defaults to float('inf').
+            min_input_len (float, optional): minimum input seq length, 
+                in seconds for raw wav, in frame numbers for feature data. Defaults to 0.0.
+            max_output_len (float, optional): maximum input seq length, 
+                in modeling units. Defaults to 500.0.
+            min_output_len (float, optional): minimum input seq length, 
+                in modeling units. Defaults to 0.0.
+            max_output_input_ratio (float, optional): maximum output seq length/output seq length ratio. 
+                Defaults to 10.0.
+            min_output_input_ratio (float, optional): minimum output seq length/output seq length ratio.
+                Defaults to 0.05.
         
         """
         super().__init__()
@@ -116,3 +122,27 @@ class TripletManifestDataset(ManifestDataset):
         instance = self._manifest[idx]
         return instance["utt"], instance["feat"], instance["text"], instance[
             "text1"]
+
+
+class TransformDataset(Dataset):
+    """Transform Dataset.
+
+    Args:
+        data: list object from make_batchset
+        transfrom: transform function
+
+    """
+
+    def __init__(self, data, transform):
+        """Init function."""
+        super().__init__()
+        self.data = data
+        self.transform = transform
+
+    def __len__(self):
+        """Len function."""
+        return len(self.data)
+
+    def __getitem__(self, idx):
+        """[] operator."""
+        return self.transform(self.data[idx])
diff --git a/deepspeech/io/reader.py b/deepspeech/io/reader.py
new file mode 100644
index 000000000..95cdbb951
--- /dev/null
+++ b/deepspeech/io/reader.py
@@ -0,0 +1,410 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from collections import OrderedDict
+
+import kaldiio
+import numpy as np
+import soundfile
+
+from deepspeech.frontend.augmentor.augmentation import AugmentationPipeline
+from deepspeech.utils.log import Log
+
+__all__ = ["LoadInputsAndTargets"]
+
+logger = Log(__name__).getlog()
+
+
+class LoadInputsAndTargets():
+    """Create a mini-batch from a list of dicts
+
+    >>> batch = [('utt1',
+    ...           dict(input=[dict(feat='some.ark:123',
+    ...                            filetype='mat',
+    ...                            name='input1',
+    ...                            shape=[100, 80])],
+    ...                output=[dict(tokenid='1 2 3 4',
+    ...                             name='target1',
+    ...                             shape=[4, 31])]]))
+    >>> l = LoadInputsAndTargets()
+    >>> feat, target = l(batch)
+
+    :param: str mode: Specify the task mode, "asr" or "tts"
+    :param: str preprocess_conf: The path of a json file for pre-processing
+    :param: bool load_input: If False, not to load the input data
+    :param: bool load_output: If False, not to load the output data
+    :param: bool sort_in_input_length: Sort the mini-batch in descending order
+        of the input length
+    :param: bool use_speaker_embedding: Used for tts mode only
+    :param: bool use_second_target: Used for tts mode only
+    :param: dict preprocess_args: Set some optional arguments for preprocessing
+    :param: Optional[dict] preprocess_args: Used for tts mode only
+    """
+
+    def __init__(
+            self,
+            mode="asr",
+            preprocess_conf=None,
+            load_input=True,
+            load_output=True,
+            sort_in_input_length=True,
+            preprocess_args=None,
+            keep_all_data_on_mem=False, ):
+        self._loaders = {}
+
+        if mode not in ["asr"]:
+            raise ValueError("Only asr are allowed: mode={}".format(mode))
+
+        if preprocess_conf is not None:
+            with open(preprocess_conf, 'r') as fin:
+                self.preprocessing = AugmentationPipeline(fin.read())
+            logger.warning(
+                "[Experimental feature] Some preprocessing will be done "
+                "for the mini-batch creation using {}".format(
+                    self.preprocessing))
+        else:
+            # If conf doesn't exist, this function don't touch anything.
+            self.preprocessing = None
+
+        self.mode = mode
+        self.load_output = load_output
+        self.load_input = load_input
+        self.sort_in_input_length = sort_in_input_length
+        if preprocess_args is None:
+            self.preprocess_args = {}
+        else:
+            assert isinstance(preprocess_args, dict), type(preprocess_args)
+            self.preprocess_args = dict(preprocess_args)
+
+        self.keep_all_data_on_mem = keep_all_data_on_mem
+
+    def __call__(self, batch, return_uttid=False):
+        """Function to load inputs and targets from list of dicts
+
+        :param List[Tuple[str, dict]] batch: list of dict which is subset of
+            loaded data.json
+        :param bool return_uttid: return utterance ID information for visualization
+        :return: list of input token id sequences [(L_1), (L_2), ..., (L_B)]
+        :return: list of input feature sequences
+            [(T_1, D), (T_2, D), ..., (T_B, D)]
+        :rtype: list of float ndarray
+        :return: list of target token id sequences [(L_1), (L_2), ..., (L_B)]
+        :rtype: list of int ndarray
+
+        """
+        x_feats_dict = OrderedDict()  # OrderedDict[str, List[np.ndarray]]
+        y_feats_dict = OrderedDict()  # OrderedDict[str, List[np.ndarray]]
+        uttid_list = []  # List[str]
+
+        for uttid, info in batch:
+            uttid_list.append(uttid)
+
+            if self.load_input:
+                # Note(kamo): This for-loop is for multiple inputs
+                for idx, inp in enumerate(info["input"]):
+                    # {"input":
+                    #  [{"feat": "some/path.h5:F01_050C0101_PED_REAL",
+                    #    "filetype": "hdf5",
+                    #    "name": "input1", ...}], ...}
+                    x = self._get_from_loader(
+                        filepath=inp["feat"],
+                        filetype=inp.get("filetype", "mat"))
+                    x_feats_dict.setdefault(inp["name"], []).append(x)
+
+            if self.load_output:
+                for idx, inp in enumerate(info["output"]):
+                    if "tokenid" in inp:
+                        # ======= Legacy format for output =======
+                        # {"output": [{"tokenid": "1 2 3 4"}])
+                        x = np.fromiter(
+                            map(int, inp["tokenid"].split()), dtype=np.int64)
+                    else:
+                        # ======= New format =======
+                        # {"input":
+                        #  [{"feat": "some/path.h5:F01_050C0101_PED_REAL",
+                        #    "filetype": "hdf5",
+                        #    "name": "target1", ...}], ...}
+                        x = self._get_from_loader(
+                            filepath=inp["feat"],
+                            filetype=inp.get("filetype", "mat"))
+
+                    y_feats_dict.setdefault(inp["name"], []).append(x)
+
+        if self.mode == "asr":
+            return_batch, uttid_list = self._create_batch_asr(
+                x_feats_dict, y_feats_dict, uttid_list)
+        else:
+            raise NotImplementedError(self.mode)
+
+        if self.preprocessing is not None:
+            # Apply pre-processing all input features
+            for x_name in return_batch.keys():
+                if x_name.startswith("input"):
+                    return_batch[x_name] = self.preprocessing(
+                        return_batch[x_name], uttid_list,
+                        **self.preprocess_args)
+
+        if return_uttid:
+            return tuple(return_batch.values()), uttid_list
+
+        # Doesn't return the names now.
+        return tuple(return_batch.values())
+
+    def _create_batch_asr(self, x_feats_dict, y_feats_dict, uttid_list):
+        """Create a OrderedDict for the mini-batch
+
+        :param OrderedDict x_feats_dict:
+            e.g. {"input1": [ndarray, ndarray, ...],
+                  "input2": [ndarray, ndarray, ...]}
+        :param OrderedDict y_feats_dict:
+            e.g. {"target1": [ndarray, ndarray, ...],
+                  "target2": [ndarray, ndarray, ...]}
+        :param: List[str] uttid_list:
+            Give uttid_list to sort in the same order as the mini-batch
+        :return: batch, uttid_list
+        :rtype: Tuple[OrderedDict, List[str]]
+        """
+        # handle single-input and multi-input (paralell) asr mode
+        xs = list(x_feats_dict.values())
+
+        if self.load_output:
+            ys = list(y_feats_dict.values())
+            assert len(xs[0]) == len(ys[0]), (len(xs[0]), len(ys[0]))
+
+            # get index of non-zero length samples
+            nonzero_idx = list(
+                filter(lambda i: len(ys[0][i]) > 0, range(len(ys[0]))))
+            for n in range(1, len(y_feats_dict)):
+                nonzero_idx = filter(lambda i: len(ys[n][i]) > 0, nonzero_idx)
+        else:
+            # Note(kamo): Be careful not to make nonzero_idx to a generator
+            nonzero_idx = list(range(len(xs[0])))
+
+        if self.sort_in_input_length:
+            # sort in input lengths based on the first input
+            nonzero_sorted_idx = sorted(
+                nonzero_idx, key=lambda i: -len(xs[0][i]))
+        else:
+            nonzero_sorted_idx = nonzero_idx
+
+        if len(nonzero_sorted_idx) != len(xs[0]):
+            logger.warning(
+                "Target sequences include empty tokenid (batch {} -> {}).".
+                format(len(xs[0]), len(nonzero_sorted_idx)))
+
+        # remove zero-length samples
+        xs = [[x[i] for i in nonzero_sorted_idx] for x in xs]
+        uttid_list = [uttid_list[i] for i in nonzero_sorted_idx]
+
+        x_names = list(x_feats_dict.keys())
+        if self.load_output:
+            ys = [[y[i] for i in nonzero_sorted_idx] for y in ys]
+            y_names = list(y_feats_dict.keys())
+
+            # Keeping x_name and y_name, e.g. input1, for future extension
+            return_batch = OrderedDict([
+                * [(x_name, x) for x_name, x in zip(x_names, xs)],
+                * [(y_name, y) for y_name, y in zip(y_names, ys)],
+            ])
+        else:
+            return_batch = OrderedDict(
+                [(x_name, x) for x_name, x in zip(x_names, xs)])
+        return return_batch, uttid_list
+
+    def _get_from_loader(self, filepath, filetype):
+        """Return ndarray
+
+        In order to make the fds to be opened only at the first referring,
+        the loader are stored in self._loaders
+
+        >>> ndarray = loader.get_from_loader(
+        ...     'some/path.h5:F01_050C0101_PED_REAL', filetype='hdf5')
+
+        :param: str filepath:
+        :param: str filetype:
+        :return:
+        :rtype: np.ndarray
+        """
+        if filetype == "hdf5":
+            # e.g.
+            #    {"input": [{"feat": "some/path.h5:F01_050C0101_PED_REAL",
+            #                "filetype": "hdf5",
+            # -> filepath = "some/path.h5", key = "F01_050C0101_PED_REAL"
+            filepath, key = filepath.split(":", 1)
+
+            loader = self._loaders.get(filepath)
+            if loader is None:
+                # To avoid disk access, create loader only for the first time
+                loader = h5py.File(filepath, "r")
+                self._loaders[filepath] = loader
+            return loader[key][()]
+        elif filetype == "sound.hdf5":
+            # e.g.
+            #    {"input": [{"feat": "some/path.h5:F01_050C0101_PED_REAL",
+            #                "filetype": "sound.hdf5",
+            # -> filepath = "some/path.h5", key = "F01_050C0101_PED_REAL"
+            filepath, key = filepath.split(":", 1)
+
+            loader = self._loaders.get(filepath)
+            if loader is None:
+                # To avoid disk access, create loader only for the first time
+                loader = SoundHDF5File(filepath, "r", dtype="int16")
+                self._loaders[filepath] = loader
+            array, rate = loader[key]
+            return array
+        elif filetype == "sound":
+            # e.g.
+            #    {"input": [{"feat": "some/path.wav",
+            #                "filetype": "sound"},
+            # Assume PCM16
+            if not self.keep_all_data_on_mem:
+                array, _ = soundfile.read(filepath, dtype="int16")
+                return array
+            if filepath not in self._loaders:
+                array, _ = soundfile.read(filepath, dtype="int16")
+                self._loaders[filepath] = array
+            return self._loaders[filepath]
+        elif filetype == "npz":
+            # e.g.
+            #    {"input": [{"feat": "some/path.npz:F01_050C0101_PED_REAL",
+            #                "filetype": "npz",
+            filepath, key = filepath.split(":", 1)
+
+            loader = self._loaders.get(filepath)
+            if loader is None:
+                # To avoid disk access, create loader only for the first time
+                loader = np.load(filepath)
+                self._loaders[filepath] = loader
+            return loader[key]
+        elif filetype == "npy":
+            # e.g.
+            #    {"input": [{"feat": "some/path.npy",
+            #                "filetype": "npy"},
+            if not self.keep_all_data_on_mem:
+                return np.load(filepath)
+            if filepath not in self._loaders:
+                self._loaders[filepath] = np.load(filepath)
+            return self._loaders[filepath]
+        elif filetype in ["mat", "vec"]:
+            # e.g.
+            #    {"input": [{"feat": "some/path.ark:123",
+            #                "filetype": "mat"}]},
+            # In this case, "123" indicates the starting points of the matrix
+            # load_mat can load both matrix and vector
+            if not self.keep_all_data_on_mem:
+                return kaldiio.load_mat(filepath)
+            if filepath not in self._loaders:
+                self._loaders[filepath] = kaldiio.load_mat(filepath)
+            return self._loaders[filepath]
+        elif filetype == "scp":
+            # e.g.
+            #    {"input": [{"feat": "some/path.scp:F01_050C0101_PED_REAL",
+            #                "filetype": "scp",
+            filepath, key = filepath.split(":", 1)
+            loader = self._loaders.get(filepath)
+            if loader is None:
+                # To avoid disk access, create loader only for the first time
+                loader = kaldiio.load_scp(filepath)
+                self._loaders[filepath] = loader
+            return loader[key]
+        else:
+            raise NotImplementedError(
+                "Not supported: loader_type={}".format(filetype))
+
+
+class SoundHDF5File():
+    """Collecting sound files to a HDF5 file
+
+    >>> f = SoundHDF5File('a.flac.h5', mode='a')
+    >>> array = np.random.randint(0, 100, 100, dtype=np.int16)
+    >>> f['id'] = (array, 16000)
+    >>> array, rate = f['id']
+
+
+    :param: str filepath:
+    :param: str mode:
+    :param: str format: The type used when saving wav. flac, nist, htk, etc.
+    :param: str dtype:
+
+    """
+
+    def __init__(self,
+                 filepath,
+                 mode="r+",
+                 format=None,
+                 dtype="int16",
+                 **kwargs):
+        self.filepath = filepath
+        self.mode = mode
+        self.dtype = dtype
+
+        self.file = h5py.File(filepath, mode, **kwargs)
+        if format is None:
+            # filepath = a.flac.h5 -> format = flac
+            second_ext = os.path.splitext(os.path.splitext(filepath)[0])[1]
+            format = second_ext[1:]
+            if format.upper() not in soundfile.available_formats():
+                # If not found, flac is selected
+                format = "flac"
+
+        # This format affects only saving
+        self.format = format
+
+    def __repr__(self):
+        return '<SoundHDF5 file "{}" (mode {}, format {}, type {})>'.format(
+            self.filepath, self.mode, self.format, self.dtype)
+
+    def create_dataset(self, name, shape=None, data=None, **kwds):
+        f = io.BytesIO()
+        array, rate = data
+        soundfile.write(f, array, rate, format=self.format)
+        self.file.create_dataset(
+            name, shape=shape, data=np.void(f.getvalue()), **kwds)
+
+    def __setitem__(self, name, data):
+        self.create_dataset(name, data=data)
+
+    def __getitem__(self, key):
+        data = self.file[key][()]
+        f = io.BytesIO(data.tobytes())
+        array, rate = soundfile.read(f, dtype=self.dtype)
+        return array, rate
+
+    def keys(self):
+        return self.file.keys()
+
+    def values(self):
+        for k in self.file:
+            yield self[k]
+
+    def items(self):
+        for k in self.file:
+            yield k, self[k]
+
+    def __iter__(self):
+        return iter(self.file)
+
+    def __contains__(self, item):
+        return item in self.file
+
+    def __len__(self, item):
+        return len(self.file)
+
+    def __enter__(self):
+        return self
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        self.file.close()
+
+    def close(self):
+        self.file.close()
diff --git a/deepspeech/io/utility.py b/deepspeech/io/utility.py
index 0cd37428b..99487a0af 100644
--- a/deepspeech/io/utility.py
+++ b/deepspeech/io/utility.py
@@ -17,11 +17,16 @@ import numpy as np
 
 from deepspeech.utils.log import Log
 
-__all__ = ["pad_sequence"]
+__all__ = ["pad_list", "pad_sequence"]
 
 logger = Log(__name__).getlog()
 
 
+def pad_list(sequences: List[np.ndarray],
+             padding_value: float=0.0) -> np.ndarray:
+    return pad_sequence(sequences, True, padding_value)
+
+
 def pad_sequence(sequences: List[np.ndarray],
                  batch_first: bool=True,
                  padding_value: float=0.0) -> np.ndarray:
diff --git a/deepspeech/models/ds2/rnn.py b/deepspeech/models/ds2/rnn.py
index 01b55c4a2..0d8c9fd2c 100644
--- a/deepspeech/models/ds2/rnn.py
+++ b/deepspeech/models/ds2/rnn.py
@@ -297,7 +297,7 @@ class RNNStack(nn.Layer):
                         share_weights=share_rnn_weights))
             i_size = h_size * 2
 
-        self.rnn_stacks = nn.ModuleList(rnn_stacks)
+        self.rnn_stacks = nn.LayerList(rnn_stacks)
 
     def forward(self, x: paddle.Tensor, x_len: paddle.Tensor):
         """
diff --git a/deepspeech/models/u2.py b/deepspeech/models/u2.py
index f1d466a27..c1a35560a 100644
--- a/deepspeech/models/u2.py
+++ b/deepspeech/models/u2.py
@@ -54,7 +54,7 @@ __all__ = ["U2Model", "U2InferModel"]
 logger = Log(__name__).getlog()
 
 
-class U2BaseModel(nn.Module):
+class U2BaseModel(nn.Layer):
     """CTC-Attention hybrid Encoder-Decoder model"""
 
     @classmethod
@@ -612,32 +612,32 @@ class U2BaseModel(nn.Module):
                 best_index = i
         return hyps[best_index][0]
 
-    #@jit.export
+    #@jit.to_static
     def subsampling_rate(self) -> int:
         """ Export interface for c++ call, return subsampling_rate of the
             model
         """
         return self.encoder.embed.subsampling_rate
 
-    #@jit.export
+    #@jit.to_static
     def right_context(self) -> int:
         """ Export interface for c++ call, return right_context of the model
         """
         return self.encoder.embed.right_context
 
-    #@jit.export
+    #@jit.to_static
     def sos_symbol(self) -> int:
         """ Export interface for c++ call, return sos symbol id of the model
         """
         return self.sos
 
-    #@jit.export
+    #@jit.to_static
     def eos_symbol(self) -> int:
         """ Export interface for c++ call, return eos symbol id of the model
         """
         return self.eos
 
-    @jit.export
+    @jit.to_static
     def forward_encoder_chunk(
             self,
             xs: paddle.Tensor,
@@ -667,7 +667,7 @@ class U2BaseModel(nn.Module):
             xs, offset, required_cache_size, subsampling_cache,
             elayers_output_cache, conformer_cnn_cache)
 
-    # @jit.export([
+    # @jit.to_static([
     #         paddle.static.InputSpec(shape=[1, None, feat_dim],dtype='float32'),  # audio feat, [B,T,D]
     #     ])
     def ctc_activation(self, xs: paddle.Tensor) -> paddle.Tensor:
@@ -680,7 +680,7 @@ class U2BaseModel(nn.Module):
         """
         return self.ctc.log_softmax(xs)
 
-    @jit.export
+    @jit.to_static
     def forward_attention_decoder(
             self,
             hyps: paddle.Tensor,
diff --git a/deepspeech/models/u2_st.py b/deepspeech/models/u2_st.py
index a73f52e99..b725cc359 100644
--- a/deepspeech/models/u2_st.py
+++ b/deepspeech/models/u2_st.py
@@ -48,7 +48,7 @@ __all__ = ["U2STModel", "U2STInferModel"]
 logger = Log(__name__).getlog()
 
 
-class U2STBaseModel(nn.Module):
+class U2STBaseModel(nn.Layer):
     """CTC-Attention hybrid Encoder-Decoder model"""
 
     @classmethod
@@ -417,32 +417,32 @@ class U2STBaseModel(nn.Module):
         best_hyps = best_hyps[:, 1:]
         return best_hyps
 
-    @jit.export
+    @jit.to_static
     def subsampling_rate(self) -> int:
         """ Export interface for c++ call, return subsampling_rate of the
             model
         """
         return self.encoder.embed.subsampling_rate
 
-    @jit.export
+    @jit.to_static
     def right_context(self) -> int:
         """ Export interface for c++ call, return right_context of the model
         """
         return self.encoder.embed.right_context
 
-    @jit.export
+    @jit.to_static
     def sos_symbol(self) -> int:
         """ Export interface for c++ call, return sos symbol id of the model
         """
         return self.sos
 
-    @jit.export
+    @jit.to_static
     def eos_symbol(self) -> int:
         """ Export interface for c++ call, return eos symbol id of the model
         """
         return self.eos
 
-    @jit.export
+    @jit.to_static
     def forward_encoder_chunk(
             self,
             xs: paddle.Tensor,
@@ -472,7 +472,7 @@ class U2STBaseModel(nn.Module):
             xs, offset, required_cache_size, subsampling_cache,
             elayers_output_cache, conformer_cnn_cache)
 
-    @jit.export
+    @jit.to_static
     def ctc_activation(self, xs: paddle.Tensor) -> paddle.Tensor:
         """ Export interface for c++ call, apply linear transform and log
             softmax before ctc
@@ -483,7 +483,7 @@ class U2STBaseModel(nn.Module):
         """
         return self.ctc.log_softmax(xs)
 
-    @jit.export
+    @jit.to_static
     def forward_attention_decoder(
             self,
             hyps: paddle.Tensor,
diff --git a/deepspeech/modules/activation.py b/deepspeech/modules/activation.py
index 0fe66b739..30132775e 100644
--- a/deepspeech/modules/activation.py
+++ b/deepspeech/modules/activation.py
@@ -69,7 +69,7 @@ class ConvGLUBlock(nn.Layer):
                 dim=0)
             self.dropout_residual = nn.Dropout(p=dropout)
 
-        self.pad_left = ConstantPad2d((0, 0, kernel_size - 1, 0), 0)
+        self.pad_left = nn.Pad2d((0, 0, kernel_size - 1, 0), 0)
 
         layers = OrderedDict()
         if bottlececk_dim == 0:
diff --git a/deepspeech/modules/decoder.py b/deepspeech/modules/decoder.py
index 696a6315b..87c9fa492 100644
--- a/deepspeech/modules/decoder.py
+++ b/deepspeech/modules/decoder.py
@@ -33,7 +33,7 @@ logger = Log(__name__).getlog()
 __all__ = ["TransformerDecoder"]
 
 
-class TransformerDecoder(nn.Module):
+class TransformerDecoder(nn.Layer):
     """Base class of Transfomer decoder module.
     Args:
         vocab_size: output dim
@@ -86,7 +86,7 @@ class TransformerDecoder(nn.Module):
         self.use_output_layer = use_output_layer
         self.output_layer = nn.Linear(attention_dim, vocab_size)
 
-        self.decoders = nn.ModuleList([
+        self.decoders = nn.LayerList([
             DecoderLayer(
                 size=attention_dim,
                 self_attn=MultiHeadedAttention(attention_heads, attention_dim,
diff --git a/deepspeech/modules/decoder_layer.py b/deepspeech/modules/decoder_layer.py
index c6fac5412..47c42615e 100644
--- a/deepspeech/modules/decoder_layer.py
+++ b/deepspeech/modules/decoder_layer.py
@@ -25,15 +25,15 @@ logger = Log(__name__).getlog()
 __all__ = ["DecoderLayer"]
 
 
-class DecoderLayer(nn.Module):
+class DecoderLayer(nn.Layer):
     """Single decoder layer module.
     Args:
         size (int): Input dimension.
-        self_attn (nn.Module): Self-attention module instance.
+        self_attn (nn.Layer): Self-attention module instance.
             `MultiHeadedAttention` instance can be used as the argument.
-        src_attn (nn.Module): Self-attention module instance.
+        src_attn (nn.Layer): Self-attention module instance.
             `MultiHeadedAttention` instance can be used as the argument.
-        feed_forward (nn.Module): Feed-forward module instance.
+        feed_forward (nn.Layer): Feed-forward module instance.
             `PositionwiseFeedForward` instance can be used as the argument.
         dropout_rate (float): Dropout rate.
         normalize_before (bool):
@@ -48,9 +48,9 @@ class DecoderLayer(nn.Module):
     def __init__(
             self,
             size: int,
-            self_attn: nn.Module,
-            src_attn: nn.Module,
-            feed_forward: nn.Module,
+            self_attn: nn.Layer,
+            src_attn: nn.Layer,
+            feed_forward: nn.Layer,
             dropout_rate: float,
             normalize_before: bool=True,
             concat_after: bool=False, ):
diff --git a/deepspeech/modules/encoder.py b/deepspeech/modules/encoder.py
index 27e0f8d78..71ec61a0e 100644
--- a/deepspeech/modules/encoder.py
+++ b/deepspeech/modules/encoder.py
@@ -358,7 +358,7 @@ class TransformerEncoder(BaseEncoder):
                          pos_enc_layer_type, normalize_before, concat_after,
                          static_chunk_size, use_dynamic_chunk, global_cmvn,
                          use_dynamic_left_chunk)
-        self.encoders = nn.ModuleList([
+        self.encoders = nn.LayerList([
             TransformerEncoderLayer(
                 size=output_size,
                 self_attn=MultiHeadedAttention(attention_heads, output_size,
@@ -438,7 +438,7 @@ class ConformerEncoder(BaseEncoder):
         convolution_layer_args = (output_size, cnn_module_kernel, activation,
                                   cnn_module_norm, causal)
 
-        self.encoders = nn.ModuleList([
+        self.encoders = nn.LayerList([
             ConformerEncoderLayer(
                 size=output_size,
                 self_attn=encoder_selfattn_layer(*encoder_selfattn_layer_args),
diff --git a/deepspeech/modules/loss.py b/deepspeech/modules/loss.py
index 3e441bbbc..8918ca669 100644
--- a/deepspeech/modules/loss.py
+++ b/deepspeech/modules/loss.py
@@ -48,7 +48,8 @@ class CTCLoss(nn.Layer):
         logits = logits.transpose([1, 0, 2])
         # (TODO:Hui Zhang) ctc loss does not support int64 labels
         ys_pad = ys_pad.astype(paddle.int32)
-        loss = self.loss(logits, ys_pad, hlens, ys_lens)
+        loss = self.loss(
+            logits, ys_pad, hlens, ys_lens, norm_by_times=self.batch_average)
         if self.batch_average:
             # Batch-size average
             loss = loss / B
diff --git a/deepspeech/modules/rnn.py b/deepspeech/modules/rnn.py
index 01b55c4a2..0d8c9fd2c 100644
--- a/deepspeech/modules/rnn.py
+++ b/deepspeech/modules/rnn.py
@@ -297,7 +297,7 @@ class RNNStack(nn.Layer):
                         share_weights=share_rnn_weights))
             i_size = h_size * 2
 
-        self.rnn_stacks = nn.ModuleList(rnn_stacks)
+        self.rnn_stacks = nn.LayerList(rnn_stacks)
 
     def forward(self, x: paddle.Tensor, x_len: paddle.Tensor):
         """
diff --git a/deepspeech/training/cli.py b/deepspeech/training/cli.py
index d3b853554..ecd7a8f26 100644
--- a/deepspeech/training/cli.py
+++ b/deepspeech/training/cli.py
@@ -47,18 +47,11 @@ def default_argument_parser():
     # data and output
     parser.add_argument("--config", metavar="FILE", help="path of the config file to overwrite to default config with.")
     parser.add_argument("--dump-config", metavar="FILE", help="dump config to yaml file.")
-    # parser.add_argument("--data", metavar="DATA_DIR", help="path to the datatset.")
     parser.add_argument("--output", metavar="OUTPUT_DIR", help="path to save checkpoint and logs.")
 
     # load from saved checkpoint
     parser.add_argument("--checkpoint_path", type=str, help="path of the checkpoint to load")
 
-    # save jit model to
-    parser.add_argument("--export_path", type=str, help="path of the jit model to save")
-
-    # save asr result to
-    parser.add_argument("--result_file", type=str, help="path of save the asr result")
-
     # running
     parser.add_argument("--device", type=str, default='gpu', choices=["cpu", "gpu"],
                         help="device type to use, cpu and gpu are supported.")
diff --git a/deepspeech/training/optimizer.py b/deepspeech/training/optimizer.py
index f7933f8d4..db7069c98 100644
--- a/deepspeech/training/optimizer.py
+++ b/deepspeech/training/optimizer.py
@@ -15,6 +15,7 @@ from typing import Any
 from typing import Dict
 from typing import Text
 
+import paddle
 from paddle.optimizer import Optimizer
 from paddle.regularizer import L2Decay
 
@@ -43,6 +44,40 @@ def register_optimizer(cls):
     return cls
 
 
+@register_optimizer
+class Noam(paddle.optimizer.Adam):
+    """Seem to: espnet/nets/pytorch_backend/transformer/optimizer.py """
+
+    def __init__(self,
+                 learning_rate=0,
+                 beta1=0.9,
+                 beta2=0.98,
+                 epsilon=1e-9,
+                 parameters=None,
+                 weight_decay=None,
+                 grad_clip=None,
+                 lazy_mode=False,
+                 multi_precision=False,
+                 name=None):
+        super().__init__(
+            learning_rate=learning_rate,
+            beta1=beta1,
+            beta2=beta2,
+            epsilon=epsilon,
+            parameters=parameters,
+            weight_decay=weight_decay,
+            grad_clip=grad_clip,
+            lazy_mode=lazy_mode,
+            multi_precision=multi_precision,
+            name=name)
+
+    def __repr__(self):
+        echo = f"<{self.__class__.__module__}.{self.__class__.__name__} object at {hex(id(self))}> "
+        echo += f"learning_rate: {self._learning_rate}, "
+        echo += f"(beta1: {self._beta1} beta2: {self._beta2}), "
+        echo += f"epsilon: {self._epsilon}"
+
+
 def dynamic_import_optimizer(module):
     """Import Optimizer class dynamically.
 
@@ -69,15 +104,18 @@ class OptimizerFactory():
             args['grad_clip']) if "grad_clip" in args else None
         weight_decay = L2Decay(
             args['weight_decay']) if "weight_decay" in args else None
-        module_class = dynamic_import_optimizer(name.lower())
-
         if weight_decay:
-            logger.info(f'WeightDecay: {weight_decay}')
+            logger.info(f'<WeightDecay - {weight_decay}>')
         if grad_clip:
-            logger.info(f'GradClip: {grad_clip}')
-        logger.info(
-            f"Optimizer: {module_class.__name__} {args['learning_rate']}")
+            logger.info(f'<GradClip - {grad_clip}>')
 
+        module_class = dynamic_import_optimizer(name.lower())
         args.update({"grad_clip": grad_clip, "weight_decay": weight_decay})
-
-        return instance_class(module_class, args)
+        opt = instance_class(module_class, args)
+        if "__repr__" in vars(opt):
+            logger.info(f"{opt}")
+        else:
+            logger.info(
+                f"<Optimizer {module_class.__module__}.{module_class.__name__}> LR: {args['learning_rate']}"
+            )
+        return opt
diff --git a/deepspeech/training/scheduler.py b/deepspeech/training/scheduler.py
index b8f3ece7c..bb53281a8 100644
--- a/deepspeech/training/scheduler.py
+++ b/deepspeech/training/scheduler.py
@@ -41,22 +41,6 @@ def register_scheduler(cls):
     return cls
 
 
-def dynamic_import_scheduler(module):
-    """Import Scheduler class dynamically.
-
-    Args:
-        module (str): module_name:class_name or alias in `SCHEDULER_DICT`
-
-    Returns:
-        type: Scheduler class
-
-    """
-    module_class = dynamic_import(module, SCHEDULER_DICT)
-    assert issubclass(module_class,
-                      LRScheduler), f"{module} does not implement LRScheduler"
-    return module_class
-
-
 @register_scheduler
 class WarmupLR(LRScheduler):
     """The WarmupLR scheduler
@@ -102,6 +86,41 @@ class WarmupLR(LRScheduler):
         self.step(epoch=step)
 
 
+@register_scheduler
+class ConstantLR(LRScheduler):
+    """
+    Args:
+        learning_rate (float): The initial learning rate. It is a python float number.
+        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
+        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
+    
+    Returns:
+        ``ConstantLR`` instance to schedule learning rate.
+    """
+
+    def __init__(self, learning_rate, last_epoch=-1, verbose=False):
+        super().__init__(learning_rate, last_epoch, verbose)
+
+    def get_lr(self):
+        return self.base_lr
+
+
+def dynamic_import_scheduler(module):
+    """Import Scheduler class dynamically.
+
+    Args:
+        module (str): module_name:class_name or alias in `SCHEDULER_DICT`
+
+    Returns:
+        type: Scheduler class
+
+    """
+    module_class = dynamic_import(module, SCHEDULER_DICT)
+    assert issubclass(module_class,
+                      LRScheduler), f"{module} does not implement LRScheduler"
+    return module_class
+
+
 class LRSchedulerFactory():
     @classmethod
     def from_args(cls, name: str, args: Dict[Text, Any]):
diff --git a/doc/src/feature_list.md b/doc/src/feature_list.md
index 573669fa2..b675d8100 100644
--- a/doc/src/feature_list.md
+++ b/doc/src/feature_list.md
@@ -1,4 +1,4 @@
-# Featrues
+# Features
 
 ### Speech Recognition
 
diff --git a/env.sh b/env.sh
index 9d22259df..461586e7d 100644
--- a/env.sh
+++ b/env.sh
@@ -1,6 +1,6 @@
 export MAIN_ROOT=${PWD}
 
-export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH}
+export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:/usr/local/bin:${PATH}
 export LC_ALL=C
 
 # Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
diff --git a/examples/aishell/s0/conf/augmentation.json b/examples/aishell/s0/conf/augmentation.json
index 1987ad424..ac8a1c530 100644
--- a/examples/aishell/s0/conf/augmentation.json
+++ b/examples/aishell/s0/conf/augmentation.json
@@ -19,15 +19,17 @@
   {
     "type": "specaug",
     "params": {
-      "F": 10,
-      "T": 50,
+      "W": 5,
+      "warp_mode": "PIL",
+      "F": 30,
       "n_freq_masks": 2,
+      "T": 40,
       "n_time_masks": 2,
       "p": 1.0,
-      "W": 80,
       "adaptive_number_ratio": 0,
       "adaptive_size_ratio": 0,
-      "max_n_time_masks": 20
+      "max_n_time_masks": 20,
+      "replace_with_zero": false
     },
     "prob": 1.0
   }
diff --git a/examples/aishell/s1/conf/augmentation.json b/examples/aishell/s1/conf/augmentation.json
index 1987ad424..d0409b142 100644
--- a/examples/aishell/s1/conf/augmentation.json
+++ b/examples/aishell/s1/conf/augmentation.json
@@ -27,7 +27,9 @@
       "W": 80,
       "adaptive_number_ratio": 0,
       "adaptive_size_ratio": 0,
-      "max_n_time_masks": 20
+      "max_n_time_masks": 20,
+      "replace_with_zero": true,
+      "warp_mode": "PIL"
     },
     "prob": 1.0
   }
diff --git a/examples/aug_conf/augmentation.json b/examples/aug_conf/augmentation.json
deleted file mode 100644
index a1a759e67..000000000
--- a/examples/aug_conf/augmentation.json
+++ /dev/null
@@ -1,10 +0,0 @@
-[
-  {
-    "type": "shift",
-    "params": {
-      "min_shift_ms": -5,
-      "max_shift_ms": 5
-    },
-    "prob": 1.0
-  }
-]
diff --git a/examples/aug_conf/augmentation.example.json b/examples/augmentation/augmentation.json
similarity index 91%
rename from examples/aug_conf/augmentation.example.json
rename to examples/augmentation/augmentation.json
index efae2e5e3..c99299d6c 100644
--- a/examples/aug_conf/augmentation.example.json
+++ b/examples/augmentation/augmentation.json
@@ -52,16 +52,18 @@
   {
     "type": "specaug",
     "params": {
+      "W": 80,
+      "warp_mode": "PIL",
       "F": 10,
-      "T": 50,
       "n_freq_masks": 2,
+      "T": 50,
       "n_time_masks": 2,
       "p": 1.0,
-      "W": 80,
       "adaptive_number_ratio": 0,
       "adaptive_size_ratio": 0,
-      "max_n_time_masks": 20
+      "max_n_time_masks": 20,
+      "replace_with_zero": false
     },
-    "prob": 0.0
+    "prob": 1.0
   }
 ]
diff --git a/examples/callcenter/s1/conf/augmentation.json b/examples/callcenter/s1/conf/augmentation.json
index 1987ad424..81d110b0b 100644
--- a/examples/callcenter/s1/conf/augmentation.json
+++ b/examples/callcenter/s1/conf/augmentation.json
@@ -27,7 +27,8 @@
       "W": 80,
       "adaptive_number_ratio": 0,
       "adaptive_size_ratio": 0,
-      "max_n_time_masks": 20
+      "max_n_time_masks": 20,
+      "replace_with_zero": true
     },
     "prob": 1.0
   }
diff --git a/examples/librispeech/s0/conf/augmentation.json b/examples/librispeech/s0/conf/augmentation.json
index 1987ad424..d0409b142 100644
--- a/examples/librispeech/s0/conf/augmentation.json
+++ b/examples/librispeech/s0/conf/augmentation.json
@@ -27,7 +27,9 @@
       "W": 80,
       "adaptive_number_ratio": 0,
       "adaptive_size_ratio": 0,
-      "max_n_time_masks": 20
+      "max_n_time_masks": 20,
+      "replace_with_zero": true,
+      "warp_mode": "PIL"
     },
     "prob": 1.0
   }
diff --git a/examples/librispeech/s1/README.md b/examples/librispeech/s1/README.md
index daa4d175b..4cb3629de 100644
--- a/examples/librispeech/s1/README.md
+++ b/examples/librispeech/s1/README.md
@@ -21,7 +21,6 @@
 | --- | --- | --- | --- | --- | --- | --- | --- |
 | conformer | 47.63 M | conf/conformer.yaml | spec_aug + shift | test-clean-all | attention | 6.35 | 0.057117 |  
 
-
 ## Chunk Conformer
 | Model | Params | Config | Augmentation| Test set | Decode method | Chunk Size & Left Chunks | Loss | WER |  
 | --- | --- | --- | --- | --- | --- | --- | --- | --- |  
@@ -39,4 +38,7 @@
 ### Test w/o length filter
 | Model | Params | Config | Augmentation| Test set | Decode method | Loss | WER |  
 | --- | --- | --- | --- | --- | --- | --- | --- |
-| transformer | 32.52 M | conf/transformer.yaml | spec_aug + shift | test-clean-all | attention | 6.98 | 0.066500 |  
+| transformer | 32.52 M | conf/transformer.yaml | spec_aug + shift | test-clean-all | attention | 7.63 | 0.056832 |  
+| transformer | 32.52 M | conf/transformer.yaml | spec_aug + shift | test-clean-all | ctc_greedy_search | 7.63 | 0.059742 |  
+| transformer | 32.52 M | conf/transformer.yaml | spec_aug + shift | test-clean-all | ctc_prefix_beam_search | 7.63 | 0.059057 |  
+| transformer | 32.52 M | conf/transformer.yaml | spec_aug + shift | test-clean-all | attention_rescoring | 7.63 | 0.047417 |  
diff --git a/examples/librispeech/s1/conf/augmentation.json b/examples/librispeech/s1/conf/augmentation.json
index c1078393d..8e6e97040 100644
--- a/examples/librispeech/s1/conf/augmentation.json
+++ b/examples/librispeech/s1/conf/augmentation.json
@@ -27,7 +27,9 @@
       "W": 80,
       "adaptive_number_ratio": 0,
       "adaptive_size_ratio": 0,
-      "max_n_time_masks": 20
+      "max_n_time_masks": 20,
+      "replace_with_zero": true,
+      "warp_mode": "PIL"
     },
     "prob": 1.0
   }
diff --git a/examples/librispeech/s1/conf/transformer.yaml b/examples/librispeech/s1/conf/transformer.yaml
index 8a769dca4..bc2ec6061 100644
--- a/examples/librispeech/s1/conf/transformer.yaml
+++ b/examples/librispeech/s1/conf/transformer.yaml
@@ -4,7 +4,7 @@ data:
   dev_manifest: data/manifest.dev
   test_manifest: data/manifest.test-clean
   min_input_len: 0.5  # second
-  max_input_len: 20.0 # second
+  max_input_len: 30.0 # second
   min_output_len: 0.0 # tokens
   max_output_len: 400.0 # tokens
   min_output_input_ratio: 0.05
diff --git a/examples/librispeech/s1/run.sh b/examples/librispeech/s1/run.sh
index 2a8f2e2d1..def10ab05 100755
--- a/examples/librispeech/s1/run.sh
+++ b/examples/librispeech/s1/run.sh
@@ -5,7 +5,7 @@ source path.sh
 stage=0
 stop_stage=100
 conf_path=conf/transformer.yaml
-avg_num=30
+avg_num=5
 source ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
 
 avg_ckpt=avg_${avg_num}
diff --git a/examples/librispeech/s2/conf/augmentation.json b/examples/librispeech/s2/conf/augmentation.json
index c1078393d..e20fc1997 100644
--- a/examples/librispeech/s2/conf/augmentation.json
+++ b/examples/librispeech/s2/conf/augmentation.json
@@ -1,21 +1,4 @@
 [
-  {
-    "type": "shift",
-    "params": {
-      "min_shift_ms": -5,
-      "max_shift_ms": 5
-    },
-    "prob": 1.0
-  },
-  {
-    "type": "speed",
-    "params": {
-      "min_speed_rate": 0.9,
-      "max_speed_rate": 1.1,
-      "num_rates": 3
-    },
-    "prob": 0.0
-  },
   {
     "type": "specaug",
     "params": {
@@ -27,7 +10,9 @@
       "W": 80,
       "adaptive_number_ratio": 0,
       "adaptive_size_ratio": 0,
-      "max_n_time_masks": 20
+      "max_n_time_masks": 20,
+      "replace_with_zero": true,
+      "warp_mode": "PIL"
     },
     "prob": 1.0
   }
diff --git a/examples/librispeech/s2/conf/transformer.yaml b/examples/librispeech/s2/conf/transformer.yaml
index 8a769dca4..ded4f2408 100644
--- a/examples/librispeech/s2/conf/transformer.yaml
+++ b/examples/librispeech/s2/conf/transformer.yaml
@@ -3,23 +3,17 @@ data:
   train_manifest: data/manifest.train
   dev_manifest: data/manifest.dev
   test_manifest: data/manifest.test-clean
-  min_input_len: 0.5  # second
-  max_input_len: 20.0 # second
-  min_output_len: 0.0 # tokens
-  max_output_len: 400.0 # tokens
-  min_output_input_ratio: 0.05
-  max_output_input_ratio: 10.0
 
 collator:
-  vocab_filepath: data/vocab.txt
+  vocab_filepath: data/train_960_unigram5000_units.txt
   unit_type: 'spm'
-  spm_model_prefix: 'data/bpe_unigram_5000'
+  spm_model_prefix: 'data/train_960_unigram5000'
   mean_std_filepath: ""
   augmentation_config: conf/augmentation.json
   batch_size: 64
   raw_wav: True  # use raw_wav or kaldi feature
   specgram_type: fbank #linear, mfcc, fbank
-  feat_dim: 80
+  feat_dim: 83
   delta_delta: False
   dither: 1.0
   target_sample_rate: 16000
@@ -38,7 +32,7 @@ collator:
 
 # network architecture
 model:
-    cmvn_file: "data/mean_std.json"
+    cmvn_file:  
     cmvn_file_type: "json"
     # encoder related
     encoder: transformer
@@ -74,20 +68,20 @@ model:
 training:
   n_epoch: 120
   accum_grad: 2
-  global_grad_clip: 5.0
-  optim: adam
-  optim_conf:
-    lr: 0.004
-    weight_decay: 1e-06
-  scheduler: warmuplr     # pytorch v1.1.0+ required
-  scheduler_conf:
-    warmup_steps: 25000
-    lr_decay: 1.0
   log_interval: 100
   checkpoint:
     kbest_n: 50
     latest_n: 5
 
+optim: adam
+optim_conf:
+  global_grad_clip: 5.0
+  weight_decay: 1.0e-06
+scheduler: warmuplr     # pytorch v1.1.0+ required
+scheduler_conf:
+  lr: 0.004
+  warmup_steps: 25000
+  lr_decay: 1.0
 
 decoding:
   batch_size: 64
diff --git a/examples/librispeech/s2/local/align.sh b/examples/librispeech/s2/local/align.sh
index ad6c84bc8..b3d8fa5f5 100755
--- a/examples/librispeech/s2/local/align.sh
+++ b/examples/librispeech/s2/local/align.sh
@@ -1,7 +1,7 @@
 #!/bin/bash
 
-if [ $# != 2 ];then
-    echo "usage: ${0} config_path ckpt_path_prefix"
+if [ $# != 3 ];then
+    echo "usage: ${0} config_path dict_path ckpt_path_prefix"
     exit -1
 fi
 
@@ -13,7 +13,8 @@ if [ ${ngpu} == 0 ];then
     device=cpu
 fi
 config_path=$1
-ckpt_prefix=$2
+dict_path=$2
+ckpt_prefix=$3
 
 batch_size=1
 output_dir=${ckpt_prefix}
@@ -21,11 +22,14 @@ mkdir -p ${output_dir}
 
 # align dump in `result_file`
 # .tier, .TextGrid dump in `dir of result_file`
-python3 -u ${BIN_DIR}/alignment.py \
+python3 -u ${BIN_DIR}/test.py \
+--model-name 'u2_kaldi' \
+--run-mode 'align' \
+--dict-path ${dict_path} \
 --device ${device} \
 --nproc 1 \
 --config ${config_path} \
---result_file ${output_dir}/${type}.align \
+--result-file ${output_dir}/${type}.align \
 --checkpoint_path ${ckpt_prefix} \
 --opts decoding.batch_size ${batch_size}
 
diff --git a/examples/librispeech/s2/local/espnet_json_to_manifest.py b/examples/librispeech/s2/local/espnet_json_to_manifest.py
new file mode 100755
index 000000000..acfa46681
--- /dev/null
+++ b/examples/librispeech/s2/local/espnet_json_to_manifest.py
@@ -0,0 +1,36 @@
+#!/usr/bin/env python
+import argparse
+import json
+
+
+def main(args):
+    with open(args.json_file, 'r') as fin:
+        data_json = json.load(fin)
+
+    # manifest format:
+    # {"input": [
+    #       {"feat": "dev/deltafalse/feats.1.ark:842920", "name": "input1", "shape": [349, 83]}
+    #  ], 
+    #  "output": [
+    #       {"name": "target1", "shape": [12, 5002], "text": "NO APOLLO", "token": "▁NO ▁A PO LL O", "tokenid": "3144 482 352 269 317"}
+    #  ], 
+    #  "utt2spk": "116-288045", 
+    #  "utt": "116-288045-0019"}
+    with open(args.manifest_file, 'w') as fout:
+        for key, value in data_json['utts'].items():
+            value['utt'] = key
+            fout.write(json.dumps(value, ensure_ascii=False))
+            fout.write("\n")
+
+
+if __name__ == '__main__':
+    parser = argparse.ArgumentParser(description=__doc__)
+    parser.add_argument(
+        '--json-file', type=str, default=None, help="espnet data json file.")
+    parser.add_argument(
+        '--manifest-file',
+        type=str,
+        default='maniefst.train',
+        help='manifest data json line file.')
+    args = parser.parse_args()
+    main(args)
diff --git a/examples/librispeech/s2/local/export.sh b/examples/librispeech/s2/local/export.sh
index f99a15bad..efa70a2b9 100755
--- a/examples/librispeech/s2/local/export.sh
+++ b/examples/librispeech/s2/local/export.sh
@@ -17,7 +17,9 @@ if [ ${ngpu} == 0 ];then
     device=cpu
 fi
 
-python3 -u ${BIN_DIR}/export.py \
+python3 -u ${BIN_DIR}/test.py \
+--model-name 'u2_kaldi' \
+--run-mode 'export' \
 --device ${device} \
 --nproc ${ngpu} \
 --config ${config_path} \
diff --git a/examples/librispeech/s2/local/test.sh b/examples/librispeech/s2/local/test.sh
index 3bd3f0bba..efd06f35e 100755
--- a/examples/librispeech/s2/local/test.sh
+++ b/examples/librispeech/s2/local/test.sh
@@ -1,7 +1,7 @@
 #!/bin/bash
 
-if [ $# != 2 ];then
-    echo "usage: ${0} config_path ckpt_path_prefix"
+if [ $# != 3 ];then
+    echo "usage: ${0} config_path dict_path ckpt_path_prefix"
     exit -1
 fi
 
@@ -14,7 +14,8 @@ if [ ${ngpu} == 0 ];then
 fi
 
 config_path=$1
-ckpt_prefix=$2
+dict_path=$2
+ckpt_prefix=$3
 
 chunk_mode=false
 if [[ ${config_path} =~ ^.*chunk_.*yaml$ ]];then
@@ -38,10 +39,13 @@ for type in attention ctc_greedy_search; do
         batch_size=64
     fi
     python3 -u ${BIN_DIR}/test.py \
+    --model-name u2_kaldi \
+    --run-mode test \
+    --dict-path ${dict_path} \
     --device ${device} \
     --nproc 1 \
     --config ${config_path} \
-    --result_file ${ckpt_prefix}.${type}.rsl \
+    --result-file ${ckpt_prefix}.${type}.rsl \
     --checkpoint_path ${ckpt_prefix} \
     --opts decoding.decoding_method ${type} decoding.batch_size ${batch_size}
 
@@ -55,10 +59,13 @@ for type in ctc_prefix_beam_search attention_rescoring; do
     echo "decoding ${type}"
     batch_size=1
     python3 -u ${BIN_DIR}/test.py \
+    --model-name u2_kaldi \
+    --run-mode test \
+    --dict-path ${dict_path} \
     --device ${device} \
     --nproc 1 \
     --config ${config_path} \
-    --result_file ${ckpt_prefix}.${type}.rsl \
+    --result-file ${ckpt_prefix}.${type}.rsl \
     --checkpoint_path ${ckpt_prefix} \
     --opts decoding.decoding_method ${type} decoding.batch_size ${batch_size}
 
diff --git a/examples/librispeech/s2/local/train.sh b/examples/librispeech/s2/local/train.sh
index ec17054ab..c75252594 100755
--- a/examples/librispeech/s2/local/train.sh
+++ b/examples/librispeech/s2/local/train.sh
@@ -25,6 +25,7 @@ if [ ${seed} ]; then
 fi
 
 python3 -u ${BIN_DIR}/train.py \
+--model-name u2_kaldi \
 --device ${device} \
 --nproc ${ngpu} \
 --config ${config_path} \
diff --git a/examples/librispeech/s2/path.sh b/examples/librispeech/s2/path.sh
index 457f7e548..c90e27821 100644
--- a/examples/librispeech/s2/path.sh
+++ b/examples/librispeech/s2/path.sh
@@ -10,5 +10,5 @@ export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH}
 export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib/
 
 
-MODEL=u2
+MODEL=u2_kaldi
 export BIN_DIR=${MAIN_ROOT}/deepspeech/exps/${MODEL}/bin
diff --git a/examples/librispeech/s2/run.sh b/examples/librispeech/s2/run.sh
index 2a8f2e2d1..26398dd14 100755
--- a/examples/librispeech/s2/run.sh
+++ b/examples/librispeech/s2/run.sh
@@ -5,7 +5,8 @@ source path.sh
 stage=0
 stop_stage=100
 conf_path=conf/transformer.yaml
-avg_num=30
+dict_path=data/train_960_unigram5000_units.txt
+avg_num=5
 source ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
 
 avg_ckpt=avg_${avg_num}
@@ -29,12 +30,12 @@ fi
 
 if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
     # test ckpt avg_n
-    CUDA_VISIBLE_DEVICES=0 ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1
+    CUDA_VISIBLE_DEVICES=0 ./local/test.sh ${conf_path} ${dict_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1
 fi
 
 if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
     # ctc alignment of test data
-    CUDA_VISIBLE_DEVICES=0 ./local/align.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1
+    CUDA_VISIBLE_DEVICES=0 ./local/align.sh ${conf_path} ${dict_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1
 fi
 
 if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
diff --git a/examples/punctuation_restoration/README.md b/examples/punctuation_restoration/README.md
new file mode 100644
index 000000000..f2ca76996
--- /dev/null
+++ b/examples/punctuation_restoration/README.md
@@ -0,0 +1,3 @@
+# Punctation Restoration
+
+Please using `https://github.com/745165806/PaddleSpeechTask` to do this task.
diff --git a/examples/ted_en_zh/t0/.gitignore b/examples/ted_en_zh/t0/.gitignore
new file mode 100644
index 000000000..469c61715
--- /dev/null
+++ b/examples/ted_en_zh/t0/.gitignore
@@ -0,0 +1,3 @@
+TED-En-Zh
+data
+exp
diff --git a/examples/ted_en_zh/t0/README.md b/examples/ted_en_zh/t0/README.md
new file mode 100644
index 000000000..e2443d363
--- /dev/null
+++ b/examples/ted_en_zh/t0/README.md
@@ -0,0 +1,10 @@
+
+# TED En-Zh
+
+## Dataset
+
+| Data Subset | Duration in Seconds |
+| --- | --- |
+| data/manifest.train | 0.942 ~ 60   |
+| data/manifest.dev   | 1.151 ~ 39   |  
+| data/manifest.test  | 1.1 ~ 42.746 |
diff --git a/examples/ted_en_zh/t0/local/data.sh b/examples/ted_en_zh/t0/local/data.sh
index 0a5c58aa5..32cfd9d7a 100755
--- a/examples/ted_en_zh/t0/local/data.sh
+++ b/examples/ted_en_zh/t0/local/data.sh
@@ -7,37 +7,37 @@ stop_stage=100
 nbpe=8000
 bpemode=unigram
 bpeprefix="data/bpe_${bpemode}_${nbpe}"
-DATA_DIR= 
+data_dir=/mnt/dataset/TED_EnZh
 
 
 source ${MAIN_ROOT}/utils/parse_options.sh
 
-
-mkdir -p data
 TARGET_DIR=${MAIN_ROOT}/examples/dataset
 mkdir -p ${TARGET_DIR}
+mkdir -p data
 
-if [ ! -d ${SOURCE_DIR} ]; then
-    echo "Error: Dataset is not avaiable. Please download and unzip the dataset"
-    echo "Download Link: https://pan.baidu.com/s/18L-59wgeS96WkObISrytQQ Passwd: bva0"
-    echo "The tree of the directory should be:"
-    echo "."
-    echo "|-- En-Zh"
-    echo "|-- test-segment"
-    echo "    |-- tst2010"
-    echo "    |-- ..."
-    echo "|-- train-split"
-    echo "    |-- train-segment"
-    echo "|-- README.md"
-
-    exit 1
-fi
 
 if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
+    if [ ! -e ${data_dir} ]; then
+        echo "Error: Dataset is not avaiable. Please download and unzip the dataset"
+        echo "Download Link: https://pan.baidu.com/s/18L-59wgeS96WkObISrytQQ Passwd: bva0"
+        echo "The tree of the directory should be:"
+        echo "."
+        echo "|-- En-Zh"
+        echo "|-- test-segment"
+        echo "    |-- tst2010"
+        echo "    |-- ..."
+        echo "|-- train-split"
+        echo "    |-- train-segment"
+        echo "|-- README.md"
+
+        exit 1
+    fi
+
     # generate manifests
     python3 ${TARGET_DIR}/ted_en_zh/ted_en_zh.py \
     --manifest_prefix="data/manifest" \
-    --src_dir="${DATA_DIR}" 
+    --src_dir="${data_dir}"
 
     echo "Complete raw data pre-process."
 fi
diff --git a/examples/ted_en_zh/t0/run.sh b/examples/ted_en_zh/t0/run.sh
index 89048f3dd..26fadb608 100755
--- a/examples/ted_en_zh/t0/run.sh
+++ b/examples/ted_en_zh/t0/run.sh
@@ -16,7 +16,7 @@ echo "checkpoint name ${ckpt}"
 
 if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
     # prepare data
-    bash ./local/data.sh --DATA_DIR ${data_path} || exit -1
+    bash ./local/data.sh --data_dir ${data_path} || exit -1
 fi
 
 if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
diff --git a/examples/thchs30/a0/local/data.sh b/examples/thchs30/a0/local/data.sh
index 169367acc..8614a0415 100644
--- a/examples/thchs30/a0/local/data.sh
+++ b/examples/thchs30/a0/local/data.sh
@@ -20,27 +20,33 @@ if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
         echo "Prepare THCHS-30 failed. Terminated."
         exit 1
     fi
-    
 fi
 
-# dump manifest to data/
-python3 ${MAIN_ROOT}/utils/dump_manifest.py --manifest-path=data/manifest.train --output-dir=data
-
-# copy files to data/dict to gen word.lexicon
-cp  ${TARGET_DIR}/thchs30/data_thchs30/lm_word/lexicon.txt data/dict/lm_word_lexicon_1
-cp  ${TARGET_DIR}/thchs30/resource/dict/lexicon.txt data/dict/lm_word_lexicon_2
+if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
+    # dump manifest to data/
+    python3 ${MAIN_ROOT}/utils/dump_manifest.py --manifest-path=data/manifest.train --output-dir=data
+fi
 
-# copy phone.lexicon to data/dict
-cp  ${TARGET_DIR}/thchs30/data_thchs30/lm_phone/lexicon.txt data/dict/phone.lexicon
+if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
+    # copy files to data/dict to gen word.lexicon
+    cp  ${TARGET_DIR}/thchs30/data_thchs30/lm_word/lexicon.txt data/dict/lm_word_lexicon_1
+    cp  ${TARGET_DIR}/thchs30/resource/dict/lexicon.txt data/dict/lm_word_lexicon_2
+    # copy phone.lexicon to data/dict
+    cp  ${TARGET_DIR}/thchs30/data_thchs30/lm_phone/lexicon.txt data/dict/phone.lexicon
+fi
 
-# gen word.lexicon
-python local/gen_word2phone.py  --root-dir=data/dict --output-dir=data/dict
+if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
+    # gen word.lexicon
+    python local/gen_word2phone.py  --lexicon-files="data/dict/lm_word_lexicon_1 data/dict/lm_word_lexicon_2" --output-path=data/dict/word.lexicon
+fi
 
-# reorganize dataset for MFA
-if [ ! -d $EXP_DIR/thchs30_corpus ]; then
-    echo "reorganizing thchs30 corpus..."
-    python local/reorganize_thchs30.py --root-dir=data --output-dir=data/thchs30_corpus --script-type=$LEXICON_NAME
-    echo "reorganization done."
+if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
+    # reorganize dataset for MFA
+    if [ ! -d $EXP_DIR/thchs30_corpus ]; then
+        echo "reorganizing thchs30 corpus..."
+        python local/reorganize_thchs30.py --root-dir=data --output-dir=data/thchs30_corpus --script-type=$LEXICON_NAME
+        echo "reorganization done."
+    fi
 fi
 
 echo "THCHS-30  data preparation done."
diff --git a/examples/thchs30/a0/local/gen_word2phone.py b/examples/thchs30/a0/local/gen_word2phone.py
index cd584fcdc..9bc0249bf 100644
--- a/examples/thchs30/a0/local/gen_word2phone.py
+++ b/examples/thchs30/a0/local/gen_word2phone.py
@@ -18,6 +18,7 @@ file2: THCHS-30/resource/dict/lexicon.txt
 import argparse
 from collections import defaultdict
 from pathlib import Path
+from typing import List
 from typing import Union
 
 # key: (cn, ('ee', 'er4')),value: count
@@ -34,7 +35,7 @@ def is_Chinese(ch):
     return False
 
 
-def proc_line(line):
+def proc_line(line: str):
     line = line.strip()
     if is_Chinese(line[0]):
         line_list = line.split()
@@ -49,20 +50,25 @@ def proc_line(line):
                 cn_phones_counter[(cn, phones)] += 1
 
 
-def gen_lexicon(root_dir: Union[str, Path], output_dir: Union[str, Path]):
-    root_dir = Path(root_dir).expanduser()
-    output_dir = Path(output_dir).expanduser()
-    output_dir.mkdir(parents=True, exist_ok=True)
-    file1 = root_dir / "lm_word_lexicon_1"
-    file2 = root_dir / "lm_word_lexicon_2"
-    write_file = output_dir / "word.lexicon"
+"""
+example lines of output
+the first column is a Chinese character
+the second is the probability of this pronunciation
+and the rest are the phones of this pronunciation
+一 0.22 ii i1↩
+一 0.45 ii i4↩
+一 0.32 ii i2↩
+一 0.01 ii i5
+"""
+
+
+def gen_lexicon(lexicon_files: List[Union[str, Path]],
+                output_path: Union[str, Path]):
+    for file_path in lexicon_files:
+        with open(file_path, "r") as f1:
+            for line in f1:
+                proc_line(line)
 
-    with open(file1, "r") as f1:
-        for line in f1:
-            proc_line(line)
-    with open(file2, "r") as f2:
-        for line in f2:
-            proc_line(line)
     for key in cn_phones_counter:
         cn = key[0]
         cn_counter[cn].append((key[1], cn_phones_counter[key]))
@@ -75,7 +81,8 @@ def gen_lexicon(root_dir: Union[str, Path], output_dir: Union[str, Path]):
             p = round(p, 2)
             if p > 0:
                 cn_counter_p[key].append((item[0], p))
-    with open(write_file, "w") as wf:
+
+    with open(output_path, "w") as wf:
         for key in cn_counter_p:
             phone_p_list = cn_counter_p[key]
             for item in phone_p_list:
@@ -87,8 +94,21 @@ if __name__ == "__main__":
     parser = argparse.ArgumentParser(
         description="Gen Chinese characters to phone lexicon for THCHS-30 dataset"
     )
+    # A line of word_lexicon:
+    # 一丁点 ii i4 d ing1 d ian3
+    # the first is word, and the rest are the phones of the word, and the len of phones is twice of the word's len
+    parser.add_argument(
+        "--lexicon-files",
+        type=str,
+        default="data/dict/lm_word_lexicon_1 data/dict/lm_word_lexicon_2",
+        help="lm_word_lexicon files")
     parser.add_argument(
-        "--root-dir", type=str, help="dir to thchs30 lm_word_lexicons")
-    parser.add_argument("--output-dir", type=str, help="path to save outputs")
+        "--output-path",
+        type=str,
+        default="data/dict/word.lexicon",
+        help="path to save output word2phone lexicon")
     args = parser.parse_args()
-    gen_lexicon(args.root_dir, args.output_dir)
+    lexicon_files = args.lexicon_files.split(" ")
+    output_path = Path(args.output_path).expanduser()
+
+    gen_lexicon(lexicon_files, output_path)
diff --git a/examples/thchs30/a0/local/reorganize_thchs30.py b/examples/thchs30/a0/local/reorganize_thchs30.py
index 9df6bc6a9..c7c6248bc 100644
--- a/examples/thchs30/a0/local/reorganize_thchs30.py
+++ b/examples/thchs30/a0/local/reorganize_thchs30.py
@@ -58,8 +58,6 @@ def write_lab(root_dir: Union[str, Path],
 def reorganize_thchs30(root_dir: Union[str, Path],
                        output_dir: Union[str, Path]=None,
                        script_type='phone'):
-    root_dir = Path(root_dir).expanduser()
-    output_dir = Path(output_dir).expanduser()
     output_dir.mkdir(parents=True, exist_ok=True)
     link_wav(root_dir, output_dir)
     write_lab(root_dir, output_dir, script_type)
@@ -72,12 +70,15 @@ if __name__ == "__main__":
     parser.add_argument(
         "--output-dir",
         type=str,
-        help="path to save outputs(audio and transcriptions)")
+        help="path to save outputs (audio and transcriptions)")
 
     parser.add_argument(
         "--script-type",
         type=str,
         default="phone",
         help="type of lab ('word'/'syllable'/'phone')")
+
     args = parser.parse_args()
-    reorganize_thchs30(args.root_dir, args.output_dir, args.script_type)
+    root_dir = Path(args.root_dir).expanduser()
+    output_dir = Path(args.output_dir).expanduser()
+    reorganize_thchs30(root_dir, output_dir, args.script_type)
diff --git a/examples/thchs30/a0/run.sh b/examples/thchs30/a0/run.sh
index 53f96b378..5081b612a 100755
--- a/examples/thchs30/a0/run.sh
+++ b/examples/thchs30/a0/run.sh
@@ -14,14 +14,17 @@ source ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
 # gen lexicon relink gen dump
 if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
     # prepare data
-    bash ./local/data.sh $LEXICON_NAME|| exit -1
+    echo "Start prepare thchs30 data for MFA ..."
+    bash ./local/data.sh $LEXICON_NAME || exit -1
 fi
 
-# run MFA
-if [ ! -d "$EXP_DIR/thchs30_alignment" ]; then
-    echo "Start MFA training..."
-    mfa_train_and_align data/thchs30_corpus data/dict/$LEXICON_NAME.lexicon $EXP_DIR/thchs30_alignment -o $EXP_DIR/thchs30_model --clean --verbose --temp_directory exp/.mfa_train_and_align --num_jobs $NUM_JOBS
-    echo "training done! \nresults: $EXP_DIR/thchs30_alignment \nmodel: $EXP_DIR/thchs30_model\n"
+if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
+    # run MFA
+    if [ ! -d "$EXP_DIR/thchs30_alignment" ]; then
+        echo "Start MFA training ..."
+        mfa_train_and_align data/thchs30_corpus data/dict/$LEXICON_NAME.lexicon $EXP_DIR/thchs30_alignment -o $EXP_DIR/thchs30_model --clean --verbose --temp_directory exp/.mfa_train_and_align --num_jobs $NUM_JOBS
+        echo "MFA training done! \nresults: $EXP_DIR/thchs30_alignment \nmodel: $EXP_DIR/thchs30_model\n"
+    fi
 fi
 
 
diff --git a/examples/timit/s1/conf/augmentation.json b/examples/timit/s1/conf/augmentation.json
index c1078393d..8e6e97040 100644
--- a/examples/timit/s1/conf/augmentation.json
+++ b/examples/timit/s1/conf/augmentation.json
@@ -27,7 +27,9 @@
       "W": 80,
       "adaptive_number_ratio": 0,
       "adaptive_size_ratio": 0,
-      "max_n_time_masks": 20
+      "max_n_time_masks": 20,
+      "replace_with_zero": true,
+      "warp_mode": "PIL"
     },
     "prob": 1.0
   }
diff --git a/examples/tiny/s0/conf/augmentation.json b/examples/tiny/s0/conf/augmentation.json
index a1a759e67..4480307b9 100644
--- a/examples/tiny/s0/conf/augmentation.json
+++ b/examples/tiny/s0/conf/augmentation.json
@@ -1,4 +1,13 @@
 [
+  {
+    "type": "speed",
+    "params": {
+      "min_speed_rate": 0.9,
+      "max_speed_rate": 1.1,
+      "num_rates": 3
+    },
+    "prob": 0.0
+  },
   {
     "type": "shift",
     "params": {
@@ -6,5 +15,22 @@
       "max_shift_ms": 5
     },
     "prob": 1.0
+  },
+  {
+    "type": "specaug",
+    "params": {
+      "W": 5,
+      "warp_mode": "PIL",
+      "F": 30,
+      "n_freq_masks": 2,
+      "T": 40,
+      "n_time_masks": 2,
+      "p": 1.0,
+      "adaptive_number_ratio": 0,
+      "adaptive_size_ratio": 0,
+      "max_n_time_masks": 20,
+      "replace_with_zero": true
+    },
+    "prob": 1.0
   }
 ]
diff --git a/examples/tiny/s1/conf/augmentation.json b/examples/tiny/s1/conf/augmentation.json
index f26c282e7..6010c2e47 100644
--- a/examples/tiny/s1/conf/augmentation.json
+++ b/examples/tiny/s1/conf/augmentation.json
@@ -27,7 +27,9 @@
       "W": 80,
       "adaptive_number_ratio": 0,
       "adaptive_size_ratio": 0,
-      "max_n_time_masks": 20
+      "max_n_time_masks": 20,
+      "replace_with_zero": true,
+      "warp_mode": "PIL"
     },
     "prob": 1.0
   }
diff --git a/requirements.txt b/requirements.txt
index baaa9ba9b..08f2f258c 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,5 +1,7 @@
 coverage
 gpustat
+kaldiio
+Pillow
 pre-commit
 pybind11
 resampy==0.2.2
diff --git a/speechnn/.gitignore b/speechnn/.gitignore
new file mode 100644
index 000000000..378eac25d
--- /dev/null
+++ b/speechnn/.gitignore
@@ -0,0 +1 @@
+build
diff --git a/speechnn/CMakeLists.txt b/speechnn/CMakeLists.txt
index 878374bab..88182eb4c 100644
--- a/speechnn/CMakeLists.txt
+++ b/speechnn/CMakeLists.txt
@@ -1,77 +1,56 @@
 cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
 
-project(deepspeech VERSION 0.1)
+project(speechnn VERSION 0.1)
 
-set(CMAKE_VERBOSE_MAKEFILE on)
-# set std-14
-set(CMAKE_CXX_STANDARD 14)
+if(CMAKE_INSTALL_PREFIX_INITIALIZED_TO_DEFAULT)
+  set(CMAKE_INSTALL_PREFIX ${CMAKE_CURRENT_SOURCE_DIR}/src CACHE PATH "Install path prefix." FORCE)
+endif(CMAKE_INSTALL_PREFIX_INITIALIZED_TO_DEFAULT)
+set(CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake;${CMAKE_MODULE_PATH}")
 
 # include file 
-include(FetchContent)
-include(ExternalProject)
-# fc_patch dir
-set(FETCHCONTENT_QUIET off)
-get_filename_component(fc_patch "fc_patch" REALPATH BASE_DIR "${CMAKE_SOURCE_DIR}")
-set(FETCHCONTENT_BASE_DIR ${fc_patch})
-
-
-###############################################################################
-# Option Configurations
-###############################################################################
-# option configurations 
-option(TEST_DEBUG "option for debug" OFF)
-
-
-###############################################################################
-# Include third party
-###############################################################################
-# #example for include third party
-# FetchContent_Declare()
-# # FetchContent_MakeAvailable was not added until CMake 3.14
-# FetchContent_MakeAvailable()
-# include_directories()
-
-# ABSEIL-CPP
-include(FetchContent)
-FetchContent_Declare(
-  absl
-  GIT_REPOSITORY "https://github.com/abseil/abseil-cpp.git"
-  GIT_TAG "20210324.1"
-)
-FetchContent_MakeAvailable(absl)
+include(cmake/third_party.cmake)
 
-# libsndfile
-include(FetchContent)
-FetchContent_Declare(
-  libsndfile
-  GIT_REPOSITORY "https://github.com/libsndfile/libsndfile.git"
-  GIT_TAG "1.0.31"
-)
-FetchContent_MakeAvailable(libsndfile)
-
-
-###############################################################################
-# Add local library
-###############################################################################
-# system lib 
-find_package()
-# if dir have CmakeLists.txt 
-add_subdirectory()
-# if dir do not have CmakeLists.txt 
-add_library(lib_name STATIC file.cc)
-target_link_libraries(lib_name item0 item1)
-add_dependencies(lib_name depend-target)
-
-
-###############################################################################
-# Library installation
-###############################################################################
-install()
 
+set(CMAKE_VERBOSE_MAKEFILE on)
+# set std-14
+set(CMAKE_CXX_STANDARD 14)
 
-###############################################################################
-# Build binary file
-###############################################################################
-add_executable()
-target_link_libraries()
 
+# # fc_patch dir
+# set(FETCHCONTENT_QUIET off)
+# get_filename_component(fc_patch "fc_patch" REALPATH BASE_DIR "${CMAKE_SOURCE_DIR}")
+# set(FETCHCONTENT_BASE_DIR ${fc_patch})
+# 
+# 
+# ###############################################################################
+# # Option Configurations
+# ###############################################################################
+# # option configurations 
+# option(TEST_DEBUG "option for debug" OFF)
+# 
+# 
+# ###############################################################################
+# # Add local library
+# ###############################################################################
+# # system lib 
+# find_package()
+# # if dir have CmakeLists.txt 
+# add_subdirectory()
+# # if dir do not have CmakeLists.txt 
+# add_library(lib_name STATIC file.cc)
+# target_link_libraries(lib_name item0 item1)
+# add_dependencies(lib_name depend-target)
+# 
+# 
+# ###############################################################################
+# # Library installation
+# ###############################################################################
+# install()
+# 
+# 
+# ###############################################################################
+# # Build binary file
+# ###############################################################################
+# add_executable()
+# target_link_libraries()
+# 
diff --git a/speechnn/cmake/third_party.cmake b/speechnn/cmake/third_party.cmake
new file mode 100644
index 000000000..fdd7b53c2
--- /dev/null
+++ b/speechnn/cmake/third_party.cmake
@@ -0,0 +1,197 @@
+include(ExternalProject)
+# Creat a target named "third_party", which can compile external dependencies on all platform(windows/linux/mac)
+
+set(THIRD_PARTY_PATH  "${CMAKE_BINARY_DIR}/third_party" CACHE STRING
+    "A path setting third party libraries download & build directories.")
+set(THIRD_PARTY_CACHE_PATH     "${CMAKE_SOURCE_DIR}"    CACHE STRING
+    "A path cache third party source code to avoid repeated download.")
+
+set(THIRD_PARTY_BUILD_TYPE Release)
+set(third_party_deps)
+
+
+# cache funciton to avoid repeat download code of third_party.
+# This function has 4 parameters, URL / REPOSITOR / TAG / DIR:
+# 1. URL:           specify download url of 3rd party
+# 2. REPOSITORY:    specify git REPOSITORY of 3rd party
+# 3. TAG:           specify git tag/branch/commitID of 3rd party
+# 4. DIR:           overwrite the original SOURCE_DIR when cache directory
+#
+# The function Return 1 PARENT_SCOPE variables:
+#  - ${TARGET}_DOWNLOAD_CMD: Simply place "${TARGET}_DOWNLOAD_CMD" in ExternalProject_Add,
+#                            and you no longer need to set any donwnload steps in ExternalProject_Add.
+# For example:
+#    Cache_third_party(${TARGET}
+#            REPOSITORY ${TARGET_REPOSITORY}
+#            TAG        ${TARGET_TAG}
+#            DIR        ${TARGET_SOURCE_DIR})
+
+FUNCTION(cache_third_party TARGET)
+    SET(options "")
+    SET(oneValueArgs URL REPOSITORY TAG DIR)
+    SET(multiValueArgs "")
+    cmake_parse_arguments(cache_third_party "${optionps}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
+
+    STRING(REPLACE "extern_" "" TARGET_NAME ${TARGET})
+    STRING(REGEX REPLACE "[0-9]+" "" TARGET_NAME ${TARGET_NAME})
+    STRING(TOUPPER ${TARGET_NAME} TARGET_NAME)
+    IF(cache_third_party_REPOSITORY)
+        SET(${TARGET_NAME}_DOWNLOAD_CMD
+                GIT_REPOSITORY  ${cache_third_party_REPOSITORY})
+        IF(cache_third_party_TAG)
+            LIST(APPEND   ${TARGET_NAME}_DOWNLOAD_CMD
+                    GIT_TAG     ${cache_third_party_TAG})
+        ENDIF()
+    ELSEIF(cache_third_party_URL)
+        SET(${TARGET_NAME}_DOWNLOAD_CMD
+                URL             ${cache_third_party_URL})
+    ELSE()
+        MESSAGE(FATAL_ERROR    "Download link (Git repo or URL) must be specified for cache!")
+    ENDIF()
+    IF(WITH_TP_CACHE)
+        IF(NOT cache_third_party_DIR)
+            MESSAGE(FATAL_ERROR   "Please input the ${TARGET_NAME}_SOURCE_DIR for overwriting when -DWITH_TP_CACHE=ON")
+        ENDIF()
+        # Generate and verify cache dir for third_party source code
+        SET(cache_third_party_REPOSITORY ${cache_third_party_REPOSITORY} ${cache_third_party_URL})
+        IF(cache_third_party_REPOSITORY AND cache_third_party_TAG)
+            STRING(MD5 HASH_REPO ${cache_third_party_REPOSITORY})
+            STRING(MD5 HASH_GIT ${cache_third_party_TAG})
+            STRING(SUBSTRING ${HASH_REPO} 0 8 HASH_REPO)
+            STRING(SUBSTRING ${HASH_GIT} 0 8 HASH_GIT)
+            STRING(CONCAT HASH ${HASH_REPO} ${HASH_GIT})
+            # overwrite the original SOURCE_DIR when cache directory
+            SET(${cache_third_party_DIR} ${THIRD_PARTY_CACHE_PATH}/third_party/${TARGET}_${HASH})
+        ELSEIF(cache_third_party_REPOSITORY)
+            STRING(MD5 HASH_REPO ${cache_third_party_REPOSITORY})
+            STRING(SUBSTRING ${HASH_REPO} 0 16 HASH)
+            # overwrite the original SOURCE_DIR when cache directory
+            SET(${cache_third_party_DIR} ${THIRD_PARTY_CACHE_PATH}/third_party/${TARGET}_${HASH})
+        ENDIF()
+
+        IF(EXISTS ${${cache_third_party_DIR}})
+            # judge whether the cache dir is empty
+            FILE(GLOB files ${${cache_third_party_DIR}}/*)
+            LIST(LENGTH files files_len)
+            IF(files_len GREATER 0)
+                list(APPEND ${TARGET_NAME}_DOWNLOAD_CMD DOWNLOAD_COMMAND "")
+            ENDIF()
+        ENDIF()
+        SET(${cache_third_party_DIR} ${${cache_third_party_DIR}} PARENT_SCOPE)
+    ENDIF()
+
+    # Pass ${TARGET_NAME}_DOWNLOAD_CMD to parent scope, the double quotation marks can't be removed
+    SET(${TARGET_NAME}_DOWNLOAD_CMD "${${TARGET_NAME}_DOWNLOAD_CMD}" PARENT_SCOPE)
+ENDFUNCTION()
+
+MACRO(UNSET_VAR VAR_NAME)
+    UNSET(${VAR_NAME} CACHE)
+    UNSET(${VAR_NAME})
+ENDMACRO()
+
+# Funciton to Download the dependencies during compilation
+# This function has 2 parameters, URL / DIRNAME:
+# 1. URL:           The download url of 3rd dependencies
+# 2. NAME:          The name of file, that determin the dirname
+#
+FUNCTION(file_download_and_uncompress URL NAME)
+  set(options "")
+  set(oneValueArgs MD5)
+  set(multiValueArgs "")
+  cmake_parse_arguments(URL "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
+  MESSAGE(STATUS "Download dependence[${NAME}] from ${URL}, MD5: ${URL_MD5}")
+  SET(${NAME}_INCLUDE_DIR ${THIRD_PARTY_PATH}/${NAME}/data PARENT_SCOPE)
+  ExternalProject_Add(
+      download_${NAME}
+      ${EXTERNAL_PROJECT_LOG_ARGS}
+      PREFIX                ${THIRD_PARTY_PATH}/${NAME}
+      URL                   ${URL}
+      URL_MD5               ${URL_MD5}
+      TIMEOUT               120
+      DOWNLOAD_DIR          ${THIRD_PARTY_PATH}/${NAME}/data/
+      SOURCE_DIR            ${THIRD_PARTY_PATH}/${NAME}/data/
+      DOWNLOAD_NO_PROGRESS  1
+      CONFIGURE_COMMAND     ""
+      BUILD_COMMAND         ""
+      UPDATE_COMMAND        ""
+      INSTALL_COMMAND       ""
+    )
+  set(third_party_deps ${third_party_deps} download_${NAME} PARENT_SCOPE)
+ENDFUNCTION()
+
+
+# Correction of flags on different Platform(WIN/MAC) and Print Warning Message
+if (APPLE)
+    if(WITH_MKL)
+        MESSAGE(WARNING
+            "Mac is not supported with MKL in Paddle yet. Force WITH_MKL=OFF.")
+        set(WITH_MKL OFF CACHE STRING "Disable MKL for building on mac" FORCE)
+    endif()
+endif()
+
+if(WIN32 OR APPLE)
+    MESSAGE(STATUS "Disable XBYAK in Windows and MacOS")
+    SET(WITH_XBYAK OFF CACHE STRING "Disable XBYAK in Windows and MacOS" FORCE)
+
+    if(WITH_LIBXSMM)
+        MESSAGE(WARNING
+            "Windows, Mac are not supported with libxsmm in Paddle yet."
+            "Force WITH_LIBXSMM=OFF")
+        SET(WITH_LIBXSMM OFF CACHE STRING "Disable LIBXSMM in Windows and MacOS" FORCE)
+    endif()
+
+    if(WITH_BOX_PS)
+        MESSAGE(WARNING
+            "Windows or Mac is not supported with BOX_PS in Paddle yet."
+            "Force WITH_BOX_PS=OFF")
+        SET(WITH_BOX_PS OFF CACHE STRING "Disable BOX_PS package in Windows and MacOS" FORCE)
+    endif()
+
+    if(WITH_PSLIB)
+        MESSAGE(WARNING
+            "Windows or Mac is not supported with PSLIB in Paddle yet."
+            "Force WITH_PSLIB=OFF")
+        SET(WITH_PSLIB OFF CACHE STRING "Disable PSLIB package in Windows and MacOS" FORCE)
+    endif()
+
+    if(WITH_LIBMCT)
+        MESSAGE(WARNING
+            "Windows or Mac is not supported with LIBMCT in Paddle yet."
+            "Force WITH_LIBMCT=OFF")
+        SET(WITH_LIBMCT OFF CACHE STRING "Disable LIBMCT package in Windows and MacOS" FORCE)
+    endif()
+
+    if(WITH_PSLIB_BRPC)
+        MESSAGE(WARNING
+            "Windows or Mac is not supported with PSLIB_BRPC in Paddle yet."
+            "Force WITH_PSLIB_BRPC=OFF")
+        SET(WITH_PSLIB_BRPC OFF CACHE STRING "Disable PSLIB_BRPC package in Windows and MacOS" FORCE)
+    endif()
+endif()
+
+set(WITH_MKLML ${WITH_MKL})
+if(NOT DEFINED WITH_MKLDNN)
+    if(WITH_MKL AND AVX2_FOUND)
+        set(WITH_MKLDNN ON)
+    else()
+        message(STATUS "Do not have AVX2 intrinsics and disabled MKL-DNN")
+        set(WITH_MKLDNN OFF)
+    endif()
+endif()
+
+if(WIN32 OR APPLE OR NOT WITH_GPU OR ON_INFER)
+    set(WITH_DGC OFF)
+endif()
+
+if(${CMAKE_VERSION} VERSION_GREATER "3.5.2")
+    set(SHALLOW_CLONE "GIT_SHALLOW TRUE") # adds --depth=1 arg to git clone of External_Projects
+endif()
+
+
+########################### include third_party according to flags ###############################
+include(third_party/libsndfile)      # download, build, install libsndfile 
+include(third_party/boost)     # download boost
+include(third_party/eigen)     # download eigen3
+include(third_party/threadpool)     # download threadpool 
+
+
diff --git a/speechnn/cmake/third_party/absl.cmake b/speechnn/cmake/third_party/absl.cmake
new file mode 100644
index 000000000..c2a8eceb5
--- /dev/null
+++ b/speechnn/cmake/third_party/absl.cmake
@@ -0,0 +1,13 @@
+cmake_minimum_required(VERSION 3.14)
+include(ExternalProject)
+include(FetchContent)
+
+FetchContent_Declare(
+  absl
+  GIT_REPOSITORY "https://github.com/abseil/abseil-cpp.git"
+  GIT_TAG "20210324.1"
+)
+
+FetchContent_MakeAvailable(absl)
+
+
diff --git a/speechnn/cmake/third_party/boost.cmake b/speechnn/cmake/third_party/boost.cmake
new file mode 100644
index 000000000..eb0b2c150
--- /dev/null
+++ b/speechnn/cmake/third_party/boost.cmake
@@ -0,0 +1,49 @@
+include(ExternalProject)
+
+set(BOOST_PROJECT       "extern_boost")
+# To release PaddlePaddle as a pip package, we have to follow the
+# manylinux1 standard, which features as old Linux kernels and
+# compilers as possible and recommends CentOS 5. Indeed, the earliest
+# CentOS version that works with NVIDIA CUDA is CentOS 6.  And a new
+# version of boost, say, 1.66.0, doesn't build on CentOS 6.  We
+# checked that the devtools package of CentOS 6 installs boost 1.41.0.
+# So we use 1.41.0 here.
+set(BOOST_VER   "1.41.0")
+set(BOOST_TAR   "boost_1_41_0" CACHE STRING "" FORCE)
+set(BOOST_URL   "http://paddlepaddledeps.bj.bcebos.com/${BOOST_TAR}.tar.gz" CACHE STRING "" FORCE)
+
+MESSAGE(STATUS "BOOST_VERSION: ${BOOST_VER}, BOOST_URL: ${BOOST_URL}")
+
+set(BOOST_PREFIX_DIR ${THIRD_PARTY_PATH}/boost)
+set(BOOST_SOURCE_DIR ${THIRD_PARTY_PATH}/boost/src/extern_boost)
+cache_third_party(${BOOST_PROJECT}
+        URL       ${BOOST_URL}
+        DIR       BOOST_SOURCE_DIR)
+
+set(BOOST_INCLUDE_DIR "${BOOST_SOURCE_DIR}" CACHE PATH "boost include directory." FORCE)
+set_directory_properties(PROPERTIES CLEAN_NO_CUSTOM 1)
+include_directories(${BOOST_INCLUDE_DIR})
+
+if(WIN32 AND MSVC_VERSION GREATER_EQUAL 1600)
+    add_definitions(-DBOOST_HAS_STATIC_ASSERT)
+endif()
+
+ExternalProject_Add(
+    ${BOOST_PROJECT}
+    ${EXTERNAL_PROJECT_LOG_ARGS}
+    "${BOOST_DOWNLOAD_CMD}"
+    URL_MD5               f891e8c2c9424f0565f0129ad9ab4aff
+    PREFIX                ${BOOST_PREFIX_DIR}
+    DOWNLOAD_DIR          ${BOOST_SOURCE_DIR}
+    SOURCE_DIR            ${BOOST_SOURCE_DIR}
+    DOWNLOAD_NO_PROGRESS  1
+    CONFIGURE_COMMAND     ""
+    BUILD_COMMAND         ""
+    INSTALL_COMMAND       ""
+    UPDATE_COMMAND        ""
+    )
+
+add_library(boost INTERFACE)
+
+add_dependencies(boost ${BOOST_PROJECT})
+set(Boost_INCLUDE_DIR ${BOOST_INCLUDE_DIR})
diff --git a/speechnn/cmake/third_party/eigen.cmake b/speechnn/cmake/third_party/eigen.cmake
new file mode 100644
index 000000000..6a0323071
--- /dev/null
+++ b/speechnn/cmake/third_party/eigen.cmake
@@ -0,0 +1,53 @@
+include(ExternalProject)
+
+# update eigen to the commit id f612df27 on 03/16/2021
+set(EIGEN_PREFIX_DIR ${THIRD_PARTY_PATH}/eigen3)
+set(EIGEN_SOURCE_DIR ${THIRD_PARTY_PATH}/eigen3/src/extern_eigen3)
+set(EIGEN_REPOSITORY https://gitlab.com/libeigen/eigen.git)
+set(EIGEN_TAG        f612df273689a19d25b45ca4f8269463207c4fee)
+
+cache_third_party(extern_eigen3
+    REPOSITORY    ${EIGEN_REPOSITORY}
+    TAG           ${EIGEN_TAG}
+    DIR           EIGEN_SOURCE_DIR)
+
+if(WIN32)
+    add_definitions(-DEIGEN_STRONG_INLINE=inline)
+elseif(LINUX)
+    if(WITH_ROCM)
+        # For HIPCC Eigen::internal::device::numeric_limits is not EIGEN_DEVICE_FUNC
+        # which will cause compiler error of using __host__ funciont in __host__ __device__
+        file(TO_NATIVE_PATH ${PADDLE_SOURCE_DIR}/patches/eigen/Meta.h native_src)
+        file(TO_NATIVE_PATH ${EIGEN_SOURCE_DIR}/Eigen/src/Core/util/Meta.h native_dst)
+        file(TO_NATIVE_PATH ${PADDLE_SOURCE_DIR}/patches/eigen/TensorReductionGpu.h native_src1)
+        file(TO_NATIVE_PATH ${EIGEN_SOURCE_DIR}/unsupported/Eigen/CXX11/src/Tensor/TensorReductionGpu.h native_dst1)
+        set(EIGEN_PATCH_COMMAND cp ${native_src} ${native_dst} && cp ${native_src1} ${native_dst1})
+    endif()
+endif()
+
+set(EIGEN_INCLUDE_DIR ${EIGEN_SOURCE_DIR})
+INCLUDE_DIRECTORIES(${EIGEN_INCLUDE_DIR})
+
+ExternalProject_Add(
+    extern_eigen3
+    ${EXTERNAL_PROJECT_LOG_ARGS}
+    ${SHALLOW_CLONE}
+    "${EIGEN_DOWNLOAD_CMD}"
+    PREFIX          ${EIGEN_PREFIX_DIR}
+    SOURCE_DIR      ${EIGEN_SOURCE_DIR}
+    UPDATE_COMMAND    ""
+    PATCH_COMMAND     ${EIGEN_PATCH_COMMAND}
+    CONFIGURE_COMMAND ""
+    BUILD_COMMAND     ""
+    INSTALL_COMMAND   ""
+    TEST_COMMAND      ""
+)
+
+add_library(eigen3 INTERFACE)
+
+add_dependencies(eigen3 extern_eigen3)
+
+# sw not support thread_local semantic
+if(WITH_SW)
+  add_definitions(-DEIGEN_AVOID_THREAD_LOCAL)
+endif()
diff --git a/speechnn/cmake/third_party/libsndfile.cmake b/speechnn/cmake/third_party/libsndfile.cmake
new file mode 100644
index 000000000..05d5c6ed4
--- /dev/null
+++ b/speechnn/cmake/third_party/libsndfile.cmake
@@ -0,0 +1,11 @@
+cmake_minimum_required(VERSION 3.14)
+include(ExternalProject)
+include(FetchContent)
+
+FetchContent_Declare(
+        libsndfile 
+        GIT_REPOSITORY  https://github.com/libsndfile/libsndfile.git
+        GIT_TAG         v1.0.30 # tag v1.0.30
+)
+
+FetchContent_GetProperties(libsndfile)
diff --git a/speechnn/cmake/third_party/openfst.cmake b/speechnn/cmake/third_party/openfst.cmake
new file mode 100644
index 000000000..39f335a1c
--- /dev/null
+++ b/speechnn/cmake/third_party/openfst.cmake
@@ -0,0 +1,26 @@
+cmake_minimum_required(VERSION 3.14)
+include(ExternalProject)
+include(FetchContent)
+
+FetchContent_Declare(
+        openfst
+        GIT_REPOSITORY  https://github.com/kkm000/openfst
+        GIT_TAG         338225416178ac36b8002d70387f5556e44c8d05 # tag win/1.7.2.1
+)
+
+FetchContent_GetProperties(openfst)
+if(NOT openfst_POPULATED)
+    FetchContent_Populate(openfst)
+    include_directories(${openfst_SOURCE_DIR}/src/include)
+
+    add_subdirectory(${openfst_SOURCE_DIR} ${openfst_BINARY_DIR})
+
+    install(DIRECTORY ${openfst_SOURCE_DIR}/src/include/ DESTINATION include/
+            FILES_MATCHING PATTERN "*.h")
+
+    install(TARGETS fst
+            EXPORT kaldi-targets
+            ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}
+            LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
+            RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR})
+endif()
diff --git a/speechnn/cmake/third_party/openfst_lib_target.cmake b/speechnn/cmake/third_party/openfst_lib_target.cmake
new file mode 100644
index 000000000..dde5efc40
--- /dev/null
+++ b/speechnn/cmake/third_party/openfst_lib_target.cmake
@@ -0,0 +1,31 @@
+if(NOT OPENFST_ROOT_DIR)
+    message(FATAL_ERROR)
+endif()
+
+set(fst_source_dir ${OPENFST_ROOT_DIR}/src/lib)
+set(fst_include_dir ${OPENFST_ROOT_DIR}/src/include)
+
+include_directories(${fst_include_dir})
+file(GLOB fst_sources "${fst_source_dir}/*.cc")
+
+add_library(fst ${fst_sources})
+target_include_directories(fst PUBLIC
+     $<BUILD_INTERFACE:${fst_include_dir}>
+     $<INSTALL_INTERFACE:include/openfst>
+)
+
+install(TARGETS fst
+    EXPORT kaldi-targets
+    ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}
+    LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
+    RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}
+)
+
+install(DIRECTORY ${fst_include_dir}/fst
+    DESTINATION include/openfst
+    PATTERN "test/*.h" EXCLUDE
+)
+
+unset(fst_source_dir)
+unset(fst_include_dir)
+unset(fst_sources)
diff --git a/speechnn/cmake/third_party/threadpool.cmake b/speechnn/cmake/third_party/threadpool.cmake
new file mode 100644
index 000000000..d2c249e9b
--- /dev/null
+++ b/speechnn/cmake/third_party/threadpool.cmake
@@ -0,0 +1,36 @@
+INCLUDE(ExternalProject)
+
+SET(THREADPOOL_PREFIX_DIR ${THIRD_PARTY_PATH}/threadpool)
+SET(THREADPOOL_SOURCE_DIR ${THIRD_PARTY_PATH}/threadpool/src/extern_threadpool)
+if(WITH_ASCEND OR WITH_ASCEND_CL)
+    SET(THREADPOOL_REPOSITORY https://gitee.com/tianjianhe/ThreadPool.git)
+else()
+    SET(THREADPOOL_REPOSITORY ${GIT_URL}/progschj/ThreadPool.git)
+endif()
+SET(THREADPOOL_TAG        9a42ec1329f259a5f4881a291db1dcb8f2ad9040)
+
+cache_third_party(extern_threadpool
+    REPOSITORY   ${THREADPOOL_REPOSITORY}
+    TAG          ${THREADPOOL_TAG}
+    DIR          THREADPOOL_SOURCE_DIR)
+
+SET(THREADPOOL_INCLUDE_DIR ${THREADPOOL_SOURCE_DIR})
+INCLUDE_DIRECTORIES(${THREADPOOL_INCLUDE_DIR})
+
+ExternalProject_Add(
+    extern_threadpool
+    ${EXTERNAL_PROJECT_LOG_ARGS}
+    ${SHALLOW_CLONE}
+    "${THREADPOOL_DOWNLOAD_CMD}"
+    PREFIX          ${THREADPOOL_PREFIX_DIR}
+    SOURCE_DIR      ${THREADPOOL_SOURCE_DIR}
+    UPDATE_COMMAND  ""
+    CONFIGURE_COMMAND ""
+    BUILD_COMMAND     ""
+    INSTALL_COMMAND   ""
+    TEST_COMMAND      ""
+)
+
+add_library(simple_threadpool INTERFACE)
+
+add_dependencies(simple_threadpool extern_threadpool)
diff --git a/speechnn/cmake/third_party/version.cmake b/speechnn/cmake/third_party/version.cmake
new file mode 100644
index 000000000..c3780ee69
--- /dev/null
+++ b/speechnn/cmake/third_party/version.cmake
@@ -0,0 +1,15 @@
+function(get_version)
+    file(READ ${CMAKE_CURRENT_SOURCE_DIR}/src/.version version)
+    string(STRIP ${version} version)
+    execute_process(COMMAND git log -n1 --format=%H src/.version
+                    WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
+                    OUTPUT_VARIABLE version_commit
+                    OUTPUT_STRIP_TRAILING_WHITESPACE)
+    execute_process(COMMAND git rev-list --count "${version_commit}..HEAD"
+                    WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
+                    OUTPUT_VARIABLE patch_number)
+    string(STRIP ${patch_number} patch_number)
+
+    set(KALDI_VERSION ${version} PARENT_SCOPE)
+    set(KALDI_PATCH_NUMBER ${patch_number} PARENT_SCOPE)
+endfunction()
diff --git a/speechnn/core/CMakeLists.txt b/speechnn/core/transformers/.gitkeep
similarity index 100%
rename from speechnn/core/CMakeLists.txt
rename to speechnn/core/transformers/.gitkeep
diff --git a/speechnn/core/transformers/README.md b/speechnn/core/transformers/README.md
new file mode 100644
index 000000000..edbcb9cc3
--- /dev/null
+++ b/speechnn/core/transformers/README.md
@@ -0,0 +1,9 @@
+# Fast Transformers for Speech
+
+- Conformer
+- Transformer
+
+## Reference
+
+* https://github.com/NVIDIA/FasterTransformer.git
+* https://github.com/idiap/fast-transformers
diff --git a/speechnn/core/frontend/CMakeLists.txt b/speechnn/examples/.gitkeep
similarity index 100%
rename from speechnn/core/frontend/CMakeLists.txt
rename to speechnn/examples/.gitkeep
diff --git a/speechnn/core/frontend/audio/CMakeLists.txt b/speechnn/examples/CMakeLists.txt
similarity index 100%
rename from speechnn/core/frontend/audio/CMakeLists.txt
rename to speechnn/examples/CMakeLists.txt
diff --git a/speechnn/core/frontend/text/CMakeLists.txt b/speechnn/speechnn/CMakeLists.txt
similarity index 100%
rename from speechnn/core/frontend/text/CMakeLists.txt
rename to speechnn/speechnn/CMakeLists.txt
diff --git a/speechnn/core/decoder/CMakeLists.txt b/speechnn/speechnn/decoder/CMakeLists.txt
similarity index 100%
rename from speechnn/core/decoder/CMakeLists.txt
rename to speechnn/speechnn/decoder/CMakeLists.txt
diff --git a/speechnn/core/model/CMakeLists.txt b/speechnn/speechnn/frontend/CMakeLists.txt
similarity index 100%
rename from speechnn/core/model/CMakeLists.txt
rename to speechnn/speechnn/frontend/CMakeLists.txt
diff --git a/speechnn/core/protocol/CMakeLists.txt b/speechnn/speechnn/frontend/audio/CMakeLists.txt
similarity index 100%
rename from speechnn/core/protocol/CMakeLists.txt
rename to speechnn/speechnn/frontend/audio/CMakeLists.txt
diff --git a/speechnn/core/utils/CMakeLists.txt b/speechnn/speechnn/frontend/text/CMakeLists.txt
similarity index 100%
rename from speechnn/core/utils/CMakeLists.txt
rename to speechnn/speechnn/frontend/text/CMakeLists.txt
diff --git a/speechnn/third_party/CMakeLists.txt b/speechnn/speechnn/model/CMakeLists.txt
similarity index 100%
rename from speechnn/third_party/CMakeLists.txt
rename to speechnn/speechnn/model/CMakeLists.txt
diff --git a/speechnn/speechnn/nn/CMakeLists.txt b/speechnn/speechnn/nn/CMakeLists.txt
new file mode 100644
index 000000000..e69de29bb
diff --git a/speechnn/speechnn/protocol/CMakeLists.txt b/speechnn/speechnn/protocol/CMakeLists.txt
new file mode 100644
index 000000000..e69de29bb
diff --git a/speechnn/speechnn/utils/CMakeLists.txt b/speechnn/speechnn/utils/CMakeLists.txt
new file mode 100644
index 000000000..e69de29bb
diff --git a/tests/chains/ds2_params_lite_train_infer.txt b/tests/chains/ds2_params_lite_train_infer.txt
new file mode 100644
index 000000000..82a9da9a9
--- /dev/null
+++ b/tests/chains/ds2_params_lite_train_infer.txt
@@ -0,0 +1,51 @@
+===========================train_params===========================
+model_name:deepspeech2
+python:python3.8
+gpu_list:0
+null:null
+null:null
+null:null
+--output:null
+null:null
+--checkpoint_path:
+train_model_name:checkpoints/9
+null:null
+null:null
+##
+trainer:norm_train
+norm_train: ../../../deepspeech/exps/deepspeech2/bin/train.py --nproc 1 --config conf/deepspeech2.yaml --model_type offline --device gpu
+pact_train:null
+fpgm_train:null
+distill_train:null
+null:null
+null:null
+##
+===========================eval_params===========================
+eval: ../../../deepspeech/exps/deepspeech2/bin/test.py --nproc 1 --config conf/deepspeech2.yaml --result_file tests/9.rsl  --model_type offline --device gpu
+null:null
+##
+===========================infer_params===========================
+--export_path:checkpoints/9.jit
+--checkpoint_path:checkpoints/9
+norm_export: ../../../deepspeech/exps/deepspeech2/bin/export.py --nproc 1 --config conf/deepspeech2.yaml --model_type offline
+quant_export:null
+fpgm_export:null
+distill_export:null
+export1:null
+export2:null
+##
+infer_model:null
+infer_export:null
+infer_quant:null
+inference:null
+--use_gpu:null
+--enable_mkldnn:null
+--cpu_threads:null
+--rec_batch_num:null
+--use_tensorrt:null
+--precision:null
+--det_model_dir:null
+--image_dir:null
+--save_log_path:null
+--benchmark:null
+null:null
diff --git a/tests/chains/ds2_params_whole_train_infer.txt b/tests/chains/ds2_params_whole_train_infer.txt
new file mode 100644
index 000000000..e97051c41
--- /dev/null
+++ b/tests/chains/ds2_params_whole_train_infer.txt
@@ -0,0 +1,51 @@
+===========================train_params===========================
+model_name:deepspeech2
+python:python3.8
+gpu_list:0
+null:null
+null:null
+null:null
+--output:null
+null:null
+--checkpoint_path:
+train_model_name:checkpoints/1
+null:null
+null:null
+##
+trainer:norm_train
+norm_train: ../../../deepspeech/exps/deepspeech2/bin/train.py --nproc 1 --config conf/deepspeech2.yaml --model_type offline --device gpu
+pact_train:null
+fpgm_train:null
+distill_train:null
+null:null
+null:null
+##
+===========================eval_params===========================
+eval: ../../../deepspeech/exps/deepspeech2/bin/test.py --nproc 1 --config conf/deepspeech2.yaml --result_file tests/1.rsl  --model_type offline --device gpu
+null:null
+##
+===========================infer_params===========================
+--export_path:checkpoints/1.jit
+--checkpoint_path:checkpoints/1
+norm_export: ../../../deepspeech/exps/deepspeech2/bin/export.py --nproc 1 --config conf/deepspeech2.yaml --model_type offline
+quant_export:null
+fpgm_export:null
+distill_export:null
+export1:null
+export2:null
+##
+infer_model:null
+infer_export:null
+infer_quant:null
+inference:null
+--use_gpu:null
+--enable_mkldnn:null
+--cpu_threads:null
+--rec_batch_num:null
+--use_tensorrt:null
+--precision:null
+--det_model_dir:null
+--image_dir:null
+--save_log_path:null
+--benchmark:null
+null:null
diff --git a/tests/chains/lite_train_infer.sh b/tests/chains/lite_train_infer.sh
new file mode 100644
index 000000000..76b22a38c
--- /dev/null
+++ b/tests/chains/lite_train_infer.sh
@@ -0,0 +1,5 @@
+bash prepare.sh ds2_params_lite_train_infer.txt lite_train_infer
+cd ../../examples/tiny/s0
+source path.sh
+bash ../../../tests/chains/test.sh ../../../tests/chains/ds2_params_lite_train_infer.txt lite_train_infer
+cd ../../../tests/chains
diff --git a/tests/chains/prepare.sh b/tests/chains/prepare.sh
new file mode 100644
index 000000000..73a302836
--- /dev/null
+++ b/tests/chains/prepare.sh
@@ -0,0 +1,84 @@
+#!/bin/bash
+FILENAME=$1
+# MODE be one of ['lite_train_infer' 'whole_infer' 'whole_train_infer', 'infer']
+MODE=$2
+
+dataline=$(cat ${FILENAME})
+
+# parser params
+IFS=$'\n'
+lines=(${dataline})
+function func_parser_key(){
+    strs=$1
+    IFS=":"
+    array=(${strs})
+    tmp=${array[0]}
+    echo ${tmp}
+}
+function func_parser_value(){
+    strs=$1
+    IFS=":"
+    array=(${strs})
+    tmp=${array[1]}
+    echo ${tmp}
+}
+IFS=$'\n'
+# The training params
+model_name=$(func_parser_value "${lines[1]}")
+
+trainer_list=$(func_parser_value "${lines[14]}")
+
+# MODE be one of ['lite_train_infer' 'whole_infer' 'whole_train_infer']
+MODE=$2
+
+if [ ${MODE} = "lite_train_infer" ];then
+    # pretrain lite train data
+    curPath=$(readlink -f "$(dirname "$0")")
+    cd ${curPath}/../../examples/tiny/s0
+    source path.sh
+    # download audio data
+    bash ./local/data.sh || exit -1
+    # download language model
+    bash local/download_lm_en.sh
+    if [ $? -ne 0 ]; then
+    exit 1
+    fi
+    cd ${curPath}
+
+elif [ ${MODE} = "whole_train_infer" ];then
+    curPath=$(readlink -f "$(dirname "$0")")
+    cd ${curPath}/../../examples/aishell/s0
+    source path.sh
+    # download audio data
+    bash ./local/data.sh || exit -1
+    # download language model
+    bash local/download_lm_ch.sh
+    if [ $? -ne 0 ]; then
+    exit 1
+    fi
+    cd ${curPath}
+elif [ ${MODE} = "whole_infer" ];then
+    curPath=$(readlink -f "$(dirname "$0")")
+    cd ${curPath}/../../examples/aishell/s0
+    source path.sh
+    # download audio data
+    bash ./local/data.sh || exit -1
+    # download language model
+    bash local/download_lm_ch.sh
+    if [ $? -ne 0 ]; then
+    exit 1
+    fi
+    cd ${curPath}
+else
+    curPath=$(readlink -f "$(dirname "$0")")
+    cd ${curPath}/../../examples/aishell/s0
+    source path.sh
+    # download audio data
+    bash ./local/data.sh || exit -1
+    # download language model
+    bash local/download_lm_ch.sh
+    if [ $? -ne 0 ]; then
+    exit 1
+    fi
+    cd ${curPath}
+fi
diff --git a/tests/chains/test.sh b/tests/chains/test.sh
new file mode 100644
index 000000000..6a48ba765
--- /dev/null
+++ b/tests/chains/test.sh
@@ -0,0 +1,365 @@
+#!/bin/bash
+FILENAME=$1
+# MODE be one of ['lite_train_infer' 'whole_infer' 'whole_train_infer', 'infer']
+MODE=$2
+
+dataline=$(cat ${FILENAME})
+
+# parser params
+IFS=$'\n'
+lines=(${dataline})
+
+function func_parser_key(){
+    strs=$1
+    IFS=":"
+    array=(${strs})
+    tmp=${array[0]}
+    echo ${tmp}
+}
+function func_parser_value(){
+    strs=$1
+    IFS=":"
+    array=(${strs})
+    tmp=${array[1]}
+    echo ${tmp}
+}
+function func_set_params(){
+    key=$1
+    value=$2
+    if [ ${key} = "null" ];then
+        echo " "
+    elif [[ ${value} = "null" ]] || [[ ${value} = " " ]] || [ ${#value} -le 0 ];then
+        echo " "
+    else
+        echo "${key}=${value}"
+    fi
+}
+function func_parser_params(){
+    strs=$1
+    IFS=":"
+    array=(${strs})
+    key=${array[0]}
+    tmp=${array[1]}
+    IFS="|"
+    res=""
+    for _params in ${tmp[*]}; do
+        IFS="="
+        array=(${_params})
+        mode=${array[0]}
+        value=${array[1]}
+        if [[ ${mode} = ${MODE} ]]; then
+            IFS="|"
+            #echo $(func_set_params "${mode}" "${value}")
+            echo $value
+            break
+        fi
+        IFS="|"
+    done
+    echo ${res}
+}
+function status_check(){
+    last_status=$1   # the exit code
+    run_command=$2
+    run_log=$3
+    if [ $last_status -eq 0 ]; then
+        echo -e "\033[33m Run successfully with command - ${run_command}!  \033[0m" | tee -a ${run_log}
+    else
+        echo -e "\033[33m Run failed with command - ${run_command}!  \033[0m" | tee -a ${run_log}
+    fi
+}
+
+IFS=$'\n'
+# The training params
+model_name=$(func_parser_value "${lines[1]}")
+python=$(func_parser_value "${lines[2]}")
+gpu_list=$(func_parser_value "${lines[3]}")
+train_use_gpu_key=$(func_parser_key "${lines[4]}")
+train_use_gpu_value=$(func_parser_value "${lines[4]}")
+autocast_list=$(func_parser_value "${lines[5]}")
+autocast_key=$(func_parser_key "${lines[5]}")
+epoch_key=$(func_parser_key "${lines[6]}")
+epoch_num=$(func_parser_params "${lines[6]}")
+save_model_key=$(func_parser_key "${lines[7]}")
+train_batch_key=$(func_parser_key "${lines[8]}")
+train_batch_value=$(func_parser_params "${lines[8]}")
+pretrain_model_key=$(func_parser_key "${lines[9]}")
+pretrain_model_value=$(func_parser_value "${lines[9]}")
+train_model_name=$(func_parser_value "${lines[10]}")
+train_infer_img_dir=$(func_parser_value "${lines[11]}")
+train_param_key1=$(func_parser_key "${lines[12]}")
+train_param_value1=$(func_parser_value "${lines[12]}")
+
+trainer_list=$(func_parser_value "${lines[14]}")
+trainer_norm=$(func_parser_key "${lines[15]}")
+norm_trainer=$(func_parser_value "${lines[15]}")
+pact_key=$(func_parser_key "${lines[16]}")
+pact_trainer=$(func_parser_value "${lines[16]}")
+fpgm_key=$(func_parser_key "${lines[17]}")
+fpgm_trainer=$(func_parser_value "${lines[17]}")
+distill_key=$(func_parser_key "${lines[18]}")
+distill_trainer=$(func_parser_value "${lines[18]}")
+trainer_key1=$(func_parser_key "${lines[19]}")
+trainer_value1=$(func_parser_value "${lines[19]}")
+trainer_key2=$(func_parser_key "${lines[20]}")
+trainer_value2=$(func_parser_value "${lines[20]}")
+
+eval_py=$(func_parser_value "${lines[23]}")
+eval_key1=$(func_parser_key "${lines[24]}")
+eval_value1=$(func_parser_value "${lines[24]}")
+
+save_infer_key=$(func_parser_key "${lines[27]}")
+export_weight=$(func_parser_key "${lines[28]}")
+norm_export=$(func_parser_value "${lines[29]}")
+pact_export=$(func_parser_value "${lines[30]}")
+fpgm_export=$(func_parser_value "${lines[31]}")
+distill_export=$(func_parser_value "${lines[32]}")
+export_key1=$(func_parser_key "${lines[33]}")
+export_value1=$(func_parser_value "${lines[33]}")
+export_key2=$(func_parser_key "${lines[34]}")
+export_value2=$(func_parser_value "${lines[34]}")
+
+# parser inference model
+infer_model_dir_list=$(func_parser_value "${lines[36]}")
+infer_export_list=$(func_parser_value "${lines[37]}")
+infer_is_quant=$(func_parser_value "${lines[38]}")
+# parser inference
+inference_py=$(func_parser_value "${lines[39]}")
+use_gpu_key=$(func_parser_key "${lines[40]}")
+use_gpu_list=$(func_parser_value "${lines[40]}")
+use_mkldnn_key=$(func_parser_key "${lines[41]}")
+use_mkldnn_list=$(func_parser_value "${lines[41]}")
+cpu_threads_key=$(func_parser_key "${lines[42]}")
+cpu_threads_list=$(func_parser_value "${lines[42]}")
+batch_size_key=$(func_parser_key "${lines[43]}")
+batch_size_list=$(func_parser_value "${lines[43]}")
+use_trt_key=$(func_parser_key "${lines[44]}")
+use_trt_list=$(func_parser_value "${lines[44]}")
+precision_key=$(func_parser_key "${lines[45]}")
+precision_list=$(func_parser_value "${lines[45]}")
+infer_model_key=$(func_parser_key "${lines[46]}")
+image_dir_key=$(func_parser_key "${lines[47]}")
+infer_img_dir=$(func_parser_value "${lines[47]}")
+save_log_key=$(func_parser_key "${lines[48]}")
+benchmark_key=$(func_parser_key "${lines[49]}")
+benchmark_value=$(func_parser_value "${lines[49]}")
+infer_key1=$(func_parser_key "${lines[50]}")
+infer_value1=$(func_parser_value "${lines[50]}")
+
+LOG_PATH="./tests/output"
+mkdir -p ${LOG_PATH}
+status_log="${LOG_PATH}/results.log"
+
+
+function func_inference(){
+    IFS='|'
+    _python=$1
+    _script=$2
+    _model_dir=$3
+    _log_path=$4
+    _img_dir=$5
+    _flag_quant=$6
+    # inference
+    for use_gpu in ${use_gpu_list[*]}; do
+        if [ ${use_gpu} = "False" ] || [ ${use_gpu} = "cpu" ]; then
+            for use_mkldnn in ${use_mkldnn_list[*]}; do
+                if [ ${use_mkldnn} = "False" ] && [ ${_flag_quant} = "True" ]; then
+                    continue
+                fi
+                for threads in ${cpu_threads_list[*]}; do
+                    for batch_size in ${batch_size_list[*]}; do
+                        _save_log_path="${_log_path}/infer_cpu_usemkldnn_${use_mkldnn}_threads_${threads}_batchsize_${batch_size}.log"
+                        set_infer_data=$(func_set_params "${image_dir_key}" "${_img_dir}")
+                        set_benchmark=$(func_set_params "${benchmark_key}" "${benchmark_value}")
+                        set_batchsize=$(func_set_params "${batch_size_key}" "${batch_size}")
+                        set_cpu_threads=$(func_set_params "${cpu_threads_key}" "${threads}")
+                        set_model_dir=$(func_set_params "${infer_model_key}" "${_model_dir}")
+                        set_infer_params1=$(func_set_params "${infer_key1}" "${infer_value1}")
+                        command="${_python} ${_script} ${use_gpu_key}=${use_gpu} ${use_mkldnn_key}=${use_mkldnn} ${set_cpu_threads} ${set_model_dir} ${set_batchsize} ${set_infer_data} ${set_benchmark} ${set_infer_params1} > ${_save_log_path} 2>&1 "
+                        eval $command
+                        last_status=${PIPESTATUS[0]}
+                        eval "cat ${_save_log_path}"
+                        status_check $last_status "${command}" "${status_log}"
+                    done
+                done
+            done
+        elif [ ${use_gpu} = "True" ] || [ ${use_gpu} = "gpu" ]; then
+            for use_trt in ${use_trt_list[*]}; do
+                for precision in ${precision_list[*]}; do
+                    if [[ ${_flag_quant} = "False" ]] && [[ ${precision} =~ "int8" ]]; then
+                        continue
+                    fi
+                    if [[ ${precision} =~ "fp16" || ${precision} =~ "int8" ]] && [ ${use_trt} = "False" ]; then
+                        continue
+                    fi
+                    if [[ ${use_trt} = "False" || ${precision} =~ "int8" ]] && [ ${_flag_quant} = "True" ]; then
+                        continue
+                    fi
+                    for batch_size in ${batch_size_list[*]}; do
+                        _save_log_path="${_log_path}/infer_gpu_usetrt_${use_trt}_precision_${precision}_batchsize_${batch_size}.log"
+                        set_infer_data=$(func_set_params "${image_dir_key}" "${_img_dir}")
+                        set_benchmark=$(func_set_params "${benchmark_key}" "${benchmark_value}")
+                        set_batchsize=$(func_set_params "${batch_size_key}" "${batch_size}")
+                        set_tensorrt=$(func_set_params "${use_trt_key}" "${use_trt}")
+                        set_precision=$(func_set_params "${precision_key}" "${precision}")
+                        set_model_dir=$(func_set_params "${infer_model_key}" "${_model_dir}")
+                        set_infer_params1=$(func_set_params "${infer_key1}" "${infer_value1}")
+                        command="${_python} ${_script} ${use_gpu_key}=${use_gpu} ${set_tensorrt} ${set_precision} ${set_model_dir} ${set_batchsize} ${set_infer_data} ${set_benchmark} ${set_infer_params1} > ${_save_log_path} 2>&1 "
+                        eval $command
+                        last_status=${PIPESTATUS[0]}
+                        eval "cat ${_save_log_path}"
+                        status_check $last_status "${command}" "${status_log}"
+
+                    done
+                done
+            done
+        else
+            echo "Does not support hardware other than CPU and GPU Currently!"
+        fi
+    done
+}
+
+if [ ${MODE} = "infer" ]; then
+    GPUID=$3
+    if [ ${#GPUID} -le 0 ];then
+        env=" "
+    else
+        env="export CUDA_VISIBLE_DEVICES=${GPUID}"
+    fi
+    # set CUDA_VISIBLE_DEVICES
+    eval $env
+    export Count=0
+    IFS="|"
+    infer_run_exports=(${infer_export_list})
+    infer_quant_flag=(${infer_is_quant})
+    for infer_model in ${infer_model_dir_list[*]}; do
+        # run export
+        if [ ${infer_run_exports[Count]} != "null" ];then
+            save_infer_dir=$(dirname $infer_model)
+            set_export_weight=$(func_set_params "${export_weight}" "${infer_model}")
+            set_save_infer_key=$(func_set_params "${save_infer_key}" "${save_infer_dir}")
+            export_cmd="${python} ${norm_export} ${set_export_weight} ${set_save_infer_key}"
+            eval $export_cmd
+            status_export=$?
+            if [ ${status_export} = 0 ];then
+                status_check $status_export "${export_cmd}" "${status_log}"
+            fi
+        else
+            save_infer_dir=${infer_model}
+        fi
+        #run inference
+        is_quant=${infer_quant_flag[Count]}
+        func_inference "${python}" "${inference_py}" "${save_infer_dir}" "${LOG_PATH}" "${infer_img_dir}" ${is_quant}
+        Count=$(($Count + 1))
+    done
+
+else
+    IFS="|"
+    export Count=0
+    USE_GPU_KEY=(${train_use_gpu_value})
+    for gpu in ${gpu_list[*]}; do
+        use_gpu=${USE_GPU_KEY[Count]}
+        Count=$(($Count + 1))
+        if [ ${gpu} = "-1" ];then
+            env=""
+        elif [ ${#gpu} -le 1 ];then
+            env="export CUDA_VISIBLE_DEVICES=${gpu}"
+            eval ${env}
+        elif [ ${#gpu} -le 15 ];then
+            IFS=","
+            array=(${gpu})
+            env="export CUDA_VISIBLE_DEVICES=${array[0]}"
+            IFS="|"
+        else
+            IFS=";"
+            array=(${gpu})
+            ips=${array[0]}
+            gpu=${array[1]}
+            IFS="|"
+            env=" "
+        fi
+        for autocast in ${autocast_list[*]}; do
+            for trainer in ${trainer_list[*]}; do
+                flag_quant=False
+                if [ ${trainer} = ${pact_key} ]; then
+                    run_train=${pact_trainer}
+                    run_export=${pact_export}
+                    flag_quant=True
+                elif [ ${trainer} = "${fpgm_key}" ]; then
+                    run_train=${fpgm_trainer}
+                    run_export=${fpgm_export}
+                elif [ ${trainer} = "${distill_key}" ]; then
+                    run_train=${distill_trainer}
+                    run_export=${distill_export}
+                elif [ ${trainer} = ${trainer_key1} ]; then
+                    run_train=${trainer_value1}
+                    run_export=${export_value1}
+                elif [[ ${trainer} = ${trainer_key2} ]]; then
+                    run_train=${trainer_value2}
+                    run_export=${export_value2}
+                else
+                    run_train=${norm_trainer}
+                    run_export=${norm_export}
+                fi
+
+                if [ ${run_train} = "null" ]; then
+                    continue
+                fi
+
+                set_autocast=$(func_set_params "${autocast_key}" "${autocast}")
+                set_epoch=$(func_set_params "${epoch_key}" "${epoch_num}")
+                set_pretrain=$(func_set_params "${pretrain_model_key}" "${pretrain_model_value}")
+                set_batchsize=$(func_set_params "${train_batch_key}" "${train_batch_value}")
+                set_train_params1=$(func_set_params "${train_param_key1}" "${train_param_value1}")
+                set_use_gpu=$(func_set_params "${train_use_gpu_key}" "${use_gpu}")
+                save_log="${LOG_PATH}/${trainer}_gpus_${gpu}_autocast_${autocast}"
+
+                # load pretrain from norm training if current trainer is pact or fpgm trainer
+                if [ ${trainer} = ${pact_key} ] || [ ${trainer} = ${fpgm_key} ]; then
+                    set_pretrain="${load_norm_train_model}"
+                fi
+
+                set_save_model=$(func_set_params "${save_model_key}" "${save_log}")
+                if [ ${#gpu} -le 2 ];then  # train with cpu or single gpu
+                    cmd="${python} ${run_train} ${set_use_gpu}  ${set_save_model} ${set_epoch} ${set_pretrain} ${set_autocast} ${set_batchsize} ${set_train_params1} "
+                elif [ ${#gpu} -le 15 ];then  # train with multi-gpu
+                    cmd="${python} -m paddle.distributed.launch --gpus=${gpu} ${run_train} ${set_save_model} ${set_epoch} ${set_pretrain} ${set_autocast} ${set_batchsize} ${set_train_params1}"
+                else     # train with multi-machine
+                    cmd="${python} -m paddle.distributed.launch --ips=${ips} --gpus=${gpu} ${run_train} ${set_save_model} ${set_pretrain} ${set_epoch} ${set_autocast} ${set_batchsize} ${set_train_params1}"
+                fi
+                # run train
+                #eval "unset CUDA_VISIBLE_DEVICES"
+                eval $cmd
+                status_check $? "${cmd}" "${status_log}"
+
+                set_eval_pretrain=$(func_set_params "${pretrain_model_key}" "${save_log}/${train_model_name}")
+                # save norm trained models to set pretrain for pact training and fpgm training
+                if [ ${trainer} = ${trainer_norm} ]; then
+                    load_norm_train_model=${set_eval_pretrain}
+                fi
+                # run eval
+                if [ ${eval_py} != "null" ]; then
+                    set_eval_params1=$(func_set_params "${eval_key1}" "${eval_value1}")
+                    eval_cmd="${python} ${eval_py} ${set_eval_pretrain} ${set_use_gpu} ${set_eval_params1}"
+                    eval $eval_cmd
+                    status_check $? "${eval_cmd}" "${status_log}"
+                fi
+                # run export model
+                if [ ${run_export} != "null" ]; then
+                    # run export model
+                    save_infer_path="${save_log}"
+                    set_export_weight=$(func_set_params "${export_weight}" "${save_log}/${train_model_name}")
+                    set_save_infer_key=$(func_set_params "${save_infer_key}" "${save_infer_path}")
+                    export_cmd="${python} ${run_export} ${set_export_weight} ${set_save_infer_key}"
+                    eval $export_cmd
+                    status_check $? "${export_cmd}" "${status_log}"
+
+                    #run inference
+                    eval $env
+                    save_infer_path="${save_log}"
+                    func_inference "${python}" "${inference_py}" "${save_infer_path}" "${LOG_PATH}" "${train_infer_img_dir}" "${flag_quant}"
+                    eval "unset CUDA_VISIBLE_DEVICES"
+                fi
+            done  # done with:    for trainer in ${trainer_list[*]}; do
+        done      # done with:    for autocast in ${autocast_list[*]}; do
+    done          # done with:    for gpu in ${gpu_list[*]}; do
+fi  # end if [ ${MODE} = "infer" ]; then
diff --git a/tests/chains/whole_train_infer.sh b/tests/chains/whole_train_infer.sh
new file mode 100644
index 000000000..496041a7b
--- /dev/null
+++ b/tests/chains/whole_train_infer.sh
@@ -0,0 +1,5 @@
+bash prepare.sh ds2_params_whole_train_infer.txt whole_train_infer
+cd ../../examples/aishell/s0
+source path.sh
+bash ../../../tests/chains/test.sh ../../../tests/chains/ds2_params_whole_train_infer.txt whole_train_infer
+cd ../../../tests/chains
diff --git a/tools/extras/install_mfa.sh b/tools/extras/install_mfa.sh
index b0a4cf990..ae126fa62 100755
--- a/tools/extras/install_mfa.sh
+++ b/tools/extras/install_mfa.sh
@@ -4,7 +4,7 @@
 
 test -d Montreal-Forced-Aligner || git clone https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner.git
 
-pushd Montreal-Forced-Aligner && git checkout v2.0.0a7 &&  python setup.py install
+pushd Montreal-Forced-Aligner &&  python setup.py install && popd
 
 test -d kaldi || { echo "need install kaldi first"; exit 1;}