You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
PaddleSpeech/.notebook/WarmupLR.ipynb

340 lines
38 KiB

{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "d6a0e098",
"metadata": {},
"outputs": [],
"source": [
"from typing import Union\n",
"\n",
"import torch\n",
"from torch.optim.lr_scheduler import _LRScheduler\n",
"\n",
"from typeguard import check_argument_types\n",
"\n",
"\n",
"class WarmupLR(_LRScheduler):\n",
" \"\"\"The WarmupLR scheduler\n",
" This scheduler is almost same as NoamLR Scheduler except for following\n",
" difference:\n",
" NoamLR:\n",
" lr = optimizer.lr * model_size ** -0.5\n",
" * min(step ** -0.5, step * warmup_step ** -1.5)\n",
" WarmupLR:\n",
" lr = optimizer.lr * warmup_step ** 0.5\n",
" * min(step ** -0.5, step * warmup_step ** -1.5)\n",
" Note that the maximum lr equals to optimizer.lr in this scheduler.\n",
" \"\"\"\n",
"\n",
" def __init__(\n",
" self,\n",
" optimizer: torch.optim.Optimizer,\n",
" warmup_steps: Union[int, float] = 25000,\n",
" last_epoch: int = -1,\n",
" ):\n",
" assert check_argument_types()\n",
" self.warmup_steps = warmup_steps\n",
"\n",
" # __init__() must be invoked before setting field\n",
" # because step() is also invoked in __init__()\n",
" super().__init__(optimizer, last_epoch)\n",
"\n",
" def __repr__(self):\n",
" return f\"{self.__class__.__name__}(warmup_steps={self.warmup_steps})\"\n",
"\n",
" def get_lr(self):\n",
" step_num = self.last_epoch + 1\n",
" return [\n",
" lr\n",
" * self.warmup_steps ** 0.5\n",
" * min(step_num ** -0.5, step_num * self.warmup_steps ** -1.5)\n",
" for lr in self.base_lrs\n",
" ]\n",
"\n",
" def set_step(self, step: int):\n",
" self.last_epoch = step"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "0d496677",
"metadata": {},
"outputs": [],
"source": [
"import torch.optim as optim\n",
"model = torch.nn.Linear(10, 200)\n",
"optimizer = optim.Adam(model.parameters())\n",
"scheduler = WarmupLR(optimizer, warmup_steps=25000)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "e3e3f3dc",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0 0.0 -1\n"
]
}
],
"source": [
"infos = {}\n",
"start_epoch = infos.get('epoch', -1) + 1\n",
"cv_loss = infos.get('cv_loss', 0.0)\n",
"step = infos.get('step', -1)\n",
"print(start_epoch, cv_loss, step)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "dc3d550c",
"metadata": {},
"outputs": [],
"source": [
"scheduler.set_step(step)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "e527634e",
"metadata": {},
"outputs": [],
"source": [
"lrs=[]\n",
"for i in range(100000):\n",
" scheduler.step()\n",
" lrs.append(scheduler.get_lr())"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "f1452db9",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Collecting matplotlib\n",
" Downloading matplotlib-3.4.1-cp38-cp38-manylinux1_x86_64.whl (10.3 MB)\n",
"\u001b[K |████████████████████████████████| 10.3 MB 575 kB/s eta 0:00:01\n",
"\u001b[?25hCollecting kiwisolver>=1.0.1\n",
" Downloading kiwisolver-1.3.1-cp38-cp38-manylinux1_x86_64.whl (1.2 MB)\n",
"\u001b[K |████████████████████████████████| 1.2 MB 465 kB/s eta 0:00:01\n",
"\u001b[?25hRequirement already satisfied: pillow>=6.2.0 in /workspace/wenet/venv/lib/python3.8/site-packages (from matplotlib) (8.1.2)\n",
"Requirement already satisfied: numpy>=1.16 in /workspace/wenet/venv/lib/python3.8/site-packages (from matplotlib) (1.20.1)\n",
"Requirement already satisfied: python-dateutil>=2.7 in /workspace/wenet/venv/lib/python3.8/site-packages (from matplotlib) (2.8.1)\n",
"Collecting cycler>=0.10\n",
" Downloading cycler-0.10.0-py2.py3-none-any.whl (6.5 kB)\n",
"Requirement already satisfied: pyparsing>=2.2.1 in /workspace/wenet/venv/lib/python3.8/site-packages (from matplotlib) (2.4.7)\n",
"Requirement already satisfied: six in /workspace/wenet/venv/lib/python3.8/site-packages (from cycler>=0.10->matplotlib) (1.15.0)\n",
"Installing collected packages: kiwisolver, cycler, matplotlib\n",
"Successfully installed cycler-0.10.0 kiwisolver-1.3.1 matplotlib-3.4.1\n"
]
}
],
"source": [
"!pip install matplotlib\n",
"import matplotlib.pyplot as plt\n",
"\n",
"%matplotlib inline"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "0f36d04f",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[<matplotlib.lines.Line2D at 0x7f0c39aa82e0>]"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"xs = list(range(100000))\n",
"plt.plot(xs, lrs)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "4f4e282c",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/workspace/wenet/venv/lib/python3.8/site-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n",
" and should_run_async(code)\n"
]
}
],
"source": [
"from typing import Union\n",
"\n",
"from paddle.optimizer.lr import LRScheduler\n",
"from typeguard import check_argument_types\n",
"\n",
"class WarmupLR(LRScheduler):\n",
" \"\"\"The WarmupLR scheduler\n",
" This scheduler is almost same as NoamLR Scheduler except for following\n",
" difference:\n",
" NoamLR:\n",
" lr = optimizer.lr * model_size ** -0.5\n",
" * min(step ** -0.5, step * warmup_step ** -1.5)\n",
" WarmupLR:\n",
" lr = optimizer.lr * warmup_step ** 0.5\n",
" * min(step ** -0.5, step * warmup_step ** -1.5)\n",
" Note that the maximum lr equals to optimizer.lr in this scheduler.\n",
" \"\"\"\n",
"\n",
" def __init__(self,\n",
" warmup_steps: Union[int, float]=25000,\n",
" learning_rate=1.0,\n",
" last_epoch=-1,\n",
" verbose=False):\n",
" assert check_argument_types()\n",
" self.warmup_steps = warmup_steps\n",
" super().__init__(learning_rate, last_epoch, verbose)\n",
"\n",
" def __repr__(self):\n",
" return f\"{self.__class__.__name__}(warmup_steps={self.warmup_steps})\"\n",
"\n",
" def get_lr(self):\n",
" step_num = self.last_epoch + 1\n",
" return self.base_lr * self.warmup_steps**0.5 * min(\n",
" step_num**-0.5, step_num * self.warmup_steps**-1.5)\n",
"\n",
" def set_step(self, step: int):\n",
" self.step(step)"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "8c40b202",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"-1\n"
]
}
],
"source": [
"sc = WarmupLR(warmup_steps=25000, learning_rate=0.001)\n",
"print(step)\n",
"#sc.set_step(step)\n",
"sc.set_step(0)"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "ecbc7e37",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[<matplotlib.lines.Line2D at 0x7f0ba6dd9c40>]"
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"lrs=[]\n",
"for i in range(100000):\n",
" sc.step()\n",
" lrs.append(sc.get_lr())\n",
"xs = list(range(100000))\n",
"plt.plot(xs, lrs)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e613fe16",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "f0fd9f40",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
}