You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
450 lines
13 KiB
450 lines
13 KiB
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"id": "primary-organic",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import torch"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 38,
|
|
"id": "stopped-semester",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def mask_finished_scores(score: torch.Tensor,\n",
|
|
" flag: torch.Tensor) -> torch.Tensor:\n",
|
|
" \"\"\"\n",
|
|
" If a sequence is finished, we only allow one alive branch. This function\n",
|
|
" aims to give one branch a zero score and the rest -inf score.\n",
|
|
" Args:\n",
|
|
" score (torch.Tensor): A real value array with shape\n",
|
|
" (batch_size * beam_size, beam_size).\n",
|
|
" flag (torch.Tensor): A bool array with shape\n",
|
|
" (batch_size * beam_size, 1).\n",
|
|
" Returns:\n",
|
|
" torch.Tensor: (batch_size * beam_size, beam_size).\n",
|
|
" \"\"\"\n",
|
|
" beam_size = score.size(-1)\n",
|
|
" zero_mask = torch.zeros_like(flag, dtype=torch.bool)\n",
|
|
" if beam_size > 1:\n",
|
|
" unfinished = torch.cat((zero_mask, flag.repeat([1, beam_size - 1])),\n",
|
|
" dim=1)\n",
|
|
" finished = torch.cat((flag, zero_mask.repeat([1, beam_size - 1])),\n",
|
|
" dim=1)\n",
|
|
" else:\n",
|
|
" unfinished = zero_mask\n",
|
|
" finished = flag\n",
|
|
" print(unfinished)\n",
|
|
" print(finished)\n",
|
|
" score.masked_fill_(unfinished, -float('inf'))\n",
|
|
" score.masked_fill_(finished, 0)\n",
|
|
" return score"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 58,
|
|
"id": "agreed-portuguese",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"tensor([[ True],\n",
|
|
" [False]])\n",
|
|
"tensor([[-0.8841, 0.7381, -0.9986],\n",
|
|
" [ 0.2675, -0.7971, 0.3798]])\n",
|
|
"tensor([[ True, True],\n",
|
|
" [False, False]])\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"score = torch.randn((2, 3))\n",
|
|
"flag = torch.ones((2, 1), dtype=torch.bool)\n",
|
|
"flag[1] = False\n",
|
|
"print(flag)\n",
|
|
"print(score)\n",
|
|
"print(flag.repeat([1, 2]))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 59,
|
|
"id": "clean-aspect",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"tensor([[False, True, True],\n",
|
|
" [False, False, False]])\n",
|
|
"tensor([[ True, False, False],\n",
|
|
" [False, False, False]])\n",
|
|
"tensor([[ 0.0000, -inf, -inf],\n",
|
|
" [ 0.2675, -0.7971, 0.3798]])\n",
|
|
"tensor([[ 0.0000, -inf, -inf],\n",
|
|
" [ 0.2675, -0.7971, 0.3798]])\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"r = mask_finished_scores(score, flag)\n",
|
|
"print(r)\n",
|
|
"print(score)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 55,
|
|
"id": "thrown-airline",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Tensor(shape=[2, 1], dtype=bool, place=CUDAPlace(0), stop_gradient=True,\n",
|
|
" [[True ],\n",
|
|
" [False]])\n",
|
|
"Tensor(shape=[2, 3], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n",
|
|
" [[ 2.05994511, 1.87704289, 0.01988174],\n",
|
|
" [-0.40165186, 0.77547729, -0.64469045]])\n",
|
|
"Tensor(shape=[2, 2], dtype=bool, place=CUDAPlace(0), stop_gradient=True,\n",
|
|
" [[True , True ],\n",
|
|
" [False, False]])\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"import paddle\n",
|
|
"\n",
|
|
"score = paddle.randn((2, 3))\n",
|
|
"flag = paddle.ones((2, 1), dtype='bool')\n",
|
|
"flag[1] = False\n",
|
|
"print(flag)\n",
|
|
"print(score)\n",
|
|
"print(flag.tile([1, 2]))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 56,
|
|
"id": "internal-patent",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Tensor(shape=[2, 3], dtype=bool, place=CUDAPlace(0), stop_gradient=True,\n",
|
|
" [[False, True , True ],\n",
|
|
" [False, False, False]])\n",
|
|
"Tensor(shape=[2, 3], dtype=bool, place=CUDAPlace(0), stop_gradient=True,\n",
|
|
" [[True , False, False],\n",
|
|
" [False, False, False]])\n",
|
|
"x Tensor(shape=[2, 3], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n",
|
|
" [[ 2.05994511, 1.87704289, 0.01988174],\n",
|
|
" [-0.40165186, 0.77547729, -0.64469045]])\n",
|
|
"2 Tensor(shape=[2, 3], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n",
|
|
" [[ 2.05994511, 1.87704289, 0.01988174],\n",
|
|
" [-0.40165186, 0.77547729, -0.64469045]])\n",
|
|
"3 Tensor(shape=[2, 3], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n",
|
|
" [[ 2.05994511, -inf. , -inf. ],\n",
|
|
" [-0.40165186, 0.77547729, -0.64469045]])\n",
|
|
"x Tensor(shape=[2, 3], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n",
|
|
" [[ 2.05994511, -inf. , -inf. ],\n",
|
|
" [-0.40165186, 0.77547729, -0.64469045]])\n",
|
|
"2 Tensor(shape=[2, 3], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n",
|
|
" [[ 2.05994511, -inf. , -inf. ],\n",
|
|
" [-0.40165186, 0.77547729, -0.64469045]])\n",
|
|
"3 Tensor(shape=[2, 3], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n",
|
|
" [[ 0. , -inf. , -inf. ],\n",
|
|
" [-0.40165186, 0.77547729, -0.64469045]])\n",
|
|
"Tensor(shape=[2, 3], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n",
|
|
" [[ 0. , -inf. , -inf. ],\n",
|
|
" [-0.40165186, 0.77547729, -0.64469045]])\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"paddle.bool = 'bool'\n",
|
|
"\n",
|
|
"def masked_fill(xs:paddle.Tensor, mask:paddle.Tensor, value:float):\n",
|
|
" print(xs)\n",
|
|
" trues = paddle.ones_like(xs) * value\n",
|
|
" assert xs.shape == mask.shape\n",
|
|
" xs = paddle.where(mask, trues, xs)\n",
|
|
" return xs\n",
|
|
"\n",
|
|
"def masked_fill_(xs:paddle.Tensor, mask:paddle.Tensor, value:float):\n",
|
|
" print('x', xs)\n",
|
|
" trues = paddle.ones_like(xs) * value\n",
|
|
" assert xs.shape == mask.shape\n",
|
|
" ret = paddle.where(mask, trues, xs)\n",
|
|
" print('2', xs)\n",
|
|
" paddle.assign(ret, output=xs)\n",
|
|
" print('3', xs)\n",
|
|
"\n",
|
|
"paddle.Tensor.masked_fill = masked_fill\n",
|
|
"paddle.Tensor.masked_fill_ = masked_fill_\n",
|
|
"\n",
|
|
"def mask_finished_scores_pd(score: paddle.Tensor,\n",
|
|
" flag: paddle.Tensor) -> paddle.Tensor:\n",
|
|
" \"\"\"\n",
|
|
" If a sequence is finished, we only allow one alive branch. This function\n",
|
|
" aims to give one branch a zero score and the rest -inf score.\n",
|
|
" Args:\n",
|
|
" score (torch.Tensor): A real value array with shape\n",
|
|
" (batch_size * beam_size, beam_size).\n",
|
|
" flag (torch.Tensor): A bool array with shape\n",
|
|
" (batch_size * beam_size, 1).\n",
|
|
" Returns:\n",
|
|
" torch.Tensor: (batch_size * beam_size, beam_size).\n",
|
|
" \"\"\"\n",
|
|
" beam_size = score.shape[-1]\n",
|
|
" zero_mask = paddle.zeros_like(flag, dtype=paddle.bool)\n",
|
|
" if beam_size > 1:\n",
|
|
" unfinished = paddle.concat((zero_mask, flag.tile([1, beam_size - 1])),\n",
|
|
" axis=1)\n",
|
|
" finished = paddle.concat((flag, zero_mask.tile([1, beam_size - 1])),\n",
|
|
" axis=1)\n",
|
|
" else:\n",
|
|
" unfinished = zero_mask\n",
|
|
" finished = flag\n",
|
|
" print(unfinished)\n",
|
|
" print(finished)\n",
|
|
" \n",
|
|
" #score.masked_fill_(unfinished, -float('inf'))\n",
|
|
" #score.masked_fill_(finished, 0)\n",
|
|
"# infs = paddle.ones_like(score) * -float('inf')\n",
|
|
"# score = paddle.where(unfinished, infs, score)\n",
|
|
"# score = paddle.where(finished, paddle.zeros_like(score), score)\n",
|
|
"\n",
|
|
"# score = score.masked_fill(unfinished, -float('inf'))\n",
|
|
"# score = score.masked_fill(finished, 0)\n",
|
|
" score.masked_fill_(unfinished, -float('inf'))\n",
|
|
" score.masked_fill_(finished, 0)\n",
|
|
" return score\n",
|
|
"\n",
|
|
"r = mask_finished_scores_pd(score, flag)\n",
|
|
"print(r)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 57,
|
|
"id": "vocal-prime",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"<bound method PyCapsule.value of Tensor(shape=[2, 3], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n",
|
|
" [[ 0. , -inf. , -inf. ],\n",
|
|
" [-0.40165186, 0.77547729, -0.64469045]])>"
|
|
]
|
|
},
|
|
"execution_count": 57,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"score.value"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 71,
|
|
"id": "bacterial-adolescent",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from typing import Union, Any"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 72,
|
|
"id": "absent-fiber",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def repeat(xs : paddle.Tensor, *size: Any):\n",
|
|
" print(size)\n",
|
|
" return paddle.tile(xs, size)\n",
|
|
"paddle.Tensor.repeat = repeat"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 73,
|
|
"id": "material-harbor",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"(1, 2)\n",
|
|
"Tensor(shape=[2, 2], dtype=bool, place=CUDAPlace(0), stop_gradient=True,\n",
|
|
" [[True , True ],\n",
|
|
" [False, False]])\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"flag = paddle.ones((2, 1), dtype='bool')\n",
|
|
"flag[1] = False\n",
|
|
"print(flag.repeat(1, 2))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 84,
|
|
"id": "acute-brighton",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"(Tensor(shape=[1], dtype=int64, place=CUDAPlace(0), stop_gradient=True,\n",
|
|
" [1]), 2)\n",
|
|
"Tensor(shape=[2, 2], dtype=bool, place=CUDAPlace(0), stop_gradient=True,\n",
|
|
" [[True , True ],\n",
|
|
" [False, False]])\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"flag = paddle.ones((2, 1), dtype='bool')\n",
|
|
"flag[1] = False\n",
|
|
"print(flag.repeat(paddle.to_tensor(1), 2))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 85,
|
|
"id": "european-rugby",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def size(xs, *args: int):\n",
|
|
" nargs = len(args)\n",
|
|
" s = paddle.shape(xs)\n",
|
|
" assert(nargs <= 1)\n",
|
|
" if nargs == 1:\n",
|
|
" return s[args[0]]\n",
|
|
" else:\n",
|
|
" return s\n",
|
|
"paddle.Tensor.size = size"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 86,
|
|
"id": "moral-special",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"Tensor(shape=[2], dtype=int32, place=CPUPlace, stop_gradient=True,\n",
|
|
" [2, 1])"
|
|
]
|
|
},
|
|
"execution_count": 86,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"flag.size()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 87,
|
|
"id": "ahead-coach",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"Tensor(shape=[1], dtype=int32, place=CPUPlace, stop_gradient=True,\n",
|
|
" [1])"
|
|
]
|
|
},
|
|
"execution_count": 87,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"flag.size(1)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 88,
|
|
"id": "incomplete-fitness",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"Tensor(shape=[1], dtype=int32, place=CPUPlace, stop_gradient=True,\n",
|
|
" [2])"
|
|
]
|
|
},
|
|
"execution_count": 88,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"flag.size(0)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "upset-connectivity",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.7.0"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|