{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "primary-organic", "metadata": {}, "outputs": [], "source": [ "import torch" ] }, { "cell_type": "code", "execution_count": 38, "id": "stopped-semester", "metadata": {}, "outputs": [], "source": [ "def mask_finished_scores(score: torch.Tensor,\n", " flag: torch.Tensor) -> torch.Tensor:\n", " \"\"\"\n", " If a sequence is finished, we only allow one alive branch. This function\n", " aims to give one branch a zero score and the rest -inf score.\n", " Args:\n", " score (torch.Tensor): A real value array with shape\n", " (batch_size * beam_size, beam_size).\n", " flag (torch.Tensor): A bool array with shape\n", " (batch_size * beam_size, 1).\n", " Returns:\n", " torch.Tensor: (batch_size * beam_size, beam_size).\n", " \"\"\"\n", " beam_size = score.size(-1)\n", " zero_mask = torch.zeros_like(flag, dtype=torch.bool)\n", " if beam_size > 1:\n", " unfinished = torch.cat((zero_mask, flag.repeat([1, beam_size - 1])),\n", " dim=1)\n", " finished = torch.cat((flag, zero_mask.repeat([1, beam_size - 1])),\n", " dim=1)\n", " else:\n", " unfinished = zero_mask\n", " finished = flag\n", " print(unfinished)\n", " print(finished)\n", " score.masked_fill_(unfinished, -float('inf'))\n", " score.masked_fill_(finished, 0)\n", " return score" ] }, { "cell_type": "code", "execution_count": 58, "id": "agreed-portuguese", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([[ True],\n", " [False]])\n", "tensor([[-0.8841, 0.7381, -0.9986],\n", " [ 0.2675, -0.7971, 0.3798]])\n", "tensor([[ True, True],\n", " [False, False]])\n" ] } ], "source": [ "score = torch.randn((2, 3))\n", "flag = torch.ones((2, 1), dtype=torch.bool)\n", "flag[1] = False\n", "print(flag)\n", "print(score)\n", "print(flag.repeat([1, 2]))" ] }, { "cell_type": "code", "execution_count": 59, "id": "clean-aspect", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([[False, True, True],\n", " [False, False, False]])\n", "tensor([[ True, False, False],\n", " [False, False, False]])\n", "tensor([[ 0.0000, -inf, -inf],\n", " [ 0.2675, -0.7971, 0.3798]])\n", "tensor([[ 0.0000, -inf, -inf],\n", " [ 0.2675, -0.7971, 0.3798]])\n" ] } ], "source": [ "r = mask_finished_scores(score, flag)\n", "print(r)\n", "print(score)" ] }, { "cell_type": "code", "execution_count": 55, "id": "thrown-airline", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Tensor(shape=[2, 1], dtype=bool, place=CUDAPlace(0), stop_gradient=True,\n", " [[True ],\n", " [False]])\n", "Tensor(shape=[2, 3], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n", " [[ 2.05994511, 1.87704289, 0.01988174],\n", " [-0.40165186, 0.77547729, -0.64469045]])\n", "Tensor(shape=[2, 2], dtype=bool, place=CUDAPlace(0), stop_gradient=True,\n", " [[True , True ],\n", " [False, False]])\n" ] } ], "source": [ "import paddle\n", "\n", "score = paddle.randn((2, 3))\n", "flag = paddle.ones((2, 1), dtype='bool')\n", "flag[1] = False\n", "print(flag)\n", "print(score)\n", "print(flag.tile([1, 2]))" ] }, { "cell_type": "code", "execution_count": 56, "id": "internal-patent", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Tensor(shape=[2, 3], dtype=bool, place=CUDAPlace(0), stop_gradient=True,\n", " [[False, True , True ],\n", " [False, False, False]])\n", "Tensor(shape=[2, 3], dtype=bool, place=CUDAPlace(0), stop_gradient=True,\n", " [[True , False, False],\n", " [False, False, False]])\n", "x Tensor(shape=[2, 3], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n", " [[ 2.05994511, 1.87704289, 0.01988174],\n", " [-0.40165186, 0.77547729, -0.64469045]])\n", "2 Tensor(shape=[2, 3], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n", " [[ 2.05994511, 1.87704289, 0.01988174],\n", " [-0.40165186, 0.77547729, -0.64469045]])\n", "3 Tensor(shape=[2, 3], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n", " [[ 2.05994511, -inf. , -inf. ],\n", " [-0.40165186, 0.77547729, -0.64469045]])\n", "x Tensor(shape=[2, 3], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n", " [[ 2.05994511, -inf. , -inf. ],\n", " [-0.40165186, 0.77547729, -0.64469045]])\n", "2 Tensor(shape=[2, 3], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n", " [[ 2.05994511, -inf. , -inf. ],\n", " [-0.40165186, 0.77547729, -0.64469045]])\n", "3 Tensor(shape=[2, 3], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n", " [[ 0. , -inf. , -inf. ],\n", " [-0.40165186, 0.77547729, -0.64469045]])\n", "Tensor(shape=[2, 3], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n", " [[ 0. , -inf. , -inf. ],\n", " [-0.40165186, 0.77547729, -0.64469045]])\n" ] } ], "source": [ "paddle.bool = 'bool'\n", "\n", "def masked_fill(xs:paddle.Tensor, mask:paddle.Tensor, value:float):\n", " print(xs)\n", " trues = paddle.ones_like(xs) * value\n", " assert xs.shape == mask.shape\n", " xs = paddle.where(mask, trues, xs)\n", " return xs\n", "\n", "def masked_fill_(xs:paddle.Tensor, mask:paddle.Tensor, value:float):\n", " print('x', xs)\n", " trues = paddle.ones_like(xs) * value\n", " assert xs.shape == mask.shape\n", " ret = paddle.where(mask, trues, xs)\n", " print('2', xs)\n", " paddle.assign(ret, output=xs)\n", " print('3', xs)\n", "\n", "paddle.Tensor.masked_fill = masked_fill\n", "paddle.Tensor.masked_fill_ = masked_fill_\n", "\n", "def mask_finished_scores_pd(score: paddle.Tensor,\n", " flag: paddle.Tensor) -> paddle.Tensor:\n", " \"\"\"\n", " If a sequence is finished, we only allow one alive branch. This function\n", " aims to give one branch a zero score and the rest -inf score.\n", " Args:\n", " score (torch.Tensor): A real value array with shape\n", " (batch_size * beam_size, beam_size).\n", " flag (torch.Tensor): A bool array with shape\n", " (batch_size * beam_size, 1).\n", " Returns:\n", " torch.Tensor: (batch_size * beam_size, beam_size).\n", " \"\"\"\n", " beam_size = score.shape[-1]\n", " zero_mask = paddle.zeros_like(flag, dtype=paddle.bool)\n", " if beam_size > 1:\n", " unfinished = paddle.concat((zero_mask, flag.tile([1, beam_size - 1])),\n", " axis=1)\n", " finished = paddle.concat((flag, zero_mask.tile([1, beam_size - 1])),\n", " axis=1)\n", " else:\n", " unfinished = zero_mask\n", " finished = flag\n", " print(unfinished)\n", " print(finished)\n", " \n", " #score.masked_fill_(unfinished, -float('inf'))\n", " #score.masked_fill_(finished, 0)\n", "# infs = paddle.ones_like(score) * -float('inf')\n", "# score = paddle.where(unfinished, infs, score)\n", "# score = paddle.where(finished, paddle.zeros_like(score), score)\n", "\n", "# score = score.masked_fill(unfinished, -float('inf'))\n", "# score = score.masked_fill(finished, 0)\n", " score.masked_fill_(unfinished, -float('inf'))\n", " score.masked_fill_(finished, 0)\n", " return score\n", "\n", "r = mask_finished_scores_pd(score, flag)\n", "print(r)" ] }, { "cell_type": "code", "execution_count": 57, "id": "vocal-prime", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 57, "metadata": {}, "output_type": "execute_result" } ], "source": [ "score.value" ] }, { "cell_type": "code", "execution_count": 71, "id": "bacterial-adolescent", "metadata": {}, "outputs": [], "source": [ "from typing import Union, Any" ] }, { "cell_type": "code", "execution_count": 72, "id": "absent-fiber", "metadata": {}, "outputs": [], "source": [ "def repeat(xs : paddle.Tensor, *size: Any):\n", " print(size)\n", " return paddle.tile(xs, size)\n", "paddle.Tensor.repeat = repeat" ] }, { "cell_type": "code", "execution_count": 73, "id": "material-harbor", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(1, 2)\n", "Tensor(shape=[2, 2], dtype=bool, place=CUDAPlace(0), stop_gradient=True,\n", " [[True , True ],\n", " [False, False]])\n" ] } ], "source": [ "flag = paddle.ones((2, 1), dtype='bool')\n", "flag[1] = False\n", "print(flag.repeat(1, 2))" ] }, { "cell_type": "code", "execution_count": 84, "id": "acute-brighton", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(Tensor(shape=[1], dtype=int64, place=CUDAPlace(0), stop_gradient=True,\n", " [1]), 2)\n", "Tensor(shape=[2, 2], dtype=bool, place=CUDAPlace(0), stop_gradient=True,\n", " [[True , True ],\n", " [False, False]])\n" ] } ], "source": [ "flag = paddle.ones((2, 1), dtype='bool')\n", "flag[1] = False\n", "print(flag.repeat(paddle.to_tensor(1), 2))" ] }, { "cell_type": "code", "execution_count": 85, "id": "european-rugby", "metadata": {}, "outputs": [], "source": [ "def size(xs, *args: int):\n", " nargs = len(args)\n", " s = paddle.shape(xs)\n", " assert(nargs <= 1)\n", " if nargs == 1:\n", " return s[args[0]]\n", " else:\n", " return s\n", "paddle.Tensor.size = size" ] }, { "cell_type": "code", "execution_count": 86, "id": "moral-special", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Tensor(shape=[2], dtype=int32, place=CPUPlace, stop_gradient=True,\n", " [2, 1])" ] }, "execution_count": 86, "metadata": {}, "output_type": "execute_result" } ], "source": [ "flag.size()" ] }, { "cell_type": "code", "execution_count": 87, "id": "ahead-coach", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Tensor(shape=[1], dtype=int32, place=CPUPlace, stop_gradient=True,\n", " [1])" ] }, "execution_count": 87, "metadata": {}, "output_type": "execute_result" } ], "source": [ "flag.size(1)" ] }, { "cell_type": "code", "execution_count": 88, "id": "incomplete-fitness", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Tensor(shape=[1], dtype=int32, place=CPUPlace, stop_gradient=True,\n", " [2])" ] }, "execution_count": 88, "metadata": {}, "output_type": "execute_result" } ], "source": [ "flag.size(0)" ] }, { "cell_type": "code", "execution_count": null, "id": "upset-connectivity", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.0" } }, "nbformat": 4, "nbformat_minor": 5 }